In [None]:
%load_ext tensorboard
%load_ext autoreload
%autoreload 2

import torch
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from networks import VAE
from torchvision.datasets import MNIST
from torchvision import transforms

from os.path import isdir

## Initial experiment on MNIST VAE task

In [None]:
# Load dataset
PATH_DATASETS = 'data/'
download = isdir(PATH_DATASETS + 'MNIST')
train_ds = MNIST(PATH_DATASETS, train=True, download=download, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=256)
val_ds = MNIST(PATH_DATASETS, train=False, download=download, transform=transforms.ToTensor())
val_loader = DataLoader(val_ds, batch_size=256)

# Setup network
config_adam = {
    'optimizer': torch.optim.Adam,
    'optimizer_params': {'lr': 1e-3}
}
vae_adam = VAE(config_adam)

# Train network
logger = TensorBoardLogger(save_dir='lightning_logs', name='VAE-Adam', default_hp_metric=False)
trainer = Trainer(
    gpus=0,
    max_epochs=2,
    logger=logger
)

# Train the model ⚡
trainer.fit(vae_adam, train_loader, val_loader)

In [None]:
%tensorboard --logdir lightning_logs