<a href="https://colab.research.google.com/github/ATML-2022-Group6/inference-suboptimality/blob/main/run_ais.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys

if "google.colab" in sys.modules:
  ! git clone https://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@github.com/ATML-2022-Group6/inference-suboptimality
  ! cp -r inference-suboptimality/* .
  ! tar -xvf datasets/mnist.pkl.tar.gz && mv mnist.pkl datasets/

In [2]:
%load_ext autoreload
%autoreload 2

import pathlib

import jax
from jax import numpy as jnp
from jax import random

from tqdm.notebook import tqdm

from ais import batch_ais_iwelbo, AISHyperParams
from datasets import get_batches, get_dataset
from utils import HyperParams, load_params
from vae import VAE

In [3]:
# -- Vary across experiments
encoder_size = (200, 200)
decoder_size = (200, 200)
has_flow = False
kl_annealing = True
dataset_name = "mnist" # mnist, fashion, kmnist
# --- 

def fmt_size(sizes): return "".join(str(size)[0] for size in list(sizes))

## Name of this experiment (important to change for saving results)
name = "_".join([
  dataset_name,
  ["ffg","flow"][has_flow],
  ["regular","anneal"][kl_annealing],
  "e"+fmt_size(encoder_size)+"d"+fmt_size(decoder_size),
])
print(name)

mnist_ffg_anneal_smaller


In [4]:
mount_google_drive = False

if mount_google_drive and "google.colab" in sys.modules:
  from google.colab import drive
  drive.mount("/content/drive")
  save_dir = "/content/drive/My Drive/ATML/" + name
else:
  save_dir = "./experiments/" + name

pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)

In [5]:
hps = HyperParams(has_flow=has_flow, encoder_hidden=encoder_size, decoder_hidden=decoder_size)
ais_hps = AISHyperParams()

print(hps)
print(ais_hps)

HyperParams(image_size=784, latent_size=50, encoder_hidden=(200, 200), decoder_hidden=(200, 200), has_flow=False, num_flows=2, flow_hidden_size=200)
AISHyperParams(num_iwae_samples=10, annealing_steps=10000)


In [6]:
use_tpu = True
if use_tpu and "google.colab" in sys.modules:
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()

In [7]:
mnist = get_dataset(dataset_name)



In [8]:
batch_size = 10
smaller_data = True # Due to computational cost limit to 1000 images
train_batches = get_batches(mnist["train_x"], batch_size, smaller_data)

In [9]:
file_name = save_dir + "/params.pkl"
params = load_params(file_name)
decoder_params = params[1]

In [11]:
model = VAE(hps)
ais_iwelbos = []
for i, batch in enumerate(tqdm(train_batches)):
  rng = random.PRNGKey(i)
  ais_iwelbo = batch_ais_iwelbo(ais_hps, model, decoder_params, batch, rng)
  ais_iwelbos.append(ais_iwelbo)
  print("AIS", i, "-", ais_iwelbo)

print("Average AIS:", jnp.nanmean(ais_iwelbo))

  0%|          | 0/100 [00:00<?, ?it/s]

AIS 0 - -84.48168
