# **Monet2Photo CycleGAN. Demo Implementation by Vasili Karol**

# Chapter 1 - preparing data


In [None]:
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.autograd import Variable

from torchvision.utils import make_grid
from torchvision.datasets import ImageFolder
import torchvision.transforms as tt

import numpy as np
from numpy.random import uniform as rand_noise
import matplotlib.pyplot as plt
import seaborn as sns
import os
import random
from tqdm.notebook import tqdm

sns.set(style='darkgrid', font_scale=1.2)

In [None]:
Path = '../input/monet2photo'

In [None]:
image_size = 256
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
transforms = tt.Compose( [
                          tt.Resize([image_size, image_size]),
                          tt.ToTensor(),
                          tt.Normalize(mean=mean, std=std)
 ] )

images = ImageFolder(Path, transform=transforms )

**Id в Dataset'е: 0/2 - Monet test/train, 1/3- Photo test/train**

In [None]:
Dataset = DataLoader(images, num_workers=2, shuffle=True, pin_memory=True)   

In [None]:
#Балансировка Photo и Monet в меньшую сторону
def custom_count(list, x, dim=None):
    n = 0
    if dim:
        for a in list:
            if a[dim] == x:
                n += 1
    else:
        for a in list:
            if a == x or x in a:
                n += 1
    return n

Monet_len = custom_count(Dataset.dataset, 2, dim=1)
Photo_len = custom_count(Dataset.dataset, 3, dim=1)    
limit = 0
while limit < (Photo_len-Monet_len):
    for x in Dataset.dataset.imgs:
        if x[1] == 3 and limit < (Photo_len-Monet_len):
            Dataset.dataset.imgs.remove(x)
            limit += 1

In [None]:
Dataset.dataset

In [None]:
Test_monet = []
Test_photo = []
limit_test = 0

for x in Dataset.dataset:
    if x[1] == 0: 
        Test_monet.append(x[0])
    elif x[1] == 1:
        Test_photo.append(x[0])
        
while limit_test < (751+121):
    for x in Dataset.dataset.imgs:
        if x[1] == 0: 
            Dataset.dataset.imgs.remove(x)
            limit_test += 1
        elif x[1] == 1:
            Dataset.dataset.imgs.remove(x)
            limit_test += 1

Test_monet = DataLoader(Test_monet, num_workers=2, shuffle=True, pin_memory=True)   
Test_photo = DataLoader(Test_photo, num_workers=2, shuffle=True, pin_memory=True)   

In [None]:
def show_grid(grid_images, grid_size, figsize=8):
    fig, ax = plt.subplots(figsize=(figsize,figsize))
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow( torch.moveaxis(make_grid(grid_images[:np.prod(grid_size)], nrow=grid_size[1], normalize=True ), 0, -1) )
    plt.show()

a = [1, 1]
list_demonstration = [[], []]
for Image, id in Dataset:
    if id ==2 and len(list_demonstration[0])<8:
        list_demonstration[0].append(Image.reshape(Image.size()[1:]))
        a[0] += 1
    if id ==3 and len(list_demonstration[1])<8:
        list_demonstration[1].append(Image.reshape(Image.size()[1:]))
        a[1] += 1
    if a[0] > 8 and a[1] > 8:
        show_grid(list_demonstration[0] + list_demonstration[1], [2, 8], 20)
        break


In [None]:
device = torch.device('cuda')

# Chapter 2 - building networks

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.downsample = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3,bias=False),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1,bias=False),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1,bias=False),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
            )
        
        self.res_block_sample = [
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1,bias=False),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1,bias=False),
            nn.InstanceNorm2d(256)
        ]
        self.Res_block1 = nn.Sequential(*self.res_block_sample)
        self.Res_block2 = nn.Sequential(*self.res_block_sample)
        self.Res_block3 = nn.Sequential(*self.res_block_sample)
        self.Res_block4 = nn.Sequential(*self.res_block_sample)
        self.Res_block5 = nn.Sequential(*self.res_block_sample)
        self.Res_block6 = nn.Sequential(*self.res_block_sample)
        self.Res_block7 = nn.Sequential(*self.res_block_sample)
        self.Res_block8 = nn.Sequential(*self.res_block_sample)
        self.Res_block9 = nn.Sequential(*self.res_block_sample)
        
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(64, 3, kernel_size=7, stride=1, padding=3, bias=False),
            nn.Tanh()
        )
    def forward(self, x):
        x = self.downsample(x)
        
        #9 times for 256*256, 6 times for 128*128
        x = x + self.Res_block1(x)
        x = x + self.Res_block2(x)
        x = x + self.Res_block3(x)
        x = x + self.Res_block4(x)
        x = x + self.Res_block5(x)
        x = x + self.Res_block6(x)
        x = x + self.Res_block7(x)
        x = x + self.Res_block8(x)
        x = x + self.Res_block9(x)
        
        x = self.upsample(x)
        return x

In [None]:
Generator_x = Generator().to(device)
Generator_z = Generator().to(device)

In [None]:
#70*70 PatchGAN architecture
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.downsample = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1,bias=False),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1,bias=False),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1,bias=False),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1,bias=False),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1,bias=False),
            #nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.downsample(x) 

