<a href="https://colab.research.google.com/github/TheodorSergeev/optml_gan/blob/main/dcgan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Adapted from https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

# Initialisation

In [1]:
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')

    # packages to generate requirement.txt
    %pip install nbconvert
    %pip install pipreqs
    # for Frechet inception distance
    %pip install pytorch-fid

    %cd drive/My Drive/optml_gan2
    PATH = './'
else:
    PATH = './'

In [2]:
from __future__ import print_function

import time

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data

import torchvision.utils as vutils

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from scipy import linalg
from torch.nn.functional import adaptive_avg_pool2d
%matplotlib inline

In [3]:
%load_ext autoreload
%autoreload 2

# Source code

In [4]:
from src.data_handling import *
from src.utils import *
from src.model import *
from src.losses import *
from src.fid import *

loss_dict = {
    "kl": (loss_dis_kl, loss_gen_kl),
    "wass": (loss_dis_wasser, loss_gen_wasser),
    "hinge": (loss_dis_hinge, loss_gen_hinge)
}

# FID

from src.training import *
from src.visualisation import *
from src.serialisation import *

# https://keras.io/examples/generative/conditional_gan/
from src.architectures import *

from src.gridsearch import *

# Metrics

In [5]:
# Root directory for dataset
dataroot = PATH + "data/"

# Dataset name
dataset_name = 'mnist' # 'cifar10' or 'mnist'

# Number of workers for dataloader
workers = 0

# Spatial size of training images. All images will be resized to this size using a transformer.
image_size = 28 # 28 for mnist, 64 for others

# Size of z latent vector (i.e. size of generator input)
nz = 128

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1
create_repo_paths(PATH)
dataset, nc = get_dataset(dataset_name, image_size, dataroot)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [158]:
# Create the dataloader

batch_size_eval = 32 # 128
num_samples = 64 # 1000
set_seeds(manualSeed=123)
which = torch.ones(len(dataset)).multinomial(num_samples, replacement=True)
dataset_subset = torch.utils.data.Subset(dataset, which)

real_dataloader = torch.utils.data.DataLoader(dataset_subset, batch_size=batch_size_eval,
                                         shuffle=False, num_workers=workers) # shuffle=False for reproducibility
                                         
# Load inception model
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
inception_model = InceptionV3([block_idx])
inception_model = inception_model.to(device)

Random Seed:  123


In [None]:
loss_name = 'wass'
netG = init_net(Generator(ngpu, nc, nz), device, ngpu)
netD = init_net(Discriminator(ngpu, nc, loss_name), device, ngpu)

In [12]:
# for i in range(5):
# sample_batch = next(iter(dataloader))
frechet_dist = calculate_fid(num_samples, real_dataloader, batch_size_eval, device, inception_model, netG, nz, workers)
# print()

100%|██████████| 32/32 [00:06<00:00,  4.83it/s]
100%|██████████| 32/32 [00:05<00:00,  5.70it/s]


frechet dist: 365.0568528083575 | time to calculate : 24.826169967651367 s


In [13]:
# Load inception model
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
inception_model = InceptionV3([block_idx])
inception_model = inception_model.to(device)

# Init empty G and D
netG = init_net(Generator(ngpu, nc, nz), device, ngpu)
netD = init_net(Discriminator(ngpu, nc, loss_name), device, ngpu)

# Init paths
create_repo_paths(PATH)
generated_data_path = PATH + 'generated_data/'
generated_data_path 

# Create a sample of the mnist dataset
batch_size_eval = 10 # 128
num_samples = 10 # 1000
set_seeds(manualSeed=123)
which = torch.ones(len(dataset)).multinomial(num_samples, replacement=True)
dataset_subset = torch.utils.data.Subset(dataset, which)

real_dataloader = torch.utils.data.DataLoader(dataset_subset, batch_size=batch_size_eval,
                                         shuffle=False, num_workers=workers) # shuffle=False for reproducibility

Generator(
  (fc1): Linear(in_features=128, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=1024, bias=True)
  (fc4): Linear(in_features=1024, out_features=784, bias=True)
)
Discriminator(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
)
Random Seed:  123


In [30]:
def load_G(ngpu, Generator, save_path_G, device):

    netG = init_net(Generator(ngpu, nc, nz), device,ngpu)
    netG.load_state_dict(torch.load(save_path_G))
    netG.eval()
    return netG

In [170]:
list_subfolders_with_paths = sorted([f.path for f in os.scandir(generated_data_path) if f.is_dir()])

paths_adam = list_subfolders_with_paths[0:7]
paths_rmsprop = list_subfolders_with_paths[7:14]
paths_sgd = list_subfolders_with_paths[14:]

# print( paths_adam)
# print( paths_rmsprop)
# print( paths_sgd)

