<a href="https://colab.research.google.com/github/ATML-2022-Group6/inference-suboptimality/blob/main/run_local_opt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys

if "google.colab" in sys.modules:
  ! git clone https://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@github.com/ATML-2022-Group6/inference-suboptimality
  ! cp -r inference-suboptimality/* .
  ! tar -xvf datasets/mnist.pkl.tar.gz && mv mnist.pkl datasets/

Cloning into 'inference-suboptimality'...
remote: Enumerating objects: 1167, done.[K
remote: Counting objects: 100% (1167/1167), done.[K
remote: Compressing objects: 100% (821/821), done.[K
remote: Total 1167 (delta 601), reused 857 (delta 332), pack-reused 0[K
Receiving objects: 100% (1167/1167), 327.89 MiB | 35.44 MiB/s, done.
Resolving deltas: 100% (601/601), done.
Checking out files: 100% (250/250), done.
mnist.pkl


In [1]:
%load_ext autoreload
%autoreload 2

import pathlib

import utils
import vae
from datasets import get_batches, get_dataset
from local_opt import LocalHyperParams, local_opt
from utils import HyperParams

In [4]:
use_tpu = True
if use_tpu and "google.colab" in sys.modules:
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()

In [5]:
# -- Vary across experiment
encoder_size = (200, 200)
decoder_size = (200, 200)
trained_has_flow = False
local_has_flow = True
kl_annealing = True
dataset_name = "kmnist" # mnist, fashion, kmnist
# --- 

def fmt_size(sizes): return "".join(str(size)[0] for size in list(sizes))

## Name of this experiment (important to change for saving results)
name = "_".join([
  dataset_name,
  ["ffg","flow"][trained_has_flow],
  ["regular","anneal"][kl_annealing],
  "e"+fmt_size(encoder_size)+"d"+fmt_size(decoder_size),
  ["local-ffg", "local-flow"][local_has_flow],
])
print(name)

kmnist_ffg_anneal_e22d22_local-flow


In [6]:
mount_google_drive = True

if mount_google_drive and "google.colab" in sys.modules:
  from google.colab import drive
  drive.mount("/content/drive")
  save_dir = "/content/drive/My Drive/ATML/" + name
else:
  save_dir = "./experiments/" + name

pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
hps = HyperParams(has_flow=local_has_flow, encoder_hidden=encoder_size, decoder_hidden=decoder_size)
local_hps = LocalHyperParams()

print(hps)
print(local_hps)

HyperParams(image_size=784, latent_size=50, encoder_hidden=(200, 200), decoder_hidden=(200, 200), has_flow=True, num_flows=2, flow_hidden_size=200)
LocalHyperParams(learning_rate=0.001, mc_samples=100, display_epoch=100, debug=True, iwae_samples=100, max_epochs=100000, patience=10, es_epsilon=0.05)


In [8]:
mnist = get_dataset(dataset_name)

In [9]:
batch_size = 100
# only locally optimise 1000 due to computational cost
smaller_data = True
train_batches = get_batches(mnist["train_x"], batch_size, smaller_data)

In [10]:
# remove "local-ffg" or "local-flow" suffix 
trained_model_dir = save_dir[:save_dir.rfind("_")]
file_name = trained_model_dir + "/params.pkl"
trained_params = utils.load_params(file_name)

In [11]:
model = vae.VAE(hps)
elbos, iwaes, local_params = local_opt(local_hps, model, train_batches, trained_params)

Optimising Local Flow ...


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -236.1785
Epoch 200.0000 - ELBO -150.5003
Epoch 300.0000 - ELBO -146.9034
Epoch 400.0000 - ELBO -145.4398
Epoch 500.0000 - ELBO -144.5657
Epoch 600.0000 - ELBO -144.1561
Epoch 700.0000 - ELBO -143.5932
Epoch 800.0000 - ELBO -143.3139
Epoch 900.0000 - ELBO -143.1120
Epoch 1000.0000 - ELBO -142.9018
Epoch 1100.0000 - ELBO -142.8071
Epoch 1200.0000 - ELBO -142.6584
Epoch 1300.0000 - ELBO -142.5801
Epoch 1400.0000 - ELBO -142.4940
Epoch 1500.0000 - ELBO -142.4163
Epoch 1600.0000 - ELBO -142.3246
Epoch 1700.0000 - ELBO -142.3454
Epoch 1800.0000 - ELBO -142.2675
Epoch 1900.0000 - ELBO -142.2616
Epoch 2000.0000 - ELBO -142.1731
Epoch 2100.0000 - ELBO -142.1476
Epoch 2200.0000 - ELBO -142.1855
Epoch 2300.0000 - ELBO -142.1171
Epoch 2400.0000 - ELBO -142.1210
Epoch 2500.0000 - ELBO -142.0634
Epoch 2600.0000 - ELBO -142.0851
Epoch 2700.0000 - ELBO -142.0215
Epoch 2800.0000 - ELBO -142.0039
Epoch 2900.0000 - ELBO -142.0060
Epoch 3000.0000 - ELBO -141.9642
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -219.7541
Epoch 200.0000 - ELBO -140.0282
Epoch 300.0000 - ELBO -136.7682
Epoch 400.0000 - ELBO -135.5122
Epoch 500.0000 - ELBO -134.7336
Epoch 600.0000 - ELBO -134.1944
Epoch 700.0000 - ELBO -133.8031
Epoch 800.0000 - ELBO -133.5839
Epoch 900.0000 - ELBO -133.3273
Epoch 1000.0000 - ELBO -133.1412
Epoch 1100.0000 - ELBO -133.0656
Epoch 1200.0000 - ELBO -132.8441
Epoch 1300.0000 - ELBO -132.8281
Epoch 1400.0000 - ELBO -132.7344
Epoch 1500.0000 - ELBO -132.6644
Epoch 1600.0000 - ELBO -132.6012
Epoch 1700.0000 - ELBO -132.5767
Epoch 1800.0000 - ELBO -132.5287
Epoch 1900.0000 - ELBO -132.4938
Epoch 2000.0000 - ELBO -132.4691
Epoch 2100.0000 - ELBO -132.4155
Epoch 2200.0000 - ELBO -132.4003
Epoch 2300.0000 - ELBO -132.3851
Epoch 2400.0000 - ELBO -132.3402
Epoch 2500.0000 - ELBO -132.3301
Epoch 2600.0000 - ELBO -132.3011
Epoch 2700.0000 - ELBO -132.2757
Epoch 2800.0000 - ELBO -132.2710
Epoch 2900.0000 - ELBO -132.2514
Epoch 3000.0000 - ELBO -132.2251
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -238.4444
Epoch 200.0000 - ELBO -150.6028
Epoch 300.0000 - ELBO -146.7102
Epoch 400.0000 - ELBO -145.1680
Epoch 500.0000 - ELBO -144.2468
Epoch 600.0000 - ELBO -143.6125
Epoch 700.0000 - ELBO -143.2150
Epoch 800.0000 - ELBO -142.8904
Epoch 900.0000 - ELBO -142.6612
Epoch 1000.0000 - ELBO -142.4993
Epoch 1100.0000 - ELBO -142.3025
Epoch 1200.0000 - ELBO -142.1601
Epoch 1300.0000 - ELBO -142.1048
Epoch 1400.0000 - ELBO -141.9741
Epoch 1500.0000 - ELBO -141.9045
Epoch 1600.0000 - ELBO -141.9335
Epoch 1700.0000 - ELBO -141.7932
Epoch 1800.0000 - ELBO -141.7509
Epoch 1900.0000 - ELBO -141.7305
Epoch 2000.0000 - ELBO -141.7340
Epoch 2100.0000 - ELBO -141.6497
Epoch 2200.0000 - ELBO -141.6184
Epoch 2300.0000 - ELBO -141.6185
Epoch 2400.0000 - ELBO -141.5904
Epoch 2500.0000 - ELBO -141.5475
Epoch 2600.0000 - ELBO -141.5305
Epoch 2700.0000 - ELBO -141.5410
Epoch 2800.0000 - ELBO -141.4787
Epoch 2900.0000 - ELBO -141.4711
Epoch 3000.0000 - ELBO -141.4717
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -224.3361
Epoch 200.0000 - ELBO -142.6082
Epoch 300.0000 - ELBO -139.0896
Epoch 400.0000 - ELBO -137.7831
Epoch 500.0000 - ELBO -136.9319
Epoch 600.0000 - ELBO -136.3979
Epoch 700.0000 - ELBO -135.9221
Epoch 800.0000 - ELBO -135.6846
Epoch 900.0000 - ELBO -135.4687
Epoch 1000.0000 - ELBO -135.2766
Epoch 1100.0000 - ELBO -135.1672
Epoch 1200.0000 - ELBO -135.0682
Epoch 1300.0000 - ELBO -134.9408
Epoch 1400.0000 - ELBO -134.8849
Epoch 1500.0000 - ELBO -134.7863
Epoch 1600.0000 - ELBO -134.7543
Epoch 1700.0000 - ELBO -134.6954
Epoch 1800.0000 - ELBO -134.6551
Epoch 1900.0000 - ELBO -134.6385
Epoch 2000.0000 - ELBO -134.5631
Epoch 2100.0000 - ELBO -134.5996
Epoch 2200.0000 - ELBO -134.5263
Epoch 2300.0000 - ELBO -134.4925
Epoch 2400.0000 - ELBO -134.4872
Epoch 2500.0000 - ELBO -134.4508
Epoch 2600.0000 - ELBO -134.4397
Epoch 2700.0000 - ELBO -134.4146
Epoch 2800.0000 - ELBO -134.4042
Epoch 2900.0000 - ELBO -134.3776
Epoch 3000.0000 - ELBO -134.3713
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -240.9662
Epoch 200.0000 - ELBO -153.7899
Epoch 300.0000 - ELBO -150.0539
Epoch 400.0000 - ELBO -148.6395
Epoch 500.0000 - ELBO -147.7385
Epoch 600.0000 - ELBO -147.3735
Epoch 700.0000 - ELBO -146.6984
Epoch 800.0000 - ELBO -146.4234
Epoch 900.0000 - ELBO -146.1868
Epoch 1000.0000 - ELBO -146.0429
Epoch 1100.0000 - ELBO -145.8987
Epoch 1200.0000 - ELBO -145.8071
Epoch 1300.0000 - ELBO -145.6953
Epoch 1400.0000 - ELBO -145.5891
Epoch 1500.0000 - ELBO -145.5894
Epoch 1600.0000 - ELBO -145.4717
Epoch 1700.0000 - ELBO -145.4675
Epoch 1800.0000 - ELBO -145.4116
Epoch 1900.0000 - ELBO -145.3745
Epoch 2000.0000 - ELBO -145.3259
Epoch 2100.0000 - ELBO -145.2951
Epoch 2200.0000 - ELBO -145.2816
Epoch 2300.0000 - ELBO -145.2450
Epoch 2400.0000 - ELBO -145.2113
Epoch 2500.0000 - ELBO -145.2081
Epoch 2600.0000 - ELBO -145.1492
Epoch 2700.0000 - ELBO -145.1660
Epoch 2800.0000 - ELBO -145.1602
Epoch 2900.0000 - ELBO -145.1203
Epoch 3000.0000 - ELBO -145.1008
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -238.8889
Epoch 200.0000 - ELBO -150.9845
Epoch 300.0000 - ELBO -147.1973
Epoch 400.0000 - ELBO -145.6719
Epoch 500.0000 - ELBO -144.8357
Epoch 600.0000 - ELBO -144.5201
Epoch 700.0000 - ELBO -143.7124
Epoch 800.0000 - ELBO -143.4760
Epoch 900.0000 - ELBO -143.2645
Epoch 1000.0000 - ELBO -143.0629
Epoch 1100.0000 - ELBO -142.9487
Epoch 1200.0000 - ELBO -142.8238
Epoch 1300.0000 - ELBO -142.6614
Epoch 1400.0000 - ELBO -142.6266
Epoch 1500.0000 - ELBO -142.5057
Epoch 1600.0000 - ELBO -142.5117
Epoch 1700.0000 - ELBO -142.4228
Epoch 1800.0000 - ELBO -142.3895
Epoch 1900.0000 - ELBO -142.3778
Epoch 2000.0000 - ELBO -142.3175
Epoch 2100.0000 - ELBO -142.2367
Epoch 2200.0000 - ELBO -142.2322
Epoch 2300.0000 - ELBO -142.2434
Epoch 2400.0000 - ELBO -142.2066
Epoch 2500.0000 - ELBO -142.1659
Epoch 2600.0000 - ELBO -142.1154
Epoch 2700.0000 - ELBO -142.1368
Epoch 2800.0000 - ELBO -142.1204
Epoch 2900.0000 - ELBO -142.0979
Epoch 3000.0000 - ELBO -142.0578
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -217.0124
Epoch 200.0000 - ELBO -137.1897
Epoch 300.0000 - ELBO -133.9845
Epoch 400.0000 - ELBO -132.6491
Epoch 500.0000 - ELBO -131.9294
Epoch 600.0000 - ELBO -131.2629
Epoch 700.0000 - ELBO -130.9224
Epoch 800.0000 - ELBO -130.6534
Epoch 900.0000 - ELBO -130.5381
Epoch 1000.0000 - ELBO -130.2360
Epoch 1100.0000 - ELBO -130.1417
Epoch 1200.0000 - ELBO -130.0502
Epoch 1300.0000 - ELBO -129.9686
Epoch 1400.0000 - ELBO -129.8705
Epoch 1500.0000 - ELBO -129.8076
Epoch 1600.0000 - ELBO -129.7625
Epoch 1700.0000 - ELBO -129.7107
Epoch 1800.0000 - ELBO -129.6551
Epoch 1900.0000 - ELBO -129.6434
Epoch 2000.0000 - ELBO -129.5899
Epoch 2100.0000 - ELBO -129.5961
Epoch 2200.0000 - ELBO -129.5541
Epoch 2300.0000 - ELBO -129.5240
Epoch 2400.0000 - ELBO -129.5134
Epoch 2500.0000 - ELBO -129.4380
Epoch 2600.0000 - ELBO -129.4543
Epoch 2700.0000 - ELBO -129.4372
Epoch 2800.0000 - ELBO -129.4442
Epoch 2900.0000 - ELBO -129.3898
Epoch 3000.0000 - ELBO -129.3944
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -235.2356
Epoch 200.0000 - ELBO -148.3804
Epoch 300.0000 - ELBO -144.6977
Epoch 400.0000 - ELBO -143.1732
Epoch 500.0000 - ELBO -142.5482
Epoch 600.0000 - ELBO -141.7343
Epoch 700.0000 - ELBO -141.3712
Epoch 800.0000 - ELBO -141.1022
Epoch 900.0000 - ELBO -140.8761
Epoch 1000.0000 - ELBO -140.6527
Epoch 1100.0000 - ELBO -140.4661
Epoch 1200.0000 - ELBO -140.3959
Epoch 1300.0000 - ELBO -140.3014
Epoch 1400.0000 - ELBO -140.2090
Epoch 1500.0000 - ELBO -140.1670
Epoch 1600.0000 - ELBO -140.0690
Epoch 1700.0000 - ELBO -140.0430
Epoch 1800.0000 - ELBO -139.9888
Epoch 1900.0000 - ELBO -139.9718
Epoch 2000.0000 - ELBO -139.8952
Epoch 2100.0000 - ELBO -139.9121
Epoch 2200.0000 - ELBO -139.8582
Epoch 2300.0000 - ELBO -139.8394
Epoch 2400.0000 - ELBO -139.8043
Epoch 2500.0000 - ELBO -139.7952
Epoch 2600.0000 - ELBO -139.7339
Epoch 2700.0000 - ELBO -139.8082
Epoch 2800.0000 - ELBO -139.7194
Epoch 2900.0000 - ELBO -139.6968
Epoch 3000.0000 - ELBO -139.6988
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -227.2186
Epoch 200.0000 - ELBO -144.2200
Epoch 300.0000 - ELBO -140.8138
Epoch 400.0000 - ELBO -139.4886
Epoch 500.0000 - ELBO -138.6947
Epoch 600.0000 - ELBO -138.1130
Epoch 700.0000 - ELBO -137.7375
Epoch 800.0000 - ELBO -137.4173
Epoch 900.0000 - ELBO -137.2534
Epoch 1000.0000 - ELBO -136.9859
Epoch 1100.0000 - ELBO -136.8892
Epoch 1200.0000 - ELBO -136.8147
Epoch 1300.0000 - ELBO -136.6778
Epoch 1400.0000 - ELBO -136.6260
Epoch 1500.0000 - ELBO -136.6142
Epoch 1600.0000 - ELBO -136.4657
Epoch 1700.0000 - ELBO -136.4626
Epoch 1800.0000 - ELBO -136.3951
Epoch 1900.0000 - ELBO -136.3905
Epoch 2000.0000 - ELBO -136.3090
Epoch 2100.0000 - ELBO -136.2874
Epoch 2200.0000 - ELBO -136.3376
Epoch 2300.0000 - ELBO -136.2508
Epoch 2400.0000 - ELBO -136.2112
Epoch 2500.0000 - ELBO -136.1920
Epoch 2600.0000 - ELBO -136.2050
Epoch 2700.0000 - ELBO -136.1663
Epoch 2800.0000 - ELBO -136.1521
Epoch 2900.0000 - ELBO -136.1691
Epoch 3000.0000 - ELBO -136.1008
Epoch 3100.0000 - E

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -232.3789
Epoch 200.0000 - ELBO -146.4945
Epoch 300.0000 - ELBO -143.0434
Epoch 400.0000 - ELBO -141.3934
Epoch 500.0000 - ELBO -140.7554
Epoch 600.0000 - ELBO -140.0810
Epoch 700.0000 - ELBO -139.6636
Epoch 800.0000 - ELBO -139.3981
Epoch 900.0000 - ELBO -139.1978
Epoch 1000.0000 - ELBO -138.9705
Epoch 1100.0000 - ELBO -138.8649
Epoch 1200.0000 - ELBO -138.6908
Epoch 1300.0000 - ELBO -138.8091
Epoch 1400.0000 - ELBO -138.4989
Epoch 1500.0000 - ELBO -138.4930
Epoch 1600.0000 - ELBO -138.4123
Epoch 1700.0000 - ELBO -138.4240
Epoch 1800.0000 - ELBO -138.3565
Epoch 1900.0000 - ELBO -138.3228
Epoch 2000.0000 - ELBO -138.2829
Epoch 2100.0000 - ELBO -138.2705
Epoch 2200.0000 - ELBO -138.2055
Epoch 2300.0000 - ELBO -138.1991
Epoch 2400.0000 - ELBO -138.2215
Epoch 2500.0000 - ELBO -138.1363
Epoch 2600.0000 - ELBO -138.1224
Epoch 2700.0000 - ELBO -138.1268
Epoch 2800.0000 - ELBO -138.0737
Epoch 2900.0000 - ELBO -138.0731
Epoch 3000.0000 - ELBO -138.0592
Epoch 3100.0000 - E

In [12]:
utils.save_params(save_dir + "/elbos.pkl", elbos)
utils.save_params(save_dir + "/iwaes.pkl", iwaes)
utils.save_params(save_dir + "/local_params.pkl", local_params)