<a href="https://colab.research.google.com/github/ATML-2022-Group6/inference-suboptimality/blob/main/run_local_opt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys

if "google.colab" in sys.modules:
  ! git clone https://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@github.com/ATML-2022-Group6/inference-suboptimality
  ! cp -r inference-suboptimality/* .
  ! tar -xvf datasets/mnist.pkl.tar.gz && mv mnist.pkl datasets/

Cloning into 'inference-suboptimality'...
remote: Enumerating objects: 903, done.[K
remote: Counting objects: 100% (903/903), done.[K
remote: Compressing objects: 100% (628/628), done.[K
remote: Total 903 (delta 477), reused 646 (delta 261), pack-reused 0[K
Receiving objects: 100% (903/903), 172.83 MiB | 34.42 MiB/s, done.
Resolving deltas: 100% (477/477), done.
Checking out files: 100% (129/129), done.
mnist.pkl


In [2]:
%load_ext autoreload
%autoreload 2

import pathlib

import utils
import vae
from datasets import get_mnist, get_batches, get_fashion_mnist
from local_opt import LocalHyperParams, local_opt
from utils import HyperParams

In [3]:
use_tpu = True
if use_tpu and "google.colab" in sys.modules:
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()

In [4]:
# -- Vary across experiment
is_larger = False
trained_has_flow = False
local_has_flow = False
kl_annealing = True
use_fashion = True
# --- 

## Name of this experiment (important to change for saving results)
name = "_".join([
  ["mnist","fashion"][use_fashion],
  ["ffg","flow"][trained_has_flow],
  ["regular","anneal"][kl_annealing],
  ["smaller","larger"][is_larger],
  ["local-ffg", "local-flow"][local_has_flow],
])
print(name)

fashion_ffg_anneal_smaller_local-ffg


In [5]:
mount_google_drive = False

if mount_google_drive and "google.colab" in sys.modules:
  from google.colab import drive
  drive.mount("/content/drive")
  save_dir = "/content/drive/My Drive/ATML/" + name
else:
  save_dir = "./experiments/" + name

pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)

In [6]:
hidden_size = (500, 500, 500) if is_larger else (200, 200)
hps = HyperParams(has_flow=local_has_flow, encoder_hidden=hidden_size, decoder_hidden=hidden_size)
local_hps = LocalHyperParams()

print(hps)
print(local_hps)

HyperParams(image_size=784, latent_size=50, encoder_hidden=(200, 200), decoder_hidden=(200, 200), has_flow=False, num_flows=2, flow_hidden_size=200)
LocalHyperParams(learning_rate=0.001, mc_samples=100, display_epoch=100, debug=True, iwae_samples=100, max_epochs=100000, patience=10, es_epsilon=0.05)


In [7]:
mnist = get_fashion_mnist() if use_fashion else get_mnist()

In [8]:
batch_size = 100
# only locally optimise 1000 due to computational cost
smaller_data = True
train_batches = get_batches(mnist["train_x"], batch_size, smaller_data)

In [9]:
# remove "local-ffg" or "local-flow" suffix 
trained_model_dir = save_dir[:save_dir.rfind("_")]
file_name = trained_model_dir + "/params.pkl"
trained_params = utils.load_params(file_name)

In [10]:
model = vae.VAE(hps)
elbos, iwaes, local_params = local_opt(local_hps, model, train_batches, trained_params)

Optimising Local FFG ...


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -100.2317
Epoch 200.0000 - ELBO -98.9216
Epoch 300.0000 - ELBO -98.3328
Epoch 400.0000 - ELBO -98.0059
Epoch 500.0000 - ELBO -97.8244
Epoch 600.0000 - ELBO -97.7151
Epoch 700.0000 - ELBO -97.6429
Epoch 800.0000 - ELBO -97.6067
Epoch 900.0000 - ELBO -97.5755
Epoch 1000.0000 - ELBO -97.5592
Epoch 1100.0000 - ELBO -97.5495
Epoch 1200.0000 - ELBO -97.5435
Epoch 1300.0000 - ELBO -97.5388
Epoch 1400.0000 - ELBO -97.5365
Epoch 1500.0000 - ELBO -97.5359
Epoch 1600.0000 - ELBO -97.5353
Epoch 1700.0000 - ELBO -97.5382
Epoch 1800.0000 - ELBO -97.5334
Epoch 1900.0000 - ELBO -97.5328
Batch 1, ELBO -97.5377, IWAE -94.8769


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -107.2895
Epoch 200.0000 - ELBO -105.8323
Epoch 300.0000 - ELBO -105.1795
Epoch 400.0000 - ELBO -104.8237
Epoch 500.0000 - ELBO -104.6177
Epoch 600.0000 - ELBO -104.4929
Epoch 700.0000 - ELBO -104.4217
Epoch 800.0000 - ELBO -104.3702
Epoch 900.0000 - ELBO -104.3344
Epoch 1000.0000 - ELBO -104.3152
Epoch 1100.0000 - ELBO -104.3040
Epoch 1200.0000 - ELBO -104.2974
Epoch 1300.0000 - ELBO -104.2994
Epoch 1400.0000 - ELBO -104.2850
Epoch 1500.0000 - ELBO -104.2868
Epoch 1600.0000 - ELBO -104.2911
Epoch 1700.0000 - ELBO -104.2913
Epoch 1800.0000 - ELBO -104.2874
Epoch 1900.0000 - ELBO -104.2879
Epoch 2000.0000 - ELBO -104.2833
Batch 2, ELBO -104.3212, IWAE -101.5943


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -101.9117
Epoch 200.0000 - ELBO -100.5954
Epoch 300.0000 - ELBO -100.0215
Epoch 400.0000 - ELBO -99.7157
Epoch 500.0000 - ELBO -99.5443
Epoch 600.0000 - ELBO -99.4411
Epoch 700.0000 - ELBO -99.3840
Epoch 800.0000 - ELBO -99.3471
Epoch 900.0000 - ELBO -99.3178
Epoch 1000.0000 - ELBO -99.3062
Epoch 1100.0000 - ELBO -99.3028
Epoch 1200.0000 - ELBO -99.2920
Epoch 1300.0000 - ELBO -99.2895
Epoch 1400.0000 - ELBO -99.2849
Epoch 1500.0000 - ELBO -99.2885
Epoch 1600.0000 - ELBO -99.2811
Epoch 1700.0000 - ELBO -99.2821
Epoch 1800.0000 - ELBO -99.2873
Epoch 1900.0000 - ELBO -99.2796
Batch 3, ELBO -99.2974, IWAE -96.7155


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -101.9212
Epoch 200.0000 - ELBO -100.6227
Epoch 300.0000 - ELBO -100.0530
Epoch 400.0000 - ELBO -99.7339
Epoch 500.0000 - ELBO -99.5535
Epoch 600.0000 - ELBO -99.4469
Epoch 700.0000 - ELBO -99.3942
Epoch 800.0000 - ELBO -99.3532
Epoch 900.0000 - ELBO -99.3232
Epoch 1000.0000 - ELBO -99.3140
Epoch 1100.0000 - ELBO -99.3063
Epoch 1200.0000 - ELBO -99.3025
Epoch 1300.0000 - ELBO -99.2999
Epoch 1400.0000 - ELBO -99.2981
Epoch 1500.0000 - ELBO -99.2954
Epoch 1600.0000 - ELBO -99.2971
Epoch 1700.0000 - ELBO -99.2902
Epoch 1800.0000 - ELBO -99.2859
Epoch 1900.0000 - ELBO -99.2903
Batch 4, ELBO -99.2894, IWAE -96.6396


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -112.7738
Epoch 200.0000 - ELBO -111.3290
Epoch 300.0000 - ELBO -110.6791
Epoch 400.0000 - ELBO -110.3162
Epoch 500.0000 - ELBO -110.1131
Epoch 600.0000 - ELBO -109.9845
Epoch 700.0000 - ELBO -109.9147
Epoch 800.0000 - ELBO -109.8643
Epoch 900.0000 - ELBO -109.8328
Epoch 1000.0000 - ELBO -109.8128
Epoch 1100.0000 - ELBO -109.8031
Epoch 1200.0000 - ELBO -109.7915
Epoch 1300.0000 - ELBO -109.7888
Epoch 1400.0000 - ELBO -109.7819
Epoch 1500.0000 - ELBO -109.7811
Epoch 1600.0000 - ELBO -109.7791
Epoch 1700.0000 - ELBO -109.7805
Epoch 1800.0000 - ELBO -109.7763
Epoch 1900.0000 - ELBO -109.7769
Epoch 2000.0000 - ELBO -109.7766
Batch 5, ELBO -109.7498, IWAE -107.1799


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -102.1475
Epoch 200.0000 - ELBO -100.8337
Epoch 300.0000 - ELBO -100.2483
Epoch 400.0000 - ELBO -99.9288
Epoch 500.0000 - ELBO -99.7504
Epoch 600.0000 - ELBO -99.6371
Epoch 700.0000 - ELBO -99.5773
Epoch 800.0000 - ELBO -99.5372
Epoch 900.0000 - ELBO -99.5147
Epoch 1000.0000 - ELBO -99.5033
Epoch 1100.0000 - ELBO -99.4908
Epoch 1200.0000 - ELBO -99.4896
Epoch 1300.0000 - ELBO -99.4874
Epoch 1400.0000 - ELBO -99.4787
Epoch 1500.0000 - ELBO -99.4839
Epoch 1600.0000 - ELBO -99.4807
Epoch 1700.0000 - ELBO -99.4823
Epoch 1800.0000 - ELBO -99.4740
Epoch 1900.0000 - ELBO -99.4782
Batch 6, ELBO -99.5321, IWAE -96.9923


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -97.2532
Epoch 200.0000 - ELBO -95.9126
Epoch 300.0000 - ELBO -95.3106
Epoch 400.0000 - ELBO -94.9890
Epoch 500.0000 - ELBO -94.8027
Epoch 600.0000 - ELBO -94.6950
Epoch 700.0000 - ELBO -94.6334
Epoch 800.0000 - ELBO -94.5910
Epoch 900.0000 - ELBO -94.5659
Epoch 1000.0000 - ELBO -94.5553
Epoch 1100.0000 - ELBO -94.5449
Epoch 1200.0000 - ELBO -94.5315
Epoch 1300.0000 - ELBO -94.5276
Epoch 1400.0000 - ELBO -94.5310
Epoch 1500.0000 - ELBO -94.5291
Epoch 1600.0000 - ELBO -94.5270
Epoch 1700.0000 - ELBO -94.5205
Epoch 1800.0000 - ELBO -94.5248
Epoch 1900.0000 - ELBO -94.5261
Batch 7, ELBO -94.5579, IWAE -91.8960


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -99.5144
Epoch 200.0000 - ELBO -98.2689
Epoch 300.0000 - ELBO -97.7202
Epoch 400.0000 - ELBO -97.4280
Epoch 500.0000 - ELBO -97.2550
Epoch 600.0000 - ELBO -97.1610
Epoch 700.0000 - ELBO -97.0970
Epoch 800.0000 - ELBO -97.0623
Epoch 900.0000 - ELBO -97.0373
Epoch 1000.0000 - ELBO -97.0309
Epoch 1100.0000 - ELBO -97.0204
Epoch 1200.0000 - ELBO -97.0132
Epoch 1300.0000 - ELBO -97.0144
Epoch 1400.0000 - ELBO -97.0110
Epoch 1500.0000 - ELBO -97.0120
Epoch 1600.0000 - ELBO -97.0073
Epoch 1700.0000 - ELBO -97.0086
Epoch 1800.0000 - ELBO -97.0126
Epoch 1900.0000 - ELBO -97.0115
Batch 8, ELBO -97.0157, IWAE -94.4156


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -96.7485
Epoch 200.0000 - ELBO -95.5155
Epoch 300.0000 - ELBO -94.9653
Epoch 400.0000 - ELBO -94.6753
Epoch 500.0000 - ELBO -94.5137
Epoch 600.0000 - ELBO -94.4151
Epoch 700.0000 - ELBO -94.3643
Epoch 800.0000 - ELBO -94.3303
Epoch 900.0000 - ELBO -94.3074
Epoch 1000.0000 - ELBO -94.2942
Epoch 1100.0000 - ELBO -94.2915
Epoch 1200.0000 - ELBO -94.2841
Epoch 1300.0000 - ELBO -94.2797
Epoch 1400.0000 - ELBO -94.2812
Epoch 1500.0000 - ELBO -94.2820
Epoch 1600.0000 - ELBO -94.2770
Epoch 1700.0000 - ELBO -94.2755
Epoch 1800.0000 - ELBO -94.2800
Epoch 1900.0000 - ELBO -94.2721
Batch 9, ELBO -94.3022, IWAE -91.7395


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -97.7768
Epoch 200.0000 - ELBO -96.4051
Epoch 300.0000 - ELBO -95.7837
Epoch 400.0000 - ELBO -95.4449
Epoch 500.0000 - ELBO -95.2512
Epoch 600.0000 - ELBO -95.1360
Epoch 700.0000 - ELBO -95.0615
Epoch 800.0000 - ELBO -95.0215
Epoch 900.0000 - ELBO -94.9823
Epoch 1000.0000 - ELBO -94.9689
Epoch 1100.0000 - ELBO -94.9584
Epoch 1200.0000 - ELBO -94.9543
Epoch 1300.0000 - ELBO -94.9521
Epoch 1400.0000 - ELBO -94.9411
Epoch 1500.0000 - ELBO -94.9412
Epoch 1600.0000 - ELBO -94.9334
Epoch 1700.0000 - ELBO -94.9401
Epoch 1800.0000 - ELBO -94.9360
Epoch 1900.0000 - ELBO -94.9354
Batch 10, ELBO -94.9480, IWAE -92.3193
Average ELBO -99.0551
Average IWAE -96.4369


In [11]:
utils.save_params(save_dir + "/elbos.pkl", elbos)
utils.save_params(save_dir + "/iwaes.pkl", iwaes)
utils.save_params(save_dir + "/local_params.pkl", local_params)

In [12]:
! zip -r {save_dir + ".zip"} {save_dir + "/"}

  adding: experiments/fashion_ffg_anneal_smaller_local-ffg/ (stored 0%)
  adding: experiments/fashion_ffg_anneal_smaller_local-ffg/local_params.pkl (deflated 7%)
  adding: experiments/fashion_ffg_anneal_smaller_local-ffg/elbos.pkl (deflated 35%)
  adding: experiments/fashion_ffg_anneal_smaller_local-ffg/iwaes.pkl (deflated 35%)
