In [None]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import torchvision
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from scipy import linalg
from torch.nn.functional import adaptive_avg_pool2d
from torchvision.utils import save_image
from torchvision.models import vgg19
from PIL import Image
print(torch.__version__)

In [None]:
# https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/srgan/models.py

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])

    def forward(self, img):
        return self.feature_extractor(img)

# Load the training data set into two different loader with different resolution one is low resolution as an input to generator 64 * 64 and other as HR image of size 256 * 256.
We will assume 256 * 256 as HR image instead of original image as the size of original image are two big for the resource I currently have
Batch size is also just two images at a time due to memory contraints, going over 2 image primary memory over flow due to the sizing of generator's convolution.
DataSet used for training is downloaded from 
https://data.vision.ee.ethz.ch/cvl/DIV2K/

In [None]:
batchSize = 2
transform_low=transforms.Compose([
                               transforms.Resize((64,64),Image.BICUBIC),
                               transforms.ToTensor(),
                               ])
transform_High=transforms.Compose([
                               transforms.Resize((256,256),Image.BICUBIC),
                               transforms.ToTensor(),
                               ])

# transform = transforms.Compose([transforms.Scale(imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]) # We create a list of transformations (scaling, tensor conversion, normalization) to apply to the input images.
lowResSet = dset.ImageFolder(root = 'C:\\Users\\saura\\Downloads\\DIV2K_train_HR', transform = transform_low) 
highResSet = dset.ImageFolder(root = 'C:\\Users\\saura\\Downloads\\DIV2K_train_HR', transform = transform_High) 
lowResLoader = torch.utils.data.DataLoader(lowResSet, batch_size = batchSize, num_workers = 2)
highResLoader = torch.utils.data.DataLoader(highResSet, batch_size = batchSize, num_workers = 2)


# Save Images as grid

In [None]:
def SaveFigGrid(input_imgs, name):
    imgs =  torchvision.utils.make_grid(input_imgs)
    npImgs = imgs.numpy()
    fig = plt.figure(figsize=(8,8))
    plt.imshow(np.transpose(npImgs, (1,2,0)), cmap = 'Greys_r')
    plt.xticks([])
    plt.yticks([])
    fig.savefig(name)


# Test the loader, dataset are working correctly by savinf some sample in 8 by 8 grid

In [None]:
trainIter  =  iter(highResLoader)
imgs, labels = trainIter.next()
imgs.shape
# SaveFigGrid(imgs,"test")

# Generator Network Model

In [None]:
class G(nn.Module):
    def __init__(self):
            super(G, self).__init__()
            self.layer1 = nn.Sequential(
                nn.Conv2d(3, 64, 9, 1, 4, bias = False),
                nn.PReLU(),
            )
            
            self.layer2 = nn.Sequential(
                nn.ConvTranspose2d(64, 64, 3, 1, 1, bias = False),
                nn.BatchNorm2d(64),
                nn.ReLU(True),
                nn.ConvTranspose2d(64, 64, 3, 1, 1, bias = False),
                nn.BatchNorm2d(64),
            )
            
            self.layer3 = nn.Sequential(
                nn.Conv2d(64, 64, 3, 1, 1, bias = False),
                nn.BatchNorm2d(64),
            )
            
            self.layer4 = nn.Sequential(
                nn.Conv2d(64, 256, 3, 1, 1),
                nn.BatchNorm2d(256),
                nn.PixelShuffle(upscale_factor=2),
                nn.PReLU(),
                nn.Conv2d(64, 256, 3, 1, 1),
                nn.BatchNorm2d(256),
                nn.PixelShuffle(upscale_factor=2),
                nn.PReLU(),
            )
            
            self.layer5 = nn.Sequential(
                nn.Conv2d(64, 3, 9,1,4),
                nn.Tanh()
            )
                
    def forward(self, x):
            out1 = self.layer1(x)
            out = out1
            for _ in range(18):
                out = self.layer2(out)
            out2 = self.layer3(out)
            out = torch.add(out1,out2)
            out = self.layer4(out)
            out = self.layer5(out)
            return out

