<a href="https://colab.research.google.com/github/arnegebert/vae-iwae-vralpha/blob/master/code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import datetime
import logging
import os

import numpy as np
import torch
import torch.utils.data
from torch import nn, optim, Tensor as T
from torch.autograd import detect_anomaly
from torch.distributions.multinomial import Multinomial
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [0]:
# Hyperparameters
alpha = -500 # alpha value used in Renyi alpha-divergence, ignored when model_type is vae/iwae/vrmax
K = 5 # number of samples taken per input data point
L = 2 # number of stochastic layers in network architecture; either 1 or 2

algorithm = 'vralpha' # one of ['vae', 'iwae', 'vrmax', 'vralpha', 'general_alpha']
# For details for each of these, especially VR-max, VR-alpha and general alpha, please see the
# Renyi Divergence Variational Inference paper
# But shortly:
# vae
#   = Use L_VI objective (Optimize (1/K)*SUM(log[ p(x,z_k)/q(z_k|x)])).
#   Choice of alpha is ignored.
# iwae
#   = Use IWAE objective L_k (Optimize log[(1/K)*SUM(p(x,z)/q(z|x))]).
#   Choice of alpha ignored.
# vrmax
#   = Use VR-max algorithm.
#   => (Optimize log[ MAX_k(p(x,z_k)/q(z_k|x))] ).
#   Choice of alpha is ignored.
# vralpha
#   = Use VR-alpha
#   => (Optimize log[ p(x,z_k)/q(z_k|x)] where the kth sample is whosen w.p. ~ magnitude^(1-alpha)).
#   Choice of alpha is important.
# general_alpha
#   = Use direct estimate of L_{\alpha,K}
#   => (Optimize [1/(1-alpha)]*log((1/K)*SUM( log[ (p(x,z_k)/q(z_k|x))^(1-alpha) ] ))).
#   Choice of alpha is important.
#   (general_alpha is the same as VR-alpha except that we backpropagate K samples instead of only one.
#   That is, we don't estimate the objective by taking only one of the K samples, but instead utilize all
#   K samples. )

data_name = 'fashion' # one of ['mnist', 'fashion', 'fashionmnist']

epochs = 100
learning_rate = 1e-3

log_interval = 1 # how frequently to log average training loss
test_interval = 5 # how frequently to test
train_batch_size = 256 # batch size during training
test_batch_size = 32 # batch size used during testing, different than training because testing is done with K=5000

seed = 0 # fixed seed
torch.manual_seed(seed)

os.makedirs('results', exist_ok=True)
os.makedirs('models', exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

assert(L in [1, 2]) # we only have networks with 1 or 2 stochastic layers
assert(algorithm in ['vae', 'iwae', 'vrmax', 'vralpha', 'general_alpha'])
assert(not(alpha==1 and algorithm in ['vralpha', 'general_alpha'])) # divide by 0 error otherwise
assert(data_name in ['mnist', 'fashion', 'fashionmnist'])

In [0]:
def load_data_and_initialize_loaders(data_name, train_batch, test_batch):
    data_name = data_name.lower()
    kwargs = {'num_workers': 1, 'pin_memory': True}
    if data_name == 'mnist':
        train_data = datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())
        test_data = datasets.MNIST('./data', train=False, transform=transforms.ToTensor())
    elif data_name == 'fashion' or data_name == 'fashionmnist':
        train_data = datasets.FashionMNIST('./data', train=True, download=True, transform=transforms.ToTensor())
        test_data = datasets.FashionMNIST('./data', train=False, transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(train_data, batch_size = train_batch, shuffle = True, ** kwargs)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size = test_batch, shuffle = True, ** kwargs)
    return train_loader, test_loader

