# Unpaired Style Transfer using CycleGAN
[CycleGAN](https://arxiv.org/abs/1611.07004) improves upon paired style transfer architecture by relaxing the constraint on input and output images. CycleGAN explores the unpaired style transfer paradigm where the model actually tries to learn the stylistic differences between source and target domains without explicit pairing between input and output images. Zhu and Park et al. describe this unpaired style transfer similar to our ability of imagining how a Van Gogh or Monet would have painted a particular scene (without having actually seen a side by side example). Quoting from the paper itself,

> Instead, we have knowledge of the set of Monet paintings and of the set of landscape photographs. We can reason about the stylistic differences between these two sets, and thereby imagine what a scene might look like if we were to “translate” it from one set into the other.

This provides a nice advantage as well as opens additional use cases where exact pairing of source and target domains is either not available or we do not have enough training examples.



## Load Libraries

In [1]:
import torch
import itertools
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.utils import save_image, make_grid
import torchvision.transforms as transforms
from torchvision.datasets.utils import download_and_extract_archive

import os
import glob
import random
import numpy as np

from PIL import Image

## U-Net Generator
The U-Net architecture uses skip connections to shuttle important features between the input and outputs. In case of pix2pix GAN, skip connections are added between every $ith$ down-sampling and $(n-i)th$ over-sampling layers, where $n$ is the total number of layers in the generator. The skip connection leads to concatenation of all channels from the ith and $(n-i)th$ layers.

In [2]:
class UpSampleBlock(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(UpSampleBlock, self).__init__()
        layers = [
            nn.ConvTranspose2d(
                input_channels,
                output_channels,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False),
        ]
        layers.append(nn.InstanceNorm2d(output_channels))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Dropout(0.5))
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_connection):
        x = self.model(x)
        x = torch.cat((x, skip_connection), 1)

        return x

In [3]:
class DownSampleBlock(nn.Module):
    def __init__(self, input_channels, output_channels,normalize=True):
        super(DownSampleBlock, self).__init__()
        layers = [
            nn.Conv2d(
                input_channels,
                output_channels,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False)
            ]
        if normalize:
          layers.append(nn.InstanceNorm2d(output_channels))
        layers.append(nn.LeakyReLU(0.2))
        layers.append(nn.Dropout(0.5))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [4]:
