<a href="https://colab.research.google.com/github/ATML-2022-Group6/inference-suboptimality/blob/main/run_local_opt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys

if "google.colab" in sys.modules:
  ! git clone https://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@github.com/ATML-2022-Group6/inference-suboptimality
  ! cp -r inference-suboptimality/* .
  ! tar -xvf datasets/mnist.pkl.tar.gz && mv mnist.pkl datasets/

Cloning into 'inference-suboptimality'...
remote: Enumerating objects: 1137, done.[K
remote: Counting objects: 100% (1137/1137), done.[K
remote: Compressing objects: 100% (798/798), done.[K
remote: Total 1137 (delta 582), reused 839 (delta 325), pack-reused 0[K
Receiving objects: 100% (1137/1137), 326.79 MiB | 17.47 MiB/s, done.
Resolving deltas: 100% (582/582), done.
Checking out files: 100% (238/238), done.
mnist.pkl


In [2]:
%load_ext autoreload
%autoreload 2

import pathlib

import utils
import vae
from datasets import get_batches, get_dataset
from local_opt import LocalHyperParams, local_opt
from utils import HyperParams

In [4]:
use_tpu = False
if use_tpu and "google.colab" in sys.modules:
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()

In [26]:
# -- Vary across experiment
encoder_size = (200, 200)
decoder_size = (200, 200)
trained_has_flow = False
local_has_flow = False
kl_annealing = True
dataset_name = "kmnist" # mnist, fashion, kmnist
# --- 

def fmt_size(sizes): return "".join(str(size)[0] for size in list(sizes))

## Name of this experiment (important to change for saving results)
name = "_".join([
  dataset_name,
  ["ffg","flow"][trained_has_flow],
  ["regular","anneal"][kl_annealing],
  "e"+fmt_size(encoder_size)+"d"+fmt_size(decoder_size),
  ["local-ffg", "local-flow"][local_has_flow],
])
print(name)

kmnist_ffg_anneal_e22d22_local-ffg


In [27]:
mount_google_drive = True

if mount_google_drive and "google.colab" in sys.modules:
  from google.colab import drive
  drive.mount("/content/drive")
  save_dir = "/content/drive/My Drive/ATML/" + name
else:
  save_dir = "./experiments/" + name

pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [28]:
hps = HyperParams(has_flow=local_has_flow, encoder_hidden=encoder_size, decoder_hidden=decoder_size)
local_hps = LocalHyperParams()

print(hps)
print(local_hps)

HyperParams(image_size=784, latent_size=50, encoder_hidden=(200, 200), decoder_hidden=(200, 200), has_flow=False, num_flows=2, flow_hidden_size=200)
LocalHyperParams(learning_rate=0.001, mc_samples=100, display_epoch=100, debug=True, iwae_samples=100, max_epochs=100000, patience=10, es_epsilon=0.05)


In [29]:
mnist = get_dataset(dataset_name)

In [30]:
batch_size = 100
# only locally optimise 1000 due to computational cost
smaller_data = True
train_batches = get_batches(mnist["train_x"], batch_size, smaller_data)

In [31]:
# remove "local-ffg" or "local-flow" suffix 
trained_model_dir = save_dir[:save_dir.rfind("_")]
file_name = trained_model_dir + "/params.pkl"
trained_params = utils.load_params(file_name)

In [32]:
model = vae.VAE(hps)
elbos, iwaes, local_params = local_opt(local_hps, model, train_batches, trained_params)

Optimising Local FFG ...


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -146.8355
Epoch 200.0000 - ELBO -144.3323
Epoch 300.0000 - ELBO -143.2755
Epoch 400.0000 - ELBO -142.7046
Epoch 500.0000 - ELBO -142.3848
Epoch 600.0000 - ELBO -142.1963
Epoch 700.0000 - ELBO -142.0842
Epoch 800.0000 - ELBO -142.0111
Epoch 900.0000 - ELBO -141.9634
Epoch 1000.0000 - ELBO -141.9315
Epoch 1100.0000 - ELBO -141.9191
Epoch 1200.0000 - ELBO -141.9140
Epoch 1300.0000 - ELBO -141.8996
Epoch 1400.0000 - ELBO -141.8949
Epoch 1500.0000 - ELBO -141.8997
Epoch 1600.0000 - ELBO -141.8917
Epoch 1700.0000 - ELBO -141.8937
Epoch 1800.0000 - ELBO -141.8963
Epoch 1900.0000 - ELBO -141.8877
Epoch 2000.0000 - ELBO -141.8858
Batch 1, ELBO -141.9303, IWAE -137.7619


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -136.4505
Epoch 200.0000 - ELBO -134.3091
Epoch 300.0000 - ELBO -133.3971
Epoch 400.0000 - ELBO -132.9045
Epoch 500.0000 - ELBO -132.6226
Epoch 600.0000 - ELBO -132.4563
Epoch 700.0000 - ELBO -132.3377
Epoch 800.0000 - ELBO -132.2751
Epoch 900.0000 - ELBO -132.2278
Epoch 1000.0000 - ELBO -132.2120
Epoch 1100.0000 - ELBO -132.1910
Epoch 1200.0000 - ELBO -132.1781
Epoch 1300.0000 - ELBO -132.1674
Epoch 1400.0000 - ELBO -132.1693
Epoch 1500.0000 - ELBO -132.1644
Epoch 1600.0000 - ELBO -132.1576
Epoch 1700.0000 - ELBO -132.1646
Epoch 1800.0000 - ELBO -132.1572
Epoch 1900.0000 - ELBO -132.1620
Epoch 2000.0000 - ELBO -132.1588
Epoch 2100.0000 - ELBO -132.1593
Epoch 2200.0000 - ELBO -132.1592
Epoch 2300.0000 - ELBO -132.1572
Epoch 2400.0000 - ELBO -132.1588
Epoch 2500.0000 - ELBO -132.1629
Epoch 2600.0000 - ELBO -132.1571
Batch 2, ELBO -132.0840, IWAE -127.9872


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -146.3714
Epoch 200.0000 - ELBO -143.8591
Epoch 300.0000 - ELBO -142.7990
Epoch 400.0000 - ELBO -142.2336
Epoch 500.0000 - ELBO -141.8953
Epoch 600.0000 - ELBO -141.6941
Epoch 700.0000 - ELBO -141.5757
Epoch 800.0000 - ELBO -141.4982
Epoch 900.0000 - ELBO -141.4483
Epoch 1000.0000 - ELBO -141.4125
Epoch 1100.0000 - ELBO -141.3967
Epoch 1200.0000 - ELBO -141.3820
Epoch 1300.0000 - ELBO -141.3757
Epoch 1400.0000 - ELBO -141.3727
Epoch 1500.0000 - ELBO -141.3644
Epoch 1600.0000 - ELBO -141.3598
Epoch 1700.0000 - ELBO -141.3674
Epoch 1800.0000 - ELBO -141.3554
Epoch 1900.0000 - ELBO -141.3602
Epoch 2000.0000 - ELBO -141.3620
Epoch 2100.0000 - ELBO -141.3568
Epoch 2200.0000 - ELBO -141.3651
Epoch 2300.0000 - ELBO -141.3607
Epoch 2400.0000 - ELBO -141.3631
Epoch 2500.0000 - ELBO -141.3600
Epoch 2600.0000 - ELBO -141.3619
Batch 3, ELBO -141.2820, IWAE -137.0411


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -138.9006
Epoch 200.0000 - ELBO -136.6224
Epoch 300.0000 - ELBO -135.6417
Epoch 400.0000 - ELBO -135.1167
Epoch 500.0000 - ELBO -134.8121
Epoch 600.0000 - ELBO -134.6293
Epoch 700.0000 - ELBO -134.5224
Epoch 800.0000 - ELBO -134.4589
Epoch 900.0000 - ELBO -134.4192
Epoch 1000.0000 - ELBO -134.3860
Epoch 1100.0000 - ELBO -134.3678
Epoch 1200.0000 - ELBO -134.3599
Epoch 1300.0000 - ELBO -134.3486
Epoch 1400.0000 - ELBO -134.3496
Epoch 1500.0000 - ELBO -134.3447
Epoch 1600.0000 - ELBO -134.3409
Epoch 1700.0000 - ELBO -134.3398
Epoch 1800.0000 - ELBO -134.3474
Epoch 1900.0000 - ELBO -134.3408
Epoch 2000.0000 - ELBO -134.3375
Batch 4, ELBO -134.2534, IWAE -130.0074


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -149.9985
Epoch 200.0000 - ELBO -147.5358
Epoch 300.0000 - ELBO -146.4627
Epoch 400.0000 - ELBO -145.8839
Epoch 500.0000 - ELBO -145.5510
Epoch 600.0000 - ELBO -145.3464
Epoch 700.0000 - ELBO -145.2181
Epoch 800.0000 - ELBO -145.1443
Epoch 900.0000 - ELBO -145.0957
Epoch 1000.0000 - ELBO -145.0634
Epoch 1100.0000 - ELBO -145.0418
Epoch 1200.0000 - ELBO -145.0301
Epoch 1300.0000 - ELBO -145.0199
Epoch 1400.0000 - ELBO -145.0177
Epoch 1500.0000 - ELBO -145.0097
Epoch 1600.0000 - ELBO -145.0137
Epoch 1700.0000 - ELBO -145.0132
Epoch 1800.0000 - ELBO -145.0165
Epoch 1900.0000 - ELBO -145.0147
Epoch 2000.0000 - ELBO -145.0109
Epoch 2100.0000 - ELBO -145.0054
Epoch 2200.0000 - ELBO -145.0060
Epoch 2300.0000 - ELBO -145.0071
Epoch 2400.0000 - ELBO -145.0049
Epoch 2500.0000 - ELBO -145.0040
Batch 5, ELBO -144.9953, IWAE -140.7339


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -146.7659
Epoch 200.0000 - ELBO -144.3494
Epoch 300.0000 - ELBO -143.3312
Epoch 400.0000 - ELBO -142.7818
Epoch 500.0000 - ELBO -142.4783
Epoch 600.0000 - ELBO -142.2908
Epoch 700.0000 - ELBO -142.1712
Epoch 800.0000 - ELBO -142.1071
Epoch 900.0000 - ELBO -142.0645
Epoch 1000.0000 - ELBO -142.0376
Epoch 1100.0000 - ELBO -142.0175
Epoch 1200.0000 - ELBO -142.0099
Epoch 1300.0000 - ELBO -142.0009
Epoch 1400.0000 - ELBO -141.9906
Epoch 1500.0000 - ELBO -141.9950
Epoch 1600.0000 - ELBO -141.9922
Epoch 1700.0000 - ELBO -141.9904
Epoch 1800.0000 - ELBO -141.9965
Epoch 1900.0000 - ELBO -142.0005
Epoch 2000.0000 - ELBO -141.9911
Batch 6, ELBO -141.9807, IWAE -137.7146


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -133.5528
Epoch 200.0000 - ELBO -131.3560
Epoch 300.0000 - ELBO -130.4603
Epoch 400.0000 - ELBO -129.9794
Epoch 500.0000 - ELBO -129.7191
Epoch 600.0000 - ELBO -129.5611
Epoch 700.0000 - ELBO -129.4694
Epoch 800.0000 - ELBO -129.4162
Epoch 900.0000 - ELBO -129.3865
Epoch 1000.0000 - ELBO -129.3550
Epoch 1100.0000 - ELBO -129.3421
Epoch 1200.0000 - ELBO -129.3396
Epoch 1300.0000 - ELBO -129.3308
Epoch 1400.0000 - ELBO -129.3268
Epoch 1500.0000 - ELBO -129.3226
Epoch 1600.0000 - ELBO -129.3261
Epoch 1700.0000 - ELBO -129.3205
Epoch 1800.0000 - ELBO -129.3205
Epoch 1900.0000 - ELBO -129.3226
Epoch 2000.0000 - ELBO -129.3185
Batch 7, ELBO -129.2908, IWAE -125.1722


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -144.3593
Epoch 200.0000 - ELBO -141.9700
Epoch 300.0000 - ELBO -140.9642
Epoch 400.0000 - ELBO -140.4292
Epoch 500.0000 - ELBO -140.1117
Epoch 600.0000 - ELBO -139.9219
Epoch 700.0000 - ELBO -139.8054
Epoch 800.0000 - ELBO -139.7282
Epoch 900.0000 - ELBO -139.6827
Epoch 1000.0000 - ELBO -139.6551
Epoch 1100.0000 - ELBO -139.6334
Epoch 1200.0000 - ELBO -139.6210
Epoch 1300.0000 - ELBO -139.6150
Epoch 1400.0000 - ELBO -139.6064
Epoch 1500.0000 - ELBO -139.6022
Epoch 1600.0000 - ELBO -139.5975
Epoch 1700.0000 - ELBO -139.5978
Epoch 1800.0000 - ELBO -139.6015
Epoch 1900.0000 - ELBO -139.5950
Epoch 2000.0000 - ELBO -139.5931
Epoch 2100.0000 - ELBO -139.6012
Epoch 2200.0000 - ELBO -139.5992
Epoch 2300.0000 - ELBO -139.5955
Epoch 2400.0000 - ELBO -139.5976
Epoch 2500.0000 - ELBO -139.5980
Batch 8, ELBO -139.5885, IWAE -135.1989


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -140.5182
Epoch 200.0000 - ELBO -138.2551
Epoch 300.0000 - ELBO -137.2884
Epoch 400.0000 - ELBO -136.7689
Epoch 500.0000 - ELBO -136.4700
Epoch 600.0000 - ELBO -136.2997
Epoch 700.0000 - ELBO -136.1833
Epoch 800.0000 - ELBO -136.1117
Epoch 900.0000 - ELBO -136.0741
Epoch 1000.0000 - ELBO -136.0470
Epoch 1100.0000 - ELBO -136.0387
Epoch 1200.0000 - ELBO -136.0282
Epoch 1300.0000 - ELBO -136.0213
Epoch 1400.0000 - ELBO -136.0148
Epoch 1500.0000 - ELBO -136.0160
Epoch 1600.0000 - ELBO -136.0117
Epoch 1700.0000 - ELBO -136.0097
Epoch 1800.0000 - ELBO -136.0148
Epoch 1900.0000 - ELBO -136.0104
Epoch 2000.0000 - ELBO -136.0055
Batch 9, ELBO -136.0036, IWAE -131.6384


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -142.7559
Epoch 200.0000 - ELBO -140.3525
Epoch 300.0000 - ELBO -139.3367
Epoch 400.0000 - ELBO -138.7830
Epoch 500.0000 - ELBO -138.4727
Epoch 600.0000 - ELBO -138.2755
Epoch 700.0000 - ELBO -138.1550
Epoch 800.0000 - ELBO -138.0901
Epoch 900.0000 - ELBO -138.0359
Epoch 1000.0000 - ELBO -138.0034
Epoch 1100.0000 - ELBO -137.9944
Epoch 1200.0000 - ELBO -137.9780
Epoch 1300.0000 - ELBO -137.9705
Epoch 1400.0000 - ELBO -137.9667
Epoch 1500.0000 - ELBO -137.9708
Epoch 1600.0000 - ELBO -137.9607
Epoch 1700.0000 - ELBO -137.9589
Epoch 1800.0000 - ELBO -137.9616
Epoch 1900.0000 - ELBO -137.9660
Epoch 2000.0000 - ELBO -137.9530
Epoch 2100.0000 - ELBO -137.9610
Epoch 2200.0000 - ELBO -137.9549
Batch 10, ELBO -137.9586, IWAE -133.7064
Average ELBO -137.9367
Average IWAE -133.6962


In [33]:
utils.save_params(save_dir + "/elbos.pkl", elbos)
utils.save_params(save_dir + "/iwaes.pkl", iwaes)
utils.save_params(save_dir + "/local_params.pkl", local_params)