<a href="https://colab.research.google.com/github/RaghavendraGaleppa/GANs/blob/master/CycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torchvision
import PIL.Image as Image
import numpy as np
import matplotlib.pyplot as plt
import os

In [2]:
!wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip

--2019-05-21 15:55:19--  https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip
Resolving people.eecs.berkeley.edu (people.eecs.berkeley.edu)... 128.32.189.73
Connecting to people.eecs.berkeley.edu (people.eecs.berkeley.edu)|128.32.189.73|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 116867962 (111M) [application/zip]
Saving to: ‘horse2zebra.zip’


2019-05-21 15:55:27 (35.5 MB/s) - ‘horse2zebra.zip’ saved [116867962/116867962]



In [0]:
!mkdir ./.data/; unzip -qq -d ./.data/ horse2zebra.zip

In [0]:
def load_image(root):
  files = os.listdir(root)
  data = []
  
  for fname in files:
    img = Image.open(root+fname)
    data.append(np.array(img))
    
  return data

In [0]:
horse = np.array(load_image('./.data/horse2zebra/trainA/'))

In [0]:
zebra = load_image('./.data/horse2zebra/trainB/')

In [29]:
for z in zebra:
  if(z.shape == (256,256)):
    print('lala')
    zebra.remove(z)

lala
lala
lala
lala
lala
lala
lala


  after removing the cwd from sys.path.


In [0]:
zebra = np.array(zebra)

In [31]:
print(horse.shape)
print(zebra.shape)

(1067, 256, 256, 3)
(1327, 256, 256, 3)


In [0]:
# Buidling the Generator for Cycle GAN
# Input Image 
# --> Encoding
#     --> Conv Layer 1
#     --> Conv Layer 2
#     --> Conv Layer 3
# --> Transformation
#     --> Resnet Block 1
#     --> Resnet Block 2
#     --> Resnet Block 3
# --> Decoding
#     --> Deconv Layer 1
#     --> Deconv Layer 2
#     --> Deconv Layer 3

class EncodingBlock(nn.Module):
  
  def __init__(self, inputChannels):
    super(EncodingBlock, self).__init__()
    
    # (256,256,3 -> 256,256,32)
    self.conv1 = nn.Conv2d(inputChannels, 32, 3, stride=1, padding=1)
    self.bn1 = nn.BatchNorm2d(32)
    # (256, 256, 32 -> 128, 128, 64)
    self.conv2 = nn.Conv2d(32, 64, 3, stride=(2,2), padding=1)
    self.bn2 = nn.BatchNorm2d(64)
    # (128, 128, 64 -> 64, 64, 128)
    self.conv3 = nn.Conv2d(64, 128, 3, stride=2,padding=1)
    self.bn3 = nn.BatchNorm2d(128)
    
    self.relu = nn.LeakyReLU()
    
  def forward(self, x):
    x = self.bn1(self.relu(self.conv1(x)))
    #print(x.shape)
    x = self.bn2(self.relu(self.conv2(x)))
    #print(x.shape)
    x = self.bn3(self.relu(self.conv3(x)))
    #print(x.shape)
    
    return x
  
class ResnetBlock(nn.Module):
  
  def __init__(self, inp, out):
    super(ResnetBlock, self).__init__()
    self.res1 = nn.Conv2d(inp, out, 3, stride=1, padding=1)
    self.res2 = nn.Conv2d(out,out, 3, stride=1, padding=1)
    
    self.relu = nn.LeakyReLU()
    self.bn = nn.BatchNorm2d(out)
    
  def forward(self, x):
    
    X = self.bn(self.relu(self.res1(x)))
    X = self.bn(self.relu(self.res2(X)))
    
    X = X+x
    
    return X
  
class DecodingBlock(nn.Module): 
  
  def __init__(self, inp):
    super(DecodingBlock, self).__init__()
    
    # (64, 64, inp) -> (64, 64, 512)
    self.dconv1 = nn.ConvTranspose2d(inp, 512, 3,stride=1, padding=1)
    self.bn1 = nn.BatchNorm2d(512)
    # (64, 64, 512) -> (128, 128, 256)
    self.dconv2 = nn.ConvTranspose2d(512, 256,4, stride=2, padding=1)
    self.bn2 = nn.BatchNorm2d(256)
    # (128, 128, 256) -> (256, 256, 128)
    self.dconv3 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
    self.bn3 = nn.BatchNorm2d(128)
    # (256, 256, 128) -> (256, 256, 3)
    self.dconv4 = nn.ConvTranspose2d(128, 3, 3, stride=1, padding=1)
    
    self.relu = nn.LeakyReLU()
    self.tanh = nn.Tanh()
    
  def forward(self, x):
    x = self.relu(self.bn1(self.dconv1(x)))
    #print(x.shape)
    x = self.relu(self.bn2(self.dconv2(x)))
    #print(x.shape)
    x = self.relu(self.bn3(self.dconv3(x)))
    #print(x.shape)
    x = self.tanh(self.dconv4(x))
    #print(x.shape)
    
    return x
    

