## Import libraries

In [None]:
from collections import OrderedDict
from torch import Tensor
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict
from torch import Tensor
from typing import Type, Any, Callable, Union, List, Optional
import glob
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Subset,DataLoader
import torchvision.transforms as transforms
import torchvision
import random
from google.colab import files
from sklearn.metrics import  confusion_matrix
from sklearn.model_selection import ShuffleSplit
import cv2
from google.colab.patches import cv2_imshow
from scipy.ndimage import distance_transform_edt
from torch.autograd import Variable
import skimage.segmentation
import skimage.io
import skimage 
from scipy.optimize import linear_sum_assignment
import skimage.segmentation
import matplotlib.pyplot as plt
import skimage.io
import skimage.segmentation
from skimage import feature
from skimage import filters
import copy
import sklearn
import torchvision
import math
from sklearn.model_selection import train_test_split
import imageio
import tensorflow as tf

## Pre-processing functions

In [None]:
# Image Augmentations
def randomHueSaturationValue(
    image,
    hue_shift_limit=(-40, 40),
    sat_shift_limit=(-10, 10),
    val_shift_limit=(-20, 20),
    u=0.5,
):
    if np.random.random() < u:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(image)
        hue_shift = np.random.randint(
            hue_shift_limit[0], hue_shift_limit[1] + 1)
        hue_shift = np.uint8(hue_shift)
        h += hue_shift
        sat_shift = np.random.uniform(sat_shift_limit[0], sat_shift_limit[1])
        s = cv2.add(s, sat_shift)
        val_shift = np.random.uniform(val_shift_limit[0], val_shift_limit[1])
        v = cv2.add(v, val_shift)
        image = cv2.merge((h, s, v))
        image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
 
    return image
 
 
def randomShiftScaleRotate(
    image,
    shift_limit=(-0.1, 0.1),
    scale_limit=(-0.1, 0.1),
    aspect_limit=(-0.1, 0.1),
    rotate_limit=(-90, 90),
    borderMode=cv2.BORDER_CONSTANT,
    u=0.5,
):
    if np.random.random() < u:
        height, width, channel = image.shape
 
        angle = np.random.uniform(rotate_limit[0], rotate_limit[1])
        scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1])
        aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1])
        sx = scale * aspect / (aspect ** 0.5)
        sy = scale / (aspect ** 0.5)
        dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width)
        dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height)
 
        cc = np.math.cos(angle / 180 * np.math.pi) * sx
        ss = np.math.sin(angle / 180 * np.math.pi) * sy
        rotate_matrix = np.array([[cc, -ss], [ss, cc]])
 
        box0 = np.array([[0, 0], [width, 0], [width, height], [0, height]])
        box1 = box0 - np.array([width / 2, height / 2])
        box1 = np.dot(box1, rotate_matrix.T) + np.array(
            [width / 2 + dx, height / 2 + dy]
        )
 
        box0 = box0.astype(np.float32)
        box1 = box1.astype(np.float32)
        mat = cv2.getPerspectiveTransform(box0, box1)
        image = cv2.warpPerspective(
            image,
            mat,
            (width, height),
            flags=cv2.INTER_NEAREST,
            borderMode=borderMode,
            borderValue=(0, 0, 0),
        )
 
    return image
 
 
def randomHorizontalFlip(image, u=0.5):
    if np.random.random() < u:
        image = cv2.flip(image, 1)
 
    return image
 
def randomVerticleFlip(image, u=0.5):
    if np.random.random() < u:
        image = cv2.flip(image, 0)
 
    return image 
 
def randomRotate90(image, u=0.5):
    if np.random.random() < u:
        image = np.rot90(image)
 
    return image

## Load training and Testing data

In [None]:
# Dataset Definition
class Train(Dataset):
  def __init__(self,train_image_paths,image_dimension,lr_dimension):
    self.image_path=train_image_paths
    self.dim=image_dimension
    self.lr_dim=lr_dimension
  
  def __len__(self):
    return len(self.image_path)

  def __getitem__(self,idx):
    image=imageio.imread(self.image_path[idx])
    image=cv2.resize(image,(256,256))
    # augmentaions applied
    image=randomHueSaturationValue(image)
    image=randomShiftScaleRotate(image)
    image=randomVerticleFlip(image)
    image=randomHorizontalFlip(image)
    image=randomRotate90(image)

    ground_truth=torch.from_numpy(image.copy()).permute(2,0,1)
    input_image=cv2.resize(image,(self.lr_dim,self.lr_dim),interpolation=cv2.INTER_AREA)   
    input_image=cv2.resize(input_image,(self.dim,self.dim)) 
    input_image=torch.from_numpy(input_image.copy()).permute(2,0,1)

    return input_image/255.0, ground_truth/255.0 

