In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import utils
import models
import tensorboard_logging

from tqdm import tqdm

from utils import AEDatasetWrapper
from models import IntroVAE, ConvVAE

from tensorboard_logging import TensorBoardExtension, ReconstructionsLogger, TensorBoardModelLogger, LatentSpaceReconLogger, RandomReconLogger

In [10]:
# Parameters:
params = {'batch_size': 256,
              'embedding_dim': 32,
              'image_dim': 784,
              'nEpoch': 10,
              'conv_ch': 32,
              'margin': 0}

    # Dataset construction
transform = transforms.Compose([
        transforms.ToTensor(),  # convert to tensor
    ])

In [3]:
trainset = AEDatasetWrapper(torchvision.datasets.CIFAR10('.', train=True, transform=transform, download=True))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, num_workers=1)
testset = AEDatasetWrapper(torchvision.datasets.CIFAR10('.', train=False, transform=transform, download=True))
testloader = torch.utils.data.DataLoader(testset, batch_size=params['batch_size'], shuffle=False, num_workers=1)

Files already downloaded and verified
Files already downloaded and verified


In [13]:
# construct the encoder, decoder and optimiser
vae = utils.cuda(IntroVAE(params['conv_ch'], params['embedding_dim'], margin=params['margin']))

In [14]:
tb_comment = 'intro-vae-2'
tbl = TensorBoardExtension(write_graph=True, comment=tb_comment)
tbml = TensorBoardModelLogger(comment=tb_comment)
rsl = ReconstructionsLogger(comment=tb_comment, output_shape=(3, 32, 32))
lsrl = LatentSpaceReconLogger(comment=tb_comment, output_shape=(3, 32, 32), latent_dim=params['embedding_dim'])
rrl = RandomReconLogger(comment=tb_comment, latent_dim=params['embedding_dim'], output_shape=(3, 32, 32))

In [6]:
# write model graph
tbl.direct_call(vae, trainloader)
tbl.close_writer()

In [None]:
# log model
tbml.direct_call(vae)
tbml.close_writer()

In [15]:
# training loop
for epoch in range(params['nEpoch']):
    losses = []
    trainloader = tqdm(trainloader)
    
    log = True

    for i, data in enumerate(trainloader, 0):
        inputs, _ = data
        inputs = utils.cuda(inputs)
        
        y_pred = vae(inputs)
        
        vae.train_iter(inputs)
        
        loss = F.mse_loss(y_pred[0], inputs, reduction='sum') / inputs.shape[0]

        # keep track of the loss and update the stats
        losses.append(loss.item())
        trainloader.set_postfix(loss=np.mean(losses), epoch=epoch)
        
        if log:
            valid_input, _ = next(iter(testloader))
            valid_input = utils.cuda(valid_input)
            y_pred_valid = vae(valid_input)
            rsl.direct_call(vae, epoch, y_pred_valid)
            log = False

    # Generate visualisation of reconstructions (each epoch)
    lsrl.direct_call(vae, epoch)
    rrl.direct_call(vae, epoch)


rsl.close_writer()
lsrl.close_writer()
rrl.close_writer()

100%|█████████████████████████████████████████████████████████████| 196/196 [01:21<00:00,  3.16it/s, epoch=0, loss=157]
100%|█████████████████████████████████████████████████████████████| 196/196 [01:21<00:00,  3.21it/s, epoch=1, loss=115]
100%|█████████████████████████████████████████████████████████████| 196/196 [01:20<00:00,  3.18it/s, epoch=2, loss=105]
100%|████████████████████████████████████████████████████████████| 196/196 [01:21<00:00,  3.18it/s, epoch=3, loss=97.9]
100%|████████████████████████████████████████████████████████████| 196/196 [01:21<00:00,  3.17it/s, epoch=4, loss=93.7]
100%|████████████████████████████████████████████████████████████| 196/196 [01:21<00:00,  3.19it/s, epoch=5, loss=91.5]
100%|████████████████████████████████████████████████████████████| 196/196 [01:21<00:00,  3.17it/s, epoch=6, loss=87.4]
100%|████████████████████████████████████████████████████████████| 196/196 [01:21<00:00,  3.15it/s, epoch=7, loss=85.5]
100%|███████████████████████████████████