In [1]:
import random
import numpy as np
import pandas as pd


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import vgg19
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image, make_grid

import plotly
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm_notebook as tqdm
from sklearn.model_selection import train_test_split

random.seed(42)
import warnings
warnings.filterwarnings("ignore")

In [2]:
# hyperparams 
class Hyperparameters(object):
      def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

hp = Hyperparameters(
    dataset_path = "data/Train/images_png", 
    n_epochs = 100 , 
    batch_size = 8 , 
    learning_rate =  0.00008, 
    n_cpu = 4, 
    height = 512, # image height 
    width = 512, # image width 
    channels = 3, # num of channels in images 
    b1 = 0.5,   # adam: decay of first order momentum of gradient
    b2 = 0.999, # adam: decay of second order momentum of gradient
    decay_epoch = 100 ,  # epoch from which to start lr decay
    cuda = torch.cuda.is_available(), 
    limit = 10000,
    device = "cuda:0"

)

In [3]:
mean = np.array([0.485, 0.456, 0.406])
std  = np.array([0.229, 0.224, 0.225])


In [4]:

class CelebDataset(Dataset):

    def __init__(self,paths) -> None:
        super().__init__()
   
        self.items = paths

        # transforms for low resolution 
        self.low_res_transforms = transforms.Compose([
            transforms.Resize((hp.height//4, hp.width//4), Image.BICUBIC ), 
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std= std)
        ])
        # transforms for high resoultion 
        self.high_res_transforms = transforms.Compose([
            transforms.Resize((hp.height, hp.width), Image.BICUBIC ), 
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std= std)
        ])        

    def __len__(self):
        return len(self.items)
    

    def __getitem__(self, index):
        img = Image.open(self.items[index % len(self.items)]).convert("RGB")
        img_lr = self.low_res_transforms(img)
        img_hr = self.high_res_transforms(img)

        # return {"lr": img_lr, "hr": img_hr}
        return img_hr

In [5]:
import glob

train_paths, test_paths = train_test_split(sorted(glob.glob(hp.dataset_path + "/*.*"))[:hp.limit], test_size=0.02, random_state=42)
train_dataloader = DataLoader(CelebDataset(train_paths), batch_size=hp.batch_size, shuffle=True, num_workers=hp.n_cpu)
test_dataloader = DataLoader(CelebDataset(test_paths), batch_size=int(hp.batch_size*0.75), shuffle=True, num_workers=hp.n_cpu)

In [6]:
def get_mean_and_std(loader):
    mean = 0.
    std = 0.
    total_images = 0

    for images in loader:
        image_count_in_a_batch = images.size(0)
        images = images.view(image_count_in_a_batch , images.size(1) , -1)
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)

        total_images += image_count_in_a_batch
    
    mean /= total_images
    std /= total_images

    return mean , std

In [7]:
mean , std = get_mean_and_std(train_dataloader)
print(f"mean = {mean}")
print(f"std = {std}")

mean = tensor([-0.8496, -0.6413, -0.4793])
std = tensor([0.6298, 0.5712, 0.5333])


In [8]:
# The generator and discriminator architecture: 

class ResidualBlock(nn.Module):
    def __init__(self, in_features) -> None:
        super(ResidualBlock, self).__init__()

        self.residual_block = nn.Sequential(
            nn.Conv2d(in_features, in_features, kernel_size= 3, stride=1,padding=1),
            nn.BatchNorm2d(in_features),
            nn.PReLU(num_parameters= in_features),
            nn.Conv2d(in_features, in_features, kernel_size= 3, stride=1,padding=1),
            nn.BatchNorm2d(in_features)
        )
    
    def forward(self,x):
        return x + self.residual_block(x)
    

    
