# Optimal Transport Project

*Authors : Romain Avouac, Slimane Thabet*

In [0]:
from time import time
from multiprocessing import cpu_count
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from IPython import display
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils import data
from torchvision import datasets, transforms
from torchvision.transforms.functional import to_pil_image, resize, to_tensor
from torchvision.transforms.functional import normalize
import imageio

In [0]:
# GPU configuration
use_gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if use_gpu else "cpu")

## Load and preprocess MNIST data

In [0]:
def load_mnist(batch_size=128, img_size=28):
    """Download, preprocess and load MNIST data."""
    mnist = datasets.MNIST('data', train=True, download=True).data
    # Perform transformation directly on raw data rather than in the DataLoader
    # => avoids overhead of transforming at each batch call => much faster epochs.
    pics = []
    for pic in mnist:
        pic = to_pil_image(pic)
        if img_size != 28:
            pic = resize(pic, img_size) # Resize image if needed
        pic = to_tensor(pic)
        pic = normalize(pic, 0.5, 0.5) # Normalize in [-1,1]
        pics.append(pic)

    mnist = torch.stack(pics)

    return torch.utils.data.DataLoader(mnist, batch_size=batch_size, shuffle=True)

## Vanilla GAN

In [0]:
class GANGenerator(nn.Module):
    def __init__(self, input_size, d, output_shape):
        super(GANGenerator, self).__init__()

        self.map1 = nn.Linear(input_size, d)
        self.map2 = nn.Linear(self.map1.out_features, d*2)
        self.map3 = nn.Linear(self.map2.out_features, d*4)
        self.map4 = nn.Linear(self.map3.out_features, 
                              output_shape[0] * output_shape[1] * output_shape[2])

        self.act = nn.LeakyReLU(negative_slope=0.2)
        self.output_shape = output_shape

    def forward(self, x):
        x = self.act(self.map1(x))
        x = self.act(self.map2(x))
        x = self.act(self.map3(x))
        x = torch.tanh(self.map4(x))
        
        return x.view((-1,)+self.output_shape)

