In [None]:
import os
import sys
import yaml
import torch
import logging
import argparse
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from pytorch_lightning import seed_everything
from torchmetrics.functional import mean_squared_error as mse
from torchmetrics.functional import peak_signal_noise_ratio as psnr
from torchmetrics.functional import structural_similarity_index_measure as ssim
from mpl_toolkits.axes_grid1 import make_axes_locatable
sys.path.append('../')
from utils import parameter_manager, model_loader
from core import datamodule, lrn, modulator, propagator

plt.style.use('seaborn-v0_8')
average_errors = []
#logging.basicConfig(level=logging.DEBUG)

In [None]:
# Load parameters
params = yaml.load(open('../config.yaml'), Loader = yaml.FullLoader)
params['batch_size'] = 1
params['distance'] = torch.tensor(0.04)

pm = parameter_manager.Parameter_Manager(params = params)

# Load in the test dataset
pm.data_split = "mnist_1000perClass"
datamod = datamodule.select_data(pm.params_datamodule)
datamod.setup()
dataloader_train_1000perClass = datamod.train_dataloader()
dataloader_test = datamod.test_dataloader()

datasets = ['mnist_single0', 'mnist_single1', 'mnist_10_1', 'mnist_10_8', 'mnist_100_1', 'mnist_100_8', 'mnist_1perClass', 'mnist_10perClass', 'mnist_100perClass', 'mnist_1000perClass']
 
data_loaders = {}
for data in datasets:
    pm.data_split = data
    datamod = datamodule.select_data(pm.params_datamodule)
    datamod.setup()
    loader = datamod.train_dataloader()
    data_loaders[f'{data}'] = loader

In [None]:
def run_measures(outputs):
    wavefronts = outputs[0]
    amplitudes = outputs[1] 
    normalized_amplitudes = outputs[2]
    images = outputs[3]
    normalized_images = outputs[4]
    target = outputs[5]
    
    mse_vals = mse(preds = normalized_images.detach(), target = target.detach())    
    psnr_vals = psnr(preds = normalized_images.detach(), target = target.detach())
    ssim_vals = ssim(preds = normalized_images.detach(), target = target.detach()).detach()

    return {'mse' : mse_vals.cpu(), 'psnr' : psnr_vals.cpu(), 'ssim' : ssim_vals.cpu()}


def eval_model(model, dataloader):
    measures = []
    measures.append(params)
    
    pbar2 = tqdm(total=len(dataloader), desc='Evaluating Model', leave=False)
    
    for i,batch in enumerate(dataloader):
        sample,target = batch
        sample = sample.cuda()
        target = target.cuda()
        batch = (sample,target)
        outputs = model.shared_step(batch, i)
        temp = run_measures(outputs)
        temp['target'] = target.detach().cpu()
        measures.append(temp)
        pbar2.update(1)
    
    pbar2.close()
    return (measures, sample[0], outputs[0])

def eval_model_single(model, dataloader):
    measures = []
    measures.append(params)
    
    batch = next(iter(dataloader))
    sample,target = batch
    sample = sample.cuda()
    target = target.cuda()
    batch = (sample,target)
    outputs = model.shared_step(batch, 0)
    temp = run_measures(outputs)
    temp['target'] = target.detach().cpu()
    measures.append(temp)
    
    return (measures, sample[0], outputs[0])


def eval_model_fromBatch(model, batch):
    measures = []
    measures.append(params)
    sample,target = batch
    sample = sample.cuda()
    target = target.cuda()
    batch = (sample,target)
    outputs = model.shared_step(batch, 0)
    temp = run_measures(outputs)
    temp['target'] = target.detach().cpu()
    measures.append(temp)
    
    return (measures, sample[0], outputs[0])


def split_mse_by_class(measures):
    mse_by_class = {}
    for i in range(0,10):
        temp = []
        for sample in measures:
            if sample['target'] == i:
                temp.append(sample['mse'])
        mse_by_class[i] = temp
    return mse_by_class

def split_psnr_by_class(measures):
    psnr_by_class = {}
    for i in range(0,10):
        temp = []
        for sample in measures:
            if sample['target'] == i:
                temp.append(sample['psnr'])
        psnr_by_class[i] = temp
    return psnr_by_class

In [None]:
pm.phase_initialization = 1
pm.distance = torch.tensor(0.04)
constrained_analytical_lrn = lrn.LRN(pm.params_model_lrn, pm.params_propagator, pm.params_modulator)
constrained_analytical_lrn.eval()


pm.distance = torch.tensor(0.60264)
analytical_lrn = lrn.LRN(pm.params_model_lrn, pm.params_propagator, pm.params_modulator)
analytical_lrn.eval()
pm.phase_initialization = 0


analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze()
constrained_analytical_phase = constrained_analytical_lrn.layers[1].phase.detach().squeeze()
wrapped_analytical_phase = analytical_phase % (2*torch.pi)
wrapped_constrained_analytical_phase = constrained_analytical_phase % (2*torch.pi)

In [None]:
fig,ax = plt.subplot_mosaic("ab;de", figsize=(10,10))
im0 = ax['a'].imshow(analytical_phase, cmap='viridis')
ax['a'].grid(False)
ax['a'].set_title("Small NA Phase")

im1 = ax['b'].imshow(constrained_analytical_phase, cmap='viridis')
ax['b'].grid(False)
ax['b'].set_title("Large NA Phase")

im2 = ax['d'].imshow(wrapped_analytical_phase, cmap='viridis')
ax['d'].grid(False)
ax['d'].set_title("Small NA Phase (wrapped)")

im3 = ax['e'].imshow(wrapped_constrained_analytical_phase, cmap='viridis')
ax['e'].grid(False)
ax['e'].set_title("Large NA Phase (wrapped)")

