In [None]:
import getpass
import os
import torch
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import scipy.ndimage

In [None]:
def define_gpu_to_use(minimum_memory_mb = 3800):
    gpu_to_use = None
    try: 
        os.environ['CUDA_VISIBLE_DEVICES']
        print('GPU already assigned before: ' + str(os.environ['CUDA_VISIBLE_DEVICES']))
        return
    except:
        pass
    torch.cuda.empty_cache()
    for i in range(16):
        free_memory = !nvidia-smi --query-gpu=memory.free -i $i --format=csv,nounits,noheader
        if free_memory[0] == 'No devices were found':
            break
        free_memory = int(free_memory[0])
        if free_memory>2*minimum_memory_mb:
            gpu_to_use = i
            break
    if gpu_to_use is None:
        print('Could not find any GPU available with the required free memory of ' +str(minimum_memory_mb) + 'MB. Please use a different system for this assignment.')
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_to_use)
        print('Chosen GPU: ' + str(gpu_to_use))
        x = torch.rand((256,1024,minimum_memory_mb-500)).cuda()
        x = torch.rand((1,1)).cuda()
        del x

In [None]:
define_gpu_to_use()

In [None]:
train_list = pd.read_csv('train_filenames.csv', header=None)
test_list = pd.read_csv('test_filenames.csv', header=None)
val_list = pd.read_csv('val_filenames.csv', header=None)

In [None]:
train_sr = '/home/sci/amanpreet/Documents/HW/CV/Project/Codes/celebA/Down_sampled/Train/'
test_sr = '/home/sci/amanpreet/Documents/HW/CV/Project/Codes/celebA/Down_sampled/Test/'
val_sr = '/home/sci/amanpreet/Documents/HW/CV/Project/Codes/celebA/Down_sampled/Val/'

hr = "//home/sci/amanpreet/Documents/HW/CV/Project/Codes/celebA/row_col_trim/"

In [None]:
class SR_Data(Dataset):
    
    def __init__(self, sr_folder, hr_folder, file_list):
        self.sr_folder = sr_folder
        self.hr_folder = hr_folder
        self.file_list = file_list
        
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        set_of_transforms = transforms.Compose([
        transforms.ToTensor()
        ])
        return set_of_transforms(Image.open(self.sr_folder+str(self.file_list[idx])).convert('RGB')), set_of_transforms(Image.open(self.hr_folder+str(self.file_list[idx])).convert('RGB'))

In [None]:
train_data = SR_Data(train_sr, hr, train_list[0])
test_data = SR_Data(test_sr, hr, test_list[0])
val_data = SR_Data(val_sr, hr, val_list[0])

print(len(train_data))
print(len(test_data))
print(len(val_data))

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.convolution_layer_0 = nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = 5, stride = 1, padding = 2)
        self.prelu0 = nn.PReLU()
        self.rb_0 =nn.Sequential(
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(64)
        )
        self.rb_1 =nn.Sequential(
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(64)
        )
        self.rb_2 =nn.Sequential(
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(64)
        )
        self.rb_3 =nn.Sequential(
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(64)
        )
        self.rb_4 =nn.Sequential(
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
            nn.BatchNorm2d(64)
        )
        self.convolution_layer_1 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.convolution_layer_2 = nn.Conv2d(in_channels = 64, out_channels = 256, kernel_size = 3, stride = 1, padding = 1)
        self.shuffle_1 = nn.PixelShuffle(2)
        self.prelu1 = nn.PReLU()        
        self.convolution_layer_3 = nn.Conv2d(in_channels = 64, out_channels = 256, kernel_size = 3, stride = 1, padding = 1)
        self.shuffle_2 = nn.PixelShuffle(2)
        self.convolution_layer_4 = nn.Conv2d(in_channels = 64, out_channels = 3, kernel_size = 5, stride = 1, padding = 2)

        
    def forward(self, x):
        x = self.convolution_layer_0(x)
        x = self.prelu0(x)
        
        residual1 = x
        
        x = self.rb_0(x)
        x += residual1
        
        residual2 = x
        
        x = self.rb_1(x)
        
        x += residual2
        
        residual3 = x
        
        x = self.rb_2(x)
        
        x += residual3
        
        residual4 = x
        
        x = self.rb_3(x)
        
        x += residual4
        
        residual5 = x
        
        x = self.rb_4(x)
        
        x += residual5

        x = self.convolution_layer_1(x)
        x = self.bn1(x)
        
        x += residual1
        
        x = self.convolution_layer_2(x)
               
        x = self.shuffle_1(x)
        x = self.prelu1(x)
        x = self.convolution_layer_3(x)
        x = self.shuffle_2(x)
        
        x = self.convolution_layer_4(x)
        
        return x       

