# Makes plot for MI vs Classification Relationship

In [None]:
%load_ext autoreload 
%autoreload 2

In [None]:
from jax import config
config.update("jax_enable_x64", True)
import numpy as np

import sys
sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/')
sys.path.insert(0, '/home/lkabuli_waller/workspace/EncodingInformation/imager_experiments')
from leyla_fns import *
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
print(os.environ.get('PYTHONPATH'))
from cleanplots import * 

## Load PSFs
Dataset gets converted to units of photons, photons are at [20, 40, 60, 80, 100, 150, 200, 250, 300] values

In [None]:
diffuser_psf = skimage.io.imread('psfs/diffuser_psf.png')
diffuser_psf = diffuser_psf[:,:,1]
diffuser_resize = diffuser_psf[200:500, 250:550]
diffuser_resize = resize(diffuser_resize, (400, 400), anti_aliasing=True)  #resize(diffuser_psf, (28, 28))
diffuser_region = diffuser_resize[:32*4, :32*4]
diffuser_region /=  np.sum(diffuser_region)

In [None]:
# load the PSFs

diffuser_psf = load_diffuser_32()
four_psf = load_four_lens_32()
one_psf = load_single_lens_32()
plt.figure(figsize=(10, 4))
plt.subplot(1,3,1)
plt.imshow(one_psf, cmap='inferno', interpolation='spline36')
plt.title('Single Lens')
# clear the axes
plt.gca().set_xticks([])
plt.gca().set_yticks([])
plt.subplot(1,3,2)
plt.imshow(four_psf, cmap='inferno', interpolation='spline36')
plt.title('Four Lenses')
# clear the axes
plt.gca().set_xticks([])
plt.gca().set_yticks([])
plt.subplot(1,3,3)
plt.imshow(diffuser_region, cmap='inferno', interpolation='spline36')
plt.title('Diffuser')
plt.suptitle("PSFs for CIFAR10 Dataset")
# clear the axes
plt.gca().set_xticks([])
plt.gca().set_yticks([])

In [None]:
plt.figure(figsize=(10, 4))
plt.subplot(1,3,1)
plt.imshow(one_psf, cmap='inferno')
plt.title('Single Lens')
plt.subplot(1,3,2)
plt.imshow(four_psf, cmap='inferno')
plt.title('Four Lenses')
plt.subplot(1,3,3)
plt.imshow(diffuser_psf, cmap='inferno')
plt.title('Diffuser')
plt.suptitle("PSFs for CIFAR10 Dataset")

In [None]:

model_names = ['cnn']
seed_values = np.arange(1, 10) # cifar10 MNIST ran for 9 seed values

# set photon properties 
bias = 10 # in photons
mean_photon_count_list = [20, 40, 60, 80, 100, 150, 200, 250, 300]
max_photon_count = mean_photon_count_list[-1]

# set eligible psfs

psf_names = ['one', 'four', 'diffuser']

# MI estimator parameters 
patch_size = 32
num_patches = 10000
bs = 500

## Load MI data and make plots of it
Using updated MI data from 01/04/2024 (which is basically identical to the previous MI data from 11/14/2023, but run for 50 epochs and more seeds etc)

The plot has essentially invisible error bars. No more outlier issues

In [None]:
from cleanplots import *
get_color_cycle()[0]

In [None]:
mi_folder = ''

### Minimum plot with no error bars

In [None]:
gaussian_mi_estimates_across_psfs = [] # only keeps the minimum values, no outliers
pixelcnn_mi_estimates_across_psfs = [] # only keeps the minimum values, no outliers

