In [None]:
import os
import sys
import csv
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from torchmetrics.functional.image import peak_signal_noise_ratio as psnr
from torchmetrics.functional import mean_squared_error as mse
from torchmetrics.functional.image import structural_similarity_index_measure as ssim
from tqdm.notebook import tqdm
sys.path.append('../')
from datamodule.datamodule import select_data
from models.models import Classifier
from mpl_toolkits.axes_grid1 import make_axes_locatable


In [None]:
print(plt.style.available)
plt.style.use('seaborn-v0_8-dark-palette')

In [None]:
version = "version_2"

In [None]:
path_data = f'../../data/sim2real/{version}'
files = os.listdir(path_data)
files.sort()
train_files = [os.path.join(path_data, i) for i in files if 'train' in i]
valid_files = [os.path.join(path_data, i) for i in files if 'valid' in i]

In [None]:
data = torch.load(train_files[0], weights_only=True)

In [None]:
data.keys()

## Training metrics

In [None]:
path_results = f'../../results/sim2real'

In [None]:
csv_filename = os.path.join(path_results, version, 'logs', 'metrics.csv')
metrics = {}
with open(csv_filename) as csvfile:
    reader = csv.reader(csvfile, delimiter=',')
    for i,row in enumerate(reader):
        if i == 0:
            for header in row:
                metrics[header] = []
            key_list = list(metrics.keys())
        else:
            for j,value in enumerate(row):
                metrics[key_list[j]].append(value)

In [None]:
print(metrics.keys())

In [None]:
# Clean up the metrics
for key,value in metrics.items():
    if key == 'epoch':
        metrics[key] = np.unique(np.asarray(value, dtype=int))
    elif key == 'step':
        pass
    else:
        metrics[key] = [float(i) for i in value if i != '']

In [None]:
# Let's just plot all of the metrics in their own plots

fig, ax = plt.subplots(len(metrics.keys()) - 2, 1, figsize=(5,5))

counter = 0
for key, value in metrics.items():
    if key == 'epoch' or key == 'step':
        pass
    else:
        ax[counter].plot(metrics['epoch'], metrics[key])
        ax[counter].set_title(key)
        ax[counter].set_xlabel('epoch')
        if 'classifier' in key:
            ax[counter].set_ylabel("Cross entropy")
        elif key == 'loss_train' or key == 'loss_val':
            ax[counter].set_ylabel("MCL")
        else:
            ax[counter].set_ylabel("MSE")
        counter +=1
plt.tight_layout()

## Plot the learned calibration layer

In [None]:
path_checkpoint = f'../../results/sim2real/{version}/checkpoints/last.ckpt'
state_dict = torch.load(path_checkpoint, weights_only = True)['state_dict']
initial_amplitude = state_dict['dom.layers.1.modulator.initial_amplitude'].squeeze().detach().cpu().numpy()
initial_phase = state_dict['dom.layers.1.modulator.initial_phase'].squeeze().detach().cpu().numpy()
optim_amplitude = state_dict['dom.layers.1.modulator.optimizeable_amplitude'].squeeze().detach().cpu().numpy()
optim_phase = state_dict['dom.layers.1.modulator.optimizeable_phase'].squeeze().detach().cpu().numpy()

In [None]:
fig, ax = plt.subplots(2,2, figsize=(8,5))

im0 = ax[0][0].imshow(initial_amplitude)
divider = make_axes_locatable(ax[0][0])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im0, cax=cax, orientation='vertical')

im1 = ax[0][1].imshow(initial_phase)
divider = make_axes_locatable(ax[0][1])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im1, cax=cax, orientation='vertical')

im2 = ax[1][0].imshow(optim_amplitude)
divider = make_axes_locatable(ax[1][0])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im2, cax=cax, orientation='vertical')

im3 = ax[1][1].imshow(optim_phase)
divider = make_axes_locatable(ax[1][1])
cax = divider.append_axes('right', size='5%', pad=0.05)
fig.colorbar(im3, cax=cax, orientation='vertical')

for a in ax.flatten():
    a.axis('off')
plt.tight_layout()

## Plot 3 examples from each dataset

In [None]:
fig, ax = plt.subplots(3, 4, figsize = (10,5))

indices = torch.randint(low=0, high=len(train_files), size=(3,))

for j,i in enumerate(indices):
    data = torch.load(train_files[i], weights_only=True)
    ax[j][0].imshow(data['resampled_sample'].squeeze().cpu().detach())
    ax[j][1].imshow(data['bench_image'].squeeze().cpu().detach())
    ax[j][2].imshow(data['sim_output'].squeeze().cpu().detach())
    ax[j][3].imshow(data['sim2real_output'].squeeze().cpu().detach())

