In [None]:
import os 
import cv2
import random
import torch
import math
import time

import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

from tqdm import *
from glob import glob
from PIL import Image
from math import log2

from torchvision import transforms
from torch import nn, einsum
from torch.autograd import grad as torch_grad
from torch.nn import Linear, SiLU, Upsample,LeakyReLU, \
                     Conv2d, ReLU, InstanceNorm2d, Tanh, AvgPool2d


In [None]:
class InpaintData(torch.utils.data.Dataset):
    def __init__(self, path, img_size=[64, 64]):
        super().__init__()
        self.file_list = []
        self.img_size = img_size
        self.transforms = transforms.Compose([
            # can put some augmentation here
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 
        ])
        for file_name in glob(f'{path}/**/*.jpg', recursive=True):
            self.file_list.append(file_name)
            
    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_name = self.file_list[idx]
        img = Image.open(file_name).convert('RGB')
        img = transforms.functional.resize(img, self.img_size, Image.BICUBIC)
        img = self.transforms(img)
        return img 

In [None]:
path = "/kaggle/input/celebahq/celeba_hq/train"
dataset = InpaintData(path)

train_data = torch.utils.data.DataLoader(
                    dataset,
                    batch_size=16,
                    shuffle= True,
                    pin_memory=True,
                    drop_last=True)

In [None]:
class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)
    
class AdaIN(nn.Module):
    def __init__(self, channels, w_dim):
        super().__init__()
        self.instance_norm = InstanceNorm2d(channels)
        self.scale = Linear(w_dim, channels)
        self.bias = Linear(w_dim, channels)

    def forward(self, x, w):
        x = self.instance_norm(x)
        scale = self.scale(w)[:, :, None, None]
        bias = self.bias(w)[:, :, None, None]
        return scale * x + bias

class AddNoise(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
        return x + self.weight * noise

class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim, n_layers=8):
        super().__init__()
        layers = [PixelNorm(), Linear(z_dim, w_dim), ReLU()]

        for _ in range(n_layers - 1):
            layers.extend([Linear(w_dim, w_dim), ReLU()])
        self.mapping = nn.Sequential(*layers)
    def forward(self, x):
        return self.mapping(x)

In [None]:
class G_block(nn.Module):
    def __init__(self, latent_dim, in_channel, out_channel, upsample=True):
        super().__init__()
        self.upsample = None
        if upsample:
            self.upsample = Upsample(scale_factor=2, mode='bilinear', align_corners=False)

        self.conv1 = Conv2d(in_channel, out_channel, kernel_size=3, padding=1)
        self.conv2 = Conv2d(out_channel, out_channel, kernel_size=3, padding=1)
        self.AddNoise1 = AddNoise(out_channel)
        self.AddNoise2 = AddNoise(out_channel)
        self.AdaIN1 = AdaIN(out_channel, latent_dim)
        self.AdaIN2 = AdaIN(out_channel, latent_dim)
        self.leaky = LeakyReLU(0.2, inplace=True)

    def forward(self, x, w):
        if self.upsample is not None:
            x = self.upsample(x)
        x = self.AdaIN1(self.leaky(self.AddNoise1(self.conv1(x))), w)
        x = self.AdaIN2(self.leaky(self.AddNoise2(self.conv2(x))), w)
        return x
    
class D_block(nn.Module):
    def __init__(self, in_channel, out_channel, downsample = True):
        super().__init__()
        self.conv_res = Conv2d(in_channel, out_channel, kernel_size=1, padding=0)
        self.downsample = None
        if downsample:
            self.downsample = AvgPool2d(kernel_size=2, stride=2)
            
        self.net = nn.Sequential(
            Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
            LeakyReLU(0.2, inplace=True),
            Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
            LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x):
        res = self.conv_res(x)
        x = self.net(x)
        x = (x + res) * (1 / math.sqrt(2))
        if self.downsample is not None:
            x = self.downsample(x)
        return x

