In [1]:
import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch.distributions as dists

# Dateset preperation
We used the FashionMNIST dataset and binarized it to fit the parameters of the architecture presented in the paper

In [2]:
class Binarized_FashionMNIST(torchvision.datasets.FashionMNIST):
    def __init__(self, root, train, transform=None, target_transform=None, download=False):
        super(Binarized_FashionMNIST, self).__init__(root, train, transform, target_transform, download)

    def __getitem__(self, idx):
        img, target = super().__getitem__(idx)
        return dists.Bernoulli(img).sample().type(torch.float32)


data_loader_train = torch.utils.data.DataLoader(
    Binarized_FashionMNIST('./data', train=True, transform=torchvision.transforms.ToTensor(), download=True),
    batch_size=128, shuffle=True)

data_loader_test = torch.utils.data.DataLoader(
    Binarized_FashionMNIST('./data', train=False, transform=torchvision.transforms.ToTensor(), download=True),
    batch_size=32, shuffle=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/Binarized_FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ./data/Binarized_FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/Binarized_FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/Binarized_FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ./data/Binarized_FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/Binarized_FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/Binarized_FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ./data/Binarized_FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/Binarized_FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/Binarized_FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ./data/Binarized_FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/Binarized_FashionMNIST/raw



# Single stochastic layer VAE and IWAE
Each instance holds a boolean variable indicating whether the model should run the VAE objective of the IWAE objective.<br>
When the model is defined as a VAE, The objective function is calculated using k=1. When the model is set to be an IWAE model, the lower bound is calculated as specified in Eq. 14 in the paper.
The lower bound is calculated in the following fashion:
Computing the lower bound of the log-likelihood is given by: <br>
$ log(q(h_1 | x)) = log(\frac{1}{\sqrt{2\pi}{\sigma}_h}e^{\frac{-{(h_1 - {\mu}_h)}^2}{2{\sigma}^2}}) = log(\frac{1}{\sqrt{2\pi}{\sigma}_h}e^{\frac{-{({\mu}_h +\epsilon\sigma - {\mu}_h)}^2}{2{\sigma}^2}}) = log(1) - log(\sqrt{2\pi}) -log({\sigma}_h) -\frac{-({\epsilon}^2{\sigma}^2)}{2{\sigma}^2} = log(1) - log(\sqrt{2\pi}) -log({\sigma}_h) -\frac{-{\epsilon}^2}{2} $ <br>
$ log(p(h_1)) = log(\frac{1}{\sqrt{2\pi}}e^{\frac{-{(h_1)}^2}{2}}) = log(1) - log(\sqrt{2\pi}) - \frac{-{h_1}^2}{2} $<br>
$ log(p(x| h_1)) = log({p}^{x}*{(1-p)}^{1-x}) = xlog(p) + (1-x)log(1-p) $ <br>
When ignoring constants: <br>
$ log(w) = log(q(h_1 | x)) + log(p(h_1)) - log(p(x| h_1) = -log({\sigma}_h) -\frac{-{\epsilon}^2}{2} - \frac{-{h_1}^2}{2} - xlog(p) + (1-x)log(1-p) $

In [5]:
class StochasticLayer(nn.Module):
    """
    A class representing a cluster of layers in the network
    """
    def __init__(self, input_dim, hid_dim, output_dim):
        super(StochasticLayer, self).__init__()
        self.layer = nn.Sequential(nn.Linear(input_dim, hid_dim),
                                   nn.Tanh(),
                                   nn.Linear(hid_dim, hid_dim),
                                   nn.Tanh())
        self.to_mu = nn.Linear(hid_dim, output_dim)
        self.to_logvar = nn.Linear(hid_dim, output_dim)

    def forward(self, x):
        out = self.layer(x)
        mu = self.to_mu(out)
        logvar = self.to_logvar(out)
        sigma = torch.exp(logvar)
        return mu, sigma


class VAEorIWAE(nn.Module):
    """
    A class representing the model of the network.
    The VAE and IWAE are identical in architecture and differ only in the objective function, so we can use one class to represent both.
    """
    def __init__(self, input_dim, h1_dim=50, VAE=False):
        super(VAEorIWAE, self).__init__()
        self.vae = VAE
        self.input_dim = input_dim
        self.h1_dim = h1_dim
        self.middle_layer = 200
        self.encoder_module = StochasticLayer(input_dim, self.middle_layer, h1_dim)
        self.decoder_module = nn.Sequential(nn.Linear(h1_dim, self.middle_layer),
                                            nn.Tanh(),
                                            nn.Linear(self.middle_layer, self.middle_layer),
                                            nn.Tanh(),
                                            nn.Linear(self.middle_layer, input_dim),
                                            nn.Sigmoid())

    def encoder(self, x):
        """
        Receiving an input and using the reparameterization trick outputs the latent space representation and the mus and sigmas.
        :param x: batch of samples.
        :return: latent space representation and mus and sigmas and epsilon used to make output
        """
        mu_h1, sigma_h1 = self.encoder_module(x)
        eps = torch.randn_like(mu_h1)
        output = mu_h1 + sigma_h1 * eps
        return output, mu_h1, sigma_h1, eps

    def decoder(self, output):
        return self.decoder_module(output)

    def forward(self, x):
        output, mu_h1, sigma_h1, eps = self.encoder(x)
        return output, mu_h1, sigma_h1, eps, self.decoder(output)

    def loss(self, inputs, k=100):
        """
        Calculating the loss and the log-likelihood estimator of the model.
        :param inputs: input to calculate loss to.
        :param k: Relevant in the IWAE case, k for the tightness of the bound.
        :return: if VAE returns the lower-bound loss as specified in eq.7 in the paper, if IWAE returns the lower-bound loss as specified in eq. 14 and the log-likelihood loss
        """
        if self.vae:
            # If the model is of type VAE then use k=1
            inputs = inputs.expand(1, inputs.size()[0], 784)
            h1, mu_h1, sigma_h1, eps = self.encoder(inputs)
            p = self.decoder(h1)
            # log(q(h1|x)) = We assume a gaussian with expectation mu and std sigma
            log_qh1gx = torch.sum(-0.5 * (eps) ** 2 - torch.log(sigma_h1), -1)
            # log(p(x)) -  We assume a unit gaussian, so when we look at the log-likelihood of the data this is the estimator:
            log_ph1 = torch.sum(-0.5 * h1 ** 2, -1)
            # log(p(x|h1)) - We assume that P(x|h) is bernoulli so this is the log-likelihood estimator:
            log_pxgh1 = torch.sum(inputs * torch.log(p) + (1 - inputs) * torch.log(1 - p), -1)
            # Log laws
            log_w = log_ph1 + log_pxgh1 - log_qh1gx
            return -torch.mean(log_w)
        else:
            inputs = inputs.expand(k, inputs.size()[0], 784)
            h1, mu_h1, sigma_h1, eps = self.encoder(inputs)
            p = self.decoder(h1)
            # log(q(h1|x)) = We assume a gaussian with expectation mu and std sigma
            log_qh1gx = torch.sum(-0.5 * (eps) ** 2 - torch.log(sigma_h1), -1)
            # log(p(x)) -  We assume a unit gaussian, so when we look at the log-likelihood of the data this is the estimator:
            log_ph1 = torch.sum(-0.5 * h1 ** 2, -1)
            # log(p(x|h1)) - We assume that P(x|h) is bernoulli so this is the log-likelihood estimator:
            log_pxgh1 = torch.sum(inputs * torch.log(p) + (1 - inputs) * torch.log(1 - p), -1)
            # Log laws
            log_w = log_ph1 + log_pxgh1 - log_qh1gx
            # Log-sum-exp trick, in order to overcome underflow issues when calculation probabilities
            shift = torch.max(log_w, 0)[0]
            w_tlide = torch.exp((log_w - shift)) / torch.sum(torch.exp(log_w - shift), 0)
            w_tlide = w_tlide.detach()
            loss = -torch.mean(torch.sum(w_tlide * log_w, 0))
            log_likelihood = torch.mean(shift + torch.log(torch.sum(torch.exp(log_w - shift), 0) / k))
        return loss, -log_likelihood

    def covert_iwae_vae(self, vea_bool):
        """
        A method to convert an IWAE to a VAE and vice versa
        :param vea_bool: True - indicates turning IWAE to VAE, False - VAE to IWAE
        :return:
        """
        self.vae = vea_bool

    def Au(self, inputs):
        """
        Returns the mu parameters given an input
        :param inputs: A batch of samples
        :return: The mus of the latent variables
        """
        inputs = inputs.expand(1, inputs.size()[0], 784)
        _, mu_h1, _, _ = self.encoder(inputs)
        return mu_h1




In [4]:
def metrics_eval(iwae=None, vae=None):
    """
    Calculating the NLL loss and Active unit count on the test set for a given IWAE and/or VAE model and printing them out.
    :param iwae: Trained IWAE model
    :param vae: Trained VAE MODEL
    :return: Active Unit count
    """
    mus_vae = []
    mus_iwae = []
    vae_nll = []
    iwae_nll = []
    if vae:
        # Converting VAE to IWAE to be able to calculate L_5000 NLL loss.
        vae.covert_iwae_vae(False)
    with torch.no_grad():
        for data in tqdm(data_loader_test):
            data = torch.flatten(data, start_dim=1).cuda()
            if vae:
                mus_vae.append(vae.Au(data))
                vae_nll.append(vae.loss(data, 5000)[1].item())

            if iwae:
                mus_iwae.append(iwae.Au(data))
                iwae_nll.append(iwae.loss(data, 5000)[1].item())

        if vae:
            mus_vae = [mu.squeeze(0) for mu in mus_vae]
            mus_vae = torch.cat(mus_vae).cpu().detach().numpy()
            vae_active_units = sum(np.var(mus_vae, axis=0) > 0.01)
            vae_nll = np.mean(vae_nll)
            print(f"VAE active units:{vae_active_units}")
            print(f"VAE NLL loss:{vae_nll}")
        print()
        if iwae:
            mus_iwae = [mu.squeeze(0) for mu in mus_iwae]
            mus_iwae = torch.cat(mus_iwae).cpu().detach().numpy()
            iwae_active_units = sum(np.var(mus_iwae, axis=0) > 0.01)
            iwae_nll = np.mean(iwae_nll)
            print(f"IWAE active units:{iwae_active_units}")
            print(f"IWAE NLL loss:{iwae_nll}")

    return iwae_active_units

# Experiment 1
For k=1 an IWAE and VAE model are trained for 100 epochs, and for k=5,50 an IWAE model is trained for 100 epochs. <br> After training the model, the NLL loss and Active unit count are calculated and presented.<br>
Same as in the paper, the NLL loss is estimated using $\mathcal{L}_{5000}$

In [None]:
for k in [1, 5, 50]:
    print(f"K={k}")
    print()
    iwae = VAEorIWAE(784, VAE=False).cuda()
    vae = VAEorIWAE(784, VAE=True).cuda()
    optimizer_vae = optim.Adam(vae.parameters(), lr=0.001)
    optimizer_iwae = optim.Adam(iwae.parameters(), lr=0.001)
    num_epoches = 100
    for epoch in tqdm(range(num_epoches)):
        for data in data_loader_train:
            data = torch.flatten(data, start_dim=1).cuda()
            optimizer_iwae.zero_grad()
            loss, _ = iwae.loss(data, k)
            loss.backward()
            optimizer_iwae.step()
            if k == 1:
                optimizer_vae.zero_grad()
                loss = vae.loss(data, k)
                loss.backward()
                optimizer_vae.step()
    if k == 1:
        iwae_au = metrics_eval(iwae, vae)
    else:
        iwae_ai = metrics_eval(iwae)

K=1



100%|██████████| 100/100 [35:56<00:00, 21.57s/it]
100%|██████████| 313/313 [02:07<00:00,  2.45it/s]


VAE active units:7
VAE NLL loss:236.45030227027382

IWAE active units:7
IWAE NLL loss:236.3812139834078
K=5



100%|██████████| 100/100 [32:29<00:00, 19.49s/it]
100%|██████████| 313/313 [01:05<00:00,  4.75it/s]



IWAE active units:10
IWAE NLL loss:234.33937443254854
K=50



100%|██████████| 100/100 [32:37<00:00, 19.58s/it]
100%|██████████| 313/313 [01:05<00:00,  4.80it/s]


IWAE active units:13
IWAE NLL loss:233.14258548626884





# Experiment 2
for k = 100, An IWAE and VAE models are trained for 100 epochs and their NLL values and active units are calculated.<br>
After 100 epochs the models objective function switch and one is trained for 100 more epochs using the objective function of the other.

In [6]:
iwae = VAEorIWAE(784, VAE=False).cuda()
optimizer_iwae = optim.Adam(iwae.parameters(), lr=0.001)
vae = VAEorIWAE(784, VAE=True).cuda()
optimizer_vae = optim.Adam(vae.parameters(), lr=0.001)
num_epoches = 200
train_loss_epoch_iwae = []
train_loss_epoch_vae = []
for epoch in tqdm(range(num_epoches)):
    running_loss_1 = []
    running_loss_2 = []
    if epoch == int(num_epoches / 2):
        _ = metrics_eval(iwae, vae)
        iwae.covert_iwae_vae(True)
        vae.covert_iwae_vae(False)
    for data in data_loader_train:
        data = torch.flatten(data, start_dim=1).cuda()
        optimizer_iwae.zero_grad()
        if not iwae.vae:
            loss, _ = iwae.loss(data)
        else:
            loss = iwae.loss(data)
        loss.backward()
        optimizer_iwae.step()
        running_loss_1.append(loss.item())

        optimizer_vae.zero_grad()
        if vae.vae:
            loss = vae.loss(data)
        else:
            loss, _ = vae.loss(data)
        loss.backward()
        optimizer_vae.step()
        running_loss_1.append(loss.item())

    train_loss_epoch_iwae.append(np.mean(running_loss_1))
_ = metrics_eval(vae, iwae)



 50%|█████     | 100/200 [33:53<34:10, 20.51s/it]
  0%|          | 0/313 [00:00<?, ?it/s][A
  0%|          | 1/313 [00:00<00:59,  5.27it/s][A
  1%|          | 2/313 [00:00<01:00,  5.15it/s][A
  1%|          | 3/313 [00:00<01:00,  5.12it/s][A
  1%|▏         | 4/313 [00:00<01:00,  5.12it/s][A
  2%|▏         | 5/313 [00:00<01:00,  5.06it/s][A
  2%|▏         | 6/313 [00:01<01:00,  5.06it/s][A
  2%|▏         | 7/313 [00:01<01:00,  5.05it/s][A
  3%|▎         | 8/313 [00:01<01:00,  5.05it/s][A
  3%|▎         | 9/313 [00:01<00:59,  5.07it/s][A
  3%|▎         | 10/313 [00:01<00:59,  5.05it/s][A
  4%|▎         | 11/313 [00:02<00:59,  5.05it/s][A
  4%|▍         | 12/313 [00:02<00:59,  5.05it/s][A
  4%|▍         | 13/313 [00:02<00:59,  5.05it/s][A
  4%|▍         | 14/313 [00:02<00:59,  5.05it/s][A
  5%|▍         | 15/313 [00:02<00:59,  5.05it/s][A
  5%|▌         | 16/313 [00:03<00:58,  5.04it/s][A
  5%|▌         | 17/313 [00:03<00:58,  5.04it/s][A
  6%|▌         | 18/313 [00:03<0

VAE active units:7
VAE NLL loss:236.5087579599204

IWAE active units:19
IWAE NLL loss:233.13913997369832


100%|██████████| 200/200 [1:08:52<00:00, 20.66s/it]
100%|██████████| 313/313 [01:00<00:00,  5.16it/s]

VAE active units:14
VAE NLL loss:233.62889284371568

IWAE active units:28
IWAE NLL loss:233.0190677338134



