# Make the plot for MI and deconvolution relationship for paper figure, 2024/10/23

In [None]:
%load_ext autoreload 
%autoreload 2

In [None]:
from jax import config
config.update("jax_enable_x64", True)
import numpy as np

import sys 
sys.path.append('/home/your_username/EncodingInformation/src')
from lensless_helpers import *
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
print(os.environ.get('PYTHONPATH'))
from cleanplots import * 

In [None]:
seed_value = 10

# set photon properties 
bias = 10 # in photons
mean_photon_count_list = [20, 40, 80, 160, 320]
max_photon_count = mean_photon_count_list[-1]

# set eligible psfs

psf_names = ['one', 'four', 'diffuser']

# MI estimator parameters 
patch_size = 32
num_patches = 10000
val_set_size = 1000
test_set_size = 1500

mi_dir = '/home/your_username/EncodingInformation/lensless_imager/mi_estimates_smaller_lr/'
recon_dir = '/home/your_username/EncodingInformation/lensless_imager/deconvolutions/'

## Load MI data and make plots of it

The plot has essentially invisible error bars. No outlier issues

In [None]:
from cleanplots import *
get_color_cycle()[0]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
mis_across_psfs = []
lowers_across_psfs = []
uppers_across_psfs = []
for psf_name in psf_names:
    mis = []
    lowers = []
    uppers = []
    for photon_count in mean_photon_count_list:
        mi_estimates = np.load(mi_dir + 'pixelcnn_mi_estimate_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))
        mi_values = mi_estimates[0]
        print(np.max(mi_values) - np.min(mi_values))
        lower_bounds = mi_estimates[1]
        upper_bounds = mi_estimates[2]
        # get index that has smallest mi value across the different model runs.
        min_mi_index = np.argmin(mi_values)
        mis.append(mi_values[min_mi_index])
        lowers.append(lower_bounds[min_mi_index])
        uppers.append(upper_bounds[min_mi_index])
    ax.plot(mean_photon_count_list, mis, label=psf_name) 
    ax.fill_between(mean_photon_count_list, lowers, uppers, alpha=0.3)
    mis_across_psfs.append(mis)
    lowers_across_psfs.append(lowers)
    uppers_across_psfs.append(uppers)
plt.legend()
plt.title("PixelCNN MI estimates across Photon Count, CIFAR10")
plt.xlabel("Mean Photon Count")
plt.ylabel("Estimated Mutual Information")
mis_across_psfs = np.array(mis_across_psfs)
lowers_across_psfs = np.array(lowers_across_psfs)
uppers_across_psfs = np.array(uppers_across_psfs)

## Load recon data and make plots of it

In [None]:
mses_across_psfs = []
mse_lowers_across_psfs = []
mse_uppers_across_psfs = []
psnrs_across_psfs = []
psnr_lowers_across_psfs = []
psnr_uppers_across_psfs = []
ssims_across_psfs = []
ssim_lowers_across_psfs = []
ssim_uppers_across_psfs = []

for psf_name in psf_names: 
    mse_vals = []
    mse_lowers = []
    mse_uppers = []
    psnr_vals = []
    psnr_lowers = []
    psnr_uppers = []
    ssim_vals = []
    ssim_lowers = []
    ssim_uppers = []
    for photon_count in mean_photon_count_list:
        metrics = np.load(recon_dir + 'unsupervised_wiener_recon_{}_photon_count_{}_psf.npy'.format(photon_count, psf_name))
        mse = metrics[0]
        psnr = metrics[1] 
        ssim = metrics[2]
        bootstrap_mse, bootstrap_psnr, bootstrap_ssim = compute_bootstraps(mse, psnr, ssim, test_set_size)
        mean_mse, lower_bound_mse, upper_bound_mse = compute_confidence_interval(bootstrap_mse, confidence_interval=0.95)
        mean_psnr, lower_bound_psnr, upper_bound_psnr = compute_confidence_interval(bootstrap_psnr, confidence_interval=0.95)
        mean_ssim, lower_bound_ssim, upper_bound_ssim = compute_confidence_interval(bootstrap_ssim, confidence_interval=0.95)
        mse_vals.append(mean_mse)
        mse_lowers.append(lower_bound_mse)
        mse_uppers.append(upper_bound_mse)
        psnr_vals.append(mean_psnr)
        psnr_lowers.append(lower_bound_psnr)
        psnr_uppers.append(upper_bound_psnr)
        ssim_vals.append(mean_ssim)
        ssim_lowers.append(lower_bound_ssim)
        ssim_uppers.append(upper_bound_ssim)
    mses_across_psfs.append(mse_vals)
    mse_lowers_across_psfs.append(mse_lowers)
    mse_uppers_across_psfs.append(mse_uppers)
    psnrs_across_psfs.append(psnr_vals)
    psnr_lowers_across_psfs.append(psnr_lowers)
    psnr_uppers_across_psfs.append(psnr_uppers)
    ssims_across_psfs.append(ssim_vals)
    ssim_lowers_across_psfs.append(ssim_lowers)
    ssim_uppers_across_psfs.append(ssim_uppers)
mses_across_psfs = np.array(mses_across_psfs)
mse_lowers_across_psfs = np.array(mse_lowers_across_psfs)
mse_uppers_across_psfs = np.array(mse_uppers_across_psfs)
psnrs_across_psfs = np.array(psnrs_across_psfs)
psnr_lowers_across_psfs = np.array(psnr_lowers_across_psfs)
psnr_uppers_across_psfs = np.array(psnr_uppers_across_psfs)
ssims_across_psfs = np.array(ssims_across_psfs)
ssim_lowers_across_psfs = np.array(ssim_lowers_across_psfs)
ssim_uppers_across_psfs = np.array(ssim_uppers_across_psfs)
plt.figure(figsize=(20, 5))
plt.subplot(1, 3, 1)
for i in range(len(psf_names)):
    plt.plot(mean_photon_count_list, mses_across_psfs[i], label=psf_names[i])
    plt.fill_between(mean_photon_count_list, mse_lowers_across_psfs[i], mse_uppers_across_psfs[i], alpha=0.5)
plt.title("MSE")
plt.legend()
plt.subplot(1, 3, 2)
for i in range(len(psf_names)):
    plt.plot(mean_photon_count_list, psnrs_across_psfs[i], label=psf_names[i])
    plt.fill_between(mean_photon_count_list, psnr_lowers_across_psfs[i], psnr_uppers_across_psfs[i], alpha=0.5)
plt.title("PSNR")
plt.subplot(1, 3, 3)
for i in range(len(psf_names)):
    plt.plot(mean_photon_count_list, ssims_across_psfs[i], label=psf_names[i])
    plt.fill_between(mean_photon_count_list, ssim_lowers_across_psfs[i], ssim_uppers_across_psfs[i], alpha=0.5)
plt.title("SSIM")
plt.legend()

## Make figures, omitting error bars since smaller than marker size and reverting to circular markers

### Setup

In [None]:
def marker_for_psf(psf_name):
    if psf_name =='one':
        marker = 'o'
    elif psf_name == 'four':
        marker = 'o'
        #marker = 's' 
    elif psf_name == 'diffuser':
        #marker = '*'
        marker = 'o'
    elif psf_name == 'uc':
        marker = 'x'
    elif psf_name =='two':
        marker = 'd'
    elif psf_name == 'three':
        marker = 'v'
    elif psf_name == 'five':
        marker = 'p'
    elif psf_name == 'aperture':
        marker = 'P'
    return marker

In [None]:
# Choose a base colormap
base_colormap = plt.get_cmap('inferno')
# Define the start and end points--used so that high values aren't too light against white background
start, end = 0, 0.88 # making end point 0.8
from matplotlib.colors import LinearSegmentedColormap
# Create a new colormap from the portion of the original colormap
colormap = LinearSegmentedColormap.from_list(
    'trunc({n},{a:.2f},{b:.2f})'.format(n=base_colormap.name, a=start, b=end),
    base_colormap(np.linspace(start, end, 256))
)

min_photons_per_pixel =  min(mean_photon_count_list)
max_photons_per_pixel =  max(mean_photon_count_list)

min_log_photons = np.log(min_photons_per_pixel)
max_log_photons = np.log(max_photons_per_pixel)

def color_for_photon_level(photons_per_pixel):
    log_photons = np.log(photons_per_pixel)
    return colormap((log_photons - min_log_photons) / (max_log_photons - min_log_photons) )

In [None]:
# old format for selecting target indices, now not used much
metric_type = 1 # 0 for MSE, 1 for PSNR 
valid_psfs = [0, 1, 2]
valid_photon_counts = [20, 40, 80, 160, 320]
psf_names = [psf_names[i] for i in valid_psfs]
print(psf_names)

In [None]:
mse_error_lower = np.abs(mses_across_psfs - mse_lowers_across_psfs)
mse_error_upper = np.abs(mse_uppers_across_psfs - mses_across_psfs)
psnr_error_lower = np.abs(psnrs_across_psfs - psnr_lowers_across_psfs)
psnr_error_upper = np.abs(psnr_uppers_across_psfs - psnrs_across_psfs)
mi_error_lower = np.abs(mis_across_psfs - lowers_across_psfs)
mi_error_upper = np.abs(uppers_across_psfs - mis_across_psfs)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))
for psf_idx, psf_name in enumerate(psf_names):
    # plot all of the points here. 
    mi_means_across_photons = []
    recon_means_across_photons = []
    for photon_idx, photon_count in enumerate(mean_photon_count_list):
        color = color_for_photon_level(photon_count) 
        mi_value = mis_across_psfs[psf_idx][photon_idx] 
        recon_value = mses_across_psfs[psf_idx][photon_idx] 
        ax.scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)
        # add to lists to track later 
        mi_means_across_photons.append(mi_value)
        recon_means_across_photons.append(recon_value)
    #ax.errorbar(mis_across_psfs[psf_idx], mses_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[mse_error_lower[psf_idx], mse_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)
    
    mi_means_across_photons = np.array(mi_means_across_photons)
    recon_means_across_photons = np.array(recon_means_across_photons)
    ax.plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)
