## CycleGAN

In [1]:
import os
from glob import glob
import random
import itertools
import numpy as np
from PIL import Image
%matplotlib inline
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [4]:
class ResBlock(nn.Module):
    def __init__(self, features):
        super(ResBlock, self).__init__()
        block_list = []
        self.block = self.make_block(block_list, features)
        
    def make_block(self, modules_list, features):
        modules_list.append(nn.ReflectionPad2d(1))
        modules_list.append(nn.Conv2d(features, features, kernel_size=3, stride=1, bias=True))
        modules_list.append(self.select_normalization(norm='instance', features=features))
        modules_list.append(nn.ReLU(inplace=True))
        modules_list.append(nn.ReflectionPad2d(1))
        modules_list.append(nn.Conv2d(features, features, kernel_size=3, stride=1, bias=True))
        modules_list.append(self.select_normalization(norm='instance', features=features))
        modules = nn.Sequential(*modules_list)
        return modules
        
    def select_normalization(self, norm, features):
        if norm == 'batch':
            return nn.BatchNorm2d(features)
        elif norm == 'instance':
            return nn.InstanceNorm2d(features)
        else:
            assert 0, '%s is not supported.' % norm

    def forward(self, x):
        out = x + self.block(x)
        return out

In [5]:
class Generator(nn.Module):
    def __init__(self, n_down, n_up, n_res, in_features):
        super(Generator, self).__init__()
        
        out_features = 64
        first_conv = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_features, out_features, kernel_size=7, stride=1, padding=0, bias=True),
            self.select_normalization(norm='instance', features=out_features),
            nn.ReLU(inplace=True)]
        
        down_block = []
        for _ in range(n_down):
            in_features = out_features
            out_features = in_features * 2
            down_block += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1, bias=True),
                self.select_normalization(norm='instance', features=out_features),
                nn.ReLU(inplace=True)]
            
        res_block = []
        res_features = out_features
        for _ in range(n_res):
            res_block.append(ResBlock(res_features))
            
        up_block = []
        in_features = res_features
        out_features = in_features // 2
        for _ in range(n_up):
            up_block += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
                self.select_normalization(norm='instance', features=out_features),
                nn.ReLU(inplace=True)]
            in_features = out_features
            out_features = in_features // 2
        
        last_conv = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_features, 3, kernel_size=7, stride=1, padding=0, bias=True),
            nn.Tanh()]
        
        self.first_conv = nn.Sequential(*first_conv)
        self.down_block = nn.Sequential(*down_block)
        self.res_block = nn.Sequential(*res_block)
        self.up_block = nn.Sequential(*up_block)
        self.last_conv = nn.Sequential(*last_conv)
        self.init_weights(self.first_conv)
        self.init_weights(self.down_block)
        self.init_weights(self.res_block)
        self.init_weights(self.up_block)
        self.init_weights(self.last_conv)

    def init_weights(self, net):
        classname = net.__class__.__name__
        if classname.find('Conv') != -1:
            torch.nn.init.normal_(net.weight.data, 0.0, 0.02)
            if hasattr(net, 'bias') and net.bias is not None:
                torch.nn.init.constant_(net.bias.data, 0.0)
    
    def select_normalization(self, norm, features):
        if norm == 'batch':
            return nn.BatchNorm2d(features)
        elif norm == 'instance':
            return nn.InstanceNorm2d(features)
        else:
            assert 0, '%s is not supported.' % norm
            
    def forward(self, x):
        h = self.first_conv(x)
        h = self.down_block(h)
        h = self.res_block(h)
        h = self.up_block(h)
        out = self.last_conv(h)   
        return out

In [6]:
class Discriminator(nn.Module):
    def __init__(self, n_layers=3):
        super(Discriminator, self).__init__()
        out_features = 64
        modules = [nn.Conv2d(3, out_features, kernel_size=4, stride=2, padding=1, bias=True),
                   nn.LeakyReLU(negative_slope=0.2, inplace=True)]

        for i in range(n_layers):
            in_features = out_features
            out_features = in_features * 2
            if i == n_layers - 1:    stride=1
            else:    stride=2
            modules += [nn.Conv2d(in_features, out_features, kernel_size=4, stride=stride, padding=1, bias=True),
                        self.select_normalization(norm='instance', features=out_features),
                        nn.LeakyReLU(negative_slope=0.2, inplace=True)]
        
        modules += [nn.Conv2d(out_features, 1, kernel_size=4, stride=1, padding=1, bias=True)]
        self.layers = nn.Sequential(*modules)
        self.init_weights(self.layers)

    def init_weights(self, net):
        classname = net.__class__.__name__
        if classname.find('Conv') != -1:
            torch.nn.init.normal_(net.weight.data, 0.0, 0.02)
            if hasattr(net, 'bias') and net.bias is not None:
                torch.nn.init.constant_(net.bias.data, 0.0)
    
    def select_normalization(self, norm, features):
        if norm == 'batch':
            return nn.BatchNorm2d(features)
        elif norm == 'instance':
            return nn.InstanceNorm2d(features)
        else:
            assert 0, '%s is not supported.' % norm
            
    def forward(self, x):
        out = self.layers(x)
        return out

