<a href="https://colab.research.google.com/github/ATML-2022-Group6/inference-suboptimality/blob/main/run_local_opt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import sys

if "google.colab" in sys.modules:
  ! git clone https://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@github.com/ATML-2022-Group6/inference-suboptimality
  ! cp -r inference-suboptimality/* .
  ! tar -xvf datasets/mnist.pkl.tar.gz && mv mnist.pkl datasets/

fatal: destination path 'inference-suboptimality' already exists and is not an empty directory.
mnist.pkl


In [5]:
%load_ext autoreload
%autoreload 2

import pathlib

import utils
import vae
from datasets import get_batches, get_dataset
from local_opt import LocalHyperParams, local_opt
from utils import HyperParams

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
use_tpu = False
if use_tpu and "google.colab" in sys.modules:
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()

In [7]:
# -- Vary across experiment
encoder_size = (200, 200)
decoder_size = (200, 200, 200, 200)
trained_has_flow = False
local_has_flow = False
kl_annealing = True
dataset_name = "mnist" # mnist, fashion, kmnist
# --- 

def fmt_size(sizes): return "".join(str(size)[0] for size in list(sizes))

## Name of this experiment (important to change for saving results)
name = "_".join([
  dataset_name,
  ["ffg","flow"][trained_has_flow],
  ["regular","anneal"][kl_annealing],
  "e"+fmt_size(encoder_size)+"d"+fmt_size(decoder_size),
  ["local-ffg", "local-flow"][local_has_flow],
])
print(name)

mnist_ffg_anneal_e22d2222_local-ffg


In [8]:
mount_google_drive = False

if mount_google_drive and "google.colab" in sys.modules:
  from google.colab import drive
  drive.mount("/content/drive")
  save_dir = "/content/drive/My Drive/ATML/" + name
else:
  save_dir = "./experiments/" + name

pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)

In [9]:
hps = HyperParams(has_flow=local_has_flow, encoder_hidden=encoder_size, decoder_hidden=decoder_size)
local_hps = LocalHyperParams()

print(hps)
print(local_hps)

HyperParams(image_size=784, latent_size=50, encoder_hidden=(200, 200), decoder_hidden=(200, 200, 200, 200), has_flow=False, num_flows=2, flow_hidden_size=200)
LocalHyperParams(learning_rate=0.001, mc_samples=100, display_epoch=100, debug=True, iwae_samples=100, max_epochs=100000, patience=10, es_epsilon=0.05)


In [10]:
mnist = get_dataset(dataset_name)

In [11]:
batch_size = 100
# only locally optimise 1000 due to computational cost
smaller_data = True
train_batches = get_batches(mnist["train_x"], batch_size, smaller_data)

In [12]:
# remove "local-ffg" or "local-flow" suffix 
trained_model_dir = save_dir[:save_dir.rfind("_")]
file_name = trained_model_dir + "/params.pkl"
trained_params = utils.load_params(file_name)

In [13]:
model = vae.VAE(hps)
elbos, iwaes, local_params = local_opt(local_hps, model, train_batches, trained_params)

Optimising Local FFG ...


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -88.3701
Epoch 200.0000 - ELBO -87.8807
Epoch 300.0000 - ELBO -87.7324
Epoch 400.0000 - ELBO -87.6582
Epoch 500.0000 - ELBO -87.6177
Epoch 600.0000 - ELBO -87.5937
Epoch 700.0000 - ELBO -87.5785
Epoch 800.0000 - ELBO -87.5676
Epoch 900.0000 - ELBO -87.5592
Epoch 1000.0000 - ELBO -87.5563
Epoch 1100.0000 - ELBO -87.5549
Epoch 1200.0000 - ELBO -87.5512
Epoch 1300.0000 - ELBO -87.5470
Epoch 1400.0000 - ELBO -87.5445
Epoch 1500.0000 - ELBO -87.5452
Epoch 1600.0000 - ELBO -87.5447
Batch 1, ELBO -87.5200, IWAE -86.3114


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -87.3992
Epoch 200.0000 - ELBO -86.9077
Epoch 300.0000 - ELBO -86.7526
Epoch 400.0000 - ELBO -86.6788
Epoch 500.0000 - ELBO -86.6412
Epoch 600.0000 - ELBO -86.6144
Epoch 700.0000 - ELBO -86.6033
Epoch 800.0000 - ELBO -86.5961
Epoch 900.0000 - ELBO -86.5924
Epoch 1000.0000 - ELBO -86.5832
Epoch 1100.0000 - ELBO -86.5816
Epoch 1200.0000 - ELBO -86.5808
Epoch 1300.0000 - ELBO -86.5780
Epoch 1400.0000 - ELBO -86.5800
Epoch 1500.0000 - ELBO -86.5803
Epoch 1600.0000 - ELBO -86.5762
Batch 2, ELBO -86.5973, IWAE -85.5345


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -88.9981
Epoch 200.0000 - ELBO -88.5158
Epoch 300.0000 - ELBO -88.3641
Epoch 400.0000 - ELBO -88.2910
Epoch 500.0000 - ELBO -88.2541
Epoch 600.0000 - ELBO -88.2341
Epoch 700.0000 - ELBO -88.2191
Epoch 800.0000 - ELBO -88.2140
Epoch 900.0000 - ELBO -88.2083
Epoch 1000.0000 - ELBO -88.2010
Epoch 1100.0000 - ELBO -88.2016
Epoch 1200.0000 - ELBO -88.2035
Epoch 1300.0000 - ELBO -88.1995
Epoch 1400.0000 - ELBO -88.2020
Epoch 1500.0000 - ELBO -88.2007
Epoch 1600.0000 - ELBO -88.1986
Batch 3, ELBO -88.1886, IWAE -87.0449


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -85.4450
Epoch 200.0000 - ELBO -85.0472
Epoch 300.0000 - ELBO -84.9263
Epoch 400.0000 - ELBO -84.8688
Epoch 500.0000 - ELBO -84.8399
Epoch 600.0000 - ELBO -84.8267
Epoch 700.0000 - ELBO -84.8193
Epoch 800.0000 - ELBO -84.8093
Epoch 900.0000 - ELBO -84.8043
Epoch 1000.0000 - ELBO -84.8004
Epoch 1100.0000 - ELBO -84.8019
Epoch 1200.0000 - ELBO -84.7994
Epoch 1300.0000 - ELBO -84.8015
Epoch 1400.0000 - ELBO -84.7947
Epoch 1500.0000 - ELBO -84.7954
Epoch 1600.0000 - ELBO -84.7924
Epoch 1700.0000 - ELBO -84.7947
Epoch 1800.0000 - ELBO -84.7936
Batch 4, ELBO -84.7732, IWAE -83.7460


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -88.1241
Epoch 200.0000 - ELBO -87.6429
Epoch 300.0000 - ELBO -87.4997
Epoch 400.0000 - ELBO -87.4273
Epoch 500.0000 - ELBO -87.3921
Epoch 600.0000 - ELBO -87.3740
Epoch 700.0000 - ELBO -87.3607
Epoch 800.0000 - ELBO -87.3538
Epoch 900.0000 - ELBO -87.3462
Epoch 1000.0000 - ELBO -87.3443
Epoch 1100.0000 - ELBO -87.3440
Epoch 1200.0000 - ELBO -87.3427
Epoch 1300.0000 - ELBO -87.3438
Epoch 1400.0000 - ELBO -87.3447
Epoch 1500.0000 - ELBO -87.3398
Epoch 1600.0000 - ELBO -87.3418
Batch 5, ELBO -87.3141, IWAE -86.2557


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -89.9150
Epoch 200.0000 - ELBO -89.4271
Epoch 300.0000 - ELBO -89.2692
Epoch 400.0000 - ELBO -89.1929
Epoch 500.0000 - ELBO -89.1558
Epoch 600.0000 - ELBO -89.1333
Epoch 700.0000 - ELBO -89.1248
Epoch 800.0000 - ELBO -89.1133
Epoch 900.0000 - ELBO -89.1101
Epoch 1000.0000 - ELBO -89.1080
Epoch 1100.0000 - ELBO -89.1038
Epoch 1200.0000 - ELBO -89.1052
Epoch 1300.0000 - ELBO -89.1002
Epoch 1400.0000 - ELBO -89.1026
Epoch 1500.0000 - ELBO -89.0992
Epoch 1600.0000 - ELBO -89.1057
Batch 6, ELBO -89.0886, IWAE -87.9331


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -87.1138
Epoch 200.0000 - ELBO -86.6275
Epoch 300.0000 - ELBO -86.4747
Epoch 400.0000 - ELBO -86.3999
Epoch 500.0000 - ELBO -86.3636
Epoch 600.0000 - ELBO -86.3387
Epoch 700.0000 - ELBO -86.3217
Epoch 800.0000 - ELBO -86.3168
Epoch 900.0000 - ELBO -86.3125
Epoch 1000.0000 - ELBO -86.3094
Epoch 1100.0000 - ELBO -86.3015
Epoch 1200.0000 - ELBO -86.3005
Epoch 1300.0000 - ELBO -86.2965
Epoch 1400.0000 - ELBO -86.2930
Epoch 1500.0000 - ELBO -86.2943
Epoch 1600.0000 - ELBO -86.2949
Batch 7, ELBO -86.2803, IWAE -85.1698


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -92.0182
Epoch 200.0000 - ELBO -91.5048
Epoch 300.0000 - ELBO -91.3484
Epoch 400.0000 - ELBO -91.2736
Epoch 500.0000 - ELBO -91.2367
Epoch 600.0000 - ELBO -91.2125
Epoch 700.0000 - ELBO -91.2047
Epoch 800.0000 - ELBO -91.2037
Epoch 900.0000 - ELBO -91.1957
Epoch 1000.0000 - ELBO -91.1922
Epoch 1100.0000 - ELBO -91.1887
Epoch 1200.0000 - ELBO -91.1895
Epoch 1300.0000 - ELBO -91.1893
Epoch 1400.0000 - ELBO -91.1853
Epoch 1500.0000 - ELBO -91.1872
Epoch 1600.0000 - ELBO -91.1874
Batch 8, ELBO -91.1682, IWAE -90.0381


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -87.9636
Epoch 200.0000 - ELBO -87.4457
Epoch 300.0000 - ELBO -87.2904
Epoch 400.0000 - ELBO -87.2143
Epoch 500.0000 - ELBO -87.1762
Epoch 600.0000 - ELBO -87.1557
Epoch 700.0000 - ELBO -87.1481
Epoch 800.0000 - ELBO -87.1361
Epoch 900.0000 - ELBO -87.1294
Epoch 1000.0000 - ELBO -87.1264
Epoch 1100.0000 - ELBO -87.1255
Epoch 1200.0000 - ELBO -87.1254
Epoch 1300.0000 - ELBO -87.1245
Epoch 1400.0000 - ELBO -87.1186
Epoch 1500.0000 - ELBO -87.1232
Epoch 1600.0000 - ELBO -87.1234
Batch 9, ELBO -87.1308, IWAE -85.9948


  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch 100.0000 - ELBO -88.0875
Epoch 200.0000 - ELBO -87.6009
Epoch 300.0000 - ELBO -87.4628
Epoch 400.0000 - ELBO -87.4017
Epoch 500.0000 - ELBO -87.3683
Epoch 600.0000 - ELBO -87.3516
Epoch 700.0000 - ELBO -87.3431
Epoch 800.0000 - ELBO -87.3389
Epoch 900.0000 - ELBO -87.3342
Epoch 1000.0000 - ELBO -87.3328
Epoch 1100.0000 - ELBO -87.3318
Epoch 1200.0000 - ELBO -87.3291
Epoch 1300.0000 - ELBO -87.3272
Epoch 1400.0000 - ELBO -87.3281
Epoch 1500.0000 - ELBO -87.3283
Epoch 1600.0000 - ELBO -87.3259
Batch 10, ELBO -87.3103, IWAE -86.2141
Average ELBO -87.5371
Average IWAE -86.4243


In [15]:
utils.save_params(save_dir + "/elbos.pkl", elbos)
utils.save_params(save_dir + "/iwaes.pkl", iwaes)
utils.save_params(save_dir + "/local_params.pkl", local_params)