In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import utils

import models

import torchbearer as tb

from tensorboard_logging import ReconstructionsLogger, TensorBoardModelLogger, LatentSpaceReconLogger, RandomReconLogger
from torchbearer import Trial
from torchbearer.callbacks.tensor_board import TensorBoard
from utils import AEDatasetWrapper

import warnings
warnings.filterwarnings('ignore')

In [2]:
# Parameters:
params = {'batch_size': 128,
              'nEpoch': 50,
              'imgSize': 32,
              'zsize': 64,
              'depth': 0,
              'margin': 2,
              'alpha': 0,
              'beta': 3,
              'gamma': 1,
              'lr': 1e-3}

    # Dataset construction
transform = transforms.Compose([
        transforms.ToTensor(),  # convert to tensor
    ])

In [3]:
trainset = AEDatasetWrapper(torchvision.datasets.CIFAR10('.', train=True, transform=transform, download=True))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, num_workers=1)
testset = AEDatasetWrapper(torchvision.datasets.CIFAR10('.', train=False, transform=transform, download=True))
testloader = torch.utils.data.DataLoader(testset, batch_size=params['batch_size'], shuffle=False, num_workers=1)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# construct the encoder, decoder and optimiser
iSize = params['imgSize']

vae = models.IntroVAE(iSize, params['zsize'], params['depth'], 
                      alpha=params['alpha'], beta=params['beta'], 
                      gamma=params['gamma'], margin=params['margin'],
                      lr=params['lr'], amsgrad=True)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

tb_comment = 'simple_incauna'
tbl = TensorBoard(write_graph=True, comment=tb_comment)
tbml = TensorBoardModelLogger(comment=tb_comment)
rsl = ReconstructionsLogger(comment=tb_comment, output_shape=(3, iSize, iSize))
lsrl = LatentSpaceReconLogger(comment=tb_comment, output_shape=(3, iSize, iSize), latent_dim=params['zsize'])
rrl = RandomReconLogger(comment=tb_comment, latent_dim=params['zsize'], output_shape=(3, iSize, iSize))

In [5]:
encoder_step = models.encoder_step()
encoder_loss = models.encoder_loss()

decoder_step = models.decoder_step()
decoder_loss = models.decoder_loss()

forward_step = models.forward_step()
recons_loss = models.recons_loss()

In [6]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [7]:
trial = Trial(vae, optimizer=None, criterion=None, 
              metrics=[recons_loss, encoder_loss, decoder_loss], 
              callbacks=[encoder_step, decoder_step, forward_step, tbl, tbml, rsl, lsrl, rrl]).to(device)
trial.with_generators(trainloader, val_generator=testloader)

--------------------- OPTIMZER ---------------------
MockOptimizer ()

-------------------- CRITERION ---------------------
<function Trial.__init__.<locals>.criterion at 0x7f688c9d4bf8>

--------------------- METRICS ----------------------
['recons_loss', 'encoder_loss', 'decoder_loss']

-------------------- CALLBACKS ---------------------
['models.encoder_step', 'models.decoder_step', 'models.forward_step', 'torchbearer.callbacks.tensor_board.TensorBoard', 'tensorboard_logging.TensorBoardModelLogger', 'tensorboard_logging.ReconstructionsLogger', 'tensorboard_logging.LatentSpaceReconLogger', 'tensorboard_logging.RandomReconLogger']

---------------------- MODEL -----------------------
IntroVAE(
  (enc): ImEncoder(
    (encoder): Sequential(
      (0): Block(
        (upchannels): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
        (seq): Sequential(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU()
          (2): Conv2d(16, 16,

In [8]:
trial.run(epochs=50)

0/50(t): 100%|██████████| 391/391 [02:50<00:00,  2.63it/s, decoder_loss=575.9, encoder_loss=576.5, recons_loss=191.6, running_decoder_loss=576.8, running_encoder_loss=577.7, running_recons_loss=191.9]
0/50(v): 100%|██████████| 79/79 [00:02<00:00, 28.08it/s, val_decoder_loss=590.7, val_encoder_loss=595.2, val_recons_loss=181.8]
1/50(t): 100%|██████████| 391/391 [02:48<00:00,  2.60it/s, decoder_loss=549.8, encoder_loss=561.1, recons_loss=183.2, running_decoder_loss=488.6, running_encoder_loss=526.9, running_recons_loss=167]
1/50(v): 100%|██████████| 79/79 [00:02<00:00, 30.04it/s, val_decoder_loss=591.2, val_encoder_loss=614.4, val_recons_loss=192.1]
2/50(t): 100%|██████████| 391/391 [02:48<00:00,  2.60it/s, decoder_loss=425, encoder_loss=442.6, recons_loss=141.2, running_decoder_loss=368.5, running_encoder_loss=389.6, running_recons_loss=123.7]
2/50(v): 100%|██████████| 79/79 [00:02<00:00, 29.97it/s, val_decoder_loss=504.8, val_encoder_loss=545.4, val_recons_loss=173.6]
3/50(t): 100%|███

KeyboardInterrupt: 