In [None]:
import os
import sys
import yaml
import torch
import logging
import argparse
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from pytorch_lightning import seed_everything
from torchmetrics.functional import mean_squared_error as mse
from torchmetrics.functional import peak_signal_noise_ratio as psnr
from torchmetrics.functional import structural_similarity_index_measure as ssim

sys.path.append('../')
from utils import parameter_manager, model_loader
from core import datamodule, lrn, modulator, propagator

plt.style.use('seaborn-v0_8')

#logging.basicConfig(level=logging.DEBUG)

## Load parameters and test dataset

In [None]:
# Load parameters
params = yaml.load(open('../config.yaml'), Loader = yaml.FullLoader)
params['batch_size'] = 1
params['distance'] = torch.tensor(0.60264)

pm = parameter_manager.Parameter_Manager(params = params)

# Load in the test dataset
pm.data_split = "mnist_1000perClass"
datamod = datamodule.select_data(pm.params_datamodule)
datamod.setup()
dataloader_train_1000perClass = datamod.train_dataloader()
dataloader_test = datamod.test_dataloader()

datasets = ['mnist_single0', 'mnist_single1', 'mnist_10_1', 'mnist_10_8', 'mnist_100_1', 'mnist_100_8', 'mnist_1perClass', 'mnist_10perClass', 'mnist_100perClass', 'mnist_1000perClass']
 
data_loaders = {}
for data in datasets:
    pm.data_split = data
    datamod = datamodule.select_data(pm.params_datamodule)
    datamod.setup()
    loader = datamod.train_dataloader()
    data_loaders[f'{data}'] = loader
    

## Generate testing data statistics

I don't know these are useful.

In [None]:
# targets = data[1]

# bins = [0,1,2,3,4,5,6,7,8,9]
# counts, bins = np.histogram(targets, bins=10, range=[0, 10])

# fig,ddax = plt.subplots(1,1, figsize=(5,5))

# N, bins, patches = ax.hist(bins[:-1], bins, weights=counts, color='green')

# tick_labels = [i for i in range(0,10)]
# ax.set_xticks([i + 0.5 for i in range(0,10)], tick_labels)

# ax.set_ylabel("Number of Samples")
# ax.grid(False)
# ax.set_xlabel("Class")

## Utility functions to evaluate the models

In [None]:
def run_measures(outputs):
   
    wavefronts = outputs[0]
    amplitudes = outputs[1] 
    normalized_amplitudes = outputs[2]
    images = outputs[3]
    normalized_images = outputs[4]
    target = outputs[5]

    mse_vals = mse(normalized_images.detach(), target.detach())
    psnr_vals = psnr(normalized_images.detach(), target.detach())
    ssim_vals = ssim(normalized_images.detach(), target.detach()).detach()

    return {'mse' : mse_vals.cpu(), 'psnr' : psnr_vals.cpu(), 'ssim' : ssim_vals.cpu()}


def eval_model(model, dataloader):
    measures = []
    measures.append(params)
    for i,batch in enumerate(tqdm(dataloader)):
        sample,target = batch
        sample = sample.cuda()
        target = target.cuda()
        batch = (sample,target)
        outputs = model.shared_step(batch, i)
        temp = run_measures(outputs)
        temp['target'] = target.detach().cpu()
        measures.append(temp)
    return (measures, sample[0], outputs[0])

def eval_model_single(model, dataloader):
    measures = []
    measures.append(params)
    
    batch = next(iter(dataloader))
    sample,target = batch
    sample = sample.cuda()
    target = target.cuda()
    batch = (sample,target)
    outputs = model.shared_step(batch, 0)
    temp = run_measures(outputs)
    temp['target'] = target.detach().cpu()
    measures.append(temp)
    
    return (measures, sample[0], outputs[0])

def eval_model_fromBatch(model, batch):
    measures = []
    measures.append(params)
    sample,target = batch
    sample = sample.cuda()
    target = target.cuda()
    batch = (sample,target)
    outputs = model.shared_step(batch, 0)
    temp = run_measures(outputs)
    temp['target'] = target.detach().cpu()
    measures.append(temp)
    
    return (measures, sample[0], outputs[2])

In [None]:
def split_mse_by_class(measures):
    mse_by_class = {}
    for i in range(0,10):
        temp = []
        for sample in measures:
            if sample['target'] == i:
                temp.append(sample['mse'])
        mse_by_class[i] = temp
    return mse_by_class

