In [12]:
import os

import numpy as np
import torch
from pyro.contrib.examples.util import MNIST
import torch.nn as nn
import torchvision.transforms as transforms

import pyro
import pyro.distributions as dist
import pyro.contrib.examples.util  # patches torchvision
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import pickle

In [1]:
# This code is heavily based on the Pyro tutorial code for Variational Auto-Encoders: https://pyro.ai/examples/vae.html

In [13]:
pyro.__version__

'1.5.2'

In [14]:
torch.__version__

'1.6.0'

In [15]:
with open("data/sort-of-clevr.pickle", "rb") as fp:
    train_data, test_data = pickle.load(fp)

train_images = torch.tensor([x[0] for x in train_data]).float()
train_images = train_images.transpose(3, 1)
test_images = torch.tensor([x[0] for x in test_data]).float()
test_images = test_images.transpose(3, 1)

In [25]:
train_images[0].shape

torch.Size([3, 75, 75])

In [17]:
def setup_data_loaders(batch_size=128, use_cuda=False):    
    kwargs = {'num_workers': 1, 'pin_memory': use_cuda}
    train_loader = torch.utils.data.DataLoader(dataset=train_images,
        batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=test_images,
        batch_size=batch_size, shuffle=False, **kwargs)
    return train_loader, test_loader

In [18]:
class Encoder(nn.Module):
    def __init__(self, input_channels, hidden_channels, hidden_dim, z_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(3, hidden_channels[0], kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(hidden_channels[0], hidden_channels[1], kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.fc1 = nn.Linear(5184, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        
        self.softplus = nn.Softplus()
        
    def forward(self, x):
        x = self.pool(self.softplus(self.conv1(x)))
        x = self.pool(self.softplus(self.conv2(x)))
        x = x.flatten(1)
        x = self.softplus(self.fc1(x))
        
        mu = self.fc21(x)
        sigma = torch.exp(self.fc22(x))
        
        return mu, sigma

In [19]:
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_channels, output_channels):
        super().__init__()
        self.hidden_channels = hidden_channels
        
        self.fc1 = nn.Linear(z_dim, hidden_channels[1] * (75 // 4)**2)
        self.convt1 = nn.ConvTranspose2d(hidden_channels[1], hidden_channels[0], kernel_size=3, stride=2)
        self.convt2 = nn.ConvTranspose2d(hidden_channels[0], 3, kernel_size=3, stride=2)
        
        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, z):
        x = self.softplus(self.fc1(z))
        x = x.view(x.shape[0], self.hidden_channels[1], 75 // 4, 75 // 4)
        x = self.softplus(self.convt1(x))
        x = self.sigmoid(self.convt2(x))
        return x

In [20]:
class VAE(nn.Module):
    
    def __init__(self, z_dim=64, hidden_dim=2048, hc=(8,16), use_cuda=False):
        super().__init__()
        self.encoder = Encoder(input_channels=3, hidden_channels=hc, hidden_dim=hidden_dim, z_dim=z_dim)
        self.decoder = Decoder(z_dim=z_dim, hidden_channels=hc, output_channels=3)
        
        if use_cuda:
            self.cuda()
            
        self.use_cuda = use_cuda
        self.z_dim = z_dim
    
    def model(self, x):
        pyro.module("decoder", self.decoder)
        with pyro.plate("data", x.shape[0]):
            z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
            z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
            
            z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
            loc_img = self.decoder(z)
            x_flat = x.flatten(1)
            loc_img_flat = loc_img.flatten(1)
            obs = pyro.sample("obs", dist.Bernoulli(loc_img_flat).to_event(1), obs=x_flat.reshape(-1, 16875))
            
    def guide(self, x):
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", x.shape[0]):
            z_loc, z_scale = self.encoder(x)
            pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
        
    def reconstruct_img(self, x):
        z_loc, z_scale = self.encoder(x)
        z = dist.Normal(z_loc, z_scale).sample()
        loc_img = self.decoder(z)
        return loc_img

In [21]:
def train(svi, train_loader, use_cuda=False):
    epoch_loss = 0.
    for x in train_loader:
        if use_cuda:
            x = x.cuda()
        epoch_loss += svi.step(x)
        
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train
    return total_epoch_loss_train

In [22]:
def evaluate(svi, test_loader, use_cuda=False):
    test_loss = 0.
    for x in test_loader:
        if use_cuda:
            x = x.cuda()
        test_loss += svi.evaluate_loss(x)
    normalizer_test = len(test_loader.dataset)
    total_epoch_loss_test = test_loss / normalizer_test
    return total_epoch_loss_test

In [23]:
LEARNING_RATE = 1e-4
USE_CUDA = False

NUM_EPOCHS = 100
TEST_FREQUENCY = 10

In [24]:
train_loader, test_loader = setup_data_loaders(batch_size=256, use_cuda=USE_CUDA)

# clear param store
pyro.clear_param_store()

# setup the VAE
vae = VAE(use_cuda=USE_CUDA)

# setup the optimizer
adam_args = {"lr": LEARNING_RATE}
optimizer = Adam(adam_args)

# setup the inference algorithm
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())

train_elbo = []
test_elbo = []
# training loop
for epoch in range(NUM_EPOCHS):
    total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)
    train_elbo.append(-total_epoch_loss_train)
    print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))

    if epoch % TEST_FREQUENCY == 0:
        # report test diagnostics
        total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)
        test_elbo.append(-total_epoch_loss_test)
        print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))
        with open(f"trained_vae_epoch_{epoch}.p", "wb") as fp:
            pickle.dump(vae, fp)

[epoch 000]  average training loss: 12841.5535
[epoch 000] average test loss: 12503.8160
[epoch 001]  average training loss: 12166.9941
[epoch 002]  average training loss: 11329.0304
[epoch 003]  average training loss: 10372.3705
[epoch 004]  average training loss: 9412.5936
[epoch 005]  average training loss: 8476.5856
[epoch 006]  average training loss: 7609.6482


KeyboardInterrupt: 

In [None]:
vae = VAE()

In [None]:
vae.reconstruct_img(test_images[[0],...])