# **Anime CycleGAN. Implementation by Vasili Karol**

# Chapter 1 - preparing data


In [None]:
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torch.autograd import Variable

from torchvision.utils import make_grid
from torchvision.datasets import ImageFolder
import torchvision.transforms as tt

import numpy as np
from numpy.random import uniform as rand_noise
import matplotlib.pyplot as plt
import seaborn as sns
import os
import random
from tqdm.notebook import tqdm
import glob, itertools
from PIL import Image as PILImage

sns.set(style='darkgrid', font_scale=1.2)

***All work has been done on kaggle platform. Links to Data and Net's weights:***

https://www.kaggle.com/arnaud58/selfie2anime

https://www.kaggle.com/shadowedtomb/anigan

In [None]:
Path = '../input/selfie2anime'

In [None]:
#image_size = 256
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
transforms = tt.Compose( [
                          #tt.Resize([image_size, image_size]),
                          tt.ToTensor(),
                          tt.Normalize(mean=mean, std=std)
 ] )

**ImageDataset** inspiration from BALRAJ ASHWATH's work:

https://www.kaggle.com/balraj98/cyclegan-translating-paintings-photos-pytorch

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root, transforms_=transforms, mode="train"):
        self.transform = transforms_

        self.files_A = sorted(glob.glob(os.path.join(root, f"{mode}A") + "/*.*"))
        self.files_B = sorted(glob.glob(os.path.join(root, f"{mode}B") + "/*.*"))

    def __getitem__(self, index):
        image_A = PILImage.open(self.files_A[index % len(self.files_A)])
        image_B = PILImage.open(self.files_B[index % len(self.files_B)])
            
        # Convert grayscale images to rgb
        if image_A.mode != "RGB":
            image_A = to_rgb(image_A)
        if image_B.mode != "RGB":
            image_B = to_rgb(image_B)

        item_A = self.transform(image_A)
        item_B = self.transform(image_B)
        return {"A": item_A, "B": item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [None]:
TestData = DataLoader(ImageDataset(Path, mode='test'), num_workers=2, shuffle=True, pin_memory=True)   
ImgDataset = DataLoader(ImageDataset(Path), num_workers=2, shuffle=True, pin_memory=True)   

**Id в Dataset'е: 0/2 - Monet test/train, 1/3- Photo test/train**

In [None]:
def show_grid(grid_images, grid_size, figsize=8):
    fig, ax = plt.subplots(figsize=(figsize,figsize))
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow( torch.moveaxis(make_grid(grid_images[:np.prod(grid_size)], nrow=grid_size[1], normalize=True ), 0, -1) )
    plt.show()

Anime_ex = []
Real_ex = []
for img in ImgDataset:
    Anime_ex.append(img['B'].view(3, 256, 256))
    Real_ex.append(img['A'].view(3, 256, 256))
    if len(Anime_ex)==8 and len(Real_ex)==8:
        show_grid(Real_ex+Anime_ex, [2, 8], 15)
        break
del Anime_ex
del Real_ex

# Chapter 2 - building networks


In [None]:
device = torch.device('cuda')

Nets implemented from official paper:

https://arxiv.org/pdf/1703.10593.pdf

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.downsample = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3,bias=False),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1,bias=False),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1,bias=False),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
            )
        
        self.res_block_sample = [
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1,bias=False),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1,bias=False),
            nn.InstanceNorm2d(256)
        ]
        self.Res_block1 = nn.Sequential(*self.res_block_sample)
        self.Res_block2 = nn.Sequential(*self.res_block_sample)
        self.Res_block3 = nn.Sequential(*self.res_block_sample)
        self.Res_block4 = nn.Sequential(*self.res_block_sample)
        self.Res_block5 = nn.Sequential(*self.res_block_sample)
        self.Res_block6 = nn.Sequential(*self.res_block_sample)
        self.Res_block7 = nn.Sequential(*self.res_block_sample)
        self.Res_block8 = nn.Sequential(*self.res_block_sample)
        self.Res_block9 = nn.Sequential(*self.res_block_sample)
        
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(64, 3, kernel_size=7, stride=1, padding=3, bias=False),
            nn.Tanh()
        )
    def forward(self, x):
        x = self.downsample(x)
        
        #9 times for 256*256, 6 times for 128*128
        x = x + self.Res_block1(x)
        x = x + self.Res_block2(x)
        x = x + self.Res_block3(x)
        x = x + self.Res_block4(x)
        x = x + self.Res_block5(x)
        x = x + self.Res_block6(x)
        x = x + self.Res_block7(x)
        x = x + self.Res_block8(x)
        x = x + self.Res_block9(x)
        
        x = self.upsample(x)
        return x