def split_psnr_by_class(measures):
    psnr_by_class = {}
    for i in range(0,10):
        temp = []
        for sample in measures:
            if sample['target'] == i:
                temp.append(sample['psnr'])
        psnr_by_class[i] = temp
    return psnr_by_class

---
---
---
---
---
---


In [None]:
average_mse_values = []
average_psnr_values = []

analytical_average_mse_values = []
analytical_average_psnr_values = []

## Analytical LRN to compare against

In [None]:
pm.phase_initialization = 1
analytical_lrn = lrn.LRN(pm.params_model_lrn, pm.params_propagator, pm.params_modulator).cuda()
analytical_lrn.eval()
pm.phase_initialization = 0

In [None]:
fig,ax = plt.subplots(1,1)
ax.imshow(analytical_lrn.layers[1].phase.detach().cpu().squeeze() % (2*np.pi), cmap = 'viridis')
ax.grid(False)
ax.axis('off')
fig.savefig('good_lens.pdf')

In [None]:
analytical_measures, analytical_sample, analytical_output = eval_model(analytical_lrn, dataloader_test)
analytical_params = analytical_measures.pop(0)

analytical_mse_by_class = split_mse_by_class(analytical_measures)
analytical_psnr_by_class = split_psnr_by_class(analytical_measures)

analytical_mse_by_class = [analytical_mse_by_class[i] for i in analytical_mse_by_class]
analytical_psnr_by_class =  [analytical_psnr_by_class[i] for i in analytical_psnr_by_class]

analytical_average_mse_values.append(np.average(analytical_mse_by_class))
analytical_average_psnr_values.append(np.average(analytical_psnr_by_class))

## Model trained on single MNIST digit : 1

How does this model perform when tested on images of the same class and images of different classes? Compare to the analytical lens.

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_1/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
measures, sample, output = eval_model(model, dataloader_test)
params = measures.pop(0)

In [None]:
mse_by_class = split_mse_by_class(measures)
psnr_by_class = split_psnr_by_class(measures)

mse_by_class = [mse_by_class[i] for i in mse_by_class]
psnr_by_class = [psnr_by_class[i] for i in psnr_by_class]

average_mse_values.append(np.average(mse_by_class))
average_psnr_values.append(np.average(psnr_by_class))

ticks = np.asarray([i for i in range(0,10)])
bp0 = plt.boxplot(mse_by_class, widths = 0.6)
plt.xticks(ticks = ticks+1, labels = ticks)
plt.ylabel("Mean squared error")
plt.xlabel("MNIST class")
plt.ylim(0,0.12)
plt.title("Mean squared error by class\n Model trained on single 1")
plt.text(7,0.115, 'Average Model Error : {:.3f}'.format(average_mse_values[-1]))
plt.text(7,0.11, 'Average Analytical Error : {:.3f}'.format(analytical_average_mse_values[-1]))

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, data_loaders['mnist_single0'])
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, data_loaders['mnist_single0'])
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('single1_trainImage_vsAnalytical.pdf')

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, dataloader_test)
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, dataloader_test)
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('single1_testImage_vsAnalytical.pdf')

## Model trained with 10 1s

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_10_1/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
measures, sample, output = eval_model(model, dataloader_test)
params = measures.pop(0)

In [None]:
mse_by_class = split_mse_by_class(measures)
psnr_by_class = split_psnr_by_class(measures)

mse_by_class = [mse_by_class[i] for i in mse_by_class]
psnr_by_class = [psnr_by_class[i] for i in psnr_by_class]

average_mse_values.append(np.average(mse_by_class))
average_psnr_values.append(np.average(psnr_by_class))

ticks = np.asarray([i for i in range(0,10)])
bp0 = plt.boxplot(mse_by_class, widths = 0.6)
plt.xticks(ticks = ticks+1, labels = ticks)
plt.ylabel("Mean squared error")
plt.xlabel("MNIST class")
plt.ylim(0,0.12)
plt.title("Mean squared error by class\n Model trained on 100 1s")
plt.text(7,0.115, 'Average Model Error : {:.3f}'.format(average_mse_values[-1]))
plt.text(7,0.11, 'Average Analytical Error : {:.3f}'.format(analytical_average_mse_values[-1]))

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, data_loaders['mnist_10_1'])
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, data_loaders['mnist_10_1'])
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('10_1_trainImage_vsAnalytical.pdf')

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, dataloader_test)
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, dataloader_test)
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('10_1_testImage_vsAnalytical.pdf')

