# [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://arxiv.org/abs/1703.10593)

### Jun-Yan Zhu, Taesung Park, Phillip Isola & Alexei A. Efros

Algorithm for unpaired image-to-image translation.

It consists of two CNNs $G$ and $F$, where $G$ maps the image $X$ to the domain $D_Y$ and $F$ maps $Y$ back to the domain $D_X$. The cyclical consitency error then forces the CNNs to produce images where input image $X$ and projected image $F(G(X))=\hat{x}$ as well as input $Y$ and projected $G(F(Y))=\hat{y}$ are pixelwise close to eachother. The graph below from the original paper cited above illustrates the situation.


<img src="images/cycleGAN-formulation.png">

## Set Up the cycleGAN Model

In [None]:
%pylab inline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from random import sample

import os

# Check for CUDA device
device_txt = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_txt)
# Print Device Type
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(device))

### Define by User ###
# Choose dataset, available ["ae_photos","apple2orange", "cezanne2photo", "cityscapes", "facades",
#                            "horse2zebra", "iphone2dslr_flower", "maps", "mini", "mini_colorization",
#                            "mini_pix2pix", "monet2photo", "summer2winter_yosemite", "ukiyoe2photo",
#                            "vangogh2photo"]
use_dataset = 'horse2zebra'
######################
# Define Image size
img_size = 128 if use_dataset=="cityscapes" else 256

# Define Training Details
epochs     = 200
batch_size = 1
# Learning Rate (Adam) and beta1 (beta2=0.999 always)
lr    = 0.0002
beta1 = 0.5
# Epoch when starting with linear LR decay
epoch_lr_decline = 100
# Loss functions lambdas
## Cyclincal Loss
cyc_lamb = 10.
## Identity Loss for paintings
ident_lamb = .5*cyc_lamb if use_dataset in ["monet2photo", "iphone2dslr_flower"] else 0.
lambs = (cyc_lamb, ident_lamb)

### Download data

In [None]:
import urllib.request

# Download dataset

url = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/'+use_dataset+'.zip'
urllib.request.urlretrieve(url, './data.zip')
# Make folders, unzip and delete file
!mkdir data saved_models
!unzip -qq -o data.zip -d data/ && rm data.zip

## Define the two Generator Networks

