In [None]:
import torch
from torch import nn
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision.utils import save_image
import torchvision
import os

In [None]:
class ResNet(nn.Module):
    def __init__(self, in_channels):
        super(ResNet, self).__init__()
        self.resnet = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(64)          
        )
        self.relu = nn.PReLU()
        
    def forward(self, x):
        skip_connect = x
        x = skip_connect + self.resnet(x)
        return self.relu(x)
        

class Generator(nn.Module):
    def __init__(self, in_channels = 3, hidden_channels = 64):
        super(Generator, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size = 9, stride =1, padding = 4),
            nn.PReLU(),
        )
        
        self.layer2 = nn.Sequential(
            *[ResNet(64) for _ in range(16)]
        )
        
        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(64)
        )
        
        self.layer_up = nn.Sequential(
            nn.Conv2d(64, 256, kernel_size = 3, stride = 1, padding = 1),
            nn.PixelShuffle(2),
            nn.PReLU()
        )
        
        self.final_layer = nn.Conv2d(64, 3, kernel_size = 9, stride = 1, padding = 4) 
        self.tanh = nn.Tanh()
        
    def forward(self, img):
        img1 = self.layer1(img)
        x = self.layer2(img1)
        x = self.layer3(x) + img1
#         print(x.shape)
        x = self.layer_up(x)
        x = self.layer_up(x)
#         print(x.shape)
        x = self.final_layer(x)
        return self.tanh(x)

In [None]:
def test():
    imgs_fake = torch.randn((1, 3, 64, 64))
    model = Generator()
    preds = model(imgs_fake)
    print(preds.shape)

In [None]:
test()

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 3):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size = 3, stride = 1, padding = 1)
        self.relu1 = nn.LeakyReLU(0.2)
        con_blocks = []
        con_blocks.append(self.con_block(64, 64, 3, 2, 1))
        con_blocks.append(self.con_block(64, 128, 3, 1, 1))
        con_blocks.append(self.con_block(128, 128, 3, 2, 1))
        con_blocks.append(self.con_block(128, 256, 3, 1, 1))
        con_blocks.append(self.con_block(256, 256, 3, 2, 1))
        con_blocks.append(self.con_block(256, 512, 3, 1, 1))
        con_blocks.append(self.con_block(512, 512, 3, 2, 1))
        self.seq = nn.Sequential(*con_blocks)
        self.linear1 = nn.Linear(512*16*16, 1024) # 512*6*6
        self.relu2 = nn.LeakyReLU(0.2)
        self.linear2 = nn.Linear(1024, 1)
#         self.sig = nn.Sigmoid()
        
    def con_block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
    
    def forward(self, img):
        x = self.conv1(img)
        x = self.relu1(x)
        x = self.seq(x)
#         print(x.shape)
        x = x.reshape(x.shape[0], -1)
        x = self.linear1(x)
        x = self.relu2(x)
        x = self.linear2(x)
#         x =  self.sig(x)
        return x

In [None]:
def test():
    imgs_fake = torch.randn((1, 3, 256, 256))
    model = Discriminator()
    preds = model(imgs_fake)
    print(preds.shape)

In [None]:
test()

In [None]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02),
        torch.nn.init.normal_(m.bias, 0)

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root_dir_lr, root_dir_hr, lr_transform = None, hr_transform = None):
        self.root_dir_lr = root_dir_lr
        self.root_dir_hr = root_dir_hr
        self.lr_files = os.listdir(self.root_dir_lr)
        self.hr_files = os.listdir(self.root_dir_hr)
        self.lr_transform = lr_transform
        self.hr_transform = hr_transform

    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, index):
        lr_img = self.lr_files[index % len(self.lr_files)]
        hr_img = self.hr_files[index % len(self.hr_files)]
        lr_img_path = os.path.join(self.root_dir_lr, lr_img)
        hr_img_path = os.path.join(self.root_dir_hr, hr_img)
        img_lr = Image.open(lr_img_path)
        img_hr = Image.open(hr_img_path)

        if self.lr_transform or self.hr_transform :
            img_lr = self.lr_transform(img_lr)
            img_hr = self.hr_transform(img_hr)

        return img_lr, img_hr

In [None]:
transforms_lr = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transforms_hr = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = ImageDataset("Data/LR/", "Data/HR/", lr_transform=transforms_lr, hr_transform = transforms_hr)
dataloader = DataLoader(dataset, batch_size= 16, shuffle = True)

In [None]:
# for z, h in dataloader:
#         print(z.shape)
#         save_image(z, "z.png")
#         save_image(h, "h.png")
#         break;

In [None]:
from torchvision.models import vgg19

class VGGFeatures(nn.Module):
    def __init__(self):
        super(VGGFeatures, self).__init__()
        self.vgg19 = vgg19(pretrained = True).features[:18].eval().to('cpu')
        
        for param in self.vgg19.parameters():
            param.requires_grad = False
            
    def forward(self, img):
        return self.vgg19(img)
    
vgg = VGGFeatures()

In [None]:
device = 'cpu'
lr = 1e-4
beta_1 = 0.9
beta_2 = 0.999
n_epochs = 100
img_channels = 3
steps = 6
batch_size = 16

gen = Generator().to(device)
disc = Discriminator().to(device)

disc_opt = torch.optim.Adam(disc.parameters(), lr = lr, betas = (beta_1, beta_2))
gen_opt = torch.optim.Adam(gen.parameters(), lr = lr, betas = (beta_1, beta_2))
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

mse = nn.MSELoss()
bce = nn.BCEwithLogitsLoss()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    img = img.detach().cpu()
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
current_step = 0
gen_losses = 0
disc_losses = 0
os.makedirs("srgan", exist_ok=True)

for epochs in range(n_epochs):
    for lr, hr in dataloader:
        lr = lr.to(device)
        hr = hr.to(device)
        current_batch_size = len(lr)
        
        disc_opt.zero_grad()
        fake_hr = gen(lr)
        disc_preds_fake = disc(fake_hr.detach())
        disc_fake_loss = bce(disc_preds_fake, torch.zeros_like(disc_preds_fake))
        disc_preds_real = disc(hr)
        disc_real_loss = bce(disc_preds_real, torch.ones_like(disc_preds_real))
        disc_loss = (disc_fake_loss + disc_real_loss ) / 2
        disc_loss.backward(retain_graph = True)
        disc_opt.step()
        disc_losses += disc_loss.item()
        
        gen_opt.zero_grad()
        fake_hr2 = gen(lr)
        disc_preds_fake = disc(fake_hr2)
        adv_loss = bce(disc_preds_fake, torch.ones_like(disc_preds_fake))
        vgg_loss = mse(vgg(hr), vgg(fake_hr2))
        perpectual_loss = vgg_loss + 1e-3*adv_loss
        perpectual_loss.backward()
        gen_opt.step()
        gen_losses += perpectual_loss.item()
        
        if current_step % steps == 0 and current_step > 0 :
            print(f"Epochs: {epochs} Step: {current_step} Generator loss: {gen_losses / steps}, discriminator loss: {disc_losses / steps}")
            img_grid_hr = torchvision.utils.make_grid(torch.cat([hr[:1], fake_hr2[:1]]), nrow = 2)
            save_image(fake_hr2.data[:3], "srgan/%d.png" % current_step, nrow=3, normalize=True)
            matplotlib_imshow(img_grid_hr, one_channel=False)
            gen_losses = 0
            disc_losses = 0
        
        current_step += 1