class Upsample(nn.Module):
    def __init__(self, in_channels, scale_factor) -> None:
        super(Upsample, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * scale_factor ** 2, kernel_size=3, stride=1, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
        self.act = nn.PReLU()

    def forward(self,x):
        return self.act(self.pixel_shuffle(self.conv(x)))
    
# The Generator     
class Generator(nn.Module):

    def __init__(self, in_channels = 3, out_channels= 3, n_residual_block = 16) :
        super(Generator, self).__init__()

        self.conv_inp = nn.Sequential(
                nn.Conv2d(in_channels,64, kernel_size=9, stride=1, padding=4),
                nn.PReLU(num_parameters=64)
        )
        resblocks = [ResidualBlock(64) for _ in range(n_residual_block)]
        self.res_block = nn.Sequential(*resblocks)

        self.mid_conv = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64))

        self.upsamples = nn.Sequential(
            Upsample(64, scale_factor=2),
            Upsample(64, scale_factor=2)

        )

        self.final_conv = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh())

         

    def forward(self,x):
        out1 = self.conv_inp(x)
        out = self.res_block(out1)
        out2 = self.mid_conv(out)
        out = torch.add(out1, out2)
        out = self.upsamples(out)
        out = self.final_conv(out)

        return out





In [9]:
class DescConv(nn.Module):
    def __init__(self, in_channels, out_channels ,use_bn = True ,**kwargs) -> None:
        super().__init__()
        self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.act = nn.LeakyReLU(0.2, inplace= True)

    def forward(self,x):
        return self.act(self.bn(self.cnn(x)))
    

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
        super().__init__()
        blocks = []
        for idx, feature in enumerate(features):
            blocks.append(
                DescConv(
                    in_channels,
                    feature,
                    kernel_size=3,
                    stride=1 + idx % 2, # 1,2,1,2,1,2...... 
                    padding=1,
                    use_bn=False if idx == 0 else True,
                )
            )
            in_channels = feature

        self.blocks = nn.Sequential(*blocks)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((6, 6)),
            nn.Flatten(),
            nn.Linear(512*6*6, 1024), # opt from last DescConv block torch.Size([5, 512, 6, 6]) so 
                                      # 512 * 6 * 6
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
        )

    def forward(self, x):
        x = self.blocks(x)
        return self.classifier(x)

In [10]:
class VggFeatureExtractor(nn.Module):
    def __init__(self):
        super(VggFeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])

    def forward(self, img):
        return self.feature_extractor(img)

In [11]:
def test():
    low_resolution = 24  # 96x96 -> 24x24
    with torch.cuda.amp.autocast():
        x = torch.randn((5, 3, low_resolution, low_resolution))
        gen = Generator()
        gen_out = gen(x)
        disc = Discriminator()
        disc_out = disc(gen_out)

        print(gen_out.shape)
        print(disc_out.shape)

test()

torch.Size([5, 3, 96, 96])
torch.Size([5, 1])


In [12]:

# # models instantiation 


generator = Generator()
discriminator = Discriminator()

feature_extractor = VggFeatureExtractor()
feature_extractor.eval()

# losses 
# gan_loss = torch.nn.MSELoss()
gan_loss = torch.nn.BCEWithLogitsLoss()
content_loss = torch.nn.L1Loss()


# if hp.cuda:
#     generator = generator.cuda()
#     discriminator = discriminator.cuda()
#     feature_extractor = feature_extractor.cuda()
#     gan_loss = gan_loss.cuda()
#     content_loss = content_loss.cuda()

if hp.cuda:
    generator = generator.to(hp.device)
    discriminator = discriminator.to(hp.device)
    feature_extractor = feature_extractor.to(hp.device)
    gan_loss = gan_loss.to(hp.device)
    content_loss = content_loss.to(hp.device)


# optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=hp.learning_rate, betas=(hp.b1, hp.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=hp.learning_rate, betas=(hp.b1, hp.b2))

Tensor = torch.cuda.FloatTensor if hp.cuda else torch.Tensor

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /home/rjn/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:14<00:00, 38.5MB/s] 


In [13]:
from tqdm import tqdm

# train losses
train_gen_loss, train_disc_loss, train_counter = [], [], []
# test losses 
test_gen_loss, test_disc_loss = [], []



