In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import utils
import models

from tensorboard_logging import ReconstructionsLogger, TensorBoardModelLogger, LatentSpaceReconLogger, RandomReconLogger
from torchbearer import Trial
from torchbearer.callbacks.tensor_board import TensorBoard
from utils import AEDatasetWrapper

from models import ConvVAE

In [2]:
# Parameters:
params = {'batch_size': 256,
              'embedding_dim': 32,
              'image_dim': 784,
              'nEpoch': 10,
              'conv_ch': 32}

    # Dataset construction
transform = transforms.Compose([
        transforms.ToTensor(),  # convert to tensor
    ])

In [3]:
trainset = AEDatasetWrapper(torchvision.datasets.CIFAR10('.', train=True, transform=transform, download=True))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, num_workers=1)
testset = AEDatasetWrapper(torchvision.datasets.CIFAR10('.', train=False, transform=transform, download=True))
testloader = torch.utils.data.DataLoader(testset, batch_size=params['batch_size'], shuffle=False, num_workers=1)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# construct the encoder, decoder and optimiser
vae = ConvVAE(params['conv_ch'], params['embedding_dim'])
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

tb_comment = 'cifar-conv-vae-next'
tbl = TensorBoard(write_graph=True, comment=tb_comment)
tbml = TensorBoardModelLogger(comment=tb_comment)
rsl = ReconstructionsLogger(comment=tb_comment, output_shape=(3, 32, 32))
lsrl = LatentSpaceReconLogger(comment=tb_comment, output_shape=(3, 32, 32), latent_dim=params['embedding_dim'])
rrl = RandomReconLogger(comment=tb_comment, latent_dim=params['embedding_dim'], output_shape=(3, 32, 32))

In [5]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [6]:
trial = Trial(vae, optimizer, ConvVAE.loss_mse, metrics=['loss', models.recons_loss()], callbacks=[tbl, rsl, tbml, lsrl, rrl]).to(device)
trial.with_generators(trainloader, val_generator=testloader)

--------------------- OPTIMZER ---------------------
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)

-------------------- CRITERION ---------------------
<function ConvVAE.loss_mse at 0x7f7fc2694840>

--------------------- METRICS ----------------------
['loss', 'recons_loss']

-------------------- CALLBACKS ---------------------
['torchbearer.callbacks.tensor_board.TensorBoard', 'tensorboard_logging.ReconstructionsLogger', 'tensorboard_logging.TensorBoardModelLogger', 'tensorboard_logging.LatentSpaceReconLogger', 'tensorboard_logging.RandomReconLogger']

---------------------- MODEL -----------------------
ConvVAE(
  (enc): Encoder(
    (conv1): Conv2d(3, 3, kernel_size=(2, 2), stride=(1, 1))
    (conv2): Conv2d(3, 32, kernel_size=(2, 2), stride=(2, 2))
    (conv3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (

In [None]:
trial.run(epochs=params['nEpoch'])