## Import Libraries

In [1]:
import numpy as np
import pandas as pd
import os
import random
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import vgg19
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image, make_grid
from PIL import Image
from tqdm import tqdm_notebook as tqdm
random.seed(42)
import warnings
warnings.filterwarnings("ignore")

## Import Images

In [2]:
class ImageDl(Dataset):
    def __init__(self,files):
        self.files = files
        self.lr_transform = transforms.Compose([
            transforms.Resize((64,64),Image.BICUBIC),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            transforms.Normalize(*stats)   
        ])
        self.hr_transform = transforms.Compose([
            transforms.Resize((256,256),Image.BICUBIC),
            transforms.CenterCrop(256),
            transforms.ToTensor(),
            transforms.Normalize(*stats)   
        ])

    def __getitem__(self, index):
        img = self.files[index % len(self.files)][0]
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)
        return {"lr": img_lr, "hr": img_hr}
    
    def __len__(self):
        return len(self.files)

In [3]:
os.makedirs("generated-images", exist_ok=True)
dataset = ImageFolder('../input/animefacedataset/')
batch_size = 20
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
train_dataloader = DataLoader(ImageDl(dataset), batch_size, shuffle=True, num_workers=3, pin_memory=True)
# pin_memory: Host to GPU copies are much faster when they originate from pinned (page-locked) memory
print('Total batches:',len(train_dataloader))

Total batches: 3179


## Generator

In [4]:
class ResidualBlock(nn.Module):
    
    def __init__(self,in_channels):
        super(ResidualBlock,self).__init__()
        self.res_block = nn.Sequential(
            nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(in_channels,0.8),
            nn.PReLU(),
            nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(in_channels,0.8)
        )
        
    def forward(self,x):
        return x + self.res_block(x)

In [5]:
class Generator(nn.Module):
    
    def __init__(self,in_channels=3,out_channels=3,res_block_size=16):
        super(Generator,self).__init__()
        # first conv block
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels,64,kernel_size=9,stride=1,padding=4),      
            nn.PReLU()
        )
        # residual blocks
        res_blocks = []
        for _ in range(res_block_size):
            res_blocks.append(ResidualBlock(64))
        
        self.res_blocks = nn.Sequential(*res_blocks)
        # second conv layer
        self.conv2 = nn.Sequential(
            nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(64,0.8),            
        )
        # upsampling layers
        upsample_layers = []
        for out_features in range(2):
            upsample_layers += [
                nn.Conv2d(64,256,kernel_size=3,stride=1,padding=1),       
                nn.BatchNorm2d(256),
                nn.PixelShuffle(upscale_factor=2),
                nn.PReLU(),       
            ]
        self.upsample_layers = nn.Sequential(*upsample_layers)
        # third conv layer
        self.conv3 = nn.Sequential(
            nn.Conv2d(64,out_channels,kernel_size=9,stride=1,padding=4),
            nn.Tanh()
        )
        
    def forward(self,x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out2 = self.conv2(out)
        out_final = torch.add(out1,out2)
        out_final = self.upsample_layers(out_final)
        out_final = self.conv3(out_final)
        return out_final
    

## Discriminator

In [6]:
class Discriminator(nn.Module):
    def __init__(self,input_shape):
        super(Discriminator,self).__init__()
        self.input_shape = input_shape
        in_channels,in_height,in_width = self.input_shape
        
        def disc_block(in_filters, out_filters,first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters,out_filters,kernel_size=3,stride=1,padding=1))
            if not first_block :
                layers.append(nn.BatchNorm2d(out_filters))
            
            layers.append(nn.LeakyReLU(0.2,inplace=True))
            layers.append(nn.Conv2d(out_filters,out_filters,kernel_size=3,stride=2,padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2,inplace=True))
            return layers
        
        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64,128,256,512]):
            layers.extend(disc_block(in_filters,out_filters,first_block=(i == 0)))
            in_filters = out_filters
            
        layers.append(nn.Conv2d(out_filters,1,kernel_size=3,stride=1,padding=1))
        self.disc = nn.Sequential(*layers)
        
    def forward(self,img):
        return self.disc(img)

## Training

In [7]:
generator = Generator()
discriminator = Discriminator(input_shape=(3,256,256))
vgg19_model = vgg19(pretrained=True)
# Extract features from vgg pretrained vgg model
feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])       
feature_extractor.eval()
adversarial_crit = torch.nn.MSELoss()
content_crit = torch.nn.L1Loss()

cuda = torch.cuda.is_available()
if cuda:
    Tensor = torch.cuda.FloatTensor    
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    feature_extractor = feature_extractor.cuda()
    adversarial_crit = adversarial_crit.cuda()
    content_crit = content_crit.cuda()
    
else :
    Tensor = torch.Tensor    

lr = 0.00008
EPOCHS = 4
gen_opt = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
disc_opt = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]

In [8]:
for epoch in range(EPOCHS):
    tqdm_bar = tqdm(train_dataloader)
    for idx, imgs in enumerate(tqdm_bar):
        generator.train(); discriminator.train()
        imgs_lr = Variable(imgs["lr"].type(Tensor))
        imgs_hr = Variable(imgs["hr"].type(Tensor))
        valid = Variable(Tensor(np.ones((imgs_lr.size(0),1,16,16))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((imgs_lr.size(0),1,16,16))), requires_grad=False)
        gen_opt.zero_grad()
        gen_hr = generator(imgs_lr)
        loss_adv = adversarial_crit(discriminator(gen_hr), valid)
        gen_features = feature_extractor(gen_hr)
        real_features = feature_extractor(imgs_hr)
        loss_content = content_crit(gen_features, real_features.detach())
        gen_loss = loss_content + 1e-3 * loss_adv
        gen_loss.backward()
        gen_opt.step()
        disc_opt.zero_grad()
        real_loss = adversarial_crit(discriminator(imgs_hr), valid)
        fake_loss = adversarial_crit(discriminator(gen_hr.detach()), fake)
        disc_loss = (real_loss + fake_loss) / 2
        disc_loss.backward()
        disc_opt.step()

        if idx > 3174:
            imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
            gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
            imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
            img_grid = torch.cat((imgs_lr, gen_hr), -1)
            save_image(img_grid, f"generated-images/{idx}.png", normalize=False)



  0%|          | 0/3179 [00:00<?, ?it/s]

  0%|          | 0/3179 [00:00<?, ?it/s]

  0%|          | 0/3179 [00:00<?, ?it/s]

  0%|          | 0/3179 [00:00<?, ?it/s]