In [None]:
import os
import sys
import yaml
import torch
import logging
import argparse
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from pytorch_lightning import seed_everything
from torchmetrics.functional import mean_squared_error as mse
from torchmetrics.functional import peak_signal_noise_ratio as psnr
from torchmetrics.functional import structural_similarity_index_measure as ssim

sys.path.append('../')
from utils import parameter_manager, model_loader
from core import datamodule, lrn, modulator, propagator

plt.style.use('seaborn-v0_8')

#logging.basicConfig(level=logging.DEBUG)

## Load parameters and test dataset

In [None]:
# Load parameters
params = yaml.load(open('../config.yaml'), Loader = yaml.FullLoader)
params['batch_size'] = 1
params['distance'] = torch.tensor(0.60264)

pm = parameter_manager.Parameter_Manager(params = params)

# Load in the test dataset
pm.data_split = "mnist_1000perClass"
datamod = datamodule.select_data(pm.params_datamodule)
datamod.setup()
dataloader_train_1000perClass = datamod.train_dataloader()
dataloader_test = datamod.test_dataloader()

datasets = ['mnist_single0', 'mnist_single1', 'mnist_10_1', 'mnist_10_8', 'mnist_100_1', 'mnist_100_8', 'mnist_1perClass', 'mnist_10perClass', 'mnist_100perClass', 'mnist_1000perClass']
 
data_loaders = {}
for data in datasets:
    pm.data_split = data
    datamod = datamodule.select_data(pm.params_datamodule)
    datamod.setup()
    loader = datamod.train_dataloader()
    data_loaders[f'{data}'] = loader
    

## Generate testing data statistics

I don't know these are useful.

In [None]:
# targets = data[1]

# bins = [0,1,2,3,4,5,6,7,8,9]
# counts, bins = np.histogram(targets, bins=10, range=[0, 10])

# fig,ddax = plt.subplots(1,1, figsize=(5,5))

# N, bins, patches = ax.hist(bins[:-1], bins, weights=counts, color='green')

# tick_labels = [i for i in range(0,10)]
# ax.set_xticks([i + 0.5 for i in range(0,10)], tick_labels)

# ax.set_ylabel("Number of Samples")
# ax.grid(False)
# ax.set_xlabel("Class")

## Utility functions to evaluate the models

In [None]:
def run_measures(outputs):
   
    wavefronts = outputs[0]
    amplitudes = outputs[1] 
    normalized_amplitudes = outputs[2]
    images = outputs[3]
    normalized_images = outputs[4]
    target = outputs[5]

    mse_vals = mse(normalized_images.detach(), target.detach())
    psnr_vals = psnr(normalized_images.detach(), target.detach())
    ssim_vals = ssim(normalized_images.detach(), target.detach()).detach()

    return {'mse' : mse_vals.cpu(), 'psnr' : psnr_vals.cpu(), 'ssim' : ssim_vals.cpu()}


def eval_model(model, dataloader):
    measures = []
    measures.append(params)
    for i,batch in enumerate(tqdm(dataloader)):
        sample,target = batch
        sample = sample.cuda()
        target = target.cuda()
        batch = (sample,target)
        outputs = model.shared_step(batch, i)
        temp = run_measures(outputs)
        temp['target'] = target.detach().cpu()
        measures.append(temp)
    return (measures, sample[0], outputs[0])

def eval_model_single(model, dataloader):
    measures = []
    measures.append(params)
    
    batch = next(iter(dataloader))
    sample,target = batch
    sample = sample.cuda()
    target = target.cuda()
    batch = (sample,target)
    outputs = model.shared_step(batch, 0)
    temp = run_measures(outputs)
    temp['target'] = target.detach().cpu()
    measures.append(temp)
    
    return (measures, sample[0], outputs[0])

def eval_model_fromBatch(model, batch):
    measures = []
    measures.append(params)
    sample,target = batch
    sample = sample.cuda()
    target = target.cuda()
    batch = (sample,target)
    outputs = model.shared_step(batch, 0)
    temp = run_measures(outputs)
    temp['target'] = target.detach().cpu()
    measures.append(temp)
    
    return (measures, sample[0], outputs[2])

In [None]:
def split_mse_by_class(measures):
    mse_by_class = {}
    for i in range(0,10):
        temp = []
        for sample in measures:
            if sample['target'] == i:
                temp.append(sample['mse'])
        mse_by_class[i] = temp
    return mse_by_class

def split_psnr_by_class(measures):
    psnr_by_class = {}
    for i in range(0,10):
        temp = []
        for sample in measures:
            if sample['target'] == i:
                temp.append(sample['psnr'])
        psnr_by_class[i] = temp
    return psnr_by_class

---
---
---
---
---
---


In [None]:
average_mse_values = []
average_psnr_values = []

analytical_average_mse_values = []
analytical_average_psnr_values = []

## Analytical LRN to compare against

In [None]:
pm.phase_initialization = 1
pm.collect_params()
analytical_lrn = lrn.LRN(pm.params_model_lrn, pm.params_propagator, pm.params_modulator).cuda()
analytical_lrn.eval()
pm.phase_initialization = 0

In [None]:
fig,ax = plt.subplots(1,1)
ax.imshow(analytical_lrn.layers[1].phase.detach().cpu().squeeze() % (2*np.pi), cmap = 'viridis')
ax.grid(False)
ax.axis('off')
fig.savefig('good_lens.pdf')

In [None]:
analytical_measures, analytical_sample, analytical_output = eval_model(analytical_lrn, dataloader_test)
analytical_params = analytical_measures.pop(0)

analytical_mse_by_class = split_mse_by_class(analytical_measures)
analytical_psnr_by_class = split_psnr_by_class(analytical_measures)

analytical_mse_by_class = [analytical_mse_by_class[i] for i in analytical_mse_by_class]
analytical_psnr_by_class =  [analytical_psnr_by_class[i] for i in analytical_psnr_by_class]

analytical_average_mse_values.append(np.average(analytical_mse_by_class))
analytical_average_psnr_values.append(np.average(analytical_psnr_by_class))

In [None]:
# Get a good image for comparisons
number = 500
for i,batch1 in enumerate(dataloader_test):
    if i == number:
        break
fig,ax = plt.subplots(1,1)
ax.imshow(batch1[0].abs().squeeze(), cmap='viridis')
ax.grid(False)


In [None]:
# Get a good image for comparisons
number = 1
for i,batch2 in enumerate(data_loaders['mnist_single0']):
    if i == number:
        break
fig,ax = plt.subplots(1,1)
ax.imshow(batch2[0].abs().squeeze(), cmap='viridis')
ax.grid(False)

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/lrn/test_lrn/epoch=0-step=1250-v2.ckpt')
model.eval()
model = model.cuda()

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, data_loaders['mnist_single0'])
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, data_loaders['mnist_single0'])
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
test_measures, test_example_input, test_example_output = eval_model_fromBatch(model, batch1)
train_measures, train_example_input, train_example_output = eval_model_fromBatch(model, batch2)

test_params = test_measures.pop(0)
train_params = train_measures.pop(0)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(train_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(train_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

#ax[0][2].text(680,75,'PSNR: {:.2f}'.format(train_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(680,150,'MSE: {:.2f}'.format(train_measures[0]['mse']), color='white', fontsize=12)


ax[1][1].imshow(test_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(test_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[1][2].text(680,75,'PSNR: {:.2f}'.format(test_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(680,150,'MSE: {:.2f}'.format(test_measures[0]['mse']), color='white', fontsize=12)

plt.tight_layout()
fig.savefig('testLrn_output.pdf')