class Generator(nn.Module):
    def __init__(self, input_channels=3,out_channels=3):
        super(Generator, self).__init__()
        self.downsample1 = DownSampleBlock(input_channels,64, normalize=False)
        self.downsample2 = DownSampleBlock(64, 128)
        self.downsample3 = DownSampleBlock(128, 256)
        self.downsample4 = DownSampleBlock(256, 512)
        self.downsample5 = DownSampleBlock(512, 512)
        self.downsample6 = DownSampleBlock(512, 512)
        self.downsample7 = DownSampleBlock(512, 512)
        self.downsample8 = DownSampleBlock(512, 512,normalize=False)

        self.upsample1 = UpSampleBlock(512, 512)
        self.upsample2 = UpSampleBlock(1024, 512)
        self.upsample3 = UpSampleBlock(1024, 512)
        self.upsample4 = UpSampleBlock(1024, 512)
        self.upsample5 = UpSampleBlock(1024, 256)
        self.upsample6 = UpSampleBlock(512, 128)
        self.upsample7 = UpSampleBlock(256, 64)

        self.final_layer = nn.Sequential(
            nn.Upsample(scale_factor=2),
            # padding left, right, top, bottom
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, out_channels, 4, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        # downsampling blocks
        d1 = self.downsample1(x)
        d2 = self.downsample2(d1)
        d3 = self.downsample3(d2)
        d4 = self.downsample4(d3)
        d5 = self.downsample5(d4)
        d6 = self.downsample6(d5)
        d7 = self.downsample7(d6)
        d8 = self.downsample8(d7)
        # upsampling blocks with skip connections
        u1 = self.upsample1(d8, d7)
        u2 = self.upsample2(u1, d6)
        u3 = self.upsample3(u2, d5)
        u4 = self.upsample4(u3, d4)
        u5 = self.upsample5(u4, d3)
        u6 = self.upsample6(u5, d2)
        u7 = self.upsample7(u6, d1)

        return self.final_layer(u7)

## Patch-GAN Discriminator
The authors for pix2pix propose a Patch-GAN setup for the discriminator which takes the required inputs and generates an output of size NxN. Each $x_{ij}$ element of the NxN output signifies whether the corresponding patch ij in the generated image is real or fake. Each output patch can be traced back to its initial input patch basis the effective receptive field for each of the layers.

In [5]:
class Discriminator(nn.Module):
    def __init__(self, input_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(input_filters, output_filters):
            layers = [
                nn.Conv2d(
                    input_filters,
                    output_filters,
                    kernel_size=4,
                    stride=2,
                    padding=1)
                ]
            layers.append(nn.InstanceNorm2d(output_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(input_channels, output_filters=64),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            # padding left, right, top, bottom
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )

    def forward(self, img):
        return self.model(img)

## Prepare Dataset Class

In [6]:
class ImageDataset(Dataset):
    def __init__(self, dataset_path, is_train=False, image_transformations=None):
        self.transform = transforms.Compose(image_transformations)
        if is_train:
          dataset_path += '/train'
        else:
          dataset_path += '/test'
        self.files_a = sorted(glob.glob(os.path.join(dataset_path) + "A/*.*"))
        self.files_b = sorted(glob.glob(os.path.join(dataset_path) + "B/*.*"))

    def __getitem__(self, index):

        img_a = Image.open(self.files_a[index % len(self.files_a)])
        img_b = Image.open(self.files_b[index % len(self.files_b)])

        img_a = self.transform(img_a)
        img_b = self.transform(img_b)

        return {"A": img_a, "B": img_b}

    def __len__(self):
        return max(len(self.files_a), len(self.files_b))

In [7]:
def sample_images(batches_done):
    imgs = next(iter(val_dataloader))
    generator_AB.eval()
    generator_BA.eval()
    # real A-> B
    real_A = Variable(imgs["A"].type(Tensor))
    # generated A-> B
    fake_B = generator_AB(real_A)
    # real B-> A
    real_B = Variable(imgs["B"].type(Tensor))
    # generated B-> A
    fake_A = generator_BA(real_B)

    # reconstruct orginal samples from both generators
    reconstruct_A = generator_BA(fake_B)
    reconstruct_B = generator_AB(fake_A)

    real_A = make_grid(real_A, nrow=5, normalize=True)
    real_B = make_grid(real_B, nrow=5, normalize=True)
    fake_A = make_grid(fake_A, nrow=5, normalize=True)
    fake_B = make_grid(fake_B, nrow=5, normalize=True)
    reconstruct_A = make_grid(reconstruct_A, nrow=5, normalize=True)
    reconstruct_B = make_grid(reconstruct_B, nrow=5, normalize=True)

    img_sample = torch.cat((real_A, fake_B, reconstruct_A,real_B, fake_A,reconstruct_B), 1)
    save_image(img_sample, f"images/{DATASET}_b_{batches_done}.png", nrow=5, normalize=False)

## Download Dataset

In [8]:
# set experiment by chosing dataset as
# 1. apple2orange or
# 2. vangogh2photo
DATASET = 'vangogh2photo'
DATASET_PATH = f"./images/{DATASET}"

In [9]:
os.makedirs(f"{DATASET_PATH}", exist_ok=True)
os.makedirs("saved_models/", exist_ok=True)

In [10]:
!wget -P images/ https://efrosgans.eecs.berkeley.edu/cyclegan/datasets/{DATASET}.zip
!unzip {DATASET_PATH}.zip -d ./images/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: ./images/vangogh2photo/trainB/2016-08-26 19:25:29.jpg  
  inflating: ./images/vangogh2photo/trainB/2016-02-20 16:54:25.jpg  
  inflating: ./images/vangogh2photo/trainB/2015-05-19 19:43:44.jpg  
  inflating: ./images/vangogh2photo/trainB/2015-10-26 21:18:58.jpg  
  inflating: ./images/vangogh2photo/trainB/2014-07-19 15:39:53.jpg  
  inflating: ./images/vangogh2photo/trainB/2015-06-13 21:21:54.jpg  
  inflating: ./images/vangogh2photo/trainB/2015-06-04 14:10:33.jpg  
  inflating: ./images/vangogh2photo/trainB/2014-06-18 12:46:34.jpg  
  inflating: ./images/vangogh2photo/trainB/2016-03-17 18:07:22.jpg  
  inflating: ./images/vangogh2photo/trainB/2015-01-24 17:24:29.jpg  
  inflating: ./images/vangogh2photo/trainB/2016-01-27 12:45:41.jpg  
  inflating: ./images/vangogh2photo/trainB/2016-06-14 05:57:28.jpg  
  inflating: ./images/vangogh2photo/trainB/2015-12-24 01:48:36.jpg  
  inflating: ./images/vangogh2photo/tr

## Set Parameters

In [11]:
IMG_WIDTH = 256
IMG_HEIGHT = 256
NUM_CHANNELS = 3
BATCH_SIZE = 16
N_EPOCHS = 200
SAMPLE_INTERVAL = 18

In [12]:
CUDA = True if torch.cuda.is_available() else False

## Calculate Receptive Field for Patch

In [13]:
# prepare patch size for our setup
patch = int(IMG_HEIGHT / 2**4)
patch_gan_shape = (1,patch, patch)
print("Patch Shape={}".format(patch_gan_shape))

Patch Shape=(1, 16, 16)


## Get Generator and Discriminator Model Objects

In [14]:
# Initialize generator and discriminator
generator_AB = Generator()
generator_BA = Generator()
discriminator_A = Discriminator()
discriminator_B = Discriminator()


# Loss functions
adversarial_loss = torch.nn.MSELoss()
cycle_loss = torch.nn.L1Loss()
identity_loss = torch.nn.L1Loss()

# Loss weight of L1 pixel-wise loss between translated image and real image
weight_pixel_wise_identity = 100

# Optimizers
optimizer_G = torch.optim.Adam(
    itertools.chain(generator_AB.parameters(), generator_BA.parameters()), lr=0.0002, betas=(0.5, 0.999)
    )
optimizer_D_A = torch.optim.Adam(discriminator_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(discriminator_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [15]:
if CUDA:
    generator_AB = generator_AB.cuda()
    generator_BA = generator_BA.cuda()
    discriminator_A = discriminator_A.cuda()
    discriminator_B = discriminator_B.cuda()
    adversarial_loss.cuda()
    cycle_loss.cuda()
    identity_loss.cuda()
    Tensor = torch.cuda.FloatTensor
else:
  Tensor = torch.FloatTensor

## Define Transformations and Dataloaders

In [16]:
image_transformations = [
    transforms.Resize(int(IMG_HEIGHT * 1.12), Image.BICUBIC),
    transforms.RandomCrop((IMG_HEIGHT, IMG_WIDTH)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

In [17]:
train_dataloader = DataLoader(
    ImageDataset(DATASET_PATH, is_train=True,image_transformations=image_transformations),
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [18]:
val_dataloader = DataLoader(
    ImageDataset(DATASET_PATH,image_transformations=image_transformations),
    batch_size=5,
    shuffle=True
)

## Training Begins!

In [None]:
for epoch in range(0, N_EPOCHS):
    for i, batch in enumerate(train_dataloader):

        # prepare inputs
        real_A = Variable(batch["A"].type(Tensor))
        real_B = Variable(batch["B"].type(Tensor))

        # ground truth
        valid = Variable(Tensor(np.ones((real_A.size(0), *patch_gan_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch_gan_shape))), requires_grad=False)

        #  Train Generator
        generator_AB.train()
        generator_BA.train()
        optimizer_G.zero_grad()

        # identity loss
        idn_loss_A = identity_loss(generator_AB(real_A), real_A)
        idn_loss_B = identity_loss(generator_BA(real_B), real_B)

        idn_loss = (idn_loss_A + idn_loss_B) / 2


        # generator loss
        fake_B = generator_AB(real_A)
        pred_fake = discriminator_B(fake_B)
        adv_loss_AB = adversarial_loss(pred_fake, valid)

        fake_A = generator_BA(real_B)
        pred_fake = discriminator_A(fake_A)
        adv_loss_BA = adversarial_loss(pred_fake, valid)


        # GAN loss
        adv_loss = (adv_loss_AB + adv_loss_BA)/2

        # Cycle loss
        reconstruction_A = generator_BA(fake_B)
        cycle_loss_A = cycle_loss(reconstruction_A, real_A)
        reconstruction_B = generator_AB(fake_A)
        cycle_loss_B = cycle_loss(reconstruction_B, real_B)

        overall_cycle_loss = (cycle_loss_A + cycle_loss_B) / 2

        # Overall loss
        g_loss = adv_loss + 10 * overall_cycle_loss + 5 * idn_loss


        g_loss.backward()
        optimizer_G.step()

        #  Train Discriminator A
        optimizer_D_A.zero_grad()

        pred_real_A = discriminator_A(real_A)
        loss_real_A = adversarial_loss(pred_real_A, valid)
        pred_fake_A = discriminator_A(fake_A.detach())
        loss_fake_A = adversarial_loss(pred_fake_A, fake)

        # Discriminator_A loss
        d_loss_A = 0.5 * (loss_real_A + loss_fake_A)

        d_loss_A.backward()
        optimizer_D_A.step()

        #  Train Discriminator B
        optimizer_D_B.zero_grad()

        pred_real_B = discriminator_B(real_B)
        loss_real_B = adversarial_loss(pred_real_B, valid)
        pred_fake_B = discriminator_B(fake_B.detach())
        loss_fake_B = adversarial_loss(pred_fake_B, fake)

        # Discriminator_A loss
        d_loss_B = 0.5 * (loss_real_B + loss_fake_B)

        d_loss_B.backward()
        optimizer_D_B.step()

        ## Overall Discriminator Loss
        d_loss = (d_loss_A + d_loss_B) / 2

        # Progress Report
        batches_done = epoch * len(train_dataloader) + i
        print(f'Epoch: {epoch}/{N_EPOCHS}-Batch: {i}/{len(train_dataloader)}--D.loss:{d_loss.item():.4f},G.loss:{g_loss.item():.4f}--Adv.Loss:{adv_loss.item():.4f}')

        # generate samples
        if batches_done % SAMPLE_INTERVAL == 0:
            sample_images(batches_done)

  valid = Variable(Tensor(np.ones((real_A.size(0), *patch_gan_shape))), requires_grad=False)


Epoch: 0/200-Batch: 0/393--D.loss:0.7753,G.loss:9.0990--Adv.Loss:1.4075
Epoch: 0/200-Batch: 1/393--D.loss:0.6744,G.loss:7.5540--Adv.Loss:0.5799
Epoch: 0/200-Batch: 2/393--D.loss:1.1776,G.loss:7.1868--Adv.Loss:0.6048
Epoch: 0/200-Batch: 3/393--D.loss:0.9952,G.loss:6.9940--Adv.Loss:0.7873
Epoch: 0/200-Batch: 4/393--D.loss:0.4806,G.loss:6.3515--Adv.Loss:0.4175
Epoch: 0/200-Batch: 5/393--D.loss:0.4426,G.loss:5.9172--Adv.Loss:0.5024
Epoch: 0/200-Batch: 6/393--D.loss:0.3651,G.loss:5.3290--Adv.Loss:0.4427
Epoch: 0/200-Batch: 7/393--D.loss:0.3167,G.loss:5.1697--Adv.Loss:0.3929
Epoch: 0/200-Batch: 8/393--D.loss:0.3060,G.loss:4.9910--Adv.Loss:0.3688
Epoch: 0/200-Batch: 9/393--D.loss:0.3028,G.loss:4.6442--Adv.Loss:0.3414
Epoch: 0/200-Batch: 10/393--D.loss:0.3018,G.loss:4.9163--Adv.Loss:0.3537
Epoch: 0/200-Batch: 11/393--D.loss:0.3155,G.loss:4.4283--Adv.Loss:0.3692
Epoch: 0/200-Batch: 12/393--D.loss:0.3131,G.loss:4.5722--Adv.Loss:0.3873
Epoch: 0/200-Batch: 13/393--D.loss:0.2802,G.loss:4.2028--Adv.