In [None]:
# CycleGAN 

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
#import torch.utils.data as data
import random
from torchvision import transforms
from PIL import Image
import os
import itertools
import matplotlib.pyplot as plt
%matplotlib inline

params = {
    'batch_size': 1, 
    'input_size': 256,
    'resize_scale': 286,
    'crop_size': 256, 
    'fliplr': True, 
    'num_epochs': 100, 
    'decay_epoch':100, 
    'ngf':32, #number of generator filters
    'ndf':64, #number of discriminator filters
    'num_resnet': 6, #number of resnet blocks
    'lrG': 0.0002, 
    'lrD': 0.0002, 
    'beta1': 0.5,   #beta1 for Adam
    'beta2': 0.999, #beta2 for Adam
    'lambdaA': 10,  #lambdaA for cycle loss
    'lambdaB': 10,  #lambdaB for cycle loss    
}

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

data_dir = './data/vangogh2photo'
save_dir = './results'
if not os.path.isdir(save_dir):
    os.mkdir(save_dir)
    
def to_np(x):
    return x.data.cpu().numpy()

def plot_train_result(real_image, gen_image, recon_image, epoch, 
                      save=False, show=True, figsize=(15,15)):
    fig,axes = plt.subplots(2, 3, figsize=figsize)
    imgs = [to_np(real_image[0]), to_np(gen_image[0]), to_np(recon_image[0]),
            to_np(real_image[1]), to_np(gen_image[1]), to_np(recon_image[1])]
    for ax, img in zip(axes.flatten(), imgs):
        ax.axis('off')
        img = img.squeeze()
        #img = (((img - img.min()) * 255) / (img.max()-img.min())).transpose(1,2,0).astype(np.int8)
        #print(img.min(), img.max())
        #img = (((img - (-1)) * 255) / (1-(-1))).transpose(1,2,0).astype(np.int8)
        img = (((img - (-1))) / (1-(-1))).transpose(1,2,0)
        #print(img.min(), img.max())
        ax.imshow(img, cmap=None, aspect='equal')
    plt.subplots_adjust(wspace=0, hspace=0)
    title = 'Epoch {}'.format(epoch+1)
    fig.text(0.5, 0.04, title, ha='center')
    
    if save:
        save_fn = save_dir + '/Result_epoch_{:d}.png'.format(epoch+1)
        plt.savefig(save_fn)
        
    if show:
        plt.show()
    else:
        plt.close()

        
class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []
            
    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs += 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0,1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size - 1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = torch.cat(return_images, 0)
        return return_images

class DatasetFromFolder(Dataset):
    def __init__(self, image_dir, subfolder='train', transform=None,
                 resize_scale=None, crop_size=None, fliplr=False):
        super(DatasetFromFolder, self).__init__()
        self.input_path = os.path.join(image_dir, subfolder)
        file_list = sorted(os.listdir(self.input_path))
        self.image_filenames = []
        for x in file_list:
            tmp, ext = os.path.splitext(x)
            if ext.lower() == '.jpg':
                self.image_filenames.append(x)
        #self.image_filenames = [x for x in sorted(os.listdir(self.input_path))]
        self.transform = transform
        self.resize_scale = resize_scale
        self.crop_size = crop_size
        self.fliplr = fliplr
    
    def __getitem__(self, index):
        img_fn = os.path.join(self.input_path, self.image_filenames[index])
        img = Image.open(img_fn).convert('RGB')
        
        if self.resize_scale:
            img = img.resize((self.resize_scale, self.resize_scale), Image.BILINEAR)

        if self.crop_size:
            x = random.randint(0, self.resize_scale - self.crop_size + 1)
            y = random.randint(0, self.resize_scale - self.crop_size + 1)
            img = img.crop((x, y, x+self.crop_size, y+self.crop_size))
            
        if self.fliplr:
            if random.random() < 0.5:
                img = img.transpose(Image.FLIP_LEFT_RIGHT)
        
        if self.transform is not None:
            img = self.transform(img)
        
        return img

    def __len__(self):
        return len(self.image_filenames)

    
class ConvBlock(nn.Module):
    def __init__(self, input_size, output_size, kernel_size=3, 
                 stride=2, padding=1, activation='ReLU', normalization=True):
        super(ConvBlock, self).__init__()
        conv = nn.Conv2d(input_size, output_size, kernel_size, 
                         stride, padding)
        norm_layer = nn.InstanceNorm2d(output_size)
        if activation == 'ReLU':
            act_func = nn.ReLU(True)
        elif activation == 'LeakyReLU':
            act_func = nn.LeakyReLU(0.2, True)
            
        if normalization:
            model = [conv, norm_layer, act_func]
        else:
            model = [conv, act_func]
        self.conv_block = nn.Sequential(*model)
        
    def forward(self, x):
        return self.conv_block(x)
        
        