## Model trained on 100 1s

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_100_1/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
measures, sample, output = eval_model(model, dataloader_test)
params = measures.pop(0)


In [None]:
mse_by_class = split_mse_by_class(measures)
psnr_by_class = split_psnr_by_class(measures)

mse_by_class = [mse_by_class[i] for i in mse_by_class]
psnr_by_class = [psnr_by_class[i] for i in psnr_by_class]

average_mse_values.append(np.average(mse_by_class))
average_psnr_values.append(np.average(psnr_by_class))

ticks = np.asarray([i for i in range(0,10)])
bp0 = plt.boxplot(mse_by_class, widths = 0.6)
plt.xticks(ticks = ticks+1, labels = ticks)
plt.ylabel("Mean squared error")
plt.xlabel("MNIST class")
plt.ylim(0,0.12)
plt.title("Mean squared error by class\n Model trained on 100 1s")
plt.text(7,0.115, 'Average Model Error : {:.3f}'.format(average_mse_values[-1]))
plt.text(7,0.11, 'Average Analytical Error : {:.3f}'.format(analytical_average_mse_values[-1]))

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, data_loaders['mnist_100_1'])
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, data_loaders['mnist_100_1'])
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('100_1_trainImage_vsAnalytical.pdf')

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, dataloader_test)
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, dataloader_test)
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('100_1_testImage_vsAnalytical.pdf')

## Model trained on single MNIST digit : 8

How does this model perform when tested on images of the same class and images of different classes? Compare to the analytical lens.

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_8/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
measures, sample, output = eval_model(model, dataloader_test)
params = measures.pop(0)

In [None]:
mse_by_class = split_mse_by_class(measures)
psnr_by_class = split_psnr_by_class(measures)

mse_by_class = [mse_by_class[i] for i in mse_by_class]
psnr_by_class = [psnr_by_class[i] for i in psnr_by_class]

average_mse_values.append(np.average(mse_by_class))
average_psnr_values.append(np.average(psnr_by_class))

ticks = np.asarray([i for i in range(0,10)])
bp0 = plt.boxplot(mse_by_class, widths = 0.6)
plt.xticks(ticks = ticks+1, labels = ticks)
plt.ylabel("Mean squared error")
plt.xlabel("MNIST class")
plt.ylim(0,0.12)
plt.title("Mean squared error by class\n Model trained on single 8")
plt.text(7,0.115, 'Average Model Error : {:.3f}'.format(average_mse_values[-1]))
plt.text(7,0.11, 'Average Analytical Error : {:.3f}'.format(analytical_average_mse_values[-1]))

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, data_loaders['mnist_single1'])
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, data_loaders['mnist_single1'])
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('single8_trainImage_vsAnalytical.pdf')

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, dataloader_test)
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, dataloader_test)
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('single8_testImage_vsAnalytical.pdf')

## 10 8s

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_10_8/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
measures, sample, output = eval_model(model, dataloader_test)
params = measures.pop(0)

In [None]:
mse_by_class = split_mse_by_class(measures)
psnr_by_class = split_psnr_by_class(measures)

mse_by_class = [mse_by_class[i] for i in mse_by_class]
psnr_by_class = [psnr_by_class[i] for i in psnr_by_class]

average_mse_values.append(np.average(mse_by_class))
average_psnr_values.append(np.average(psnr_by_class))

ticks = np.asarray([i for i in range(0,10)])
bp0 = plt.boxplot(mse_by_class, widths = 0.6)
plt.xticks(ticks = ticks+1, labels = ticks)
plt.ylabel("Mean squared error")
plt.xlabel("MNIST class")
plt.ylim(0,0.12)
plt.title("Mean squared error by class\n Model trained on 10 8s")
plt.text(7,0.115, 'Average Model Error : {:.3f}'.format(average_mse_values[-1]))
plt.text(7,0.11, 'Average Analytical Error : {:.3f}'.format(analytical_average_mse_values[-1]))

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, data_loaders['mnist_10_8'])
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, data_loaders['mnist_10_8'])
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('10_8_trainImage_vsAnalytical.pdf')

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, data_loaders['mnist_single0'])
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, data_loaders['mnist_single0'])
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('10_8_testImage_vsAnalytical.pdf')

## 100 8s

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_100_8/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
measures, sample, output = eval_model(model, dataloader_test)
params = measures.pop(0)