class Generator(nn.Module):
  
  def __init__(self):
    super(Generator, self).__init__()
    
    # Encoding Block
    self.enc = EncodingBlock(3)
    # 4 Resnet Blocks
    self.res1 = ResnetBlock(128,128)
    self.res2 = ResnetBlock(128,128)
    self.res3 = ResnetBlock(128,128)
    self.res4 = ResnetBlock(128,128)
    # Decoding Block
    self.dec = DecodingBlock(128)
    
  def forward(self, image):
    X = self.enc(image)
    X = self.res1(X)
    X = self.res2(X)
    X = self.res3(X)
    X = self.res4(X)
    X = self.dec(X)
    
    return X
    

In [0]:
# Building the Discriminator for GAN

dnf = 16

class Discriminator(nn.Module):
  
  def __init__(self):
    super(Discriminator, self).__init__()
    
    # (256, 256, 3) -> (256,256,32)
    self.conv1 = nn.Conv2d(3, dnf*2, 3, padding=1)
    self.bn1 = nn.BatchNorm2d(dnf*2)
    # (256,256,32) -> (128,128,64)
    self.conv2 = nn.Conv2d(dnf*2, dnf*4, 3, stride = 2, padding=1)
    self.bn2 = nn.BatchNorm2d(dnf*4)
    # (128,128,64) -> (64,64,128)
    self.conv3 = nn.Conv2d(dnf*4, dnf*8, 3, stride=2, padding=1)
    self.bn3 = nn.BatchNorm2d(dnf*8)
    # (64,64,128) -> (32, 32, 256)
    self.conv4 = nn.Conv2d(dnf*8, dnf*16, 3,stride=2, padding=1)
    self.bn4 = nn.BatchNorm2d(dnf*16)
    
    # (32,32,256) -> (16,16,256)
    self.conv5 = nn.Conv2d(dnf*16, dnf*16, 3, stride=2, padding=1)
    self.bn5 = nn.BatchNorm2d(dnf*16)
    
     # (16,16,256) -> (8,8,256)
    self.conv6 = nn.Conv2d(dnf*16, dnf*16, 3, stride=2, padding=1)
    self.bn6 = nn.BatchNorm2d(dnf*16)
    
    # (8,8,256) -> (4,4,256)
    self.conv7 = nn.Conv2d(dnf*16, dnf*16, 3, stride=2, padding=1)
    self.bn7 = nn.BatchNorm2d(dnf*16)
    
    # (4,4,256) -> (1,1,256)
    self.conv8 = nn.Conv2d(dnf*16, dnf*16, 4)
    
    # (1,1,256) -> (1,1,1)
    self.conv9 = nn.Conv2d(dnf*16, 1, 1)

    self.relu = nn.LeakyReLU()
    self.r = nn.ReLU()
    
  def forward(self,x):
    
    x = self.relu(self.bn1(self.conv1(x)))
    #print("Conv1:",x.shape)
    x = self.relu(self.bn2(self.conv2(x)))
    #print("Conv2:",x.shape)
    x = self.relu(self.bn3(self.conv3(x)))    
    #print("Conv3:",x.shape)
    x = self.relu(self.bn4(self.conv4(x)))
    #print("Conv4:",x.shape)
    x = self.relu(self.bn5(self.conv5(x)))
    #print("Conv5:",x.shape)
    x = self.relu(self.bn6(self.conv6(x)))
    #print("Conv6:",x.shape)
    x = self.relu(self.bn7(self.conv7(x)))
    #print("Conv7:",x.shape)
    x = self.relu(self.conv8(x))
    #print("Conv8:",x.shape)
    
    x = self.r(self.conv9(x))
    #print("Conv9:",x.shape)
    
    return x
    
       

In [0]:
gen_AtoB = Generator()
gen_BtoA = Generator()
disc_A = Discriminator()
disc_B = Discriminator()



In [0]:
# Data Loading part

import torchvision.transforms as transforms
import torch.utils.data as data

horse = np.array(load_image('./.data/horse2zebra/trainA/'))
zebra = load_image('./.data/horse2zebra/trainB/')
for z in zebra:
  if(z.shape == (256,256)):
    print('lala')
    zebra.remove(z)
    
print(horse.shape)
print(zebra.shape)
# Normalize the tensors from -1 to 1
horse = horse/255.0
horse = (horse-0.5)/0.5

zebra = zebra/255.0
zebra = (zebra-0.5)/0.5