In [8]:
class CycleGAN_Dataset(torch.utils.data.Dataset):
    def __init__(self, datapath, transforms=None):
        self.transforms = transforms
        # self.A_path = glob(os.path.join('fantasy', '*', '*.jpg'))
        # self.B_path = glob(os.path.join('Photos', '*.jpg'))
        self.A_path = glob(os.path.join('fantasy2', '*.jpg'))
        self.B_path = glob(os.path.join('monet2photo', 'trainB', '*.jpg'))

        random.shuffle(self.A_path)
        random.shuffle(self.B_path)
        self.datalength = min(len(self.A_path), len(self.B_path))
        self.dataA = self.A_path[:self.datalength]
        self.dataB = self.B_path[:self.datalength]
        
    def __len__(self):
        return self.datalength
    
    def __getitem__(self, i):
        imgA = Image.open(self.dataA[i])
        imgB = Image.open(self.dataB[i])
        
        if self.transforms:
            imgA = self.transforms(imgA)
            imgB = self.transforms(imgB)
        
        return imgA, imgB

In [9]:
class Image_History_Buffer:
    def __init__(self, pool_size=50):
        self.pool_size = pool_size
        self.buffer = []
    
    def get_images(self,pre_images):
        return_imgs = []
        for img in pre_images:
            img = torch.unsqueeze(img,0)
            if len(self.buffer) < self.pool_size:
                self.buffer.append(img)
                return_imgs.append(img)
            else:
                if random.randint(0,1)>0.5:
                    i = random.randint(0,self.pool_size-1)
                    tmp = self.buffer[i].clone()
                    self.buffer[i]=img
                    return_imgs.append(tmp)
                else:
                    return_imgs.append(img)
        return torch.cat(return_imgs,dim=0)

In [10]:
class loss_scheduler():
    def __init__(self, epoch_decay):
        self.epoch_decay = epoch_decay

    def f(self, epoch):
        if epoch<=self.epoch_decay:
            return 1
        else:
            scaling = 1 - (epoch-self.epoch_decay)/float(self.epoch_decay)
            return scaling

In [11]:
lr = 0.0002
img_size = 256
betas = (0.5, 0.999)
batchsize = 1
imgsize = 256
n_epochs = 200
decay_epoch = 100
lambda_val = 10
lambda_id_val = 0
datapath = 'apple2orange'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Make training dataset
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
transform = transforms.Compose([transforms.Resize(img_size, Image.BICUBIC),
                                   transforms.RandomCrop(imgsize),
                                   transforms.RandomHorizontalFlip(),
                                   transforms.ToTensor(),
                                   transforms.Normalize(mean, std)])
train_data = CycleGAN_Dataset(datapath=datapath, transforms=transform)
training_dataset = DataLoader(train_data, batch_size=batchsize, shuffle=True)

# Define networks
G_A2B = Generator(n_down=2, n_up=2, n_res=9, in_features=3).to(device)
G_B2A = Generator(n_down=2, n_up=2, n_res=9, in_features=3).to(device)
D_A = Discriminator(n_layers=3).to(device)
D_B = Discriminator(n_layers=3).to(device)

g_opt = optim.Adam(itertools.chain(G_A2B.parameters(), G_B2A.parameters()), lr=lr, betas=betas)
d_A_opt = optim.Adam(D_A.parameters(), lr=lr, betas=betas)
d_B_opt = optim.Adam(D_B.parameters(), lr=lr, betas=betas)

g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(g_opt, lr_lambda=loss_scheduler(decay_epoch).f)
d_a_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(d_A_opt, lr_lambda=loss_scheduler(decay_epoch).f)
d_b_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(d_B_opt, lr_lambda=loss_scheduler(decay_epoch).f)

adv_loss = nn.MSELoss()
l1_norm = nn.L1Loss()
criterion_idn = nn.L1Loss()

buffer_for_fakeA = Image_History_Buffer()
buffer_for_fakeB = Image_History_Buffer()

## training

