# Training of the Conditional ProGAN - Example

### imports:

In [7]:
import os
import yaml

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.utils.data import DataLoader, random_split

from syndatagenerators.models.progressive_gan.trainer import TrainerProGAN
from syndatagenerators.models.cond_progressive_gan.trainer import TrainerCPGAN
from syndatagenerators.models.utils.callbacks import DiscriminativeCallback, MMDCallback, ACFCallback
from syndatagenerators.data_preparation.datasets import LondonDataset

### define paths, household ID's, configuration

#### directories for training data, checkpoints, and the configuration file

In [2]:
# define directory where train data is saved
TRAIN_DATA_DIR = '../train_data'
# directory where model is saved/loaded from
CKPT_DIR = '../ckpt/'
# path to configuration file
CONFIG_PATH = '../models/progressive_gan/config.yml'

#### household ID's used for training

In [3]:
N_HOUSEHOLDS = 100 # needs to be in range (1, n_ids) TODO: check n_ids
# asset lists (of London Smart Meter)
list_assets = ['MAC' + str(i).zfill(6) for i in range(2, 100)]

#### load parameters for the model

In [4]:
config = yaml.safe_load(open(CONFIG_PATH, 'r'))
train_params = config['train_params']
dis_params = config['dis_params']
gen_params = config['gen_params']

#### initialize the model and the London Smart Meter Dataset

In [5]:
nb_cls = len(list_assets)

model = TrainerCPGAN(train_params=train_params, dis_params=dis_params, gen_params=gen_params, nb_classes=nb_cls)

# dataset: takes some time to initialize if train data has not been loaded before
dataset = LondonDataset(assets=list_assets,     window_length=config['data_params']['window_len'],
                            overlap=config['data_params']['overlap'], train_data_dir=TRAIN_DATA_DIR,
                            labels=True)

Train data saved


#### split in training and validation: for the moment, randomly, this would need to be adjusted

In [6]:
# split into train and validation set
dataset_train, dataset_val = random_split(dataset, [len(dataset) - 1000, 1000])

# dataloaders for training and validation
loader_train = DataLoader(dataset_train, batch_size=config['train_params']['batch_size'], shuffle=True,
                          num_workers=8)
loader_val = DataLoader(dataset_val, batch_size=len(dataset_val), shuffle=False)

#### define relevant Callbacks to be used during training

In [9]:
FILENAME = f'{model.__class__.__name__}_{model.name}'

# callback that calculates MMD every ith epoch
mmd_callback = MMDCallback(size=1000)

# callback that calculates discriminative score using an LSTM
discriminative_callback = DiscriminativeCallback(size=1000)

# callback for early stopping using MMD
early_stopping_callback = EarlyStopping('mmd', patience=4)

# checkpoint callback
checkpoint_callback = ModelCheckpoint(dirpath=CKPT_DIR, filename=FILENAME, every_n_epochs=10)

#### initialize the trainer

In [10]:
trainer = pl.Trainer(max_epochs=config['train_params']['epochs'],
                             callbacks=[mmd_callback, checkpoint_callback, early_stopping_callback],
                             check_val_every_n_epoch=5)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, loader_train, loader_val)

  rank_zero_deprecation(
Missing logger folder: C:\Users\mjuergen\Documents\SylasKI\Code\syndatagenerators\syndatagenerators\notebooks\lightning_logs

  | Name          | Type            | Params
--------------------------------------------------
0 | loss          | WGANGPLoss      | 0     
1 | generator     | CPGenerator     | 75.7 K
2 | discriminator | CPDiscriminator | 85.0 K
--------------------------------------------------
160 K     Trainable params
0         Non-trainable params
160 K     Total params
0.643     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  value = torch.tensor(value, device=self.device)


Depth on epoch 0: 0, residual: False


Training: 0it [00:00, ?it/s]