## Module 3: Cycle GAN 

CycleGAN is used to transfer images from one domain to another domain. For example, we can convert pictures of horses to zebras and back. We can also colourize gray images. There are a lot of exciting applications for this type of GAN.

<img src='Images/Cycle.jpg'/>

Let us define some constants and convenience functions:

In [1]:
import torch
import torch.nn
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
import scipy.ndimage.interpolation
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
device = torch.device('cuda')
#%matplotlib inline

mb_size = 128
Z_dim = 100
X_dim = 784
y_dim = 10
h_dim = 128
cnt = 0
lr = 1e-4
lambda_idt = 0.1

def plotgrid(img):
    img = img.view([mb_size,1,28,28])
    img = torchvision.utils.make_grid(img)
    img = img.permute(1,2,0)
    plt.imshow(img.detach().numpy())
    
def log(x):
    return torch.log(x + 1e-8)

Let us load the dataset. We are using MNIST here. Since CycleGAN converts between two domains, we will use normal MNIST for one domain. For the other domain, we will use transposed MNIST. We are basically trying to make the network learn the transpose operation.

In [2]:
dataroot = './data'

# load the dataset
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0,), (1,))])
trainset = torchvision.datasets.MNIST(root=dataroot , train=False, download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=mb_size,shuffle=True, num_workers=2)

dataiter = iter(trainloader)

def mnist_next(diter):

    try:
        images, labels = diter.next()
        if images.shape[0]!=mb_size:
            diter = initialize_loader(trainset)
            images, labels = diter.next()
        images = images.view(images.numpy().shape[0],28*28)
    except:
        diter = iter(trainloader)
        images, labels = diter.next()
        if images.shape[0]!=mb_size:
            diter = initialize_loader(trainset)
            images, labels = diter.next()
        images = images.view(images.numpy().shape[0],28*28)
    return images, labels

def mnist_next2(diter):

    try:
        images, labels = diter.next()
        if images.shape[0]!=mb_size:
            diter = initialize_loader(trainset)
            images, labels = diter.next()
        images = images.permute(0,1,3,2)
        images = images.contiguous().view(images.numpy().shape[0],28*28)
    except:
        diter = iter(trainloader)
        images, labels = diter.next()
        if images.shape[0]!=mb_size:
            diter = initialize_loader(trainset)
            images, labels = diter.next()
        images = images.permute(0,1,3,2)
        images = images.contiguous().view(images.numpy().shape[0],28*28)
    return images, labels

def initialize_loader(trainset):
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=mb_size,
                                          shuffle=True, num_workers=2)
    dataiter = iter(trainloader)
    return dataiter


Let us define the networks. For CycleGAN, there are two generators, and two discriminators.
<img src='Images/CGAN.jpg'/>

In [3]:
G_AB = torch.nn.Sequential(
    torch.nn.Linear(X_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),
    torch.nn.Sigmoid()
)

G_BA = torch.nn.Sequential(
    torch.nn.Linear(X_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),
    torch.nn.Sigmoid()
)

D_A = torch.nn.Sequential(
    torch.nn.Linear(X_dim, h_dim),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(h_dim, 1),
    torch.nn.Sigmoid()
)

D_B = torch.nn.Sequential(
    torch.nn.Linear(X_dim, h_dim),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(h_dim, 1),
    torch.nn.Sigmoid()
)

nets = [G_AB, G_BA, D_A, D_B]
G_params = list(G_AB.parameters()) + list(G_BA.parameters())
D_params = list(D_A.parameters()) + list(D_B.parameters())


def reset_grad():
    for net in nets:
        net.zero_grad()


G_solver = optim.Adam(G_params, lr=lr)
D_solver = optim.Adam(D_params, lr=lr)


Transfer all of it to the gpu:

In [4]:
G_AB.to(device)
G_BA.to(device)
D_A.to(device)
D_B.to(device)

Sequential(
  (0): Linear(in_features=784, out_features=128, bias=True)
  (1): LeakyReLU(negative_slope=0.01)
  (2): Linear(in_features=128, out_features=1, bias=True)
  (3): Sigmoid()
)

Let us train the dataset. Both the generators and discriminators are trained at one go.
<img src='Images/cyclegan.png'/>

In [None]:
# Training
for it in range(5000):
    # Sample data from both domains
    X_A= mnist_next(dataiter)[0];
    X_B = mnist_next2(dataiter)[0];
    X_A = X_A.to(device)
    X_B = X_B.to(device)
    
    # Discriminator A
    X_BA = G_BA(X_B)
    D_A_real = D_A(X_A)
    D_A_fake = D_A(X_BA)

    L_D_A = -torch.mean(log(D_A_real) + log(1 - D_A_fake))

    # Discriminator B
    X_AB = G_AB(X_A)
    D_B_real = D_B(X_B)
    D_B_fake = D_B(X_AB)

    L_D_B = -torch.mean(log(D_B_real) + log(1 - D_B_fake))

    # Total discriminator loss
    D_loss = L_D_A + L_D_B

    D_loss.backward()
    D_solver.step()
    reset_grad()

    # Generator AB
    X_AB = G_AB(X_A)
    D_B_fake = D_B(X_AB)
    X_ABA = G_BA(X_AB)
    
    L_adv_B = -torch.mean(log(D_B_fake))
    L_loss_cycle_A = torch.mean(torch.sum((X_A - X_ABA)**2, 1))
    #L_G_AB = L_adv_B + L_loss_cycle_A

    # Generator BA
    X_BA = G_BA(X_B)
    D_A_fake = D_A(X_BA)
    X_BAB = G_AB(X_BA)

    L_adv_A = -torch.mean(log(D_A_fake))
    L_loss_cycle_B = torch.mean(torch.sum((X_B - X_BAB)**2, 1))
    #L_G_BA = L_adv_A + L_loss_cycle_B
    
    L_G = L_adv_B + L_adv_A
    L_cycle = L_loss_cycle_A + L_loss_cycle_B
    # Identity loss
    
    # G_A2B(B) should equal B if real B is fed
    same_B = G_AB(X_B)
    loss_identity_B = torch.mean( torch.sum((X_B - same_B)**2, 1))
    
    # G_B2A(A) should equal A if real A is fed
    same_A = G_BA(X_A)
    loss_identity_A = torch.mean(torch.sum((X_A - same_A)**2, 1))
    Identity_loss = lambda_idt*(loss_identity_B + loss_identity_A)
    
    # Total generator loss
    G_loss = L_G + L_cycle + Identity_loss

    G_loss.backward()
    G_solver.step()
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {:.4}; G_loss: {:.4}'.format(it, D_loss.item(), G_loss.item()))
        cnt += 1


Iter-0; D_loss: 2.848; G_loss: 402.2
Iter-1000; D_loss: 1.425; G_loss: 98.59


See what the trained networks produce:

In [None]:
input_A = torch.tensor(mnist_next(dataiter )[0])
input_B = torch.tensor(mnist_next2(dataiter)[0])

samples_A = G_BA(input_B.to(device))
samples_B = G_AB(input_A.to(device))
samples_A = samples_A.cpu()
samples_B = samples_B.cpu()

plotgrid(input_B)
plt.show()
plotgrid(samples_A)
plt.show()
plotgrid(input_A)
plt.show()
plotgrid(samples_B)



plt.show


## Points to ponder
1. Observe the ouput without cycle consistency loss ?
2. What happens when 'lambda_idt' is change to 10 ?