In [None]:
netG = G()
# netG(imgs)

# Descriminator Network Model

In [None]:
class D(nn.Module):
    def __init__(self):
            super(D, self).__init__()
            
            self.model = nn.Sequential(
                nn.Conv2d(3, 64, 3, 1, 1),
                nn.LeakyReLU(0.1, inplace = True),
                nn.Conv2d(64, 64, 3,2,1),
                nn.BatchNorm2d(64),
                nn.LeakyReLU(0.2, inplace=True),
                
                nn.Conv2d(64, 128, 3, 1, 1),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(0.1, inplace = True),
                nn.Conv2d(128, 128,3,2,1),
                nn.BatchNorm2d(128),
                nn.LeakyReLU(0.2, inplace=True),
                
                nn.Conv2d(128, 256, 3, 1, 1),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.1, inplace = True),
                nn.Conv2d(256, 256,3,2,1),
                nn.BatchNorm2d(256),
                nn.LeakyReLU(0.2, inplace=True),
                
                nn.Conv2d(256, 512, 3, 1, 1),
                nn.BatchNorm2d(512),
                nn.LeakyReLU(0.1, inplace = True),
                nn.Conv2d(512, 512, 3,2,1),
                nn.BatchNorm2d(512),
                nn.LeakyReLU(0.2, inplace=True),
                
                nn.Conv2d(512, 1, 4, 1, 1),
                nn.Sigmoid(),
                
                
            ) 
                
    def forward(self, x):
            return self.model(x)

In [None]:
netD = D()
flag = False     # turn this flag as true if you want to load previously saved models
# netD(imgs).shape

In [None]:
if flag:
    checkpoint = torch.load("srgan/models/model_20_iteration_9399.pth")
    netG.load_state_dict(checkpoint['G_State'])
    netD.load_state_dict(checkpoint['D_State'])

# Optimizers for Generator and Discriminator(TTUR) and Loss function as Binary Cross Entrophy

In [None]:
criterion = nn.BCELoss()
criterion_content = nn.MSELoss()
optD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optG = optim.Adam(netG.parameters(), lr = 0.0001, betas = (0.5, 0.999))

In [None]:
feature_extractor = FeatureExtractor()
feature_extractor.eval()

# Save Models Code

In [None]:
def save_model(name, generator, discriminator):
    torch.save({
                'G': generator,
                'D': discriminator,
                'G_State': generator.state_dict(),
                'D_State': discriminator.state_dict()
               }, name+".pth")


In [None]:
tb = SummaryWriter("srgan/run4")
i = -1
for epoch in range(100):
    print("epoch", epoch)
    lowIter  =  iter(lowResLoader)
    highIter =  iter(highResLoader)
    for _ in range(len(lowIter)):
        i += 1
        optG.zero_grad()
        low, labelLow = lowIter.next()
        high, labelHigh = highIter.next()
        low = Variable(low)
        high = Variable(high)
        SR = netG(low)
        desSR = netD(SR.detach())
        true_labels = torch.ones(desSR.shape)
        adv_loss = criterion(desSR,true_labels)
        
        #content Loss
        sr_feature = feature_extractor(SR)
        hr_feature = feature_extractor(high)
        content_loss = criterion_content(sr_feature,hr_feature.detach())
        
