# Paired Style Transfer using Pix2Pix GAN

In their work titled [“Image to Image Translation with Conditional Adversarial Networks”](https://arxiv.org/abs/1611.07004), Isola and Zhu et. al. present a conditional GAN network which is able to learn task specific loss functions and thus work across datasets. As the name suggests, this GAN architecture takes a specific type of image as input and transforms it into a different domain. It is called pair-wise style transfer as the training set needs to have samples from both, source and target domains.

## Load Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import torchvision.transforms as transforms
from torchvision.datasets.utils import download_and_extract_archive

import os
import glob
import random
import numpy as np

from PIL import Image

## U-Net Generator
The U-Net architecture uses skip connections to shuttle important features between the input and outputs. In case of pix2pix GAN, skip connections are added between every $ith$ down-sampling and $(n-i)th$ over-sampling layers, where $n$ is the total number of layers in the generator. The skip connection leads to concatenation of all channels from the ith and $(n-i)th$ layers.

In [3]:
class DownSampleBlock(nn.Module):
    def __init__(self, input_channels, output_channels,normalize=True):
        super(DownSampleBlock, self).__init__()
        layers = [
            nn.Conv2d(
                input_channels,
                output_channels,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False)
            ]
        if normalize:
          layers.append(nn.InstanceNorm2d(output_channels))
        layers.append(nn.LeakyReLU(0.2))
        layers.append(nn.Dropout(0.5))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [4]:
class UpSampleBlock(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(UpSampleBlock, self).__init__()
        layers = [
            nn.ConvTranspose2d(
                input_channels,
                output_channels,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False),
        ]
        layers.append(nn.InstanceNorm2d(output_channels))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Dropout(0.5))
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_connection):
        x = self.model(x)
        x = torch.cat((x, skip_connection), 1)

        return x

In [5]:
class Generator(nn.Module):
    def __init__(self, input_channels=3,out_channels=3):
        super(Generator, self).__init__()

        self.downsample1 = DownSampleBlock(input_channels,64, normalize=False)
        self.downsample2 = DownSampleBlock(64, 128)
        self.downsample3 = DownSampleBlock(128, 256)
        self.downsample4 = DownSampleBlock(256, 512)
        self.downsample5 = DownSampleBlock(512, 512)
        self.downsample6 = DownSampleBlock(512, 512)
        self.downsample7 = DownSampleBlock(512, 512)
        self.downsample8 = DownSampleBlock(512, 512,normalize=False)

        self.upsample1 = UpSampleBlock(512, 512)
        self.upsample2 = UpSampleBlock(1024, 512)
        self.upsample3 = UpSampleBlock(1024, 512)
        self.upsample4 = UpSampleBlock(1024, 512)
        self.upsample5 = UpSampleBlock(1024, 256)
        self.upsample6 = UpSampleBlock(512, 128)
        self.upsample7 = UpSampleBlock(256, 64)

        self.final_layer = nn.Sequential(
            nn.Upsample(scale_factor=2),
            # padding left, right, top, bottom
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, out_channels, 4, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        # downsampling blocks
        d1 = self.downsample1(x)
        d2 = self.downsample2(d1)
        d3 = self.downsample3(d2)
        d4 = self.downsample4(d3)
        d5 = self.downsample5(d4)
        d6 = self.downsample6(d5)
        d7 = self.downsample7(d6)
        d8 = self.downsample8(d7)
        # upsampling blocks with skip connections
        u1 = self.upsample1(d8, d7)
        u2 = self.upsample2(u1, d6)
        u3 = self.upsample3(u2, d5)
        u4 = self.upsample4(u3, d4)
        u5 = self.upsample5(u4, d3)
        u6 = self.upsample6(u5, d2)
        u7 = self.upsample7(u6, d1)

        return self.final_layer(u7)

## Patch-GAN Discriminator
The authors for pix2pix propose a Patch-GAN setup for the discriminator which takes the required inputs and generates an output of size NxN. Each $x_{ij}$ element of the NxN output signifies whether the corresponding patch ij in the generated image is real or fake. Each output patch can be traced back to its initial input patch basis the effective receptive field for each of the layers.

In [6]:
class Discriminator(nn.Module):
    def __init__(self, input_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(input_filters, output_filters):
            layers = [
                nn.Conv2d(
                    input_filters,
                    output_filters,
                    kernel_size=4,
                    stride=2,
                    padding=1)
                ]
            layers.append(nn.InstanceNorm2d(output_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(input_channels * 2, output_filters=64),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            # padding left, right, top, bottom
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )

    def forward(self, img_A, img_B):
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

## Prepare Dataset Class

In [7]:
class ImageDataset(Dataset):
    def __init__(self, dataset_path, is_train=False, image_transformations=None):
        self.transform = transforms.Compose(image_transformations)
        if is_train:
          dataset_path += '/train'
        else:
          dataset_path += '/val'
        self.files = sorted(glob.glob(os.path.join(dataset_path) + "/*.*"))

    def __getitem__(self, index):

        img = Image.open(self.files[index % len(self.files)])
        w, h = img.size
        img_A = img.crop((0, 0, w / 2, h))
        img_B = img.crop((w / 2, 0, w, h))

        # flip images randomly
        if np.random.random() < 0.5:
            img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], "RGB")
            img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], "RGB")

        img_A = self.transform(img_A)
        img_B = self.transform(img_B)

        return {"A": img_A, "B": img_B}

    def __len__(self):
        return len(self.files)

