In [1]:
from torchvision.datasets import utils
import torch.utils.data as data_utils
import torch
import os
import numpy as np
from torch import nn
from torch.nn.modules import upsampling
from torch.functional import F
from torch.optim import Adam

In [2]:
def get_data_loader(dataset_location, batch_size):
    URL = "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/"
    # start processing
    def lines_to_np_array(lines):
        return np.array([[int(i) for i in line.split()] for line in lines])
    splitdata = []
    for splitname in ["train", "valid", "test"]:
        filename = "binarized_mnist_%s.amat" % splitname
        filepath = os.path.join(dataset_location, filename)
        utils.download_url(URL + filename, dataset_location)
        with open(filepath) as f:
            lines = f.readlines()
        x = lines_to_np_array(lines).astype('float32')
        x = x.reshape(x.shape[0], 1, 28, 28)
        # pytorch data loader
        dataset = data_utils.TensorDataset(torch.from_numpy(x))
        dataset_loader = data_utils.DataLoader(x, batch_size=batch_size, shuffle=splitname == "train")
        splitdata.append(dataset_loader)
    return splitdata

In [3]:
train, valid, test = get_data_loader("binarized_mnist", 64)

Using downloaded and verified file: binarized_mnist/binarized_mnist_train.amat
Using downloaded and verified file: binarized_mnist/binarized_mnist_valid.amat
Using downloaded and verified file: binarized_mnist/binarized_mnist_test.amat


In [4]:
class Encoder(nn.Module):
    def __init__(self, latent_size):
        super(Encoder, self).__init__()
        self.conv_stack = nn.Sequential(
            nn.Conv2d(1, 32, 3),
            nn.ELU(),
            nn.AvgPool2d(2),
            nn.Conv2d(32, 64, 3),
            nn.ELU(),
            nn.AvgPool2d(2),
            nn.Conv2d(64, 256, 5),
            nn.ELU()
        )
        self.mlp = nn.Sequential(
            nn.Linear(256, 2 * latent_size),
        )
    def forward(self, x):
        convolved = self.conv_stack(x)
        flattened = convolved.view(x.size(0), -1)
        z_mean, z_logvar = self.mlp(flattened).chunk(2, dim=-1)
        return z_mean, z_logvar

class Decoder(nn.Module):
    def __init__(self, latent_size):
        super(Decoder, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(latent_size, 256),
            nn.ELU()
        )
        
        self.upsample_stack = nn.Sequential(
            nn.Conv2d(256, 64, 5, padding=(4, 4)),
            nn.ELU(),
            upsampling.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(64, 32, 3, padding=(2, 2)),
            nn.ELU(),
            upsampling.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(32, 16, 3, padding=(2, 2)),
            nn.ELU(),
            nn.Conv2d(16, 1, 3, padding=(2, 2)),
        )
        
    def forward(self, z):
        flattened = self.mlp(z)
        convolved = self.upsample_stack(flattened.view(z.size(0), 256, 1, 1))
        return convolved - 5.
    

def kl_divergence(mean, logvar, prior_mean, prior_logvar):
    output = 0.5 * torch.sum(
        prior_logvar - logvar +
        ((torch.exp(logvar) + (mean - prior_mean)**2) /
         torch.exp(prior_logvar)) - 1., dim=-1
    )
    return output

class VAE(nn.Module):
    def __init__(self, latent_size):
        super(VAE, self).__init__()
        self.encode = Encoder(latent_size)
        self.decode = Decoder(latent_size)
    def forward(self, x):
        z_mean, z_logvar = self.encode(x)
        z_sample = z_mean + torch.exp(z_logvar / 2.) * torch.randn_like(z_logvar)
        x_mean = self.decode(z_sample)
        return z_mean, z_logvar, x_mean
    def loss(self, x, z_mean, z_logvar, x_mean):
        ZERO = torch.from_numpy(np.array(0.))
        kl = kl_divergence(z_mean, z_logvar, ZERO, ZERO).mean()
        recon_loss = F.binary_cross_entropy_with_logits(
            x_mean.view(x.size(0), -1),
            x.view(x.size(0), -1),
            reduction='none'
        ).sum(1).mean()
        return recon_loss + kl