In [None]:
Discriminator_x = Discriminator().to(device)
Discriminator_z = Discriminator().to(device)

In [None]:
#function and idea from https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        
#normalize weights:
Generator_x.apply(weights_init)   
Generator_z.apply(weights_init)   
Discriminator_x.apply(weights_init)
Discriminator_z.apply(weights_init)

In [None]:
Adversarial = nn.MSELoss().cuda()
Consistency = nn.L1Loss().cuda()
Identify = nn.L1Loss().cuda()

criterion = {
    #*0.5 for discriminators
    'Adversarial' : Adversarial, #D(real/fake) and 1/0
    
    #*10 for Consistency
    'Consistency' : Consistency, #G_z(G_x(x)) and x
    
    'Identify' : Identify #G_x(x) and z
}

In [None]:
optims = {
    'generator_x' : torch.optim.Adam(Generator_x.parameters(), lr=0.0002, betas=(0.5, 0.999)),
    'generator_z' : torch.optim.Adam(Generator_z.parameters(), lr=0.0002, betas=(0.5, 0.999)),
    
    'discriminator_x' : torch.optim.Adam(Discriminator_x.parameters(), lr=0.0002, betas=(0.5, 0.999)),
    'discriminator_z' : torch.optim.Adam(Discriminator_z.parameters(), lr=0.0002, betas=(0.5, 0.999))
}

scheduler = {
    "discriminator_x": lr_scheduler.StepLR(optims['discriminator_x'], step_size=10, gamma=0.8),
    "discriminator_z": lr_scheduler.StepLR(optims['discriminator_z'], step_size=10, gamma=0.8),
    "generator_x": lr_scheduler.StepLR(optims['generator_x'], step_size=10, gamma=0.8),
    "generator_z": lr_scheduler.StepLR(optims['generator_z'], step_size=10, gamma=0.8)
} 

# Chapter 3 - train loop

In [None]:
Epochs = 100
Consistency_a = 10
history = {'gx':[],
           'gz':[],
           'dx':[],
           'dz':[]
          }

In [None]:
#pretrained
Pathdx = '../input/cyclegan-weights/CycleGAN_50_dx'
Pathgx = '../input/cyclegan-weights/CycleGAN_50_gx'
Pathdz = '../input/cyclegan-weights/CycleGAN_50_dz'
Pathgz = '../input/cyclegan-weights/CycleGAN_50_gz'
LOAD = False

if LOAD == True:
    Discriminator_z.load_state_dict(torch.load(Pathdz))
    Generator_z.load_state_dict(torch.load(Pathgz))
    Discriminator_x.load_state_dict(torch.load(Pathdx))
    Generator_x.load_state_dict(torch.load(Pathgx))

In [None]:
Tensor = torch.cuda.FloatTensor 

In [None]:
BUFFER = True
if BUFFER:
    Buffer_x = [] #make_buffer(Generator_x, 2)
    Buffer_z = [] #make_buffer(Generator_z, 3)


In [None]:
#training
#Это из stack overflow
from IPython import display
from ipywidgets import Output
from numpy.random import uniform as rand_noise
out = Output()
display.display(out)

Discriminator_x.train()
Discriminator_z.train()
Generator_x.train()
Generator_z.train()