fig, ax = plt.subplots(1, 1, figsize=(8, 6))
for psf_name in psf_names:
    gaussian_across_photons = [] 
    pixelcnn_across_photons = []
    for photon_count in mean_photon_count_list:
        gaussian_mi_estimate = np.load(mi_folder + 'cifar10_mi_estimates/gaussian_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))
        pixelcnn_mi_estimate = np.load(mi_folder + 'cifar10_mi_estimates/pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))
        gaussian_across_photons.append(gaussian_mi_estimate)
        pixelcnn_across_photons.append(pixelcnn_mi_estimate)
    gaussian_mins = np.min(gaussian_across_photons, axis=1)
    pixelcnn_mins = np.min(pixelcnn_across_photons, axis=1)
    ax.plot(mean_photon_count_list, gaussian_mins, '-', label='Gaussian {}'.format(psf_name))
    ax.plot(mean_photon_count_list, pixelcnn_mins, '-', label='PixelCNN {}'.format(psf_name))
    gaussian_mi_estimates_across_psfs.append(gaussian_mins) # only keep mean dataset for use
    pixelcnn_mi_estimates_across_psfs.append(pixelcnn_mins) # only keep mean datas
plt.legend()
plt.title("Gaussian vs. PixelCNN MI Estimates Across Photon Count, CIFAR10, 4 Seeds, Minimums")
plt.ylabel('Estimated Mutual Information')
plt.xlabel('Mean Photon Count')

gaussian_mi_estimates_across_psfs = np.array(gaussian_mi_estimates_across_psfs)
pixelcnn_mi_estimates_across_psfs = np.array(pixelcnn_mi_estimates_across_psfs)

In [None]:
plt.figure(figsize=(10, 6))
for i, modality in enumerate(psf_names):
    plt.plot(mean_photon_count_list, gaussian_mi_estimates_across_psfs[i], label = '{} Gaussian'.format(modality), color = get_color_cycle()[i], linestyle='--')
    plt.plot(mean_photon_count_list, pixelcnn_mi_estimates_across_psfs[i], label = '{} PixelCNN'.format(modality), color = get_color_cycle()[i])
plt.legend()
plt.xlabel('Mean Photon Count')
plt.ylabel("Estimated Mutual Information")
plt.title('Estimated Mutual Information vs. Mean Photon Count, CIFAR10')

In [None]:
get_color_cycle()[0] # for one lens
get_color_cycle()[1] # for four lenses
get_color_cycle()[2] # for diffuser

In [None]:
psf_names_verbose = ['One Lens', 'Four Lens', 'Diffuser']
plt.figure(figsize=(6, 5))
ax = plt.axes()
for i, modality in enumerate(psf_names_verbose):
    if i > 0:
        #plt.plot(mean_photon_count_list, gaussian_mi_estimates_across_psfs[i], label = '{} Gaussian'.format(modality), color = get_color_cycle()[i], linestyle='--')
        plt.plot(mean_photon_count_list, pixelcnn_mi_estimates_across_psfs[i], label = '{}'.format(modality), color = get_color_cycle()[i-1]) # manual color correct
plt.legend()
plt.xlabel('Mean Photon Count')
plt.ylabel("Mutual Information (bits per pixel)")
#plt.title('Estimated Mutual Information vs. Mean Photon Count, CIFAR10')
clear_spines(ax)
#plt.savefig('mi_vs_photon_count.pdf', bbox_inches='tight', transparent=True)

### Mean plot with error bars included

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
for psf_name in psf_names:
    gaussian_across_photons = [] 
    pixelcnn_across_photons = []
    for photon_count in mean_photon_count_list:
        gaussian_mi_estimate = np.load(mi_folder + 'cifar10_mi_estimates/gaussian_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))
        pixelcnn_mi_estimate = np.load(mi_folder + 'cifar10_mi_estimates/pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))
        if np.max(pixelcnn_mi_estimate) / np.min(pixelcnn_mi_estimate) > 2:
            pixelcnn_mi_estimate[pixelcnn_mi_estimate > 2 * np.min(pixelcnn_mi_estimate)] = np.min(pixelcnn_mi_estimate)
        gaussian_across_photons.append(gaussian_mi_estimate)
        pixelcnn_across_photons.append(pixelcnn_mi_estimate)
    error_lo, error_hi, mean = confidence_bars(gaussian_across_photons, 9)
    error_lo_2, error_hi_2, mean_2 = confidence_bars(pixelcnn_across_photons, 9)
    ax.plot(mean_photon_count_list, mean, '-', label='Gaussian {}'.format(psf_name))
    ax.plot(mean_photon_count_list, mean_2, '-', label='PixelCNN {}'.format(psf_name))
    ax.fill_between(mean_photon_count_list, error_lo, error_hi, alpha=0.4)
    ax.fill_between(mean_photon_count_list, error_lo_2, error_hi_2, alpha=0.4)
plt.legend()
plt.title("Gaussian vs. PixelCNN MI Estimates Across Photon Count, CIFAR10, 4 Seeds, Means, Outliers Removed")
plt.ylabel('Estimated Mutual Information')
plt.xlabel('Mean Photon Count')

## Load classification data and make plots of it

In [None]:
classifier_folder = ''

In [None]:
classifier_all_trials_across_psfs = [] # 4 x 9x 10 array, 4 psfs, 9 photon counts, 10 trials on each one 
for psf_name in psf_names:
    classifier_across_photons = []
    for photon_count in mean_photon_count_list: 
        result = np.load(classifier_folder + 'classifier_results/cifar_test_accuracy_{}_mean_photon_count_{}_psf_{}_bias_{}_model.npy'.format(photon_count, psf_name, bias, 'cnn'))
        classifier_across_photons.append(result)
    classifier_across_photons = np.array(classifier_across_photons)
    classifier_all_trials_across_psfs.append(classifier_across_photons)
classifier_all_trials_across_psfs = np.array(classifier_all_trials_across_psfs)

In [None]:
classifier_across_psfs = [] # 4 x 9 array, 4 psfs, 9 photon counts on each one 
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
for psf_name in psf_names:
    classifier_across_photons = [] 
    for photon_count in mean_photon_count_list:
        result = np.load(classifier_folder + 'classifier_results/cifar_test_accuracy_{}_mean_photon_count_{}_psf_{}_bias_{}_model.npy'.format(photon_count, psf_name, bias, 'cnn'))
        classifier_across_photons.append(result)
    error_lo, error_hi, mean = confidence_bars(classifier_across_photons, 9)
    ax.plot(mean_photon_count_list, mean, '-', label='{}'.format(psf_name))
    ax.fill_between(mean_photon_count_list, error_lo, error_hi, alpha=0.4)
    classifier_across_psfs.append(mean) # only keep mean dataset for use
classifier_across_psfs = np.array(classifier_across_psfs)
plt.legend()
plt.title("CNN Classification Accuracy vs. Mean Photon Count, CIFAR10")
plt.ylabel('Classification Accuracy')
plt.xlabel('Mean Photon Count')

## Remake plots as Bar Charts

In [None]:
# for a fixed photon count, plot just the MI values
photon_level = 300
photon_level_idx = np.argwhere(np.array(mean_photon_count_list) == photon_level)[0][0]
mi_photon_val = pixelcnn_mi_estimates_across_psfs[:, photon_level_idx]
#dual_mi_photon_val = dual_mi_across_noise[photon_level_idx]
plt.bar(np.arange(3), mi_photon_val)
plt.xticks(np.arange(3), ['One Lens', 'Four Lens', 'Diffuser'])
plt.ylabel('Estimated Mutual Information')
plt.xlabel("Encoding PSF")
plt.title("Estimated Mutual Information for Different Imaging Modalities, CIFAR10, {} Mean Photons".format(photon_level))

In [None]:
# for a fixed photon count, plot just the MI values
photon_level = 20
photon_level_idx = np.argwhere(np.array(mean_photon_count_list) == photon_level)[0][0]
mi_photon_val = pixelcnn_mi_estimates_across_psfs[:, photon_level_idx]
#dual_mi_photon_val = dual_mi_across_noise[photon_level_idx]
plt.bar(np.arange(3), mi_photon_val)
plt.xticks(np.arange(3), ['One Lens', 'Four Lens', 'Diffuser'])
plt.ylabel('Estimated Mutual Information')
plt.xlabel("Encoding PSF")
plt.title("Estimated Mutual Information for Different Imaging Modalities, CIFAR10, {} Mean Photons".format(photon_level))

## Proper way to do the labels is with these pre-made handles on your legend

In [None]:
import matplotlib.lines as mlines

In [None]:
marker_list = ['^', 's', 'D']
psf_name_labels = ['One Lens', 'Four Lens', 'Diffuser']
marker_size = 15

In [None]:
mi_names = ['Gaussian', 'PixelCNN']

## Make same plots without the No PSF Case

In [None]:
for mi_idx, mi_estimate_list in enumerate([gaussian_mi_estimates_across_psfs, pixelcnn_mi_estimates_across_psfs]):
    ## this one, call inferno(i) rather than get_color_cycle()[i]
    inferno = plt.cm.get_cmap('inferno', max_photon_count) # max photon count value, using each tick point smoothly gives a better color gradient that isn't uniform
    #inferno = plt.cm.get_cmap('inferno', len(photon_count_list))

    mod_idx_plot = 3
    mod_idx_shift = 2
    # set up figure with preprocessed handle details
    psf_label_list = []
    for i in range(0, len(psf_name_labels)):
        label = mlines.Line2D([], [], color='black', marker=marker_list[i], linestyle='None', markersize=8, label=psf_name_labels[i])
        psf_label_list.append(label)
    fig, ax = plt.subplots(figsize=(9, 6), layout='constrained')
    #Create a legend for the psf names
    first_legend = ax.legend(handles=psf_label_list, loc='center right') # 'outside right upper' is different than 'outside upper right'
    # Add the legend manually to the Axes.
    fig.add_artist(first_legend)

    for photon_level in mean_photon_count_list:
        photon_level_idx = np.argwhere(np.array(mean_photon_count_list) == photon_level)[0][0]
        mi_photon_val = mi_estimate_list[:, photon_level_idx]
        # just pick the specific photon value 
        if photon_level_idx % mod_idx_plot == mod_idx_shift:
            mean_list_fixed_photon_level = classifier_across_psfs[:, photon_level_idx]
            #plt.figure()
            for i, txt in enumerate(psf_name_labels):
                plt.plot(mi_photon_val[i], mean_list_fixed_photon_level[i], marker_list[i], color=inferno(mean_photon_count_list[photon_level_idx]), markersize=marker_size)
            
    # also add the dashed lines, want to loop across PSF but do all noise levels, so can use light gray or something? 
    for psf_idx, psf_name in enumerate(psf_names):
        classification_accuracy_across_photons_fixed_psf = classifier_across_psfs[psf_idx]
        mi_psf_val = mi_estimate_list[psf_idx]
        classification_accuracy_across_select_indices = [classification_accuracy_across_photons_fixed_psf[i] for i in range(len(classification_accuracy_across_photons_fixed_psf)) if i % mod_idx_plot == mod_idx_shift]
        mi_across_photon_select_indices = [mi_psf_val[i] for i in range(len(mi_psf_val)) if i % mod_idx_plot == mod_idx_shift]
        plt.plot(mi_across_photon_select_indices, classification_accuracy_across_select_indices, linestyle='--', color='gray')
        
    plt.xlabel('Estimated Mutual Information')
    plt.ylabel('CNN Classification Accuracy')
    plt.title("CIFAR10 Classification Accuracy vs. Estimated Mutual Information, {} MI Estimator".format(mi_names[mi_idx]))
    # put a legend only for the symbols used
    norm = mpl.colors.Normalize(vmin=0, vmax=max_photon_count) # normalize to the max photon count
    cmap = mpl.cm.ScalarMappable(norm=norm, cmap=inferno) # if using a ton of points, can just use the cm. otherwise, use plt.cm.inferno
    cmap.set_array([])
    # fig.colorbar(cmap, ticks=photon_count_list, ax=ax)
    cbar = fig.colorbar(cmap, ticks=mean_photon_count_list, aspect=9.5, orientation='vertical', ax=ax, label='Mean Photon Count')  

    plt.show()
    

In [None]:
for mi_idx, mi_estimate_list in enumerate([gaussian_mi_estimates_across_psfs, pixelcnn_mi_estimates_across_psfs]):
    ## this one, call inferno(i) rather than get_color_cycle()[i]
    inferno = plt.cm.get_cmap('inferno', max_photon_count) # max photon count value, using each tick point smoothly gives a better color gradient that isn't uniform
    #inferno = plt.cm.get_cmap('inferno', len(photon_count_list))

    mod_idx_plot = 1
    mod_idx_shift = 0
    # set up figure with preprocessed handle details
    psf_label_list = []
    for i in range(0, len(psf_name_labels)):
        label = mlines.Line2D([], [], color='black', marker=marker_list[i], linestyle='None', markersize=8, label=psf_name_labels[i])
        psf_label_list.append(label)
    fig, ax = plt.subplots(figsize=(9, 6), layout='constrained')
    #Create a legend for the psf names
    first_legend = ax.legend(handles=psf_label_list, loc='center right') # 'outside right upper' is different than 'outside upper right'
    # Add the legend manually to the Axes.
    fig.add_artist(first_legend)

    for photon_level in mean_photon_count_list:
        photon_level_idx = np.argwhere(np.array(mean_photon_count_list) == photon_level)[0][0]
        mi_photon_val = mi_estimate_list[:, photon_level_idx]
        # just pick the specific photon value 
        if photon_level_idx % mod_idx_plot == mod_idx_shift:
            mean_list_fixed_photon_level = classifier_across_psfs[:, photon_level_idx]
            #plt.figure()
            for i, txt in enumerate(psf_name_labels):  
                plt.plot(mi_photon_val[i], mean_list_fixed_photon_level[i], marker_list[i], color=inferno(mean_photon_count_list[photon_level_idx]), markersize=marker_size)
            
    # also add the dashed lines, want to loop across PSF but do all noise levels, so can use light gray or something? 
    for psf_idx, psf_name in enumerate(psf_names):
        classification_accuracy_across_photons_fixed_psf = classifier_across_psfs[psf_idx]
        mi_psf_val = mi_estimate_list[psf_idx]
        classification_accuracy_across_select_indices = [classification_accuracy_across_photons_fixed_psf[i] for i in range(len(classification_accuracy_across_photons_fixed_psf)) if i % mod_idx_plot == mod_idx_shift]
        mi_across_photon_select_indices = [mi_psf_val[i] for i in range(len(mi_psf_val)) if i % mod_idx_plot == mod_idx_shift]
        plt.plot(mi_across_photon_select_indices, classification_accuracy_across_select_indices, linestyle='--', color='gray')
        
    plt.xlabel('Estimated Mutual Information')
    plt.ylabel('CNN Classification Accuracy')
    plt.title("CIFAR10 Classification Accuracy vs. Estimated Mutual Information, {} MI Estimator".format(mi_names[mi_idx]))
    # put a legend only for the symbols used
    norm = mpl.colors.Normalize(vmin=0, vmax=max_photon_count) # normalize to the max photon count
    cmap = mpl.cm.ScalarMappable(norm=norm, cmap=inferno) # if using a ton of points, can just use the cm. otherwise, use plt.cm.inferno
    cmap.set_array([])
    # fig.colorbar(cmap, ticks=photon_count_list, ax=ax)
    cbar = fig.colorbar(cmap, ticks=mean_photon_count_list, aspect=9.5, orientation='vertical', ax=ax, label='Mean Photon Count')  

    plt.show()
    

## Incorporate Henry's Style for plots, include classifier error bars

### Setup

In [None]:
def marker_for_psf(psf_name):
    if psf_name =='one':
        marker = 'o'
    elif psf_name == 'four':
        marker = 's' 
    elif psf_name == 'diffuser':
        marker = '*'
    elif psf_name == 'uc':
        marker = 'x'
    elif psf_name =='two':
        marker = 'd'
    return marker

In [None]:
# Choose a base colormap
base_colormap = plt.cm.get_cmap('inferno')
# Define the start and end points--used so that high values aren't too light against white background
start, end = 0, 0.88 # making end point 0.8
from matplotlib.colors import LinearSegmentedColormap
# Create a new colormap from the portion of the original colormap
colormap = LinearSegmentedColormap.from_list(
    'trunc({n},{a:.2f},{b:.2f})'.format(n=base_colormap.name, a=start, b=end),
    base_colormap(np.linspace(start, end, 256))
)

min_photons_per_pixel =  min(mean_photon_count_list)
max_photons_per_pixel =  max(mean_photon_count_list)

min_log_photons = np.log(min_photons_per_pixel)
max_log_photons = np.log(max_photons_per_pixel)

def color_for_photon_level(photons_per_pixel):
    log_photons = np.log(photons_per_pixel)
    return colormap((log_photons - min_log_photons) / (max_log_photons - min_log_photons) )

### Update parameters in below block to display the things you want to display, then run the block after to make the figure

In [None]:
estimator_type = 1 # 0 for gaussian, 1 for pixelcnn
valid_psfs = [0, 1, 2] # 0 for uc, 1 for one, 2 for four, 3 for diffuser
valid_photon_counts = [20, 40, 60, 80, 100, 150, 200, 250, 300]

In [None]:
confidence_level = 0.9 
# using min-valued MI estimates 
mi_estimate_lists = [gaussian_mi_estimates_across_psfs, pixelcnn_mi_estimates_across_psfs]
# classifier array is classifier_all_trials_across_psfs, 4x9x10 array. 4 psfs, 9 photon counts, 10 trials on each one 

fig, ax = plt.subplots(1, 1, figsize=(7, 5))

mi_list_use = mi_estimate_lists[estimator_type] # use pixelcnn or gaussian, choose pixelcnn 

for psf_idx, psf_name in enumerate(psf_names):
    if psf_idx in valid_psfs:
        mi_means_across_photons = [] # track mean MI values to make trendline 
        classifier_means_across_photons = [] # track mean MI values to make trendline
        classifier_lower_across_photons = [] # track lower bounds
        classifier_upper_across_photons = [] # track upper bounds

        for photon_idx, photon_count in enumerate(mean_photon_count_list):
            if photon_count in valid_photon_counts:
                # load mean values and colors to plot 
                color = color_for_photon_level(photon_count)
                mi_value = mi_list_use[psf_idx][photon_idx] # only use an MI value if the psf is valid, correctly indexed 
                classifier_10_trials = classifier_all_trials_across_psfs[psf_idx][photon_idx]
                classifier_mean = np.mean(classifier_10_trials)
                ax.scatter(mi_value, classifier_mean, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)
                # add to lists to track later 
                mi_means_across_photons.append(mi_value)
                classifier_means_across_photons.append(classifier_mean)
                # calculate error bars
                classifier_lower_across_photons.append(np.percentile(classifier_10_trials, 100 - 100 * (1 + confidence_level) / 2))
                classifier_upper_across_photons.append(np.percentile(classifier_10_trials, 100 * (1 + confidence_level) / 2))
        mi_means_across_photons = np.array(mi_means_across_photons)
        classifier_means_across_photons = np.array(classifier_means_across_photons)
        ax.plot(mi_means_across_photons, classifier_means_across_photons, '--', color='grey', alpha=1, linewidth=2)
        ax.fill_between(mi_means_across_photons, classifier_lower_across_photons, classifier_upper_across_photons, color='grey', alpha=0.3, linewidth=0, zorder=-100)

ax.set_xlabel('Mutual Information (bits per pixel)')
ax.set_ylabel('Classification Accuracy')
clear_spines(ax)


# legend
# ax.scatter([], [], color='k', marker='x', label='No PSF')
ax.scatter([], [], color='k', marker='o', label='One Lens')
ax.scatter([], [], color='k', marker='s', label='Four Lens')
ax.scatter([], [], color='k', marker='*', label='Diffuser')

ax.legend(loc='lower right', frameon=True)
ax.set_xlim([0, None])



norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)
sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))
# set tick labels
cbar.ax.set_yticklabels(valid_photon_counts)


cbar.set_label('Mean Photon Count')

## Modified version with fewer photon counts

In [None]:
estimator_type = 1 # 0 for gaussian, 1 for pixelcnn
valid_psfs = [0, 1, 2] # 0 for uc, 1 for one, 2 for four, 3 for diffuser
valid_photon_counts = [20, 40, 80, 150, 300]

In [None]:
confidence_level = 0.9 
# using min-valued MI estimates 
mi_estimate_lists = [gaussian_mi_estimates_across_psfs, pixelcnn_mi_estimates_across_psfs]
# classifier array is classifier_all_trials_across_psfs, 4x9x10 array. 4 psfs, 9 photon counts, 10 trials on each one 

fig, ax = plt.subplots(1, 1, figsize=(7, 5))

mi_list_use = mi_estimate_lists[estimator_type] # use pixelcnn or gaussian, choose pixelcnn 

for psf_idx, psf_name in enumerate(psf_names):
    if psf_idx in valid_psfs:
        mi_means_across_photons = [] # track mean MI values to make trendline 
        classifier_means_across_photons = [] # track mean MI values to make trendline
        classifier_lower_across_photons = [] # track lower bounds
        classifier_upper_across_photons = [] # track upper bounds

        for photon_idx, photon_count in enumerate(mean_photon_count_list):
            if photon_count in valid_photon_counts:
                # load mean values and colors to plot 
                color = color_for_photon_level(photon_count)
                mi_value = mi_list_use[psf_idx][photon_idx] # only use an MI value if the psf is valid, correctly indexed 
                classifier_10_trials = classifier_all_trials_across_psfs[psf_idx][photon_idx]
                classifier_mean = np.mean(classifier_10_trials)
                ax.scatter(mi_value, classifier_mean, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)
                # add to lists to track later 
                mi_means_across_photons.append(mi_value)
                classifier_means_across_photons.append(classifier_mean)
                # calculate error bars
                classifier_lower_across_photons.append(np.percentile(classifier_10_trials, 100 - 100 * (1 + confidence_level) / 2))
                classifier_upper_across_photons.append(np.percentile(classifier_10_trials, 100 * (1 + confidence_level) / 2))
        mi_means_across_photons = np.array(mi_means_across_photons)
        classifier_means_across_photons = np.array(classifier_means_across_photons)
        ax.plot(mi_means_across_photons, classifier_means_across_photons, '--', color='grey', alpha=1, linewidth=2)
        ax.fill_between(mi_means_across_photons, classifier_lower_across_photons, classifier_upper_across_photons, color='grey', alpha=0.3, linewidth=0, zorder=-100)

ax.set_xlabel('Mutual Information (bits per pixel)')
ax.set_ylabel('Classification Accuracy')
clear_spines(ax)


# legend
# ax.scatter([], [], color='k', marker='x', label='No PSF')
ax.scatter([], [], color='k', marker='o', label='One Lens')
ax.scatter([], [], color='k', marker='s', label='Four Lens')
ax.scatter([], [], color='k', marker='*', label='Diffuser')

ax.legend(loc='lower right', frameon=True)
ax.set_xlim([0, None])



norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)
sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))
# set tick labels
cbar.ax.set_yticklabels(valid_photon_counts)


cbar.set_label('Mean Photon Count')
#plt.savefig('mi_vs_classification.pdf', bbox_inches='tight', transparent=True)