In [None]:
import time
import argparse
import numpy as np
from PIL import  Image
import pickle as pkl

import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision.utils import make_grid, save_image
from torch.utils.data import Dataset, DataLoader
from tensorboardX import SummaryWriter
from tqdm import tqdm
from copy import deepcopy
import cv2
import matplotlib.pyplot as plt
import os

!mkdir checkpoint
!mkdir generated_imgs
!mkdir output_imgs
!mkdir real_imgs
!mkdir fake_imgs
!mkdir real_test_imgs
!pip install tensorboardX
!pip install pytorch-fid

**Storing the path to dataset in the lists**

In [None]:
TrainA = []
TrainB = []
TestA = []
TestB = []

for dirname, _, filenames in os.walk('/kaggle/input/custom/celebA_dataset/Training/TrainA/'):
    for filename in filenames:
        TrainA.append(os.path.join(dirname, filename))
        TrainB.append(os.path.join('/kaggle/input/custom/celebA_dataset/Training/TrainB/', filename))
        
for dirname, _, filenames in os.walk('/kaggle/input/custom/celebA_dataset/Test/TestB/'):
    for filename in filenames:
        TestA.append(os.path.join(dirname, filename))
        TestB.append(os.path.join('/kaggle/input/custom/celebA_dataset/Test/TestB/', filename))

**Preprocess the dataset with a random rectangular mask**

In [None]:
class InpaintDataset(Dataset):
    def __init__(self, real_image_paths, mask_image_paths, transform=False):
        self.real_image_paths = real_image_paths
        self.transform = transform
    
    def __len__(self):
        return len(self.real_image_paths)

    def __getitem__(self, idx):
        real_image_filepath = self.real_image_paths[idx]
        real_img = Image.open(real_image_filepath).convert("RGB")
        mask_img = InpaintDataset.read_val_mask(self, real_image_filepath)
       
    
        if self.transform is not None:
            real_img = self.transform(real_img)
            mask_img = self.transform(mask_img)
        
        return real_img, mask_img
    
    def read_val_mask(self, path):
        """
        Read masks from val mask data
        """
        image = cv2.imread(path)[...,::-1]
        height, width, channels = image.shape
        pos =  random.randint(50, 120)
        a = random.randint(0, height-pos)
        b = random.randint(0, width-pos)
        c = a + pos
        d = b + pos
        start_point = (a, b)
        end_point = (c, d)
        color = (255, 255, 255)
        thickness = -1
        image = Image.fromarray(image.astype(np.uint8))
        image = np.array(image)
        image = cv2.rectangle(image, start_point, end_point, color, thickness)
        return transforms.ToPILImage()(image)

In [None]:
class LinearLrDecay(object):
    def __init__(self, optimizer, start_lr, end_lr, decay_start_step, decay_end_step):

        assert start_lr > end_lr
        self.optimizer = optimizer
        self.delta = (start_lr - end_lr) / (decay_end_step - decay_start_step)
        self.decay_start_step = decay_start_step
        self.decay_end_step = decay_end_step
        self.start_lr = start_lr
        self.end_lr = end_lr

    def step(self, current_step):
        if current_step <= self.decay_start_step:
            lr = self.start_lr
        elif current_step >= self.decay_end_step:
            lr = self.end_lr
        else:
            lr = self.start_lr - self.delta * (current_step - self.decay_start_step)
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
        return lr

def inits_weight(m):
        if type(m) == nn.Linear:
                nn.init.xavier_uniform(m.weight.data, 1.)

def save_checkpoint(states,is_best, output_dir, epoch,
                    filename='checkpoint.pth'):
    torch.save(states, os.path.join(output_dir, filename+str(epoch)))
    if is_best:
        torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth'))

In [None]:
def DiffAugment(x, policy='', channels_first=True):
    if policy:
        if not channels_first:
            x = x.permute(0, 3, 1, 2)
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x)
        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
    return x


def rand_brightness(x):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x


def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x


def rand_contrast(x):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x