for epoch in tqdm(range(1, Epochs+1)):
    history_per_epoch = {'gx':[], 'gz':[], 'dx':[], 'dz':[]}
    
    for Image, Id in Dataset:
        Image = Variable(Image.type(Tensor))
        if Id == 2:
            #Monet paintings
            #Adversarial
            optims['generator_x'].zero_grad()
            Z_fake = Generator_x(Image)  
            if len(Buffer_x) < 50:
                Buffer_x.append(Z_fake)
            Buffer_id = random.randint(0, len(Buffer_x)-1)
            Z_fake_label = Discriminator_x(Z_fake)
            Zeros = Variable(Tensor(np.zeros(Z_fake_label.size())), requires_grad=False)
            Ones = Variable(Tensor(np.ones(Z_fake_label.size())), requires_grad=False)
            gx_loss = criterion['Adversarial']( Z_fake_label, Ones )
            gx_loss.backward()
            optims['generator_x'].step()
            
            #Consistency
            optims['generator_x'].zero_grad()
            optims['generator_z'].zero_grad()
            Regen = Generator_z( Generator_x(Image))
            cons_loss = Consistency_a*criterion['Consistency']( Regen, Image)
            cons_loss.backward()
            optims['generator_x'].step()
            optims['generator_z'].step()
            
            #Identify
            optims['generator_x'].zero_grad()
            ident_loss = criterion['Identify']( Generator_x(Image), Image)
            ident_loss.backward()
            optims['generator_x'].step()
            
            optims['discriminator_x'].zero_grad()
            Fake_ = Buffer_x[Buffer_id]
            dx_loss = criterion['Adversarial'](Discriminator_x(Fake_.detach()), Zeros + rand_noise(0, 0.3))
            Buffer_x[Buffer_id] = Z_fake
            #del Fake_
            dx_loss.backward()
            optims['discriminator_x'].step()
            
            optims['discriminator_z'].zero_grad()
            dz_loss = criterion['Adversarial'](Discriminator_z(Image), Ones+ rand_noise(-0.2, 0.2) )
            dz_loss.backward()
            optims['discriminator_z'].step()
            
            history_per_epoch['gx'].append(gx_loss.item())
            history_per_epoch['dx'].append(dx_loss.item())
            history_per_epoch['dz'].append(dz_loss.item())

        elif Id == 3:
            #Real photoes
            optims['generator_z'].zero_grad()
            X_fake = Generator_z(Image)
            if len(Buffer_z) < 50:
                Buffer_z.append(X_fake)
            Buffer_id = random.randint(0, len(Buffer_z)-1)
            X_fake_label = Discriminator_z(X_fake)
            Zeros = Variable(Tensor(np.zeros(X_fake_label.size())), requires_grad=False)
            Ones = Variable(Tensor(np.ones(X_fake_label.size())), requires_grad=False)
            gz_loss = criterion['Adversarial']( X_fake_label, Ones )
            gz_loss.backward()
            optims['generator_z'].step()
            
            #Consistency
            optims['generator_z'].zero_grad()
            optims['generator_x'].zero_grad()
            Regen = Generator_x( Generator_z(Image))
            cons_loss = Consistency_a*criterion['Consistency']( Regen , Image)
            cons_loss.backward()
            optims['generator_z'].step()
            optims['generator_x'].step()
            
            #Identify
            optims['generator_z'].zero_grad()
            ident_loss = criterion['Identify']( Generator_z(Image), Image)
            ident_loss.backward()
            optims['generator_z'].step()
            
            optims['discriminator_x'].zero_grad()
            dx_loss = criterion['Adversarial'](Discriminator_x(Image), Ones+ rand_noise(-0.2, 0.2))
            dx_loss.backward()
            optims['discriminator_x'].step()
            
            optims['discriminator_z'].zero_grad()
            Fake_ = Buffer_z[Buffer_id]
            dz_loss = criterion['Adversarial'](Discriminator_z(Fake_.detach()), Zeros+ rand_noise(0, 0.3))
            Buffer_z[Buffer_id] = X_fake
            #del Fake_
            dz_loss.backward()
            optims['discriminator_z'].step()
            
            history_per_epoch['gz'].append(gz_loss.item())
            history_per_epoch['dx'].append(dx_loss.item())
            history_per_epoch['dz'].append(dz_loss.item())
            
    history['gx'].append(np.mean(history_per_epoch['gx']))
    history['gz'].append(np.mean(history_per_epoch['gz']))
    history['dx'].append(np.mean(history_per_epoch['dx']))
    history['dz'].append(np.mean(history_per_epoch['dz']))
        
    with torch.set_grad_enabled(False):
      with out:
        display.clear_output(wait=True)
        print(f"""Epoch: {epoch}/{Epochs}  Discr Loss: {history['dx'][-1]},{history['dz'][-1]}   
              Gen Loss: {history['gx'][-1]}, {history['gz'][-1]} \n Monet2Photo \t\t\t\t Photo2Monet""")
        Z_fake = Generator_x.forward(random.choice(Test_monet.dataset)[np.newaxis, :].cuda())
        X_fake = Generator_z.forward(random.choice(Test_photo.dataset)[np.newaxis, :].cuda())
        show_grid([Z_fake.view(3,256,256).cpu(), X_fake.view(3,256,256).cpu()], [1, 2])      
        del Z_fake
        del X_fake
    


# Chapter 4 - result's analysis and saving model's parameters


In [None]:
figure = plt.figure(figsize=(12, 7))
plt.plot(history['dx'], label='Discriminator_x')
plt.plot(history['gx'], label='Generator_x')
plt.plot(history['dz'], label='Discriminator_z')
plt.plot(history['gz'], label='Generator_z')
plt.legend()
plt.show()

In [None]:
Pathdx = './CycleGAN++_82_dx'
Pathgx = './CycleGAN++_82_gx'
Pathdz = './CycleGAN++_82_dz'
Pathgz = './CycleGAN++_82_gz'
torch.save(Discriminator_x.state_dict(), Pathdx)
torch.save(Generator_x.state_dict(), Pathgx)
torch.save(Discriminator_z.state_dict(), Pathdz)
torch.save(Generator_z.state_dict(), Pathgz)

<a href="CycleGAN++_75_dx"> CycleGAN+_75_dx </a>

<a href="CycleGAN++_75_dz"> CycleGAN+_75_dz </a>

<a href="CycleGAN++_75_gx"> CycleGAN+_75_gx </a>

<a href="CycleGAN++_75_gz"> CycleGAN+_75_gz </a>

In [None]:
n_samples = 6
real = []
fakes = []
Generator_z.eval()
for test_photo in Test_photo:
    with torch.set_grad_enabled(False):
        real.append(test_photo.view(3, 256, 256))
        fakes.append(Generator_z.forward(test_photo.cuda()).view(3, 256, 256).cpu())
    n_samples -=1
    if n_samples <= 0:
        break
show_grid(real+fakes, [2, 6], 18)            
        