<a href="https://colab.research.google.com/github/ATML-2022-Group6/inference-suboptimality/blob/main/run_local_opt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys

if "google.colab" in sys.modules:
  ! git clone https://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@github.com/ATML-2022-Group6/inference-suboptimality
  ! cp -r inference-suboptimality/* .
  ! tar -xvf datasets/mnist.pkl.tar.gz && mv mnist.pkl datasets/

Cloning into 'inference-suboptimality'...
remote: Enumerating objects: 1137, done.[K
remote: Counting objects: 100% (1137/1137), done.[K
remote: Compressing objects: 100% (798/798), done.[K
remote: Total 1137 (delta 582), reused 839 (delta 325), pack-reused 0[K
Receiving objects: 100% (1137/1137), 326.79 MiB | 17.47 MiB/s, done.
Resolving deltas: 100% (582/582), done.
Checking out files: 100% (238/238), done.
mnist.pkl


In [2]:
%load_ext autoreload
%autoreload 2

import pathlib

import utils
import vae
from datasets import get_batches, get_dataset
from local_opt import LocalHyperParams, local_opt
from utils import HyperParams

In [4]:
use_tpu = False
if use_tpu and "google.colab" in sys.modules:
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()

In [34]:
# -- Vary across experiment
encoder_size = (200, 200)
decoder_size = (200, 200)
trained_has_flow = True
local_has_flow = False
kl_annealing = True
dataset_name = "kmnist" # mnist, fashion, kmnist
# --- 

def fmt_size(sizes): return "".join(str(size)[0] for size in list(sizes))

## Name of this experiment (important to change for saving results)
name = "_".join([
  dataset_name,
  ["ffg","flow"][trained_has_flow],
  ["regular","anneal"][kl_annealing],
  "e"+fmt_size(encoder_size)+"d"+fmt_size(decoder_size),
  ["local-ffg", "local-flow"][local_has_flow],
])
print(name)

kmnist_flow_anneal_e22d22_local-ffg


In [36]:
mount_google_drive = True

if mount_google_drive and "google.colab" in sys.modules:
  from google.colab import drive
  drive.mount("/content/drive")
  save_dir = "/content/drive/My Drive/ATML/" + name
else:
  save_dir = "./experiments/" + name

pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [37]:
hps = HyperParams(has_flow=local_has_flow, encoder_hidden=encoder_size, decoder_hidden=decoder_size)
local_hps = LocalHyperParams()

print(hps)
print(local_hps)

HyperParams(image_size=784, latent_size=50, encoder_hidden=(200, 200), decoder_hidden=(200, 200), has_flow=False, num_flows=2, flow_hidden_size=200)
LocalHyperParams(learning_rate=0.001, mc_samples=100, display_epoch=100, debug=True, iwae_samples=100, max_epochs=100000, patience=10, es_epsilon=0.05)


In [38]:
mnist = get_dataset(dataset_name)

In [39]:
batch_size = 100
# only locally optimise 1000 due to computational cost
smaller_data = True
train_batches = get_batches(mnist["train_x"], batch_size, smaller_data)

In [40]:
# remove "local-ffg" or "local-flow" suffix 
trained_model_dir = save_dir[:save_dir.rfind("_")]
file_name = trained_model_dir + "/params.pkl"
trained_params = utils.load_params(file_name)

In [41]:
model = vae.VAE(hps)
elbos, iwaes, local_params = local_opt(local_hps, model, train_batches, trained_params)

Optimising Local FFG ...


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -527.5261
Epoch 200.0000 - ELBO -438.5320
Epoch 300.0000 - ELBO -386.4560
Epoch 400.0000 - ELBO -350.4790
Epoch 500.0000 - ELBO -323.4475
Epoch 600.0000 - ELBO -302.0044
Epoch 700.0000 - ELBO -284.6550
Epoch 800.0000 - ELBO -270.2166
Epoch 900.0000 - ELBO -258.1154
Epoch 1000.0000 - ELBO -247.6982
Epoch 1100.0000 - ELBO -238.7284
Epoch 1200.0000 - ELBO -230.9707
Epoch 1300.0000 - ELBO -224.0934
Epoch 1400.0000 - ELBO -218.0130
Epoch 1500.0000 - ELBO -212.6104
Epoch 1600.0000 - ELBO -207.7287
Epoch 1700.0000 - ELBO -203.4167
Epoch 1800.0000 - ELBO -199.5110
Epoch 1900.0000 - ELBO -195.9565
Epoch 2000.0000 - ELBO -192.7091
Epoch 2100.0000 - ELBO -189.8771
Epoch 2200.0000 - ELBO -187.2141
Epoch 2300.0000 - ELBO -184.8753
Epoch 2400.0000 - ELBO -182.6529
Epoch 2500.0000 - ELBO -180.6500
Epoch 2600.0000 - ELBO -178.8340
Epoch 2700.0000 - ELBO -177.1672
Epoch 2800.0000 - ELBO -175.6674
Epoch 2900.0000 - ELBO -174.2391
Epoch 3000.0000 - ELBO -172.9867
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -481.8532
Epoch 200.0000 - ELBO -399.5970
Epoch 300.0000 - ELBO -351.4345
Epoch 400.0000 - ELBO -319.2633
Epoch 500.0000 - ELBO -295.3203
Epoch 600.0000 - ELBO -276.6627
Epoch 700.0000 - ELBO -261.5229
Epoch 800.0000 - ELBO -248.8519
Epoch 900.0000 - ELBO -238.2399
Epoch 1000.0000 - ELBO -228.9766
Epoch 1100.0000 - ELBO -221.0243
Epoch 1200.0000 - ELBO -214.0875
Epoch 1300.0000 - ELBO -207.9143
Epoch 1400.0000 - ELBO -202.5055
Epoch 1500.0000 - ELBO -197.5537
Epoch 1600.0000 - ELBO -193.1322
Epoch 1700.0000 - ELBO -189.2480
Epoch 1800.0000 - ELBO -185.6967
Epoch 1900.0000 - ELBO -182.4528
Epoch 2000.0000 - ELBO -179.5902
Epoch 2100.0000 - ELBO -177.0683
Epoch 2200.0000 - ELBO -174.6888
Epoch 2300.0000 - ELBO -172.6070
Epoch 2400.0000 - ELBO -170.6676
Epoch 2500.0000 - ELBO -168.9130
Epoch 2600.0000 - ELBO -167.3270
Epoch 2700.0000 - ELBO -165.8671
Epoch 2800.0000 - ELBO -164.5071
Epoch 2900.0000 - ELBO -163.2788
Epoch 3000.0000 - ELBO -162.1682
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -536.7238
Epoch 200.0000 - ELBO -444.7072
Epoch 300.0000 - ELBO -391.1285
Epoch 400.0000 - ELBO -354.7888
Epoch 500.0000 - ELBO -327.7115
Epoch 600.0000 - ELBO -306.3120
Epoch 700.0000 - ELBO -288.9781
Epoch 800.0000 - ELBO -274.5559
Epoch 900.0000 - ELBO -262.2700
Epoch 1000.0000 - ELBO -251.7450
Epoch 1100.0000 - ELBO -242.4875
Epoch 1200.0000 - ELBO -234.4438
Epoch 1300.0000 - ELBO -227.3613
Epoch 1400.0000 - ELBO -221.0082
Epoch 1500.0000 - ELBO -215.4029
Epoch 1600.0000 - ELBO -210.3108
Epoch 1700.0000 - ELBO -205.8832
Epoch 1800.0000 - ELBO -201.7874
Epoch 1900.0000 - ELBO -198.0324
Epoch 2000.0000 - ELBO -194.7022
Epoch 2100.0000 - ELBO -191.7210
Epoch 2200.0000 - ELBO -189.0007
Epoch 2300.0000 - ELBO -186.5281
Epoch 2400.0000 - ELBO -184.2578
Epoch 2500.0000 - ELBO -182.1656
Epoch 2600.0000 - ELBO -180.2870
Epoch 2700.0000 - ELBO -178.5834
Epoch 2800.0000 - ELBO -176.9854
Epoch 2900.0000 - ELBO -175.5396
Epoch 3000.0000 - ELBO -174.2249
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -496.1851
Epoch 200.0000 - ELBO -407.8427
Epoch 300.0000 - ELBO -358.9245
Epoch 400.0000 - ELBO -326.4176
Epoch 500.0000 - ELBO -302.3512
Epoch 600.0000 - ELBO -283.4489
Epoch 700.0000 - ELBO -268.0498
Epoch 800.0000 - ELBO -255.1496
Epoch 900.0000 - ELBO -244.2137
Epoch 1000.0000 - ELBO -234.7809
Epoch 1100.0000 - ELBO -226.5928
Epoch 1200.0000 - ELBO -219.4602
Epoch 1300.0000 - ELBO -213.1064
Epoch 1400.0000 - ELBO -207.5267
Epoch 1500.0000 - ELBO -202.4471
Epoch 1600.0000 - ELBO -197.9926
Epoch 1700.0000 - ELBO -193.9815
Epoch 1800.0000 - ELBO -190.3591
Epoch 1900.0000 - ELBO -187.1025
Epoch 2000.0000 - ELBO -184.1086
Epoch 2100.0000 - ELBO -181.5131
Epoch 2200.0000 - ELBO -179.0578
Epoch 2300.0000 - ELBO -176.8918
Epoch 2400.0000 - ELBO -174.8958
Epoch 2500.0000 - ELBO -173.0599
Epoch 2600.0000 - ELBO -171.3934
Epoch 2700.0000 - ELBO -169.8843
Epoch 2800.0000 - ELBO -168.4709
Epoch 2900.0000 - ELBO -167.1745
Epoch 3000.0000 - ELBO -166.0122
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -555.8429
Epoch 200.0000 - ELBO -458.5266
Epoch 300.0000 - ELBO -402.4526
Epoch 400.0000 - ELBO -364.9962
Epoch 500.0000 - ELBO -337.2940
Epoch 600.0000 - ELBO -315.5012
Epoch 700.0000 - ELBO -297.7897
Epoch 800.0000 - ELBO -282.8614
Epoch 900.0000 - ELBO -270.2162
Epoch 1000.0000 - ELBO -259.2239
Epoch 1100.0000 - ELBO -249.5774
Epoch 1200.0000 - ELBO -241.2121
Epoch 1300.0000 - ELBO -233.7678
Epoch 1400.0000 - ELBO -227.0538
Epoch 1500.0000 - ELBO -221.0471
Epoch 1600.0000 - ELBO -215.6080
Epoch 1700.0000 - ELBO -210.8698
Epoch 1800.0000 - ELBO -206.4611
Epoch 1900.0000 - ELBO -202.5211
Epoch 2000.0000 - ELBO -198.9026
Epoch 2100.0000 - ELBO -195.6959
Epoch 2200.0000 - ELBO -192.7486
Epoch 2300.0000 - ELBO -190.0922
Epoch 2400.0000 - ELBO -187.6499
Epoch 2500.0000 - ELBO -185.4254
Epoch 2600.0000 - ELBO -183.3804
Epoch 2700.0000 - ELBO -181.5291
Epoch 2800.0000 - ELBO -179.8655
Epoch 2900.0000 - ELBO -178.3261
Epoch 3000.0000 - ELBO -176.8990
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -558.0010
Epoch 200.0000 - ELBO -461.0184
Epoch 300.0000 - ELBO -405.4485
Epoch 400.0000 - ELBO -367.6380
Epoch 500.0000 - ELBO -339.0504
Epoch 600.0000 - ELBO -316.4496
Epoch 700.0000 - ELBO -297.8887
Epoch 800.0000 - ELBO -282.4521
Epoch 900.0000 - ELBO -269.4491
Epoch 1000.0000 - ELBO -258.1685
Epoch 1100.0000 - ELBO -248.3810
Epoch 1200.0000 - ELBO -239.8713
Epoch 1300.0000 - ELBO -232.3614
Epoch 1400.0000 - ELBO -225.6841
Epoch 1500.0000 - ELBO -219.6598
Epoch 1600.0000 - ELBO -214.2784
Epoch 1700.0000 - ELBO -209.4757
Epoch 1800.0000 - ELBO -205.1709
Epoch 1900.0000 - ELBO -201.2094
Epoch 2000.0000 - ELBO -197.5944
Epoch 2100.0000 - ELBO -194.4958
Epoch 2200.0000 - ELBO -191.5385
Epoch 2300.0000 - ELBO -188.9249
Epoch 2400.0000 - ELBO -186.4992
Epoch 2500.0000 - ELBO -184.2589
Epoch 2600.0000 - ELBO -182.2479
Epoch 2700.0000 - ELBO -180.4267
Epoch 2800.0000 - ELBO -178.7509
Epoch 2900.0000 - ELBO -177.1990
Epoch 3000.0000 - ELBO -175.7663
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -490.8085
Epoch 200.0000 - ELBO -404.0134
Epoch 300.0000 - ELBO -354.8467
Epoch 400.0000 - ELBO -321.8548
Epoch 500.0000 - ELBO -297.2668
Epoch 600.0000 - ELBO -277.9870
Epoch 700.0000 - ELBO -262.2566
Epoch 800.0000 - ELBO -249.1930
Epoch 900.0000 - ELBO -238.1828
Epoch 1000.0000 - ELBO -228.6064
Epoch 1100.0000 - ELBO -220.4830
Epoch 1200.0000 - ELBO -213.3366
Epoch 1300.0000 - ELBO -207.0464
Epoch 1400.0000 - ELBO -201.4292
Epoch 1500.0000 - ELBO -196.3876
Epoch 1600.0000 - ELBO -191.8587
Epoch 1700.0000 - ELBO -187.9277
Epoch 1800.0000 - ELBO -184.3406
Epoch 1900.0000 - ELBO -181.0177
Epoch 2000.0000 - ELBO -178.0672
Epoch 2100.0000 - ELBO -175.4216
Epoch 2200.0000 - ELBO -173.0071
Epoch 2300.0000 - ELBO -170.8018
Epoch 2400.0000 - ELBO -168.7837
Epoch 2500.0000 - ELBO -166.9390
Epoch 2600.0000 - ELBO -165.3046
Epoch 2700.0000 - ELBO -163.7552
Epoch 2800.0000 - ELBO -162.3832
Epoch 2900.0000 - ELBO -161.1072
Epoch 3000.0000 - ELBO -159.9178
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -560.1804
Epoch 200.0000 - ELBO -457.0871
Epoch 300.0000 - ELBO -399.5212
Epoch 400.0000 - ELBO -361.5452
Epoch 500.0000 - ELBO -333.6320
Epoch 600.0000 - ELBO -311.8161
Epoch 700.0000 - ELBO -293.9215
Epoch 800.0000 - ELBO -278.9782
Epoch 900.0000 - ELBO -266.2934
Epoch 1000.0000 - ELBO -255.2335
Epoch 1100.0000 - ELBO -245.6616
Epoch 1200.0000 - ELBO -237.3189
Epoch 1300.0000 - ELBO -229.8732
Epoch 1400.0000 - ELBO -223.2391
Epoch 1500.0000 - ELBO -217.2821
Epoch 1600.0000 - ELBO -211.9180
Epoch 1700.0000 - ELBO -207.1993
Epoch 1800.0000 - ELBO -202.8877
Epoch 1900.0000 - ELBO -198.9595
Epoch 2000.0000 - ELBO -195.4121
Epoch 2100.0000 - ELBO -192.2618
Epoch 2200.0000 - ELBO -189.3395
Epoch 2300.0000 - ELBO -186.7414
Epoch 2400.0000 - ELBO -184.2954
Epoch 2500.0000 - ELBO -182.0993
Epoch 2600.0000 - ELBO -180.1078
Epoch 2700.0000 - ELBO -178.2533
Epoch 2800.0000 - ELBO -176.6253
Epoch 2900.0000 - ELBO -175.0761
Epoch 3000.0000 - ELBO -173.6159
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -519.2791
Epoch 200.0000 - ELBO -431.3160
Epoch 300.0000 - ELBO -379.4533
Epoch 400.0000 - ELBO -343.7879
Epoch 500.0000 - ELBO -317.2527
Epoch 600.0000 - ELBO -296.5571
Epoch 700.0000 - ELBO -279.7106
Epoch 800.0000 - ELBO -265.7782
Epoch 900.0000 - ELBO -254.0402
Epoch 1000.0000 - ELBO -243.8837
Epoch 1100.0000 - ELBO -235.0435
Epoch 1200.0000 - ELBO -227.3591
Epoch 1300.0000 - ELBO -220.6163
Epoch 1400.0000 - ELBO -214.4703
Epoch 1500.0000 - ELBO -209.1026
Epoch 1600.0000 - ELBO -204.1957
Epoch 1700.0000 - ELBO -199.8526
Epoch 1800.0000 - ELBO -195.9200
Epoch 1900.0000 - ELBO -192.2485
Epoch 2000.0000 - ELBO -189.0067
Epoch 2100.0000 - ELBO -186.1186
Epoch 2200.0000 - ELBO -183.4259
Epoch 2300.0000 - ELBO -181.0133
Epoch 2400.0000 - ELBO -178.8034
Epoch 2500.0000 - ELBO -176.7027
Epoch 2600.0000 - ELBO -174.8617
Epoch 2700.0000 - ELBO -173.1394
Epoch 2800.0000 - ELBO -171.5837
Epoch 2900.0000 - ELBO -170.1224
Epoch 3000.0000 - ELBO -168.8020
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -531.6653
Epoch 200.0000 - ELBO -438.0012
Epoch 300.0000 - ELBO -383.9403
Epoch 400.0000 - ELBO -347.9203
Epoch 500.0000 - ELBO -321.0793
Epoch 600.0000 - ELBO -300.3082
Epoch 700.0000 - ELBO -283.3850
Epoch 800.0000 - ELBO -269.1936
Epoch 900.0000 - ELBO -257.1126
Epoch 1000.0000 - ELBO -246.7480
Epoch 1100.0000 - ELBO -237.7432
Epoch 1200.0000 - ELBO -229.8297
Epoch 1300.0000 - ELBO -222.8743
Epoch 1400.0000 - ELBO -216.6259
Epoch 1500.0000 - ELBO -211.1386
Epoch 1600.0000 - ELBO -206.1418
Epoch 1700.0000 - ELBO -201.7076
Epoch 1800.0000 - ELBO -197.6730
Epoch 1900.0000 - ELBO -194.0426
Epoch 2000.0000 - ELBO -190.7053
Epoch 2100.0000 - ELBO -187.7771
Epoch 2200.0000 - ELBO -185.0153
Epoch 2300.0000 - ELBO -182.6082
Epoch 2400.0000 - ELBO -180.3400
Epoch 2500.0000 - ELBO -178.2451
Epoch 2600.0000 - ELBO -176.3928
Epoch 2700.0000 - ELBO -174.6496
Epoch 2800.0000 - ELBO -173.1164
Epoch 2900.0000 - ELBO -171.6415
Epoch 3000.0000 - ELBO -170.3220
Epoch 3100.0000 - E

In [42]:
utils.save_params(save_dir + "/elbos.pkl", elbos)
utils.save_params(save_dir + "/iwaes.pkl", iwaes)
utils.save_params(save_dir + "/local_params.pkl", local_params)