ax.set_xlabel("Mutual Information (bits per pixel)")
ax.set_ylabel("Mean Squared Error")
clear_spines(ax)


# legend
# ax.scatter([], [], color='k', marker='o', label='One Lens')
# ax.scatter([], [], color='k', marker='s', label='Four Lens')
# ax.scatter([], [], color='k', marker='*', label='Diffuser')

ax.legend(loc='upper right', frameon=True)
ax.set_xlim([0, None])



norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)
sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))
# set tick labels
cbar.ax.set_yticklabels(valid_photon_counts)


cbar.set_label('Photons per pixel')

#plt.savefig('mse_vs_MI_with_confidence_intervals_log_photons.pdf', bbox_inches='tight', transparent=True)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))
for psf_idx, psf_name in enumerate(psf_names):
    # plot all of the points here. 
    mi_means_across_photons = []
    recon_means_across_photons = []
    for photon_idx, photon_count in enumerate(mean_photon_count_list):
        color = color_for_photon_level(photon_count) 
        mi_value = mis_across_psfs[psf_idx][photon_idx] 
        recon_value = psnrs_across_psfs[psf_idx][photon_idx] 
        ax.scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)
        # add to lists to track later 
        mi_means_across_photons.append(mi_value)
        recon_means_across_photons.append(recon_value)
    #ax.errorbar(mis_across_psfs[psf_idx], psnrs_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[psnr_error_lower[psf_idx], psnr_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)
    
    mi_means_across_photons = np.array(mi_means_across_photons)
    recon_means_across_photons = np.array(recon_means_across_photons)
    ax.plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)
