<a href="https://colab.research.google.com/github/TheodorSergeev/optml_gan/blob/main/gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Adapted from https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

# Initialisation

* First, create an empty directoy optml_gan in the main directory of you google drive. Then download and extract the contents of our main repository ([link to repo](https://github.com/TheodorSergeev/optml_gan)) and place them in the optml_gan directory on your google drive. Then run the following cell.   
* If you only want to run the code, you can only copy the src directory.  
* If you want to reproduce our plots without going through the training process, we recommend you open the notebook in the following google drive ([link to drive](https://drive.google.com/drive/folders/1C-8I8Z3hHlfn-q-5P34SjtJbyep-u9UT?usp=sharing)) that we are hosting, it contains all the save files from the experiments and code to reproduce our plots.

In [1]:
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')

    # packages to generate requirement.txt
    %pip install nbconvert
    %pip install pipreqs
    # for Frechet inception distance
    %pip install pytorch-fid

    %cd drive/My Drive/optml_gan
    PATH = './'
else:
    PATH = './'

Mounted at /content/drive
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pipreqs
  Downloading pipreqs-0.4.11-py2.py3-none-any.whl (32 kB)
Collecting yarg
  Downloading yarg-0.1.9-py2.py3-none-any.whl (19 kB)
Installing collected packages: yarg, pipreqs
Successfully installed pipreqs-0.4.11 yarg-0.1.9
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-fid
  Downloading pytorch-fid-0.2.1.tar.gz (14 kB)
Building wheels for collected packages: pytorch-fid
  Building wheel for pytorch-fid (setup.py) ... [?25l[?25hdone
  Created wheel for pytorch-fid: filename=pytorch_fid-0.2.1-py3-none-any.whl size=14835 sha256=0c4488d187f3669520f3151912fa47d197001cb32e70678644b0f8dbf87f111d
  Stored in directory: /root/.cache/pip/wheels/24/ac/03/c5634775c8a64f702343ef5923278f8d3

In [2]:
from __future__ import print_function

import time

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data

import torchvision.utils as vutils

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from scipy import linalg
from torch.nn.functional import adaptive_avg_pool2d
%matplotlib inline

In [3]:
%load_ext autoreload
%autoreload 2

# Source code

In [4]:
from src.data_handling import *
from src.utils import *
from src.model import *
from src.losses import *
from src.fid import *

loss_dict = {
    "kl": (loss_dis_kl, loss_gen_kl),
    "wass": (loss_dis_wasser, loss_gen_wasser),
    "hinge": (loss_dis_hinge, loss_gen_hinge)
}

# FID

from src.training import *
from src.visualisation import *
from src.serialisation import *

# https://keras.io/examples/generative/conditional_gan/
from src.architectures import *

from src.gridsearch import *

# Training example

## Parameters

In [5]:
# Root directory for dataset
dataroot = PATH + "data/"

# Dataset name
dataset_name = 'mnist'  # 'cifar10' or 'mnist'

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this size using a transformer.
image_size = 28  # 28 for mnist, 64 for others

# Size of z latent vector (i.e. size of generator input)
nz = 128

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

In [6]:
# Number of training epochs
num_epochs = 3

# Learning rate for optimizers
lrD = 2e-4
lrG = 2e-4

# Beta1 hyperparam for Adam optimizers
beta1 = 0.9  # 0.9 == default

In [7]:
dataset, nc = get_dataset(dataset_name, image_size, dataroot)

# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

## Run

In [None]:
loss_name = 'wass'

shuffle = True
num_epochs = 20
plot = True
save_stats = True
create_dir = True
save_epochs = True
momentumD, momentumG = 0.0, 0.0
optimizer_name = 'adam'

iter_per_epoch_dis, iter_per_epoch_gen, grad_penalty_coef = set_loss_params(
    loss_name)

stats, dataloader, netG, netD = run_experiment(ngpu, device, dataset, workers, batch_size,
                                               shuffle, num_epochs, plot, lrD, lrG, beta1, nc, nz, loss_name, '', save_stats, create_dir,
                                               iter_per_epoch_dis, iter_per_epoch_gen, grad_penalty_coef,
                                               save_epochs, save_models, momentumD, momentumG, optimizer_name, PATH, count_params=True)


Generator parameters 1493520
Discriminator parameters 1460225
./ adam_mG0.0_mD0.0_ wassLoss_lrd0.0002_lrg0.0002_b1b0.9_itd5_itg1_gpv10.0_
Starting Training Loop...
[0/20][0/469]	Loss_D: -0.5729	Loss_G: -0.0335	D(x): 1.6005	D(G(z)): 0.0345 / 0.0335
[0/20][50/469]	Loss_D: -2567.3757	Loss_G: -5441.0566	D(x): 8946.7539	D(G(z)): 5576.6924 / 5441.0566
[0/20][100/469]	Loss_D: -30515.3223	Loss_G: 72398.7188	D(x): -13460.2734	D(G(z)): -72471.7656 / -72398.7188
[0/20][150/469]	Loss_D: -6868.7910	Loss_G: 4306.2422	D(x): 8168.0684	D(G(z)): -4858.0981 / -4306.2422
[0/20][200/469]	Loss_D: -12450.2861	Loss_G: 14821.1113	D(x): 6068.2803	D(G(z)): -15204.1309 / -14821.1113
[0/20][250/469]	Loss_D: 294.9125	Loss_G: -13247.3867	D(x): 13634.1758	D(G(z)): 13203.7832 / 13247.3867
[0/20][300/469]	Loss_D: -2972.6931	Loss_G: -16284.5488	D(x): 22966.1621	D(G(z)): 17177.3008 / 16284.5488
[0/20][350/469]	Loss_D: -6371.9805	Loss_G: 25654.4102	D(x): -12069.4854	D(G(z)): -26027.7617 / -25654.4102
[0/20][400/469]	Loss_

In [None]:
img_list = stats['img_list']
G_losses = stats['G_losses']
D_losses = stats['D_losses']

# Visualisation

In [None]:
create_repo_paths(PATH)

In [None]:
save_path = PATH + 'img/real_vs_fake'
plot_loss(G_losses, D_losses, save_path, save=False)

In [None]:
plot_realvsfake(dataloader, device, img_list, PATH + 'img/loss', save=False)

## G’s progression



In [None]:
fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

# Hyperparameter optimisation (gridsearch)

In [None]:
# Root directory for dataset
dataroot = PATH + "data/"

# Dataset name
dataset_name = 'mnist'  # 'cifar10' or 'mnist'

# Number of workers for dataloader
workers = 2

# Spatial size of training images. All images will be resized to this size using a transformer
image_size = 28  # 28 for mnist, 64 for others

# Size of z latent vector (i.e. size of generator input)
nz = 128

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

In [None]:
create_repo_paths(PATH)

In [None]:
dataset, nc = get_dataset(dataset_name, image_size, dataroot)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [None]:
grid_search(ngpu, device, dataset, workers,
            experiment_prefix='',           # add an extra word at the begining to the save path of the models and stats
            batch_size_list=[128],
            shuffle_list=[True],
            num_epochs_list=[300],
            loss_name_list=['wass'],        # wass, hinge
            optimizer_name_list=['adam', 'sgd', 'rmsprop'],   # 'adam' 'sgd' 'rmsprop'
            beta1_list=[0.9],               # 0.9 == default # Beta1 hyperparam for Adam optimizers
            lr_list=[1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7],
            momentums_list=[(0, 0)],        # [(momentumD, momentumG)]
            plot=False,
            save_stats=True,                # save the stats to disk
            create_dir=True,                # create the directories to save files
            save_epochs=10,                 # save the model every save_epochs epochs
            save_models=True,               # save the models to disk
            manualSeed=123,                 # keep at 123
            nc=nc, nz=nz,
            PATH=PATH
            )