In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import utils

from tensorboard_logging import ReconstructionsLogger, TensorBoardModelLogger, LatentSpaceReconLogger, RandomReconLogger
from torchbearer import Trial
from torchbearer.callbacks.tensor_board import TensorBoard
from utils import AEDatasetWrapper

from models import ConvVAE

In [3]:
# Parameters:
params = {'batch_size': 256,
              'embedding_dim': 32,
              'image_dim': 784,
              'nEpoch': 100,
              'conv_ch': 32}

    # Dataset construction
transform = transforms.Compose([
        transforms.ToTensor(),  # convert to tensor
    ])

In [4]:
trainset = AEDatasetWrapper(torchvision.datasets.CIFAR10('.', train=True, transform=transform, download=True))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=params['batch_size'], shuffle=True, num_workers=1)
testset = AEDatasetWrapper(torchvision.datasets.CIFAR10('.', train=False, transform=transform, download=True))
testloader = torch.utils.data.DataLoader(testset, batch_size=params['batch_size'], shuffle=False, num_workers=1)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# construct the encoder, decoder and optimiser
vae = ConvVAE(params['conv_ch'], params['embedding_dim'])
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

tb_comment = 'cifar-conv-vae'
tbl = TensorBoard(write_graph=True, comment=tb_comment)
tbml = TensorBoardModelLogger(comment=tb_comment)
rsl = ReconstructionsLogger(comment=tb_comment, output_shape=(3, 32, 32))
lsrl = LatentSpaceReconLogger(comment=tb_comment, output_shape=(3, 32, 32), latent_dim=params['embedding_dim'])
rrl = RandomReconLogger(comment=tb_comment, latent_dim=params['embedding_dim'], output_shape=(3, 32, 32))

In [6]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [7]:
trial = Trial(vae, optimizer, ConvVAE.loss_mse, metrics=['loss'], callbacks=[tbl, rsl, tbml, lsrl, rrl]).to(device)
trial.with_generators(trainloader, val_generator=testloader)

--------------------- OPTIMZER ---------------------
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)

-------------------- CRITERION ---------------------
<function ConvVAE.loss_mse at 0x000002CBC48697B8>

--------------------- METRICS ----------------------
['loss']

-------------------- CALLBACKS ---------------------
['torchbearer.callbacks.tensor_board.TensorBoard', 'tensorboard_logging.ReconstructionsLogger', 'tensorboard_logging.TensorBoardModelLogger', 'tensorboard_logging.LatentSpaceReconLogger', 'tensorboard_logging.RandomReconLogger']

---------------------- MODEL -----------------------
ConvVAE(
  (enc): Encoder(
    (conv1): Conv2d(3, 3, kernel_size=(2, 2), stride=(1, 1))
    (conv2): Conv2d(3, 32, kernel_size=(2, 2), stride=(2, 2))
    (conv3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (inter): Lin

In [2]:
print("torch.cuda.is_available()   =", torch.cuda.is_available())
print("torch.cuda.device_count()   =", torch.cuda.device_count())
print("torch.cuda.device('cuda')   =", torch.cuda.device('cuda'))
print("torch.cuda.current_device() =", torch.cuda.current_device())

torch.cuda.is_available()   = True
torch.cuda.device_count()   = 1
torch.cuda.device('cuda')   = <torch.cuda.device object at 0x000002CBC3699128>
torch.cuda.current_device() = 0


In [8]:
trial.run(epochs=params['nEpoch'])

0/100(t): 100%|██████████████████████| 196/196 [00:25<00:00, 10.07it/s, loss=142.6, loss_std=27.23, running_loss=117.3]
0/100(v): 100%|████████████████████████████████████| 40/40 [00:01<00:00, 20.12it/s, val_loss=105.1, val_loss_std=4.023]
1/100(t): 100%|██████████████████████| 196/196 [00:19<00:00,  9.93it/s, loss=102.2, loss_std=5.312, running_loss=97.68]
1/100(v): 100%|████████████████████████████████████| 40/40 [00:01<00:00, 23.26it/s, val_loss=91.42, val_loss_std=3.772]
2/100(t): 100%|██████████████████████| 196/196 [00:19<00:00,  9.90it/s, loss=93.96, loss_std=2.765, running_loss=92.83]
2/100(v): 100%|████████████████████████████████████| 40/40 [00:01<00:00, 22.80it/s, val_loss=86.51, val_loss_std=3.464]
3/100(t): 100%|██████████████████████| 196/196 [00:19<00:00,  9.93it/s, loss=90.29, loss_std=2.413, running_loss=89.15]
3/100(v): 100%|█████████████████████████████████████| 40/40 [00:01<00:00, 21.98it/s, val_loss=84.1, val_loss_std=3.111]
4/100(t): 100%|██████████████████████| 1

[((196, 40),
  {'running_loss': 117.30271957397461,
   'loss': 142.56299369189205,
   'loss_std': 27.233514316419765,
   'val_loss': 105.08052577972413,
   'val_loss_std': 4.022959170753057}),
 ((196, 40),
  {'running_loss': 97.68370697021484,
   'loss': 102.20032096395687,
   'loss_std': 5.311553330843156,
   'val_loss': 91.42498798370362,
   'val_loss_std': 3.7724186528207544}),
 ((196, 40),
  {'running_loss': 92.82899322509766,
   'loss': 93.95936374275051,
   'loss_std': 2.765106675331596,
   'val_loss': 86.50629959106445,
   'val_loss_std': 3.4644961538535104}),
 ((196, 40),
  {'running_loss': 89.15134689331055,
   'loss': 90.29133570924097,
   'loss_std': 2.4133886394601696,
   'val_loss': 84.09751434326172,
   'val_loss_std': 3.111270658913168}),
 ((196, 40),
  {'running_loss': 87.22463790893555,
   'loss': 88.02803506656569,
   'loss_std': 2.360915302651138,
   'val_loss': 80.81962165832519,
   'val_loss_std': 2.911472143819502}),
 ((196, 40),
  {'running_loss': 86.080377197265

In [5]:
utils.load(vae, "simple.w")