In [1]:
import os
import time
import copy
from collections import defaultdict
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
from torchvision import transforms, utils
from torch import nn
import albumentations as A
from albumentations import (HorizontalFlip, ShiftScaleRotate, Normalize, Resize, Compose, GaussNoise)
import cv2
from torch.optim import Adam, SGD
import torch.nn.functional as F
from torch import nn
import random
import itertools
from tqdm import tqdm
import math
from torch import optim
%matplotlib inline

In [None]:
# Install valuation metrics
! pip install torchmetrics

In [None]:
# Install valuation metrics
! pip install lpips

In [None]:
# Install model summary 
! pip install pytorch-model-summary

##Build datasets and dataloaders

In [12]:
# Mount google drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# visualize data
img1_path = "/content/drive/My Drive/Colab Notebooks/SR_fluo_data/train/input/nuclei_03.tif"
img = io.imread(img1_path)
plt.imshow(img)
plt.show()

sr1_path = "/content/drive/My Drive/Colab Notebooks/SR_fluo_data/train/target/nuclei_03.tif"
sr = io.imread(sr1_path)
plt.imshow(sr)
plt.show()

Build dataset

In [14]:
# Transform for data augmentation
def get_transforms():
            list_transforms = []
            list_transforms.extend([A.RandomRotate90(p=0.5),])
            list_transforms.extend([A.HorizontalFlip(p=0.5),])
            list_transforms.extend([A.VerticalFlip(p=0.5),])
            
            list_trfms = Compose(list_transforms)
            return list_trfms