#         https://medium.com/@jonathan_hui/gan-super-resolution-gan-srgan-b471da7270ec
        lossG = (1e-3 * content_loss) + adv_loss
        lossG.backward()
        optG.step()
        
        optD.zero_grad()
        false_labels = torch.zeros(desSR.shape)
        desHR = netD(high)
        desSR = netD(SR.detach())
        loss_real_des = criterion(desHR,true_labels)
        loss_fake_des = criterion(desSR,false_labels)
        lossD = (loss_real_des + loss_fake_des)
        lossD.backward()
        optD.step()
        if (i + 1) % 100 == 0:
            SaveFigGrid(SR.detach(), "srgan/SR/epoch_"+str(epoch)+"_iteration_"+str(i))
            SaveFigGrid(high, "srgan/HR/epoch_"+str(epoch)+"_iteration_"+str(i))
            SaveFigGrid(low, "srgan/LR/epoch_"+str(epoch)+"_iteration_"+str(i))
            save_model("srgan/models/model_"+str(epoch)+"_iteration_"+str(i), netG, netD)   
    
    tb.add_scalar("Adversarial_loss", adv_loss, epoch)
    tb.add_scalar("content_loss", content_loss, epoch)
    tb.add_scalar("Total_Gen_Loss", lossG, epoch)
    tb.add_scalar("Total_Des_Loss", lossD, epoch)
    tb.add_scalars(f'Discriminator loss vs Generator loss', {
                    'Discriminator loss': lossD,
                    'Generator loss': lossG,
                    }, epoch)
    tb.add_scalars(f'Adversial loss vs content loss', {
                    'Adversial loss': lossD,
                    'Generator loss': lossG,
                    }, epoch)
    
    
tb.close()

# Evaluation

In [None]:
from skimage.measure import compare_ssim as ssim
from skimage.measure import compare_psnr as psnr
import numpy as np
import os, sys
from PIL import Image


# Generated all the Super resolution image from BSD dataset and save them to disk for calculating SSIM and PSNR
BSD100 dataset is downloaded from https://github.com/jbhuang0604/SelfExSR/tree/master/data/BSD100

In [None]:
netG = G()
checkpoint = torch.load("srgan/models/model_20_iteration_9399.pth")
netG.load_state_dict(checkpoint['G_State'])


def SaveFig(input_imgs, name):
    npImgs = input_imgs.numpy()[0]
    fig = plt.figure(figsize=(4,4))
    plt.imshow(np.transpose(npImgs, (1,2,0)))
    plt.xticks([])
    plt.yticks([])
    fig.savefig(name)



batch = 1
low_trans=transforms.Compose([
                               transforms.Resize((64,64),Image.BICUBIC),
                               transforms.ToTensor(),
                               ])
lowResValSet = dset.ImageFolder(root = 'C:\\Users\\saura\\Deep Learning\\Proj 2\\valImage', transform = low_trans) 
lowResValLoader = torch.utils.data.DataLoader(lowResValSet, batch_size = batch)
valIter  =  iter(lowResValLoader)



path = "C:\\Users\\saura\\Deep Learning\\Proj 2\\valImage\\BSD"
dirs = os.listdir( path )
for item in dirs:
    low, labelLow = valIter.next()
    low = Variable(low)
    SR = netG(low)
    SaveFig(SR.detach(), "C:\\Users\\saura\\Deep Learning\\Proj 2\\valImage\\SR\\"+item)
    

# Calculate SSIM and PSNR, Please set the correct path for both type of image set before calucation

In [None]:
pathH = "C:\\Users\\saura\\Deep Learning\\Proj 2\\valImage\\BSD"
pathL = "C:\\Users\\saura\\Deep Learning\\Proj 2\\valImage\\SR"
dirs = os.listdir( path )

tb = SummaryWriter("srgan/runVal")
i = -1
ssimT = 0
psnrT = 0
for item in dirs:
    i +=1
    im = Image.open(pathH+"\\"+item)
    imH = np.asarray(im.resize((256,256), Image.BICUBIC).convert('LA'))
    im = Image.open(pathL+"\\"+item)
    imL = np.asarray(im.resize((256,256), Image.BICUBIC).convert('LA'))
    ssimVal = ssim(imH,imL, multichannel = True)
    psnrVal = psnr(imH,imL)
    ssimT += ssimVal
    psnrT += psnrVal
    tb.add_scalar("ssim", ssimVal, i)
    tb.add_scalar("psnr", psnrVal, i)

tb.add_scalar("average_ssim", ssimT/(i+1))
tb.add_scalar("average_psnr", psnrT/(i+1))
tb.close()
