# Import Libraries

In [None]:
import torch
import torch.nn as nn  #neural network
import torchvision #image transformation
import os
import PIL #for image
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid  #make grid of images
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import wandb
import matplotlib.animation as animation
from IPython.display import HTML

### Copy Checkpoints from Input to output directory

In [None]:
if(os.path.isfile("../input/checkpointsanime/G-Checkpoint.pkl") and os.path.isfile("../input/checkpointsanime/D-Checkpoint.pkl")):
    !cp -r '../input/checkpointsanime/' ./

# Setup Parameters and Hyper-Parameters

In [None]:
datapath = '../input/gananime-lite/'
checkpt_path = './checkpointsanime/'
epochs = 10000
batch_size = 128
image_size = 64 #resize image to this
lr = 1e-4
device = torch.device("cuda:0" if (torch.cuda.is_available()) else 'cpu')
last_epoch = 0
nc = 3 #number of channels (RGB)
nz = 200 #size of latent vector (input to first layer of Generator)
ngf = ndf = image_size #size of feature maps for generator and discriminator
ngpu = 1 #number of GPUs
beta1 = 0.5 #parameter for Adam optimizer
device

# WANDB Config

In [None]:
wandb.login(key='xxxxxxxxxxxxxxxxxxxx')   #enter your key

In [None]:
%%capture
exp_name = wandb.util.generate_id()
myrun = wandb.init(
        project='AnimeGAN',
        group=exp_name,
        config={
            'Image Size':image_size,
            'Num Channels':nc,
            'nz':nz,
            'ngf':ngf,
            'ndf':ndf,
            'Learning Rate':lr,
            'Beta1':beta1,
            'Epoch': epochs,
            'Batch_size':batch_size,
            'Loss':"BCELoss",            
            'Optimizer':'Adam',
            'Last Epoch':last_epoch,
        }
)
config = wandb.config
print(exp_name)

# Import Dataset and Set Data Loader

Dataset : https://www.kaggle.com/prasoonkottarathil/gananime-lite