divider = make_axes_locatable(ax['a'])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im0, cax=cax, orientation='vertical')

divider = make_axes_locatable(ax['b'])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im1, cax=cax, orientation='vertical')


divider = make_axes_locatable(ax['d'])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im2, cax=cax, orientation='vertical')


divider = make_axes_locatable(ax['e'])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im3, cax=cax, orientation='vertical')

plt.tight_layout()

In [None]:
fig,ax = plt.subplot_mosaic("aabb;ccdd", figsize=(26,10))

ln0, = ax['a'].plot(analytical_phase[:,int(1080/2)], color='green')
ax['a'].set_title("Small NA")

ln1, = ax['b'].plot(constrained_analytical_phase[:,int(1080/2)], color='purple')
ax['b'].set_title("Large NA")


ln2, = ax['c'].plot(wrapped_analytical_phase[:,int(1080/2)], color='green')
ax['c'].set_title("Small NA (wrapped)")


ln3, = ax['d'].plot(wrapped_constrained_analytical_phase[:,int(1080/2)], color='purple')
ax['d'].set_title("Large NA (wrapped)")


---
## Let's look at learned, high variety
---

In [None]:
learned_lrn = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_1000perClass/epoch=4-step=6250.ckpt')
learned_lrn.eval()
learned_phase = learned_lrn.layers[1].phase.detach().squeeze().cpu()
wrapped_learned_phase = (learned_phase+np.pi) % (torch.pi * 2)

print(torch.min(learned_phase))
print(torch.min(wrapped_learned_phase))

In [None]:
fig,ax = plt.subplot_mosaic("xy;aa;cc", figsize=(10,14))

im0 = ax['x'].imshow(learned_phase, cmap='viridis')
ax['x'].grid(False)
ax['x'].set_title("Learned Small NA")

im1 = ax['y'].imshow(wrapped_learned_phase, cmap='viridis')
ax['y'].grid(False)
ax['y'].set_title("Learned Small NA (wrapped)")
 
ln0, = ax['a'].plot(learned_phase[:,int(1080/2)]+np.pi, color='green')
ax['a'].set_title("Learned Small NA")

ln2, = ax['c'].plot(wrapped_learned_phase[:,int(1080/2)], color='green')
ax['c'].set_title("Learned Small NA (wrapped)")

divider = make_axes_locatable(ax['x'])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im0, cax=cax, orientation='vertical')

divider = make_axes_locatable(ax['y'])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im1, cax=cax, orientation='vertical')

plt.tight_layout()

In [None]:
fig,ax = plt.subplots(3,1, figsize=(8,26))

im0 = ax[0].imshow(wrapped_analytical_phase, cmap='viridis')
ax[0].grid(False)
im1 = ax[1].imshow(wrapped_learned_phase, cmap='viridis')
ax[1].grid(False)
im2 = ax[2].imshow(wrapped_analytical_phase - wrapped_learned_phase, cmap='viridis')
ax[2].grid(False)


divider = make_axes_locatable(ax[0])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im0, cax=cax, orientation='vertical')

divider = make_axes_locatable(ax[1])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im1, cax=cax, orientation='vertical')

divider = make_axes_locatable(ax[2])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im2, cax=cax, orientation='vertical')

In [None]:
fig,ax = plt.subplot_mosaic("aa;bb;cc", figsize=(10,14))
ax['a'].plot(wrapped_analytical_phase[:,int(1080/2)], color='green')
ax['b'].plot(wrapped_learned_phase[:,int(1080/2)], color='purple')
ax['c'].plot(wrapped_analytical_phase[:,int(1080/2)] - wrapped_learned_phase[:,int(1080/2)])

---
## Let's look at learned, constrained, high variety
---

In [None]:
learned_lrn = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_1000perClass_4cm_mse_random/epoch=4-step=6250.ckpt')
learned_lrn.eval()
learned_phase = learned_lrn.layers[1].phase.detach().squeeze().cpu()
wrapped_learned_phase = (learned_phase+np.pi) % (torch.pi * 2)

In [None]:
fig,ax = plt.subplot_mosaic("xy;aa;cc", figsize=(10,14))

im0 = ax['x'].imshow(learned_phase, cmap='viridis')
ax['x'].grid(False)
ax['x'].set_title("Learned Small NA")

im1 = ax['y'].imshow(wrapped_learned_phase, cmap='viridis')
ax['y'].grid(False)
ax['y'].set_title("Learned Small NA (wrapped)")
 
ln0, = ax['a'].plot(learned_phase[:,int(1080/2)], color='green')
ax['a'].set_title("Learned Small NA")

ln2, = ax['c'].plot(wrapped_learned_phase[:,int(1080/2)], color='green')
ax['c'].set_title("Learned Small NA (wrapped)")

plt.tight_layout()

In [None]:
fig,ax = plt.subplots(1,3, figsize=(26,8))

ax[0].imshow(wrapped_constrained_analytical_phase, cmap='viridis')
ax[0].grid(False)
ax[1].imshow(wrapped_learned_phase, cmap='viridis')
ax[1].grid(False)
ax[2].imshow(wrapped_constrained_analytical_phase - wrapped_learned_phase, cmap='viridis')
ax[2].grid(False)

In [None]:
fig,ax = plt.subplot_mosaic("aa;bb;cc", figsize=(20,14))
ax['a'].plot(wrapped_constrained_analytical_phase[:,int(1080/2)], color='green')
ax['b'].plot(wrapped_learned_phase[:,int(1080/2)], color='purple')
ax['c'].plot(wrapped_constrained_analytical_phase[:,int(1080/2)] - wrapped_learned_phase[:,int(1080/2)])