In [1]:
for epoch in range(1, n_epochs+1):
    G_B2A.train()
    G_A2B.train()
    D_A.train()
    D_B.train()
    for idx, (imgA, imgB) in enumerate(training_dataset):
        imgA = imgA.to(device)
        imgB = imgB.to(device)
        imgA_fake, imgB_fake = G_B2A(imgB), G_A2B(imgA)
        imgA_rec, imgB_rec = G_B2A(imgB_fake), G_A2B(imgA_fake)
        if lambda_id_val > 0:
            iden_imgA, iden_imgB = G_B2A(imgA), G_A2B(imgB)
        
        # Update the discriminator (D_A, D_B)
        d_A_opt.zero_grad()
        disA_out_real = D_A(imgA)
        imgA_fake_ = buffer_for_fakeA.get_images(imgA_fake)
        disA_out_fake = D_A(imgA_fake_.detach())
        d_lossA_real = adv_loss(disA_out_real, torch.tensor(1.0).expand_as(disA_out_real).to(device))
        d_lossA_fake = adv_loss(disA_out_fake, torch.tensor(0.0).expand_as(disA_out_fake).to(device))
        disA_loss = (d_lossA_real + d_lossA_fake) * 0.5
        disA_loss.backward()
        d_A_opt.step()
        
        d_B_opt.zero_grad()
        disB_out_real = D_B(imgB)
        imgB_fake_ = buffer_for_fakeB.get_images(imgB_fake)
        disB_out_fake = D_B(imgB_fake_.detach())
        d_lossB_real = adv_loss(disB_out_real, torch.tensor(1.0).expand_as(disB_out_real).to(device))
        d_lossB_fake = adv_loss(disB_out_fake, torch.tensor(0.0).expand_as(disA_out_fake).to(device))
        disB_loss = (d_lossB_real + d_lossB_fake) * 0.5
        disB_loss.backward()
        d_B_opt.step()
        
        # Update the generator (G)
        g_opt.zero_grad()
        disB_out_fake = D_B(imgB_fake)
        disA_out_fake = D_A(imgA_fake)
        g_lossA = adv_loss(disA_out_fake, torch.tensor(1.0).expand_as(disA_out_fake).to(device))
        g_lossB = adv_loss(disB_out_fake, torch.tensor(1.0).expand_as(disB_out_fake).to(device))
        gen_adv_loss = g_lossA + g_lossB
        
        cycle_consistency_loss = l1_norm(imgA_rec, imgA) + l1_norm(imgB_rec, imgB)
        if lambda_id_val > 0:
            identity_loss = criterion_idn(iden_imgA, imgA) + criterion_idn(iden_imgB, imgB)
            gen_loss = gen_adv_loss + lambda_val * cycle_consistency_loss + lambda_id_val * identity_loss
        else:
            gen_loss = gen_adv_loss + lambda_val * cycle_consistency_loss
        gen_loss.backward()
        g_opt.step()
        
        if idx % 100 == 0:
            print('Training epoch: {} [{}/{} ({:.0f}%)] | D loss (A): {:.6f} | D loss (B): {:.6f} | G loss: {:.6f} | Consistency: {:.6f}  |'\
                  .format(epoch, idx * len(imgA), len(training_dataset.dataset),
                  100. * idx / len(training_dataset), disA_loss.item(), disB_loss.item(), gen_loss.item(), cycle_consistency_loss.item()))
    if epoch % 10 == 0:
        torch.save(G_A2B.state_dict(), f'fant_G_A2B_{epoch}.pth')
        torch.save(G_B2A.state_dict(), f'fant_G_B2A_{epoch}.pth')
        torch.save(D_A.state_dict(), f'fant_D_A_{epoch}.pth')
        torch.save(D_B.state_dict(), f'fant_D_B_{epoch}.pth')        

## inference

In [None]:
G_A2B.load_state_dict(torch.load('G_A2B_12.pth'))
G_B2A.load_state_dict(torch.load('G_B2A_12.pth'))
D_A.load_state_dict(torch.load('D_A_12.pth'))
D_B.load_state_dict(torch.load('D_B_12.pth'))

In [None]:
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
crop = transforms.Resize(900)
normalize = transforms.Normalize(mean=mean, std=std)
to_tensor = transforms.ToTensor()
transform = transforms.Compose([crop, to_tensor, normalize])

A_path = glob(os.path.join('fantasy', '*', '*.jpg'))[10]
B_path = glob(os.path.join('Photos', '*.jpg'))[6]


imgA = Image.open(A_path)
imgB = Image.open(B_path)
imgA_tensor = transform(imgA).to(device)[None,:,:,:]
imgB_tensor = transform(imgB).to(device)[None,:,:,:]
G_A2B.eval()
G_B2A.eval()
with torch.no_grad():
    fake_B = G_A2B(imgA_tensor)
    fake_A = G_B2A(imgB_tensor)
    rec_B = G_A2B(fake_A)
    rec_A = G_B2A(fake_B)
    mean = torch.tensor(mean, dtype=torch.float32)[None,:,None,None].to(device)
    std = torch.tensor(std, dtype=torch.float32)[None,:,None,None].to(device)
    fake_B = (fake_B * std) + mean
    fake_A = (fake_A * std) + mean

fake_imgA = Image.fromarray((fake_A * 256.).clamp(min=0, max=255).data.cpu().squeeze().permute(1,2,0).numpy().astype(np.uint8))
fake_imgB = Image.fromarray((fake_B * 256.).clamp(min=0, max=255).data.cpu().squeeze().permute(1,2,0).numpy().astype(np.uint8))
plt_items = [fake_imgB, fake_imgA]
title_list = ['Fake_B', 'Fake_A']
rows = 2
cols = 1
axes=[]
fig=plt.figure(figsize=(16, 9))

for i in range(rows*cols):
    item = plt_items[i]
    axes.append( fig.add_subplot(rows, cols, i+1) )
    axes[-1].set_title(title_list[i])
    plt.axis('off')

    plt.imshow(item)
fig.tight_layout()    
plt.show()

In [None]:
rec_A = (rec_A * std) + mean
Image.fromarray((rec_A * 256.).clamp(min=0, max=255).data.cpu().squeeze().permute(1,2,0).numpy().astype(np.uint8))