class DeconvBlock(nn.Module):
    def __init__(self, input_size, output_size, kernel_size=3, 
                 stride=2, padding=1, output_padding=1, 
                 activation='ReLU', normalization=True):
        super(DeconvBlock, self).__init__()
        deconv = nn.ConvTranspose2d(input_size, output_size, kernel_size, 
                                    stride, padding, output_padding)
        norm_layer = nn.InstanceNorm2d(output_size)
        
        if activation == 'ReLU':
            act_func = nn.ReLU(True)
        elif activation == 'LeakyReLU':
            act_func = nn.LeakyReLU(0.2, True)
        
        if normalization:              
            model = [deconv, norm_layer, act_func]
        else:
            model = [deconv, act_func]
        self.deconv_block = nn.Sequential(*model)
        
    def forward(self, x):
        return self.deconv_block(x)

    
class ResnetBlock(nn.Module):
    def __init__(self, num_filter, kernel_size=3, stride=1, padding=0):
        super(ResnetBlock, self).__init__()
        conv = nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding)
        norm_layer = nn.InstanceNorm2d(num_filter)
        relu = nn.ReLU(True)
        pad = nn.ReflectionPad2d(1)
        model = [pad, conv, norm_layer, relu, pad, conv, norm_layer]
        self.resnet_block = nn.Sequential(*model)
        
    def forward(self, x):
        out = self.resnet_block(x) + x
        return out
    
    
class Generator(nn.Module):
    def __init__(self, input_dim, num_filter, output_dim, num_resnet):
        super(Generator, self).__init__()
        
        #Downsampling 
        model = [nn.ReflectionPad2d(3), 
                 ConvBlock(input_dim, num_filter, kernel_size=7, stride=1, padding=0),
                 ConvBlock(num_filter, num_filter*2), 
                 ConvBlock(num_filter*2, num_filter*4)]
        #Resnet blocks
        for i in range(num_resnet):
            model += [ResnetBlock(num_filter*4)]
        #Upsampling 
        model += [DeconvBlock(num_filter*4, num_filter*2), 
                  DeconvBlock(num_filter*2, num_filter),
                  nn.ReflectionPad2d(3)]
        #Ouput layer
        model += [nn.Conv2d(num_filter, output_dim, kernel_size=7, stride=1, padding=0)]
        model += [nn.Tanh()]
        
        self.model = nn.Sequential(*model)
        
    def forward(self, x):
        return self.model(x)
    
    def normal_weight_init(self,mean=0.0,std=0.02):
        for m in self.children():
            if isinstance(m,ConvBlock):
                nn.init.normal_(m.conv.weight,mean,std)
            if isinstance(m,DeconvBlock):
                nn.init.normal_(m.deconv.weight,mean,std)
            if isinstance(m,ResnetBlock):
                nn.init.normal_(m.conv.weight,mean,std)
                nn.init.constant_(m.conv.bias,0)

    
class Discriminator(nn.Module):
    def __init__(self, input_dim, num_filter, output_dim, n_layer=3):
        super(Discriminator, self).__init__()
        
        # PatchGAN
        model = [ConvBlock(input_dim, num_filter, kernel_size=4, stride=2, padding=1, 
                           activation='LeakyReLU', normalization=False)]
        in_channel = num_filter
               
        for i in range(n_layer):
            out_channel = in_channel * 2
            model += [ConvBlock(in_channel, out_channel, kernel_size=4, 
                                stride=2, padding=1, activation='LeakyReLU')]
            in_channel = out_channel
        
        model += [nn.Conv2d(out_channel, output_dim, kernel_size=4, stride=1, padding=1)]
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        return self.model(x)
    
    def normal_weight_init(self,mean=0.0,std=0.02):
        for m in self.children():
            if isinstance(m,ConvBlock):
                nn.init.normal_(m.conv.weight,mean,std)
            if isinstance(m,DeconvBlock):
                nn.init.normal_(m.deconv.weight,mean,std)
            if isinstance(m,ResnetBlock):
                nn.init.normal_(m.conv.weight,mean,std)
                nn.init.constant_(m.conv.bias,0)

