# Validation Metrics

In [None]:
import os
from scipy.io import loadmat
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import torch

from utilities.functions import SSIM, PSNR

In [None]:
# Load validation data for processing
mat_gt_file = os.getcwd() + '/data/gsn_10/ValidationGtBlocks.mat'
mat_noisy_file = os.getcwd() + '/data/gsn_10/ValidationNoisyBlocks.mat'
mat_gt = loadmat(mat_gt_file)
mat_noisy = loadmat(mat_noisy_file)

size = mat_gt['val_gt'].shape
print(size)

In [None]:
index0 = 1//2

image_mat_gt = mat_gt['val_gt'][index0, :, :, :, :].astype(float)
image_mat_noisy = mat_noisy['val_ng'][index0, :, :, :, :].astype(float)

In [None]:
fig = make_subplots(3, 3)

for i in range(3):
    for j in range(3):
        fig.add_trace(go.Image(z=image_mat_gt[i * 3 + j]), i + 1, j + 1)
fig.update_layout(autosize=False, height=800, width=800, 
                  title_text="Validation GT Samples {i}".format(i=index0),
                 showlegend=False)
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
fig.show()

In [None]:
fig = make_subplots(3, 3)

for i in range(3):
    for j in range(3):
        fig.add_trace(go.Image(z=image_mat_noisy[i * 3 + j]), i + 1, j + 1)
fig.update_layout(autosize=False, height=800, width=800, 
                  title_text="Validation Noisy Samples {i}".format(i=index0),
                 showlegend=False)
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
fig.show()

In [None]:
total_ssim = 0
total_psnr = 0

for i in range(size[0]):
    image_mat_gt = torch.tensor(mat_gt['val_gt'][i, :, :, :, :] / 255., 
                                dtype=torch.float).permute(0, 3, 1, 2)
    image_mat_noisy = torch.tensor(mat_noisy['val_ng'][i, :, :, :, :] / 255., 
                                   dtype=torch.float).permute(0, 3, 1, 2)
    ssim = SSIM(image_mat_gt, image_mat_noisy)
    mse = torch.square(image_mat_gt - image_mat_noisy).mean()
    psnr = PSNR(mse)
    print('Noisy', i, ':', ssim.item(), '-', psnr.item())

    total_ssim += ssim.item()
    total_psnr += psnr.item()

total_ssim /= size[0]
total_psnr /= size[0]

print('Total Raw', '-', total_ssim, '-', total_psnr)