In [None]:
Generator_x = Generator().to(device)
Generator_z = Generator().to(device)

GaussianNoise code from  

https://github.com/ShivamShrirao/facegan_pytorch/blob/main/facegan_pytorch.ipynb


In [None]:
class GaussianNoise(nn.Module):                 
    def __init__(self, std=0.1, decay_rate=0):
        super().__init__()
        self.std = std
        self.decay_rate = decay_rate

    def decay_step(self):
        self.std = max(self.std - self.decay_rate, 0)

    def forward(self, x):
        if self.training:
            return x + torch.empty_like(x).normal_(std=self.std)
        else:
            return x

In [None]:
#70*70 PatchGAN architecture
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.downsample = nn.Sequential(
            GaussianNoise(),
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1,bias=False),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            GaussianNoise(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1,bias=False),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            GaussianNoise(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1,bias=False),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            GaussianNoise(),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1,bias=False),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            GaussianNoise(),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1,bias=False),
            #nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.downsample(x) 

In [None]:
Discriminator_x = Discriminator().to(device)
Discriminator_z = Discriminator().to(device)

Function and idea of **weights_init** from official pytorch implementation:

https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        
#normalize weights:
Generator_x.apply(weights_init)   
Generator_z.apply(weights_init)   
Discriminator_x.apply(weights_init)
Discriminator_z.apply(weights_init)

In [None]:
Adversarial = nn.MSELoss().cuda()
Consistency = nn.L1Loss().cuda()
Identify = nn.L1Loss().cuda()

criterion = {
    #*0.5 for discriminators
    'Adversarial' : Adversarial, #D(real/fake) and 1/0
    
    #*10 for Consistency
    'Consistency' : Consistency, #G_z(G_x(x)) and x
    
    'Identify' : Identify #G_x(x) and z
}

In [None]:
lr = 1e-04
optims = {
    'generator_x' : torch.optim.Adam(Generator_x.parameters(), lr=lr, betas=(0.5, 0.999)),
    'generator_z' : torch.optim.Adam(Generator_z.parameters(), lr=lr, betas=(0.5, 0.999)),
    
    'discriminator_x' : torch.optim.Adam(Discriminator_x.parameters(), lr=lr, betas=(0.5, 0.999)),
    'discriminator_z' : torch.optim.Adam(Discriminator_z.parameters(), lr=lr, betas=(0.5, 0.999))
}

scheduler = {
    "discriminator_x": lr_scheduler.StepLR(optims['discriminator_x'], step_size=5, gamma=0.5),
    "discriminator_z": lr_scheduler.StepLR(optims['discriminator_z'], step_size=5, gamma=0.5),
    "generator_x": lr_scheduler.StepLR(optims['generator_x'], step_size=5, gamma=0.5),
    "generator_z": lr_scheduler.StepLR(optims['generator_z'], step_size=5, gamma=0.5)
} 

# Chapter 3 - train loop


In [None]:
Epochs = 200
Consistency_a = 10
history = {'gx':[],
           'gz':[],
           'dx':[],
           'dz':[]
          }

In [None]:
PATH_load = './AnimeGAN+_100_params'
LOAD = True

if LOAD == True:
    checkpoint = torch.load(PATH_load)
    Discriminator_z.load_state_dict(checkpoint['dz'])
    Generator_z.load_state_dict(checkpoint['gz'])
    Discriminator_x.load_state_dict(checkpoint['dx'])
    Generator_x.load_state_dict(checkpoint['gx'])
    
    optims['generator_x'].load_state_dict(checkpoint['optim_gx'])
    optims['generator_z'].load_state_dict(checkpoint['optim_gz'])
    optims['discriminator_x'].load_state_dict(checkpoint['optim_dx'])
    optims['discriminator_z'].load_state_dict(checkpoint['optim_dz'])