In [None]:
mse_by_class = split_mse_by_class(measures)
psnr_by_class = split_psnr_by_class(measures)

mse_by_class = [mse_by_class[i] for i in mse_by_class]
psnr_by_class = [psnr_by_class[i] for i in psnr_by_class]

average_mse_values.append(np.average(mse_by_class))
average_psnr_values.append(np.average(psnr_by_class))

ticks = np.asarray([i for i in range(0,10)])
bp0 = plt.boxplot(mse_by_class, widths = 0.6)
plt.xticks(ticks = ticks+1, labels = ticks)
plt.ylabel("Mean squared error")
plt.xlabel("MNIST class")
plt.ylim(0,0.12)
plt.title("Mean squared error by class\n Model trained on 100 8s")
plt.text(7,0.115, 'Average Model Error : {:.3f}'.format(average_mse_values[-1]))
plt.text(7,0.11, 'Average Analytical Error : {:.3f}'.format(analytical_average_mse_values[-1]))

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, data_loaders['mnist_100_8'])
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, data_loaders['mnist_100_8'])
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('100_8_trainImage_vsAnalytical.pdf')

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, data_loaders['mnist_single0'])
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, data_loaders['mnist_single0'])
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('100_8_testImage_vsAnalytical.pdf')

## Model trained on 1 image per class

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_1perClass/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
measures, sample, output = eval_model(model, dataloader_test)
params = measures.pop(0)

In [None]:
mse_by_class = split_mse_by_class(measures)
psnr_by_class = split_psnr_by_class(measures)

mse_by_class = [mse_by_class[i] for i in mse_by_class]
psnr_by_class = [psnr_by_class[i] for i in psnr_by_class]

average_mse_values.append(np.average(mse_by_class))
average_psnr_values.append(np.average(psnr_by_class))

ticks = np.asarray([i for i in range(0,10)])
bp0 = plt.boxplot(mse_by_class, widths = 0.6)
plt.xticks(ticks = ticks+1, labels = ticks)
plt.ylabel("Mean squared error")
plt.xlabel("MNIST class")
plt.ylim(0,0.12)
plt.title("Mean squared error by class\n Model trained on 1 per class")
plt.text(7,0.115, 'Average Model Error : {:.3f}'.format(average_mse_values[-1]))
plt.text(7,0.11, 'Average Analytical Error : {:.3f}'.format(analytical_average_mse_values[-1]))

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, data_loaders['mnist_1perClass'])
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, data_loaders['mnist_1perClass'])
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('1perClass_trainImage_vsAnalytical.pdf')

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, dataloader_test)
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, dataloader_test)
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('1perClass_testImage_vsAnalytical.pdf')

## Model trained on 10 image per class

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_10perClass/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
measures, sample, output = eval_model(model, dataloader_test)
params = measures.pop(0)

In [None]:
mse_by_class = split_mse_by_class(measures)
psnr_by_class = split_psnr_by_class(measures)

mse_by_class = [mse_by_class[i] for i in mse_by_class]
psnr_by_class = [psnr_by_class[i] for i in psnr_by_class]

average_mse_values.append(np.average(mse_by_class))
average_psnr_values.append(np.average(psnr_by_class))

ticks = np.asarray([i for i in range(0,10)])
bp0 = plt.boxplot(mse_by_class, widths = 0.6)
plt.xticks(ticks = ticks+1, labels = ticks)
plt.ylabel("Mean squared error")
plt.xlabel("MNIST class")
plt.ylim(0,0.12)
plt.title("Mean squared error by class\n Model trained on 10 per class")
plt.text(7,0.115, 'Average Model Error : {:.3f}'.format(average_mse_values[-1]))
plt.text(7,0.11, 'Average Analytical Error : {:.3f}'.format(analytical_average_mse_values[-1]))

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, dataloader_test)
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, dataloader_test)
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('10perClass_testImage_vsAnalytical.pdf')

## Model trained on 100 image per class

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_100perClass/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
measures, sample, output = eval_model(model, dataloader_test)
params = measures.pop(0)

In [None]:
mse_by_class = split_mse_by_class(measures)
psnr_by_class = split_psnr_by_class(measures)

mse_by_class = [mse_by_class[i] for i in mse_by_class]
psnr_by_class = [psnr_by_class[i] for i in psnr_by_class]

