In [None]:
%load_ext tensorboard
%matplotlib widget
%load_ext autoreload
%autoreload 2

import torch
from torch.utils.data import DataLoader
import pandas as pd
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
from torchvision.datasets import MNIST
from torchvision import transforms
from os.path import isdir
import matplotlib.pyplot as plt
import seaborn as sns

from networks import VAE
from optimizers.kfac import KFACOptimizer

# Change if training with GPU
NUM_GPUS = 0

# Silence unnecessary Pytorch Lightning logs
import logging
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

## Initial experiment on MNIST VAE task

In [None]:
# Setup experiment parameters
MAX_EPOCHS = 1
DATALOADER_WORKERS = 8
SGD_MOMENTUM=0.99
configs = [
    {
        'label': f'SGD with momentum $\\eta={SGD_MOMENTUM}$',
        'optimizer': torch.optim.SGD,
        'optimizer_params': {'lr': 1e-3,
                             'weight_decay': 1e-5,
                             'momentum': SGD_MOMENTUM,
                             'nesterov': True}
    },
    {
        'label': 'Adam',
        'optimizer': torch.optim.Adam,
        'optimizer_params': {'lr': 1e-3}
    },
    {
        'label': 'KFAC',
        'optimizer': KFACOptimizer,
        'optimizer_params': {'lr': 1e-2, 
                             'momentum': 0, 
                             'stat_decay': 0.99, 
                             'damping': 1e-3, 
                             'kl_clip': 1e-2,
                             'weight_decay': 1e-5, 
                             'TCov': 10, 
                             'TInv': 100}
    }
]

# Load dataset
PATH_DATASETS = 'data/'
download = isdir(PATH_DATASETS + 'MNIST')
train_ds = MNIST(PATH_DATASETS, train=True, download=download, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=256, num_workers=DATALOADER_WORKERS)
val_ds = MNIST(PATH_DATASETS, train=False, download=download, transform=transforms.ToTensor())
val_loader = DataLoader(val_ds, batch_size=256, num_workers=DATALOADER_WORKERS)

In [None]:
# Iterate over configs, train and collect data
for config in configs:
    
    # Config name
    config_name = config['label']
    
    # Initialize VAE with the chosen config
    vae = VAE(config)

    # Train network
    logger = CSVLogger(save_dir='logs/MNIST_VAE/', version="V0", name=config_name)
    trainer = Trainer(
        enable_model_summary=False,
        gpus=NUM_GPUS,
        max_epochs=MAX_EPOCHS,
        logger=logger
    )

    # Train the model
    trainer.fit(vae, train_loader, val_loader)

## Plots

In [None]:
training_logs = []
validation_logs = []
for config in configs:
    config_name = config['label']
    
    # Load training logs
    train_df = pd.read_csv(f'logs/MNIST_VAE/{config_name}/V0/metrics.csv')
    
    # Add training losses
    train_loss = train_df[~train_df.train_loss.isnull()][['train_loss', 'step']]
    train_loss['Optimizer'] = config_name
    training_logs += train_loss.T.to_dict().values()
    
    # Add validation losses
    valid_loss = train_df[~train_df.val_loss.isnull()][['val_loss', 'step']]
    valid_loss['Optimizer'] = config_name
    validation_logs += valid_loss.T.to_dict().values()
    
training_df = pd.DataFrame(training_logs)
validation_df = pd.DataFrame(validation_logs)

In [None]:
plt.close()
sns.lineplot(data=training_df, x='step', y='train_loss', hue='Optimizer')
plt.show()

In [None]:
plt.close()
sns.lineplot(data=validation_df, x='step', y='val_loss', hue='Optimizer')
plt.show()