In [None]:
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = 3, stride = 1),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, stride = 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = 3, stride = 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 2),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(),
            Flatten(),
            nn.Linear(in_features = 40960, out_features= 1024),
            nn.LeakyReLU(),
            nn.Linear(in_features=1024, out_features= 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.layers(x)

In [None]:
Gen = Generator()
Dis = Discriminator()

In [None]:
model_path = "/home/sci/amanpreet/Documents/HW/CV/Project/Codes/Models/"

In [None]:
Gen.load_state_dict(torch.load(model_path+'Shuffler_Gen_14.pth'))
Dis.load_state_dict(torch.load(model_path+'Shuffler_Dis_14.pth'))

In [None]:
Gen = Gen.cuda()
Dis = Dis.cuda()

In [None]:
train_loader = torch.utils.data.DataLoader(train_data, shuffle = True, batch_size = 16, num_workers = 8)
val_loader = torch.utils.data.DataLoader(val_data, batch_size = 16, num_workers = 8)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = 16, num_workers = 8)

In [None]:
import torchvision.models as models
vgg = models.vgg19(pretrained=True)
class Identity(torch.nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

In [None]:
vgg.classifier = Identity()
vgg = vgg.cuda()

In [None]:
optimizerG = torch.optim.Adam(Gen.parameters(), lr = 0.0001)
optimizerD = torch.optim.Adam(Dis.parameters(), lr = 0.0001)

loss_clf = nn.BCELoss()
loss_content = nn.MSELoss()

In [None]:
for epoch in range(15,35):
    count = 0
    for sr_img, hr_img in train_loader:
               
        # Train the Discriminator
        
        Dis.zero_grad()
        
        sr_img = sr_img.cuda()
        hr_img = hr_img.cuda()
        
        out_D = Dis(hr_img)
        
        bsize = sr_img.size(0)
        label = torch.full((bsize,1), 1, device='cuda')
        
        errD_real = loss_clf(out_D, label)
        
        D_x = out_D.mean().item()
        
        # Get the Fake images 
        
        fake = Gen(sr_img)
        
        out_D_G = Dis(fake)
        
        label = torch.full((bsize,1), 0, device='cuda')

        errD_fake = loss_clf(out_D_G, label)
        errD = errD_fake + errD_real
        
        errD.backward(retain_graph=True)
        D_G_x = out_D_G.mean().item()
        
        # Update Discriminator
        
        optimizerD.step()
        
        # Train the Generator
        
        Gen.zero_grad()
        
        # pass both the Hr and fake to VGG and compute mse on that.
        
        with torch.no_grad():
            fake_vgg = vgg(fake)

            hr_vgg = vgg(hr_img)
        
        errG = torch.dist(fake_vgg, hr_vgg)
        
        err_g_mse = loss_content(fake,hr_img)
        
        errG = 0.006*errG + err_g_mse
        
        out_D_G_1 = Dis(fake)
        label = torch.full((bsize,1), 1, device='cuda')
        e_g = loss_clf(out_D_G_1, label)
        D_G_1 = out_D_G_1.mean().item()
        
        errG = errG + 0.001*e_g # Check paper perceptual loss
        
        errG.backward(retain_graph=True)
        
        # Update the Generator
        
        optimizerG.step()
        
        count = count + 1
        
        if count%100 == 0:
            with torch.no_grad():
                
                print(count)
                
                print("#########################Stats#####################")
                
                print("D_x : ", D_x, " D_G_x : ", D_G_x, " D_G_x_1 : ", D_G_1, " Error : ", errG.item())
                
                f, (ax1, ax2, ax3) = plt.subplots(1, 3)

                c = 0
                for sr, hr in val_loader:
                    if c == 0:
                        sr_rel = sr.cuda()
                        hr_rel = hr.cuda()
                        fk = Gen(sr_rel)
                    c = c + 1
                
                img = np.transpose(hr_rel.cpu().numpy()[0,:,:,:], (1, 2, 0))
                img = (img)*255
                img = img.astype(np.uint8)
                ax1.imshow(img)
                ax1.set_title('Ground truth')
                
                imgf = np.transpose(fk.cpu().numpy()[0,:,:,:], (1, 2, 0))
                imgf = (imgf)*255
                imgf = imgf.astype(np.uint8)
                ax2.imshow(imgf)
                ax2.set_title('Super Resolution')
                
                imgi = np.transpose(sr_rel.cpu().numpy()[0,:,:,:], (1, 2, 0))
                imgi = (imgi)*255
                imgi = imgi.astype(np.uint8)
                ax3.imshow(imgi)
                ax3.set_title('Input')
                
                plt.show()
    
    torch.save(Gen.state_dict(),model_path+'Shuffler_Gen_'+str(epoch)+'.pth')
    torch.save(Dis.state_dict(),model_path+'Shuffler_Dis_'+str(epoch)+'.pth')

In [None]:
for epoch in range(35,65):
    count = 0
    for sr_img, hr_img in train_loader:
               
        # Train the Discriminator
        
        Dis.zero_grad()
        
        sr_img = sr_img.cuda()
        hr_img = hr_img.cuda()
        
        out_D = Dis(hr_img)
        
        bsize = sr_img.size(0)
        label = torch.full((bsize,1), 1, device='cuda')
        
        errD_real = loss_clf(out_D, label)
        
        D_x = out_D.mean().item()
        
        # Get the Fake images 
        
        fake = Gen(sr_img)
        
        out_D_G = Dis(fake)
        
        label = torch.full((bsize,1), 0, device='cuda')

        errD_fake = loss_clf(out_D_G, label)
        errD = errD_fake + errD_real
        
        errD.backward(retain_graph=True)
        D_G_x = out_D_G.mean().item()
        
        # Update Discriminator
        
        optimizerD.step()
        
        # Train the Generator
        
        Gen.zero_grad()
        
        # pass both the Hr and fake to VGG and compute mse on that.
        
        with torch.no_grad():
            fake_vgg = vgg(fake)

            hr_vgg = vgg(hr_img)
        
        errG = torch.dist(fake_vgg, hr_vgg)
        
        err_g_mse = loss_content(fake,hr_img)
        
        errG = 0.006*errG + err_g_mse
        
        out_D_G_1 = Dis(fake)
        label = torch.full((bsize,1), 1, device='cuda')
        e_g = loss_clf(out_D_G_1, label)
        D_G_1 = out_D_G_1.mean().item()
        
        errG = errG + 0.001*e_g # Check paper perceptual loss
        
        errG.backward(retain_graph=True)
        
        # Update the Generator
        
        optimizerG.step()
        
        count = count + 1
        
        if count%100 == 0:
            with torch.no_grad():
                
                print(count)
                
                print("#########################Stats#####################")
                
                print("D_x : ", D_x, " D_G_x : ", D_G_x, " D_G_x_1 : ", D_G_1, " Error : ", errG.item())
                
                f, (ax1, ax2, ax3) = plt.subplots(1, 3)

                c = 0
                for sr, hr in val_loader:
                    if c == 0:
                        sr_rel = sr.cuda()
                        hr_rel = hr.cuda()
                        fk = Gen(sr_rel)
                    c = c + 1
                
                img = np.transpose(hr_rel.cpu().numpy()[0,:,:,:], (1, 2, 0))
                img = (img)*255
                img = img.astype(np.uint8)
                ax1.imshow(img)
                ax1.set_title('Ground truth')
                
                imgf = np.transpose(fk.cpu().numpy()[0,:,:,:], (1, 2, 0))
                imgf = (imgf)*255
                imgf = imgf.astype(np.uint8)
                ax2.imshow(imgf)
                ax2.set_title('Super Resolution')
                
                imgi = np.transpose(sr_rel.cpu().numpy()[0,:,:,:], (1, 2, 0))
                imgi = (imgi)*255
                imgi = imgi.astype(np.uint8)
                ax3.imshow(imgi)
                ax3.set_title('Input')
                
                plt.show()
    
    torch.save(Gen.state_dict(),model_path+'Shuffler_Gen_'+str(epoch)+'.pth')
    torch.save(Dis.state_dict(),model_path+'Shuffler_Dis_'+str(epoch)+'.pth')