In [23]:
import os

import numpy as np
import torch
from pyro.contrib.examples.util import MNIST
import torch.nn as nn
import torchvision.transforms as transforms

import pyro
import pyro.distributions as dist
import pyro.contrib.examples.util  # patches torchvision
from pyro.nn import PyroSample, PyroModule
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

In [24]:
# encoder

# self.fc1 = nn.Linear(784, 400)
# self.fc1a = nn.Linear(400, 100)
# self.fc21 = nn.Linear(100, 2)  # Latent space of 2D
# self.fc22 = nn.Linear(100, 2)  # Latent space of 2D
# ...
# h1 = F.relu(self.fc1(x))
# h2 = F.relu(self.fc1a(h1))
# return self.fc21(h2), self.fc22(h2)


In [25]:
distbsdsdsasfs = dist.Normal(0., 1.).expand([784,400]).to_event(2)

In [26]:
distbsdsdsasfs.sample().shape


torch.Size([784, 400])

In [27]:
class Encoder(PyroModule):
    def __init__(self):
        super().__init__()
        # setup the three linear transformations used
        self.fc1 = PyroModule[nn.Linear](400, 784)
        self.fc1.weight = PyroSample(dist.Normal(0., 1.).expand([400, 784]).to_event(2))
        self.fc1.bias = PyroSample(dist.Normal(0., 1.).expand([400]).to_event(1))
        self.fc1a = PyroModule[nn.Linear](100, 400)
        self.fc1a.weight = PyroSample(dist.Normal(0., 1.).expand([100, 400]).to_event(2))
        self.fc1a.bias = PyroSample(dist.Normal(0., 1.).expand([100]).to_event(1))
        self.fc21 = PyroModule[nn.Linear](2, 100)
        self.fc21.weight = PyroSample(dist.Normal(0., 1.).expand([2, 100]).to_event(2))
        self.fc21.bias = PyroSample(dist.Normal(0., 1.).expand([2]).to_event(1))
        self.fc22 = PyroModule[nn.Linear](2, 100)
        self.fc22.weight = PyroSample(dist.Normal(0., 1.).expand([2, 100]).to_event(2))
        self.fc22.bias = PyroSample(dist.Normal(0., 1.).expand([2]).to_event(1))
        # setup the non-linearities
        self.relu = nn.ReLU()

    def forward(self, x):
        # define the forward computation on the image x
        # first shape the mini-batch to have pixels in the rightmost dimension
        x = x.reshape(-1, 784)
        # then compute the hidden units
        h1 = self.relu(self.fc1(x))
        h2 = self.relu(self.fc1a(h1))
        return self.fc21(h2), torch.exp(self.fc22(h2))

In [28]:
# decoder

# self.fc3 = nn.Linear(2, 100) # Latent space of 2D
# self.fc3a = nn.Linear(100, 400)
# self.fc4 = nn.Linear(400, 784)
# ...
# h3 = F.relu(self.fc3(z))
# h4 = F.relu(self.fc3a(h3))
# return torch.sigmoid(self.fc4(h4))


In [29]:
class Decoder(PyroModule):
    def __init__(self):
        super().__init__()
        # setup the two linear transformations used
        self.fc3 = PyroModule[nn.Linear](2, 100)
        self.fc3.weight = PyroSample(dist.Normal(0., 1.).expand([2, 100]).to_event(2))
        self.fc3.bias = PyroSample(dist.Normal(0., 1.).expand([100]).to_event(1))
        self.fc3a = PyroModule[nn.Linear](100, 400)
        self.fc3a.weight = PyroSample(dist.Normal(0., 1.).expand([100, 400]).to_event(2))
        self.fc3a.bias = PyroSample(dist.Normal(0., 1.).expand([400]).to_event(1))
        self.fc4 = PyroModule[nn.Linear](400, 784)
        self.fc4.weight = PyroSample(dist.Normal(0., 1.).expand([400, 784]).to_event(2))
        self.fc4.bias = PyroSample(dist.Normal(0., 1.).expand([784]).to_event(1))
        # setup the non-linearities
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, z):
        # define the forward computation on the latent z
        # first compute the hidden units
        h1 = self.relu(self.fc3(z))
        h2 = self.relu(self.fc3a(h1))
        loc_img = self.sigmoid(self.fc4(h2))
        # return the parameter for the output Bernoulli
        # each is of size batch_size x 784
        return loc_img

In [35]:
z_test = pyro.sample("latent", dist.Normal(0, 1))
z_test

tensor(0.7289)

