# Base Import

In [None]:
!git clone https://github.com/Edouard99/GAN.git

In [None]:
import os
import torch
import torch.nn
import torch.backends.cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np 
from IPython.display import clear_output
from IPython.display import HTML
import random
from PIL import Image
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
"""This cell import the dataset and according to the image type and size
For now the designed GAN are for 64px and 256px images
Available dataset are Pokemon 64px and 256px


"""

image_type="pokemon" # Only type of data available for now
image_size=64 # Choose between 64 and 256


os.mkdir("/content/data")
os.mkdir("/content/data/img_ds")
os.mkdir("/content/results")
os.mkdir("/content/net")

if image_type=="pokemon":
    if image_size==64:
        !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1mKWPRvdYg6jfN6G8AFxsHpzjo3608QaJ' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1mKWPRvdYg6jfN6G8AFxsHpzjo3608QaJ" -O /content/poke_ds_64.zip && rm -rf /tmp/cookies.txt
        !unzip -q /content/poke_ds_64.zip -d /content/data/img_ds
        !rm /content/poke_ds_64.zip
    if image_size==256:
        !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=19sQKN9H4gmNPxQLLtjDmw5SV5KZ4Q1n1' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=19sQKN9H4gmNPxQLLtjDmw5SV5KZ4Q1n1" -O /content/poke_ds_256.zip && rm -rf /tmp/cookies.txt
        !unzip -q /content/poke_ds_256.zip -d /content/data/img_ds
        !rm /content/poke_ds_256.zip

# Parameters

In [4]:
dataroot = "/content/data/img_ds" # Image Dataset path

workers = 2 # Number of workers for dataloader

batch_size = 32 # Batch Size

nc = 3 # Number of channels in the training images. For color images this is 3

nz = 16 # Size of latente space (input vector of generator)

ngf = 128 # Number of filter used in the generator (see doc)

ndf = 128 # Number of filter used in the discriminant (see doc)

ngpu = 1 # Is GPU available ? use 1 for GPU and 0 for CPU

# Dataset Creation

In [None]:
""" This cell ensures reproducibility of the results, I used the seed 666 in my project, feel free to change the seed for new results """

manualSeed = 666

random.seed(manualSeed)
torch.manual_seed(manualSeed)
np.random.seed(manualSeed)
torch.cuda.manual_seed_all(manualSeed)
torch.cuda.manual_seed(manualSeed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32

g = torch.Generator()
g.manual_seed(manualSeed)

In [None]:
""" This cell creates the dataloader that will provide the images to the neural network, images are centered, resized (if necessary) and normalized"""

dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers,pin_memory=True,worker_init_fn=seed_worker,generator=g)

device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu") # GPU or CPU

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))


# Model

In [None]:
""" This cell allows to select the model and the training type used"""
gan_mode="dcgan" #Gan Type, choose between "dcgan" and "wgan"
training_mode="boosting" #Training mode, choose between {"classic", "monitoring", "boosting"} for DCGAN and {"wgan"} for Wgan

# Learning rate for optimizers:
lr_dcgan = 0.00015
lr_wgan= 5e-5
# Beta1 hyperparam for Adam optimizers (dcgan)
beta1 = 0.5


# Select and import the model and trainer regarding what you selected for gan_mode and training_mode
if gan_mode=="wgan":
  training_mode=gan_mode
else:
  if training_mode=="wgan":
    training_mode="classic"
if image_size==64:
    from GAN.Models import gan64 as model
elif image_size==256:
    from GAN.Models import gan256 as model

if training_mode=="classic":
    from GAN.Trainings import training_classic as trainer
if training_mode=="boosting":
    from GAN.Trainings import training_boosting as trainer
if training_mode=="monitoring":
    from GAN.Trainings import training_monitoring as trainer
if gan_mode=="wgan":
    from GAN.Trainings import training_wgan as trainer


# Create the generator
netG = model.Generator(nz,ngf,nc).to(device)
netG.apply(model.weights_init)

# Create the Discriminator
netD = model.Discriminator(nc,ndf,device,mode=gan_mode).to(device)
netD.apply(model.weights_init)

fixed_noise = torch.randn(64, nz, 1, 1, device=device) # This is a fixed noise that will be used to generate a grid of images saved evey 4 epochs

#Create the optimizers
if gan_mode=="dcgan":
    optimizerD = optim.Adam(netD.parameters(), lr=lr_dcgan, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr_dcgan, betas=(beta1, 0.999))
if gan_mode=="wgan":
    optimizerD = optim.RMSprop(netD.parameters(), lr=lr_wgan)
    optimizerG = optim.RMSprop(netG.parameters(), lr=lr_wgan)

# Training

In [9]:
""" This cell trains the networks 
    During the training, the network's weights are saved(can be disabled) and a grid of images(based on fixed_noise) is generated and saved at each epoch

Please use train function as:
    train(dataloader,discriminant_net,generator_net,discriminant_optimizer,generator_optimizer,number_of_epochs,
                device(CPU/GPU),save_the_net_parameters(True/False),net_saving_path,image_grid_saving_path,fixed_noise)

For monitored training there is 2 more parameters that can be set by the user, please refer to training_monitoring.py
"""

img_list,G_losses,D_losses=trainer.train(dataloader,netD,netG,optimizerD,optimizerG,150,device,True,"/content/net/","/content/results/",fixed_noise)

In [None]:
"""Generate an animation of all the image grid saved during training"""


fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list[0:len(img_list)]]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())