# GAN Training for DataChallenge

In [1]:
import os
import sys
import yaml
sys.path.append('../../')

import torch
import pytorch_lightning as pl
import numpy as np
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning import Trainer, loggers
from syndatagenerators.data_preparation.datachallenge import lcl_to_hdf5, load_households_from_hdf, DataChallengeDataModule

from syndatagenerators.models.progressive_gan.trainer import TrainerProGAN
from syndatagenerators.models.utils.callbacks import MMDCallback, ACFCallback, DiscriminativeCallback
from syndatagenerators.models.utils.plot_functions import plot_sample_grid

### Load Data

In [2]:
# Paths
lsmpath = r'/share/data1/bschaefermeier/datasets/Small_LCL_Data/'
hdf5file = '/share/data1/bschaefermeier/datasets/londonSmartMeter.h5'

In [3]:
# Create hdf5 file from LSM Data if it does not exist yet.
lcl_to_hdf5(root=lsmpath, file_out=hdf5file, overwrite=False)

HDF5 file /share/data1/bschaefermeier/datasets/Small_LCL_Data/ already exists. Delete it or set overwrite=True if you want to recreate it.


In [4]:
datamodule = DataChallengeDataModule(hdf5file, sample_dim=128, num_workers=8)
full_train_data, test_data = datamodule.train_dataloader(), datamodule.test_dataloader()
print(f"Batches in train data: {len(full_train_data)} Batches in test data: {len(test_data)}")

120it [00:14,  8.14it/s]


Batches in train data: 608 Batches in test data: 711


In [28]:
full_train_data.dataset.shape

torch.Size([38877, 1, 128])

### Load configuration file with data, train and network parameters

In [15]:
# path to configuration file
CONFIG_PATH = '../models/utils/config_progan2.yml'

In [16]:
config = yaml.safe_load(open(CONFIG_PATH, 'r'))
config

{'train_params': {'batch_size': 64,
  'lr': 1e-05,
  'epochs': 3500,
  'sample_cycle': 10,
  'lambda_gp': 10,
  'n_critic': 5,
  'epochs_per_step': 5,
  'nb_fade_in_epochs': 150,
  'schedule': [800, 1500, 2100, 2600],
  'name': 13,
  'feature_dim': 1,
  'target_len': 128,
  'nb_labels': 1},
 'dis_params': {'kernel_size': 11, 'channel_dim': 32},
 'gen_params': {'kernel_size': 11, 'channel_dim': 32},
 'data_params': {'window_len': 64,
  'overlap': 0,
  'n_households': 1000,
  'train_data_dir': '/share/data1/mjuergens/SyLasKI/train_data',
  'ckpt_dir': '/share/data1/bschaefermeier/chkpt/progan/'}}

### Split data 

In [29]:
from torch.utils.data import random_split, DataLoader
train_size = int(len(full_train_data.dataset)*0.95)
val_size = len(full_train_data.dataset) - train_size
train_data, val_data = random_split(full_train_data.dataset, [train_size, val_size])
print(f"Train size: {len(train_data)} Validation size: {len(val_data)}")

# dataloaders for training and validation
loader_train = DataLoader(train_data, batch_size=config['train_params']['batch_size'], shuffle=False,
                          num_workers=8)
loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)

print(f'Training size: {len(train_data)} Validation size: {len(val_data)}')

Train size: 36933 Validation size: 1944
Training size: 36933 Validation size: 1944


### Create the model

In [17]:
model = TrainerProGAN(train_params=config["train_params"], dis_params=config["dis_params"], gen_params=config["gen_params"])

### Callbacks to be used during training

In [18]:
FILENAME = f'{model.__class__.__name__}_{model.name}'
MODEL_PATH = os.path.join(config['data_params']['ckpt_dir'], str(config['train_params']['name']))

# callback that calculates MMD every ith epoch
mmd_callback = MMDCallback(size=1000)

# callback that calculates discriminative score using an LSTM
discriminative_callback = DiscriminativeCallback(size=1000)

# callback for early stopping using MMD
#early_stopping_callback = EarlyStopping('mmd', patience=4)
early_stopping_callback = EarlyStopping('generator loss', patience=30, stopping_threshold=-5)

# checkpoint callback
checkpoint_callback = ModelCheckpoint(dirpath=config['data_params']['ckpt_dir'], filename=FILENAME, every_n_epochs=10, save_top_k=1, monitor="generator loss")

## Train the model

In [None]:
1+1

In [None]:
# tensorboard logger
tb_logger = pl.loggers.TensorBoardLogger(save_dir=MODEL_PATH)
callbacks = [mmd_callback, checkpoint_callback]
trainer = pl.Trainer(logger=tb_logger, accelerator='cuda', max_epochs=config['train_params']['epochs'],
                     callbacks=callbacks,
                     check_val_every_n_epoch=10)

import time
start_time = time.time()
trainer.fit(model, loader_train, loader_val)
duration = time.time()-start_time
print(f"Training for {trainer.current_epoch} epochs took {duration}seconds.")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]

  | Name          | Type         | Params
-----------------------------------------------
0 | loss          | WGANGPLoss   | 0     
1 | generator     | ProGenerator | 95.7 K
2 | discriminator | ProCritic    | 113 K 
-----------------------------------------------
209 K     Trainable params
0         Non-trainable params
209 K     Total params
0.836     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

## Evaluate the model

In [45]:
checkpoint_callback.best_model_score

tensor(-6.0336, device='cuda:0')

In [None]:
checkpoint_callback

In [34]:
def generate_synthetic_samples(generator, depth: int = 4, size: int = 5000,
                               clip: bool = True):
    generator.eval()
    z = torch.randn(size, generator.nb_features, generator.target_len).float()
    x_gen = generator(z, depth=depth, residual=0).detach()
    x_gen = generator(z, depth=depth, residual=False).detach()

    if clip:
        x_gen = torch.clamp(x_gen, 0, 1)
    return x_gen

In [36]:
def load_checkpoint(path):
    model = TrainerProGAN.load_from_checkpoint(path)
    model.eval()
    return model

LOAD_CHECKPOINT = False
# Checkpoint for first model that was trained with 1000 households instead of 100.
if LOAD_CHECKPOINT:
    chkpt = '/share/data1/bschaefermeier/chkpt/progan/TrainerProGAN_12-v39.ckpt'
    model = load_checkpoint(chkpt)

In [37]:
x_gen = generate_synthetic_samples(model.generator, depth=4, clip=False)

In [38]:
x_gen.shape

torch.Size([5000, 1, 128])

In [1]:
#realdatasample = np.random.choice(dataset, 16)
plot_sample_grid(dataset.data, random_sample=True, title='Real Samples')
fig = plot_sample_grid(x_gen, random_sample=True, title='Generated Samples')

NameError: name 'plot_sample_grid' is not defined

In [None]:
from syndatagenerators.metrics.visualization import plot_TSNE
from syndatagenerators.models.utils.plot_functions import plot_sample_grid

In [40]:
tsne_plot = plot_TSNE(dataset_val[:5000], x_gen, use_seaborn=True)

In [42]:
from syndatagenerators.metrics.mmd_score import mmd
mmd(dataset_val[:], x_gen)

tensor(1.1658)