In [None]:
!unzip summer2winter_yosemite.zip

In [None]:
import os
import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# visualizing data
import matplotlib.pyplot as plt
import numpy as np
import warnings

%matplotlib inline

### Data setup
Creates data_loaders for the train and test data. This is then used to create the summer and winter data

In [None]:
def get_data_loader(image_type, image_dir='summer2winter_yosemite', 
                    image_size=128, batch_size=16, num_workers=0):    
    # resize and normalize the images
    transform = transforms.Compose([transforms.Resize(image_size), # resize to 128x128
                                    transforms.ToTensor()])

    # get training and test directories
    image_path = './' + image_dir
    train_path = os.path.join(image_path, image_type)
    test_path = os.path.join(image_path, 'test_{}'.format(image_type))

    # define datasets using ImageFolder
    train_dataset = datasets.ImageFolder(train_path, transform)
    test_dataset = datasets.ImageFolder(test_path, transform)

    # create and return DataLoaders
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, test_loader

In [None]:
# Create train and test dataloaders for images from the two domains X and Y
# image_type = directory names for our data
dataloader_X, test_dataloader_X = get_data_loader(image_type='summer')
dataloader_Y, test_dataloader_Y = get_data_loader(image_type='winter')

### Displays some images to make sure the data is batched properly

In [None]:
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    

# get some images from X
dataiter = iter(dataloader_X)
images, _ = dataiter.next()

# show images
fig = plt.figure(figsize=(12, 8))
imshow(torchvision.utils.make_grid(images))

In [None]:
# get some images from Y
dataiter = iter(dataloader_Y)
images, _ = dataiter.next()

# show images
fig = plt.figure(figsize=(12,8))
imshow(torchvision.utils.make_grid(images))

### Scaling function
GANs have been shown to perform best with a tanh function at the end of the generator so the data has to be scaled accordingly

In [None]:
def scale(x, feature_range=(-1, 1)):
    # scale from 0-1 to feature_range
    min, max = feature_range
    x = x * (max - min) + min
    return x
scaled_img = scale(img)

In [None]:
img = images[0]

print('Min: ', img.min())
print('Max: ', img.max())

print('='*18)

print('Scaled min: ', scaled_img.min())
print('Scaled max: ', scaled_img.max())

### Creates discriminator

In [None]:
import torch.nn as nn
import torch.nn.functional as F

