In [1]:
import sys

if "google.colab" in sys.modules:
  ! git clone https://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@github.com/ATML-2022-Group6/inference-suboptimality
  ! cp -r inference-suboptimality/* .
  ! tar -xvf datasets/mnist.pkl.tar.gz && mv mnist.pkl datasets/

Cloning into 'inference-suboptimality'...
remote: Enumerating objects: 939, done.[K
remote: Counting objects: 100% (939/939), done.[K
remote: Compressing objects: 100% (649/649), done.[K
remote: Total 939 (delta 497), reused 677 (delta 276), pack-reused 0[K
Receiving objects: 100% (939/939), 215.67 MiB | 32.74 MiB/s, done.
Resolving deltas: 100% (497/497), done.
Checking out files: 100% (153/153), done.
mnist.pkl


In [2]:
%load_ext autoreload
%autoreload 2

import pathlib

import utils
import vae
from datasets import get_mnist, get_batches, get_fashion_mnist
from local_opt import LocalHyperParams, local_opt
from utils import HyperParams

from jax.config import config
config.update("jax_debug_nans", True)

In [3]:
use_tpu = False
if use_tpu and "google.colab" in sys.modules:
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()

In [15]:
# -- Vary across experiment
is_larger = True
trained_has_flow = False
local_has_flow = False
kl_annealing = True
use_fashion = True
# --- 

## Name of this experiment (important to change for saving results)
name = "_".join([
  ["mnist","fashion"][use_fashion],
  ["ffg","flow"][trained_has_flow],
  ["regular","anneal"][kl_annealing],
  ["smaller","larger"][is_larger],
  ["local-ffg", "local-flow"][local_has_flow],
])
print(name)

fashion_ffg_anneal_larger_local-ffg


In [16]:
mount_google_drive = False

if mount_google_drive and "google.colab" in sys.modules:
  from google.colab import drive
  drive.mount("/content/drive")
  save_dir = "/content/drive/My Drive/ATML/" + name
else:
  save_dir = "./experiments/" + name

pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)

In [6]:
hidden_size = (500, 500, 500) if is_larger else (200, 200)
hps = HyperParams(has_flow=local_has_flow, encoder_hidden=hidden_size, decoder_hidden=hidden_size)
local_hps = LocalHyperParams()

print(hps)
print(local_hps)

HyperParams(image_size=784, latent_size=50, encoder_hidden=(500, 500, 500), decoder_hidden=(500, 500, 500), has_flow=False, num_flows=2, flow_hidden_size=200)
LocalHyperParams(learning_rate=0.001, mc_samples=100, display_epoch=100, debug=True, iwae_samples=100, max_epochs=100000, patience=10, es_epsilon=0.05)


In [7]:
mnist = get_fashion_mnist() if use_fashion else get_mnist()

In [8]:
batch_size = 100
# only locally optimise 1000 due to computational cost
smaller_data = True
train_batches = get_batches(mnist["train_x"], batch_size, smaller_data)

In [9]:
# remove "local-ffg" or "local-flow" suffix 
trained_model_dir = save_dir[:save_dir.rfind("_")]
file_name = trained_model_dir + "/params.pkl"
trained_params = utils.load_params(file_name)


In [10]:
model = vae.VAE(hps)
elbos, iwaes, local_params = local_opt(local_hps, model, train_batches, trained_params)

Optimising Local FFG ...


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -83.9463
Epoch 200.0000 - ELBO -82.7349
Epoch 300.0000 - ELBO -82.2681
Epoch 400.0000 - ELBO -82.0265
Epoch 500.0000 - ELBO -81.8844
Epoch 600.0000 - ELBO -81.8085
Epoch 700.0000 - ELBO -81.7561
Epoch 800.0000 - ELBO -81.7241
Epoch 900.0000 - ELBO -81.7023
Epoch 1000.0000 - ELBO -81.6875
Epoch 1100.0000 - ELBO -81.6766
Epoch 1200.0000 - ELBO -81.6741
Epoch 1300.0000 - ELBO -81.6685
Epoch 1400.0000 - ELBO -81.6634
Epoch 1500.0000 - ELBO -81.6646
Epoch 1600.0000 - ELBO -81.6602
Epoch 1700.0000 - ELBO -81.6590
Epoch 1800.0000 - ELBO -81.6575
Epoch 1900.0000 - ELBO -81.6585
Batch 1, ELBO -81.6598, IWAE -78.3467


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -88.9666
Epoch 200.0000 - ELBO -87.7669
Epoch 300.0000 - ELBO -87.3235
Epoch 400.0000 - ELBO -87.0935
Epoch 500.0000 - ELBO -86.9708
Epoch 600.0000 - ELBO -86.8991
Epoch 700.0000 - ELBO -86.8611
Epoch 800.0000 - ELBO -86.8355
Epoch 900.0000 - ELBO -86.8251
Epoch 1000.0000 - ELBO -86.8046
Epoch 1100.0000 - ELBO -86.7999
Epoch 1200.0000 - ELBO -86.7965
Epoch 1300.0000 - ELBO -86.7945
Epoch 1400.0000 - ELBO -86.7937
Epoch 1500.0000 - ELBO -86.7923
Epoch 1600.0000 - ELBO -86.7931
Epoch 1700.0000 - ELBO -86.7928
Epoch 1800.0000 - ELBO -86.7877
Batch 2, ELBO -86.7775, IWAE -83.2561


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -88.3407
Epoch 200.0000 - ELBO -87.0666
Epoch 300.0000 - ELBO -86.5846
Epoch 400.0000 - ELBO -86.3293
Epoch 500.0000 - ELBO -86.1873
Epoch 600.0000 - ELBO -86.0981
Epoch 700.0000 - ELBO -86.0441
Epoch 800.0000 - ELBO -86.0136
Epoch 900.0000 - ELBO -85.9962
Epoch 1000.0000 - ELBO -85.9775
Epoch 1100.0000 - ELBO -85.9735
Epoch 1200.0000 - ELBO -85.9733
Epoch 1300.0000 - ELBO -85.9624
Epoch 1400.0000 - ELBO -85.9635
Epoch 1500.0000 - ELBO -85.9552
Epoch 1600.0000 - ELBO -85.9536
Epoch 1700.0000 - ELBO -85.9584
Epoch 1800.0000 - ELBO -85.9545
Epoch 1900.0000 - ELBO -85.9539
Epoch 2000.0000 - ELBO -85.9525
Batch 3, ELBO -85.9202, IWAE -82.4228


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -84.4191
Epoch 200.0000 - ELBO -83.1176
Epoch 300.0000 - ELBO -82.6245
Epoch 400.0000 - ELBO -82.3797
Epoch 500.0000 - ELBO -82.2453
Epoch 600.0000 - ELBO -82.1769
Epoch 700.0000 - ELBO -82.1260
Epoch 800.0000 - ELBO -82.0996
Epoch 900.0000 - ELBO -82.0771
Epoch 1000.0000 - ELBO -82.0772
Epoch 1100.0000 - ELBO -82.0649
Epoch 1200.0000 - ELBO -82.0591
Epoch 1300.0000 - ELBO -82.0612
Epoch 1400.0000 - ELBO -82.0582
Epoch 1500.0000 - ELBO -82.0543
Epoch 1600.0000 - ELBO -82.0516
Epoch 1700.0000 - ELBO -82.0479
Epoch 1800.0000 - ELBO -82.0555
Epoch 1900.0000 - ELBO -82.0503
Epoch 2000.0000 - ELBO -82.0511
Epoch 2100.0000 - ELBO -82.0552
Batch 4, ELBO -82.0029, IWAE -78.4910


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -94.2702
Epoch 200.0000 - ELBO -92.9326
Epoch 300.0000 - ELBO -92.4253
Epoch 400.0000 - ELBO -92.1700
Epoch 500.0000 - ELBO -92.0191
Epoch 600.0000 - ELBO -91.9327
Epoch 700.0000 - ELBO -91.8807
Epoch 800.0000 - ELBO -91.8552
Epoch 900.0000 - ELBO -91.8349
Epoch 1000.0000 - ELBO -91.8188
Epoch 1100.0000 - ELBO -91.8150
Epoch 1200.0000 - ELBO -91.8153
Epoch 1300.0000 - ELBO -91.8063
Epoch 1400.0000 - ELBO -91.8076
Epoch 1500.0000 - ELBO -91.8005
Epoch 1600.0000 - ELBO -91.8016
Epoch 1700.0000 - ELBO -91.8038
Epoch 1800.0000 - ELBO -91.8006
Epoch 1900.0000 - ELBO -91.7979
Epoch 2000.0000 - ELBO -91.8031
Batch 5, ELBO -91.8026, IWAE -88.3498


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -86.9500
Epoch 200.0000 - ELBO -85.7299
Epoch 300.0000 - ELBO -85.2737
Epoch 400.0000 - ELBO -85.0282
Epoch 500.0000 - ELBO -84.8916
Epoch 600.0000 - ELBO -84.8112
Epoch 700.0000 - ELBO -84.7627
Epoch 800.0000 - ELBO -84.7262
Epoch 900.0000 - ELBO -84.7161
Epoch 1000.0000 - ELBO -84.7002
Epoch 1100.0000 - ELBO -84.6954
Epoch 1200.0000 - ELBO -84.6843
Epoch 1300.0000 - ELBO -84.6828
Epoch 1400.0000 - ELBO -84.6719
Epoch 1500.0000 - ELBO -84.6801
Epoch 1600.0000 - ELBO -84.6779
Epoch 1700.0000 - ELBO -84.6767
Epoch 1800.0000 - ELBO -84.6764
Epoch 1900.0000 - ELBO -84.6791
Epoch 2000.0000 - ELBO -84.6763
Epoch 2100.0000 - ELBO -84.6757
Epoch 2200.0000 - ELBO -84.6873
Epoch 2300.0000 - ELBO -84.6689
Epoch 2400.0000 - ELBO -84.6716
Batch 6, ELBO -84.6387, IWAE -81.2869


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -83.2782
Epoch 200.0000 - ELBO -82.0397
Epoch 300.0000 - ELBO -81.5624
Epoch 400.0000 - ELBO -81.3071
Epoch 500.0000 - ELBO -81.1642
Epoch 600.0000 - ELBO -81.0776
Epoch 700.0000 - ELBO -81.0253
Epoch 800.0000 - ELBO -80.9903
Epoch 900.0000 - ELBO -80.9703
Epoch 1000.0000 - ELBO -80.9574
Epoch 1100.0000 - ELBO -80.9487
Epoch 1200.0000 - ELBO -80.9378
Epoch 1300.0000 - ELBO -80.9435
Epoch 1400.0000 - ELBO -80.9390
Epoch 1500.0000 - ELBO -80.9367
Epoch 1600.0000 - ELBO -80.9266
Epoch 1700.0000 - ELBO -80.9286
Epoch 1800.0000 - ELBO -80.9285
Epoch 1900.0000 - ELBO -80.9315
Batch 7, ELBO -80.9158, IWAE -77.5602


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -83.8075
Epoch 200.0000 - ELBO -82.4855
Epoch 300.0000 - ELBO -81.9795
Epoch 400.0000 - ELBO -81.7128
Epoch 500.0000 - ELBO -81.5569
Epoch 600.0000 - ELBO -81.4745
Epoch 700.0000 - ELBO -81.4190
Epoch 800.0000 - ELBO -81.3813
Epoch 900.0000 - ELBO -81.3650
Epoch 1000.0000 - ELBO -81.3477
Epoch 1100.0000 - ELBO -81.3431
Epoch 1200.0000 - ELBO -81.3386
Epoch 1300.0000 - ELBO -81.3372
Epoch 1400.0000 - ELBO -81.3307
Epoch 1500.0000 - ELBO -81.3315
Epoch 1600.0000 - ELBO -81.3340
Epoch 1700.0000 - ELBO -81.3323
Epoch 1800.0000 - ELBO -81.3286
Epoch 1900.0000 - ELBO -81.3333
Batch 8, ELBO -81.3245, IWAE -77.9907


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -83.9814
Epoch 200.0000 - ELBO -82.7694
Epoch 300.0000 - ELBO -82.3003
Epoch 400.0000 - ELBO -82.0551
Epoch 500.0000 - ELBO -81.9188
Epoch 600.0000 - ELBO -81.8364
Epoch 700.0000 - ELBO -81.7834
Epoch 800.0000 - ELBO -81.7544
Epoch 900.0000 - ELBO -81.7403
Epoch 1000.0000 - ELBO -81.7228
Epoch 1100.0000 - ELBO -81.7182
Epoch 1200.0000 - ELBO -81.7141
Epoch 1300.0000 - ELBO -81.7161
Epoch 1400.0000 - ELBO -81.7063
Epoch 1500.0000 - ELBO -81.7053
Epoch 1600.0000 - ELBO -81.6991
Epoch 1700.0000 - ELBO -81.7001
Epoch 1800.0000 - ELBO -81.7018
Epoch 1900.0000 - ELBO -81.7029
Epoch 2000.0000 - ELBO -81.7044
Batch 9, ELBO -81.7555, IWAE -78.3622


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -81.3218
Epoch 200.0000 - ELBO -79.9979
Epoch 300.0000 - ELBO -79.5019
Epoch 400.0000 - ELBO -79.2407
Epoch 500.0000 - ELBO -79.0914
Epoch 600.0000 - ELBO -79.0097
Epoch 700.0000 - ELBO -78.9501
Epoch 800.0000 - ELBO -78.9195
Epoch 900.0000 - ELBO -78.9031
Epoch 1000.0000 - ELBO -78.8908
Epoch 1100.0000 - ELBO -78.8825
Epoch 1200.0000 - ELBO -78.8744
Epoch 1300.0000 - ELBO -78.8733
Epoch 1400.0000 - ELBO -78.8662
Epoch 1500.0000 - ELBO -78.8647
Epoch 1600.0000 - ELBO -78.8709
Epoch 1700.0000 - ELBO -78.8650
Epoch 1800.0000 - ELBO -78.8672
Epoch 1900.0000 - ELBO -78.8644
Epoch 2000.0000 - ELBO -78.8704
Batch 10, ELBO -78.8232, IWAE -75.4614
Average ELBO -83.5621
Average IWAE -80.1528


In [17]:
utils.save_params(save_dir + "/elbos.pkl", elbos)
utils.save_params(save_dir + "/iwaes.pkl", iwaes)
utils.save_params(save_dir + "/local_params.pkl", local_params)

print(save_dir)

./experiments/fashion_ffg_anneal_larger_local-ffg
