In [None]:
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
from torch import nn
import torch.optim as optim
from torchvision import models
from torchvision.utils import save_image
import os
import tqdm
import PIL
import torch.nn.functional as F
import random
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
torch.backends.cudnn.benchmark =  True
torch.backends.cudnn.enabled =  True

random.seed(0)
torch.manual_seed(0)
np.random.seed(0)

In [None]:
high_res_dim = 256
preprocess = transforms.Compose([
    transforms.Resize(high_res_dim),
    transforms.CenterCrop(high_res_dim),
    transforms.ToTensor(),
])
dataset = ImageFolder('/home/aditya/Datasets/flikr8k/', transform=preprocess)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=40, shuffle=True)

In [None]:
images, labels = next(iter(dataloader))
fig = plt.figure()
rows = 2
cols = 3
for ii in range(1, rows*cols + 1, 1):
    fig.add_subplot(rows, cols, ii)
    plt.imshow(images[ii].permute(1, 2, 0))
    plt.axis('off') 

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.PReLU(),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        out = self.layer1(x)
        return torch.add(out, x)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, stride=1, padding=4, bias=False)
        self.prelu = nn.PReLU()

        self.resblock1 = ResidualBlock(64, 64)

        self.conv2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
                                   nn.BatchNorm2d(64)
                                  )

        self.conv3 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
                                   nn.PixelShuffle(2),
                                   nn.PReLU()
                                  )
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=9, stride=1, padding=4, bias=False)
        
    def forward(self, x):
        block1 = self.prelu(self.conv(x))
        
        block2 = self.resblock1(block1)
        block2 = self.resblock1(block2)
        block2 = self.resblock1(block2)
        block2 = self.resblock1(block2)
        block2 = self.resblock1(block2)

        block2 = self.conv2(block2)
        block2 = torch.add(block2, block1)

        block2 = self.conv3(block2)
        # block2 = self.conv3(block2)
        block2 = self.conv4(block2)
        return block2

In [None]:
class DiscriminatorConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(DiscriminatorConvBlock, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 
                                   nn.BatchNorm2d(out_channels),
                                   nn.LeakyReLU(),
                                 )
    def forward(self, x):
        out = self.conv1(x)
        return out

In [None]:
class Discriminator(nn.Module):
    def __init__(self, low_res_dim):
        super(Discriminator, self).__init__()
        img_d = int(low_res_dim / 8)
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), 
                                  nn.LeakyReLU(),
                                 )
        self.conv2 = DiscriminatorConvBlock(64, 64, 2)
        self.conv3 = DiscriminatorConvBlock(64, 128, 1)
        self.conv4 = DiscriminatorConvBlock(128, 128, 2)
        self.conv5 = DiscriminatorConvBlock(128, 256, 1)
        self.conv6 = DiscriminatorConvBlock(256, 256, 2)
        self.conv7 = DiscriminatorConvBlock(256, 512, 1)
        self.conv8 = DiscriminatorConvBlock(512, 512, 2)

        self.dense1 = nn.Linear(512 * img_d * img_d , 1024)
        self.leakyRelu = nn.LeakyReLU()
        self.dense2 = nn.Linear(1024 , 1)
        self.drop = nn.Dropout(0.3)
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = self.conv6(out)
        out = self.conv7(out)
        out = self.conv8(out)
        out = out.view(-1, out.size(1) * out.size(2) * out.size(3))
        out = self.leakyRelu(self.dense1(out))
        out = torch.sigmoid(self.drop(self.dense2(out)))
        return out
        

In [None]:
num_epochs = 1000
low_res = 128
gen_model = Generator().to(device)
disc_model = Discriminator(low_res).to(device)
vgg = models.vgg19(pretrained=True).to(device)

# gen_model = nn.DataParallel(gen_model, device_ids = [0, 1])
# disc_model = nn.DataParallel(disc_model, device_ids = [0, 1])

gen_optimizer = optim.Adam(gen_model.parameters(),lr=0.0001)
disc_optimizer = optim.Adam(disc_model.parameters(),lr=0.0000001)
gen_scheduler = CosineAnnealingWarmRestarts(gen_optimizer, 
                                        T_0 = 8,# Number of iterations for the first restart
                                        T_mult = 1, # A factor increases TiTi​ after a restart
                                        eta_min = 1e-6) # Minimum learning rate
disc_scheduler = CosineAnnealingWarmRestarts(disc_optimizer, 
                                        T_0 = 8,# Number of iterations for the first restart
                                        T_mult = 1, # A factor increases TiTi​ after a restart
                                        eta_min = 1e-6) # Minimum learning rate
