In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import utils

import models

import torchbearer as tb

from tensorboard_logging import ReconstructionsLogger, TensorBoardModelLogger, LatentSpaceReconLogger, RandomReconLogger
from torchbearer import Trial
from torchbearer.callbacks.tensor_board import TensorBoard
from utils import AEDatasetWrapper

In [3]:
# Parameters:
params = {'batch_size': 512,
              'embedding_dim': 64,
              'image_dim': 784,
              'nEpoch': 50,
              'conv_ch': 32,
              'margin': 4,
             'alpha': 0.,
             'beta': 3,
             'gamma': 1}

    # Dataset construction
transform = transforms.Compose([
        transforms.ToTensor(),  # convert to tensor
    ])

In [4]:
trainset = AEDatasetWrapper(torchvision.datasets.CIFAR10('.', train=True, transform=transform, download=True))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, num_workers=3)
testset = AEDatasetWrapper(torchvision.datasets.CIFAR10('.', train=False, transform=transform, download=True))
testloader = torch.utils.data.DataLoader(testset, batch_size=params['batch_size'], shuffle=False, num_workers=3)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
# construct the encoder, decoder and tensorboard_logging instances
vae = models.IntroVAE(params['conv_ch'], params['embedding_dim'], alpha=params['alpha'], beta=params['beta'], gamma=params['gamma'], margin=params['margin'], amsgrad=True)

tb_comment = 'simplulex'
tbl = TensorBoard(write_graph=True, comment=tb_comment)
tbml = TensorBoardModelLogger(comment=tb_comment)
rsl = ReconstructionsLogger(comment=tb_comment, output_shape=(3, 32, 32))
lsrl = LatentSpaceReconLogger(comment=tb_comment, output_shape=(3, 32, 32), latent_dim=params['embedding_dim'])
rrl = RandomReconLogger(comment=tb_comment, latent_dim=params['embedding_dim'], output_shape=(3, 32, 32))

In [6]:
encoder_step = models.encoder_step()
encoder_loss = models.encoder_loss()

decoder_step = models.decoder_step()
decoder_loss = models.decoder_loss()

forward_step = models.forward_step()
recons_loss = models.recons_loss()

In [7]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [8]:
trial = Trial(vae, optimizer=None, criterion=None, 
              metrics=[recons_loss, encoder_loss, decoder_loss], 
              callbacks=[encoder_step, decoder_step, forward_step, tbl, tbml, rsl, lsrl, rrl],
              verbose=0).to(device)
trial.with_generators(trainloader, val_generator=testloader)

--------------------- OPTIMZER ---------------------
MockOptimizer ()

-------------------- CRITERION ---------------------
<function Trial.__init__.<locals>.criterion at 0x7f90799a5d90>

--------------------- METRICS ----------------------
['recons_loss', 'encoder_loss', 'decoder_loss']

-------------------- CALLBACKS ---------------------
['models.encoder_step', 'models.decoder_step', 'models.forward_step', 'torchbearer.callbacks.tensor_board.TensorBoard', 'tensorboard_logging.TensorBoardModelLogger', 'tensorboard_logging.ReconstructionsLogger', 'tensorboard_logging.LatentSpaceReconLogger', 'tensorboard_logging.RandomReconLogger']

---------------------- MODEL -----------------------
IntroVAE(
  (enc): Encoder(
    (conv1): Conv2d(3, 3, kernel_size=(2, 2), stride=(1, 1))
    (conv2): Conv2d(3, 32, kernel_size=(2, 2), stride=(2, 2))
    (conv3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 

In [None]:
trial.run(epochs=100)

In [None]:
utils.save(vae, 'intro.w')