which_iterations = [0,50,100] #[0,50,100,150,200,250,290] # [0,10,20,30,40,50,60,70,80,90,100,110,120, 130,140,150,160,170,180,190,200,210,220,230,240,250,260,270,280,290]
# desired_learning_rates = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7] # [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
list_paths = paths_rmsprop # paths_adam paths_sgd paths_rmsprop

all_optimizer_scores = {}
fix_extension = False
calculate_frechet_bool = True

all_lr_scores = {}
get_stats = True
# all_scores['adam'] = [desired_learning_rates, score_list]

for path in list_paths:
    folder = path[17:]
    print(folder)
    param_list = folder.split('_')

    optimizer_name = param_list[0]
    loss_name = param_list[3][:-4]
    lr = param_list[4][3:]
    # print()
    print(optimizer_name, lr)
    score_list = []
    if get_stats:
        stats_path = path + '/stat.pickle'
        stats = pickle_load(stats_path)
        img_list = stats['img_list'] # 8x8 images fake generatred images in one picture
        G_losses = stats['G_losses'] 
        D_losses = stats['D_losses'] 
        img_list_nogrid  = stats['img_list_nogrid'] # 64 fake generatred images in a list
    else:
        for file in os.listdir(path+'/models/'):
            # print(file[:7])
                if file[:7] == 'model_G':
                    # print(file)
                    number = int(file[8:].split('.')[0]) # split because sometimes file has extension .zip sometimes doesn't
                    # print(number)
                    if number in which_iterations: # epochs
                        # score_list[float(lr)] = number
                        
                        print(number)
                        # print(path+'/models/'+file)
                        # print(len(file[8:].split('.')))

                        if fix_extension: # if files are not in .zip extension, use this
                            if len(file[8:].split('.'))==1: 
                                print(path+'/models/'+file)
                                os.rename(path+'/models/'+file, path+'/models/'+file+'.zip') 

                        if calculate_frechet_bool:
                            net_G = load_G(ngpu, Generator, path+'/models/'+file,device)
                            frechet_dist = calculate_fid(num_samples, real_dataloader, batch_size_eval, device, inception_model, netG, nz, workers)
                            score_list.append(frechet_dist)
    all_lr_scores[float(lr)] = score_list
all_optimizer_scores[optimizer_name] = all_lr_scores
all_optimizer_stats[optimizer_name]

rmsprop_mG0_mD0_wassLoss_lrd0.0001_lrg0.0001_b1b0.9_itd5_itg1_gpv10.0_
rmsprop 0.0001
0


100%|██████████| 2/2 [00:00<00:00,  5.16it/s]
100%|██████████| 2/2 [00:00<00:00,  5.69it/s]


frechet dist: 375.0156716629199 | time to calculate : 9.843644380569458 s
100


100%|██████████| 2/2 [00:00<00:00,  5.37it/s]
100%|██████████| 2/2 [00:00<00:00,  5.62it/s]


frechet dist: 375.7824242718262 | time to calculate : 9.964108228683472 s
50


100%|██████████| 2/2 [00:00<00:00,  5.35it/s]
100%|██████████| 2/2 [00:00<00:00,  5.63it/s]


frechet dist: 376.8077026061756 | time to calculate : 10.002908945083618 s
rmsprop_mG0_mD0_wassLoss_lrd0.001_lrg0.001_b1b0.9_itd5_itg1_gpv10.0_
rmsprop 0.001
0


100%|██████████| 2/2 [00:00<00:00,  5.34it/s]
100%|██████████| 2/2 [00:00<00:00,  5.59it/s]


frechet dist: 378.0283057411398 | time to calculate : 10.097379207611084 s
100


100%|██████████| 2/2 [00:00<00:00,  4.93it/s]
100%|██████████| 2/2 [00:00<00:00,  5.62it/s]


frechet dist: 376.1479737473754 | time to calculate : 10.048278570175171 s
50


100%|██████████| 2/2 [00:00<00:00,  5.30it/s]
100%|██████████| 2/2 [00:00<00:00,  5.07it/s]


frechet dist: 375.83017605963806 | time to calculate : 9.873766422271729 s
rmsprop_mG0_mD0_wassLoss_lrd0.01_lrg0.01_b1b0.9_itd5_itg1_gpv10.0_
rmsprop 0.01
0


100%|██████████| 2/2 [00:00<00:00,  5.34it/s]
100%|██████████| 2/2 [00:00<00:00,  5.63it/s]


frechet dist: 376.31020434258375 | time to calculate : 9.776933193206787 s
100


100%|██████████| 2/2 [00:00<00:00,  5.27it/s]
100%|██████████| 2/2 [00:00<00:00,  5.62it/s]


frechet dist: 374.7157998532318 | time to calculate : 9.813695669174194 s
50


100%|██████████| 2/2 [00:00<00:00,  4.94it/s]
100%|██████████| 2/2 [00:00<00:00,  5.62it/s]


frechet dist: 377.2082279357388 | time to calculate : 9.87267255783081 s
