In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# cd drive/My \Drive/.....

In [None]:
# Testing
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
import math

class FSRCNN(nn.Module):
    def __init__(self, scale_factor=2, num_channels=1, d=64, s=12, m=4):
        super(FSRCNN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2),
            nn.PReLU(d)
        )
        self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)]
        for _ in range(m):
            self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3//2), nn.PReLU(s)])
        self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)])
        self.mid_part = nn.Sequential(*self.mid_part)
        self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2,
                                            output_padding=scale_factor-1)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.first_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        for m in self.mid_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
        nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)
        nn.init.zeros_(self.last_part.bias.data)

    def forward(self, x):
        x1 = self.first_part(x)
        x2 = self.mid_part(x1)
        x3 = self.last_part(x2)
        return x3
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FSRCNN(scale_factor=2).to(device)
model = torch.load('./Weights/FSRCNN.pth').to(device)

x_test = np.load('./Data/test_LR.npy').astype(np.float32).reshape(-1,1,75,75)

out = []
for i in tqdm(range(x_test.shape[0])):

    data = torch.from_numpy(x_test[i].reshape(1,1,75,75))
    data = data.to(device)
    recon = model(data)
    out.append(recon.cpu().detach().numpy().reshape(1,150,150))
dataSR = np.asarray(out)
print(dataSR.shape)

x = np.load('./Data/test_HR.npy').astype(np.float32).reshape(-1,1,150,150)
x_out = dataSR.astype(np.float32)

print("Metrics:")
criteria = nn.MSELoss()
losses = []
Ssim = []
Psnr = []
for i in range(x_test.shape[0]):
    losses.append(criteria(torch.from_numpy(x_out[i]), torch.from_numpy(x[i])))
    Ssim.append(ssim(x[i][0], x_out[i][0], data_range=x_out[i][0].max() - x_out[i][0].min()))
    Psnr.append(psnr(x[i][0], x_out[i][0]))
print("Average MSE super resolution samples: " + str('%.5f'%np.average(losses)))
print("Average SSIM super resolution samples: " + str('%.5f'%np.average(Ssim)))
print("Average PSNR super resolution samples: " + str('%.5f'%np.average(Psnr)))

x_out2 = np.load('./Data/test_LR_Inter.npy').astype(np.float32).reshape(-1,1,150,150)
print("Metrics for Interpolated images:")
criteria = nn.MSELoss()
losses = []
Ssim = []
Psnr = []
for i in range(x_test.shape[0]):
    losses.append(criteria(torch.from_numpy(x_out2[i]), torch.from_numpy(x[i])))
    Ssim.append(ssim(x[i][0], x_out2[i][0], data_range=x_out2[i][0].max() - x_out2[i][0].min()))
    Psnr.append(psnr(x[i][0], x_out2[i][0]))
print("Average MSE Interpolated samples: " + str('%.5f'%np.average(losses)))
print("Average SSIM Interpolated samples: " + str('%.5f'%np.average(Ssim)))
print("Average PSNR Interpolated samples: " + str('%.5f'%np.average(Psnr)))

In [None]:
# Visualize samples
import cv2

out = []
for i in tqdm(range(x_test.shape[0])):
    data = x_test[i].reshape(75,75)
    recon = cv2.resize(data, (150,150), interpolation=cv2.INTER_LINEAR) # Bilinear Interpolation
    out.append(recon.reshape(1,150,150))
dataInter = np.asarray(out)
dataLR = x_test
dataHR = x

for i in range(20):
  f, axarr = plt.subplots(nrows=1,ncols=4,figsize=(16,4))
  plt.sca(axarr[0]); 
  plt.imshow(dataLR[i][0]); plt.title('Low Resolution Image (Input)')
  plt.sca(axarr[1]); 
  plt.imshow(dataInter[i][0]); plt.title('Interpolated Image')
  plt.sca(axarr[2]); 
  plt.imshow(dataSR[i][0]); plt.title('Model Output')
  plt.sca(axarr[3]); 
  plt.imshow(dataHR[i][0]); plt.title('Ground Truth')
  plt.savefig('./Results/.../Sample' + str(i+1) + '.png', format='png', dpi=300)
  plt.close()