In [None]:
vae = VAE(100)
params = vae.parameters()
optimizer = Adam(params, lr=3e-4)
print(vae)

In [None]:

for i in range(20):
    for x in train:
        z_mean, z_logvar, x_mean = vae(x)
        loss = vae.loss(x, z_mean, z_logvar, x_mean)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    with torch.no_grad():
        total_loss = 0.
        total_count = 0
        for x in valid:
            total_loss += vae.loss(x, *vae(x)) * x.size(0)
            total_count += x.size(0)
        print(total_loss / total_count)

In [None]:
torch.save(vae, 'model.pt')

In [5]:
x =  next(iter(valid))
vae = torch.load('model.pt')

In [6]:
c = - 0.5 * np.log(2*np.pi)

def log_normal(x, mean=None, logvar=None):
    # Log prob of scalar gaussian
    if mean is None:
        mean = 0. #torch.zeros_like(x)
    if logvar is None:
        logvar = torch.from_numpy(np.array(0.))
    sqr_dist = (x - mean)**2
    var = torch.exp(logvar)
    return -sqr_dist / (2. * var) - logvar/2. + c

In [7]:
with torch.no_grad():
    K = 200

    # Sample
    z_mean, z_logvar = vae.encode(x)
    eps = torch.randn(z_mean.size(0), K, z_mean.size(1))
    # Broadcast the noise over the mean and variance
    z_samples = z_mean[:, None, :] + torch.exp(z_logvar / 2.)[:, None, :] * eps

    # Decode samples

    # Flatten out the z samples
    z_samples_flat = z_samples.view(-1, z_samples.size(-1)) 
    x_mean = vae.decode(z_samples_flat) # Push it through
    # Bring it back to the original shape
    x_mean = x_mean.view(x.size(0), K, x_mean.size(-3), x_mean.size(-2), x_mean.size(-1))

    # Probabilities

    # Repeat images so they're the same shape as the reconstruction
    x_flat = x[:, None].repeat(1, K, 1, 1, 1)

    # Calculate all the probabilities!
    log_p_x_z = -F.binary_cross_entropy_with_logits(x_mean, x_flat, reduction='none').sum(dim=(-1, -2, -3))
    log_q_z_x = log_normal(z_samples, z_mean[:, None, :], z_logvar[:, None, :]).sum(-1) # Broadcasting again.
    log_p_z = log_normal(z_samples).sum(-1)

    # Recombine them.
    w = log_p_x_z + log_p_z - log_q_z_x
    k, _ = torch.max(w, dim=1, keepdim=True)
    arith_mean = torch.log(torch.mean(torch.exp(w - k))) + k[:, 0]



In [8]:
arith_mean

tensor([-101.7955,  -70.1837, -124.5817,  -75.1420,  -94.6216,  -98.7999,
        -108.0888,  -75.3675,  -47.6263,  -70.6153,  -99.8721, -101.5540,
        -105.6319,  -94.1859,  -79.0731,  -86.6333,  -52.8111,  -76.0696,
        -110.7590,  -93.9893,  -78.7240,  -87.4326,  -98.1920,  -94.1972,
        -113.9483,  -53.7177,  -96.9957,  -48.2595, -120.0450,  -88.6475,
         -46.4143, -114.3459,  -71.4471, -115.1691,  -52.1049,  -92.5448,
        -116.2071,  -97.9242,  -57.5351, -112.6738, -101.1358, -113.8887,
        -109.4086, -109.7504,  -84.0596, -100.4736, -110.9115, -123.2812,
        -119.9335,  -87.7900,  -77.1799,  -46.4592, -115.0935,  -64.8117,
         -65.3133,  -84.7353,  -41.6219, -116.8076,  -82.0027,  -76.8602,
         -99.6423,  -70.2371, -117.4422, -108.7642])