In [1]:
import sys

if "google.colab" in sys.modules:
  ! git clone https://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@github.com/ATML-2022-Group6/inference-suboptimality
  ! cp -r inference-suboptimality/* .
  ! tar -xvf datasets/mnist.pkl.tar.gz && mv mnist.pkl datasets/

Cloning into 'inference-suboptimality'...
remote: Enumerating objects: 868, done.[K
remote: Counting objects: 100% (868/868), done.[K
remote: Compressing objects: 100% (608/608), done.[K
remote: Total 868 (delta 454), reused 619 (delta 246), pack-reused 0[K
Receiving objects: 100% (868/868), 148.37 MiB | 22.99 MiB/s, done.
Resolving deltas: 100% (454/454), done.
Checking out files: 100% (113/113), done.
mnist.pkl


In [2]:
%load_ext autoreload
%autoreload 2

import pathlib

import utils
import vae
from datasets import get_mnist, get_batches, get_fashion_mnist
from local_opt import LocalHyperParams, local_opt
from utils import HyperParams

In [3]:
use_tpu = True
if use_tpu and "google.colab" in sys.modules:
  import jax.tools.colab_tpu
  try:
    jax.tools.colab_tpu.setup_tpu()
  except KeyError:
    print("Warning: No TPU access available.")



In [4]:
# -- Vary across experiment
is_larger = False
trained_has_flow = True
local_has_flow = True
kl_annealing = False
use_fashion = True
# --- 

## Name of this experiment (important to change for saving results)
name = "_".join([
  ["mnist","fashion"][use_fashion],
  ["ffg","flow"][trained_has_flow],
  ["regular","anneal"][kl_annealing],
  ["smaller","larger"][is_larger],
  ["local-ffg", "local-flow"][local_has_flow],
])
print(name)

fashion_flow_regular_smaller_local-flow


In [5]:
mount_google_drive = True

if mount_google_drive and "google.colab" in sys.modules:
  from google.colab import drive
  drive.mount("/content/drive")
  save_dir = "/content/drive/My Drive/ATML/" + name
else:
  save_dir = "./experiments/" + name

pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)

Mounted at /content/drive


In [6]:
hidden_size = (500, 500, 500) if is_larger else (200, 200)
hps = HyperParams(has_flow=local_has_flow, encoder_hidden=hidden_size, decoder_hidden=hidden_size)
local_hps = LocalHyperParams()

print(hps)
print(local_hps)

HyperParams(image_size=784, latent_size=50, encoder_hidden=(200, 200), decoder_hidden=(200, 200), has_flow=True, num_flows=2, flow_hidden_size=200)
LocalHyperParams(learning_rate=0.001, mc_samples=100, display_epoch=100, debug=True, iwae_samples=100, max_epochs=100000, patience=10, es_epsilon=0.05)


In [7]:
mnist = get_fashion_mnist() if use_fashion else get_mnist()

In [8]:
batch_size = 100
# only locally optimise 1000 due to computational cost
smaller_data = True
train_batches = get_batches(mnist["train_x"], batch_size, smaller_data)

In [10]:
# remove "local-ffg" or "local-flow" suffix 
trained_model_dir = save_dir[:save_dir.rfind("_")]
file_name = trained_model_dir + "/params.pkl"
trained_params = utils.load_params(file_name)

In [11]:
model = vae.VAE(hps)
elbos, iwaes, local_params = local_opt(local_hps, model, train_batches, trained_params)

Optimising Local Flow ...


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -305.9972
Epoch 200.0000 - ELBO -123.5582
Epoch 300.0000 - ELBO -114.7578
Epoch 400.0000 - ELBO -109.7302
Epoch 500.0000 - ELBO -107.1064
Epoch 600.0000 - ELBO -105.3846
Epoch 700.0000 - ELBO -104.2227
Epoch 800.0000 - ELBO -103.3557
Epoch 900.0000 - ELBO -102.6652
Epoch 1000.0000 - ELBO -102.0697
Epoch 1100.0000 - ELBO -101.6228
Epoch 1200.0000 - ELBO -101.2295
Epoch 1300.0000 - ELBO -100.9615
Epoch 1400.0000 - ELBO -100.6366
Epoch 1500.0000 - ELBO -100.4249
Epoch 1600.0000 - ELBO -100.2113
Epoch 1700.0000 - ELBO -100.0493
Epoch 1800.0000 - ELBO -99.8971
Epoch 1900.0000 - ELBO -99.7403
Epoch 2000.0000 - ELBO -99.6268
Epoch 2100.0000 - ELBO -99.4891
Epoch 2200.0000 - ELBO -99.4165
Epoch 2300.0000 - ELBO -99.3187
Epoch 2400.0000 - ELBO -99.2300
Epoch 2500.0000 - ELBO -99.1557
Epoch 2600.0000 - ELBO -99.0849
Epoch 2700.0000 - ELBO -99.0312
Epoch 2800.0000 - ELBO -98.9575
Epoch 2900.0000 - ELBO -98.9173
Epoch 3000.0000 - ELBO -98.8598
Epoch 3100.0000 - ELBO -98.8089


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -386.8521
Epoch 200.0000 - ELBO -133.3002
Epoch 300.0000 - ELBO -123.5401
Epoch 400.0000 - ELBO -118.3021
Epoch 500.0000 - ELBO -115.3762
Epoch 600.0000 - ELBO -113.5714
Epoch 700.0000 - ELBO -112.3205
Epoch 800.0000 - ELBO -111.4736
Epoch 900.0000 - ELBO -110.7805
Epoch 1000.0000 - ELBO -110.1638
Epoch 1100.0000 - ELBO -109.7661
Epoch 1200.0000 - ELBO -109.4127
Epoch 1300.0000 - ELBO -109.0830
Epoch 1400.0000 - ELBO -108.8341
Epoch 1500.0000 - ELBO -108.6108
Epoch 1600.0000 - ELBO -108.4067
Epoch 1700.0000 - ELBO -108.2471
Epoch 1800.0000 - ELBO -108.0690
Epoch 1900.0000 - ELBO -107.9724
Epoch 2000.0000 - ELBO -107.8176
Epoch 2100.0000 - ELBO -107.6981
Epoch 2200.0000 - ELBO -107.6104
Epoch 2300.0000 - ELBO -107.5100
Epoch 2400.0000 - ELBO -107.4396
Epoch 2500.0000 - ELBO -107.3680
Epoch 2600.0000 - ELBO -107.2886
Epoch 2700.0000 - ELBO -107.2256
Epoch 2800.0000 - ELBO -107.1179
Epoch 2900.0000 - ELBO -107.1377
Epoch 3000.0000 - ELBO -107.0345
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -275.2344
Epoch 200.0000 - ELBO -125.1395
Epoch 300.0000 - ELBO -115.6614
Epoch 400.0000 - ELBO -111.2074
Epoch 500.0000 - ELBO -108.9254
Epoch 600.0000 - ELBO -107.5037
Epoch 700.0000 - ELBO -106.5335
Epoch 800.0000 - ELBO -105.7633
Epoch 900.0000 - ELBO -105.1910
Epoch 1000.0000 - ELBO -104.7813
Epoch 1100.0000 - ELBO -104.3569
Epoch 1200.0000 - ELBO -104.0452
Epoch 1300.0000 - ELBO -103.7859
Epoch 1400.0000 - ELBO -103.5696
Epoch 1500.0000 - ELBO -103.3553
Epoch 1600.0000 - ELBO -103.2043
Epoch 1700.0000 - ELBO -103.0331
Epoch 1800.0000 - ELBO -102.9094
Epoch 1900.0000 - ELBO -102.8015
Epoch 2000.0000 - ELBO -102.7107
Epoch 2100.0000 - ELBO -102.5766
Epoch 2200.0000 - ELBO -102.5172
Epoch 2300.0000 - ELBO -102.4336
Epoch 2400.0000 - ELBO -102.3621
Epoch 2500.0000 - ELBO -102.2755
Epoch 2600.0000 - ELBO -102.2365
Epoch 2700.0000 - ELBO -102.1841
Epoch 2800.0000 - ELBO -102.1359
Epoch 2900.0000 - ELBO -102.0576
Epoch 3000.0000 - ELBO -102.0191
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -303.5396
Epoch 200.0000 - ELBO -123.9873
Epoch 300.0000 - ELBO -115.6767
Epoch 400.0000 - ELBO -111.2145
Epoch 500.0000 - ELBO -108.7231
Epoch 600.0000 - ELBO -107.1659
Epoch 700.0000 - ELBO -106.0488
Epoch 800.0000 - ELBO -105.2454
Epoch 900.0000 - ELBO -104.6280
Epoch 1000.0000 - ELBO -104.0933
Epoch 1100.0000 - ELBO -103.7089
Epoch 1200.0000 - ELBO -103.3438
Epoch 1300.0000 - ELBO -103.0467
Epoch 1400.0000 - ELBO -102.8149
Epoch 1500.0000 - ELBO -102.6139
Epoch 1600.0000 - ELBO -102.3565
Epoch 1700.0000 - ELBO -102.2402
Epoch 1800.0000 - ELBO -102.0926
Epoch 1900.0000 - ELBO -101.9357
Epoch 2000.0000 - ELBO -101.8322
Epoch 2100.0000 - ELBO -101.7148
Epoch 2200.0000 - ELBO -101.6222
Epoch 2300.0000 - ELBO -101.5340
Epoch 2400.0000 - ELBO -101.4660
Epoch 2500.0000 - ELBO -101.3863
Epoch 2600.0000 - ELBO -101.3111
Epoch 2700.0000 - ELBO -101.2431
Epoch 2800.0000 - ELBO -101.2008
Epoch 2900.0000 - ELBO -101.1445
Epoch 3000.0000 - ELBO -101.0864
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -406.0774
Epoch 200.0000 - ELBO -139.2126
Epoch 300.0000 - ELBO -127.9104
Epoch 400.0000 - ELBO -122.7359
Epoch 500.0000 - ELBO -119.9905
Epoch 600.0000 - ELBO -118.3110
Epoch 700.0000 - ELBO -117.2172
Epoch 800.0000 - ELBO -116.4371
Epoch 900.0000 - ELBO -115.7820
Epoch 1000.0000 - ELBO -115.3778
Epoch 1100.0000 - ELBO -114.9391
Epoch 1200.0000 - ELBO -114.6159
Epoch 1300.0000 - ELBO -114.3639
Epoch 1400.0000 - ELBO -114.1157
Epoch 1500.0000 - ELBO -113.9464
Epoch 1600.0000 - ELBO -113.7369
Epoch 1700.0000 - ELBO -113.5998
Epoch 1800.0000 - ELBO -113.4791
Epoch 1900.0000 - ELBO -113.3371
Epoch 2000.0000 - ELBO -113.2085
Epoch 2100.0000 - ELBO -113.1434
Epoch 2200.0000 - ELBO -113.0616
Epoch 2300.0000 - ELBO -112.9483
Epoch 2400.0000 - ELBO -112.8788
Epoch 2500.0000 - ELBO -112.8038
Epoch 2600.0000 - ELBO -112.7259
Epoch 2700.0000 - ELBO -112.6964
Epoch 2800.0000 - ELBO -112.6344
Epoch 2900.0000 - ELBO -112.5803
Epoch 3000.0000 - ELBO -112.5280
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -343.7636
Epoch 200.0000 - ELBO -125.8680
Epoch 300.0000 - ELBO -116.9015
Epoch 400.0000 - ELBO -111.6759
Epoch 500.0000 - ELBO -108.8351
Epoch 600.0000 - ELBO -106.9949
Epoch 700.0000 - ELBO -105.7271
Epoch 800.0000 - ELBO -104.8234
Epoch 900.0000 - ELBO -104.1298
Epoch 1000.0000 - ELBO -103.6253
Epoch 1100.0000 - ELBO -103.2205
Epoch 1200.0000 - ELBO -102.8314
Epoch 1300.0000 - ELBO -102.5836
Epoch 1400.0000 - ELBO -102.2584
Epoch 1500.0000 - ELBO -102.1055
Epoch 1600.0000 - ELBO -101.8792
Epoch 1700.0000 - ELBO -101.7129
Epoch 1800.0000 - ELBO -101.5687
Epoch 1900.0000 - ELBO -101.4442
Epoch 2000.0000 - ELBO -101.3168
Epoch 2100.0000 - ELBO -101.1938
Epoch 2200.0000 - ELBO -101.0997
Epoch 2300.0000 - ELBO -101.0220
Epoch 2400.0000 - ELBO -100.9347
Epoch 2500.0000 - ELBO -100.8613
Epoch 2600.0000 - ELBO -100.8009
Epoch 2700.0000 - ELBO -100.7154
Epoch 2800.0000 - ELBO -100.6583
Epoch 2900.0000 - ELBO -100.6032
Epoch 3000.0000 - ELBO -100.5645
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -357.1285
Epoch 200.0000 - ELBO -119.3045
Epoch 300.0000 - ELBO -109.5099
Epoch 400.0000 - ELBO -105.2618
Epoch 500.0000 - ELBO -102.9339
Epoch 600.0000 - ELBO -101.4639
Epoch 700.0000 - ELBO -100.4000
Epoch 800.0000 - ELBO -99.6132
Epoch 900.0000 - ELBO -98.9672
Epoch 1000.0000 - ELBO -98.4665
Epoch 1100.0000 - ELBO -98.0605
Epoch 1200.0000 - ELBO -97.7015
Epoch 1300.0000 - ELBO -97.4147
Epoch 1400.0000 - ELBO -97.1477
Epoch 1500.0000 - ELBO -96.9391
Epoch 1600.0000 - ELBO -96.7535
Epoch 1700.0000 - ELBO -96.5693
Epoch 1800.0000 - ELBO -96.4484
Epoch 1900.0000 - ELBO -96.3135
Epoch 2000.0000 - ELBO -96.1861
Epoch 2100.0000 - ELBO -96.0997
Epoch 2200.0000 - ELBO -96.0304
Epoch 2300.0000 - ELBO -95.9491
Epoch 2400.0000 - ELBO -95.8538
Epoch 2500.0000 - ELBO -95.7789
Epoch 2600.0000 - ELBO -95.7062
Epoch 2700.0000 - ELBO -95.6698
Epoch 2800.0000 - ELBO -95.6076
Epoch 2900.0000 - ELBO -95.5699
Epoch 3000.0000 - ELBO -95.5276
Epoch 3100.0000 - ELBO -95.4675
Epoch 3200

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -357.8299
Epoch 200.0000 - ELBO -125.4376
Epoch 300.0000 - ELBO -116.2650
Epoch 400.0000 - ELBO -111.3394
Epoch 500.0000 - ELBO -108.2646
Epoch 600.0000 - ELBO -106.1501
Epoch 700.0000 - ELBO -104.6912
Epoch 800.0000 - ELBO -103.6241
Epoch 900.0000 - ELBO -102.8131
Epoch 1000.0000 - ELBO -102.2333
Epoch 1100.0000 - ELBO -101.7234
Epoch 1200.0000 - ELBO -101.2998
Epoch 1300.0000 - ELBO -101.0089
Epoch 1400.0000 - ELBO -100.7544
Epoch 1500.0000 - ELBO -100.5057
Epoch 1600.0000 - ELBO -100.3242
Epoch 1700.0000 - ELBO -100.1428
Epoch 1800.0000 - ELBO -100.0167
Epoch 1900.0000 - ELBO -99.9083
Epoch 2000.0000 - ELBO -99.7739
Epoch 2100.0000 - ELBO -99.6409
Epoch 2200.0000 - ELBO -99.5741
Epoch 2300.0000 - ELBO -99.4867
Epoch 2400.0000 - ELBO -99.4099
Epoch 2500.0000 - ELBO -99.3294
Epoch 2600.0000 - ELBO -99.2798
Epoch 2700.0000 - ELBO -99.1979
Epoch 2800.0000 - ELBO -99.1210
Epoch 2900.0000 - ELBO -99.0825
Epoch 3000.0000 - ELBO -99.0409
Epoch 3100.0000 - ELBO -98.9635

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -276.4076
Epoch 200.0000 - ELBO -119.7841
Epoch 300.0000 - ELBO -112.0542
Epoch 400.0000 - ELBO -107.2758
Epoch 500.0000 - ELBO -104.3802
Epoch 600.0000 - ELBO -102.5248
Epoch 700.0000 - ELBO -101.3217
Epoch 800.0000 - ELBO -100.4418
Epoch 900.0000 - ELBO -99.8221
Epoch 1000.0000 - ELBO -99.2666
Epoch 1100.0000 - ELBO -98.8740
Epoch 1200.0000 - ELBO -98.5438
Epoch 1300.0000 - ELBO -98.2694
Epoch 1400.0000 - ELBO -98.0078
Epoch 1500.0000 - ELBO -97.8205
Epoch 1600.0000 - ELBO -97.6251
Epoch 1700.0000 - ELBO -97.4882
Epoch 1800.0000 - ELBO -97.3272
Epoch 1900.0000 - ELBO -97.2185
Epoch 2000.0000 - ELBO -97.1469
Epoch 2100.0000 - ELBO -97.0061
Epoch 2200.0000 - ELBO -96.9245
Epoch 2300.0000 - ELBO -96.8707
Epoch 2400.0000 - ELBO -96.7457
Epoch 2500.0000 - ELBO -96.7175
Epoch 2600.0000 - ELBO -96.6377
Epoch 2700.0000 - ELBO -96.6150
Epoch 2800.0000 - ELBO -96.5429
Epoch 2900.0000 - ELBO -96.4841
Epoch 3000.0000 - ELBO -96.4444
Epoch 3100.0000 - ELBO -96.3820
Epoch 320

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -290.7695
Epoch 200.0000 - ELBO -119.9870
Epoch 300.0000 - ELBO -110.4842
Epoch 400.0000 - ELBO -106.1395
Epoch 500.0000 - ELBO -103.8478
Epoch 600.0000 - ELBO -102.4111
Epoch 700.0000 - ELBO -101.4397
Epoch 800.0000 - ELBO -100.7067
Epoch 900.0000 - ELBO -100.1036
Epoch 1000.0000 - ELBO -99.7257
Epoch 1100.0000 - ELBO -99.2890
Epoch 1200.0000 - ELBO -99.0135
Epoch 1300.0000 - ELBO -98.7420
Epoch 1400.0000 - ELBO -98.5533
Epoch 1500.0000 - ELBO -98.3628
Epoch 1600.0000 - ELBO -98.1908
Epoch 1700.0000 - ELBO -98.0827
Epoch 1800.0000 - ELBO -97.9287
Epoch 1900.0000 - ELBO -97.8311
Epoch 2000.0000 - ELBO -97.7436
Epoch 2100.0000 - ELBO -97.6619
Epoch 2200.0000 - ELBO -97.5378
Epoch 2300.0000 - ELBO -97.4854
Epoch 2400.0000 - ELBO -97.4173
Epoch 2500.0000 - ELBO -97.3483
Epoch 2600.0000 - ELBO -97.3095
Epoch 2700.0000 - ELBO -97.2487
Epoch 2800.0000 - ELBO -97.2018
Epoch 2900.0000 - ELBO -97.1617
Epoch 3000.0000 - ELBO -97.1017
Epoch 3100.0000 - ELBO -97.0658
Epoch 32

In [12]:
utils.save_params(save_dir + "/elbos.pkl", elbos)
utils.save_params(save_dir + "/iwaes.pkl", iwaes)
utils.save_params(save_dir + "/local_params.pkl", local_params)