<a href="https://colab.research.google.com/github/ATML-2022-Group6/inference-suboptimality/blob/main/run_local_opt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys

if "google.colab" in sys.modules:
  ! git clone https://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@github.com/ATML-2022-Group6/inference-suboptimality
  ! cp -r inference-suboptimality/* .
  ! tar -xvf datasets/mnist.pkl.tar.gz && mv mnist.pkl datasets/

Cloning into 'inference-suboptimality'...
remote: Enumerating objects: 1137, done.[K
remote: Counting objects: 100% (1137/1137), done.[K
remote: Compressing objects: 100% (798/798), done.[K
remote: Total 1137 (delta 582), reused 839 (delta 325), pack-reused 0[K
Receiving objects: 100% (1137/1137), 326.79 MiB | 32.41 MiB/s, done.
Resolving deltas: 100% (582/582), done.
Checking out files: 100% (238/238), done.
mnist.pkl


In [2]:
%load_ext autoreload
%autoreload 2

import pathlib

import utils
import vae
from datasets import get_batches, get_dataset
from local_opt import LocalHyperParams, local_opt
from utils import HyperParams

In [4]:
use_tpu = False
if use_tpu and "google.colab" in sys.modules:
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()

In [19]:
# -- Vary across experiment
encoder_size = (200, 200)
decoder_size = ()
trained_has_flow = False
local_has_flow = False
kl_annealing = True
dataset_name = "mnist" # mnist, fashion, kmnist
# --- 

def fmt_size(sizes): return "".join(str(size)[0] for size in list(sizes))

## Name of this experiment (important to change for saving results)
name = "_".join([
  dataset_name,
  ["ffg","flow"][trained_has_flow],
  ["regular","anneal"][kl_annealing],
  "e"+fmt_size(encoder_size)+"d"+fmt_size(decoder_size),
  ["local-ffg", "local-flow"][local_has_flow],
])
print(name)

mnist_ffg_anneal_e22d_local-ffg


In [20]:
mount_google_drive = True

if mount_google_drive and "google.colab" in sys.modules:
  from google.colab import drive
  drive.mount("/content/drive")
  save_dir = "/content/drive/My Drive/ATML/" + name
else:
  save_dir = "./experiments/" + name

pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [21]:
hps = HyperParams(has_flow=local_has_flow, encoder_hidden=encoder_size, decoder_hidden=decoder_size)
local_hps = LocalHyperParams()

print(hps)
print(local_hps)

HyperParams(image_size=784, latent_size=50, encoder_hidden=(200, 200), decoder_hidden=(), has_flow=False, num_flows=2, flow_hidden_size=200)
LocalHyperParams(learning_rate=0.001, mc_samples=100, display_epoch=100, debug=True, iwae_samples=100, max_epochs=100000, patience=10, es_epsilon=0.05)


In [22]:
mnist = get_dataset(dataset_name)

In [23]:
batch_size = 100
# only locally optimise 1000 due to computational cost
smaller_data = True
train_batches = get_batches(mnist["train_x"], batch_size, smaller_data)

In [24]:
# remove "local-ffg" or "local-flow" suffix 
trained_model_dir = save_dir[:save_dir.rfind("_")]
file_name = trained_model_dir + "/params.pkl"
trained_params = utils.load_params(file_name)

In [25]:
model = vae.VAE(hps)
elbos, iwaes, local_params = local_opt(local_hps, model, train_batches, trained_params)

Optimising Local FFG ...


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -111.6068
Epoch 200.0000 - ELBO -111.4507
Epoch 300.0000 - ELBO -111.4142
Epoch 400.0000 - ELBO -111.4038
Epoch 500.0000 - ELBO -111.3959
Epoch 600.0000 - ELBO -111.4040
Epoch 700.0000 - ELBO -111.3973
Epoch 800.0000 - ELBO -111.4027
Epoch 900.0000 - ELBO -111.3999
Epoch 1000.0000 - ELBO -111.4010
Epoch 1100.0000 - ELBO -111.4012
Epoch 1200.0000 - ELBO -111.4012
Epoch 1300.0000 - ELBO -111.3992
Epoch 1400.0000 - ELBO -111.3981
Epoch 1500.0000 - ELBO -111.3988
Batch 1, ELBO -111.3702, IWAE -108.8787


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -111.5322
Epoch 200.0000 - ELBO -111.3870
Epoch 300.0000 - ELBO -111.3543
Epoch 400.0000 - ELBO -111.3470
Epoch 500.0000 - ELBO -111.3448
Epoch 600.0000 - ELBO -111.3452
Epoch 700.0000 - ELBO -111.3439
Epoch 800.0000 - ELBO -111.3391
Epoch 900.0000 - ELBO -111.3421
Epoch 1000.0000 - ELBO -111.3426
Epoch 1100.0000 - ELBO -111.3396
Epoch 1200.0000 - ELBO -111.3427
Batch 2, ELBO -111.3571, IWAE -108.8526


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -112.7938
Epoch 200.0000 - ELBO -112.6433
Epoch 300.0000 - ELBO -112.6110
Epoch 400.0000 - ELBO -112.6032
Epoch 500.0000 - ELBO -112.6034
Epoch 600.0000 - ELBO -112.6027
Epoch 700.0000 - ELBO -112.6014
Epoch 800.0000 - ELBO -112.5989
Epoch 900.0000 - ELBO -112.6004
Epoch 1000.0000 - ELBO -112.5987
Epoch 1100.0000 - ELBO -112.5982
Epoch 1200.0000 - ELBO -112.5982
Batch 3, ELBO -112.6443, IWAE -110.0219


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -108.7966
Epoch 200.0000 - ELBO -108.6598
Epoch 300.0000 - ELBO -108.6324
Epoch 400.0000 - ELBO -108.6270
Epoch 500.0000 - ELBO -108.6189
Epoch 600.0000 - ELBO -108.6177
Epoch 700.0000 - ELBO -108.6205
Epoch 800.0000 - ELBO -108.6198
Epoch 900.0000 - ELBO -108.6191
Epoch 1000.0000 - ELBO -108.6165
Epoch 1100.0000 - ELBO -108.6173
Epoch 1200.0000 - ELBO -108.6155
Batch 4, ELBO -108.6012, IWAE -106.2071


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -113.8542
Epoch 200.0000 - ELBO -113.7040
Epoch 300.0000 - ELBO -113.6712
Epoch 400.0000 - ELBO -113.6573
Epoch 500.0000 - ELBO -113.6536
Epoch 600.0000 - ELBO -113.6518
Epoch 700.0000 - ELBO -113.6530
Epoch 800.0000 - ELBO -113.6531
Epoch 900.0000 - ELBO -113.6548
Epoch 1000.0000 - ELBO -113.6559
Epoch 1100.0000 - ELBO -113.6555
Epoch 1200.0000 - ELBO -113.6513
Epoch 1300.0000 - ELBO -113.6520
Epoch 1400.0000 - ELBO -113.6510
Epoch 1500.0000 - ELBO -113.6530
Batch 5, ELBO -113.6741, IWAE -110.9930


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -115.8328
Epoch 200.0000 - ELBO -115.6885
Epoch 300.0000 - ELBO -115.6519
Epoch 400.0000 - ELBO -115.6447
Epoch 500.0000 - ELBO -115.6364
Epoch 600.0000 - ELBO -115.6393
Epoch 700.0000 - ELBO -115.6380
Epoch 800.0000 - ELBO -115.6397
Epoch 900.0000 - ELBO -115.6372
Epoch 1000.0000 - ELBO -115.6341
Epoch 1100.0000 - ELBO -115.6395
Epoch 1200.0000 - ELBO -115.6412
Epoch 1300.0000 - ELBO -115.6401
Epoch 1400.0000 - ELBO -115.6370
Epoch 1500.0000 - ELBO -115.6396
Batch 6, ELBO -115.5863, IWAE -112.9992


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -112.0582
Epoch 200.0000 - ELBO -111.9044
Epoch 300.0000 - ELBO -111.8716
Epoch 400.0000 - ELBO -111.8537
Epoch 500.0000 - ELBO -111.8520
Epoch 600.0000 - ELBO -111.8499
Epoch 700.0000 - ELBO -111.8525
Epoch 800.0000 - ELBO -111.8503
Epoch 900.0000 - ELBO -111.8494
Epoch 1000.0000 - ELBO -111.8463
Epoch 1100.0000 - ELBO -111.8516
Epoch 1200.0000 - ELBO -111.8480
Epoch 1300.0000 - ELBO -111.8522
Epoch 1400.0000 - ELBO -111.8480
Batch 7, ELBO -111.8503, IWAE -109.2739


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -115.9176
Epoch 200.0000 - ELBO -115.7562
Epoch 300.0000 - ELBO -115.7192
Epoch 400.0000 - ELBO -115.7051
Epoch 500.0000 - ELBO -115.7017
Epoch 600.0000 - ELBO -115.7013
Epoch 700.0000 - ELBO -115.7049
Epoch 800.0000 - ELBO -115.6982
Epoch 900.0000 - ELBO -115.7002
Epoch 1000.0000 - ELBO -115.6986
Epoch 1100.0000 - ELBO -115.7043
Epoch 1200.0000 - ELBO -115.7039
Epoch 1300.0000 - ELBO -115.7016
Epoch 1400.0000 - ELBO -115.7035
Batch 8, ELBO -115.6624, IWAE -113.1382


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -112.3042
Epoch 200.0000 - ELBO -112.1448
Epoch 300.0000 - ELBO -112.1179
Epoch 400.0000 - ELBO -112.1040
Epoch 500.0000 - ELBO -112.0999
Epoch 600.0000 - ELBO -112.1007
Epoch 700.0000 - ELBO -112.0996
Epoch 800.0000 - ELBO -112.0976
Epoch 900.0000 - ELBO -112.0972
Epoch 1000.0000 - ELBO -112.0979
Epoch 1100.0000 - ELBO -112.1006
Epoch 1200.0000 - ELBO -112.1001
Batch 9, ELBO -112.0877, IWAE -109.6011


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -114.5190
Epoch 200.0000 - ELBO -114.3631
Epoch 300.0000 - ELBO -114.3217
Epoch 400.0000 - ELBO -114.3101
Epoch 500.0000 - ELBO -114.3019
Epoch 600.0000 - ELBO -114.3058
Epoch 700.0000 - ELBO -114.3006
Epoch 800.0000 - ELBO -114.3019
Epoch 900.0000 - ELBO -114.3063
Epoch 1000.0000 - ELBO -114.2993
Epoch 1100.0000 - ELBO -114.3017
Epoch 1200.0000 - ELBO -114.3000
Epoch 1300.0000 - ELBO -114.3034
Epoch 1400.0000 - ELBO -114.3020
Batch 10, ELBO -114.3012, IWAE -111.7310
Average ELBO -112.7135
Average IWAE -110.1697


In [26]:
utils.save_params(save_dir + "/elbos.pkl", elbos)
utils.save_params(save_dir + "/iwaes.pkl", iwaes)
utils.save_params(save_dir + "/local_params.pkl", local_params)