In [0]:
# Implementation of the VAE used on the MNIST dataset with 1 stochastic layer.
class mnist_model_1(nn.Module):
    def __init__(self, alpha):
        super(mnist_model_1, self).__init__()

        self.fc1 = nn.Linear(784, 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc31 = nn.Linear(200, 50)
        self.fc32 = nn.Linear(200, 50)

        self.fc4 = nn.Linear(50, 200)
        self.fc5 = nn.Linear(200, 200)
        self.fc6 = nn.Linear(200, 784)

        self.K = K
        self.alpha = alpha

    def encode(self, x):
        h1 = torch.tanh(self.fc1(x))
        h2 = torch.tanh(self.fc2(h1))
        return self.fc31(h2), self.fc32(h2)

    def reparameterize(self, mu, logstd):
        std = torch.exp(logstd)
        eps = torch.randn_like(std)
        # This is the reparametrization trick - represent the sample as a sum rather than black-box generated number
        return mu + eps*std

    def decode(self, z):
        h3 = torch.tanh(self.fc4(z))
        h4 = torch.tanh(self.fc5(h3))
        return torch.sigmoid(self.fc6(h4))

    def forward(self, x):
        mu, logstd = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logstd)
        return self.decode(z), mu, logstd

    def compute_loss_for_batch(self, data, model, K=K, test=False):
        # data = (B, 1, H, W)
        B, _, H, W = data.shape

        # Generate K copies of each observation. Each will get sampled once according to the generated distribution to generate a total of K observation samples
        data_k_vec = data.repeat((1, K, 1, 1)).view(-1, H*W)

        # Retrieve the estimated mean and log(standard deviation) estimates from the posterior approximator
        mu, logstd = model.encode(data_k_vec)

        # Use the reparametrization trick to generate (mean)+(epsilon)*(standard deviation) for each sample of each observation
        z = model.reparameterize(mu, logstd)

        # Calculate log q(z|x) - how likely are the importance samples given the distribution that generated them?
        log_q = compute_log_probabitility_gaussian(z, mu, logstd)

        # Calculate log p(z) - how likely are the importance samples under the prior N(0,1) assumption?
        log_p_z = compute_log_probabitility_gaussian(z, torch.zeros_like(z, requires_grad=False), torch.zeros_like(z, requires_grad=False))

        # Hand the samples to the decoder network and get a reconstruction of each sample.
        decoded = model.decode(z)

        # Calculate log p(x|z) with a bernoulli distribution - how likely are the recreations given the latents that generated them?
        log_p = compute_log_probabitility_bernoulli(decoded, data_k_vec)

        # Begin calculating L_alpha depending on the (a) model type, and (b) optimization method
        # log_p_z + log_p - log_q = log(p(z_i)p(x|z_i)/q(z_i|x)) = log(p(x,z_i)/q(z_i|x)) = L_VI
        #   (for each importance sample i out of K for each observation)
        if algorithm == 'iwae' or test:
            # Re-order the entries so that each row holds the K importance samples for each observation
            log_w_matrix = (log_p_z + log_p - log_q).view(B, K)

        elif algorithm =='vae':
            # Don't reorder, and divide by K in anticipation of taking a batch sum of (1/K)*SUM(log(p(x,z)/q(z|x)))
            log_w_matrix = (log_p_z + log_p - log_q).view(B*K, 1)*1/K

        elif algorithm=='general_alpha' or algorithm=='vralpha':
            # Re-order the entries so that each row holds the K importance samples for each observation
            # Multiply by (1-alpha) because (1-alpha)* log(p(x,z_i)/q(z_i|x)) =  log([p(x,z_i)/q(z_i|x)]^(1-alpha))
            log_w_matrix = (log_p_z + log_p - log_q).view(-1, K) * (1-alpha)

        elif algorithm == 'vrmax':
            # Re-order the entries so that each row holds the K importance samples for each observation
            # Take the max in each row, representing the maximum-weighted sample
            log_w_matrix = (log_p_z + log_p - log_q).view(-1, K).max(axis=1, keepdim=True).values

            # immediately return loss = -sum(L_alpha) over each observation
            return -torch.sum(log_w_matrix)

        # Begin using the "max trick". Subtract the maximum log(*) sample value for each observation.
        # log_w_minus_max = log([p(z_i,x)/q(z_i|x)] / max([p(z_k,x)/q(z_k|x)]))
        log_w_minus_max = log_w_matrix - torch.max(log_w_matrix, 1, keepdim=True)[0]

        # Exponentiate so that each term is [p(z_i,x)/q(z_i|x)] / max([p(z_k,x)/q(z_k|x)]) (no log)
        ws_matrix = torch.exp(log_w_minus_max)

        # Calculate normalized weights in each row. Max denominators cancel out!
        # ws_norm = [p(z_i,x)/q(z_i|x)]/SUM([p(z_k,x)/q(z_k|x)])
        ws_norm = ws_matrix / torch.sum(ws_matrix, 1, keepdim=True)

        if algorithm == 'vralpha' and not test:
            # If we're specifically using a VR-alpha model, we want to choose a sample to backprop according to the values in ws_norm above
            # So we make a distribution in each row
            sample_dist = Multinomial(1, ws_norm)

            # Then we choose a sample in each row acccording to this distribution
            ws_sum_per_datapoint = log_w_matrix.gather(1, sample_dist.sample().argmax(1, keepdim=True))
        else:
            # For any other model, we're taking the full sum at this point
            ws_sum_per_datapoint = torch.sum(log_w_matrix * ws_norm, 1)

        if algorithm in ["general_alpha", "vralpha"] and not test:
            # For both VR-alpha and directly estimating L_alpha with a sum, we have to renormalize the sum with 1-alpha
            ws_sum_per_datapoint /= (1 - alpha)

        # Return a value of loss = -L_alpha as the batch sum.
        loss = -torch.sum(ws_sum_per_datapoint)

        return loss


