### Setup


In [10]:
import argparse
import os

import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torchvision import datasets

import torch

from gan import Gan

In [11]:
os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument('-f') # uncomment to run on colab/jupyter
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.00005, help="learning rate")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--n_critic", type=int, default=10, help="number of training steps for discriminator per iter")
parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

Namespace(batch_size=128, channels=1, clip_value=0.01, f='/home/dany/.local/share/jupyter/runtime/kernel-6c03b70b-3e55-4a8c-aa36-745326411591.json', img_size=28, latent_dim=100, lr=5e-05, n_cpu=8, n_critic=5, n_epochs=200, sample_interval=400)


### Configure data loader

In [12]:
os.makedirs("data/mnist", exist_ok=True)
mnist_dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)


In [5]:
os.makedirs("data/celebA", exist_ok=True)
celeba_dataloader = torch.utils.data.DataLoader(
    datasets.CelebA(
        "data/celebA",
        #split="training",
        download=False, #True gives a Badzip error ==>
        # download the file from
        # https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg?resourcekey=0-rJlzl934LzC-Xp28GeIBzQ
        # and save in the above folder
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)


94694.0 bytes

Using downloaded and verified file: data/celebA/celeba/identity_CelebA.txt


94756.0 bytes

BadZipFile: File is not a zip file

### BGAN

In [14]:
from nets.bgan_net import BGanGenerator, BGanDiscriminator

bGan_G = BGanGenerator(img_shape, opt.latent_dim)
bGan_D = BGanDiscriminator(img_shape)

bGan = Gan(bGan_G, bGan_D, opt.lr,dataset_name="mnist",loss_name="bgan") #mnist or celeba

dataloader = mnist_dataloader # or celeba_dataloader
b_list_loss_G,b_list_loss_D,b_list_Frech_dist = bGan.train(dataloader, opt.n_epochs, opt.clip_value, opt.n_critic, opt.sample_interval)

[Epoch 0/200] [Batch 0/469] [D loss: 0.12472975254058838] [G loss: 0.12375310063362122]
[FID 724.468384]
[Epoch 0/200] [Batch 5/469] [D loss: 0.09683245420455933] [G loss: 0.1236024722456932]
[Epoch 0/200] [Batch 10/469] [D loss: 0.034609079360961914] [G loss: 0.12319678068161011]
[Epoch 0/200] [Batch 15/469] [D loss: -0.05227452516555786] [G loss: 0.12194396555423737]
[Epoch 0/200] [Batch 20/469] [D loss: -0.1337328553199768] [G loss: 0.11932303011417389]
[Epoch 0/200] [Batch 25/469] [D loss: -0.2018526941537857] [G loss: 0.1158469170331955]
[Epoch 0/200] [Batch 30/469] [D loss: -0.23955893516540527] [G loss: 0.111631840467453]
[Epoch 0/200] [Batch 35/469] [D loss: -0.2631893754005432] [G loss: 0.10731492936611176]
[Epoch 0/200] [Batch 40/469] [D loss: -0.2761201560497284] [G loss: 0.10256332159042358]
[Epoch 0/200] [Batch 45/469] [D loss: -0.2842242121696472] [G loss: 0.0978941097855568]
[Epoch 0/200] [Batch 50/469] [D loss: -0.28751662373542786] [G loss: 0.09309809654951096]
[Epoch 

### Wasserstein

In [15]:
from nets.wasserstein_net import WassersteinGenerator, WassersteinDiscriminator

wasserstein_G = WassersteinGenerator(img_shape, opt.latent_dim)
wasserstein_D = WassersteinDiscriminator(img_shape)

wasserstein_Gan = Gan(wasserstein_G, wasserstein_D, opt.lr, dataset_name="mnist", loss_name="wgan") #mnist or celeba

dataloader = mnist_dataloader # or celeba_dataloader
w_list_loss_G,w_list_loss_D,w_list_Frech_dist = wasserstein_Gan.train(dataloader, opt.n_epochs, opt.clip_value, opt.n_critic, opt.sample_interval)

[Epoch 0/200] [Batch 0/469] [D loss: 0.02023226022720337] [G loss: -0.49788546562194824]
[FID 724.034058]
[Epoch 0/200] [Batch 5/469] [D loss: -0.028502970933914185] [G loss: -0.4981392025947571]
[Epoch 0/200] [Batch 10/469] [D loss: -0.09236612915992737] [G loss: -0.4990048110485077]
[Epoch 0/200] [Batch 15/469] [D loss: -0.1835019588470459] [G loss: -0.5018697381019592]
[Epoch 0/200] [Batch 20/469] [D loss: -0.268688440322876] [G loss: -0.5072141885757446]
[Epoch 0/200] [Batch 25/469] [D loss: -0.32264918088912964] [G loss: -0.5144681930541992]
[Epoch 0/200] [Batch 30/469] [D loss: -0.35816723108291626] [G loss: -0.5229469537734985]
[Epoch 0/200] [Batch 35/469] [D loss: -0.3793531060218811] [G loss: -0.5323294997215271]
[Epoch 0/200] [Batch 40/469] [D loss: -0.38370227813720703] [G loss: -0.5429894924163818]
[Epoch 0/200] [Batch 45/469] [D loss: -0.38563036918640137] [G loss: -0.5522460341453552]
[Epoch 0/200] [Batch 50/469] [D loss: -0.3863444924354553] [G loss: -0.561630368232727]


In [28]:
import pandas as pd

b_loss_dict = {'b_loss_G': b_list_loss_G, 'b_loss_D' : b_list_loss_D,  }
df_b = pd.DataFrame(b_loss_dict)
df_b.to_csv(path_or_buf='results/bGAN.csv')

wass_loss_dict = {'wass_loss_G': w_list_loss_G, 'wass_loss_D' : w_list_loss_D,  }
df_w = pd.DataFrame(wass_loss_dict)
df_w.to_csv(path_or_buf='results/WGAN.csv')

dict_frech_dist = {'b_frech_dist':b_list_Frech_dist,'wass_frech_dist':w_list_Frech_dist}
df_frech_dist = pd.DataFrame(dict_frech_dist)
df_frech_dist.to_csv(path_or_buf='results/frech_dist.csv')

## Results and Plot

In [26]:
import pandas as pd
import plotly.express as px

df = pd.read_csv('results/frech_dist.csv')

fig = px.line(df, y = ['wass_frech_dist','b_frech_dist'], title='Frechet Distance')
fig.show()








In [38]:
df1 = pd.read_csv('results/bGAN.csv')
df2= pd.read_csv('results/WGAN.csv')
df = pd.concat([df1, df2], axis=1)
fig = px.line(df, y = ["b_loss_D","b_loss_G","wass_loss_G","wass_loss_D"], title='Losses')
fig.show()