average_mse_values.append(np.average(mse_by_class))
average_psnr_values.append(np.average(psnr_by_class))

ticks = np.asarray([i for i in range(0,10)])
bp0 = plt.boxplot(mse_by_class, widths = 0.6)
plt.xticks(ticks = ticks+1, labels = ticks)
plt.ylabel("Mean squared error")
plt.xlabel("MNIST class")
plt.ylim(0,0.12)
plt.title("Mean squared error by class\n Model trained on 100 per class")
plt.text(7,0.115, 'Average Model Error : {:.3f}'.format(average_mse_values[-1]))
plt.text(7,0.11, 'Average Analytical Error : {:.3f}'.format(analytical_average_mse_values[-1]))

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, dataloader_test)
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, dataloader_test)
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('100perClass_testImage_vsAnalytical.pdf')

## Model trained on 1000 image per class

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_1000perClass/epoch=4-step=6250-v1.ckpt')
model.eval()
model = model.cuda()
measures, sample, output = eval_model(model, dataloader_test)
params = measures.pop(0)

In [None]:
mse_by_class = split_mse_by_class(measures)
psnr_by_class = split_psnr_by_class(measures)

mse_by_class = [mse_by_class[i] for i in mse_by_class]
psnr_by_class = [psnr_by_class[i] for i in psnr_by_class]

average_mse_values.append(np.average(mse_by_class))
average_psnr_values.append(np.average(psnr_by_class))


ticks = np.asarray([i for i in range(0,10)])
bp0 = plt.boxplot(mse_by_class, widths = 0.6)
plt.xticks(ticks = ticks+1, labels = ticks)
plt.ylabel("Mean squared error")
plt.xlabel("MNIST class")
plt.ylim(0,0.12)
plt.title("Mean squared error by class\n Model trained on 1000 per class")
plt.text(7,0.115, 'Average Model Error : {:.3f}'.format(average_mse_values[-1]))
plt.text(7,0.11, 'Average Analytical Error : {:.3f}'.format(analytical_average_mse_values[-1]))

In [None]:
learned_measures, learned_example_input, learned_example_output = eval_model_single(model, dataloader_test)
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

analytical_measures, analytical_example_input, analytical_example_output = eval_model_single(analytical_lrn, dataloader_test)
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