In [None]:
class Generator(nn.Module):
    def __init__(self, img_size, z_dim , fmap_max=512, fmap_min=64):
        super().__init__()
        self.z_dim = z_dim
        self.num_layers = int(math.log2(img_size)-1)
        # Generate filter sizes for each layer
        filters = [min(fmap_max, fmap_min * (2 ** i)) for i in range(self.num_layers + 1)]
        filters = list(reversed(filters))
        
        self.mapping = MappingNetwork(z_dim, z_dim)
        self.init_input = nn.Parameter(torch.ones(1, filters[0], 4, 4))
        self.init_conv = Conv2d(filters[0], filters[0], 3, padding=1)
        self.leaky = LeakyReLU(0.2, inplace=True)
        self.final_block = nn.Sequential(
            Conv2d(filters[-1], 3, 3, padding=1),
            Tanh()
        )
        self.blocks = nn.ModuleList([])
        for i, (in_channel, out_channel) in enumerate(zip(filters[:-1], filters[1:])):
            # Apply upscaling except for the first block
            up_sample = i != 0
            g_block = G_block(z_dim, in_channel, out_channel, up_sample)
            self.blocks.append(g_block)

    def forward(self, latents):
        batch_size = latents.shape[0]
        styles = self.mapping(F.normalize(latents, dim=1))
        styles = styles.view([-1, batch_size, self.z_dim])

        x = self.init_conv(self.init_input)
        x = self.leaky(x)
        for block, style in zip(self.blocks, styles):
            x = block(x, style)
        return self.final_block(x)

class Discriminator(nn.Module):
    def __init__(self, img_size, fmap_min=32, fmap_max=512):
        super().__init__()
        self.num_layers = int(log2(img_size) - 1)
        # Generate filter sizes for each layer
        filters = [3] + [min(fmap_max, fmap_min * (2 ** i)) for i in range(self.num_layers)]

        self.blocks = nn.ModuleList([])
        for i, (in_channel, out_channel) in enumerate(zip(filters[:-1], filters[1:])):
            # Apply downsampling except for the last block
            down_sample = i != (len(filters) - 2)
            d_block = D_block(in_channel, out_channel, down_sample)
            self.blocks.append(d_block)
            
        self.final_block = nn.Sequential(
            # The output shape will become [1, 1, 1]
            Conv2d(filters[-1], filters[-1], kernel_size=3, padding=1),
            LeakyReLU(0.2),
            Conv2d(filters[-1], filters[-1], kernel_size=4, padding=0, stride=1),
            LeakyReLU(0.2),
            Conv2d(filters[-1], 1, kernel_size=1, padding=0, stride=1)
        )

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        x = self.final_block(x).squeeze()
        return x

In [None]:
# Helper function
def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

def d_loss_fn(real_pred, fake_pred):
    real_loss = F.softplus(-real_pred)
    fake_loss = F.softplus(fake_pred)
    return real_loss.mean() + fake_loss.mean()
    
def g_loss_fn(fake_pred):
    loss = F.softplus(-fake_pred).mean()
    return loss

