In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# cd drive/My \Drive/.....

In [None]:
# Testing
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm

# Channel Attention (CA) Layer
class CALayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.ca = nn.Sequential(
                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.ca(y)
        return x * y

# Residual Channel Attention Block (RCAB)
class RCAB(nn.Module):
    def __init__(self, channel, reduction=16):
        super(RCAB, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(channel, channel, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel, channel, 3, 1, 1),
            CALayer(channel, reduction)
        )

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

# Residual Group (RG)
class RG(nn.Module):
    def __init__(self, channel, num_rcab, reduction):
        super(RG, self).__init__()
        modules_body = [RCAB(channel, reduction) for _ in range(num_rcab)]
        modules_body.append(nn.Conv2d(channel, channel, 3, padding=1))
        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

# Residual Channel Attention Network (RCAN)
class RCAN(nn.Module):
    def __init__(self, num_rg=2, num_rcab=4, num_channels=1, num_features=64, scale_factor=2, reduction=16):
        super(RCAN, self).__init__()
        self.sf = scale_factor
        self.num_features = num_features

        self.head = nn.Conv2d(num_channels, num_features, 3, padding=1)

        self.body = nn.Sequential(*[RG(num_features, num_rcab, reduction) for _ in range(num_rg)])

        self.tail = nn.Sequential(
            nn.Conv2d(num_features, num_features, 3, padding=1),
            nn.Conv2d(num_features, num_channels * (scale_factor ** 2), 3, padding=1),
            nn.PixelShuffle(scale_factor)
        )

    def forward(self, x):
        x = self.head(x)
        res = self.body(x)
        res += x
        x = self.tail(res)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load('./Weights/RCAN.pth').to(device)

x_test = np.load('./Data/test_LR.npy').astype(np.float32).reshape(-1,1,75,75)

out = []
for i in tqdm(range(x_test.shape[0])):

    data = torch.from_numpy(x_test[i].reshape(1,1,75,75))
    data = data.to(device)
    recon = model(data)
    out.append(recon.cpu().detach().numpy().reshape(1,150,150))
dataSR = np.asarray(out)
print(dataSR.shape)

x = np.load('./Data/test_HR.npy').astype(np.float32).reshape(-1,1,150,150)
x_out = dataSR.astype(np.float32)

print("Metrics:")
criteria = nn.MSELoss()
losses = []
Ssim = []
Psnr = []
for i in range(x_test.shape[0]):
    losses.append(criteria(torch.from_numpy(x_out[i]), torch.from_numpy(x[i])))
    Ssim.append(ssim(x[i][0], x_out[i][0], data_range=x_out[i][0].max() - x_out[i][0].min()))
    Psnr.append(psnr(x[i][0], x_out[i][0]))
print("Average MSE super resolution samples: " + str('%.5f'%np.average(losses)))
print("Average SSIM super resolution samples: " + str('%.5f'%np.average(Ssim)))
print("Average PSNR super resolution samples: " + str('%.5f'%np.average(Psnr)))

In [None]:
# Visualize samples
import cv2

out = []
for i in tqdm(range(x_test.shape[0])):
    data = x_test[i].reshape(75,75)
    recon = cv2.resize(data, (150,150), interpolation=cv2.INTER_LINEAR) # Bilinear Interpolation
    out.append(recon.reshape(1,150,150))
dataInter = np.asarray(out)
dataLR = x_test
dataHR = x

for i in range(20):
  f, axarr = plt.subplots(nrows=1,ncols=4,figsize=(16,4))
  plt.sca(axarr[0]); 
  plt.imshow(dataLR[i][0]); plt.title('Low Resolution Image (Input)')
  plt.sca(axarr[1]); 
  plt.imshow(dataInter[i][0]); plt.title('Interpolated Image')
  plt.sca(axarr[2]); 
  plt.imshow(dataSR[i][0]); plt.title('Model Output')
  plt.sca(axarr[3]); 
  plt.imshow(dataHR[i][0]); plt.title('Ground Truth')
  plt.savefig('./Results/.../Sample' + str(i+1) + '.png', format='png', dpi=300)
  plt.close()