In [None]:
dataset = datasets.ImageFolder(root=datapath,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
dataloader = DataLoader(dataset, shuffle=True, batch_size = batch_size)

# Plot Some Training Images

In [None]:
real_batch = next(iter(dataloader))
plt.figure(figsize=(10, 10))
plt.title('Sample Training Images')
plt.axis("off")
plt.imshow(np.transpose(make_grid(real_batch[0].to(device)[:64], normalize=True).cpu(),(1,2,0)));

# Generator

In [None]:
# Generator Architecture
# 200 -> 512
# 512 -> 256
# 256 -> 128
# 128 -> 64
# 64-> 3

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        
        self.gen = nn.Sequential(
        #200->512
        nn.ConvTranspose2d(nz, ngf*8, kernel_size=4, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(ngf*8),
        nn.ReLU(inplace=True),
            
        #512->256
        nn.ConvTranspose2d(ngf*8, ngf*4, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(ngf*4),
        nn.ReLU(inplace=True),
            
        #256->128
        nn.ConvTranspose2d(ngf*4, ngf*2, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(ngf*2),
        nn.ReLU(inplace=True),
            
        #128->64
        nn.ConvTranspose2d(ngf*2, ngf, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(ngf),
        nn.ReLU(inplace=True),
            
        #64->3
        nn.ConvTranspose2d(ngf, nc, kernel_size=4, stride=2, padding=1, bias=False),
        nn.Tanh()
        )
        
    def forward(self, input):    #how model is going to run
        return self.gen(input)

# Discriminator

In [None]:
# Discriminator Architecture
#3 -> 32
#32 -> 64
#64 -> 128
#128 -> 256
#256 -> 512
#512 -> 1

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        
        self.disc = nn.Sequential(
        #3 -> 32
        nn.Conv2d(nc, ndf//2, kernel_size=3, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(ndf//2),
        nn.LeakyReLU(0.2, inplace=True),
            
        #32 -> 64
        nn.Conv2d(ndf//2, ndf, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(ndf),
        nn.LeakyReLU(0.2, inplace=True),
        
        #64 -> 128
        nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(ndf*2),
        nn.LeakyReLU(0.2, inplace=True),
            
        #128 -> 256
        nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(ndf*4),
        nn.LeakyReLU(0.2, inplace=True),
            
        #256 -> 512
        nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(ndf*8),
        nn.LeakyReLU(0.2, inplace=True),
            
        #512 -> 1
        nn.Conv2d(ndf*8, 1, kernel_size=4, stride=2, padding=1, bias=False),
        nn.Sigmoid()
        )
        
    def forward(self, input):
        return self.disc(input)

# Model
## Generator Instance

In [None]:
netG = Generator(ngpu).to(device) #create generator instance on device(GPU else CPU)
netG

## Discriminator Instance

In [None]:
netD = Discriminator(ngpu).to(device) #create generator instance on device(GPU else CPU)
netD

# Config Wandb to watch Generator and Discriminator

In [None]:
wandb.watch(netD, log_freq=100)
wandb.watch(netG, log_freq=100)

# BCE Loss Function

In [None]:
fixed_noise = torch.randn(image_size, nz, 1, 1, device = device) #to visualize generator progress with same set of image

loss_fn = nn.BCELoss()

real_label = 1
fake_label = 0

# Optimizers

In [None]:
optim_g = optim.Adam(netG.parameters(), lr = lr)
optim_d = optim.Adam(netD.parameters(), lr = lr)

# Save Checkpoint

In [None]:
def save_chckpt():
    torch.save({
        'epoch':epoch,
        'model_state_dict':netG.state_dict(),
        'optimizer_state_dict':optim_g.state_dict()
    }, f"{checkpt_path}G-Checkpoint.pkl")
    
    torch.save({
        'epoch':epoch,
        'model_state_dict':netD.state_dict(),
        'optimizer_state_dict':optim_d.state_dict()
    }, f"{checkpt_path}D-Checkpoint.pkl")
    
    print(f"Saved Checkpoint:\n\t Epoch : {epoch}")

# Load Checkpoint

In [None]:
def load_chckpt():
    checkpoint = torch.load(f"{checkpt_path}G-Checkpoint.pkl")
    netG.load_state_dict(checkpoint['model_state_dict'])
    optim_g.load_state_dict(checkpoint['optimizer_state_dict'])
    
    checkpoint = torch.load(f"{checkpt_path}D-Checkpoint.pkl")
    netD.load_state_dict(checkpoint['model_state_dict'])
    optim_d.load_state_dict(checkpoint['optimizer_state_dict'])
    
    last_epoch = checkpoint['epoch']

    print(f"Checkpoint Loaded:\n\t Epoch : {last_epoch}")
    return last_epoch

# Load from Previous Checkpoint

In [None]:
if(os.path.isfile(f"{checkpt_path}G-Checkpoint.pkl") and os.path.isfile(f"{checkpt_path}D-Checkpoint.pkl")):
    last_epoch = load_chckpt()+1

# Generate Fake Images

In [None]:
def log_fake(noise):   #log fake images
    fake_img = netG(noise) #generate image from noise
    fake_img = fake_img.detach().cpu() #deteach since only used for visualization
    grid = make_grid(fake_img[:64], nrow=8).permute(1, 2, 0)
    wandb.log({'Generated Image' : wandb.Image(grid.numpy().clip(0, 1))})
    #plt.figure(figsize=(10, 10))
    #plt.imshow(grid.clip(0, 1));
    #plt.show()

# Training

In [None]:
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

for epoch in range(last_epoch, epochs):
    for i, data in enumerate(dataloader, 0):    #i - index, data - value
        
        #Discriminator - real image batch
        netD.zero_grad()  #set gradients to zero
        
        real = data[0].to(device)  #bind with device
        batch_s = real.size(0)  #find batch size (since in last iteration batch size may vary)
        label = torch.full((batch_s,), real_label, dtype=torch.float, device=device)   #create tensor with real_label with current batch size
        output = netD(real).view(-1)  #reshaping output from discriminator to 1D (since only a float number is needed)
        errD_real = loss_fn(output, label) #calc loss of discriminator for real images batch
        errD_real.backward()   #calculating gradient
        D_x = output.mean().item()  #mean loss for current batch of real images
        
        #Discriminator - fake image batch
        noise = torch.randn(batch_s, nz, 1, 1, device = device) #generate noise
        fake = netG(noise)  #fake image from noise
        #not needed to find batch size since this is fake can generate batch_size number of images each time
        #label = fill_(fake_label)  #generate fake labels of batch_size
        #label = torch.full((batch_s,), fake_label, dtype=torch.float, device=device)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)  #detach since gradient should not be altered by discriminator(only to find probability)
        errD_fake = loss_fn(output, label)  #calc loss of discriminator for fake images batch
        errD_fake.backward()   #calculating gradient
        D_G_z1 = output.mean().item()  #mean loss for current batch of fake images
        
        #calculate accumulated loss of discriminator
        errD = errD_real + errD_fake
        #update D
        optim_d.step()
        
        #Generator
        netG.zero_grad()  #set gradients to zero
        #should use same noise
        label.fill_(real_label)  #generate real labels of batch_size
        output = netD(fake).view(-1)  #reshaping output from discriminator to 1D (since only a float number is needed)
        errG = loss_fn(output, label)  #calc generator loss
        errG.backward()  #calculating gradient
        D_G_z2 = output.mean().item()
        optim_g.step()
        
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # log in every 200 batches in each epoch
        if (iters % 200 == 0) or ((epoch == epochs-1) and (i == len(dataloader)-1)):
            #log wandb
            wandb.log({'Epoch':epoch, 'Discriminator Loss':errD.item(), 'Generator Loss':errG.item()})
            wandb.save(f"{checkpt_path}G-Checkpoint.pkl")
            wandb.save(f"{checkpt_path}D-Checkpoint.pkl")
            print("Checkpoint Logged")
                
            #accumulate different stages of same noise for animation
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
                grid = make_grid(fake[:25], nrow=5).permute(1, 2, 0)
            img_list.append(make_grid(fake, padding=2, normalize=True))
            
            #wandb log generated images of different noise
            log_fake(torch.randn(batch_size, nz, 1, 1, device = device))
            
        iters += 1
    save_chckpt()
    
    # save animation every 5 epochs
    if(epoch % 5 ==0):
        fig = plt.figure(figsize=(64,64))
        plt.axis("off")
        ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
        ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
        writervideo = animation.FFMpegWriter(fps=2)
        ani.save('Animation.mp4', writer=writervideo)
        wandb.save('./Animation.mp4')
        #HTML(ani.to_jshtml())