In [0]:
class GANCritic(nn.Module):
    def __init__(self, input_size, d):
        super(GANCritic, self).__init__()

        self.map1 = nn.Linear(input_size, d)
        self.map2 = nn.Linear(self.map1.out_features, d//2)
        self.map3 = nn.Linear(self.map2.out_features, d//4)
        self.map4 = nn.Linear(self.map3.out_features, 1)

        self.act = nn.LeakyReLU(negative_slope=0.2)

    def forward(self, x):
        x = nn.Flatten()(x)
        x = self.act(self.map1(x))
        x = F.dropout(x, 0.3)
        x = self.act(self.map2(x))
        x = F.dropout(x, 0.3)
        x = self.act(self.map3(x))
        x = F.dropout(x, 0.3)
        x = torch.sigmoid(self.map4(x))

        return x

In [0]:
class GAN():
    
    def __init__(self, dataloader, generator, critic, lr=0.0001):

        self.dataloader = dataloader

        # default parameters for mnist 
        self.img_channels = dataloader.dataset[0].shape[0]
        self.img_rows = dataloader.dataset[0].shape[1]
        self.img_cols = dataloader.dataset[0].shape[2]
        self.img_shape = (self.img_channels, self.img_rows, self.img_cols)
        self.z_dim = z_dim
        self.lr = lr

        self.generator = generator.to(device)
        self.critic = critic.to(device)

        
    def sample_data(self, n_sample):
        z_random = np.random.randn(n_sample, self.z_dim)
        z_random = torch.FloatTensor(z_random).to(device)
        samples = self.generator(z_random)
        samples = samples.detach().cpu().numpy()
        return samples
        
    def train(self, epochs=100, print_interval=10, save_generator_path=None):
        
        criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
        d_optimizer = optim.Adam(self.critic.parameters(), lr=self.lr, betas=(0.5, 0.999))
        g_optimizer = optim.Adam(self.generator.parameters(), lr=self.lr, betas=(0.5, 0.999))
        d_steps = 1
        g_steps = 1

        t = time()
        
        for epoch in range(epochs):
            
            for batch in self.dataloader:
                batch = batch.type(torch.FloatTensor).to(device)

                for d_index in range(d_steps):
                    # 1. Train D on real+fake
                    self.critic.zero_grad()

                    #  1A: Train D on real
                    d_real_data = Variable(batch.to(device))
                    d_real_decision = self.critic(d_real_data)
                    y_real = Variable(torch.ones(d_real_decision.shape).to(device))
                    d_real_error = criterion(d_real_decision, y_real)
        
                    #  1B: Train D on fake
                    d_gen_input = torch.randn((batch.shape[0], self.z_dim))
                    d_gen_input = Variable(d_gen_input.to(device))
                    d_fake_data = self.generator(d_gen_input).detach()  # detach to avoid training G on these labels
                    d_fake_decision = self.critic(d_fake_data)
                    y_fake = Variable(torch.zeros(d_real_decision.shape).to(device))
                    d_fake_error = criterion(d_fake_decision, y_fake) 

                    # Backward propagation on the sum of the two losses
                    d_train_loss = d_real_error + d_fake_error
                    d_train_loss.backward()
                    d_optimizer.step() # Only optimizes D's parameters
        
                for g_index in range(g_steps):
                    # 2. Train G on D's response (but DO NOT train D on these labels)
                    self.generator.zero_grad()
        
                    gen_input = torch.randn((batch.shape[0], self.z_dim))
                    gen_input = Variable(gen_input.to(device))
                    g_fake_data = self.generator(gen_input)
                    dg_fake_decision = self.critic(g_fake_data)
                    y_ones = Variable(torch.ones(dg_fake_decision.shape).to(device))
                    g_error = criterion(dg_fake_decision, y_ones)   # Train G to pretend it's genuine
        
                    g_error.backward()
                    g_optimizer.step()  # Only optimizes G's parameters
                    
    
            if (epoch > 0 and epoch % print_interval == 0) or epoch+1 == epochs:
                de = d_train_loss.detach().cpu().numpy()
                ge = g_error.detach().cpu().numpy()
                print("Epoch %s: C_loss =  %s ;  G_loss = %s;  time = %s" %
                      (epoch, de, ge, time()-t))
                
            # if epoch % 1 == 0:
            #     samples = self.sample_data(3)*0.5 + 0.5
            #     for img in samples:
            #         plt.figure()
            #         plt.imshow(img[0,:,:], cmap='gray')
            #         plt.show()

        if save_generator_path is not None:
            torch.save(self.generator.state_dict(), save_generator_path)


In [0]:
# Vanilla GAN parameters
img_size = 28 # Keep initial MNIST size
z_dim = 100
G_dim_init = 128
C_dim_init = 1024

lr = 0.0002
batch_size = 128
n_epochs = 100

save_generator_path = 'vanilla_gan_gen.pt'

In [0]:
# Get MNIST data as Torch dataloader
mnist_dataloader = load_mnist(batch_size=batch_size, img_size=img_size)
img_shape = mnist_dataloader.dataset[0].shape
n_pixels = img_shape[0] * img_shape[1] * img_shape[2]

In [9]:
TRAIN_MODE = True

if TRAIN_MODE:
    # Train GAN and save generator weights
    gan_generator = GANGenerator(z_dim, G_dim_init, img_shape)
    gan_critic = GANCritic(n_pixels, C_dim_init)
    gan = GAN(mnist_dataloader, gan_generator, gan_critic, lr=lr)
    gan.train(n_epochs, save_generator_path=save_generator_path) # Change path to None to prevent saving weights
else:
    # Load previously trained generator weights
    gan_generator = GANGenerator(z_dim, G_dim_init, img_shape)
    gan_generator.load_state_dict(torch.load(save_generator_path))
    gan_critic = GANCritic(n_pixels, C_dim_init) # Not used
    gan = GAN(mnist_dataloader, gan_generator, gan_critic, lr=lr)

Epoch 10: C_loss =  0.4421556 ;  G_loss = 2.6215782;  time = 34.57757544517517
Epoch 20: C_loss =  0.6888797 ;  G_loss = 1.6280651;  time = 66.4431402683258
Epoch 30: C_loss =  0.9987544 ;  G_loss = 1.8691449;  time = 97.94661331176758
Epoch 40: C_loss =  0.8366662 ;  G_loss = 1.7507833;  time = 129.45234441757202
Epoch 50: C_loss =  0.826293 ;  G_loss = 1.4914486;  time = 160.96948385238647
Epoch 60: C_loss =  0.844104 ;  G_loss = 1.2613406;  time = 192.32516193389893
Epoch 70: C_loss =  0.98980445 ;  G_loss = 1.2515063;  time = 223.7138316631317
Epoch 80: C_loss =  1.1156857 ;  G_loss = 1.732059;  time = 255.320326089859
Epoch 90: C_loss =  1.0733883 ;  G_loss = 1.140757;  time = 286.79772424697876
Epoch 99: C_loss =  1.0112252 ;  G_loss = 1.2990302;  time = 315.51535511016846


In [0]:
# Plot some generated images as a GIF
samples = gan.sample_data(1000)*0.5 + 0.5
samples = samples * 256
samples = samples.astype(np.uint8)
samples = np.squeeze(samples, 1)

gif_path = 'vanilla_gan.gif'
imageio.mimwrite(gif_path, samples, fps=5)
gifPath = Path(gif_path)
with open(gifPath,'rb') as f:
    display.Image(data=f.read(), format='png', width=200, height=200)

## DC-GAN

In [0]:
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [0]:
class DCGANGenerator(nn.Module):
    def __init__(self, z_dim, d):
        super(DCGANGenerator, self).__init__()

        self.z_dim = z_dim

        self.deconv1 = nn.ConvTranspose2d(z_dim, d, 4, 1, 0)
        self.deconv1_bn = nn.BatchNorm2d(d)
        self.deconv2 = nn.ConvTranspose2d(d, d//2, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm2d(d//2)
        self.deconv3 = nn.ConvTranspose2d(d//2, d//4, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm2d(d//4)
        self.deconv4 = nn.ConvTranspose2d(d//4, d//8, 4, 2, 1)
        self.deconv4_bn = nn.BatchNorm2d(d//8)
        self.deconv5 = nn.ConvTranspose2d(d//8, 1, 4, 2, 1)

        self.activ = nn.ReLU()

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    def forward(self, x):
        x = x.view(-1, self.z_dim, 1, 1)
        x = self.activ(self.deconv1_bn(self.deconv1(x)))
        x = self.activ(self.deconv2_bn(self.deconv2(x)))
        x = self.activ(self.deconv3_bn(self.deconv3(x)))
        x = self.activ(self.deconv4_bn(self.deconv4(x)))
        x = torch.tanh(self.deconv5(x)) # Output shape : 
        return x

In [0]:
class DCGANCritic(nn.Module):
    def __init__(self, d):
        super(DCGANCritic, self).__init__()

        self.conv1 = nn.Conv2d(1, d, 4, 2, 1)
        self.conv2 = nn.Conv2d(d, d*2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(d*2)
        self.conv3 = nn.Conv2d(d*2, d*4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(d*4)
        self.conv4 = nn.Conv2d(d*4, d*8, 4, 2, 1)
        self.conv4_bn = nn.BatchNorm2d(d*8)
        self.conv5 = nn.Conv2d(d*8, 1, 4, 1, 0)

        self.activ = nn.LeakyReLU(negative_slope=0.2)

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    def forward(self, x):
        x = self.activ(self.conv1(x))
        x = self.activ(self.conv2_bn(self.conv2(x)))
        x = self.activ(self.conv3_bn(self.conv3(x)))
        x = self.activ(self.conv4_bn(self.conv4(x)))
        x = torch.sigmoid(self.conv5(x))

        return x.view((-1, 1)) # (batch_size, 1)

In [0]:
# DC-GAN parameters
img_size = 64
z_dim = 100
G_dim_init = 512 # 1024 or 512
C_dim_init = G_dim_init // 8
lr = 0.0002
n_epochs = 20

path_weights_dcgan = 'dc_gan_gen.pt'

In [0]:
# Get MNIST data as Torch dataloader
mnist_dataloader = load_mnist(batch_size=batch_size, img_size=img_size)
img_shape = mnist_dataloader.dataset[0].shape

In [16]:
TRAIN_MODE = True

if TRAIN_MODE:
    # Train GAN and save generator weights
    dcgan_generator = DCGANGenerator(z_dim, G_dim_init)
    dcgan_generator.weight_init(mean=0.0, std=0.02)
    dcgan_critic = DCGANCritic(C_dim_init)
    dcgan_critic.weight_init(mean=0.0, std=0.02)
    dcgan = GAN(mnist_dataloader, dcgan_generator, dcgan_critic, lr=lr)
    dcgan.train(n_epochs, save_generator_path=path_weights_dcgan,
                print_interval=1) # Change path to None to prevent saving weights
else:
    # Load previously trained generator weights
    dcgan_generator = DCGANGenerator(z_dim, G_dim_init)
    gan_generator.load_state_dict(torch.load(path_weights_dcgan))
    dcgan_critic = DCGANCritic(C_dim_init) # Not used
    dcgan = GAN(mnist_dataloader, dcgan_generator, dcgan_critic, lr=lr)

Epoch 1: C_loss =  0.48784274 ;  G_loss = 2.0608;  time = 93.96772336959839
Epoch 2: C_loss =  0.2936925 ;  G_loss = 4.147394;  time = 140.93153953552246
Epoch 3: C_loss =  0.19924375 ;  G_loss = 2.7586212;  time = 187.89247012138367
Epoch 4: C_loss =  0.22999239 ;  G_loss = 3.4658613;  time = 234.85290145874023
Epoch 5: C_loss =  0.11085105 ;  G_loss = 5.0173283;  time = 281.80547976493835
Epoch 6: C_loss =  0.25720948 ;  G_loss = 4.0861044;  time = 328.77236771583557
Epoch 7: C_loss =  0.15680411 ;  G_loss = 4.3106117;  time = 375.74854826927185
Epoch 8: C_loss =  0.5543426 ;  G_loss = 2.4007347;  time = 422.71264481544495
Epoch 9: C_loss =  0.07505908 ;  G_loss = 4.055755;  time = 469.6685416698456
Epoch 10: C_loss =  0.6098065 ;  G_loss = 2.4486609;  time = 516.6242158412933
Epoch 11: C_loss =  0.37127236 ;  G_loss = 3.9166965;  time = 563.5892198085785
Epoch 12: C_loss =  0.050106414 ;  G_loss = 3.867641;  time = 610.55885887146
Epoch 13: C_loss =  0.40300643 ;  G_loss = 4.5084662

In [0]:
# Plot some generated images as a GIF
samples = dcgan.sample_data(1000)*0.5 + 0.5
samples = samples * 256
samples = samples.astype(np.uint8)
samples = np.squeeze(samples, 1)

gif_path = 'dc_gan.gif'
imageio.mimwrite(gif_path, samples, fps=5)
gifPath = Path(gif_path)
with open(gifPath,'rb') as f:
    display.Image(data=f.read(), format='png', width=200, height=200)