def gradient_penalty(images, output, weight = 10):
    batch_size = images.shape[0]
    gradients = torch_grad(outputs=output, inputs=images,
                           grad_outputs=torch.ones(output.size(), device=images.device),
                           create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradients = gradients.reshape(batch_size, -1)
    return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()

def get_g_input(batch, img_size, n_layers, latent_dim, device):
    styles = torch.randn(batch, 1, latent_dim).cuda(device)
    styles = torch.tile(styles, [1, n_layers, 1])
    return styles

In [None]:
class Trainer():
    def __init__(self, img_size, latent_dim, dataloader, ckpt_dir, \
                 epochs, save_n_epoch, lr = 1e-4, load_path=None):
        os.makedirs(ckpt_dir, exist_ok=True)
        os.makedirs('./results/', exist_ok=True)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.num_layers = int(log2(img_size)-1)
        
        self.g = Generator(img_size, latent_dim).to(self.device)
        self.d = Discriminator(img_size).to(self.device)
        
        self.dataloader = dataloader
        self.ckpt_dir = ckpt_dir
        
        self.best_epoch = 1
        self.start_epoch = 1
        self.epochs = epochs
        self.save_n_epoch = save_n_epoch
        
        self.G_opt = torch.optim.Adam(self.g.parameters(), lr=lr, betas=(0.5, 0.9))
        self.D_opt = torch.optim.Adam(self.d.parameters(), lr=lr, betas=(0.5, 0.9))
        
        
        if load_path is not None:
            self.load_state_dict(load_path)
            print("sucessful load state dict !!!!!!")
            print(f"start from epoch {self.start_epoch}")
    
    def state_dict(self, epoch):
        return {
            "epoch": epoch,
            "g_model": self.g.state_dict(),
            "d_model": self.d.state_dict()
        }
    
    def load_state_dict(self, path):
        state_dict = torch.load(path)
        self.g.load_state_dict(state_dict['g_model'])
        self.d.load_state_dict(state_dict['d_model'])
        self.start_epoch = state_dict['epoch']
        
    def d_train_step(self, real_img):
        batch_size = real_img.size(0)
        requires_grad(self.g, False)
        requires_grad(self.d, True)
 
        styles = get_g_input(batch_size, self.img_size, 
                            self.num_layers, self.latent_dim, 
                            self.device)                         

        fake_img = self.g(styles)
        fake_pred = self.d(fake_img)
        real_pred = self.d(real_img.requires_grad_())
        
        # gradient penalty
        alpha = torch.rand([batch_size, 1, 1, 1]).to(self.device)
        mix_img = alpha * fake_img + (1 - alpha) * real_img
        mix_pred = self.d(mix_img)
        gp = gradient_penalty(mix_img, mix_pred)
        d_loss = d_loss_fn(real_pred, fake_pred) + gp
        
        self.d.zero_grad()
        d_loss.backward()
        self.D_opt.step()
        return d_loss
    
    def g_train_step(self, real_img):
        batch_size = real_img.size(0)
        requires_grad(self.g, True)
        requires_grad(self.d, False)
        
        styles = get_g_input(batch_size, self.img_size, 
                            self.num_layers, self.latent_dim, 
                            self.device)  
        fake_img = self.g(styles)
        fake_pred = self.d(fake_img)
        g_loss = g_loss_fn(fake_pred)

        self.g.zero_grad()
        g_loss.backward()
        self.G_opt.step()
        return g_loss
        
    def train(self):
        for epoch in tqdm(range(self.start_epoch, self.epochs+1), desc=f"Training progress"):
            start = time.time()
            print(f'Start of epoch {epoch}')
    
            for i, img in enumerate(self.dataloader):
                img = img.to(self.device)
                self.G_opt.zero_grad()
                self.D_opt.zero_grad()
                
                d_loss = self.d_train_step(img)
                if i % 5 == 0:
                    g_loss = self.g_train_step(img)
                
                if i > 1500:
                    break
                   
            
            if self.best_epoch  > g_loss:
                torch.save(self.state_dict(epoch), f"{self.ckpt_dir}/best_epoch.pt")
                print(f"!!!!!!!!!!!!! saving best epoch {epoch} state dict !!!!!!!```````")
                self.best_epoch = g_loss
                
            if epoch % self.save_n_epoch == 0:
                self.generate_and_create_image(epoch)
                torch.save(self.state_dict(epoch), f"{self.ckpt_dir}/weight_epoch{epoch}.pt")
                print(f"sucessful saving epoch {epoch} state dict !!!!!!!")
                
            time_minutes = (time.time() - start) / 60
            print(f"epoch: {epoch}, D loss: {d_loss.data :.4f} ~~~~~~")
            print(f"epoch: {epoch}, G loss: {g_loss.data :.4f} ~~~~~~")
            print (f'Time taken for epoch {epoch} is {time_minutes:.3f} min\n') 
        print("finish training: ~~~~~~~~~~~~~~~~~~~~~~~~~~")
    
    def generate_and_create_image(self, epoch):
        self.g.eval()
        styles = get_g_input(16, self.img_size, 
                            self.num_layers, self.latent_dim, 
                            self.device)  
        fake_img = self.g(styles)
        num_rows = 4
        num_columns = 4
        fig, axs = plt.subplots(num_rows, num_columns, figsize=(8, 8))

        for i in range(num_rows):
            for j in range(num_columns):
                ax = axs[i, j]
                index = i * num_columns + j
                img = (fake_img[index] + 1) / 2
                img = img.clamp(0, 1)

                # Display the image
                ax.imshow(img.permute(1, 2, 0).cpu().detach().numpy())
                ax.axis('off')
        
        self.g.train()
        plt.savefig(f'./results/result_img{epoch}.jpg')
        plt.show()

In [None]:
clpt_dir = './model_weight/'
#load_path = None
trainer = Trainer(64, 256, train_data, clpt_dir, 100, 5)
trainer.train()