In [None]:
Tensor = torch.cuda.FloatTensor 

In [None]:
BUFFER = True
if BUFFER:
    Buffer_x = [] 
    Buffer_z = [] 

**Output mechanism** was made within information from tqdm issue:

https://github.com/tqdm/tqdm/issues/818

In [None]:
#training
from IPython import display
from ipywidgets import Output
from numpy.random import uniform as rand_noise
out = Output()
display.display(out)

Discriminator_x.train()
Discriminator_z.train()
Generator_x.train()
Generator_z.train()

for epoch in tqdm(range(1, Epochs+1)):
    
    if epoch > 100:
        scheduler['generator_x'].step()
        scheduler['generator_z'].step()
        scheduler['discriminator_x'].step()
        scheduler['discriminator_z'].step()
    
    history_per_epoch = {'gx':[], 'gz':[], 'dx':[], 'dz':[]}
    
    for Image in ImgDataset:
        Anime = Variable(Image['B'].type(Tensor))        
        Real = Variable(Image['A'].type(Tensor))        
           ##  Real images  ##
        #Adversarial
        optims['generator_x'].zero_grad()
        Z_fake = Generator_x(Real)  
        if len(Buffer_x) < 50:
            Buffer_x.append(Z_fake)
        Buffer_id = random.randint(0, len(Buffer_x)-1)
        Z_fake_label = Discriminator_x(Z_fake)
        Zeros = Variable(Tensor(np.zeros(Z_fake_label.size())), requires_grad=False)
        Ones = Variable(Tensor(np.ones(Z_fake_label.size())), requires_grad=False)
        gx_loss = criterion['Adversarial']( Z_fake_label, Ones )
        gx_loss.backward()
        optims['generator_x'].step()
        
        #Consistency
        optims['generator_x'].zero_grad()
        optims['generator_z'].zero_grad()
        Regen = Generator_z( Generator_x(Real))
        cons_loss = Consistency_a*criterion['Consistency']( Regen, Real)
        cons_loss.backward()
        optims['generator_x'].step()
        optims['generator_z'].step()
        
        #Identify
        optims['generator_x'].zero_grad()
        ident_loss = criterion['Identify']( Generator_x(Real), Real)
        ident_loss.backward()
        optims['generator_x'].step()
        
        optims['discriminator_x'].zero_grad()
        Fake_ = Buffer_x[Buffer_id]
        dx_loss = criterion['Adversarial'](Discriminator_x(Fake_.detach()), Zeros + rand_noise(0, 0.3))
        Buffer_x[Buffer_id] = Z_fake
        #del Fake_
        dx_loss.backward()
        optims['discriminator_x'].step()
        
        optims['discriminator_z'].zero_grad()
        dz_loss = criterion['Adversarial'](Discriminator_z(Real), Ones+ rand_noise(-0.2, 0.2) )
        dz_loss.backward()
        optims['discriminator_z'].step()
        
        history_per_epoch['gx'].append(gx_loss.item())
        history_per_epoch['dx'].append(dx_loss.item())
        history_per_epoch['dz'].append(dz_loss.item())
            
            
               ##  Anime images  ##
        optims['generator_z'].zero_grad()
        X_fake = Generator_z(Anime)
        if len(Buffer_z) < 50:
            Buffer_z.append(X_fake)
        Buffer_id = random.randint(0, len(Buffer_z)-1)
        X_fake_label = Discriminator_z(X_fake)
        Zeros = Variable(Tensor(np.zeros(X_fake_label.size())), requires_grad=False)
        Ones = Variable(Tensor(np.ones(X_fake_label.size())), requires_grad=False)
        gz_loss = criterion['Adversarial']( X_fake_label, Ones )
        gz_loss.backward()
        optims['generator_z'].step()
        
        #Consistency
        optims['generator_z'].zero_grad()
        optims['generator_x'].zero_grad()
        Regen = Generator_x( Generator_z(Anime))
        cons_loss = Consistency_a*criterion['Consistency']( Regen , Anime)
        cons_loss.backward()
        optims['generator_z'].step()
        optims['generator_x'].step()
        
        #Identify
        optims['generator_z'].zero_grad()
        ident_loss = criterion['Identify']( Generator_z(Anime), Anime)
        ident_loss.backward()
        optims['generator_z'].step()
        
        optims['discriminator_x'].zero_grad()
        dx_loss = criterion['Adversarial'](Discriminator_x(Anime), Ones+ rand_noise(-0.2, 0.2))
        dx_loss.backward()
        optims['discriminator_x'].step()
        
        optims['discriminator_z'].zero_grad()
        Fake_ = Buffer_z[Buffer_id]
        dz_loss = criterion['Adversarial'](Discriminator_z(Fake_.detach()), Zeros+ rand_noise(0, 0.3))
        Buffer_z[Buffer_id] = X_fake
        dz_loss.backward()
        optims['discriminator_z'].step()
        
        history_per_epoch['gz'].append(gz_loss.item())
        history_per_epoch['dx'].append(dx_loss.item())
        history_per_epoch['dz'].append(dz_loss.item())
            
    history['gx'].append(np.mean(history_per_epoch['gx']))
    history['gz'].append(np.mean(history_per_epoch['gz']))
    history['dx'].append(np.mean(history_per_epoch['dx']))
    history['dz'].append(np.mean(history_per_epoch['dz']))
    
    with torch.set_grad_enabled(False):
      with out:
        display.clear_output(wait=True)
        print(f"""Epoch: {epoch}/{Epochs}  Discr Loss: {history['dx'][-1]},{history['dz'][-1]}   
              Gen Loss: {history['gx'][-1]}, {history['gz'][-1]} \n Monet2Photo \t\t\t\t Photo2Monet""")
        for Image_test in ImgDataset:
            Z_fake = Generator_x.forward(Image_test['A'].cuda())
            X_fake = Generator_z.forward(Image_test['B'].cuda())
            break
        show_grid([Z_fake.view(3,256,256).cpu(), X_fake.view(3,256,256).cpu()], [1, 2])      
        del Z_fake
        del X_fake