In [30]:
class BAE(PyroModule):
    # by default our latent space is 50-dimensional
    # and we use 400 hidden units
    def __init__(self, use_cuda=False):
        super().__init__()
        # create the encoder and decoder networks
        self.encoder = Encoder()
        self.decoder = Decoder()

        if use_cuda:
            # calling cuda() here will put all the parameters of
            # the encoder and decoder networks into gpu memory
            self.cuda()
        self.use_cuda = use_cuda

    # define the model p(x|z)p(z)
    def model(self, x):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)

        #with pyro.plate("name", size) as ind:
            # ...do conditionally independent stuff with ind...

            
        with pyro.plate("data", x.shape[0]) as ind:
            # setup hyperparameters for prior p(z)
            #z_loc = x.new_zeros(torch.Size((x.shape[0], 2)))
            #z_scale = x.new_ones(torch.Size((x.shape[0], 2)))
            # sample from prior (value will be sampled by guide when computing the ELBO)
            #z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
            z = pyro.sample("latent", dist.Normal(0, 1))
            # decode the latent code z
            print(z.shape)
            loc_img = self.decoder(z)
            # score against actual images
            pyro.sample("obs", dist.Bernoulli(loc_img, validate_args=False).to_event(1), obs=x.reshape(-1, 784))

    # define the guide (i.e. variational distribution) q(z|x)
    def guide(self, x):
        # register PyTorch module `encoder` with Pyro
        #pyro.module("encoder", self.encoder)
        z_loc, z_scale = self.encoder(x)
        with pyro.plate("data", x.shape[0]):
            # use the encoder to get the parameters used to define q(z|x)
            #z_loc, z_scale = self.encoder(x)
            # sample the latent code z
            pyro.sample("latent", dist.Normal(z_loc, z_scale, validate_args=False).to_event(1))

    # define a helper function for reconstructing images
    def reconstruct_img(self, x):
        # encode image x
        z_loc, z_scale = self.encoder(x)
        # sample in latent space
        z = dist.Normal(z_loc, z_scale).sample()
        # decode the image (note we don't sample in image space)
        loc_img = self.decoder(z)
        return loc_img1

In [31]:
# for loading and batching MNIST dataset
def setup_data_loaders(batch_size=128, use_cuda=False):
    root = './data'
    download = True
    trans = transforms.ToTensor()
    train_set = MNIST(root=root, train=True, transform=trans,
                      download=download)
    test_set = MNIST(root=root, train=False, transform=trans)

    kwargs = {'num_workers': 1, 'pin_memory': use_cuda}
    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                              batch_size=batch_size, shuffle=False, **kwargs)
    return train_loader, test_loader

In [32]:
def train(svi, train_loader, use_cuda=False):
    # initialize loss accumulator
    epoch_loss = 0.
    # do a training epoch over each mini-batch x returned
    # by the data loader
    for x, _ in train_loader:
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # do ELBO gradient and accumulate loss
        epoch_loss += svi.step(x)

    # return epoch loss
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train
    return total_epoch_loss_train

def evaluate(svi, test_loader, use_cuda=False):
    # initialize loss accumulator
    test_loss = 0.
    # compute the loss over the entire test set
    for x, _ in test_loader:
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # compute ELBO estimate and accumulate loss
        test_loss += svi.evaluate_loss(x)
    normalizer_test = len(test_loader.dataset)
    total_epoch_loss_test = test_loss / normalizer_test
    return total_epoch_loss_test

In [33]:

# Run options
LEARNING_RATE = 1.0e-3
USE_CUDA = False
smoke_test = True

# Run only for a single iteration for testing
NUM_EPOCHS = 1 if smoke_test else 100
TEST_FREQUENCY = 5
train_loader, test_loader = setup_data_loaders(batch_size=128, use_cuda=USE_CUDA)

# clear param store
pyro.clear_param_store()

# setup the VAE
bae = BAE(use_cuda=USE_CUDA)

# setup the optimizer
adam_args = {"lr": LEARNING_RATE}
optimizer = Adam(adam_args)

# setup the inference algorithm
svi = SVI(bae.model, bae.guide, optimizer, loss=Trace_ELBO())

train_elbo = []
test_elbo = []
# training loop
for epoch in range(NUM_EPOCHS):
    total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)
    train_elbo.append(-total_epoch_loss_train)
    print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))

    if epoch % TEST_FREQUENCY == 0:
        # report test diagnostics
        total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)
        test_elbo.append(-total_epoch_loss_test)
        print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))

torch.Size([128, 2])


RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
          Trace Shapes:                  
           Param Sites:                  
          Sample Sites:                  
              data dist         |        
                  value     128 |        
            latent dist     128 |        
                  value 128   2 |        
decoder.fc3.weight dist     128 |   2 100
                  value     128 |   2 100
  decoder.fc3.bias dist     128 | 100    
                  value     128 | 100    