<a href="https://colab.research.google.com/github/ATML-2022-Group6/inference-suboptimality/blob/main/run_local_opt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys

if "google.colab" in sys.modules:
  ! git clone https://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@github.com/ATML-2022-Group6/inference-suboptimality
  ! cp -r inference-suboptimality/* .
  ! tar -xvf datasets/mnist.pkl.tar.gz && mv mnist.pkl datasets/

Cloning into 'inference-suboptimality'...
remote: Enumerating objects: 1067, done.[K
remote: Counting objects: 100% (1067/1067), done.[K
remote: Compressing objects: 100% (744/744), done.[K
remote: Total 1067 (delta 559), reused 776 (delta 309), pack-reused 0[K
Receiving objects: 100% (1067/1067), 277.35 MiB | 33.83 MiB/s, done.
Resolving deltas: 100% (559/559), done.
Checking out files: 100% (205/205), done.
mnist.pkl


In [2]:
%load_ext autoreload
%autoreload 2

import pathlib

import utils
import vae
from datasets import get_mnist, get_batches, get_fashion_mnist
from local_opt import LocalHyperParams, local_opt
from utils import HyperParams

In [3]:
use_tpu = True
if use_tpu and "google.colab" in sys.modules:
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()

In [4]:
# -- Vary across experiment
is_larger = True
trained_has_flow = True
local_has_flow = False
kl_annealing = True
use_fashion = True
# --- 

## Name of this experiment (important to change for saving results)
name = "_".join([
  ["mnist","fashion"][use_fashion],
  ["ffg","flow"][trained_has_flow],
  ["regular","anneal"][kl_annealing],
  ["smaller","larger"][is_larger],
  ["local-ffg", "local-flow"][local_has_flow],
])
print(name)

fashion_flow_anneal_larger_local-ffg


In [5]:
mount_google_drive = True

if mount_google_drive and "google.colab" in sys.modules:
  from google.colab import drive
  drive.mount("/content/drive")
  save_dir = "/content/drive/My Drive/ATML/" + name
else:
  save_dir = "./experiments/" + name

pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)

Mounted at /content/drive


In [11]:
encoder_size = (500, 500) if is_larger else (200, 200)
hps = HyperParams(has_flow=local_has_flow, encoder_hidden=encoder_size)
local_hps = LocalHyperParams(learning_rate=1e-2)

print(hps)
print(local_hps)

HyperParams(image_size=784, latent_size=50, encoder_hidden=(500, 500), decoder_hidden=(200, 200), has_flow=False, num_flows=2, flow_hidden_size=200)
LocalHyperParams(learning_rate=0.01, mc_samples=100, display_epoch=100, debug=True, iwae_samples=100, max_epochs=100000, patience=10, es_epsilon=0.05)


In [7]:
mnist = get_fashion_mnist() if use_fashion else get_mnist()

In [8]:
batch_size = 100
# only locally optimise 1000 due to computational cost
smaller_data = True
train_batches = get_batches(mnist["train_x"], batch_size, smaller_data)

In [9]:
# remove "local-ffg" or "local-flow" suffix 
trained_model_dir = save_dir[:save_dir.rfind("_")]
file_name = trained_model_dir + "/params.pkl"
trained_params = utils.load_params(file_name)

In [12]:
model = vae.VAE(hps)
elbos, iwaes, local_params = local_opt(local_hps, model, train_batches, trained_params)

Optimising Local FFG ...


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -7224.1079
Epoch 200.0000 - ELBO -5245.9985
Epoch 300.0000 - ELBO -3875.8523
Epoch 400.0000 - ELBO -2943.2053
Epoch 500.0000 - ELBO -2329.9399
Epoch 600.0000 - ELBO -1927.4309
Epoch 700.0000 - ELBO -1644.1866
Epoch 800.0000 - ELBO -1431.6392
Epoch 900.0000 - ELBO -1265.5623
Epoch 1000.0000 - ELBO -1129.9637
Epoch 1100.0000 - ELBO -1015.1942
Epoch 1200.0000 - ELBO -915.5876
Epoch 1300.0000 - ELBO -828.4218
Epoch 1400.0000 - ELBO -750.9626
Epoch 1500.0000 - ELBO -681.5993
Epoch 1600.0000 - ELBO -619.3793
Epoch 1700.0000 - ELBO -563.5426
Epoch 1800.0000 - ELBO -513.1965
Epoch 1900.0000 - ELBO -467.7619
Epoch 2000.0000 - ELBO -426.7486
Epoch 2100.0000 - ELBO -389.8808
Epoch 2200.0000 - ELBO -356.8288
Epoch 2300.0000 - ELBO -327.5393
Epoch 2400.0000 - ELBO -301.5681
Epoch 2500.0000 - ELBO -278.4961
Epoch 2600.0000 - ELBO -258.0202
Epoch 2700.0000 - ELBO -239.8146
Epoch 2800.0000 - ELBO -223.6868
Epoch 2900.0000 - ELBO -209.3632
Epoch 3000.0000 - ELBO -196.6585
Epoch 31

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -7399.1616
Epoch 200.0000 - ELBO -5407.7930
Epoch 300.0000 - ELBO -4053.7974
Epoch 400.0000 - ELBO -3149.5994
Epoch 500.0000 - ELBO -2556.6033
Epoch 600.0000 - ELBO -2147.0989
Epoch 700.0000 - ELBO -1841.1697
Epoch 800.0000 - ELBO -1605.5365
Epoch 900.0000 - ELBO -1418.6976
Epoch 1000.0000 - ELBO -1264.1383
Epoch 1100.0000 - ELBO -1133.2194
Epoch 1200.0000 - ELBO -1022.0690
Epoch 1300.0000 - ELBO -925.7477
Epoch 1400.0000 - ELBO -840.5204
Epoch 1500.0000 - ELBO -765.2193
Epoch 1600.0000 - ELBO -698.5584
Epoch 1700.0000 - ELBO -638.9926
Epoch 1800.0000 - ELBO -585.2531
Epoch 1900.0000 - ELBO -536.5612
Epoch 2000.0000 - ELBO -492.4323
Epoch 2100.0000 - ELBO -452.3413
Epoch 2200.0000 - ELBO -415.9761
Epoch 2300.0000 - ELBO -383.0672
Epoch 2400.0000 - ELBO -353.3684
Epoch 2500.0000 - ELBO -326.4597
Epoch 2600.0000 - ELBO -302.0988
Epoch 2700.0000 - ELBO -280.0763
Epoch 2800.0000 - ELBO -260.3958
Epoch 2900.0000 - ELBO -242.8314
Epoch 3000.0000 - ELBO -227.1846
Epoch 3

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -6662.2153
Epoch 200.0000 - ELBO -4838.2021
Epoch 300.0000 - ELBO -3615.3762
Epoch 400.0000 - ELBO -2792.8042
Epoch 500.0000 - ELBO -2251.5095
Epoch 600.0000 - ELBO -1881.2513
Epoch 700.0000 - ELBO -1608.9147
Epoch 800.0000 - ELBO -1400.2625
Epoch 900.0000 - ELBO -1233.1030
Epoch 1000.0000 - ELBO -1095.7241
Epoch 1100.0000 - ELBO -980.5211
Epoch 1200.0000 - ELBO -881.9391
Epoch 1300.0000 - ELBO -795.6578
Epoch 1400.0000 - ELBO -719.4138
Epoch 1500.0000 - ELBO -651.4787
Epoch 1600.0000 - ELBO -591.3746
Epoch 1700.0000 - ELBO -537.9699
Epoch 1800.0000 - ELBO -490.1622
Epoch 1900.0000 - ELBO -447.1740
Epoch 2000.0000 - ELBO -408.5529
Epoch 2100.0000 - ELBO -373.7502
Epoch 2200.0000 - ELBO -342.3760
Epoch 2300.0000 - ELBO -314.0552
Epoch 2400.0000 - ELBO -288.8619
Epoch 2500.0000 - ELBO -266.2418
Epoch 2600.0000 - ELBO -245.8802
Epoch 2700.0000 - ELBO -227.6912
Epoch 2800.0000 - ELBO -211.5757
Epoch 2900.0000 - ELBO -197.3605
Epoch 3000.0000 - ELBO -184.8964
Epoch 310

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -7498.3696
Epoch 200.0000 - ELBO -5552.3486
Epoch 300.0000 - ELBO -4135.6304
Epoch 400.0000 - ELBO -3135.5540
Epoch 500.0000 - ELBO -2492.4971
Epoch 600.0000 - ELBO -2083.6389
Epoch 700.0000 - ELBO -1794.8134
Epoch 800.0000 - ELBO -1572.1752
Epoch 900.0000 - ELBO -1390.1051
Epoch 1000.0000 - ELBO -1236.6891
Epoch 1100.0000 - ELBO -1105.5736
Epoch 1200.0000 - ELBO -992.2320
Epoch 1300.0000 - ELBO -893.3493
Epoch 1400.0000 - ELBO -805.9966
Epoch 1500.0000 - ELBO -728.4669
Epoch 1600.0000 - ELBO -659.3691
Epoch 1700.0000 - ELBO -598.2837
Epoch 1800.0000 - ELBO -544.4946
Epoch 1900.0000 - ELBO -496.8925
Epoch 2000.0000 - ELBO -454.7101
Epoch 2100.0000 - ELBO -416.9176
Epoch 2200.0000 - ELBO -383.0963
Epoch 2300.0000 - ELBO -352.6917
Epoch 2400.0000 - ELBO -325.3658
Epoch 2500.0000 - ELBO -300.9970
Epoch 2600.0000 - ELBO -279.0805
Epoch 2700.0000 - ELBO -259.3548
Epoch 2800.0000 - ELBO -241.5979
Epoch 2900.0000 - ELBO -225.6084
Epoch 3000.0000 - ELBO -211.1977
Epoch 31

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -6898.7036
Epoch 200.0000 - ELBO -4944.6499
Epoch 300.0000 - ELBO -3704.4353
Epoch 400.0000 - ELBO -2892.6167
Epoch 500.0000 - ELBO -2352.2507
Epoch 600.0000 - ELBO -1985.0679
Epoch 700.0000 - ELBO -1718.5771
Epoch 800.0000 - ELBO -1510.5692
Epoch 900.0000 - ELBO -1341.2028
Epoch 1000.0000 - ELBO -1200.4636
Epoch 1100.0000 - ELBO -1081.9017
Epoch 1200.0000 - ELBO -979.9192
Epoch 1300.0000 - ELBO -891.1597
Epoch 1400.0000 - ELBO -812.8389
Epoch 1500.0000 - ELBO -741.9752
Epoch 1600.0000 - ELBO -677.8644
Epoch 1700.0000 - ELBO -619.6356
Epoch 1800.0000 - ELBO -566.4590
Epoch 1900.0000 - ELBO -518.2698
Epoch 2000.0000 - ELBO -474.7250
Epoch 2100.0000 - ELBO -435.2733
Epoch 2200.0000 - ELBO -399.7174
Epoch 2300.0000 - ELBO -367.7191
Epoch 2400.0000 - ELBO -338.9091
Epoch 2500.0000 - ELBO -313.0100
Epoch 2600.0000 - ELBO -289.8465
Epoch 2700.0000 - ELBO -269.1107
Epoch 2800.0000 - ELBO -250.6430
Epoch 2900.0000 - ELBO -234.2751
Epoch 3000.0000 - ELBO -219.6945
Epoch 31

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -6367.3979
Epoch 200.0000 - ELBO -4686.0835
Epoch 300.0000 - ELBO -3534.8213
Epoch 400.0000 - ELBO -2751.6501
Epoch 500.0000 - ELBO -2211.0500
Epoch 600.0000 - ELBO -1827.8765
Epoch 700.0000 - ELBO -1554.7295
Epoch 800.0000 - ELBO -1353.3851
Epoch 900.0000 - ELBO -1195.8794
Epoch 1000.0000 - ELBO -1065.7960
Epoch 1100.0000 - ELBO -955.3390
Epoch 1200.0000 - ELBO -859.6437
Epoch 1300.0000 - ELBO -775.6367
Epoch 1400.0000 - ELBO -701.1993
Epoch 1500.0000 - ELBO -634.8284
Epoch 1600.0000 - ELBO -575.5328
Epoch 1700.0000 - ELBO -522.2466
Epoch 1800.0000 - ELBO -474.1987
Epoch 1900.0000 - ELBO -430.8114
Epoch 2000.0000 - ELBO -391.6126
Epoch 2100.0000 - ELBO -356.3562
Epoch 2200.0000 - ELBO -324.8426
Epoch 2300.0000 - ELBO -296.8098
Epoch 2400.0000 - ELBO -271.9795
Epoch 2500.0000 - ELBO -250.0175
Epoch 2600.0000 - ELBO -230.6990
Epoch 2700.0000 - ELBO -213.6818
Epoch 2800.0000 - ELBO -198.8434
Epoch 2900.0000 - ELBO -185.8034
Epoch 3000.0000 - ELBO -174.4099
Epoch 310

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -6398.5205
Epoch 200.0000 - ELBO -4677.0151
Epoch 300.0000 - ELBO -3595.9385
Epoch 400.0000 - ELBO -2853.0664
Epoch 500.0000 - ELBO -2327.9133
Epoch 600.0000 - ELBO -1955.3289
Epoch 700.0000 - ELBO -1677.9646
Epoch 800.0000 - ELBO -1459.0172
Epoch 900.0000 - ELBO -1278.8019
Epoch 1000.0000 - ELBO -1127.9017
Epoch 1100.0000 - ELBO -1001.1575
Epoch 1200.0000 - ELBO -894.9426
Epoch 1300.0000 - ELBO -804.8260
Epoch 1400.0000 - ELBO -726.6878
Epoch 1500.0000 - ELBO -657.8493
Epoch 1600.0000 - ELBO -596.6508
Epoch 1700.0000 - ELBO -541.7601
Epoch 1800.0000 - ELBO -492.7662
Epoch 1900.0000 - ELBO -449.0880
Epoch 2000.0000 - ELBO -409.7791
Epoch 2100.0000 - ELBO -374.3587
Epoch 2200.0000 - ELBO -342.3265
Epoch 2300.0000 - ELBO -313.6196
Epoch 2400.0000 - ELBO -287.9024
Epoch 2500.0000 - ELBO -264.7611
Epoch 2600.0000 - ELBO -243.8724
Epoch 2700.0000 - ELBO -225.0866
Epoch 2800.0000 - ELBO -208.5994
Epoch 2900.0000 - ELBO -194.0831
Epoch 3000.0000 - ELBO -181.3078
Epoch 31

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -6825.1216
Epoch 200.0000 - ELBO -5081.0381
Epoch 300.0000 - ELBO -3880.6641
Epoch 400.0000 - ELBO -3039.8674
Epoch 500.0000 - ELBO -2470.9436
Epoch 600.0000 - ELBO -2072.5317
Epoch 700.0000 - ELBO -1770.1094
Epoch 800.0000 - ELBO -1535.9712
Epoch 900.0000 - ELBO -1350.3069
Epoch 1000.0000 - ELBO -1198.9531
Epoch 1100.0000 - ELBO -1073.3336
Epoch 1200.0000 - ELBO -966.4437
Epoch 1300.0000 - ELBO -872.9742
Epoch 1400.0000 - ELBO -789.7737
Epoch 1500.0000 - ELBO -715.1628
Epoch 1600.0000 - ELBO -648.3889
Epoch 1700.0000 - ELBO -588.0304
Epoch 1800.0000 - ELBO -533.2207
Epoch 1900.0000 - ELBO -483.7900
Epoch 2000.0000 - ELBO -439.3954
Epoch 2100.0000 - ELBO -399.6743
Epoch 2200.0000 - ELBO -364.3361
Epoch 2300.0000 - ELBO -332.8531
Epoch 2400.0000 - ELBO -304.7460
Epoch 2500.0000 - ELBO -279.8034
Epoch 2600.0000 - ELBO -257.6383
Epoch 2700.0000 - ELBO -237.8710
Epoch 2800.0000 - ELBO -220.3021
Epoch 2900.0000 - ELBO -204.6910
Epoch 3000.0000 - ELBO -190.8374
Epoch 31

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -6851.0649
Epoch 200.0000 - ELBO -5091.7427
Epoch 300.0000 - ELBO -3865.4978
Epoch 400.0000 - ELBO -2994.3633
Epoch 500.0000 - ELBO -2398.1360
Epoch 600.0000 - ELBO -1990.5751
Epoch 700.0000 - ELBO -1697.8730
Epoch 800.0000 - ELBO -1471.4364
Epoch 900.0000 - ELBO -1290.5317
Epoch 1000.0000 - ELBO -1142.6899
Epoch 1100.0000 - ELBO -1018.7374
Epoch 1200.0000 - ELBO -911.1895
Epoch 1300.0000 - ELBO -816.5954
Epoch 1400.0000 - ELBO -732.8184
Epoch 1500.0000 - ELBO -659.1770
Epoch 1600.0000 - ELBO -594.5541
Epoch 1700.0000 - ELBO -537.6733
Epoch 1800.0000 - ELBO -487.4312
Epoch 1900.0000 - ELBO -442.5959
Epoch 2000.0000 - ELBO -402.3523
Epoch 2100.0000 - ELBO -366.3721
Epoch 2200.0000 - ELBO -334.1830
Epoch 2300.0000 - ELBO -305.5301
Epoch 2400.0000 - ELBO -280.1474
Epoch 2500.0000 - ELBO -257.7422
Epoch 2600.0000 - ELBO -237.8235
Epoch 2700.0000 - ELBO -220.0956
Epoch 2800.0000 - ELBO -204.4297
Epoch 2900.0000 - ELBO -190.6846
Epoch 3000.0000 - ELBO -178.6064
Epoch 31

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -6371.5972
Epoch 200.0000 - ELBO -4639.0088
Epoch 300.0000 - ELBO -3505.8196
Epoch 400.0000 - ELBO -2743.5625
Epoch 500.0000 - ELBO -2243.8125
Epoch 600.0000 - ELBO -1903.1404
Epoch 700.0000 - ELBO -1646.6803
Epoch 800.0000 - ELBO -1443.4944
Epoch 900.0000 - ELBO -1276.5747
Epoch 1000.0000 - ELBO -1136.8507
Epoch 1100.0000 - ELBO -1016.5916
Epoch 1200.0000 - ELBO -912.5085
Epoch 1300.0000 - ELBO -822.5716
Epoch 1400.0000 - ELBO -743.7116
Epoch 1500.0000 - ELBO -673.8224
Epoch 1600.0000 - ELBO -611.3022
Epoch 1700.0000 - ELBO -554.9601
Epoch 1800.0000 - ELBO -504.0002
Epoch 1900.0000 - ELBO -458.2492
Epoch 2000.0000 - ELBO -417.1133
Epoch 2100.0000 - ELBO -380.0349
Epoch 2200.0000 - ELBO -346.6129
Epoch 2300.0000 - ELBO -316.7715
Epoch 2400.0000 - ELBO -290.3537
Epoch 2500.0000 - ELBO -267.0606
Epoch 2600.0000 - ELBO -246.4728
Epoch 2700.0000 - ELBO -228.1867
Epoch 2800.0000 - ELBO -212.3129
Epoch 2900.0000 - ELBO -198.4561
Epoch 3000.0000 - ELBO -186.2916
Epoch 31

In [13]:
utils.save_params(save_dir + "/elbos.pkl", elbos)
utils.save_params(save_dir + "/iwaes.pkl", iwaes)
utils.save_params(save_dir + "/local_params.pkl", local_params)