In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# cd drive/My \Drive/...

In [None]:
# Testing
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
from modules_conditional import UNet, Diffusion
import torch.nn.functional as F

device = "cuda"
model = UNet().to(device)
model.load_state_dict(torch.load('./Weights/Diff_ckpt_1.pt'))
diffusion = Diffusion(img_size=64, device=device)

class DenseLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DenseLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=3 // 2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return torch.cat([x, self.relu(self.conv(x))], 1)


class RDB(nn.Module):
    def __init__(self, in_channels, growth_rate, num_layers):
        super(RDB, self).__init__()
        self.layers = nn.Sequential(*[DenseLayer(in_channels + growth_rate * i, growth_rate) for i in range(num_layers)])

        # local feature fusion
        self.lff = nn.Conv2d(in_channels + growth_rate * num_layers, growth_rate, kernel_size=1)

    def forward(self, x):
        return x + self.lff(self.layers(x))  # local residual learning


class RDN(nn.Module):
    def __init__(self, scale_factor=2, num_channels=1, num_features=32, growth_rate=32, num_blocks=6, num_layers=4):
        super(RDN, self).__init__()
        self.G0 = num_features
        self.G = growth_rate
        self.D = num_blocks
        self.C = num_layers

        # shallow feature extraction
        self.sfe1 = nn.Conv2d(num_channels, num_features, kernel_size=3, padding=3 // 2)
        self.sfe2 = nn.Conv2d(num_features, num_features, kernel_size=3, padding=3 // 2)

        # residual dense blocks
        self.rdbs = nn.ModuleList([RDB(self.G0, self.G, self.C)])
        for _ in range(self.D - 1):
            self.rdbs.append(RDB(self.G, self.G, self.C))

        # global feature fusion
        self.gff = nn.Sequential(
            nn.Conv2d(self.G * self.D, self.G0, kernel_size=1),
            nn.Conv2d(self.G0, self.G0, kernel_size=3, padding=3 // 2)
        )

        # up-sampling
        assert 2 <= scale_factor <= 4
        if scale_factor == 2 or scale_factor == 4:
            self.upscale = []
            for _ in range(scale_factor // 2):
                self.upscale.extend([nn.Conv2d(self.G0, self.G0 * (2 ** 2), kernel_size=3, padding=3 // 2),
                                     nn.PixelShuffle(2)])
            self.upscale = nn.Sequential(*self.upscale)
        else:
            self.upscale = nn.Sequential(
                nn.Conv2d(self.G0, self.G0 * (scale_factor ** 2), kernel_size=3, padding=3 // 2),
                nn.PixelShuffle(scale_factor)
            )

        self.output = nn.Conv2d(self.G0, num_channels, kernel_size=3, padding=3 // 2)

    def forward(self, x):
        sfe1 = self.sfe1(x)
        sfe2 = self.sfe2(sfe1)

        x = sfe2
        local_features = []
        for i in range(self.D):
            x = self.rdbs[i](x)
            local_features.append(x)

        x = self.gff(torch.cat(local_features, 1)) + sfe1  # global residual learning
        #x = self.upscale(x)
        x = self.output(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model2 = torch.load('./Weights/RDN.pth').to(device)

x_test = np.load('./Data/.../test_LR.npy').astype(np.float32).reshape(-1,10,1,64,64)[:10]

out = []
out2 = []
for i in tqdm(range(x_test.shape[0])):

    data = torch.from_numpy(x_test[i].reshape(10,1,64,64))
    data = data.to(device)
    recon = diffusion.sample(model, n=data.shape[0], lat=data)
    recon2 = model2(data)
    out.append(recon.cpu().detach().numpy().reshape(10,1,64,64))
    out2.append(recon2.cpu().detach().numpy().reshape(10,1,64,64))
dataSR = np.asarray(out)
dataSR2 = np.asarray(out2)
print(dataSR.shape)
print(dataSR2.shape)

x = np.load('./Data/.../test_HR.npy').astype(np.float32).reshape(-1,10,1,64,64)[:10]
x_out = dataSR.astype(np.float32)
x_out2 = dataSR2.astype(np.float32)

x_test = x_test.reshape(-1,1,64,64) # LR
x = x.reshape(-1,1,64,64) # HR
x_out = x_out.reshape(-1,1,64,64) # SR Diffusion
x_out2 = x_out2.reshape(-1,1,64,64) # SR RDN

In [None]:
# Visualize samples
import cv2
dataLR = x_test
dataHR = x

for i in range(20):
  f, axarr = plt.subplots(nrows=1, ncols=4, figsize=(20, 3))

  plt.sca(axarr[0])
  plt.imshow(x_test[i][0], cmap='gray')
  plt.title('Low Resolution Image (Input)')

  plt.sca(axarr[1])
  plt.imshow(x_out2[i][0], cmap='gray')
  plt.title('RDN Model Output')

  plt.sca(axarr[2])
  plt.imshow(x_out[i][0], cmap='gray')
  plt.title('Diffusion Model Output')

  plt.sca(axarr[3])
  plt.imshow(x[i][0], cmap='gray')
  plt.title('Ground Truth')

  plt.savefig('./Results/.../Sample' + str(i+1) + '.png', format='png', dpi=300)
  plt.close()