Following the archetecture of transformation net form [Johnson et al.](https://cs.stanford.edu/people/jcjohns/eccv16/):

In [None]:
class Basic_Layer(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride, norm_layer, use_relu=True):
        super(Basic_Layer, self).__init__()
        self.use_relu = use_relu
        self.pad  = nn.ReflectionPad2d(kernel_size // 2)
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride)
        self.norm = norm_layer(out_ch)
       
    def forward(self, input):
        input = self.pad(input)
        if self.use_relu:
            out = F.relu(self.norm(self.conv(input)), inplace=True)
        else:
            out = self.norm(self.conv(input))
        return out
        
class Res_Block(nn.Module):
    def __init__(self, n_ch, norm_layer):
        super(Res_Block, self).__init__()
        self.layer1 = Basic_Layer(n_ch, n_ch, kernel_size=3, stride=1, norm_layer=norm_layer)
        self.layer2 = Basic_Layer(n_ch, n_ch, kernel_size=3, stride=1, norm_layer=norm_layer,
                                  use_relu=False)
    
    def forward(self, input):
        identity = input 
        input = self.layer1(input)
        input = self.layer2(input) 
        out   = input + identity
        return out

    
class Upsample(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride, norm_layer):
        super(Upsample, self).__init__()
        self.deconv = nn.ConvTranspose2d(in_ch, out_ch, kernel_size, 2, padding=kernel_size // 2, output_padding=1)
        self.norm   = norm_layer(out_ch)
            
    def forward(self, input):
        out = F.relu(self.norm(self.deconv(input)), inplace=True)
        return out

class G_net(nn.Module):
    def __init__(self, in_nc, out_nc, ngf=64, n_res_blocks=6, norm_layer=nn.InstanceNorm2d):
        super(G_net, self).__init__()
        # Define Encoding layers
        self.enco1  = Basic_Layer(in_nc, ngf, kernel_size=7, stride=1, norm_layer=norm_layer)
        self.enco2  = Basic_Layer(ngf, ngf*2, kernel_size=3, stride=2, norm_layer=norm_layer)
        self.enco3  = Basic_Layer(ngf*2, ngf*4, kernel_size=3, stride=2, norm_layer=norm_layer)
        # Define Residual layers
        self.residual = nn.Sequential(*[Res_Block(ngf*4, norm_layer=norm_layer)]*n_res_blocks)
        # Define Decoding layers
        self.deco1  = Upsample(ngf*4, ngf*2, kernel_size=3, stride=2, norm_layer=norm_layer)
        self.deco2  = Upsample(ngf*2, ngf, kernel_size=3, stride=2, norm_layer=norm_layer)
        self.deco3  = nn.Conv2d(ngf, out_nc, kernel_size=7, stride=1, padding=3)

    def forward(self, input):
        # Encoding
        input = self.enco1(input)
        input = self.enco2(input)
        input = self.enco3(input)
        # Residual
        input = self.residual(input)
        # Decoding
        input = self.deco1(input)
        input = self.deco2(input)
        input = self.deco3(input)
        return torch.tanh(input)

### Define Discriminator

In [None]:
# Define 70x70 PatchGAN Discriminator
class D_patch(nn.Module):
    def __init__(self, in_nc, ndf=64, norm_layer = nn.InstanceNorm2d):
        super(D_patch, self).__init__()
        self.conv1 = nn.Conv2d(in_nc, ndf, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
        self.norm2 = norm_layer(ndf*2)
        self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
        self.norm3 = norm_layer(ndf*4)
        self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
        self.norm4 = norm_layer(ndf*8)
        self.final = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=1, padding=1)
    
    def forward(self, input):
        input = F.leaky_relu(self.conv1(input), negative_slope=0.2, inplace=True)
        input = F.leaky_relu(self.norm2(self.conv2(input)), negative_slope=0.2, inplace=True)
        input = F.leaky_relu(self.norm3(self.conv3(input)), negative_slope=0.2, inplace=True)
        input = F.leaky_relu(self.norm4(self.conv4(input)), negative_slope=0.2, inplace=True)
        return torch.sigmoid(self.final(input))

#### Weights Initialization

In [None]:
# custom weights initialization
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, mean=1.0, std=0.02)
        nn.init.zeros_(m.bias)

## Load Data

In [None]:
class DatasetFromFolder(Dataset):
    def __init__(self, root, mode='train', unaligned=True, transform=None):
        """
        Args:
            root (str): Path to folders for train, val amd test
            mode (str): Either 'train' or 'test'
            direction (str): Either 'AtoB' or 'BtoA' indicating which direction the prediction should go
            unaligned (bool): If unpaired or paired dataset
            transform (torchvision obj) : Usual image preprocessing
        """
        super(DatasetFromFolder, self).__init__()
        self.files_A   = [os.path.join(root, '%sA/' % mode) + x for x in sorted(os.listdir(os.path.join(root, '%sA/' % mode)))]
        self.files_B   = [os.path.join(root, '%sB/' % mode) + x for x in sorted(os.listdir(os.path.join(root, '%sB/' % mode)))]
        self.transform = transform
        self.unaligned = unaligned
        

    def __getitem__(self, index):
        # Load Image
        A  = Image.open(self.files_A[index % len(self.files_A)]).convert('RGB')
        if self.unaligned:
            B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]).convert('RGB')
        else:
            B = Image.open(self.files_B[index % len(self.files_B)]).convert('RGB')
            
        # preprocessing
        if self.transform is not None:
            A = self.transform(A)
            B = self.transform(B)
        
        return A, B

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [None]:
# Select dataset
data_folder = './data/'+use_dataset+'/'

train_ds = DatasetFromFolder(root = data_folder, mode='train',
                             transform = transforms.Compose([
                                 transforms.Resize(int(img_size*1.12), Image.BICUBIC),
                                 transforms.RandomCrop(img_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                             ])
                            )

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)

## Training

