In [1]:
%load_ext autoreload
%autoreload 2
# this only works on startup!
from jax import config
config.update("jax_enable_x64", True)

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = '2'

from gpu_utils import limit_gpu_memory_growth
limit_gpu_memory_growth()

from cleanplots import *
from tqdm import tqdm
from information_estimation import *
from image_utils import *

from led_array.bsccm_utils import *
from bsccm import BSCCM
from jax import jit
import numpy as onp
import jax.numpy as np

bsccm = BSCCM('/home/hpinkard_waller/data/BSCCM/')

Opening BSCCM
Opened BSCCM


In [2]:
# load images, extract patches, and compute cov mats
num_images = 10000
num_patches = 10000
num_test_set_patches = 5000
edge_crop = 32

# channels = ['LED119', 'DPC_Right', 'Brightfield']
channels = ['LED119', 'DPC_Right']
num_bootstrap_samples = 10
patch_sizes = np.array([2, 5, 8, 10, 12, 20, 32, 45, 50])

# num_images = 100
# num_patches = 100
# num_test_set_patches = 500
# edge_crop = 32
# channels = ['LED119']
# num_bootstrap_samples = 3
# patch_sizes = np.array([2,  10,])


log_likelihood_means_by_channel = {}
log_likelihood_confidence_intervals_by_channel = {}

for channel in channels:
    print(f'Channel: {channel}')

    all_images = load_bsccm_images(bsccm, channel=channel, num_images=num_images + num_test_set_patches, edge_crop=edge_crop, median_filter=False)
    images = all_images[:num_images]
    test_set_images = all_images[num_images:]

    def compute_normalized_log_likelihood(images):       
        patches = extract_patches(images, patch_size, num_patches=num_patches, verbose=False,)
        test_set_patches = extract_patches(test_set_images, patch_size, num_patches=num_test_set_patches, verbose=False)
        cov_mat_stationary = compute_stationary_cov_mat(patches)
        cov_mat_stationary_pd = make_positive_definite(cov_mat_stationary, eigenvalue_floor=1e-4, show_plot=False)
        mean = np.ones(patch_size ** 2) * np.mean(patches)
        test_data = test_set_patches.reshape(test_set_patches.shape[0], -1)
        # compute log likelihood of test set
        return np.mean(jax.scipy.stats.multivariate_normal.logpdf(test_data, mean=mean, cov=cov_mat_stationary_pd))  

    log_likelihood_means = []
    log_likelihood_confidence_intervals = []
    for patch_size in tqdm(patch_sizes):
        mean, interval = run_bootstrap(images, compute_normalized_log_likelihood, num_bootstrap_samples=num_bootstrap_samples, confidence_interval=90)
        log_likelihood_means.append(mean)
        log_likelihood_confidence_intervals.append(interval)
    log_likelihood_means_by_channel[channel] = np.array(log_likelihood_means)
    log_likelihood_confidence_intervals_by_channel[channel] = np.array(log_likelihood_confidence_intervals)

Channel: LED119


Running bootstraps:  30%|███       | 3/10 [05:01<11:43, 100.45s/it]
  0%|          | 0/9 [05:02<?, ?it/s]


KeyboardInterrupt: 

In [None]:
fig, ax = plt.subplots(len(channels), 1, figsize=(4, 4 * len(channels)), sharex=True)
if len(channels) == 1:
    ax = [ax]
for i, channel in enumerate(channels):
    ax[i].plot(patch_sizes, log_likelihood_means, '.-')
    ax[i].fill_between(patch_sizes, log_likelihood_confidence_intervals_by_channel[channel][:, 0], log_likelihood_confidence_intervals_by_channel[channel][:, 1], alpha=0.5)
    ax[i].set(ylabel='Log likelihood')
    clear_spines(ax[i])
    ax[i].set_title(channel)
ax[-1].set(xlabel='Patch size (pixels)')