ax[1][1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[0][2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

#ax[1][2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)


fig.savefig('1000perClass_testImage_vsAnalytical.pdf')

## Summary plot
We can kinda summarize the above box plots by looking at average error by variety were we define variety as
$$
\text{Variety} = \text{log}_{10}\left(\frac{\text{# Unique Samples}}{\text{# Total Samples}}\right)
$$

In [None]:
print(len(average_mse_values))
# 1, 10, 100, 1, 10, 100, 1, 10, 100, 1000

average_errors = average_mse_values.copy()
print(average_errors)


average_errors[0] = (average_mse_values[0] + average_mse_values[3])/2
average_errors[1] = (average_mse_values[1] + average_mse_values[4])/2
average_errors[2] = (average_mse_values[2] + average_mse_values[5])/2

average_errors.pop(3)
average_errors.pop(3)
average_errors.pop(3)
# 1, 10, 100, 10, 100, 1000, 10000
print(average_errors)

average_errors[1] = (average_errors[1] + average_errors[3])/2
average_errors.pop(3)

average_errors[2] = (average_errors[2] + average_errors[3])/2
average_errors.pop(3)

#temp1 = average_errors.pop(1)
#temp2 = average_errors.pop(1)

average_errors = np.asarray(average_errors)
print(len(average_errors))
print(average_errors)

In [None]:
plt.style.use('default')


variety = [i/10000 for i in [1,1*10, 10*10, 100*10, 1000*10]]
variety = np.asarray(variety)
fig,ax = plt.subplots(1,1,figsize=(5,3))


line1, = ax.plot(variety, average_errors, marker='v', color = 'purple', label='Learned lens')
#ax.set_xticks(ticks = np.log(variety), labels=[1,10,100,1000])
ax.set_ylim(0,0.065)
ax.set_xscale('log', base=10)
ax.grid(True)
ax.tick_params(axis='x', which='both', right=False, top=False, direction='out', color='black')
line2 = ax.hlines(y = analytical_average_mse_values, xmin = 1.e-4, xmax = 1, color='green', label = 'Analytical lens')
ax.set_ylabel('Average MSE')
ax.set_xlabel('Variety')

#ax.scatter(x=(np.log10(10/10000), np.log10(100/10000)), y=(temp1, temp2))

ax.legend(handles = [line1, line2], loc='upper right')
ax.set_title('Average MSE by variety')
plt.tight_layout()
fig.savefig('mse_by_variety.pdf')

In [None]:
print(len(average_psnr_values))
# 1, 10, 100, 1, 10, 100, 10, 100, 1000

average_psnr = average_psnr_values.copy()
print(average_psnr)


average_psnr[0] = (average_psnr_values[0] + average_psnr_values[3])/2
average_psnr[1] = (average_psnr_values[1] + average_psnr_values[4])/2
average_psnr[2] = (average_psnr_values[2] + average_psnr_values[5])/2

average_psnr.pop(3)
average_psnr.pop(3)
average_psnr.pop(3)
# 1, 10, 100, 10, 100, 1000, 10000
print(average_psnr)

average_psnr[1] = (average_psnr[1] + average_psnr[3])/2
average_psnr.pop(3)

average_psnr[2] = (average_psnr[2] + average_psnr[3])/2
average_psnr.pop(3)

#temp1 = average_psnr.pop(1)
#temp2 = average_psnr.pop(1)

average_psnr = np.asarray(average_psnr)
print(len(average_psnr))
print(average_psnr)

In [None]:
variety = [i/10000 for i in [1,1*10, 10*10, 100*10, 1000*10]]

variety = np.asarray(variety)
fig,ax = plt.subplots(1,1,figsize=(5,3))
line1, = ax.plot(variety, average_psnr, marker='v', color = 'purple', label='Learned lens')

#ax.set_xticks(ticks = np.log(variety), labels=[1,10,100,1000])
ax.set_ylim(6,16)
line2 = ax.hlines(y = analytical_average_psnr_values, xmin = 1.e-4, xmax = 1, color='green', label = 'Analytical lens')
ax.set_ylabel('Average PSNR')
ax.set_xlabel('Variety')
ax.set_xscale('log', base=10)
ax.tick_params(axis='x', which='both', right=False, top=False, direction='out', color='black')
ax.grid(True)

#ax.scatter(x=(np.log10(10/10000), np.log10(100/10000)), y=(temp1, temp2))

ax.legend(handles = [line1, line2], loc='lower right')
ax.set_title('Average PSNR by variety')
plt.tight_layout()
fig.savefig('psnr_by_variety.pdf')

## Now we need to get some images

In [None]:
# Get a good image for comparisons
number = 500
for i,batch in enumerate(dataloader_test):
    if i == number:
        break
fig,ax = plt.subplots(1,1)
ax.imshow(batch[0].abs().squeeze())

In [None]:
# Analytical outputs
analytical_measures, analytical_example_input, analytical_example_output = eval_model_fromBatch(analytical_lrn, batch)
analytical_phase = analytical_lrn.layers[1].phase.detach().squeeze().cpu()
analytical_measures.pop(0)

print(learned_measures)
print(analytical_measures)

In [None]:
fig,ax = plt.subplots(1,3, figsize=(13,8))
ax[1].imshow(analytical_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1].grid(False)
ax[1].axis('off')
ax[0].imshow(analytical_phase % (2*np.pi), cmap='viridis')
ax[0].grid(False)
ax[0].axis('off')
ax[2].imshow(analytical_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[2].grid(False)
ax[2].axis('off')

#ax[2].text(700,75,'PSNR: {:.2f}'.format(analytical_measures[0]['psnr']), color='white', fontsize=12)
#ax[2].text(700,150,'MSE: {:.2f}'.format(analytical_measures[0]['mse']), color='white', fontsize=12)

fig.savefig('ideal_lens_output.pdf')

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_1/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
learned_measures, learned_example_input, learned_example_output = eval_model_fromBatch(model, batch)
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

In [None]:
fig,ax = plt.subplots(1,3, figsize=(13,8))
ax[1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1].grid(False)
ax[1].axis('off')
ax[0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0].grid(False)
ax[0].axis('off')
ax[2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[2].grid(False)
ax[2].axis('off')

#ax[2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

fig.savefig('single1_output.pdf')

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_8/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
learned_measures, learned_example_input, learned_example_output = eval_model_fromBatch(model, batch)
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

In [None]:
fig,ax = plt.subplots(1,3, figsize=(13,8))
ax[1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1].grid(False)
ax[1].axis('off')
ax[0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0].grid(False)
ax[0].axis('off')
ax[2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[2].grid(False)
ax[2].axis('off')

#ax[2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

fig.savefig('single8_output.pdf')

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_1perClass/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
learned_measures, learned_example_input, learned_example_output = eval_model_fromBatch(model, batch)
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

In [None]:
fig,ax = plt.subplots(1,3, figsize=(13,8))
ax[1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1].grid(False)
ax[1].axis('off')
ax[0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0].grid(False)
ax[0].axis('off')
ax[2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[2].grid(False)
ax[2].axis('off')

#ax[2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

fig.savefig('1perClass_output.pdf')

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_10perClass/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
learned_measures, learned_example_input, learned_example_output = eval_model_fromBatch(model, batch)
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

In [None]:
fig,ax = plt.subplots(1,3, figsize=(13,8))
ax[1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1].grid(False)
ax[1].axis('off')
ax[0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0].grid(False)
ax[0].axis('off')
ax[2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[2].grid(False)
ax[2].axis('off')

#ax[2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

fig.savefig('10perClass_output.pdf')

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_100perClass/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
learned_measures, learned_example_input, learned_example_output = eval_model_fromBatch(model, batch)
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

In [None]:
fig,ax = plt.subplots(1,3, figsize=(13,8))
ax[1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1].grid(False)
ax[1].axis('off')
ax[0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0].grid(False)
ax[0].axis('off')
ax[2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[2].grid(False)
ax[2].axis('off')

#ax[2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

fig.savefig('100perClass_output.pdf')

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_1000perClass/epoch=4-step=6250-v1.ckpt')
model.eval()
model = model.cuda()
learned_measures, learned_example_input, learned_example_output = eval_model_fromBatch(model, batch)
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
learned_measures.pop(0)

In [None]:
fig,ax = plt.subplots(1,3, figsize=(13,8))
ax[1].imshow(learned_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1].grid(False)
ax[1].axis('off')
ax[0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0].grid(False)
ax[0].axis('off')
ax[2].imshow(learned_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[2].grid(False)
ax[2].axis('off')

#ax[2].text(700,75,'PSNR: {:.2f}'.format(learned_measures[0]['psnr']), color='white', fontsize=12)
#ax[2].text(700,150,'MSE: {:.2f}'.format(learned_measures[0]['mse']), color='white', fontsize=12)

fig.savefig('1000perClass_output.pdf')

---
---
---

## Single 1 in and out of the dataset

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_1/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
learned_phase = model.layers[1].phase.detach().squeeze().cpu()

In [None]:
# Get a good image for comparisons
number = 500
for i,batch1 in enumerate(dataloader_test):
    if i == number:
        break
fig,ax = plt.subplots(1,1)
ax.imshow(batch1[0].abs().squeeze())

In [None]:
# Get a good image for comparisons
number = 1
for i,batch2 in enumerate(data_loaders['mnist_single0']):
    if i == number:
        break
fig,ax = plt.subplots(1,1)
ax.imshow(batch2[0].abs().squeeze())

In [None]:
test_measures, test_example_input, test_example_output = eval_model_fromBatch(model, batch1)
train_measures, train_example_input, train_example_output = eval_model_fromBatch(model, batch2)

test_params = test_measures.pop(0)
train_params = train_measures.pop(0)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(train_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(train_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

#ax[0][2].text(680,75,'PSNR: {:.2f}'.format(train_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(680,150,'MSE: {:.2f}'.format(train_measures[0]['mse']), color='white', fontsize=12)


ax[1][1].imshow(test_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(test_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[1][2].text(680,75,'PSNR: {:.2f}'.format(test_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(680,150,'MSE: {:.2f}'.format(test_measures[0]['mse']), color='white', fontsize=12)

plt.tight_layout()
fig.savefig('single1_overfit.pdf')

---

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_10_1/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
test_measures, test_example_input, test_example_output = eval_model_fromBatch(model, batch1)
train_measures, train_example_input, train_example_output = eval_model_fromBatch(model, batch2)

test_params = test_measures.pop(0)
train_params = train_measures.pop(0)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(train_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(train_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

#ax[0][2].text(680,75,'PSNR: {:.2f}'.format(train_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(680,150,'MSE: {:.2f}'.format(train_measures[0]['mse']), color='white', fontsize=12)


ax[1][1].imshow(test_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(test_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[1][2].text(680,75,'PSNR: {:.2f}'.format(test_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(680,150,'MSE: {:.2f}'.format(test_measures[0]['mse']), color='white', fontsize=12)

plt.tight_layout()
fig.savefig('10_1_overfit.pdf')

---

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_8/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
learned_phase = model.layers[1].phase.detach().squeeze().cpu()

In [None]:
# Get a good image for comparisons
number = 1
for i,batch2 in enumerate(data_loaders['mnist_single1']):
    if i == number:
        break
fig,ax = plt.subplots(1,1)
ax.imshow(batch2[0].abs().squeeze())

In [None]:
test_measures, test_example_input, test_example_output = eval_model_fromBatch(model, batch1)
train_measures, train_example_input, train_example_output = eval_model_fromBatch(model, batch2)

test_params = test_measures.pop(0)
train_params = train_measures.pop(0)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(train_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(train_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

#ax[0][2].text(680,75,'PSNR: {:.2f}'.format(train_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(680,150,'MSE: {:.2f}'.format(train_measures[0]['mse']), color='white', fontsize=12)


ax[1][1].imshow(test_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(test_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[1][2].text(680,75,'PSNR: {:.2f}'.format(test_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(680,150,'MSE: {:.2f}'.format(test_measures[0]['mse']), color='white', fontsize=12)

plt.tight_layout()
fig.savefig('single8_overfit.pdf')

---

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_10_8/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
test_measures, test_example_input, test_example_output = eval_model_fromBatch(model, batch1)
train_measures, train_example_input, train_example_output = eval_model_fromBatch(model, batch2)

test_params = test_measures.pop(0)
train_params = train_measures.pop(0)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(train_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(train_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

#ax[0][2].text(680,75,'PSNR: {:.2f}'.format(train_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(680,150,'MSE: {:.2f}'.format(train_measures[0]['mse']), color='white', fontsize=12)


ax[1][1].imshow(test_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(test_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[1][2].text(680,75,'PSNR: {:.2f}'.format(test_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(680,150,'MSE: {:.2f}'.format(test_measures[0]['mse']), color='white', fontsize=12)

plt.tight_layout()
fig.savefig('10_8_overfit.pdf')

---

In [None]:
# Get a good image for comparisons
number = 1
for i,batch2 in enumerate(data_loaders['mnist_1perClass']):
    if i == number:
        break
fig,ax = plt.subplots(1,1)
ax.imshow(batch2[0].abs().squeeze())

In [None]:
#Load the model in 
model = lrn.LRN.load_from_checkpoint('../my_models/LRN/model_mnist_10_1/epoch=4-step=6250.ckpt')
model.eval()
model = model.cuda()
learned_phase = model.layers[1].phase.detach().squeeze().cpu()
test_measures, test_example_input, test_example_output = eval_model_fromBatch(model, batch1)
train_measures, train_example_input, train_example_output = eval_model_fromBatch(model, batch2)

test_params = test_measures.pop(0)
train_params = train_measures.pop(0)

In [None]:
fig,ax = plt.subplots(2,3, figsize=(13,8))
ax[0][1].imshow(train_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[0][1].grid(False)
ax[0][1].axis('off')
ax[0][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[0][0].grid(False)
ax[0][0].axis('off')
ax[0][2].imshow(train_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[0][2].grid(False)
ax[0][2].axis('off')

#ax[0][2].text(680,75,'PSNR: {:.2f}'.format(train_measures[0]['psnr']), color='white', fontsize=12)
#ax[0][2].text(680,150,'MSE: {:.2f}'.format(train_measures[0]['mse']), color='white', fontsize=12)


ax[1][1].imshow(test_example_input.detach().cpu().abs().squeeze(), cmap='viridis')
ax[1][1].grid(False)
ax[1][1].axis('off')
ax[1][0].imshow(learned_phase % (2*np.pi), cmap='viridis')
ax[1][0].grid(False)
ax[1][0].axis('off')
ax[1][2].imshow(test_example_output.detach().cpu().abs().squeeze(),cmap='viridis')
ax[1][2].grid(False)
ax[1][2].axis('off')

#ax[1][2].text(680,75,'PSNR: {:.2f}'.format(test_measures[0]['psnr']), color='white', fontsize=12)
#ax[1][2].text(680,150,'MSE: {:.2f}'.format(test_measures[0]['mse']), color='white', fontsize=12)

plt.tight_layout()
fig.savefig('1perClass_overfit.pdf')