In [8]:
def sample_images(batches_done):
    imgs = next(iter(val_dataloader))
    # condition
    real_A = Variable(imgs["B"].type(Tensor))
    # real
    real_B = Variable(imgs["A"].type(Tensor))
    # generated
    fake_B = generator(real_A)
    img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
    save_image(img_sample, f"images/{batches_done}.png", nrow=4, normalize=True)

## Download Dataset

In [None]:
!wget -P images/ http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/maps.tar.gz
!tar -zxf ./images/maps.tar.gz --directory ./images/

--2024-02-04 19:21:47--  http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/maps.tar.gz
Resolving efrosgans.eecs.berkeley.edu (efrosgans.eecs.berkeley.edu)... 128.32.244.190
Connecting to efrosgans.eecs.berkeley.edu (efrosgans.eecs.berkeley.edu)|128.32.244.190|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 250242400 (239M) [application/x-gzip]
Saving to: ‘images/maps.tar.gz.1’


2024-02-04 19:23:56 (1.85 MB/s) - ‘images/maps.tar.gz.1’ saved [250242400/250242400]



## Set Parameters

In [9]:
URL="http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/maps.tar.gz"
DATASET_PATH = "./images/maps"
IMG_WIDTH = 256
IMG_HEIGHT = 256
NUM_CHANNELS = 3
BATCH_SIZE = 64
N_EPOCHS = 200
SAMPLE_INTERVAL = 18

In [10]:
CUDA = True if torch.cuda.is_available() else False

In [None]:
os.makedirs("images/maps/", exist_ok=True)
os.makedirs("saved_models/", exist_ok=True)

## Calculate Receptive Field for Patch

In [11]:
# prepare patch size for our setup
patch = int(IMG_HEIGHT / 2**4)
patch_gan_shape = (1,patch, patch)
print("Patch Shape={}".format(patch_gan_shape))

Patch Shape=(1, 16, 16)


## Get Generator and Discriminator Model Objects

In [12]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Loss functions
adversarial_loss = torch.nn.MSELoss()
pixelwise_loss = torch.nn.L1Loss()

# Loss weight of L1 pixel-wise loss between translated image and real image
weight_pixel_wise_identity = 100

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [13]:
if CUDA:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    adversarial_loss.cuda()
    pixelwise_loss.cuda()
    Tensor = torch.cuda.FloatTensor
else:
  Tensor = torch.FloatTensor

In [14]:
first_parameter = next(generator.parameters())
input_shape = first_parameter.size()
input_shape

torch.Size([64, 3, 4, 4])

## Define Transformations and Dataloaders