Before training the model we need a fuction, that stores a buffer of fakes, see [Shrivastava et al.](https://arxiv.org/abs/1612.07828)

In [None]:
class DHistBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def get_data(self, batch):
        b_size,_,_,_ = batch.shape
        to_return = []
        
        # The paper is not really clear what happens for batch size = 1, so I'll take code from aitorzip
        # and randomly pick either from current or from history
        if b_size==1:
            if len(self.data) < self.max_size:
                self.data.append(batch)
                to_return.append(batch)
            else:
                if random.rand() > 0.5:
                    i = random.randint(0, self.max_size)
                    to_return.append(self.data[i].clone())
                    self.data[i] = batch
                else:
                    to_return.append(batch)
        # The papers says that it takes half data from old buffered fakes and half from new generated
        else:
            # At the beginning we need to feed some data to the buffer
            if len(self.data) == 0:
                for b_element in batch:
                    self.data.append(b_element.unsqueeze(0))
                    to_return.append(b_element.unsqueeze(0))
            else:
                take_id_batch  = sample(range(b_size), int(b_size/2))
                leave_id_batch = [x for x in range(b_size) if x not in take_id_batch]
                random.shuffle(leave_id_batch)
                take_id_hist   = sample(range(len(self.data)), int(b_size/2))
                leave_id_hist = [x for x in range(b_size) if x not in take_id_batch]
                random.shuffle(leave_id_hist)
                
                for b_element in batch[take_id_batch]:
                    to_return.append(b_element.unsqueeze(0))
                    
                for idx in range(int(b_size/2)):
                    to_return.append(self.data[take_id_hist[idx]].clone())
                    self.data[take_id_hist[idx]] = batch[leave_id_batch[idx]].unsqueeze(0)
                # Also the paper is not really clear for odd batch sizes.
                # We again randomly pick from buffer or new generated
                if b_size%2!=0:
                    if random.rand() > 0.5:
                        to_return.append(self.data[leave_id_hist[0]].clone())
                        self.data[leave_id_hist[0]] = batch[leave_id_batch[-1]].unsqueeze(0)
                    else:
                        to_return.append(batch[leave_id_batch[-1]].unsqueeze(0))
                
                for b_element in batch:
                    if len(self.data) < self.max_size:
                        self.data.append(b_element.unsqueeze(0))
        
        # Shuffle so that new Fake images are not allways at the beginning of the batch
        random.shuffle(to_return)
        
        return torch.cat(to_return)

### Compact function to get optimizer and model

In [None]:
def get_cycleGAN(in_nc, out_nc, ngf, ndf, device, use_dataset, img_size, lr = 0.0002, beta1 = 0.5):
    
    # if img size = 128 use 6 res_blocks else 9
    if img_size<=128:
        n_res_blocks = 6
    else:
        n_res_blocks = 9
    
    # Get models
    model_G_A2B = G_net(in_nc, out_nc, ngf, n_res_blocks).to(device)
    model_G_B2A = G_net(out_nc, in_nc, ngf, n_res_blocks).to(device)
    # Combine
    Generators = (model_G_A2B, model_G_B2A)
    
    model_D_A = D_patch(in_nc, ndf).to(device)
    model_D_B = D_patch(out_nc, ndf).to(device)
    # Combine
    Discrimators = (model_D_A, model_D_B)
    
    # Init weights
    [m.apply(weights_init) for m in (model_G_A2B, model_G_B2A, model_D_B, model_D_A)]
    # Set starting epoch to 0 as default
    epoch_start    = 0
    # Discriminator Buffer
    D_buffer = ([], [])
    
    if os.path.isfile('./saved_models/cycleGAN_'+use_dataset+'_'+str(img_size)+'_saved_model.tar'):
        pretrained = "Users_answer"
        while pretrained not in ["y","n"]:
            pretrained = input("Pretrained Model available, use it? [y/n]:")
        # If User says "y", load weights
        if pretrained=="y":
            saved_data = torch.load('./saved_models/cycleGAN_'+use_dataset+'_'+str(img_size)+'_saved_model.tar',
                                    map_location=device)
            model_G_A2B.load_state_dict(saved_data['G_A2B_state_dict'])
            model_G_B2A.load_state_dict(saved_data['G_B2A_state_dict'])
            model_D_A.load_state_dict(saved_data['D_A_state_dict'])
            model_D_B.load_state_dict(saved_data['D_B_state_dict'])
            epoch_start = saved_data['current_epoch']
            D_buffer = saved_data['fake_buffer']
                
    # Define Adam optimizer
    opt_G = optim.Adam(list(model_G_A2B.parameters()) + list(model_G_B2A.parameters()),
                       lr=lr, betas=(beta1,0.999))
    opt_D = optim.Adam(list(model_D_A.parameters()) + list(model_D_B.parameters()),
                       lr=lr, betas=(beta1,0.999))
    
    return Generators, Discrimators, opt_G, opt_D, epoch_start, D_buffer

### Train Model

In [None]:
def fit_cycleGAN(epochs, G_models, D_models, opt_G, opt_D, train_dl, device, dis_buffer,
                 lambs = (10., 5.), epoch_start = 0, epoch_lr_decline = 100, show_iter=None):
    
    # Check if start epoch is greater than epoch_start and stop function if so
    if epochs<epoch_start:
        return
    
    # Define Losses
    GAN_crit   = nn.MSELoss()
    Cycle_crit = nn.L1Loss()
    Iden_crit  = nn.L1Loss()
    
    # Extract Models and optimizers
    model_G_A2B, model_G_B2A = G_models
    model_D_A, model_D_B     = D_models
    
    # Extract lambda
    cyc_lamb, iden_lamb = lambs
    
    # Extract buffers
    out_fake_A_buffer = DHistBuffer(max_size=50)
    out_fake_B_buffer = DHistBuffer(max_size=50)
    
    ## Load data
    out_fake_A_buffer.data, out_fake_B_buffer.data = dis_buffer
    
    # Define Learning Rate Decay
    if epochs>epoch_lr_decline:
        lambda_opt = lambda epoch: 1.0 - max(0, (epoch - epoch_lr_decline) / (epochs - epoch_lr_decline))
    else:
        lambda_opt = lambda epoch: 1.0
    LR_scheduler_G = optim.lr_scheduler.LambdaLR(opt_G, lr_lambda=lambda_opt)
    LR_scheduler_D = optim.lr_scheduler.LambdaLR(opt_D, lr_lambda=lambda_opt)
    
    # If not otherwise defined: show_iter = one epoch
    if show_iter is None:
        show_iter=len(train_dl)
    
    # Take time
    start_time = time.time()
    
    for epoch in range(epoch_start, epochs):
        # Start Training Loop
        for i, (A, B) in enumerate(train_dl):
            # push images to device
            real_A, real_B = A.to(device), B.to(device)
            
            ### Discriminator Training ###
            opt_D.zero_grad()
            
            # Real pictures
            ## For A
            out_real_A   = model_D_A(real_A)
            lossD_real_A = GAN_crit(out_real_A, torch.ones(out_real_A.size(), device=device))
            ## For B
            out_real_B   = model_D_B(real_B)
            lossD_real_B = GAN_crit(out_real_B, torch.ones(out_real_B.size(), device=device))
            ## Combine
            lossD_real = lossD_real_A + lossD_real_B
            
            # Fake pictures
            ## For A
            out_fake_D_A = model_D_A(out_fake_A_buffer.get_data(model_G_B2A(real_B).detach()))
            ### Loss function with all zeros for being fake
            lossD_fake_A = GAN_crit(out_fake_D_A, torch.zeros(out_fake_D_A.size(), device=device))
            ## For B
            out_fake_D_B = model_D_B(out_fake_B_buffer.get_data(model_G_A2B(real_A).detach()))
            ### Loss function with all zeros for being fake
            lossD_fake_B = GAN_crit(out_fake_D_B, torch.zeros(out_fake_D_B.size(), device=device))
            ## Combine
            lossD_fake = lossD_fake_A + lossD_fake_B
            
            lossD = (lossD_real + lossD_fake)/2
            
            # Backprop
            lossD.backward()
            opt_D.step()
            
            ### Generator Training ###
            opt_G.zero_grad()
            
            # GAN Loss
            ## A->B
            fake_B = model_G_A2B(real_A)
            out_fake_B = model_D_B(fake_B)
            GAN_loss_G_A2B = GAN_crit(out_fake_B, torch.ones(out_fake_B.size(), device=device))
            ## B->A
            fake_A = model_G_B2A(real_B)
            out_fake_A = model_D_A(fake_A)
            GAN_loss_G_B2A = GAN_crit(out_fake_A, torch.ones(out_fake_A.size(), device=device))
            ## Combine
            GAN_loss_G = GAN_loss_G_A2B + GAN_loss_G_B2A
            
            # Cycle Loss
            ## A->B->A
            recov_A  = model_G_B2A(fake_B)
            Cycle_loss_A = Cycle_crit(recov_A, real_A)
            ## B->A->B
            recov_B  = model_G_A2B(fake_A)
            Cycle_loss_B = Cycle_crit(recov_B, real_B)
            ## Combine
            Cycle_loss = Cycle_loss_A + Cycle_loss_B
            
            # Idendity Loss
            Iden_loss_B = Iden_crit(model_G_A2B(real_B), real_B)
            Iden_loss_A = Iden_crit(model_G_B2A(real_A), real_A)
            ## Combine
            Iden_loss = Iden_loss_B + Iden_loss_A
            
            # Total Loss
            lossG = GAN_loss_G + cyc_lamb*Cycle_loss + iden_lamb*Iden_loss
            
            # Backprop
            lossG.backward()
            opt_G.step()
            
            # Show some Optimazation metrics
            if (i+epoch*len(train_dl))%show_iter == 0:
                # get time values
                hours, rem = divmod(time.time()-start_time, 3600)
                minutes, seconds = divmod(rem, 60)
                print('({:0>2}:{:0>2}:{:0>2}) [{}/{}][{}/{}] -> {:.2f}%\tLoss_D: {:.4f}, D(x): {:.4f}\tLoss_G: {:.4f}, D(G(z)): {:.4f}'.format(
                    int(hours),int(minutes), int(seconds), epoch, epochs, i, len(train_dl),
                    100*(i+epoch*len(train_dl))/(epochs*len(train_dl)),
                    lossD, (out_real_A.mean()+out_real_B.mean())/2,
                    lossG, (out_fake_D_A.mean()+out_fake_D_B.mean())/2))
                
            # Show currently inputs and Generated 
            if (i+epoch*len(train_dl))%(show_iter*5) == 0:
                in_As    = real_A
                fakes_B  = model_G_A2B(in_As).detach()
                recovs_A = model_G_B2A(fakes_B).detach()
                in_Bs    = real_B
                fakes_A  = model_G_B2A(in_Bs).detach()
                recovs_B = model_G_A2B(fakes_A).detach()
                img_tmp = torch.cat([in_As, fakes_B, recovs_A, in_Bs, fakes_A, recovs_B], dim=0).cpu()
                plt.figure(figsize=(8,8))
                plt.axis("off")
                plt.imshow(np.transpose(vutils.make_grid(img_tmp, nrow=3, padding=1, normalize=True),(1,2,0)))
                plt.pause(0.001)
        
        #############
        # Save current state and epoch
        torch.save({'G_A2B_state_dict': model_G_A2B.state_dict(),
                    'G_B2A_state_dict': model_G_B2A.state_dict(),
                    'D_A_state_dict': model_D_A.state_dict(),
                    'D_B_state_dict': model_D_B.state_dict(),
                    'current_epoch': epoch+1,
                    'fake_buffer': (out_fake_A_buffer.data, out_fake_B_buffer.data),
                   },'./saved_models/cycleGAN_'+use_dataset+'_'+str(img_size)+'_saved_model.tar')
        
        # Making a step for the LR scheduler
        LR_scheduler_G.step()
        LR_scheduler_D.step()

In [None]:
# Get initialized models and optimizers
G_models, D_models, opt_G, opt_D, epoch_start, buffers = get_cycleGAN(in_nc=3, out_nc=3,
                                                                      ngf=64, ndf=64,
                                                                      device=device,
                                                                      use_dataset=use_dataset,
                                                                      img_size=img_size,
                                                                      lr=lr, beta1=beta1)

# Fit CycleGAN
fit_cycleGAN(epochs, G_models, D_models, opt_G, opt_D, train_dl, device, buffers,
             lambs=lambs, epoch_start=epoch_start, epoch_lr_decline=epoch_lr_decline)

### Testing the Model

In [None]:
# Get Test data
test_ds = DatasetFromFolder(root = data_folder, mode='test',
                             transform = transforms.Compose([
                                 transforms.Resize(int(img_size*1.12), Image.BICUBIC),
                                 transforms.RandomCrop(img_size),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                             ])
                            )

test_dl = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=0)

# Get generating Models
model_G_A2B, model_G_B2A = G_models
# Load data
saved_data = torch.load('./saved_models/cycleGAN_'+use_dataset+'_'+str(img_size)+'_saved_model.tar',
                        map_location=device)
model_G_A2B.load_state_dict(saved_data['G_A2B_state_dict'])
model_G_B2A.load_state_dict(saved_data['G_B2A_state_dict'])

# Push models to cpu
model_G_A2B = model_G_A2B.cpu()
model_G_B2A = model_G_B2A.cpu()

test_img = next(iter(test_dl))
# For A
A_test   = test_img[0].cpu()
fakes_B  = model_G_A2B(A_test).detach()
recovs_A = model_G_B2A(fakes_B).detach()
# For B
B_test   = test_img[1].cpu()
fakes_A  = model_G_B2A(B_test).detach()
recovs_B = model_G_A2B(fakes_A).detach()

img_tmp = torch.cat([A_test, fakes_B, recovs_A, B_test, fakes_A, recovs_B], dim=0).cpu()
plt.figure(figsize=(12,12))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(img_tmp, nrow=3, padding=1, normalize=True),(1,2,0)))
plt.pause(0.001)