In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# cd drive/My \Drive/.....

In [None]:
# Testing
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm

class Conv(nn.Module):
  def __init__(self, in_c, out_c, **kwargs):
      super().__init__()
      self.cnn = nn.Conv2d(in_c, out_c, **kwargs)
      self.actication = nn.PReLU(num_parameters=out_c)

  def forward(self, x):
      x = self.cnn(x)
      x = self.actication(x)
      return x

class Upsample(nn.Module):
  def __init__(self, in_c, scale_factor):
      super().__init__()
      self.conv = nn.Conv2d(in_c, in_c * scale_factor ** 2, kernel_size=3, stride=1, padding=1)
      self.pixel_shuffle = nn.PixelShuffle(scale_factor)
      self.activation = nn.PReLU(num_parameters=in_c)

  def forward(self, x):
      x = self.conv(x)
      x = self.pixel_shuffle(x)
      x = self.activation(x)
      return x

class RBlock(nn.Module):
  def __init__(self, in_c):
      super().__init__()
      self.b1 = Conv(in_c, in_c, kernel_size=3, stride=1, padding=1)
      self.b2 = Conv(in_c, in_c, kernel_size=3, stride=1, padding=1)

  def forward(self, x):
      b1_output = self.b1(x)
      b2_output = self.b2(b1_output)
      return b2_output + x

class Generator(nn.Module):
  def __init__(self, in_c=1, out_c=64, no_blocks=18):
      super().__init__()
      self.first_conv = Conv(in_c, out_c, kernel_size=9, stride=1, padding=4)
      res_blocks = []
      for _ in range(no_blocks):
          res_blocks.append(RBlock(64))
      self.res_blocks = nn.Sequential(*res_blocks)
      self.conv1 = Conv(out_c, out_c, kernel_size=3, stride=1, padding=1)

      self.upsampling = Upsample(out_c, scale_factor=2)
      self.last_conv = nn.Conv2d(out_c, in_c, kernel_size=3, stride=1, padding=1)

  def forward(self,x):
      first_conv = self.first_conv(x)
      x = self.res_blocks(first_conv)
      x = self.conv1(x) + first_conv
      x = self.upsampling(x)
      x = self.last_conv(x)
      return x
     
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Generator().to(device)
model = model.to(device)
model = torch.load('./Weights/SRResNet.pth').to(device)

x_test = np.load('./Data/test_LR.npy').astype(np.float32).reshape(-1,1,75,75)

out = []
for i in tqdm(range(x_test.shape[0])):

    data = torch.from_numpy(x_test[i].reshape(1,1,75,75))
    data = data.to(device)
    recon = model(data)
    out.append(recon.cpu().detach().numpy().reshape(1,150,150))
dataSR = np.asarray(out)
print(dataSR.shape)

x = np.load('./Data/test_HR.npy').astype(np.float32).reshape(-1,1,150,150)
x_out = dataSR.astype(np.float32)

print("Metrics:")
criteria = nn.MSELoss()
losses = []
Ssim = []
Psnr = []
for i in range(x_test.shape[0]):
    losses.append(criteria(torch.from_numpy(x_out[i]), torch.from_numpy(x[i])))
    Ssim.append(ssim(x[i][0], x_out[i][0], data_range=x_out[i][0].max() - x_out[i][0].min()))
    Psnr.append(psnr(x[i][0], x_out[i][0]))
print("Average MSE super resolution samples: " + str('%.5f'%np.average(losses)))
print("Average SSIM super resolution samples: " + str('%.5f'%np.average(Ssim)))
print("Average PSNR super resolution samples: " + str('%.5f'%np.average(Psnr)))


x_out2 = np.load('./Data/test_LR_Inter.npy').astype(np.float32).reshape(-1,1,150,150)
print("Metrics for Interpolated images:")
criteria = nn.MSELoss()
losses = []
Ssim = []
Psnr = []
for i in range(x_test.shape[0]):
    losses.append(criteria(torch.from_numpy(x_out2[i]), torch.from_numpy(x[i])))
    Ssim.append(ssim(x[i][0], x_out2[i][0], data_range=x_out2[i][0].max() - x_out2[i][0].min()))
    Psnr.append(psnr(x[i][0], x_out2[i][0]))
print("Average MSE Interpolated samples: " + str('%.5f'%np.average(losses)))
print("Average SSIM Interpolated samples: " + str('%.5f'%np.average(Ssim)))
print("Average PSNR Interpolated samples: " + str('%.5f'%np.average(Psnr)))

In [None]:
# Visualize samples
import cv2

out = []
for i in tqdm(range(x_test.shape[0])):
    data = x_test[i].reshape(75,75)
    recon = cv2.resize(data, (150,150), interpolation=cv2.INTER_LINEAR) # Bilinear Interpolation
    out.append(recon.reshape(1,150,150))
dataInter = np.asarray(out)
dataLR = x_test
dataHR = x

for i in range(20):
  f, axarr = plt.subplots(nrows=1,ncols=4,figsize=(16,4))
  plt.sca(axarr[0]); 
  plt.imshow(dataLR[i][0]); plt.title('Low Resolution Image (Input)')
  plt.sca(axarr[1]); 
  plt.imshow(dataInter[i][0]); plt.title('Interpolated Image')
  plt.sca(axarr[2]); 
  plt.imshow(dataSR[i][0]); plt.title('Model Output')
  plt.sca(axarr[3]); 
  plt.imshow(dataHR[i][0]); plt.title('Ground Truth')
  plt.savefig('./Results/.../Sample' + str(i+1) + '.png', format='png', dpi=300)
  plt.close()