ax.set_xlabel("Mutual Information (bits per pixel)")
ax.set_ylabel("Peak Signal-to-Noise Ratio (dB)")
clear_spines(ax)


# legend
# ax.scatter([], [], color='k', marker='o', label='One Lens')
# ax.scatter([], [], color='k', marker='s', label='Four Lens')
# ax.scatter([], [], color='k', marker='*', label='Diffuser')

ax.legend(loc='lower right', frameon=True)
ax.set_xlim([0, None])



norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)
sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))
# set tick labels
cbar.ax.set_yticklabels(valid_photon_counts)


cbar.set_label('Photons per pixel')

#plt.savefig('psnr_vs_MI_with_confidence_intervals_log_photons.pdf', bbox_inches='tight', transparent=True)
    

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))
for psf_idx, psf_name in enumerate(psf_names):
    # plot all of the points here. 
    mi_means_across_photons = []
    recon_means_across_photons = []
    for photon_idx, photon_count in enumerate(mean_photon_count_list):
        color = color_for_photon_level(photon_count) 
        mi_value = mis_across_psfs[psf_idx][photon_idx] 
        recon_value = ssims_across_psfs[psf_idx][photon_idx] 
        ax.scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)
        # add to lists to track later 
        mi_means_across_photons.append(mi_value)
        recon_means_across_photons.append(recon_value)
    #ax.errorbar(mis_across_psfs[psf_idx], ssims_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[ssim_error_lower[psf_idx], ssim_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)
    
    mi_means_across_photons = np.array(mi_means_across_photons)
    recon_means_across_photons = np.array(recon_means_across_photons)
    ax.plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)
ax.set_xlabel("Mutual Information (bits per pixel)")
ax.set_ylabel("Structural Similarity Index Measure (SSIM)")
clear_spines(ax)


# legend
# ax.scatter([], [], color='k', marker='o', label='One Lens')
# ax.scatter([], [], color='k', marker='s', label='Four Lens')
# ax.scatter([], [], color='k', marker='*', label='Diffuser')

ax.legend(loc='lower right', frameon=True)
ax.set_xlim([0, None])



norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)
sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, ticks=(np.log(valid_photon_counts)))
# set tick labels
cbar.ax.set_yticklabels(valid_photon_counts)


cbar.set_label('Photons per pixel')

#plt.savefig('ssim_vs_MI_with_confidence_intervals_log_photons.pdf', bbox_inches='tight', transparent=True)
    

Put all 3 into one figure

In [None]:
import glob
import numpy as np
import matplotlib.pyplot as plt
from cleanplots import *
from matplotlib.ticker import ScalarFormatter

figs, axs = plt.subplots(1, 3, figsize=(12, 4), sharex=True)


for psf_idx, psf_name in enumerate(psf_names):
    # plot all of the points here. 
    mi_means_across_photons = []
    recon_means_across_photons = []
    for photon_idx, photon_count in enumerate(mean_photon_count_list):
        color = color_for_photon_level(photon_count) 
        mi_value = mis_across_psfs[psf_idx][photon_idx] 
        recon_value = mses_across_psfs[psf_idx][photon_idx] 
        axs[0].scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)
        # add to lists to track later 
        mi_means_across_photons.append(mi_value)
        recon_means_across_photons.append(recon_value)
    #ax.errorbar(mis_across_psfs[psf_idx], mses_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[mse_error_lower[psf_idx], mse_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)
    
    mi_means_across_photons = np.array(mi_means_across_photons)
    recon_means_across_photons = np.array(recon_means_across_photons)
    axs[0].plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)
#axs[0].set_xlabel("Mutual Information (bits per pixel)")
axs[0].set_title("Mean Squared Error")
clear_spines(axs[0])

for psf_idx, psf_name in enumerate(psf_names):
    # plot all of the points here. 
    mi_means_across_photons = []
    recon_means_across_photons = []
    for photon_idx, photon_count in enumerate(mean_photon_count_list):
        color = color_for_photon_level(photon_count) 
        mi_value = mis_across_psfs[psf_idx][photon_idx] 
        recon_value = ssims_across_psfs[psf_idx][photon_idx] 
        axs[1].scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)
        # add to lists to track later 
        mi_means_across_photons.append(mi_value)
        recon_means_across_photons.append(recon_value)
    #ax.errorbar(mis_across_psfs[psf_idx], ssims_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[ssim_error_lower[psf_idx], ssim_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)
    
    mi_means_across_photons = np.array(mi_means_across_photons)
    recon_means_across_photons = np.array(recon_means_across_photons)
    axs[1].plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)
axs[1].set_xlabel("Mutual Information (bits per pixel)")
axs[1].set_title("Structural Similarity Index Measure (SSIM)")
clear_spines(axs[1])

for psf_idx, psf_name in enumerate(psf_names):
    # plot all of the points here. 
    mi_means_across_photons = []
    recon_means_across_photons = []
    for photon_idx, photon_count in enumerate(mean_photon_count_list):
        color = color_for_photon_level(photon_count) 
        mi_value = mis_across_psfs[psf_idx][photon_idx] 
        recon_value = psnrs_across_psfs[psf_idx][photon_idx] 
        axs[2].scatter(mi_value, recon_value, color=color, marker=marker_for_psf(psf_name), s=50, zorder=100)
        # add to lists to track later 
        mi_means_across_photons.append(mi_value)
        recon_means_across_photons.append(recon_value)
    #ax.errorbar(mis_across_psfs[psf_idx], psnrs_across_psfs[psf_idx], xerr=[mi_error_lower[psf_idx], mi_error_upper[psf_idx]], yerr=[psnr_error_lower[psf_idx], psnr_error_upper[psf_idx]], fmt='o', capsize=5, ecolor='black', markersize=8, barsabove=True)
    
    mi_means_across_photons = np.array(mi_means_across_photons)
    recon_means_across_photons = np.array(recon_means_across_photons)
    axs[2].plot(mi_means_across_photons, recon_means_across_photons, '--', color='gray', alpha=1, linewidth=2)
#axs[2].set_xlabel("Mutual Information (bits per pixel)")
axs[2].set_title("Peak Signal-to-Noise Ratio (dB)")
clear_spines(axs[2])

# norm = mpl.colors.Normalize(vmin=min_log_photons, vmax=max_log_photons)
# sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
# sm.set_array([])
# cbar = plt.colorbar(sm, ax=axs[2], ticks=(np.log(valid_photon_counts)))
# # set tick labels
# cbar.ax.set_yticklabels(valid_photon_counts)


# cbar.set_label('Photons per pixel')

#plt.savefig("metrics_vs_MI_with_confidence_intervals_log_photons.pdf", bbox_inches='tight', transparent=True)