for a in ax.flatten():
    a.axis('off')

plt.tight_layout()

## Image metrics

In [None]:
mse_values = {'ideal_to_sim':[], 'ideal_to_bench':[], 'ideal_to_sim2real':[], 'bench_to_sim':[], 'bench_to_sim2real':[], 'sim_to_sim2real':[]}
ssim_values = {'ideal_to_sim':[], 'ideal_to_bench':[], 'ideal_to_sim2real':[], 'bench_to_sim':[], 'bench_to_sim2real':[], 'sim_to_sim2real':[]}
psnr_values = {'ideal_to_sim':[], 'ideal_to_bench':[], 'ideal_to_sim2real':[], 'bench_to_sim':[], 'bench_to_sim2real':[], 'sim_to_sim2real':[]}
min_values = {'ideal': [], 'bench':[], 'sim':[], 'sim2real':[]}
max_values = {'ideal': [], 'bench':[], 'sim':[], 'sim2real':[]}

for i,file in enumerate(tqdm(train_files + valid_files)):
    data = torch.load(file, weights_only=True)
    resampled_sample = data['resampled_sample'].cuda()
    bench_image = data['bench_image'].cuda().unsqueeze(0).unsqueeze(0)
    sim_output = data['sim_output'].cuda().unsqueeze(0).unsqueeze(0)
    sim2real_output = data['sim2real_output'].cuda()

    mse_ideal_to_sim = mse(resampled_sample, sim_output).cpu()
    mse_ideal_to_bench = mse(resampled_sample, bench_image).cpu()
    mse_ideal_to_sim2real = mse(resampled_sample, sim2real_output).cpu()
    mse_bench_to_sim = mse(bench_image, sim_output).cpu()
    mse_bench_to_sim2real = mse(bench_image, sim2real_output).cpu()
    mse_sim_to_sim2real = mse(sim_output, sim2real_output).cpu()

    ssim_ideal_to_sim = ssim(resampled_sample, sim_output).cpu()
    ssim_ideal_to_bench = ssim(resampled_sample, bench_image).cpu()
    ssim_ideal_to_sim2real = ssim(resampled_sample, sim2real_output).cpu()
    ssim_bench_to_sim = ssim(bench_image, sim_output).cpu()
    ssim_bench_to_sim2real = ssim(bench_image, sim2real_output).cpu()
    ssim_sim_to_sim2real = ssim(sim_output, sim2real_output).cpu()

    psnr_ideal_to_sim = psnr(resampled_sample, sim_output).cpu()
    psnr_ideal_to_bench = psnr(resampled_sample, bench_image).cpu()
    psnr_ideal_to_sim2real = psnr(resampled_sample, sim2real_output).cpu()
    psnr_bench_to_sim = psnr(bench_image, sim_output).cpu()
    psnr_bench_to_sim2real = psnr(bench_image, sim2real_output).cpu()
    psnr_sim_to_sim2real = psnr(sim_output, sim2real_output).cpu()

    mse_values['ideal_to_sim'].append(mse_ideal_to_sim)
    mse_values['ideal_to_bench'].append(mse_ideal_to_bench)
    mse_values['ideal_to_sim2real'].append(mse_ideal_to_sim2real)
    mse_values['bench_to_sim'].append(mse_bench_to_sim)
    mse_values['bench_to_sim2real'].append(mse_bench_to_sim2real)
    mse_values['sim_to_sim2real'].append(mse_sim_to_sim2real)

    ssim_values['ideal_to_sim'].append(ssim_ideal_to_sim)
    ssim_values['ideal_to_bench'].append(ssim_ideal_to_bench)
    ssim_values['ideal_to_sim2real'].append(ssim_ideal_to_sim2real)
    ssim_values['bench_to_sim'].append(ssim_bench_to_sim)
    ssim_values['bench_to_sim2real'].append(ssim_bench_to_sim2real)
    ssim_values['sim_to_sim2real'].append(ssim_sim_to_sim2real)

    psnr_values['ideal_to_sim'].append(psnr_ideal_to_sim)
    psnr_values['ideal_to_bench'].append(psnr_ideal_to_bench)
    psnr_values['ideal_to_sim2real'].append(psnr_ideal_to_sim2real)
    psnr_values['bench_to_sim'].append(psnr_bench_to_sim)
    psnr_values['bench_to_sim2real'].append(psnr_bench_to_sim2real)
    psnr_values['sim_to_sim2real'].append(psnr_sim_to_sim2real)

    min_values['ideal'].append(torch.min(resampled_sample).cpu())
    min_values['bench'].append(torch.min(bench_image).cpu())
    min_values['sim'].append(torch.min(sim_output).cpu())
    min_values['sim2real'].append(torch.min(sim2real_output).cpu())
    
    max_values['ideal'].append(torch.max(resampled_sample).cpu())
    max_values['bench'].append(torch.max(bench_image).cpu())
    max_values['sim'].append(torch.max(sim_output).cpu())
    max_values['sim2real'].append(torch.max(sim2real_output).cpu())