def rand_translation(x, ratio=0.2):
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x


def rand_cutout(x, ratio=0.5):
    if random.random() < 0.3:
        cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
        offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
        offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
        grid_batch, grid_x, grid_y = torch.meshgrid(
            torch.arange(x.size(0), dtype=torch.long, device=x.device),
            torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
            torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
        )
        grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
        grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
        mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
        mask[grid_batch, grid_x, grid_y] = 0
        x = x * mask.unsqueeze(1)
    return x

def rand_rotate(x, ratio=0.5):
    k = random.randint(1,3)
    if random.random() < ratio:
        x = torch.rot90(x, k, [2,3])
    return x

AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
    'rotate': [rand_rotate],
}

In [None]:
# training hyperparameters given by code author

lr_gen = 0.0001 #Learning rate for generator
lr_dis = 0.0001 #Learning rate for discriminator
latent_dim = 1024 #Latent dimension
gener_batch_size = 32 #Batch size for generator
dis_batch_size = 32 #Batch size for discriminator
epoch = 10 #Number of epoch
weight_decay = 1e-3 #Weight decay
drop_rate = 0.5 #dropout
n_critic = 5 #
max_iter = 500000
img_name = "img_name"
lr_decay = True

# architecture details by authors
image_size = 64 #H,W size of image for discriminator
initial_size = 8 #Initial size for generator
patch_size = 4 #Patch size for generated image
num_classes = 1 #Number of classes for discriminator 
output_dir = 'checkpoint' #saved model path
dim = 768 #Embedding dimension 
optimizer = 'Adam' #Optimizer
loss = "wgangp_eps" #Loss function
phi = 1 #
beta1 = 0 #
beta2 = 0.99 #
diff_aug = "translation,cutout,color" #data augmentation


In [None]:
import torch
import torch.nn as nn


class MLP(nn.Module):
    def __init__(self, in_feat, hid_feat=None, out_feat=None,
                 dropout=0.):
        super().__init__()
        if not hid_feat:
            hid_feat = in_feat
        if not out_feat:
            out_feat = in_feat
        self.fc1 = nn.Linear(in_feat, hid_feat)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hid_feat, out_feat)
        self.droprateout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return self.droprateout(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=4, attention_dropout=0., proj_dropout=0.):
        super().__init__()
        self.heads = heads
        self.scale = 1./dim**0.5

        self.qkv = nn.Linear(dim, dim*3, bias=False)
        self.attention_dropout = nn.Dropout(attention_dropout)
        self.out = nn.Sequential(
            nn.Linear(dim, dim),
            nn.Dropout(proj_dropout)
        )

    def forward(self, x):
        b, n, c = x.shape
        qkv = self.qkv(x).reshape(b, n, 3, self.heads, c//self.heads)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)

        dot = (q @ k.transpose(-2, -1)) * self.scale
        attn = dot.softmax(dim=-1)
        attn = self.attention_dropout(attn)

        x = (attn @ v).transpose(1, 2).reshape(b, n, c)
        x = self.out(x)
        return x

class ImgPatches(nn.Module):
    def __init__(self, input_channel=3, dim=768, patch_size=4):
        super().__init__()
        self.patch_embed = nn.Conv2d(input_channel, dim,
                                     kernel_size=patch_size, stride=patch_size)

    def forward(self, img):
        patches = self.patch_embed(img).flatten(2).transpose(1, 2)
        return patches

def UpSampling(x, H, W):
        B, N, C = x.size()
        # print(x.size())
        # assert N == H*W
        x = x.permute(0, 2, 1)
        x = x.view(-1, C, H, W)
        x = nn.PixelShuffle(2)(x)
        B, C, H, W = x.size()
        x = x.view(-1, C, H*W)
        x = x.permute(0,2,1)
        return x, H, W

class Encoder_Block(nn.Module):
    def __init__(self, dim, heads, mlp_ratio=4, drop_rate=0.):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, heads, drop_rate, drop_rate)
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, dim*mlp_ratio, dropout=drop_rate)

    def forward(self, x):
        x1 = self.ln1(x)
        x = x + self.attn(x1)
        x2 = self.ln2(x)
        x = x + self.mlp(x2)
        return x


