# SRCNN and AE test notebook
This notebook gives you the code to test your trained models on a test dataset.

### This file is distributed under the following license:

Copyright (c) 2021 Idiap Research Institute, http://www.idiap.ch/

Written by Rémi Clerc <remi.clerc@idiap.ch>

Redistribution and use in source and binary forms, with or without modification, are
permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this list of
conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice, this list
of conditions and the following disclaimer in the documentation and/or other materials
provided with the distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

## Libraries

In [None]:
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image

import matplotlib
from matplotlib import pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline

from os import listdir
import math

from AE import AE
from SRCNN import SRCNN

import random
import re

## File paths
Those paths are up to you to modify

In [None]:
# this is the folder containing the .pth file(s) generated by the training
weights_folder = "../pth/"

# this is the specific file you want to test
weights_file = weights_folder + "SRCNN_1mm.pth"

# this is the folder containgin your test dataset
image_folder = "path/to/test/dataset/folder/"

# this is the folder where images and figures will be generated
output_folder = "./"

## Visualize losses and PSNR
Run this section if you have run train.py before, and you have a folder containing the .pth files of each epoch. The following code will iterate through those files, extract the losses and PSNR, and display and save figures showing that data.
Your pth files must have names of the form "epoch_xx.pth" to be taken into account.

In [None]:
files = listdir(weights_folder)
regex = re.compile('epoch.*')
files = [file for file in files if regex.match(file)]
losses = []
val_losses = []
psnr = []

#
for index in range(len(files)):
    checkpoint = torch.load(weights_folder + files[index], map_location=torch.device('cpu'))
    losses.append(checkpoint['training_loss'])
    psnr.append(checkpoint['eval_psnr'])

In [None]:
plt.figure(figsize=(10,5))
plt.plot(losses)

plt.title('Training Loss')
plt.xticks(range(0, len(losses), round(len(losses)/10)))
plt.xlabel('Epoch')
plt.ylabel('Training loss')

plt.savefig(output_folder + 'training_loss.jpg')

In [None]:
plt.figure(figsize=(10,5))
plt.plot(psnr)

plt.title('Eval PSNR for each epoch')
plt.xticks(range(0, len(psnr), round(len(psnr)/10)))
plt.xlabel('Epoch')
plt.ylabel('Eval PSNR')
plt.savefig(output_folder + 'eval_psnr.jpg')

## Test the model
In the following cell, write

`architecture = "SRCNN"` if you're testing SRCNN, or

`architecture = "AE"` if you're testing the autoencoder.

Run the rest of the code to test the model on the test dataset, and to display and save pictures showing the outputs.

In [None]:
architecture = "SRCNN"

In [None]:
cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

criterion = nn.MSELoss().to(device)

# Load model
if architecture == "SRCNN":
    model = SRCNN().to(device)

else:
    model = AE().to(device)

# Load model weights
checkpoint = torch.load(weights_file, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

model.eval()


# folder definitions
input_folder = image_folder + 'fresnelet/'
orig_folder = image_folder + 'greyscale/'
filenames = listdir(input_folder)


# run tests
for filename in filenames:
    # load image
    input_img = mpimg.imread(input_folder + filename)/255.
    input_img = torch.from_numpy(input_img).to(device).unsqueeze(0).unsqueeze(0)
    orig_img = mpimg.imread(orig_folder + filename)/255.
    orig_img = torch.from_numpy(orig_img).to(device)
    
    # run prediction
    model = model.double()
    with torch.no_grad():
        pred = model(input_img.double()).clamp(0.0, 1.0).squeeze(0).squeeze(0)

    loss = criterion(pred, orig_img)
    psnr = 10 * math.log10((orig_img.max() ** 2) / loss.item())
    print('{} PSNR: {:.2f}'.format(filename, psnr))
    
    # create figure
    f, axarr = plt.subplots(1,3, figsize=(20,10))
    axarr[0].imshow(input_img.squeeze(0).squeeze(0), cmap='gray')
    axarr[1].imshow(pred, cmap='gray')
    axarr[2].imshow(orig_img, cmap='gray')

    axarr[0].title.set_text('Input')
    axarr[1].title.set_text('Prediction')
    axarr[2].title.set_text('Original')
    
    f.savefig(output_folder + filename)