mse_loss = nn.MSELoss()
vgg_loss = nn.MSELoss()
disc_loss = nn.BCELoss()
print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
for epoch in range(num_epochs):
    gen_scheduler.step()
    disc_scheduler.step()
    gen_optimizer.zero_grad()
    dataloader = tqdm.tqdm(dataloader)
    for i, data in enumerate(dataloader):
        input_images, labels = data
        # forward pass
        input_images = input_images.to(device)
        lowres_images = transforms.Resize(low_res)(input_images)
        gen_highres_images = gen_model(lowres_images.to(device))

        # print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
        # print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
        # print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
        
        # Discriminator
        disc_model.zero_grad()
        generated_label = disc_model(gen_highres_images.to(device))
        actual_label = disc_model(input_images.to(device))
        # print("generated_label: ", generated_label)
        # print("actual_label: ", actual_label)
        
        
        # Adversarial loss
        d1_loss = (disc_loss(generated_label, torch.zeros_like(generated_label,dtype=torch.float)))
        d2_loss = (disc_loss(actual_label, torch.ones_like(actual_label,dtype=torch.float)))
        d2_loss.backward()
        d1_loss.backward(retain_graph=True)
        disc_optimizer.step()

        gen_model.zero_grad() 
        # Perceptual loss
        mse = mse_loss(input_images, gen_highres_images)
        with torch.no_grad():
            pred1 = vgg.features[:14](input_images)
            pred2 = vgg.features[:14](gen_highres_images)
        v_loss = vgg_loss(pred1, pred2)

        generator_loss = mse + v_loss
        
        generator_loss.mean().backward()
        gen_optimizer.step()
        gen_optimizer.zero_grad()
        torch.cuda.empty_cache()
        # dataloader.set_description("Epoch %d Generator Loss %f    Discriminator Loss %f" % (epoch, generator_loss.mean(), (d1_loss).mean()))
        dataloader.set_description("Epoch %d Generator Loss %f Discriminator Loss %f Memory %f GB / %f GB" % (epoch, generator_loss.mean(), (d1_loss).mean(), torch.cuda.memory_allocated(0)/1024/1024/1024, torch.cuda.max_memory_reserved(0)/1024/1024/1024))
        # print("MSE loss: ", v_loss)
    if epoch % 1 ==0:
        images, labels = next(iter(dataloader))
        fig = plt.figure()
        rows = 3
        cols = 2
        for ii in range(1, rows*cols + 1, 2):
            lowres_image = transforms.Resize(low_res)(images[ii])
            gen_highres_images = gen_model(torch.unsqueeze(lowres_image, 0).to(device))
            
            fig.add_subplot(rows, cols, ii)
            plt.imshow(images[ii].permute(1, 2, 0))
            plt.axis('off')
            
            fig.add_subplot(rows, cols, ii+1)
            hires = transforms.ToPILImage()
            plt.imshow(hires(gen_highres_images[0]))
            plt.axis('off') 
    if epoch % 5 == 0:
        checkpoint = {'model': Generator(),
              'input_size': 256,
              'output_size': 512,
              'state_dict': gen_model.state_dict()}
        torch.save(checkpoint,os.path.join("/home/aditya/Developer/repro/weights/","SR"+str(epoch+1)+".pth"))
        torch.cuda.empty_cache()

In [None]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = checkpoint['model']
    model.load_state_dict(checkpoint['state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = False
    
    model.eval()
    return model

In [None]:
gen_model = load_checkpoint("/home/aditya/Developer/repro/weights/SR6.pth").to(device)
low_res = 64
high_res = 256
topil = transforms.ToPILImage()
images, labels = next(iter(dataloader))
fig = plt.figure()
rows = 3
cols = 2
for ii in range(1, rows*cols + 1, 2):
    lowres_image = transforms.Resize(low_res)(images[ii])
    gen_highres_images = gen_model(torch.unsqueeze(lowres_image, 0).to(device))
            
    fig.add_subplot(rows, cols, ii)
    plt.imshow(lowres_image.permute(1, 2, 0))
    topil(transforms.Resize(high_res)(lowres_image)).save("resized.jpg")
    plt.axis('off')
            
    fig.add_subplot(rows, cols, ii+1)
        
    topil(gen_highres_images[0]).save("generated.jpg")
    plt.imshow(topil(gen_highres_images[0]))
    plt.axis('off')