In [6]:
%load_ext autoreload
%autoreload 2
import torch
import torch.optim as optim
import tqdm.notebook as tqdm
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR
from torchvision.transforms import ToPILImage
from IPython.display import Image
from scipy.stats import pearsonr

import copy
import os
import sys
sys.path.append('/Users/Matt/projects/gmmvae/')

import gmmvae

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
rawhdata = pd.read_csv('../data/CIFAR10/cifar10h-raw.csv')
hprobs = np.load('../data/CIFAR10/cifar10h-probs.npy')

In [26]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

dataset = datasets.CIFAR10('../data', train=False, download=True, transform=transform)

loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'house', 'ship', 'truck')

Files already downloaded and verified


In [30]:
def train_vae(model, loader, optimiser, epoch):
    model.train()
    train_loss = 0
    batch_iter = tqdm.tqdm(enumerate(loader), desc='Batch')
    
    for batch_idx, (x, label) in batch_iter:
        optimiser.zero_grad()
        loss = -model.elbo(x, num_samples=1)
        loss.backward()
        optimiser.step()
        
        train_loss += loss.item()
        batch_iter.set_postfix(loss=loss.item())
        
    print('====> Epoch: {} Average loss: {:.3f}'.format(epoch, train_loss / len(loader.dataset)))

## Construct CIFAR10 VAE.

In [28]:
z_dim = 16

likelihood = gmmvae.templates.CIFAR10Likelihood(z_dim)
variational_dist = gmmvae.templates.CIFAR10VariationalDist(z_dim)


# Define various VAE models.
vae = gmmvae.models.VAE(likelihood, variational_dist, z_dim)

In [None]:
optimiser = optim.Adam(vae.parameters())
for epoch in range(1, 14+1):
    train_vae(vae, loader, optimiser, epoch)

HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 1 Average loss: 33.230


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 2 Average loss: 32.992


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 3 Average loss: 32.909


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 4 Average loss: 32.855


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…


====> Epoch: 5 Average loss: 32.829


HBox(children=(HTML(value='Batch'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), ma…