# Chapter 4 - result's analysis and saving model's parameters


In [None]:
figure = plt.figure(figsize=(12, 7))
plt.plot(history['dx'], label='Discriminator_x')
plt.plot(history['gx'], label='Generator_x')
plt.plot(history['dz'], label='Discriminator_z')
plt.plot(history['gz'], label='Generator_z')
plt.legend()
plt.show()

In [None]:
PATH_save = 'AnimeGAN+_100_params'

torch.save({
    'dx' : Discriminator_x.state_dict(),
    'dz' : Discriminator_z.state_dict(),
    'gx' : Generator_x.state_dict(),
    'gz' : Generator_z.state_dict(),
    
    'optim_gx': optims['generator_x'].state_dict(),
    'optim_gz': optims['generator_z'].state_dict(),
    'optim_dx': optims['discriminator_x'].state_dict(),
    'optim_dz': optims['discriminator_z'].state_dict(),
            }, PATH_save)


<a href='AnimeGAN_30_params'> AnimeGAN_30_params </a>

In [None]:
PATH_load = '../input/anigan/AnimeGAN_70_params'
checkpoint = torch.load(PATH_load)
Generator_old_x = Generator().to(device)
Generator_old_z = Generator().to(device)

Generator_old_z.load_state_dict(checkpoint['gz'])
Generator_old_x.load_state_dict(checkpoint['gx'])
    

In [None]:
n_samples = 6
Real = []
Fakes = []
Old_Fakes = []
Generator_x.eval()
Generator_old_x.eval()
for img in TestData:
    with torch.set_grad_enabled(False):
        Real.append(img['B'].view(3, 256, 256))
        Old_Fakes.append(Generator_old_z(img['B'].cuda()).view( 3, 256, 256).cpu())
        Fakes.append(Generator_z(img['B'].cuda()).view( 3, 256, 256).cpu())
    n_samples -=1
    if n_samples <= 0:
        break
show_grid(Real+Old_Fakes+Fakes, [3, 6], 18)            


        