In [None]:
# !pip install git+https://github.com/alec-tschantz/pybrid.git

Collecting git+https://github.com/alec-tschantz/pybrid.git
  Cloning https://github.com/alec-tschantz/pybrid.git to /tmp/pip-req-build-z8cfr7ld
  Running command git clone -q https://github.com/alec-tschantz/pybrid.git /tmp/pip-req-build-z8cfr7ld
Building wheels for collected packages: pybrid
  Building wheel for pybrid (setup.py) ... [?25l[?25hdone
  Created wheel for pybrid: filename=pybrid-0.0.1-cp37-none-any.whl size=8649 sha256=870accab5eb5649f71159a3495226ae2dfb04e694732ae5a228cf50cf04a81ab
  Stored in directory: /tmp/pip-ephem-wheel-cache-ecl6esgw/wheels/7f/d2/e5/568382df15abbc70ecfd60b1864b418d0f1e39769d670f503b
Successfully built pybrid


In [None]:
import logging

import torch

from pybrid import utils
from pybrid import datasets
from pybrid import optim
from pybrid.models.hybrid import HybridModel

In [None]:
def main(cfg):
    cfg = utils.setup_experiment(cfg)

    datasets.download_mnist()
    train_dataset = datasets.MNIST(
        train=True,
        scale=cfg.data.label_scale,
        size=cfg.data.train_size,
        normalize=cfg.data.normalize,
    )
    test_dataset = datasets.MNIST(
        train=False,
        scale=cfg.data.label_scale,
        size=cfg.data.test_size,
        normalize=cfg.data.normalize,
    )
    train_loader = datasets.get_dataloader(train_dataset, cfg.optim.batch_size)
    test_loader = datasets.get_dataloader(test_dataset, cfg.optim.batch_size)
    msg = f"loaded MNIST ({len(train_loader)} train batches {len(test_loader)} test batches)"
    logging.info(msg)

    model = HybridModel(
        nodes=cfg.model.nodes,
        amort_nodes=cfg.model.amort_nodes,
        mu_dt=cfg.infer.mu_dt,
        act_fn=utils.get_act_fn(cfg.model.act_fn),
        use_bias=cfg.model.use_bias,
        kaiming_init=cfg.model.kaiming_init,
    )
    optimizer = optim.get_optim(
        model.params,
        cfg.optim.name,
        cfg.optim.lr,
        amort_lr=cfg.optim.amort_lr,
        batch_scale=cfg.optim.batch_scale,
        grad_clip=cfg.optim.grad_clip,
        weight_decay=cfg.optim.weight_decay,
    )
    logging.info(f"loaded model {model}")

    with torch.no_grad():
        metrics = {"hybrid_acc": [], "pc_acc": [], "amort_acc": []}
        for epoch in range(1, cfg.exp.num_epochs + 1):
            logging.info(f"epoch {epoch}/{cfg.exp.num_epochs + 1}")
            pc_losses, amort_losses = [], []
            logging.info(f"Train @ epoch {epoch} ({len(train_loader)} batches)")

            for batch_id, (img_batch, label_batch) in enumerate(train_loader):
                model.train_batch(
                    img_batch,
                    label_batch,
                    cfg.infer.num_train_iters,
                    fixed_preds=cfg.infer.fixed_preds_train,
                    use_amort=cfg.model.train_amortised,
                )
                optimizer.step(
                    curr_epoch=epoch,
                    curr_batch=batch_id,
                    n_batches=len(train_loader),
                    batch_size=img_batch.size(0),
                )

                pc_loss, amort_loss = model.get_loss()
                pc_losses.append(pc_loss)
                amort_losses.append(amort_loss)

                if batch_id % 100 == 0:
                    pc_loss = sum(pc_losses) / (batch_id + 1)
                    amort_loss = sum(amort_losses) / (batch_id + 1)
                    msg = f"[{batch_id}/{len(train_loader)}] pc loss {pc_loss:.4f} amort loss {amort_loss:.4f}"
                    logging.info(msg)

            if epoch % cfg.exp.test_every == 0:
                logging.info(f"test @ epoch {epoch} ({len(test_loader)} batches)")
                pc_acc = 0
                for _, (img_batch, label_batch) in enumerate(test_loader):
                    label_preds = model.test_batch(
                        img_batch,
                        cfg.infer.num_test_iters,
                        init_std=cfg.infer.init_std,
                        fixed_preds=cfg.infer.fixed_preds_test,
                        use_amort=False,
                    )
                    pc_acc = pc_acc + datasets.accuracy(label_preds, label_batch)                

                pc_acc = pc_acc / len(test_loader)
                metrics["pc_acc"].append(pc_acc)
                msg = "pc acc {:.4f}"
                logging.info(msg.format(pc_acc))

                _, label_batch = next(iter(test_loader))
                img_preds = model.backward(label_batch)
                datasets.plot_imgs(img_preds, cfg.exp.img_dir + f"/{epoch}.png")

            if cfg.optim.normalize_weights:
                model.normalize_weights()

            utils.save_json(metrics, cfg.exp.log_dir + "/metrics.json")

In [None]:
cfg = {
    "exp": {"log_dir": "results/predcoding", "seed": 0, "num_epochs": 20, "test_every": 1},
    "data": {"train_size": None, "test_size": None, "label_scale": 0.94, "normalize": True},
    "infer": {
        "mu_dt": 0.01,
        "num_train_iters": 50,
        "num_test_iters": 200,
        "fixed_preds_train": False,
        "fixed_preds_test": False,
        "init_std": 0.01,
    },
    "model": {
        "nodes": [10, 500, 500, 784],
        "amort_nodes": [784, 500, 500, 10],
        "train_amortised": False,
        "use_bias": True,
        "kaiming_init": False,
        "act_fn": "tanh",
    },
    "optim": {
        "name": "Adam",
        "lr": 1e-4,
        "amort_lr": 1e-4,
        "batch_size": 64,
        "batch_scale": True,
        "grad_clip": 5,
        "weight_decay": None,
        "normalize_weights": True
    },
}
main(cfg)

2021-03-22 13:00:55,063 [INFO] Starting experiment @ results/hybrid/0 [cuda]


{'data': {'label_scale': 0.94,
          'normalize': True,
          'test_size': None,
          'train_size': None},
 'exp': {'img_dir': 'results/hybrid/0/imgs',
         'log_dir': 'results/hybrid/0',
         'num_epochs': 20,
         'seed': 0,
         'test_every': 1},
 'infer': {'fixed_preds_test': False,
           'fixed_preds_train': False,
           'init_std': 0.01,
           'mu_dt': 0.01,
           'num_test_iters': 200,
           'num_train_iters': 50},
 'model': {'act_fn': 'tanh',
           'amort_nodes': [784, 500, 500, 10],
           'kaiming_init': False,
           'nodes': [10, 500, 500, 784],
           'train_amortised': True,
           'use_bias': True},
 'optim': {'amort_lr': 0.0001,
           'batch_scale': True,
           'batch_size': 64,
           'grad_clip': 50,
           'lr': 0.0001,
           'name': 'Adam',
           'weight_decay': None}}


2021-03-22 13:01:02,188 [INFO] loaded MNIST (937 train batches 156 test batches)
2021-03-22 13:01:02,207 [INFO] loaded model <HybridModel> [10, 500, 500, 784]
2021-03-22 13:01:02,208 [INFO] 
epoch 1/21
2021-03-22 13:01:02,208 [INFO] Train @ epoch 1 (937 batches)
2021-03-22 13:01:02,243 [INFO] [0/937] pc loss 3255.7188 amort loss178.1057
2021-03-22 13:01:04,739 [INFO] [100/937] pc loss 866.8825 amort loss77.3283
2021-03-22 13:01:07,228 [INFO] [200/937] pc loss 616.4335 amort loss64.2682
2021-03-22 13:01:09,706 [INFO] [300/937] pc loss 502.8942 amort loss57.7267
2021-03-22 13:01:12,153 [INFO] [400/937] pc loss 434.8599 amort loss53.4731
2021-03-22 13:01:14,630 [INFO] [500/937] pc loss 387.5079 amort loss50.3895
2021-03-22 13:01:17,078 [INFO] [600/937] pc loss 352.5068 amort loss48.0020
2021-03-22 13:01:19,544 [INFO] [700/937] pc loss 325.1014 amort loss46.1165
2021-03-22 13:01:21,989 [INFO] [800/937] pc loss 302.8834 amort loss44.5506
2021-03-22 13:01:24,441 [INFO] [900/937] pc loss 284.