# discriminator function
class Discriminator(nn.Module):
    
    def __init__(self, conv_dim=64):
        super(Discriminator, self).__init__()

        # Define all convolutional layers
        # Should accept an RGB image as input and output a single value
        self.conv1 = nn.Conv2d(3, conv_dim, 4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(conv_dim, conv_dim*2, 4, stride=2, padding=1)
        self.norm1 = nn.BatchNorm2d(conv_dim*2)
        self.conv3 = nn.Conv2d(conv_dim*2, conv_dim*4, 4, stride=2, padding=1)
        self.norm2 = nn.BatchNorm2d(conv_dim*4)
        self.conv4 = nn.Conv2d(conv_dim*4, conv_dim*8, 4, stride=2, padding=1)
        self.norm3 = nn.BatchNorm2d(conv_dim*8)
        self.conv_final = nn.Conv2d(conv_dim*8, 1, 4, stride=1, padding=1)

    def forward(self, x):
        
        # define feedforward behavior
        x = F.relu(self.conv1(x))
        x = F.relu(self.norm1(self.conv2(x)))
        x = F.relu(self.norm2(self.conv3(x)))
        x = F.relu(self.norm3(self.conv4(x)))
        x = self.conv_final(x)
        
        return x

### Defines a residual layer to model the network architecture of the original CycleGAN paper

In [None]:
# residual block class
class ResidualBlock(nn.Module):
    def __init__(self, conv_dim):
        super(ResidualBlock, self).__init__()
        # define two convolutional layers + batch normalization that will act as our residual function, F(x)
        # layers should have the same shape input as output; I suggest a kernel_size of 3
        self.conv1 = nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(conv_dim, conv_dim, 3, stride=1, padding=1)
        self.norm = nn.BatchNorm2d(conv_dim)
        
    def forward(self, x):
        # apply a ReLu activation the outputs of the first layer
        # return a summed output, x + resnet_block(x)
        residual = x
        x = self.norm(self.conv2(F.relu(self.conv1(x))))
        x += residual
        return x

### Creates generator using the residual blocks from above

In [None]:
class CycleGenerator(nn.Module):
    
    def __init__(self, conv_dim=64, n_res_blocks=6):
        super(CycleGenerator, self).__init__()        
        # 1. Define the encoder part of the generator
        self.conv1 = nn.Conv2d(3, conv_dim, 5, stride=2, padding=2)
        self.norm1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(conv_dim, conv_dim*2, 3, stride=2, padding=1)
        self.norm2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(conv_dim*2, conv_dim*4, 3, stride=2, padding=1)
        self.norm3 = nn.BatchNorm2d(256)
        
        # 2. Define the resnet part of the generator
        res_layers = []
        for i in range(n_res_blocks):
            res_layers.append(ResidualBlock(conv_dim*4))
        self.residual = nn.Sequential(*res_layers)
        
        # 3. Define the decoder part of the generator
        self.trans1 = nn.ConvTranspose2d(conv_dim*4, conv_dim*2, 4,
                                        stride=2, padding=1, bias=False)
        self.norm4 = nn.BatchNorm2d(conv_dim*2)
        self.trans2 = nn.ConvTranspose2d(conv_dim*2, conv_dim, 4,
                                        stride=2, padding=1, bias=False)
        self.norm5 = nn.BatchNorm2d(conv_dim)
        self.trans3 = nn.ConvTranspose2d(conv_dim, 3, 4,
                                        stride=2, padding=1, bias=False)

    def forward(self, x):
        # define feedforward behavior, applying activations as necessary
        x = F.relu(self.norm1(self.conv1(x)))
        x = F.relu(self.norm2(self.conv2(x)))
        x = F.relu(self.norm3(self.conv3(x)))
        
        x = self.residual(x)
        
        x = F.relu(self.norm4(self.trans1(x)))
        x = F.relu(self.norm5(self.trans2(x)))
        x = F.tanh(self.trans3(x))
        
        return x

### Creates the models

In [None]:
def create_model(g_conv_dim=64, d_conv_dim=64, n_res_blocks=6):
    # Instantiate generators
    G_XtoY = CycleGenerator(conv_dim=g_conv_dim, n_res_blocks=n_res_blocks)
    G_YtoX = CycleGenerator(conv_dim=g_conv_dim, n_res_blocks=n_res_blocks)
    # Instantiate discriminators
    D_X = Discriminator(conv_dim=d_conv_dim)
    D_Y = Discriminator(conv_dim=d_conv_dim)

    # move models to GPU, if available
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        G_XtoY.to(device)
        G_YtoX.to(device)
        D_X.to(device)
        D_Y.to(device)
        print('Models moved to GPU')
    else:
        print('Models stay on CPU')

    return G_XtoY, G_YtoX, D_X, D_Y

# call the function to get models
G_XtoY, G_YtoX, D_X, D_Y = create_model()

### Defines loss functions

In [None]:
# loss functions
def real_mse_loss(D_out):
    # how close is the produced output from being "real"?
    loss = torch.mean((D_out-1)**2)
    return loss

def fake_mse_loss(D_out):
    # how close is the produced output from being "false"?
    loss = torch.mean(D_out**2)
    return loss

def cycle_consistency_loss(real_im, reconstructed_im, lambda_weight=10):
    # calculate reconstruction loss 
    # return weighted loss
    loss = lambda_weight * torch.mean(torch.abs(real_im-reconstructed_im))
    return loss

### Defines optimizers

In [None]:
import torch.optim as optim

# hyperparams for Adam optimizers
# got values from orinigal CycleGAN paper
lr= 0.0002
beta1= 0.5
beta2= 0.999

g_params = list(G_XtoY.parameters()) + list(G_YtoX.parameters())  # Get generator parameters

# Create optimizers for the generators and discriminators
g_optimizer = optim.Adam(g_params, lr, [beta1, beta2])
d_x_optimizer = optim.Adam(D_X.parameters(), lr, [beta1, beta2])
d_y_optimizer = optim.Adam(D_Y.parameters(), lr, [beta1, beta2])

### Training loop

In [None]:
# import save code. This was provided by Udacity
from helpers import save_samples, checkpoint

In [None]:
# train the network
def training_loop(dataloader_X, dataloader_Y, test_dataloader_X, test_dataloader_Y, 
                  n_epochs=1000):
    # keep track of losses over time
    losses = []
    print_every=10

    test_iter_X = iter(test_dataloader_X)
    test_iter_Y = iter(test_dataloader_Y)

    # Get some fixed data from domains X and Y for sampling. These are images that are held
    # constant throughout training, that allow us to inspect the model's performance.
    fixed_X = test_iter_X.next()[0]
    fixed_Y = test_iter_Y.next()[0]
    fixed_X = scale(fixed_X) # make sure to scale to a range -1 to 1
    fixed_Y = scale(fixed_Y)

    # batches per epoch
    iter_X = iter(dataloader_X)
    iter_Y = iter(dataloader_Y)
    batches_per_epoch = min(len(iter_X), len(iter_Y))
    for epoch in range(1, n_epochs+1):
        # Reset iterators for each epoch
        if epoch % batches_per_epoch == 0:
            iter_X = iter(dataloader_X)
            iter_Y = iter(dataloader_Y)

        images_X, _ = iter_X.next()
        images_X = scale(images_X) # make sure to scale to a range -1 to 1
        images_Y, _ = iter_Y.next()
        images_Y = scale(images_Y)

        # move images to GPU if available (otherwise stay on CPU)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        images_X = images_X.to(device)
        images_Y = images_Y.to(device)


        # ============================================
        #            TRAIN THE DISCRIMINATORS
        # ============================================

        ##  D_X
        # Compute the discriminator losses on real images
        d_out_real = D_X(images_X)
        loss_real = real_mse_loss(d_out_real)
        
        # Generate fake images that look like domain X based on real images in domain Y
        fake_im_X = G_YtoX(images_Y)
        
        # Compute the fake loss for D_X
        d_out_fake = D_X(fake_im_X)
        loss_fake = fake_mse_loss(d_out_fake)
        
        # Compute the total loss and perform backprop
        d_x_loss = loss_real + loss_fake

        d_x_optimizer.zero_grad()
        d_x_loss.backward()
        d_x_optimizer.step()
        
        ##  D_Y
        d_out_real = D_Y(images_Y)
        loss_real = real_mse_loss(d_out_real)
        
        fake_im_Y = G_XtoY(images_X)
        d_out_fake = D_Y(fake_im_Y)
        loss_fake = fake_mse_loss(d_out_fake)
        
        d_y_loss = loss_real + loss_fake
        
        d_y_optimizer.zero_grad()
        d_y_loss.backward()
        d_y_optimizer.step()

        
        
        # =========================================
        #            TRAIN THE GENERATORS
        # =========================================

        
        # Generate fake images that look like domain X based on real images in domain Y
        gen_YtoX = G_YtoX(images_Y)
        # Compute the generator loss based on domain X
        out_x = D_X(gen_YtoX)
        d_x_genloss = real_mse_loss(out_x)
        # Create a reconstructed y
        recon_y = G_XtoY(gen_YtoX)
        # Compute the cycle consistency loss (the reconstruction loss)
        recon_y_loss = cycle_consistency_loss(images_Y, recon_y)

        ## generate fake Y images and reconstructed X images
        gen_XtoY = G_XtoY(images_X)
        out_y = D_Y(gen_XtoY)
        d_y_genloss = real_mse_loss(out_y)
        recon_x = G_YtoX(gen_XtoY)
        recon_x_loss = cycle_consistency_loss(images_X, recon_x)
        
        # Add up all generator and reconstructed losses and perform backprop
        g_total_loss = d_x_genloss + recon_y_loss + d_y_genloss + recon_x_loss
        
        g_optimizer.zero_grad()
        g_total_loss.backward()
        g_optimizer.step()
        
        # Print the log info
        if epoch % print_every == 0:
            # append real and fake discriminator losses and the generator loss
            losses.append((d_x_loss.item(), d_y_loss.item(), g_total_loss.item()))
            print('Epoch [{:5d}/{:5d}] | d_X_loss: {:6.4f} | d_Y_loss: {:6.4f} | g_total_loss: {:6.4f}'.format(
                    epoch, n_epochs, d_x_loss.item(), d_y_loss.item(), g_total_loss.item()))

            
        sample_every=100
        # Save the generated samples
        if epoch % sample_every == 0:
            G_YtoX.eval() # set generators to eval mode for sample generation
            G_XtoY.eval()
            save_samples(epoch, fixed_Y, fixed_X, G_YtoX, G_XtoY, batch_size=16)
            G_YtoX.train()
            G_XtoY.train()

        #uncomment these lines, if you want to save your model
        checkpoint_every=1000
        # Save the model parameters
        if epoch % checkpoint_every == 0:
            checkpoint(epoch, G_XtoY, G_YtoX, D_X, D_Y)

    return losses


In [None]:
n_epochs = 6000

losses = training_loop(dataloader_X, dataloader_Y, test_dataloader_X, test_dataloader_Y, n_epochs=n_epochs)

### A function to view images after a certain number of epochs

In [None]:
import matplotlib.image as mpimg

# helper visualization code
def view_samples(iteration, sample_dir='samples_cyclegan'):
    
    # samples are named by iteration
    path_XtoY = os.path.join(sample_dir, 'sample-{:06d}-X-Y.png'.format(iteration))
    path_YtoX = os.path.join(sample_dir, 'sample-{:06d}-Y-X.png'.format(iteration))
    
    # read in those samples
    try: 
        x2y = mpimg.imread(path_XtoY)
        y2x = mpimg.imread(path_YtoX)
    except:
        print('Invalid number of iterations.')
    
    fig, (ax1, ax2) = plt.subplots(figsize=(18,20), nrows=2, ncols=1, sharey=True, sharex=True)
    ax1.imshow(x2y)
    ax1.set_title('X to Y')
    ax2.imshow(y2x)
    ax2.set_title('Y to X')

In [None]:
# view samples at iteration 100
view_samples(100, 'samples_cyclegan')

In [None]:
# view samples at iteration 1000
view_samples(1000, 'samples_cyclegan')