In [None]:
image_transformations = [
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

In [None]:
train_dataloader = DataLoader(
    ImageDataset(DATASET_PATH, is_train=True,image_transformations=image_transformations),
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [None]:
val_dataloader = DataLoader(
    ImageDataset(DATASET_PATH,image_transformations=image_transformations),
    batch_size=BATCH_SIZE//8,
    shuffle=True
)

## Training Begins!

In [None]:
for epoch in range(0, N_EPOCHS):
    for i, batch in enumerate(train_dataloader):

        # prepare inputs
        real_A = Variable(batch["B"].type(Tensor))
        real_B = Variable(batch["A"].type(Tensor))

        # ground truth
        valid = Variable(Tensor(np.ones((real_A.size(0), *patch_gan_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch_gan_shape))), requires_grad=False)

        #  Train Generator
        optimizer_G.zero_grad()

        # generator loss
        fake_B = generator(real_A)
        pred_fake = discriminator(fake_B, real_A)
        adv_loss = adversarial_loss(pred_fake, valid)
        loss_pixel = pixelwise_loss(fake_B, real_B)

        # Overall Generator loss
        g_loss = adv_loss + weight_pixel_wise_identity * loss_pixel

        g_loss.backward()

        optimizer_G.step()

        #  Train Discriminator
        optimizer_D.zero_grad()

        pred_real = discriminator(real_B, real_A)
        loss_real = adversarial_loss(pred_real, valid)
        pred_fake = discriminator(fake_B.detach(), real_A)
        loss_fake = adversarial_loss(pred_fake, fake)

        # Overall Discriminator loss
        d_loss = 0.5 * (loss_real + loss_fake)

        d_loss.backward()
        optimizer_D.step()

        # Progress Report
        batches_done = epoch * len(train_dataloader) + i
        print(f'Epoch: {epoch}/{N_EPOCHS}-Batch: {i}/{len(train_dataloader)}--D.loss:{d_loss.item():.4f},G.loss:{g_loss.item():.4f}--Adv.Loss:{adv_loss.item():.4f}')

        # generate samples
        if batches_done % SAMPLE_INTERVAL == 0:
            sample_images(batches_done)

  valid = Variable(Tensor(np.ones((real_A.size(0), *patch_gan_shape))), requires_grad=False)


Epoch: 0/200-Batch: 0/18--D.loss:0.7406,G.loss:57.0879--Adv.Loss:1.3402
Epoch: 0/200-Batch: 1/18--D.loss:0.4748,G.loss:49.7242--Adv.Loss:0.4805
Epoch: 0/200-Batch: 2/18--D.loss:0.7498,G.loss:42.0825--Adv.Loss:0.5350
Epoch: 0/200-Batch: 3/18--D.loss:0.4612,G.loss:37.5967--Adv.Loss:0.2657
Epoch: 0/200-Batch: 4/18--D.loss:0.2983,G.loss:34.0058--Adv.Loss:0.3665
Epoch: 0/200-Batch: 5/18--D.loss:0.3036,G.loss:32.1782--Adv.Loss:0.4951
Epoch: 0/200-Batch: 6/18--D.loss:0.2849,G.loss:30.7707--Adv.Loss:0.4436
Epoch: 0/200-Batch: 7/18--D.loss:0.2765,G.loss:30.4441--Adv.Loss:0.3500
Epoch: 0/200-Batch: 8/18--D.loss:0.2622,G.loss:30.6459--Adv.Loss:0.3463
Epoch: 0/200-Batch: 9/18--D.loss:0.2508,G.loss:30.2438--Adv.Loss:0.4105
Epoch: 0/200-Batch: 10/18--D.loss:0.2253,G.loss:28.3003--Adv.Loss:0.4241
Epoch: 0/200-Batch: 11/18--D.loss:0.2179,G.loss:28.3821--Adv.Loss:0.4165
Epoch: 0/200-Batch: 12/18--D.loss:0.2142,G.loss:28.2538--Adv.Loss:0.4410
Epoch: 0/200-Batch: 13/18--D.loss:0.2162,G.loss:27.8362--Adv.

KeyboardInterrupt: 