# GAN using PyTorch Lightning 

See : 
- https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/basic-gan.html
- https://www.assemblyai.com/blog/pytorch-lightning-for-dummies/


Note : Need 
```pip install ipywidgets lightning tqdm```

## Step 1 - Init and parameters
#### Python init

In [None]:
import os
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from lightning import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.callbacks.progress.tqdm_progress import TQDMProgressBar
from lightning.pytorch.callbacks.progress.base          import ProgressBarBase
from lightning.pytorch.callbacks                        import ModelCheckpoint
from lightning.pytorch.loggers.tensorboard              import TensorBoardLogger

from tqdm import tqdm
from torch.utils.data import DataLoader

import fidle

from modules.SmartProgressBar    import SmartProgressBar
from modules.QuickDrawDataModule import QuickDrawDataModule

from modules.GAN                 import GAN
from modules.Generators          import *
from modules.Discriminators      import *

# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('SHEEP3')

#### Few parameters

In [None]:
latent_dim          = 128
    
generator_class     = 'Generator_1'
discriminator_class = 'Discriminator_1'    
    
scale               = .1
epochs              = 10
batch_size          = 32
num_img             = 36
fit_verbosity       = 2
    
dataset_file        = datasets_dir+'/QuickDraw/origine/sheep.npy' 
data_shape          = (28,28,1)

## Step 2 - Get some nice data

#### Get a Nice DataModule
Our DataModule is defined in [./modules/QuickDrawDataModule.py](./modules/QuickDrawDataModule.py)   
This is a [LightningDataModule](https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html)

In [None]:
dm = QuickDrawDataModule(dataset_file, scale, batch_size, num_workers=8)
dm.setup()

#### Have a look

In [None]:
dl         = dm.train_dataloader()
batch_data = next(iter(dl))

fidle.scrawler.images( batch_data.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1, 
                       y_padding=0,spines_alpha=0, save_as='01-Sheeps')

## Step 3 - Get a nice GAN model

Our Generators are defined in [./modules/Generators.py](./modules/Generators.py)  
Our Discriminators are defined in [./modules/Discriminators.py](./modules/Discriminators.py)  


Our GAN is defined in [./modules/GAN.py](./modules/GAN.py)  

#### Basic test - Just to be sure it (could) works... ;-)

In [None]:
print('\nInstantiation :\n')
generator     = Generator_1(latent_dim=latent_dim, data_shape=data_shape)
discriminator = Discriminator_1(latent_dim=latent_dim, data_shape=data_shape)

print('\nFew tests :\n')
z = torch.randn(batch_size, latent_dim)
print('z size        : ',z.size())

fake_img = generator.forward(z)
print('fake_img      : ', fake_img.size())

p = discriminator.forward(fake_img)
print('pred fake     : ', p.size())

print('batch_data    : ',batch_data.size())

p = discriminator.forward(batch_data)
print('pred real     : ', p.size())

nimg = fake_img.detach().numpy()
fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1, 
                       y_padding=0,spines_alpha=0, save_as='01-Sheeps')

#### GAN model
To simplify our code, the GAN class is defined separately in the module [./modules/GAN.py](./modules/GAN.py)  
Passing the classe names for generator/discriminator by parameter allows to stay modular and to use the PL checkpoints.

In [None]:
gan = GAN( data_shape          = data_shape,  
           batch_size          = batch_size, 
           latent_dim          = latent_dim, 
           generator_class     = generator_class, 
           discriminator_class = discriminator_class)

## Step 5 - Train it !
#### Instantiate Callbacks, Logger & co.
More about :
- [Checkpoints](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html)
- [modelCheckpoint](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint)

In [None]:

# ---- for tensorboard logs
#
logger       = TensorBoardLogger(       save_dir       = f'{run_dir}',
                                        name           = 'tb_logs'  )

# ---- To save checkpoints
#
callback_checkpoints = ModelCheckpoint( dirpath        = f'{run_dir}/models', 
                                        filename       = 'bestModel', 
                                        save_top_k     = 1, 
                                        save_last      = True,
                                        every_n_epochs = 1, 
                                        monitor        = "g_loss")

# ---- To have a nive progress bar
#
callback_progressBar = SmartProgressBar(verbosity=2)          # Usable evertywhere
# progress_bar = TQDMProgressBar(refresh_rate=1)              # Usable in real jupyter lab (bug in vscode)

#### Train it

In [None]:

trainer = Trainer(
    accelerator        = "auto",
#    devices            = 1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    max_epochs         = epochs,
    callbacks          = [callback_progressBar, callback_checkpoints],
    log_every_n_steps  = batch_size,
    logger             = logger
)

trainer.fit(gan, dm)

## Step 6 - Reload a checkpoint

In [None]:
# gan = GAN.load_from_checkpoint('./run/SHEEP3/lightning_logs/version_3/checkpoints/epoch=4-step=1980.ckpt')