# Data class
class Fluo_data(Dataset):
        def __init__(self,path,transforms = None):
            self.path = path
            self.folders = os.listdir(path)
            self.transforms = get_transforms()          
        
        def __len__(self):
            image_folder = os.path.join(self.path,'input/')
            self.image_folder = sorted(os.listdir(image_folder))
            return len(self.image_folder)
              
        
        def __getitem__(self,idx):
            # read  images and masks from the training dataset. 
            # image represents input low-res image.
            # mask represents high-resolution image (target) for low-res image.
            image_folder = os.path.join(self.path,'input/')
            mask_folder = os.path.join(self.path,'target/')
            image_path = os.path.join(image_folder,sorted(os.listdir(image_folder))[idx])
            mask_path = os.path.join(mask_folder,sorted(os.listdir(mask_folder))[idx])
            
            img = io.imread(image_path).astype('float32')
            size = 256
            img = transform.resize(img,(size,size))
            img = cv2.normalize(img, None, 0, 1.0, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
            #img_extend = np.dstack((img, img, img))
                       
            mask = io.imread(mask_path).astype('float32')
            mask = transform.resize(mask,(size,size))
            mask = cv2.normalize(mask, None, 0, 1.0, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
            #mask_extend = np.dstack((mask, mask, mask))
            
            augmented = self.transforms(image=img, mask=mask) # data augmentation
            img = augmented['image']
            
            mask = augmented['mask']
                
            # convert everything into a torch.Tensor
            mask = torch.as_tensor(mask, dtype=torch.float32)
            img = torch.as_tensor(img, dtype=torch.float32)
            img = img.unsqueeze(0)
            mask = mask.unsqueeze(0)
            
            return img, mask

In [15]:
#loading the data
data_path="/content/drive/My Drive/Colab Notebooks/SR_fluo_data/train"
data = Fluo_data(data_path)
test_data_path = "/content/drive/My Drive/Colab Notebooks/SR_fluo_data/test"
test_data = Fluo_data(test_data_path)

In [None]:
# Check length
print(len(data))

In [None]:
# Check getitem
image, mask = test_data.__getitem__(0)
print(mask.shape)
print(torch.max(image))
print(torch.max(mask))

In [None]:
# Some utility functions to show images

# Convert torch tensor to image
def image_convert(image):
    image = image.clone().cpu().numpy()
    image = image.transpose((1,2,0))
    image = cv2.normalize(image, None, 255, 0, cv2.NORM_MINMAX, cv2.CV_8U)
    return image

def plot_img(no_):
    images = []
    masks = []
    random_idx = np.random.choice(len(data),no_)
    for i in random_idx:      
        image, mask = data.__getitem__(i)
        images.append(image)
        masks.append(mask)
    plt.figure(figsize=(15,10))
    for idx in range(0,no_):
        image = image_convert(images[idx])
        plt.subplot(2,no_,idx+1)
        plt.title('Low-res image')
        plt.imshow(image,cmap='gray', vmin=0, vmax=255)
    for idx in range(0,no_):
        mask = image_convert(masks[idx])
        plt.subplot(2,no_,idx+no_+1)
        plt.title('High-res image')
        plt.imshow(mask,cmap='gray', vmin=0, vmax=255)
    plt.show()

In [None]:
plot_img(5)

Build dataloaders

In [16]:
# Build the dataloaders
# split the training and validation dataset to around 0.9 train and 0.1 val
trainset, valset = random_split(data, [43, 5])

# build train dataloader
train_loader = DataLoader(dataset=trainset, batch_size=2, shuffle = True, num_workers = 0)

# build validation dataloader
val_loader = DataLoader(dataset=valset, batch_size=2, shuffle = False, num_workers = 0)

## Model construction

In [None]:
# Check GPU.
if torch.cuda.is_available():
  device = torch.device('cuda:0')
else:
  device = torch.device('cpu')
print(device)

cpu


EDSR Model

In [None]:
# Build EDSR model. This code is adapted from https://github.com/soapisnotfat/super-resolution.git
class Net(nn.Module):
    def __init__(self, num_channels, base_channel, upscale_factor, num_residuals):
        super(Net, self).__init__()

        self.input_conv = nn.Conv2d(num_channels, base_channel, kernel_size=3, stride=1, padding=1)

        resnet_blocks = []
        for _ in range(num_residuals):
            resnet_blocks.append(ResnetBlock(base_channel, kernel=3, stride=1, padding=1))
        self.residual_layers = nn.Sequential(*resnet_blocks)

        self.mid_conv = nn.Conv2d(base_channel, base_channel, kernel_size=3, stride=1, padding=1)

        upscale = []
        for _ in range(int(math.log2(upscale_factor))):
            upscale.append(PixelShuffleBlock(base_channel, base_channel, upscale_factor=2))
        self.upscale_layers = nn.Sequential(*upscale)

        self.output_conv = nn.Conv2d(base_channel, num_channels, kernel_size=3, stride=1, padding=1)

    def weight_init(self, mean=0.0, std=0.02):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    def forward(self, x):
        x = self.input_conv(x)
        residual = x
        x = self.residual_layers(x)
        x = self.mid_conv(x)
        x = torch.add(x, residual)
        x = self.upscale_layers(x)
        x = self.output_conv(x)
        return x


def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        if m.bias is not None:
            m.bias.data.zero_()


class ResnetBlock(nn.Module):
    def __init__(self, num_channel, kernel=3, stride=1, padding=1):
        super(ResnetBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_channel, num_channel, kernel, stride, padding)
        self.conv2 = nn.Conv2d(num_channel, num_channel, kernel, stride, padding)
        self.bn = nn.BatchNorm2d(num_channel)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        x = self.bn(self.conv1(x))
        x = self.activation(x)
        x = self.bn(self.conv2(x))
        x = torch.add(x, residual)
        return x


class PixelShuffleBlock(nn.Module):
    def __init__(self, in_channel, out_channel, upscale_factor, kernel=3, stride=1, padding=1):
        super(PixelShuffleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channel, out_channel * upscale_factor ** 2, kernel, stride, padding)
        self.ps = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        x = self.ps(self.conv(x))
        return x


In [None]:
model_edsr = Net(num_channels=1, upscale_factor=1, base_channel=64, num_residuals=16).to(device)
model_edsr.weight_init(mean=0.0, std=0.02)

In [None]:
from pytorch_model_summary import summary
print(summary(model_edsr, torch.zeros((1, 1, 256, 256)), show_input=True))

SRGAN model

In [7]:
# Build SRGAN model. This code is adapted from https://github.com/soapisnotfat/super-resolution.git
def swish(x):
    return x * F.sigmoid(x)


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, kernel, out_channels, stride):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel, stride=stride, padding=kernel // 2)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel, stride=stride, padding=kernel // 2)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        y = swish(self.bn1(self.conv1(x)))
        return self.bn2(self.conv2(y)) + x


class UpsampleBlock(nn.Module):
    # Implements resize-convolution
    def __init__(self, in_channels):
        super(UpsampleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * 4, kernel_size=3, stride=1, padding=1)
        self.shuffler = nn.PixelShuffle(2)

    def forward(self, x):
        return swish(self.shuffler(self.conv(x)))


class Generator(nn.Module):
    def __init__(self, n_residual_blocks, upsample_factor, num_channel=1, base_filter=64):
        super(Generator, self).__init__()
        self.n_residual_blocks = n_residual_blocks
        self.upsample_factor = upsample_factor

        self.conv1 = nn.Conv2d(num_channel, base_filter, kernel_size=9, stride=1, padding=4)

        for i in range(self.n_residual_blocks):
            self.add_module('residual_block' + str(i + 1), ResidualBlock(in_channels=base_filter, out_channels=base_filter, kernel=3, stride=1))

        self.conv2 = nn.Conv2d(base_filter, base_filter, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(base_filter)

        for i in range(self.upsample_factor // 2):
            self.add_module('upsample' + str(i + 1), UpsampleBlock(base_filter))

        self.conv3 = nn.Conv2d(base_filter, num_channel, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        x = swish(self.conv1(x))

        y = x.clone()
        for i in range(self.n_residual_blocks):
            y = self.__getattr__('residual_block' + str(i + 1))(y)

        x = self.bn2(self.conv2(y)) + x

        for i in range(self.upsample_factor // 2):
            x = self.__getattr__('upsample' + str(i + 1))(x)

        return self.conv3(x)

    def weight_init(self, mean=0.0, std=0.02):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)


class Discriminator(nn.Module):
    def __init__(self, num_channel=1, base_filter=64):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(num_channel, base_filter, kernel_size=3, stride=1, padding=1)

        self.conv2 = nn.Conv2d(base_filter, base_filter, kernel_size=3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(base_filter)
        self.conv3 = nn.Conv2d(base_filter, base_filter * 2, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(base_filter * 2)
        self.conv4 = nn.Conv2d(base_filter * 2, base_filter * 2, kernel_size=3, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(base_filter * 2)
        self.conv5 = nn.Conv2d(base_filter * 2, base_filter * 4, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(base_filter * 4)
        self.conv6 = nn.Conv2d(base_filter * 4, base_filter * 4, kernel_size=3, stride=2, padding=1)
        self.bn6 = nn.BatchNorm2d(base_filter * 4)
        self.conv7 = nn.Conv2d(base_filter * 4, base_filter * 8, kernel_size=3, stride=1, padding=1)
        self.bn7 = nn.BatchNorm2d(base_filter * 8)
        self.conv8 = nn.Conv2d(base_filter * 8, base_filter * 8, kernel_size=3, stride=2, padding=1)
        self.bn8 = nn.BatchNorm2d(base_filter * 8)

        # Replaced original paper FC layers with FCN
        self.conv9 = nn.Conv2d(base_filter * 8, num_channel, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = swish(self.conv1(x))

        x = swish(self.bn2(self.conv2(x)))
        x = swish(self.bn3(self.conv3(x)))
        x = swish(self.bn4(self.conv4(x)))
        x = swish(self.bn5(self.conv5(x)))
        x = swish(self.bn6(self.conv6(x)))
        x = swish(self.bn7(self.conv7(x)))
        x = swish(self.bn8(self.conv8(x)))

        x = self.conv9(x)
        return torch.sigmoid(F.avg_pool2d(x, x.size()[2:])).view(x.size()[0], -1)

    def weight_init(self, mean=0.0, std=0.02):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
            
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        if m.bias is not None:
            m.bias.data.zero_()


##Training

Training of EDSR model

In [None]:
# Setting up parameters
num_epochs = 50
learning_rate = 1e-4
save_interval = 1
scale = 1
# Define the loss function and optimizer
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model_edsr.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-8)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=10,
                                               gamma=0.2)

In [None]:
output = model_edsr(image.to(device).unsqueeze(0))
print(output.shape)

torch.Size([1, 1, 256, 256])


In [None]:
! mkdir checkpoints

In [None]:
model_outputs = "checkpoints"

In [8]:
# Train the model
def train(model, train_loader, val_loader, optimizer, criterion, num_epochs, scheduler):
  train_loss = []
  val_loss = []
  for epoch in range(num_epochs):
      running_train_loss = []
      running_val_loss = []
      model.train()
      for i, (inputs, targets) in tqdm(enumerate(train_loader)):
          # Move the inputs and targets to the device (CPU or GPU)
          inputs = inputs.to(device)
          targets = targets.to(device)

          optimizer.zero_grad()
          
          # Forward pass
          outputs = model(inputs)

          # Compute the loss
          loss = criterion(outputs, targets)
          running_train_loss.append(loss.item())

          # Backward pass and optimization
          loss.backward()
          optimizer.step()

      # Use learning scheduler
      scheduler.step()

      # Test on validation dataset
      model.eval()
      #with torch.no_grad():
      for inputs,targets in val_loader:
          inputs = inputs.to(device)
          targets = targets.to(device)
          outputs = model(inputs)
          loss = criterion(outputs,targets)
          running_val_loss.append(loss.item())

      # Print loss 
      epoch_train_loss = np.mean(running_train_loss)
      print("Epoch [{}/{}], Train Loss: {:.4f}, LearningRate: {} ".format(epoch+1, num_epochs, epoch_train_loss, scheduler.get_lr() ))
      train_loss.append(epoch_train_loss)

      epoch_val_loss = np.mean(running_val_loss)
      print("Epoch [{}/{}], Val Loss: {:.4f}".format(epoch+1, num_epochs, epoch_val_loss))
      val_loss.append(epoch_val_loss)

      # Save the trained model
      if (epoch+1) % save_interval == 0:
        torch.save(model.state_dict(), model_outputs + "/edsr_epoch" + str(epoch) + ".pth")

  return model, train_loss, val_loss

In [None]:
model_edsr, train_loss, val_loss = train(model_edsr, train_loader, val_loader, optimizer, criterion, num_epochs,lr_scheduler)

In [None]:
# Plot training and validation loss
plt.figure()
plt.plot(np.arange(len(train_loss)), train_loss, label ="Train")
plt.plot(np.arange(len(train_loss)), val_loss, label = "Validation")
plt.xlabel("Number of epochs")
plt.ylabel("L1 loss")
plt.legend(loc = "upper right")
plt.savefig("Training_edsr.pdf")
plt.show()

In [None]:
# Load checkpoint
checkpoint_best = torch.load('checkpoints/edsr_epoch19.pth')
model_edsr.load_state_dict(checkpoint_best)

Training of SRGAN model

In [9]:
# This code is modified from 
from __future__ import print_function
from math import log10
import torch.backends.cudnn as cudnn
from torchvision.models.vgg import vgg19
#import progress_bar


class SRGANTrainer(object):
    def __init__(self, config, training_loader, testing_loader):
        super(SRGANTrainer, self).__init__()
        self.GPU_IN_USE = torch.cuda.is_available()
        self.device = torch.device('cuda' if self.GPU_IN_USE else 'cpu')
        self.netG = None
        self.netD = None
        self.lr = config.lr
        self.nEpochs = config.nEpochs
        self.epoch_pretrain = 10
        self.criterionG = None
        self.criterionD = None
        self.optimizerG = None
        self.optimizerD = None
        self.feature_extractor = None
        self.schedulerG = None
        self.schedulerD = None
        self.seed = config.seed
        self.upscale_factor = config.upscale_factor
        self.num_residuals = 16
        self.training_loader = training_loader
        self.testing_loader = testing_loader
        self.train_loss = []
        self.val_loss = []
        self.val_psnr = []

    def build_model(self):
        self.netG = Generator(n_residual_blocks=self.num_residuals, upsample_factor=self.upscale_factor, base_filter=64, num_channel=1).to(self.device)
        self.netD = Discriminator(base_filter=64, num_channel=1).to(self.device)
        self.feature_extractor = vgg19(pretrained=True)
        self.netG.weight_init(mean=0.0, std=0.2)
        self.netD.weight_init(mean=0.0, std=0.2)
        self.criterionG = nn.MSELoss()
        #self.criterionG = nn.L1Loss()
        self.criterionD = nn.BCELoss()
        torch.manual_seed(self.seed)

        if self.GPU_IN_USE:
            torch.cuda.manual_seed(self.seed)
            self.feature_extractor.cuda()
            cudnn.benchmark = True
            self.criterionG.cuda()
            self.criterionD.cuda()

        self.optimizerG = optim.Adam(self.netG.parameters(), lr=self.lr, betas=(0.9, 0.999))
        self.optimizerD = optim.SGD(self.netD.parameters(), lr=self.lr / 100, momentum=0.9, nesterov=True)
        self.schedulerG = optim.lr_scheduler.MultiStepLR(self.optimizerG, milestones=[10, 20, 30], gamma=0.2)  # lr decay
        self.schedulerD = optim.lr_scheduler.MultiStepLR(self.optimizerD, milestones=[10, 20, 30], gamma=0.2)  # lr decay

    @staticmethod
    def to_data(x):
        if torch.cuda.is_available():
            x = x.cpu()
        return x.data

    def save(self, epoch):
        g_model_out_path = "SRGAN_Generator_model_path{}.pth".format(epoch)
        d_model_out_path = "SRGAN_Discriminator_model_path{}.pth".format(epoch)
        torch.save(self.netG, g_model_out_path)
        torch.save(self.netD, d_model_out_path)
        print("Checkpoint saved to {}".format(g_model_out_path))
        print("Checkpoint saved to {}".format(d_model_out_path))

    def pretrain(self):
        self.netG.train()
        for batch_num, (data, target) in enumerate(self.training_loader):
            data, target = data.to(self.device), target.to(self.device)
            self.netG.zero_grad()
            loss = self.criterionG(self.netG(data), target)
            loss.backward()
            self.optimizerG.step()

    def train(self):
        # models setup
        self.netG.train()
        self.netD.train()
        g_train_loss = 0
        d_train_loss = 0
        for batch_num, (data, target) in enumerate(self.training_loader):
            # setup noise
            real_label = torch.ones(data.size(0), data.size(1)).to(self.device)
            fake_label = torch.zeros(data.size(0), data.size(1)).to(self.device)
            data, target = data.to(self.device), target.to(self.device)

            # Train Discriminator
            self.optimizerD.zero_grad()
            d_real = self.netD(target)
            d_real_loss = self.criterionD(d_real, real_label)

            d_fake = self.netD(self.netG(data))
            d_fake_loss = self.criterionD(d_fake, fake_label)
            d_total = d_real_loss + d_fake_loss
            d_train_loss += d_total.item()
            d_total.backward()
            self.optimizerD.step()

            # Train generator
            self.optimizerG.zero_grad()
            g_real = self.netG(data)
            g_fake = self.netD(g_real)
            gan_loss = self.criterionD(g_fake, real_label)
            mse_loss = self.criterionG(g_real, target)

            g_total = mse_loss + 1e-3 * gan_loss
            g_train_loss += g_total.item()
            g_total.backward()
            self.optimizerG.step()

            #progress_bar(batch_num, len(self.training_loader), 'G_Loss: %.4f | D_Loss: %.4f' % (g_train_loss / (batch_num + 1), d_train_loss / (batch_num + 1)))
        self.schedulerG.step()
        self.schedulerD.step()
        self.train_loss.append(g_train_loss / len(self.training_loader))
        print("Average G_Loss: {:.4f}, Lr_Generator:{}, Lr_discriminator: {}".format(g_train_loss / len(self.training_loader), 
                                                                                     self.optimizerG.param_groups[0]["lr"],
                                                                                     self.optimizerD.param_groups[0]["lr"] ))

    def test(self):
        self.netG.eval()
        avg_psnr = 0

        with torch.no_grad():
            for batch_num, (data, target) in enumerate(self.testing_loader):
                data, target = data.to(self.device), target.to(self.device)
                prediction = self.netG(data)
                mse = self.criterionG(prediction, target)
                psnr = 10 * log10(1 / mse.item())
                avg_psnr += psnr
                #progress_bar(batch_num, len(self.testing_loader), 'PSNR: %.4f' % (avg_psnr / (batch_num + 1)))

        self.val_loss.append(np.mean(mse.item()))
        self.val_psnr.append(psnr)
        print("    Average PSNR: {:.4f} dB, val_MSEloss:{}".format(avg_psnr / len(self.testing_loader), np.mean(mse.item()) ))

    def run(self):
        self.build_model()
        for epoch in range(1, self.epoch_pretrain + 1):
            self.pretrain()
            print("{}/{} pretrained".format(epoch, self.epoch_pretrain))

        for epoch in range(1, self.nEpochs + 1):
            print("\n===> Epoch {} starts:".format(epoch))
            self.train()
            self.test()
            if epoch > 40:
              self.save(epoch)

In [10]:
# Training parameters for SRGAN
class Config(object):
  def __init__(self, lr, nEpochs, seed, upscale_factor):
    self.lr = lr
    self.nEpochs = nEpochs
    self.seed = seed
    self.upscale_factor = upscale_factor

config = Config(lr = 1e-3, nEpochs=50, seed = 42, upscale_factor =1)

In [None]:
model_srgan_trainner = SRGANTrainer(config, train_loader, val_loader)
model_srgan_trainner.run()

In [None]:
plt.figure()
plt.plot(np.arange(config.nEpochs), model_srgan_trainner.train_loss, label ="Train")
plt.plot(np.arange(config.nEpochs), model_srgan_trainner.val_loss, label = "Validation")
plt.xlabel("Number of epochs")
plt.ylabel("Train loss")
plt.legend(loc = "upper right")
#plt.savefig("Training_edsr.pdf")
plt.show()

##Performance evaluation on testing dataset

In [None]:
# Import evaluatin metrics
from torchmetrics import PeakSignalNoiseRatio
from torchmetrics import StructuralSimilarityIndexMeasure
import lpips
psnr = PeakSignalNoiseRatio().to(device)
ssim = StructuralSimilarityIndexMeasure().to(device)
loss_fn = lpips.LPIPS(net='alex').to(device)

In [None]:
# Evaluate model performance on test dataset
def test_eval(test_data, model):
  l1_loss = []
  mse_loss = []
  psnr_list = []
  ssim_list = []
  lpips_list = []
  criterion1 = nn.L1Loss()
  criterion2 = nn.MSELoss()

  model.eval()
  with torch.no_grad():
    for i in range(len(test_data)):
      input, target = test_data.__getitem__(i)
      input = input.to(device)
      target = target.to(device)
      output = model(input.unsqueeze(0))
      output_np = output.reshape(output.shape[1],output.shape[2],output.shape[3])
      # Write output to image
      im = Image.fromarray(image_convert(output_np))
      im.save("output{}.tiff".format(i))

      # Compute L1 loss
      loss = criterion1(output,target.unsqueeze(0))
      l1_loss.append(loss.item())

      # Compute MSE loss
      loss2 = criterion2(output,target.unsqueeze(0))
      mse_loss.append(loss2.item())

      # Compute PSNR 
      psnr_value = psnr(output, target.unsqueeze(0))
      psnr_value = psnr_value.clone().cpu().detach().numpy()
      psnr_list.append(psnr_value)
      
      # Compute SSIM
      ssim_value = ssim(output, target.unsqueeze(0))
      ssim_value = ssim_value.clone().cpu().detach().numpy()
      ssim_list.append(ssim_value)

      d = loss_fn.forward(output,target.unsqueeze(0)).clone().cpu().detach().numpy()
      lpips_list.append(d)

      # Display the LR image, output image and HR image.
      plt.figure(i)
      fig, axs = plt.subplots(1, 3, figsize=(15, 10))
      #plt.axis('off')
      axs[0].imshow(image_convert(input))
      axs[0].set_title('LR image')
      axs[1].imshow(image_convert(output_np))
      axs[1].set_title('Output image')
      axs[2].imshow(image_convert(target))
      axs[2].set_title('HR image (Ground Truth)')
      plt.savefig("test_image{}.pdf".format(i), format = "pdf", bbox_inches = "tight")
      plt.show()
  return l1_loss, mse_loss, psnr_list, ssim_list, lpips_list

In [None]:
l1_loss, mse_loss, psnr_list, ssim_list, lpips_list= test_eval(test_data,model_edsr)

In [None]:
l1_loss2, mse_loss2, psnr_list2, ssim_list2, lpips_list2 = test_eval(test_data,model_srgan_trainner.netG)

In [None]:
print("psnr of output: {}".format(np.mean(psnr_list)))
print("lpips of output: {}".format(np.mean(lpips_list)))
print("ssim of output: {}".format(np.mean(ssim_list)))

In [None]:
print("psnr of output: {}".format(np.mean(psnr_list2)))
print("lpips of output: {}".format(np.mean(lpips_list2)))
print("ssim of output: {}".format(np.mean(ssim_list2)))