In [8]:
%load_ext autoreload
%autoreload 2
# this only works on startup!
from jax import config
config.update("jax_enable_x64", True)

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
from encoding_information.gpu_utils import limit_gpu_memory_growth
limit_gpu_memory_growth()

from cleanplots import *
from tqdm import tqdm
from encoding_information.information_estimation import *
from encoding_information.image_utils import *
from encoding_information.models.gaussian_process import StationaryGaussianProcess

from encoding_information.bsccm_utils import *
from bsccm import BSCCM
from jax import jit
import numpy as np
import yaml
from led_array.tf_util import prepare_test_dataset
import tensorflow.keras as tfk

bsccm = BSCCM('/home/hpinkard_waller/data/BSCCM/')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Opening BSCCM
Opened BSCCM


In [58]:
def get_marker_index(target_row):
    return np.flatnonzero(np.logical_not(np.isnan(target_row)))[0]

#compute negative log_likelihood over test set
def compute_nlls(model, test_dataset, max_num):
    negative_log_likelihoods = []
    marker_indices = []
    for i, (image, target) in tqdm(enumerate(test_dataset), total=max_num):
        if max_num is not None and i > max_num:
            break
        marker_index = get_marker_index(target)
        marker_indices.append(marker_index)
        marker = markers[marker_index]
        mixture = model(image[None])[marker]
        nll = -mixture.log_prob(target[marker_index]).numpy() 
        negative_log_likelihoods.append(nll)
    return np.array(negative_log_likelihoods), np.array(marker_indices)


def estimate_mi(model_name, config, indices, test_dataset_size, patch_size):
    saving_name = f'{model_name}_{patch_size}patch_mi_estimates.npz'

        # check if already cached
    if os.path.exists(f'.cached/{saving_name}'):
        print(f'Loading cached results for {model_name}')
        return np.load(f'.cached/{saving_name}')

    median_filter = config['data']['synthetic_noise']['median_filter']

    images = load_bsccm_images(bsccm, indices=indices[:test_dataset_size], channel=config['data']['channels'][0], 
                convert_units_to_photons=True, edge_crop=config['data']['synthetic_noise']['edge_crop'],
                median_filter=median_filter)

    mean_photons_per_pixel = np.mean(images)
    rescale_fraction = config['data']['synthetic_noise']['photons_per_pixel'] / mean_photons_per_pixel
    if rescale_fraction > 1:
        raise Exception('Rescale fraction must be less than 1')

    patches = extract_patches(images, patch_size=patch_size)

    if median_filter:
        # assume noiseless
        noisy_patches = add_noise(patches * rescale_fraction)
    else:
        noisy_patches = add_shot_noise_to_experimenal_data(patches, rescale_fraction)
    
    mi_pixel_cnn = estimate_mutual_information(noisy_patches, clean_images=patches if median_filter else None, 
                    entropy_model='pixel_cnn', verbose=True)
    mi_gp = estimate_mutual_information(noisy_patches, clean_images=patches if median_filter else None,
                     entropy_model='gaussian', verbose=True)

    # save the cached results (both nlls and marker indices in a single file)
    np.savez(f'.cached/{saving_name}.npz', mi_pixel_cnn=mi_pixel_cnn, mi_gp=mi_gp)
    return np.load(f'.cached/{saving_name}.npz')
    

def test_set_phenotyping_nll(model_name, config):
    saving_name = f'{model_name}_phenotyping_nll.npz'

    # check if already cached
    if os.path.exists(f'.cached/{saving_name}'):
        print(f'Loading cached results for {model_name}')
        return np.load(f'.cached/{saving_name}')
    
    markers, image_target_generator, dataset_size, display_range, indices = get_bsccm_image_marker_generator(bsccm, **config['data'])
    test_dataset, test_dataset_size = prepare_test_dataset(config['hyperparameters']['test_fraction'], image_target_generator, dataset_size)
    
    model = tfk.models.load_model(config['saving_dir'] + model_name + os.sep + 'model/saved_model.h5', compile=False)

    nlls, marker_indices = compute_nlls(model, test_dataset, max_num=test_dataset_size)

    # save the cached results (both nlls and marker indices in a single file)
    np.savez(f'.cached/{saving_name}', nlls=nlls, marker_indices=marker_indices)
    return np.load(f'.cached/{saving_name}')
    

## Compute and cache protein prediction performance

In [60]:
patch_size = 25

config_dir = '/home/hpinkard_waller/GitRepos/EncodingInformation/led_array/phenotyping_experiments/config_files/complete/'
# config_name = 'Synthetic_Noise_Brightfield_300_photons_replicate_1.yaml'
# TODO: when theyre all in subfolders, this will need to be changed
config_prefix = 'Synthetic_Noise'

# make a cached_results directory if it doesn't exist
if not os.path.exists('.cached'):
    os.makedirs('.cached')

results = {}
files = os.listdir(config_dir)
for i, file in enumerate(files):
    print(f'file {i} of {len(files)}')
    if not file.startswith(config_prefix):
        continue # TODO remove this once subfolders

    model_name = file.split('.')[0]

    with open(config_dir + file, 'r') as f:
        config = yaml.safe_load(f)

    mi = estimate_mi(model_name, config, indices, test_dataset_size, patch_size)
    results = {**results, **mi}

    phenotyping_nll = test_set_phenotyping_nll(model_name, config)
    results = {**results, **phenotyping_nll}


file 0 of 23


TypeError: PRNG key seed must be an integer; got Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=1/0)>

## compute and cache MI estimates