class TransformerEncoder(nn.Module):
    def __init__(self, depth, dim, heads, mlp_ratio=4, drop_rate=0.):
        super().__init__()
        self.Encoder_Blocks = nn.ModuleList([
            Encoder_Block(dim, heads, mlp_ratio, drop_rate)
            for i in range(depth)])

    def forward(self, x):
        for Encoder_Block in self.Encoder_Blocks:
            x = Encoder_Block(x)
        return x

class Generator(nn.Module):
    """docstring for Generator"""
    def __init__(self, depth1=5, depth2=4, depth3=2, initial_size=8, dim=768, heads=4, mlp_ratio=4, drop_rate=0.):#,device=device):
        super(Generator, self).__init__()

        #self.device = device
        self.initial_size = initial_size
        self.dim = dim
        self.depth1 = depth1
        self.depth2 = depth2
        self.depth3 = depth3
        self.heads = heads
        self.mlp_ratio = mlp_ratio
        self.droprate_rate =drop_rate
        self.patches = ImgPatches(3, dim, 4)

        self.mlp = nn.Linear(1024, (self.initial_size ** 2) * self.dim)

        self.positional_embedding_1 = nn.Parameter(torch.zeros(1, (16**2), 768))
        self.positional_embedding_2 = nn.Parameter(torch.zeros(1, (16*2)**2, 768//4))
        self.positional_embedding_3 = nn.Parameter(torch.zeros(1, (16*4)**2, 768//16))

        self.TransformerEncoder_encoder1 = TransformerEncoder(depth=self.depth1, dim=self.dim,heads=self.heads, mlp_ratio=self.mlp_ratio, drop_rate=self.droprate_rate)
        self.TransformerEncoder_encoder2 = TransformerEncoder(depth=self.depth2, dim=self.dim//4, heads=self.heads, mlp_ratio=self.mlp_ratio, drop_rate=self.droprate_rate)
        self.TransformerEncoder_encoder3 = TransformerEncoder(depth=self.depth3, dim=self.dim//16, heads=self.heads, mlp_ratio=self.mlp_ratio, drop_rate=self.droprate_rate)


        self.linear = nn.Sequential(nn.Conv2d(self.dim//16, 3, 1, 1, 0))

    def forward(self, img):
        x = self.patches(img)

        x = x + self.positional_embedding_1
        H, W = self.initial_size, self.initial_size
        x = self.TransformerEncoder_encoder1(x)
        
        

        x,H,W = UpSampling(x,H,W) 
        x = x + self.positional_embedding_2
        x = self.TransformerEncoder_encoder2(x)

        x,H,W = UpSampling(x,H,W)
        x = x + self.positional_embedding_3

        x = self.TransformerEncoder_encoder3(x)
        x = self.linear(x.permute(0, 2, 1).view(-1, self.dim//16, H, W))

        return x

class Discriminator(nn.Module):
    def __init__(self, diff_aug, image_size=64, patch_size=4, input_channel=3, num_classes=1,
                 dim=768, depth=7, heads=4, mlp_ratio=4,
                 drop_rate=0.):
        super().__init__()
        if image_size % patch_size != 0:
            raise ValueError('Image size must be divisible by patch size.')
        num_patches = (image_size//patch_size) ** 2
        self.diff_aug = diff_aug
        self.patch_size = patch_size
        self.depth = depth
        # Image patches and embedding layer
        self.patches = ImgPatches(input_channel, dim, self.patch_size)

        # Embedding for patch position and class
        self.positional_embedding = nn.Parameter(torch.zeros(1, 65, dim))
        self.class_embedding = nn.Parameter(torch.zeros(1, 1, dim))
        nn.init.trunc_normal_(self.positional_embedding, std=0.2)
        nn.init.trunc_normal_(self.class_embedding, std=0.2)

        self.droprate = nn.Dropout(p=drop_rate)
        self.TransfomerEncoder = TransformerEncoder(depth, dim, heads,
                                      mlp_ratio, drop_rate)
        self.norm = nn.LayerNorm(dim)
        self.out = nn.Linear(dim, num_classes)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        x = DiffAugment(x, self.diff_aug)
        b = x.shape[0]
        # print(x.shape)
        cls_token = self.class_embedding.expand(b, -1, -1)

        x = self.patches(x)
        x = torch.cat((cls_token, x), dim=1)
        # print(self.patch_size)
        # print(self.positional_embedding.shape)
        x += self.positional_embedding
        x = self.droprate(x)
        x = self.TransfomerEncoder(x)
        x = self.norm(x)
        x = self.out(x[:, 0])
        return x

In [None]:
if torch.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"

device = torch.device(dev)

generator= Generator(depth1=5, depth2=4, depth3=2, initial_size=8, dim=768, heads=4, mlp_ratio=4, drop_rate=0.5)#,device = device)
generator.to(device)

discriminator = Discriminator(diff_aug = diff_aug, image_size=64, patch_size=4, input_channel=3, num_classes=1,
                 dim=768, depth=7, heads=4, mlp_ratio=4,
                 drop_rate=0.5)
discriminator.to(device)


generator.apply(inits_weight)
discriminator.apply(inits_weight)

In [None]:
# Custom code to resume training from the previous checkpoint 
'''
whole_model_path = '/kaggle/input/modelcheckpoint125/checkpoint.pth125'
nets = torch.load(whole_model_path)
print("Loaded Model Checkpoint 130")
netG_state_dict, netD_state_dict = nets['generator_state_dict'], nets['discriminator_state_dict']

if torch.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"

device = torch.device(dev)

generator= Generator()
generator.load_state_dict(netG_state_dict)
generator.to(device)

discriminator = Discriminator(diff_aug = diff_aug)
discriminator.load_state_dict(netD_state_dict)
discriminator.to(device)

'''

In [None]:
if optimizer == 'Adam':
    optim_gen = optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()), lr=lr_gen, betas=(beta1, beta2))

    optim_dis = optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()),lr=lr_dis, betas=(beta1, beta2))
elif optimizer == 'SGD':
    optim_gen = optim.SGD(filter(lambda p: p.requires_grad, generator.parameters()),
                lr=lr_gen, momentum=0.9)

    optim_dis = optim.SGD(filter(lambda p: p.requires_grad, discriminator.parameters()),
                lr=lr_dis, momentum=0.9)

elif optimizer == 'RMSprop':
    optim_gen = optim.RMSprop(filter(lambda p: p.requires_grad, discriminator.parameters()),
                lr=lr_dis, eps=1e-08, weight_decay=weight_decay, momentum=0, centered=False)

    optim_dis = optim.RMSprop(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=lr_dis, eps=1e-08, weight_decay=weight_decay, momentum=0, centered=False)

gen_scheduler = LinearLrDecay(optim_gen, lr_gen, 0.0, 0, max_iter * n_critic)
dis_scheduler = LinearLrDecay(optim_dis, lr_dis, 0.0, 0, max_iter * n_critic)

print("optimizer:",optimizer)

fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'

writer=SummaryWriter()
writer_dict = {'writer':writer}
writer_dict["train_global_steps"]=0
writer_dict["valid_global_steps"]=0

In [None]:
GeneratorLoss = []
DiscriminatorLoss = []

In [None]:
def compute_gradient_penalty(D, real_samples, fake_samples, phi):
    """Calculates the gradient penalty loss for WGAN GP"""
    alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1))).to(real_samples.get_device())
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones([real_samples.shape[0], 1], requires_grad=False).to(real_samples.get_device())
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.contiguous().view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - phi) ** 2).mean()
    return gradient_penalty


def train(noise,generator, discriminator, optim_gen, optim_dis,
        epoch, writer, schedulers, img_size=64, latent_dim = latent_dim,
        n_critic = n_critic,
        gener_batch_size=gener_batch_size, device="cuda:0"):


    writer = writer_dict['writer']
    gen_step = 0

    generator = generator.train()
    discriminator = discriminator.train()

    transform = transforms.Compose(
        [
            transforms.Resize(size=(64, 64)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    
    train_set = InpaintDataset(TrainB, TrainB, transform=transform)

    train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=10, shuffle=True)
    total_gener_loss = 0
    total_dis_loss = 0
    
    !mkdir output_imgs/{epoch}/
    !mkdir generated_imgs/{epoch}/
    !mkdir real_imgs/{epoch}/
    !mkdir fake_imgs/{epoch}/
    !mkdir real_test_imgs/{epoch}/

    for index, (real_imgs, mask_imgs) in enumerate(train_loader):

        global_steps = writer_dict['train_global_steps']

        real_imgs = real_imgs.type(torch.cuda.FloatTensor)
    
        mask_imgs = mask_imgs.type(torch.cuda.FloatTensor)
        optim_dis.zero_grad()
        real_valid = discriminator(real_imgs)
        fake_imgs = generator(mask_imgs).detach()
        
        fake_valid = discriminator(fake_imgs)

        if loss == 'hinge':
            loss_dis = torch.mean(nn.ReLU(inplace=True)(1.0 - real_valid)).to(device) + torch.mean(nn.ReLU(inplace=True)(1 + fake_valid)).to(device)
        elif loss == 'wgangp_eps':
            gradient_penalty = compute_gradient_penalty(discriminator, real_imgs, fake_imgs.detach(), phi)
            loss_dis = -torch.mean(real_valid) + torch.mean(fake_valid) + gradient_penalty * 10 / (phi ** 2)        

        loss_dis.backward()
        optim_dis.step()

        writer.add_scalar("loss_dis", loss_dis.item(), global_steps)

        if global_steps % n_critic == 0:

            optim_gen.zero_grad()
            if schedulers:
                gen_scheduler, dis_scheduler = schedulers
                g_lr = gen_scheduler.step(global_steps)
                d_lr = dis_scheduler.step(global_steps)
                writer.add_scalar('LR/g_lr', g_lr, global_steps)
                writer.add_scalar('LR/d_lr', d_lr, global_steps)

            gener_noise = mask_imgs
            gener_noise = gener_noise.type(torch.cuda.FloatTensor)

            generated_imgs= generator(gener_noise)
            fake_valid = discriminator(generated_imgs)

            gener_loss = -torch.mean(fake_valid).to(device)
            gener_loss.backward()
            optim_gen.step()
            writer.add_scalar("gener_loss", gener_loss.item(), global_steps)

            gen_step += 1
            total_gener_loss += gener_loss.item()
            total_dis_loss += loss_dis.item()
            

        if epoch % 1 == 0:
            for ind in range(real_imgs.shape[0]):
                grid_imgs = [real_imgs[ind],  mask_imgs[ind], generated_imgs[ind]]
                save_image(grid_imgs, f'/kaggle/working/output_imgs/{epoch}/output_img_{(index * 30) + ind}.jpg',nrow=3, normalize=True, scale_each=True)
                save_image(real_imgs[ind], f'/kaggle/working/real_imgs/{epoch}/{(index * 30) + ind}.jpg',nrow=1, normalize=True, scale_each=True)
                save_image(generated_imgs[ind], f'/kaggle/working/generated_imgs/{epoch}/{(index * 30) + ind}.jpg',nrow=1, normalize=True, scale_each=True)


        if gen_step and index % 100 == 0:
            actual_imgs = mask_imgs[:25]
            sample_imgs = generated_imgs[:25]
            save_image(sample_imgs, f'/kaggle/working/generated_imgs/generated_img_{epoch}_{index % len(train_loader)}.jpg', nrow=5, normalize=True, scale_each=True)            
            
            tqdm.write("[Epoch %d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
                (epoch+1, index % len(train_loader), len(train_loader), loss_dis.item(), gener_loss.item()))
    
    GeneratorLoss.append(total_gener_loss/len(train_loader))
    DiscriminatorLoss.append(total_dis_loss/len(train_loader))

**Evaluation of the model using FID score and SSIM score**

In [None]:
import sys
sys.path.append('/kaggle/input/evaluation/evaluation')
from inception_score.inception_score import inception_score
from fid.fid import calculate_fid_given_paths
from ssim.ssim import ssim
from psnr.psnr import psnr
import os
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
import cv2
"""
def inception_score(imgs, cuda=True, batch_size=64, resize=False, splits=1):
def calculate_fid_given_paths(paths, batch_size, cuda, dims):
def ssim(img1, img2, window_size = 11, size_average = True):
"""
SIZE = (64,64)
_transforms_fun = transforms.Compose([transforms.Resize((64,64)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
def _inception_score(path, cuda=True, batch_size=1, resize=True, splits=1):
    imgs = []
    for file in os.listdir(path):
        if file.endswith("png"):
            img = Image.open(os.path.join(path, file)).convert("RGB")
            #print(np.array(img).shape)
            imgs.append(_transforms_fun(img))
    imgs = torch.stack(imgs)
    #print(imgs.size())
    return inception_score(imgs, cuda, batch_size, resize, splits)

def _fid(paths, batch_size=100, cuda=True, dims=768):
    return calculate_fid_given_paths(paths, batch_size, cuda, dims)

def _ssim(paths, window_size=11, size_average=True):
    path1, path2 = paths
    imgs1, imgs2 = [], []
    batch_size = 10
    j = 0
    total = 0
    ssim_score = 0
    for file in os.listdir(path1):
        if file.endswith("jpg"):
            img1 = Image.open(os.path.join(path1, file)).convert("RGB")
            img2 = Image.open(os.path.join(path2, file)).convert("RGB")

            imgs1.append(_transforms_fun(img1))
            imgs2.append(_transforms_fun(img2))
            j = j + 1
            total = total + 1
        if j == batch_size - 1:
            imgs1 = torch.stack(imgs1)
            imgs2 = torch.stack(imgs2)
            ssim_score = ssim_score + batch_size * ssim(imgs1, imgs2, window_size = 11, size_average = True)
            imgs1, imgs2 = [], []
            j = 0
    if j != 0:
        imgs1 = torch.stack(imgs1)
        imgs2 = torch.stack(imgs2)
        ssim_score = ssim_score +  (j+1) * ssim(imgs1, imgs2, window_size = 11, size_average = True)
    return ssim_score / total

def _psnr(paths):
    path1, path2 = paths
    imgs1, imgs2 = [], []
    psnr_value = 0
    num = 1
    for file in os.listdir(path1):
        if file.endswith("png"):
            img1 = Image.open(os.path.join(path1, file)).convert("RGB")
            img2 = Image.open(os.path.join(path2, file)).convert("RGB")
            psnr_value = psnr_value + psnr(cv2.resize(np.array(img1),SIZE), cv2.resize(np.array(img2), SIZE))
            num = num + 1

    return psnr_value / num

def _meanl1(paths):
    path1, path2 = paths
    imgs1, imgs2 = [], []
    total_error = 0
    num = 1
    for file in os.listdir(path1):
        if file.endswith("png"):
            img1 = Image.open(os.path.join(path1, file)).convert("RGB")
            img2 = Image.open(os.path.join(path2, file)).convert("RGB")

            l1_error = np.mean(np.abs(cv2.resize(np.array(img1),SIZE)-cv2.resize(np.array(img2), SIZE)))
            #print(np.array(img1).shape, l1_error,np.sum(np.abs(cv2.resize(np.array(img1),SIZE)-cv2.resize(np.array(img2), SIZE)))/256/256/3)
            total_error = total_error + l1_error
            num = num + 1

    return total_error / num

metrics = {"is":_inception_score, "fid":_fid, "ssim":_ssim, "psnr":_psnr, "meanl1":_meanl1}

In [None]:
fid_scores = []
ssim_scores = []

In [None]:
def validate(generator, epoch):
    generator = generator.eval()
    validation(generator, epoch)
    score = os.popen(f'python -m pytorch_fid /kaggle/working/real_test_imgs/{epoch} /kaggle/working/fake_imgs/{epoch}').read()
    fid_score = float(score.split()[1])
    ssim_score = metrics['ssim']([f'/kaggle/working/real_test_imgs/{epoch}', f'/kaggle/working/fake_imgs/{epoch}'])
    print(f"FID score: {fid_score}")
    return fid_score, ssim_score

def validation(generator, epoch):
    transform = transforms.Compose(
        [
            transforms.Resize(size=(64, 64)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    
    test_set = InpaintDataset(TestB, TestB, transform=transform)

    test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=30, shuffle=True)
    
    for index, (real_imgs, mask_imgs) in enumerate(test_loader):

        real_imgs = real_imgs.type(torch.cuda.FloatTensor)
    
        mask_imgs = mask_imgs.type(torch.cuda.FloatTensor)
        fake_imgs = generator(mask_imgs).detach()        
        for ind in range(real_imgs.shape[0]):
            save_image(real_imgs[ind], f'/kaggle/working/real_test_imgs/{epoch}/{(index * 30) + ind}.jpg',nrow=1, normalize=True, scale_each=True)
            save_image(fake_imgs[ind], f'/kaggle/working/fake_imgs/{epoch}/{(index * 30) + ind}.jpg',nrow=1, normalize=True, scale_each=True)

In [None]:
best = 1e4

for epoch in range(50):
    print("Epoch {}".format(epoch+1))

    lr_schedulers = (gen_scheduler, dis_scheduler) if lr_decay else None

    train(noise, generator, discriminator, optim_gen, optim_dis,
    epoch, writer, lr_schedulers,img_size = 64, latent_dim = latent_dim,
    n_critic = n_critic,
    gener_batch_size = gener_batch_size)

    checkpoint = {'epoch': epoch, 'best_fid': best}
    checkpoint['generator_state_dict'] = generator.state_dict()
    checkpoint['discriminator_state_dict'] = discriminator.state_dict()
    
    fid_score, ssim_score = validate(generator, epoch)

    print(ssim_score.item())
    fid_scores.append(fid_score)
    ssim_scores.append(ssim_score.item())
    
    if epoch % 5 != 0:
        os.system(f'rm -rf /kaggle/working/output_imgs/{epoch}')
        os.system(f'rm -rf /kaggle/working/real_imgs/{epoch}')
        os.system(f'rm -rf /kaggle/working/generated_imgs/{epoch}')
        os.system(f'rm -rf /kaggle/working/fake_imgs/{epoch}')
        os.system(f'rm -rf /kaggle/working/real_test_imgs/{epoch}')
        print("Deleted files at epoch {}".format(epoch + 1))

    if fid_score < best:
        save_checkpoint(checkpoint,  is_best=(fid_score < best), epoch=epoch, output_dir=output_dir)
        best = fid_score
        print("Saved Latest Model!")


checkpoint = {'epoch':epoch, 'best_fid':best}
checkpoint['generator_state_dict'] = generator.state_dict()
checkpoint['discriminator_state_dict'] = discriminator.state_dict()

In [None]:
fid_scores

In [None]:
ssim_scores

In [None]:
GeneratorLoss

In [None]:
DiscriminatorLoss

In [None]:
plt.plot(range(1, len(GeneratorLoss) + 1), GeneratorLoss)
plt.savefig('GeneratorLoss.png')

In [None]:
plt.plot(range(1, len(DiscriminatorLoss) + 1), DiscriminatorLoss)
plt.savefig('DiscriminatorLoss.png')

In [None]:
!zip output_imgs.zip output_imgs/*
!zip generated_imgs.zip generated_imgs/*
!zip checkpoints.zip checkpoint/*