In [None]:
class Test(Dataset):
  def __init__(self,test_image_paths,image_dimension,lr_dimension):
    self.image_path=test_image_paths
    self.dim=image_dimension
    self.lr_dim=lr_dimension
  
  def __len__(self):
    return len(self.image_path)

  def __getitem__(self,idx):
    image=imageio.imread(self.image_path[idx])
    image=cv2.resize(image,(256,256))
    ground_truth=torch.from_numpy(image.copy()).permute(2,0,1)
    input_image=cv2.resize(image,(self.lr_dim,self.lr_dim),interpolation=cv2.INTER_AREA)  
    input_image=cv2.resize(input_image,(self.dim,self.dim)) 
    input_image=torch.from_numpy(input_image.copy()).permute(2,0,1)

    return input_image/255.0,ground_truth/255.0

In [None]:
train_ds=Train(train_image_paths,256,86)
test_ds=Test(test_image_paths,256,86)

train_dl=DataLoader(train_ds,batch_size=4,shuffle=True,num_workers=2)
test_dl=DataLoader(test_ds,batch_size=4,shuffle=False,num_workers=2)

print(len(train_ds))
print(len(test_ds))

## MAML-SR Model

In [None]:
class Upsample(nn.Module):
    """ nn.Upsample is deprecated """
    def __init__(self, scale_factor, mode="bilinear"):
        super(Upsample, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode
 
    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=True, recompute_scale_factor=True)
        return x  


class Upsample_meta(nn.Module):
    """ nn.Upsample is deprecated """
    def __init__(self, scale_factor,input_features,output_features, mode="bilinear"):
        super(Upsample_meta, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.conv=nn.Conv2d(input_features,output_features,kernel_size=1,stride=1)
    def forward(self, x):
        x = self.conv(x)
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=True, recompute_scale_factor=True)
        return x  

In [None]:
# Channel attention
class Channel_Attention(nn.Module):
  def __init__(self,num_channels):
    super(Channel_Attention,self).__init__()
    self.avgpool = nn.AdaptiveAvgPool2d(1)
    self.maxpool = nn.AdaptiveMaxPool2d(1)
    self.conv=nn.Conv2d(num_channels*2,num_channels,kernel_size=1,stride=1,bias=False)

  def forward(self,x):
    avg_out=self.avgpool(x)
    max_out=self.maxpool(x)
    att_map=self.conv(torch.cat((avg_out,max_out),1))
    return x*torch.sigmoid(att_map)  

In [None]:
# Spatial Attention
class Spatial_Attention(nn.Module):
  def __init__(self):
    super(Spatial_Attention,self).__init__()
    self.conv1=nn.Conv2d(2,1,kernel_size=3,stride=1,padding=1,bias=False)

  def forward(self,x):
    avg_out = torch.mean(x, dim=1, keepdim=True)
    max_out, _ = torch.max(x, dim=1, keepdim=True)      
    att_map = torch.cat([avg_out, max_out], dim=1)
    att_map = self.conv1(att_map)
    return x*torch.sigmoid(att_map)

In [None]:
# Attention Blocks
class attention_block(nn.Module):
  def __init__(self,num_sets,fsets,lsets): #fsets--> channel sum of all multiscale features; #lsets--> channel size of the specific level of the network; #num_sets--> num of levels whose features are taken
    super(attention_block,self).__init__()
    self.upsampler=nn.ModuleList([Upsample_meta(2**i,fsets[i],lsets) for i in range(1,num_sets)])
    self.sp=Spatial_Attention()
    self.ch=Channel_Attention(num_sets*lsets)
    self.conv=nn.Conv2d(2*num_sets*lsets,lsets,kernel_size=3,stride=1,padding=1,dilation=1)
    self.bn=nn.BatchNorm2d(lsets)

  def forward(self,x1,data):
    for i in range(len(data)):
      data[i]=self.upsampler[i](data[i])
    data=torch.cat(data,1)
    x=torch.cat((x1,data),1)
    x_sp=self.sp(x)
    x_ch=self.ch(x)
    x_out = x1 + self.bn(self.conv(torch.cat((x_sp,x_ch),1)))
    return F.relu(x_out)

In [None]:
class ResidualBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=True, bias=False):
        super(ResidualBlock, self).__init__()
        dim_out = planes
        self.stride = stride
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=bias)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(dim_out, dim_out, kernel_size=(3, 3),
                               stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(planes)
        if downsample == True:
            self.downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=bias),
                nn.BatchNorm2d(planes),
            )
        elif isinstance(downsample, nn.Module):
            self.downsample = downsample
        else:
            self.downsample = None

    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        out = self.bn2(x)
        if self.downsample is not None:
            residual = self.downsample(residual)
        out += residual
        out = self.relu(out)
        return out

