In [7]:
import os 
import yaml
import argparse
import numpy as np
import torch.backends.cudnn as cudnn

from pathlib import Path
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from models import *
from experimenter import VAE_experimenter
from datasets.dataset import VAEDataset


In [8]:
parser = argparse.ArgumentParser(description='Generic runner for VAE models')
parser.add_argument('--config', '-c',
                    dest='filename',
                    metavar='FILE',
                    help='path to the config file',
                    default='configs/vae.yaml')

args = parser.parse_args(args=[])
with open(args.filename, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as error:
        print(error)

In [9]:
config

{'model_params': {'name': 'VanillaVAE', 'in_channels': 3, 'latent_dim': 128},
 'data_params': {'data_path': 'datasets/',
  'train_batch_size': 64,
  'val_batch_size': 64,
  'patch_size': 64,
  'num_workers': 4},
 'exp_params': {'LR': 0.005,
  'weight_decay': 0.0,
  'scheduler_gamma': 0.95,
  'kld_weight': 0.00025,
  'manual_seed': 1265},
 'trainer_params': {'gpus': [1], 'max_epochs': 100},
 'logging_params': {'save_dir': 'logs/', 'name': 'VanillaVAE'}}

In [10]:
tb_logger = TensorBoardLogger(save_dir=config['logging_params']['save_dir'],
                              name=config['model_params']['name'],)

model = vae_models[config['model_params']['name']](**config['model_params'])
experimenter = VAE_experimenter(model, config['exp_params'])
data = VAEDataset(**config['data_params'], pin_memory=len(config['trainer_params']['gpus']) != 0)

In [5]:
data.setup()

trainer = Trainer(logger=tb_logger,
                  callbacks=[
                      LearningRateMonitor(),
                      ModelCheckpoint(save_top_k=2,
                                      dirpath=os.path.join(tb_logger.log_dir, 'checkpoints'),
                                      moniter='val_loss',
                                      save_last=True),
                  ],
                  **config['trainer_params'])

Path(f'{tb_logger.log_dir}/samples').mkdir(exist_ok=True, parents=True)
Path(f'{tb_logger.log_dir}/reconstructions').mkdir(exist_ok=True, parents=True)

print(f'===== Training {config['model_params']['name']} =====')
trainer.fit(experimenter, datamodule=data)