In [0]:
# Implementation of the VAE used on the MNIST dataset with 2 stochastic layers.
class mnist_model_2(nn.Module):
    def __init__(self, alpha):
        super(mnist_model_2, self).__init__()

        self.fc1 = nn.Linear(784, 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc31 = nn.Linear(200, 100)  # stochastic 1
        self.fc32 = nn.Linear(200, 100)

        self.fc4 = nn.Linear(100, 100)
        self.fc5 = nn.Linear(100, 100)
        self.fc61 = nn.Linear(100, 50)  # Innermost (stochastic 2)
        self.fc62 = nn.Linear(100, 50)

        self.fc7 = nn.Linear(50, 100)
        self.fc8 = nn.Linear(100, 100)
        self.fc81 = nn.Linear(100, 100)  # stochastic 1
        self.fc82 = nn.Linear(100, 100)

        self.fc9 = nn.Linear(100, 200)
        self.fc10 = nn.Linear(200, 200)
        self.fc11 = nn.Linear(200, 784)  # reconstruction

        self.K = K
        self.alpha = alpha

    def encode(self, x):
        h1 = torch.tanh(self.fc1(x))
        h2 = torch.tanh(self.fc2(h1))
        mu, log_std = self.fc31(h2), self.fc32(h2)

        z1 = self.reparameterize(mu, log_std)
        h3 = torch.tanh(self.fc4(z1))
        h4 = torch.tanh(self.fc5(h3))

        return self.fc61(h4), self.fc62(h4), [x, z1]

    def reparameterize(self, mu, logstd, test=False):
        std = torch.exp(logstd)
        if test == True:
            eps = torch.zeros_like(mu)
        else:
            eps = torch.randn_like(std)
        # This is the reparametrization trick - represent the sample as a sum rather than black-box generated number
        return mu + eps * std

    def decode(self, z, test=False):
        h5 = torch.tanh(self.fc7(z))
        h6 = torch.tanh(self.fc8(h5))
        mu, log_std = self.fc81(h6), self.fc82(h6)

        z1 = self.reparameterize(mu, log_std, test=test)
        h7 = torch.tanh(self.fc9(z1))
        h8 = torch.tanh(self.fc10(h7))

        return torch.sigmoid(self.fc11(h8))

    def forward(self, x):
        mu, logstd, _ = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logstd)
        return self.decode(z), mu, logstd

    def compute_loss_for_batch(self, data, model, K=K, test=False):
        B, _, H, W = data.shape

        # First repeat the observations K times, representing the data as a flat (M*K, # of pixels)
        data_k_vec = data.repeat((1, K, 1, 1)).view(-1, H * W)

        # Encode the model and retrieve estimated distribution parameters mu and log(standard deviation) for each sample of each observation
        # z1 holds the latent samples generated at the first stochastic layer.
        mu, log_std, [x, z1] = self.encode(data_k_vec)

        # Sample from each observation's approximated latent distribution in each row (i.e. once for each of K importance samples, represented by rows)
        # (this uses the reparametrization trick!)
        z = model.reparameterize(mu, log_std)

        # Calculate Log p(z) (prior) - how likely are these values given the prior assumption N(0,1)?
        log_p_z = torch.sum(-0.5 * z ** 2, 1) - .5 * z.shape[1] * T.log(torch.tensor(2 * np.pi))

        # Calculate q (z | h1) - how likely are the generated output latent samples given the distributions they came from?
        log_qz_h1 = compute_log_probabitility_gaussian(z, mu, log_std)

        # Re-Generate the mu and log_std that generated the first-layer latents z1
        h1 = torch.tanh(self.fc1(x))
        h2 = torch.tanh(self.fc2(h1))
        mu, log_std = self.fc31(h2), self.fc32(h2)

        # Calculate log q(h1|x) - how likely are the first-stochastic-layer latents given the distributions they come from?
        log_qh1_x = compute_log_probabitility_gaussian(z1, mu, log_std)

        # Calculate the distribution parameters that generated the first-layer latents upon decoding
        h5 = torch.tanh(self.fc7(z))
        h6 = torch.tanh(self.fc8(h5))
        mu, log_std = self.fc81(h6), self.fc82(h6)

        # Calculate log p(h1|z) - how likely are the latents z1 under the parameters of the distribution here?
        #   (This directly encourages the decoder to learn the inverse of the map h1->z)
        log_ph1_z = compute_log_probabitility_gaussian(z1, mu, log_std)

        # Finally calculate the reconstructed image
        h7 = torch.tanh(self.fc9(z1))
        h8 = torch.tanh(self.fc10(h7))
        decoded = torch.sigmoid(self.fc11(h8))

        # calculate log p(x | h1) - how likely is the reconstruction given the latent samples that generated it?
        log_px_h1 = compute_log_probabitility_bernoulli(decoded, x)

        # Begin calculating L_alpha depending on the (a) model type, and (b) optimization method
        # log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 - log_qh1_x =
        #           log([p(z0_i)p(x|z1_i)p(z1_i|z0_i)]/[q(z0_i|z1_i)q(z1_i|x)]) = log(p(x,z0_i,z1_i)/q(z0_i,z1_i|x)) = L_VI
        #   (for each importance sample i out of K for each observation)
        # Note that if test==True then we're always using the IWAE objective!
        if algorithm == 'iwae' or test == True:
            # Re-order the entries so that each row holds the K importance samples for each observation
            log_w_matrix = (log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 - log_qh1_x).view(-1, K)

        elif algorithm == 'vae':
            # Don't reorder, and divide by K in anticipation of taking a batch sum of (1/K)*SUM(log(p(x,z)/q(z|x)))
            log_w_matrix = (log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 - log_qh1_x).view(-1, 1) * 1 / K
            return -torch.sum(log_w_matrix)

        elif algorithm == 'general_alpha' or algorithm == 'vralpha':
            # Re-order the entries so that each row holds the K importance samples for each observation
            # Multiply by (1-alpha) because (1-alpha)* log(p(x,z_i)/q(z_i|x)) =  log([p(x,z_i)/q(z_i|x)]^(1-alpha))
            log_w_matrix = (log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 - log_qh1_x).view(-1, K) * (1 - self.alpha)

        elif algorithm == 'vrmax':
            # Re-order the entries so that each row holds the K importance samples for each observation
            # Take the max in each row, representing the maximum-weighted sample, then immediately return batch sum loss -L_alpha
            log_w_matrix = (log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 - log_qh1_x).view(-1, K).max(axis=1,keepdim=True).values
            return -torch.sum(log_w_matrix)

        # Begin using the "max trick". Subtract the maximum log(*) sample value for each observation.
        # log_w_minus_max = log([p(z_i,x)/q(z_i|x)] / max([p(z_k,x)/q(z_k|x)]))
        log_w_minus_max = log_w_matrix - torch.max(log_w_matrix, 1, keepdim=True)[0]

        # Exponentiate so that each term is [p(z_i,x)/q(z_i|x)] / max([p(z_k,x)/q(z_k|x)]) (no log)
        ws_matrix = torch.exp(log_w_minus_max)

        # Calculate normalized weights in each row. Max denominators cancel out!
        # ws_norm = [p(z_i,x)/q(z_i|x)]/SUM([p(z_k,x)/q(z_k|x)])
        ws_norm = ws_matrix / torch.sum(ws_matrix, 1, keepdim=True)

        if algorithm == 'vralpha' and not test:
            # If we're specifically using a VR-alpha model, we want to choose a sample to backprop according to the values in ws_norm above
            # So we make a distribution in each row
            sample_dist = Multinomial(1, ws_norm)

            # Then we choose a sample in each row acccording to this distribution
            ws_sum_per_datapoint = log_w_matrix.gather(1, sample_dist.sample().argmax(1, keepdim=True))
        else:
            # For any other model, we're taking the full sum at this point
            ws_sum_per_datapoint = torch.sum(log_w_matrix * ws_norm, 1)

        if algorithm in ["general_alpha", "vralpha"] and not test:
            # For both VR-alpha and directly estimating L_alpha with a sum, we have to renormalize the sum with 1-alpha
            ws_sum_per_datapoint /= (1 - alpha)

        loss = -torch.sum(ws_sum_per_datapoint)

        return loss