class fcn_out(nn.Module):
    def __init__(self,input_channels,upsample_factor):
        super(fcn_out,self).__init__()
        self.conv1=nn.Conv2d(input_channels,64,kernel_size=3,stride=1,padding=3//2,dilation=1)
        self.norm1=nn.BatchNorm2d(64)
        self.upsample=Upsample(upsample_factor)
        self.conv2=nn.Conv2d(64,64,kernel_size=3,stride=1,padding=3//2,dilation=1)
        self.conv3=nn.Conv2d(64,3,kernel_size=1,stride=1,padding=0,dilation=1)
    def forward(self,x):
        x=F.relu(self.norm1(self.conv1(x)))
        x=F.relu(self.conv2(self.upsample(x)))
        x=self.conv3(x)
        return torch.sigmoid(x)

class MAMLSR(nn.Module):
    def __init__(self, num_channels=3):
        super(MAMLSR, self).__init__()

        self.res_conv = ResidualBlock

        self.down1 = self.res_conv(num_channels, 32)
        self.down2 = self.res_conv(32, 64)
        self.down3 = self.res_conv(64, 128)
        self.down4 = self.res_conv(128, 256)

        self.attn1 = attention_block(4,[32,64,128,256],32)
        self.attn2 = attention_block(3,[64,128,256],64)
        self.attn3 = attention_block(2,[128,256],128)

        self.bridge = self.conv_stage(256, 256)
        

        self.up4 = self.res_conv(1024//2, 512//2)
        self.up3 = self.res_conv(512//2, 256//2)
        self.up2 = self.res_conv(256//2, 128//2)
        self.up1 = self.res_conv(128//2, 64//2)

        self.trans4 = self.upsample(512//2, 512//2)
        self.trans3 = self.upsample(512//2, 256//2)
        self.trans2 = self.upsample(256//2, 128//2)
        self.trans1 = self.upsample(128//2, 64//2)

        self.conv_last = nn.Sequential(
            nn.Conv2d(64//2, 3, 3, 1, 1),
            nn.Sigmoid()
        )

        self.max_pool = nn.MaxPool2d(2)
        self.fcn3=fcn_out(512*2//2,8)
        self.fcn2=fcn_out(256*2//2,4)
        self.fcn1=fcn_out(128*2//2,2)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                if m.bias is not None:
                    m.bias.data.zero_()

    def conv_stage(self, dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=True):
        return nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size,
                      stride=stride, padding=padding, bias=bias),
            nn.BatchNorm2d(dim_out),
            nn.LeakyReLU(0.1),
            # nn.ReLU(),
            nn.Conv2d(dim_out, dim_out, kernel_size=kernel_size,
                      stride=stride, padding=padding, bias=bias),
            nn.BatchNorm2d(dim_out),
            nn.LeakyReLU(0.1),
            # nn.ReLU(),
        )

    def upsample(self, ch_coarse, ch_fine):
        return nn.Sequential(
            nn.ConvTranspose2d(ch_coarse, ch_fine, 4, 2, 1, bias=False),
            nn.ReLU()
        )

    def forward(self, x):
        conv1_out = self.down1(x)
        conv2_out = self.down2(self.max_pool(conv1_out))
        conv3_out = self.down3(self.max_pool(conv2_out))
        conv4_out = self.down4(self.max_pool(conv3_out))    # ch = 512  

        # multiscale attention process of the encoder's features
        conv1_out=self.attn1(conv1_out,[conv2_out,conv3_out,conv4_out])
        conv2_out=self.attn2(conv2_out,[conv3_out,conv4_out])
        conv3_out=self.attn3(conv3_out,[conv4_out])

        out = self.bridge(self.max_pool(conv4_out))         # ch = 512  

        out = self.trans4(out)
        out_4 = self.fcn3(torch.cat((out, conv4_out), 1))
        out = self.up4(torch.cat((out, conv4_out), 1))

        out = self.trans3(out)
        out_3 = self.fcn2(torch.cat((out, conv3_out), 1))
        out = self.up3(torch.cat((out, conv3_out), 1))

        out = self.trans2(out)
        out_2 = self.fcn1(torch.cat((out, conv2_out), 1))
        out = self.up2(torch.cat((out, conv2_out), 1))
        
        out = self.up1(torch.cat((self.trans1(out), conv1_out), 1))
        out = self.conv_last(out)

        return out, out_2, out_3, out_4

In [None]:
model = MAMLSR().cuda()

## Optimization objective

In [None]:
loss_l1=nn.L1Loss()
cos=nn.CosineSimilarity()

def loss_cosine(input,target):
  sim=cos(input,target)
  sim=sim.mean((1,2))
  return sim.mean()

def loss_deep1(pred2,pred3,pred4,target):
  l1,l2,l3=1,1,1  # change lamdas according to need
  loss2=loss_l1(pred2,target)
  loss3=loss_l1(pred3,target)
  loss4=loss_l1(pred4,target)
  return l1*loss2+l2*loss3+l3*loss4

def loss_fn(output,target):
  loss=loss_l1(output,target)
  return loss  

## Meta-training MAML-SR

In [None]:
# Meta Training Step
def train_meta(model,train_dl,learn):
  opt_cl=torch.optim.Adam(model.parameters(),lr=learn)  # can change to SGD
  loss_1=0.0 # loss visualization of the first step of optimization
  loss_2=0.0 # loss visualization of the second step of optimization
  for a,b in train_dl:
    a=a.float()
    b=b.float()
    output,pred2,pred3,pred4=model(a.cuda())
    loss3=loss_deep1(pred2,pred3,pred4,b.cuda())
    loss_1 = loss_1 + loss3

    # First Step of Multi-scale Optimization
    opt_cl.zero_grad()
    loss3.backward()
    opt_cl.step()

    # Second stage of optimization
    output,_,_,_=model(a.cuda())
    loss_final=loss_fn(output,b.cuda())
    loss_2 = loss_2 + loss_final
    opt_cl.zero_grad()
    loss_final.backward()
    opt_cl.step()
  
  loss_1=loss_1/len(train_dl)
  loss_2=loss_2/len(train_dl)
  return model, loss_1, loss_2

In [None]:
# Training Step

def step_train_epochs(model,train_dl,epochs,learn,path):
  max_acc=0.0
  for i in range(1,epochs+1):
    model,loss_1,loss_2=train_meta(model,train_dl,learn)
    print("Epoch: " +str(i))
    print("Train_loss_meta_step1: " + str(loss_1.detach().cpu()) + " Train_loss_meta_step2: " + str(loss_2.detach().cpu()))
    print("--------------------")

    path_final=os.path.join(path,
                                   f"epoch{i}_loss1{loss_1.detach().cpu():.4f}_loss2{loss_2.detach().cpu():.4f}.pth")
    torch.save(model.state_dict(), path_final)

In [None]:
%mkdir sisr
step_train_epochs(model,train_dl,20,0.001,'/content/sisr')

## Meta-testing MAML-SR

In [None]:
# Meta Testing Step
def test_meta(model,test_dl,learn):
  opt_cl=torch.optim.Adam(model.parameters(),lr=learn)
  opsnr=0
  ossim=0  
  for a,b in test_dl:  # test_dl is dataloader which loads data in shape (16,3,256,256) 
    a=a.float()
    b=b.float()
    with torch.no_grad():
      sr_init,sr2,_,_=model(a.cuda())
    psnr_init,ssim_init=0,0

    while True:
      s1,s2,_,_=model(a.cuda())
      loss=loss_fn(s1,s2)
      opt_cl.zero_grad()
      loss.backward()
      opt_cl.step()

      sr_p,_,_,_=model(a.cuda())
      psnr_pr=tf.image.psnr(sr_p.detach().cpu().numpy(),sr_init.detach().cpu().numpy(),max_val=1.0)
      psnr_pr=tf.reduce_mean(psnr_pr)
      ssim_pr=tf.image.ssim(sr_p.detach().cpu().numpy().transpose(0,2,3,1),sr_init.detach().cpu().numpy().transpose(0,2,3,1),max_val=1.0)
      ssim_pr=tf.reduce_mean(ssim_pr)
      if psnr_pr>psnr_init and ssim_pr>ssim_init:
        psnr_init=psnr_pr
        ssim_init=ssim_pr
        sr_init=sr_p
      elif psnr_pr<psnr_init and ssim_pr<ssim_init:
        break

    psnr_final=tf.image.psnr(sr_init.detach().cpu().numpy(),b.numpy(),max_val=1.0)   
    psnr_final=tf.reduce_mean(psnr_final) 
    ssim_final=tf.image.ssim(sr_init.detach().cpu().numpy().transpose(0,2,3,1),b.numpy().transpose(0,2,3,1),max_val=1.0)
    ssim_final=tf.reduce_mean(ssim_final)
    opsnr=opsnr+psnr_final
    ossim=ossim+ssim_final
  opsnr=opsnr/len(test_dl)
  ossim=ossim/len(test_dl)
  return model,opsnr,ossim    

In [None]:
model.load_state_dict(torch.load(''))

In [None]:
# Meta Testing Step
model,opsnr,ossim=test_meta(model,train_dl,0.001)
print("OPSNR: " + str(opsnr))
print("OSSIM: " + str(ossim))

In [None]:
torch.save(model.state_dict(),'final_model_sisrx4.pth')