In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import utils

import introvae_model, models

import torchbearer as tb

from tensorboard_logging import ReconstructionsLogger, TensorBoardModelLogger, LatentSpaceReconLogger, RandomReconLogger
from torchbearer import Trial
from torchbearer.callbacks.tensor_board import TensorBoard
from utils import AEDatasetWrapper

import warnings
warnings.filterwarnings('ignore')

In [3]:
# Parameters:
params = {'batch_size': 128,
              'nEpoch': 50,
              'imgSize': 32,
              'zsize': 64,
              'depth': 0,
              'margin': 40,
              'alpha': 0.,
              'beta': 3,
              'gamma': 1,
              'lr': 1e-3,
              'pre': False,
              'com': 'simplule'
         }

    # Dataset construction
transform = transforms.Compose([
        transforms.ToTensor(),  # convert to tensor
    ])

device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [4]:
trainset = AEDatasetWrapper(torchvision.datasets.CIFAR10('.', train=True, transform=transform, download=False))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, num_workers=4)
testset = AEDatasetWrapper(torchvision.datasets.CIFAR10('.', train=False, transform=transform, download=False))
testloader = torch.utils.data.DataLoader(testset, batch_size=params['batch_size'], shuffle=False, num_workers=4)

In [5]:
# construct the encoder, decoder and optimiser
iSize = params['imgSize']

vae = introvae_model.IntroVAE(iSize, params['zsize'], params['depth'], 
                      alpha=params['alpha'], beta=params['beta'], 
                      gamma=params['gamma'], margin=params['margin'],
                      lr=params['lr'])

tb_comment = params['com']
tbl = TensorBoard(write_graph=True, comment=tb_comment)
tbml = TensorBoardModelLogger(comment=tb_comment)
rsl = ReconstructionsLogger(comment=tb_comment, output_shape=(3, iSize, iSize))
lsrl = LatentSpaceReconLogger(comment=tb_comment, output_shape=(3, iSize, iSize), latent_dim=params['zsize'])
rrl = RandomReconLogger(comment=tb_comment, latent_dim=params['zsize'], output_shape=(3, iSize, iSize))

Encoder: [2, 3, 32, 32]=>[2, 16, 16, 16]=>[2, 32, 8, 8]=>[2, 64, 4, 4]=>[2, 64]  x 2
Decoder: [64] => [2, 64, 4, 4]=>[2, 64, 8, 8]=>[2, 64, 16, 16]=>[2, 3, 32, 32]
IntroVAE x: [2, 3, 32, 32] mu_: [2, 64] z: [2, 64] out: [2, 3, 32, 32]


In [5]:
encoder_step = models.encoder_step(pretrain=params['pre'])
encoder_loss = models.encoder_loss()

decoder_step = models.decoder_step()
decoder_loss = models.decoder_loss()

forward_step = models.forward_step()
recons_loss = models.recons_loss()

In [6]:
trial = Trial(vae, optimizer=None, criterion=None, 
              metrics=[recons_loss, encoder_loss, decoder_loss], 
              callbacks=[encoder_step, decoder_step, forward_step, tbl, tbml, rsl, lsrl, rrl]).to(device)
trial.with_generators(trainloader, val_generator=testloader)

--------------------- OPTIMZER ---------------------
MockOptimizer ()

-------------------- CRITERION ---------------------
<function Trial.__init__.<locals>.criterion at 0x7f9fedf560d0>

--------------------- METRICS ----------------------
['recons_loss', 'encoder_loss', 'decoder_loss']

-------------------- CALLBACKS ---------------------
['models.encoder_step', 'models.decoder_step', 'models.forward_step', 'torchbearer.callbacks.tensor_board.TensorBoard', 'tensorboard_logging.TensorBoardModelLogger', 'tensorboard_logging.ReconstructionsLogger', 'tensorboard_logging.LatentSpaceReconLogger', 'tensorboard_logging.RandomReconLogger']

---------------------- MODEL -----------------------
IntroVAE(
  (enc): Encoder(
    (net): Sequential(
      (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace)
      (3): AvgPool2d(kernel_size=2, stride=2, padding=0)
     

In [7]:
trial.run(params['nEpoch'])

0/50(t): 100%|██████████| 391/391 [01:42<00:00,  4.25it/s, decoder_loss=296.9, encoder_loss=321.6, recons_loss=98.01, running_decoder_loss=240.6, running_encoder_loss=266, running_recons_loss=79.39]
0/50(v): 100%|██████████| 79/79 [00:01<00:00, 40.91it/s, val_decoder_loss=234.1, val_encoder_loss=259.1, val_recons_loss=74.43]
1/50(t): 100%|██████████| 391/391 [01:42<00:00,  4.21it/s, decoder_loss=217.8, encoder_loss=244.1, recons_loss=71.92, running_decoder_loss=200.1, running_encoder_loss=228.1, running_recons_loss=66.05]
1/50(v): 100%|██████████| 79/79 [00:01<00:00, 41.62it/s, val_decoder_loss=211.1, val_encoder_loss=240.2, val_recons_loss=60.99]
2/50(t): 100%|██████████| 391/391 [01:42<00:00,  4.26it/s, decoder_loss=191, encoder_loss=218.2, recons_loss=63.13, running_decoder_loss=185.2, running_encoder_loss=212.1, running_recons_loss=61.17]
2/50(v): 100%|██████████| 79/79 [00:01<00:00, 41.88it/s, val_decoder_loss=168.9, val_encoder_loss=194.8, val_recons_loss=57.69]
3/50(t): 100%|███

[((391, 79),
  {'running_recons_loss': 79.39180374145508,
   'running_encoder_loss': 266.0420022583008,
   'running_decoder_loss': 240.60554107666016,
   'recons_loss': 98.0088844396879,
   'encoder_loss': 321.5534504063599,
   'decoder_loss': 296.865479569301,
   'val_recons_loss': 74.42650188977206,
   'val_encoder_loss': 259.1110534667969,
   'val_decoder_loss': 234.1494140625}),
 ((391, 79),
  {'running_recons_loss': 66.04617942810059,
   'running_encoder_loss': 228.07298065185546,
   'running_decoder_loss': 200.07075561523436,
   'recons_loss': 71.91610066055337,
   'encoder_loss': 244.08907538606687,
   'decoder_loss': 217.80542952203385,
   'val_recons_loss': 60.993735880791384,
   'val_encoder_loss': 240.205322265625,
   'val_decoder_loss': 211.05740356445312}),
 ((391, 79),
  {'running_recons_loss': 61.17003349304199,
   'running_encoder_loss': 212.1060299682617,
   'running_decoder_loss': 185.16772430419923,
   'recons_loss': 63.128317781726416,
   'encoder_loss': 218.2003424