In [None]:
# Clean up the values

for key,value in mse_values.items():
    array = value
    array = [i.numpy() for i in array]
    mse_values[key] = np.asarray(array)

for key,value in ssim_values.items():
    array = value
    array = [i.numpy() for i in array]
    ssim_values[key] = np.asarray(array)

for key,value in psnr_values.items():
    array = value
    array = [i.numpy() for i in array]
    psnr_values[key] = np.asarray(array)

for key,value in min_values.items():
    array = value
    array = [i.numpy() for i in array]
    min_values[key] = np.asarray(array)

for key,value in max_values.items():
    array = value
    array = [i.numpy() for i in array]
    max_values[key] = np.asarray(array)

## Violin plots now

In [None]:
fig, ax = plt.subplots(1,1, figsize=(8,5))
labels = ['Ideal to Bench', 'Ideal to Sim', 'Ideal to Sim2Real', 'Bench to Sim', 'Bench to Sim2Real', 'Sim to Sim2Real']


bp0 = ax.violinplot([mse_values['ideal_to_bench'], 
                     mse_values['ideal_to_sim'], 
                     mse_values['ideal_to_sim2real'], 
                     mse_values['bench_to_sim'], 
                     mse_values['bench_to_sim2real'],
                     mse_values['sim_to_sim2real']], 
                     points=1000, showmeans=True)


ax.set_xticks([i+1 for i in range(len(labels))], labels, rotation=45)
ax.set_ylabel("MSE")

In [None]:
fig, ax = plt.subplots(1,1, figsize=(8,5))
labels = ['Ideal to Bench', 'Ideal to Sim', 'Ideal to Sim2Real', 'Bench to Sim', 'Bench to Sim2Real', 'Sim to Sim2Real']


bp0 = ax.violinplot([ssim_values['ideal_to_bench'], 
                     ssim_values['ideal_to_sim'], 
                     ssim_values['ideal_to_sim2real'], 
                     ssim_values['bench_to_sim'], 
                     ssim_values['bench_to_sim2real'],
                     ssim_values['sim_to_sim2real']], 
                     points=1000, showmeans=True)


ax.set_xticks([i+1 for i in range(len(labels))], labels, rotation=45)
ax.set_ylabel("SSIM")

In [None]:
fig, ax = plt.subplots(1,1, figsize=(8,5))
labels = ['Ideal to Bench', 'Ideal to Sim', 'Ideal to Sim2Real', 'Bench to Sim', 'Bench to Sim2Real', 'Sim to Sim2Real']


bp0 = ax.violinplot([psnr_values['ideal_to_bench'], 
                     psnr_values['ideal_to_sim'], 
                     psnr_values['ideal_to_sim2real'], 
                     psnr_values['bench_to_sim'], 
                     psnr_values['bench_to_sim2real'],
                     psnr_values['sim_to_sim2real']], 
                     points=1000, showmeans=True)


ax.set_xticks([i+1 for i in range(len(labels))], labels, rotation=45)
ax.set_ylabel("PSNR")

In [None]:
fig, ax = plt.subplots(1,1, figsize=(8,5))
labels = ['Ideal', 'Sim', 'Bench', 'Sim2Real']


bp0 = ax.violinplot([min_values['ideal'], 
                     min_values['sim'], 
                     min_values['bench'], 
                     min_values['sim2real']], 
                     points=1000, showmeans=True)


ax.set_xticks([i+1 for i in range(len(labels))], labels, rotation=45)
ax.set_ylabel("Min values")

In [None]:
fig, ax = plt.subplots(1,1, figsize=(8,5))
labels = ['Ideal', 'Sim', 'Bench', 'Sim2Real']


bp0 = ax.violinplot([max_values['ideal'], 
                     max_values['sim'], 
                     max_values['bench'], 
                     max_values['sim2real']], 
                     points=1000, showmeans=True)


ax.set_xticks([i+1 for i in range(len(labels))], labels, rotation=45)
ax.set_ylabel("Max values")