In [None]:
# pre-processing 
transform = transforms.Compose([
    transforms.Resize(size=params['input_size']),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

# Load train data
train_data_A = DatasetFromFolder(data_dir, subfolder='trainA', transform=transform,
                                resize_scale=params['resize_scale'], crop_size=params['crop_size'], fliplr=params['fliplr'])
train_data_loader_A = DataLoader(dataset=train_data_A, batch_size=params['batch_size'], shuffle=True)
train_data_B = DatasetFromFolder(data_dir, subfolder='trainB', transform=transform,
                                resize_scale=params['resize_scale'], crop_size=params['crop_size'], fliplr=params['fliplr'])
train_data_loader_B = DataLoader(dataset=train_data_B, batch_size=params['batch_size'], shuffle=True)

#Load test data
test_data_A = DatasetFromFolder(data_dir, subfolder='testA', transform=transform)
test_data_loader_A = DataLoader(dataset=test_data_A, batch_size=params['batch_size'], shuffle=False)
test_data_B = DatasetFromFolder(data_dir, subfolder='testB', transform=transform)
test_data_loader_B = DataLoader(dataset=test_data_B, batch_size=params['batch_size'], shuffle=False)

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# Get specific test images
test_real_A_data = train_data_A.__getitem__(11).unsqueeze(0) # Convert to 4d tensor (BxNxHxW)
test_real_B_data = train_data_B.__getitem__(91).unsqueeze(0)

#Build Model 
G_A = Generator(3, params['ngf'], 3, params['num_resnet']).to(device) # input_dim, num_filter, output_dim, num_resnet
G_B = Generator(3, params['ngf'], 3, params['num_resnet']).to(device)

D_A = Discriminator(3, params['ndf'], 1).to(device) # input_dim, num_filter, output_dim
D_B = Discriminator(3, params['ndf'], 1).to(device)

G_A.normal_weight_init()
G_B.normal_weight_init()
D_A.normal_weight_init()
D_B.normal_weight_init()

G_optimizer = optim.Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=params['lrG'], betas=(params['beta1'], params['beta2']))
D_A_optimizer = optim.Adam(D_A.parameters(), lr=params['lrD'], betas=(params['beta1'], params['beta2']))
D_B_optimizer = optim.Adam(D_B.parameters(), lr=params['lrD'], betas=(params['beta1'], params['beta2']))

MSE_Loss = nn.MSELoss().to(device)
L1_Loss = nn.L1Loss().to(device)

# # Training GAN
D_A_avg_losses = []
D_B_avg_losses = []
G_A_avg_losses = []
G_B_avg_losses = []
cycle_A_avg_losses = []
cycle_B_avg_losses = []

# Generated image pool
num_pool = 50
fake_A_pool = ImagePool(num_pool)
fake_B_pool = ImagePool(num_pool)

step = 0
for epoch in range(params['num_epochs']):
    D_A_losses = []
    D_B_losses = []
    G_A_losses = []
    G_B_losses = []
    cycle_A_losses = []
    cycle_B_losses = []
    
    # Learing rate decay 
    if(epoch + 1) > params['decay_epoch']:
        D_A_optimizer.param_groups[0]['lr'] -= params['lrD'] / (params['num_epochs'] - params['decay_epoch'])
        D_B_optimizer.param_groups[0]['lr'] -= params['lrD'] / (params['num_epochs'] - params['decay_epoch'])
        G_optimizer.param_groups[0]['lr'] -= params['lrG'] / (params['num_epochs'] - params['decay_epoch'])
        
    
    # training 
    for i, (real_A, real_B) in enumerate(zip(train_data_loader_A, train_data_loader_B)):
        
        # input image data
        real_A = real_A.to(device)
        real_B = real_B.to(device)
        
        # train generator G: A->B
        fake_B = G_A(real_A)
        D_B_fake_decision = D_B(fake_B)
        G_A_loss = MSE_Loss(D_B_fake_decision, torch.ones(D_B_fake_decision.size()).to(device))
        
        # forward cycle loss
        recon_A = G_B(fake_B)
        cycle_A_loss = L1_Loss(recon_A, real_A) * params['lambdaA']
        
        # train generator F: B -> A
        fake_A = G_B(real_B)
        D_A_fake_decision = D_A(fake_A)
        G_B_loss = MSE_Loss(D_A_fake_decision, torch.ones(D_A_fake_decision.size()).to(device))
        
        # backward cycle loss
        recon_B = G_A(fake_A)
        cycle_B_loss = L1_Loss(recon_B, real_B) * params['lambdaB']
        
        # Back propagation
        G_loss = G_A_loss + G_B_loss + cycle_A_loss + cycle_B_loss
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()
        
        
        # train discriminator D_A 
        D_A_real_decision = D_A(real_A)
        D_A_real_loss = MSE_Loss(D_A_real_decision, torch.ones(D_A_real_decision.size()).to(device))
        
        fake_A = fake_A_pool.query(fake_A)
        
        D_A_fake_decision = D_A(fake_A.detach())
        D_A_fake_loss = MSE_Loss(D_A_fake_decision, torch.zeros(D_A_fake_decision.size()).to(device))
        
        # Back propagation
        D_A_loss = (D_A_real_loss + D_A_fake_loss) * 0.5
        D_A_optimizer.zero_grad()
        D_A_loss.backward()
        D_A_optimizer.step()
        
        # train discriminator D_B 
        D_B_real_decision = D_B(real_B)
        D_B_real_loss = MSE_Loss(D_B_real_decision, torch.ones(D_B_fake_decision.size()).to(device))
        
        fake_B = fake_B_pool.query(fake_B)
        
        D_B_fake_decision = D_B(fake_B.detach())
        D_B_fake_loss = MSE_Loss(D_B_fake_decision, torch.zeros(D_B_fake_decision.size()).to(device))
        
        # Back propagation
        D_B_loss = (D_B_real_loss + D_B_fake_loss) * 0.5
        D_B_optimizer.zero_grad()
        D_B_loss.backward()
        D_B_optimizer.step()
        
        # loss values
        D_A_losses.append(D_A_loss.item())
        D_B_losses.append(D_B_loss.item())
        G_A_losses.append(G_A_loss.item())
        G_B_losses.append(G_B_loss.item())
        cycle_A_losses.append(cycle_A_loss.item())
        cycle_B_losses.append(cycle_B_loss.item())

        if i%100 == 0:
            print('Epoch:[{}/{}], Step [{}/{}], D_A_loss: {:.4f}, D_B_loss: {:.4f}, G_A_loss: {:.4f}, G_B_loss: {:.4f}'
                  .format(epoch+1, params['num_epochs'], i+1, len(train_data_loader_A), 
                          D_A_loss.item(), D_B_loss.item(), G_A_loss.item(), G_B_loss.item()))
            
        step += 1
        
    D_A_avg_loss = torch.mean(torch.FloatTensor(D_A_losses))
    D_B_avg_loss = torch.mean(torch.FloatTensor(D_B_losses))
    G_A_avg_loss = torch.mean(torch.FloatTensor(G_A_losses))
    G_B_avg_loss = torch.mean(torch.FloatTensor(G_B_losses))
    cycle_A_avg_loss = torch.mean(torch.FloatTensor(cycle_A_losses))
    cycle_B_avg_loss = torch.mean(torch.FloatTensor(cycle_B_losses))

    # avg loss values for plot
    D_A_avg_losses.append(D_A_avg_loss.item())
    D_B_avg_losses.append(D_B_avg_loss.item())
    G_A_avg_losses.append(G_A_avg_loss.item())
    G_B_avg_losses.append(G_B_avg_loss.item())
    cycle_A_avg_losses.append(cycle_A_avg_loss.item())
    cycle_B_avg_losses.append(cycle_B_avg_loss.item())
    
    # Show result for test image
    test_real_A = test_real_A_data.to(device)
    test_fake_B = G_A(test_real_A)
    test_recon_A = G_B(test_fake_B)

    test_real_B = test_real_B_data.to(device)
    test_fake_A = G_B(test_real_B)
    test_recon_B = G_A(test_fake_A)
    

    plot_train_result([test_real_A, test_real_B], [test_fake_B, test_fake_A], [test_recon_A, test_recon_B],
                            epoch, save=True)

all_losses = pd.DataFrame()
all_losses['D_A_avg_losses'] = D_A_avg_losses
all_losses['D_B_avg_losses'] = D_B_avg_losses
all_losses['G_A_avg_losses'] = G_A_avg_losses
all_losses['G_B_avg_losses'] = G_B_avg_losses
all_losses['cycle_A_avg_losses'] = cycle_A_avg_losses
all_losses['cycle_B_avg_losses'] = cycle_B_avg_losses
all_losses.to_csv('avg_losses',index=False)