In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# cd drive/My \Drive/.....

In [None]:
!pip3 install e2cnn

In [None]:
# Testing
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt
%matplotlib inline
import torch
#import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
import math
import torch
from e2cnn import gspaces
from e2cnn import nn

class Equivariant_FSRCNN(torch.nn.Module):
    
    def __init__(self, sym_group = "Dihyderal", N = 2, scale_factor=2, num_channels=1, d=16, s=64, m=4):
        
        super(Equivariant_FSRCNN, self).__init__()
        
        if sym_group == 'Dihyderal':
            self.r2_act = gspaces.FlipRot2dOnR2(N=N)
        elif sym_group == 'Circular':
            self.r2_act = gspaces.Rot2dOnR2(N=N)
            
        in_type = nn.FieldType(self.r2_act, num_channels*[self.r2_act.trivial_repr])
        self.input_type = in_type

        out_type = nn.FieldType(self.r2_act, d*[self.r2_act.regular_repr])
        self.first_part = nn.SequentialModule(
            nn.MaskModule(in_type, 75, margin=1),
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=5//2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        )
        
        mid_part = []
        in_type = self.first_part.out_type
        out_type = nn.FieldType(self.r2_act, s*[self.r2_act.regular_repr])
        mid_part.extend([
            nn.R2Conv(in_type, out_type, kernel_size=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True)
        ])
        for _ in range(m):
            in_type = out_type
            out_type = nn.FieldType(self.r2_act, s*[self.r2_act.regular_repr])
            mid_part.extend([
                nn.R2Conv(in_type, out_type, kernel_size=3, padding=3//2, bias=False),
                nn.InnerBatchNorm(out_type),
                nn.ReLU(out_type, inplace=True)
            ])
        self.mid_part = nn.SequentialModule(*mid_part)

        self.last_part = torch.nn.ConvTranspose2d(s*4, num_channels, kernel_size=9, stride=scale_factor, padding=9//2,
                                            output_padding=scale_factor-1)
        
    def forward(self, input: torch.Tensor):
        x = nn.GeometricTensor(input, self.input_type)
        x = self.first_part(x)
        x = self.mid_part(x)
        x = x.tensor
        x = self.last_part(x)
        return x
       
device = torch.device("cuda")
model = Equivariant_FSRCNN().to(device)
model.load_state_dict(torch.load('./Weights/HFSRCNN-1.pth'))

x_test = np.load('./Data/test_LR.npy').astype(np.float32).reshape(-1,1,75,75)

out = []
for i in tqdm(range(x_test.shape[0])):

    data = torch.from_numpy(x_test[i].reshape(1,1,75,75))
    data = data.to(device)
    recon = model(data)
    out.append(recon.cpu().detach().numpy().reshape(1,150,150))
dataSR = np.asarray(out)
print(dataSR.shape)

x = np.load('./Data/test_HR.npy').astype(np.float32).reshape(-1,1,150,150)
x_out = dataSR.astype(np.float32)

print("Metrics:")
criteria = torch.nn.MSELoss()
losses = []
Ssim = []
Psnr = []
for i in range(x_test.shape[0]):
    losses.append(criteria(torch.from_numpy(x_out[i]), torch.from_numpy(x[i])))
    Ssim.append(ssim(x[i][0], x_out[i][0], data_range=x_out[i][0].max() - x_out[i][0].min()))
    Psnr.append(psnr(x[i][0], x_out[i][0]))
print("Average MSE super resolution samples: " + str('%.5f'%np.average(losses)))
print("Average SSIM super resolution samples: " + str('%.5f'%np.average(Ssim)))
print("Average PSNR super resolution samples: " + str('%.5f'%np.average(Psnr)))

x_out2 = np.load('./Data/test_LR_Inter.npy').astype(np.float32).reshape(-1,1,150,150)
print("Metrics for Interpolated images:")
losses = []
Ssim = []
Psnr = []
for i in range(x_test.shape[0]):
    losses.append(criteria(torch.from_numpy(x_out2[i]), torch.from_numpy(x[i])))
    Ssim.append(ssim(x[i][0], x_out2[i][0], data_range=x_out2[i][0].max() - x_out2[i][0].min()))
    Psnr.append(psnr(x[i][0], x_out2[i][0]))
print("Average MSE Interpolated samples: " + str('%.5f'%np.average(losses)))
print("Average SSIM Interpolated samples: " + str('%.5f'%np.average(Ssim)))
print("Average PSNR Interpolated samples: " + str('%.5f'%np.average(Psnr)))

In [None]:
# Visualize samples
import cv2

out = []
for i in tqdm(range(x_test.shape[0])):
    data = x_test[i].reshape(75,75)
    recon = cv2.resize(data, (150,150), interpolation=cv2.INTER_LINEAR) # Bilinear Interpolation
    out.append(recon.reshape(1,150,150))
dataInter = np.asarray(out)
dataLR = x_test
dataHR = x

for i in range(20):
  f, axarr = plt.subplots(nrows=1,ncols=4,figsize=(16,4))
  plt.sca(axarr[0]); 
  plt.imshow(dataLR[i][0]); plt.title('Low Resolution Image (Input)')
  plt.sca(axarr[1]); 
  plt.imshow(dataInter[i][0]); plt.title('Interpolated Image')
  plt.sca(axarr[2]); 
  plt.imshow(dataSR[i][0]); plt.title('Model Output')
  plt.sca(axarr[3]); 
  plt.imshow(dataHR[i][0]); plt.title('Ground Truth')
  plt.savefig('./Results/.../Sample' + str(i+1) + '.png', format='png', dpi=300)
  plt.close()