In [1]:
%load_ext autoreload
%autoreload 2
# this only works on startup!
from jax import config
config.update("jax_enable_x64", True)

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = '2'
from gpu_utils import limit_gpu_memory_growth
limit_gpu_memory_growth()

from cleanplots import *
from tqdm import tqdm
from information_estimation import *
from image_utils import *

from led_array.bsccm_utils import *
from bsccm import BSCCM
from jax import jit
import numpy as onp
import jax.numpy as np

bsccm = BSCCM('/home/hpinkard_waller/data/BSCCM/')

Opening BSCCM
Opened BSCCM


In [2]:
# Compute a "true" reference covariance matrix
num_images = 1000
num_patches = 1000
edge_crop = 32
channel = 'LED119'
ev_floor = 1e-5

patch_sizes = [5, 10]

images = load_bsccm_images(bsccm, channel=channel, num_images=num_images, edge_crop=edge_crop, median_filter=True)

cov_mats_stationary_pd = []
means = []
for patch_size in patch_sizes:
    patches = extract_patches(images, patch_size, num_patches=num_patches)
    cov_mat = compute_cov_mat(patches)
    cov_mat_pd = make_positive_definite(cov_mat, eigenvalue_floor=ev_floor, show_plot=False)
    cov_mat_stationary = compute_stationary_cov_mat(patches)
    cov_mat_stationary_pd = make_positive_definite(cov_mat_stationary, eigenvalue_floor=ev_floor, show_plot=False)
    means.append(np.mean(patches))
    cov_mats_stationary_pd.append(cov_mat_stationary_pd)

In [6]:
def compute_stationary_log_likelihood(samples, cov_mat, mean, prefer_iterative=False):
    """
    Compute the likelihood of a set of samples from a stationary process

    :param samples: N x H x W array of samples
    :param cov_mat: covariance matrix of the process
    :param mean: float mean of the process
    :param prefer_iterative: if True, compute likelihood iteratively, otherwise compute directly if possible

    :return: N x 1 array of log likelihoods
    """
    # samples is not going to be the same size as the covariance matrix
    # if sample is smaller than cov_mat, throw an excpetion
    # if sample is larger than cov_mat, then compute likelihood iteratively
    # if sample is the same size as cov_mat, then compute likelihood directly, unless prefer_iterative is True
    # check that mean if float or 1 element array
    if not isinstance(mean, float) or mean.shape != tuple():
        raise ValueError('Mean must be a float or a 1 element array')
    N_samples = samples.shape[0]
    # check for expected shape
    if samples.ndim != 3 or samples.shape[1] != samples.shape[2]:
        raise ValueError('Samples must be N x H x W')
    sample_size = samples.shape[1]

    if np.linalg.eigvalsh(cov_mat).min() < 0:
        raise ValueError('Covariance matrix is not positive definite')
    # precompute everything that will be the same for all samples
    patch_size = int(np.sqrt(cov_mat.shape[0]))
    vectorized_masks = []
    variances = []
    mean_multipliers = []
    for i in tqdm(np.arange(sample_size), desc='precomputing masks and variances'):
        for j in np.arange(sample_size):
            if not prefer_iterative and i < patch_size - 1 and j < patch_size - 1:
                # Add placeholders since these get sampled from the covariance matrix directly
                variances.append(None)
                mean_multipliers.append(None)
                vectorized_masks.append(None)
            else:
                top_part = np.ones((min(i, patch_size - 1), patch_size), dtype=bool)
                left_part = np.ones((1, min(j, patch_size - 1)), dtype=bool)
                right_part = np.zeros((1, patch_size - min(j, patch_size - 1)), dtype=bool)
                bottom_part = np.zeros((patch_size - min(i, patch_size - 1) - 1, patch_size), dtype=bool)
                middle_row = np.hstack((left_part, right_part))
                conditioning_mask = np.vstack((top_part, middle_row, bottom_part))

                vectorized_mask = conditioning_mask.reshape(-1)
                vectorized_masks.append(vectorized_mask)
                # find the linear index in the covariance matrix of the pixel we want to predict
                pixel_to_predict_index = np.min(np.array([i, patch_size - 1])) * patch_size + np.min(np.array([j, patch_size - 1]))
                sigma_11 = cov_mat[vectorized_mask][:, vectorized_mask].reshape(pixel_to_predict_index, pixel_to_predict_index) 
                sigma_12 = cov_mat[vectorized_mask][:, pixel_to_predict_index].reshape(-1, 1)
                sigma_21 = sigma_12.reshape(1, -1)
                sigma_22 = cov_mat[pixel_to_predict_index, pixel_to_predict_index].reshape(1, 1)

                variances.append(sigma_22 - sigma_21 @ np.linalg.inv(sigma_11) @ sigma_12)
                mean_multipliers.append(sigma_21 @ np.linalg.inv(sigma_11))
                # print(i, j, np.linalg.det(sigma_11))

                # print(i,j, mean_multipliers[-1].mean())
                if variances[-1] < 0:
                    raise ValueError('Variance is negative {} {}'.format(i, j))

    print('evaluating likelihood')
    
    log_likelihoods = []
    if not prefer_iterative:
        # compute the log_likelihood to the top left image subpatch of the image directly
        top_left_subpatch = samples[:, :patch_size, :patch_size].reshape(N_samples, -1)
        log_likelihoods.append(jax.scipy.stats.multivariate_normal.logpdf(top_left_subpatch, mean=mean, cov=cov_mat))


    for i in tqdm(np.arange(sample_size), desc='generating sample'):
        for j in np.arange(sample_size):

            if not prefer_iterative and i < patch_size - 1 and j < patch_size - 1:
                # already did this
                pass
            elif i == 0 and j == 0:
                # top left pixel is not conditioned on anything
                mean = 0
                variance = cov_mat[0, 0]
                # compute likelihood of top left pixel
                log_likelihoods.append(jax.scipy.stats.norm.logpdf(samples[:, i, j], loc=mean, scale=np.sqrt(variance)))
            else:
                vectorized_mask = vectorized_masks[i * sample_size + j]
                # get the relevant window of previous values
                relevant_window = samples[:, max(i - patch_size + 1, 0):max(i - patch_size + 1, 0) + patch_size, 
                                                max(j - patch_size + 1, 0):max(j - patch_size + 1, 0) + patch_size]
                previous_values = relevant_window.reshape(-1)[vectorized_mask].reshape(-1, 1)
                
                mean = mean_multipliers[i * sample_size + j] @ previous_values
                variance = variances[i * sample_size + j]
                # compute likelihood of pixel
                log_likelihoods.append(jax.scipy.stats.norm.logpdf(samples[:, i, j], loc=mean, scale=np.sqrt(variance)))
    return np.sum(np.array(log_likelihoods), axis=0)

In [7]:
test_patch = patches[0]
test_patch = test_patch.reshape(1, test_patch.shape[0], test_patch.shape[1])

cov_mat = cov_mats_stationary_pd[-1]
mean = means[-1]
direct = jax.scipy.stats.multivariate_normal.logpdf(test_patch.reshape(-1, cov_mat.shape[0]), mean=mean, cov=cov_mat)

function = compute_stationary_log_likelihood(test_patch, cov_mats_stationary_pd, mean)
function_iterative = compute_stationary_log_likelihood(test_patch, cov_mats_stationary_pd, mean, prefer_iterative=True)


In [21]:
mean.shap