In [0]:
# Compute N(obs| mu, sigma) for all K samples and sum over probabilities of the K samples
def compute_log_probabitility_gaussian(obs, mu, logstd, axis=1):
    return torch.sum(-0.5 * ((obs-mu) / torch.exp(logstd)) ** 2 - logstd, axis)-.5*obs.shape[1]*T.log(torch.tensor(2*np.pi))

# Compute Ber(obs| theta) for all K samples and sum over probabilities of the K samples
def compute_log_probabitility_bernoulli(theta, obs, axis=1):
    # 1e-18 needed to avoid numerical errors
    return torch.sum(obs*torch.log(theta+1e-18) + (1-obs)*torch.log(1-theta+1e-18), axis)

In [0]:
# train and test functions
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        # (B, 1, F1, F2) (e.g. (128, 1, 28, 28) for MNIST with B=128)
        data = data.to(device)
        optimizer.zero_grad()

        loss = model.compute_loss_for_batch(data, model)
        # comment this back in in case of NaNs
        #with detect_anomaly():
        #    loss.backward()
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    if epoch % log_interval == 0:
        print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')
        logging.info(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, labels) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            loss = model.compute_loss_for_batch(data, model, K=5000,test=True)
            test_loss += loss.item()
            if i == 0:
                # Visualizing reconstructions
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(test_batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         f'results/reconstruction_{algorithm}_L={L}_{data_name}_alpha={alpha}_K={K}_epoch={epoch}.png',
                           nrow=n)
                # Visualizing random samples from the latent space
                noise = torch.randn(64, 50).to(device)
                sample = model.decode(noise).cpu()
                save_image(sample.view(64, 1, 28, 28),
                           f'results/sample_{algorithm}_L={L}_{data_name}_alpha={alpha}_K={K}_epoch={epoch}.png')
    test_loss /= len(test_loader.dataset)
    print(f'====> Epoch: {epoch} Test set loss: {test_loss:.4f}')
    logging.info(f'====> Epoch: {epoch} Test set loss: {test_loss:.4f}')
    return test_loss

In [9]:
if L==1:
    model = mnist_model_1(alpha).to(device)
else:
    model = mnist_model_2(alpha).to(device)
train_loader, test_loader = load_data_and_initialize_loaders(data_name, train_batch_size, test_batch_size)


optimizer = optim.Adam(model.parameters(), lr=learning_rate)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Processing...
Done!
Training on GPU




In [20]:
if torch.cuda.is_available():
    print("Training on GPU")
    logging.info("Training on GPU")

print(f'{datetime.datetime.now()} \nStarting training')
logging.info(f'{datetime.datetime.now()} \nStarting training')
for e in range(1, epochs+1):
    train(e)
    if e % test_interval == 0:
        test(e)
test(epochs)
print(datetime.datetime.now())
logging.info(datetime.datetime.now())
print("Training finished")
logging.info("Training finished")

print("Saving model")
torch.save(model.state_dict(),
           f'models/{algorithm}_L={L}_{data_name}_alpha={alpha}_K={K}_epochs={epochs}.pt')
print("Saved model")

Training on GPU
2020-06-12 18:12:56.008404 
Starting training
====> Epoch: 1 Average loss: 237.5309
====> Epoch: 1 Test set loss: 234.6304
2020-06-12 18:13:33.955103
Training finished
Saving model