for epoch in range(hp.n_epochs):

    ############################ Training ####################
    gen_loss = 0
    disc_loss = 0
    train_bar = tqdm(train_dataloader, desc=f"Training")
    for batch_idx, imgs in enumerate(train_bar):
        generator.train()
        discriminator.train()

        #get the inputs
        # low_res_ipt = Variable(imgs['lr'].type(Tensor).to(hp.device))
        # high_res_ipt = Variable(imgs['hr'].type(Tensor).to(hp.device))

        low_res_ipt = imgs['lr'].to(hp.device)
        high_res_ipt = imgs['hr'].to(hp.device)
        #################### Generator ######################

        optimizer_G.zero_grad()
        generated_hr = generator(low_res_ipt)
        disc_opt = discriminator(generated_hr)

        # Adverserial loss
        loss_GAN = gan_loss(disc_opt, torch.ones_like(disc_opt))

        # content loss
        generated_features = feature_extractor(generated_hr)
        real_feaures = feature_extractor(high_res_ipt)
        loss_CONTENT = content_loss(generated_features, real_feaures)

        # total loss 
        total_loss_generator = loss_CONTENT + 1e-3 * loss_GAN
        
        # backpropagate
        total_loss_generator.backward()
        optimizer_G.step()
        #################### discriminator ######################

        optimizer_D.zero_grad()

        real_disc_opt = discriminator(high_res_ipt)
        loss_D_real = gan_loss(real_disc_opt, torch.ones_like(real_disc_opt))

        fake_disc_opt = discriminator(generated_hr.detach())
        loss_D_fake = gan_loss(fake_disc_opt, torch.zeros_like(fake_disc_opt))

        # total loss 
        total_disc_loss = (loss_D_real+loss_D_fake) / 2 
        
        # backprop
        total_disc_loss.backward()
        optimizer_D.step()

        ################## Accumulate losses ###############

        gen_loss += total_loss_generator.item()
        disc_loss += total_disc_loss.item()

        train_bar.set_postfix(
            gen_loss = gen_loss/( batch_idx + 1), 
            disc_loss = disc_loss / (batch_idx + 1 )
        )
    train_gen_loss.append(gen_loss/len(train_dataloader))
    train_disc_loss.append(disc_loss/len(train_dataloader))


    ############################ Testing ####################
    gen_loss = 0
    disc_loss = 0
    test_bar = tqdm(test_dataloader,  desc=f"Testing")

    for batch_idx, imgs in enumerate(test_bar):
        generator.eval()
        discriminator.eval()

        #get the inputs
        low_res_ipt = imgs['lr'].to(hp.device)
        high_res_ipt = imgs['hr'].to(hp.device)
        
        ############# Generator Eval ###############

        generated_hr = generator(low_res_ipt)
        disc_opt = discriminator(generated_hr)

        # Adverserial loss
        loss_GAN = gan_loss(disc_opt, torch.ones_like(disc_opt))

        # content loss
        generated_features = feature_extractor(generated_hr)
        real_feaures = feature_extractor(high_res_ipt)
        loss_CONTENT = content_loss(generated_features, real_feaures)

        # total loss 
        total_loss_generator = loss_CONTENT + 1e-3 * loss_GAN

        #################### discriminator eval ######################

        real_disc_opt = discriminator(high_res_ipt)
        loss_D_real = gan_loss(real_disc_opt, torch.ones_like(real_disc_opt))

        fake_disc_opt = discriminator(generated_hr.detach())
        loss_D_fake = gan_loss(fake_disc_opt, torch.zeros_like(fake_disc_opt))

        # total loss 
        total_disc_loss = (loss_D_real+loss_D_fake) / 2


        ############### Accumulate losses ##########################
        gen_loss += total_loss_generator.item()
        disc_loss += total_disc_loss.item()

        
        if random.uniform(0,1)<0.1:

            imgs_lr = nn.functional.interpolate(low_res_ipt, scale_factor=4)
            imgs_hr = make_grid(high_res_ipt, nrow=1, normalize=True)
            gen_hr = make_grid(generated_hr, nrow=1, normalize=True)
            imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
            img_grid = torch.cat((imgs_hr, imgs_lr, gen_hr), -1)
            save_image(img_grid, f"images/{batch_idx}.png", normalize=False)


        test_bar.set_postfix(
            gen_loss = gen_loss/( batch_idx + 1), 
            disc_loss = disc_loss / (batch_idx + 1 )
        )
    test_gen_loss.append(gen_loss/len(test_dataloader))
    test_disc_loss.append(disc_loss/len(test_dataloader))

    torch.save(generator.state_dict(), "saved_models/generator.pth")
    torch.save(discriminator.state_dict(), "saved_models/discriminator.pth")
        
       

Training:   0%|          | 0/309 [00:01<?, ?it/s]


TypeError: new(): invalid data type 'str'