<a href="https://colab.research.google.com/github/ATML-2022-Group6/inference-suboptimality/blob/main/run_local_opt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import sys

if "google.colab" in sys.modules:
  ! git clone https://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@github.com/ATML-2022-Group6/inference-suboptimality
  ! cp -r inference-suboptimality/* .
  ! tar -xvf datasets/mnist.pkl.tar.gz && mv mnist.pkl datasets/

In [None]:
%load_ext autoreload
%autoreload 2

import pathlib

import utils
import vae
from datasets import get_batches, get_dataset
from local_opt import LocalHyperParams, local_opt
from utils import HyperParams

In [None]:
use_tpu = True
if use_tpu and "google.colab" in sys.modules:
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()

In [None]:
# -- Vary across experiment
encoder_size = (200, 200)
decoder_size = (200, 200)
trained_has_flow = False
local_has_flow = False
kl_annealing = True
dataset_name = "mnist" # mnist, fashion, kmnist
# --- 

def fmt_size(sizes): return "".join(str(size)[0] for size in list(sizes))

## Name of this experiment (important to change for saving results)
name = "_".join([
  dataset_name,
  ["ffg","flow"][trained_has_flow],
  ["regular","anneal"][kl_annealing],
  "e"+fmt_size(encoder_size)+"d"+fmt_size(decoder_size),
  ["local-ffg", "local-flow"][local_has_flow],
])
print(name)

mnist_ffg_anneal_smaller_local-ffg


In [None]:
mount_google_drive = False

if mount_google_drive and "google.colab" in sys.modules:
  from google.colab import drive
  drive.mount("/content/drive")
  save_dir = "/content/drive/My Drive/ATML/" + name
else:
  save_dir = "./experiments/" + name

pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)

In [None]:
hps = HyperParams(has_flow=local_has_flow, encoder_hidden=encoder_size, decoder_hidden=decoder_size)
local_hps = LocalHyperParams()

print(hps)
print(local_hps)

HyperParams(image_size=784, latent_size=50, encoder_hidden=(200, 200), decoder_hidden=(200, 200), has_flow=False, num_flows=2, flow_hidden_size=200)
LocalHyperParams(learning_rate=0.001, mc_samples=100, display_epoch=100, debug=True, iwae_samples=100, max_epochs=100000, patience=10, es_epsilon=0.05)


In [None]:
mnist = get_dataset(dataset_name)

In [None]:
batch_size = 100
# only locally optimise 1000 due to computational cost
smaller_data = True
train_batches = get_batches(mnist["train_x"], batch_size, smaller_data)

In [None]:
# remove "local-ffg" or "local-flow" suffix 
trained_model_dir = save_dir[:save_dir.rfind("_")]
file_name = trained_model_dir + "/params.pkl"
trained_params = utils.load_params(file_name)

In [None]:
model = vae.VAE(hps)
elbos, iwaes, local_params = local_opt(local_hps, model, train_batches, trained_params)

Optimising Local Flow ...


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

Batch 1, ELBO -110.7355, IWAE -109.9645


  0%|          | 0/100000 [00:00<?, ?it/s]

Batch 2, ELBO -100.9888, IWAE -100.1293


  0%|          | 0/100000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
utils.save_params(save_dir + "/elbos.pkl", elbos)
utils.save_params(save_dir + "/iwaes.pkl", iwaes)
utils.save_params(save_dir + "/local_params.pkl", local_params)