# Required terminal commands

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
%cd /content/drive/My\ Drive/ATiML
!ls

In [0]:
!python --version

In [0]:
!pip3 install tqdm

In [0]:
!pip3 install http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl

In [0]:
!pip3 install torchvision==0.2

# Files from lxuechen repository

## Imports

In [0]:
import argparse

from tqdm import tqdm
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import numpy as np
import numpy.linalg as linalg

import torch
from torch.autograd import Variable
import torch.nn.functional as F

## Log functions

In [0]:
def log_normal(x, mean=None, logvar=None):
    """Implementation WITHOUT constant, since the constants in p(z) 
    and q(z|x) cancels out.
    Args:
        x: [B,Z]
        mean,logvar: [B,Z]

    Returns:
        output: [B]
    """
    if mean is None:
        mean = Variable(torch.zeros(x.size()).type(type(x.data)))
    if logvar is None:
        logvar = Variable(torch.zeros(x.size()).type(type(x.data)))

    return -0.5 * (logvar.sum(1) + ((x - mean).pow(2) / torch.exp(logvar)).sum(1))

In [0]:
def log_bernoulli(logit, target):
    """
    Args:
        logit:  [B, X, ?, ?]
        target: [B, X, ?, ?]
    
    Returns:
        output:      [B]
    """
    loss = -F.relu(logit) + torch.mul(target, logit) - torch.log(1. + torch.exp( -logit.abs() ))
    while len(loss.size()) > 1:
        loss = loss.sum(-1)

    return loss

In [0]:
def log_mean_exp(x):
    max_, _ = torch.max(x, 1, keepdim=True)
    return torch.log(torch.mean(torch.exp(x - max_), 1)) + torch.squeeze(max_)


## Dataset

In [0]:
import numpy as np
from scipy import io
import sys
import os
import time

import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.autograd import Variable
import collections
import pickle
from torch.autograd import Variable


class Larochelle_MNIST:

    def __init__(self, part='train', batch_size=128, partial=1000):
        with open('datasets/mnist_non_binarised.pkl', 'rb') as f:
            # This is checking if you are using a version of Python < 3.0
            if sys.version_info[0] < 3:
                mnist = pickle.load(f)
            else:
                mnist = pickle.load(f, encoding='latin1')
            train = np.concatenate((mnist[0][0], mnist[1][0]))
            # clunky but this is how we'll turn on the binarising functionality for the moment
            # just change the below bool to true if you want to binarise the mnist file
            binarise = False
            if binarise:
              self.data = {
                  'train': static_binarise(train),
                  'test': static_binarise(mnist[2][0]),
                  'partial_train': static_binarise(mnist[0][0][:partial]),
                  'partial_test': static_binarise(mnist[2][0][:partial]),
              }[part]
            else:
              self.data = {
                  'train': train,
                  'test': mnist[2][0],
                  'partial_train': mnist[0][0][:partial],
                  'partial_test': mnist[2][0][:partial],
              }[part]
        self.size = self.data.shape[0]
        self.batch_size = batch_size
        self._construct()

    def __iter__(self):
        return iter(self.batch_list)

    def _construct(self):
        self.batch_list = []
        for i in range(self.size // self.batch_size):
            batch = self.data[self.batch_size*i:self.batch_size*(i+1)]
            batch = torch.from_numpy(batch)
            # placeholder for second entry
            self.batch_list.append((batch, None))


## maths_op

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import numpy as np
import numpy.linalg as linalg

import torch
from torch.autograd import Variable
import torch.nn.functional as F


def log_normal_full_cov(x, mean, L):
    """Log density of full covariance multivariate Gaussian.
    Note: results are off by the constant log(), since this 
    quantity cancels out in p(z) and q(z|x)."""

    def batch_diag(M):
        diag = [t.diag() for t in torch.functional.unbind(M)]
        diag = torch.functional.stack(diag)
        return diag

    def batch_inverse(M, damp=False, eps=1e-6):
        damp_matrix = Variable(torch.eye(M[0].size(0)).type(M.data.type())).mul_(eps)
        inverse = []
        for t in torch.functional.unbind(M):
            # damping to ensure invertible due to float inaccuracy
            # this problem is very UNLIKELY when using double
            m = t if not damp else t + damp_matrix
            inverse.append(m.inverse())
        inverse = torch.functional.stack(inverse)
        return inverse

    L_diag = batch_diag(L)
    term1 = -torch.log(L_diag).sum(1)

    L_inverse = batch_inverse(L)
    scaled_diff = L_inverse.matmul((x - mean).unsqueeze(2)).squeeze()
    term2 = -0.5 * (scaled_diff ** 2).sum(1)

    return term1 + term2


def mean_squared_error(prediction, target):
    prediction, target = flatten(prediction), flatten(target)
    diff = prediction - target

    return -torch.sum(torch.mul(diff, diff), 1)


def discretized_logistic(mu, logs, x):
    """Probability mass follow discretized logistic. 
    https://arxiv.org/pdf/1606.04934.pdf. Assuming pixel values scaled to be
    within [0,1]. Follows implementation from OpenAI.
    """
    sigmoid = torch.nn.Sigmoid()

    s = torch.exp(logs).unsqueeze(-1).unsqueeze(-1)
    logp = torch.log(sigmoid((x + 1./256. - mu) / s) - sigmoid((x - mu) / s) + 1e-7)

    return logp.sum(-1).sum(-1).sum(-1)


def flatten(x):
    return x.view(x.size(0), -1)


def numpy_nan_guard(arr):
    return np.all(arr == arr)


def safe_repeat(x, n):
    return x.repeat(n, *[1 for _ in range(len(x.size()) - 1)])


def sigmoidial_schedule(T, delta=4):
    """From section 6 of BDMC paper."""

    def sigmoid(x):
        return np.exp(x) / (1. + np.exp(x))

    def beta_tilde(t):
        return sigmoid(delta * (2.*t / T - 1.))

    def beta(t):
        return (beta_tilde(t) - beta_tilde(1)) / (beta_tilde(T) - beta_tilde(1))

    return [beta(t) for t in range(1, T+1)]


def linear_schedule(T):
    return np.linspace(0., 1., T)


## AIS

In [0]:
from torch.autograd import Variable
from torch.autograd import grad as torchgrad


def ais_trajectory(
    model,
    loader,
    mode='forward',
    schedule=np.linspace(0., 1., 500),
    n_sample=100
):
    """Compute annealed importance sampling trajectories for a batch of data. 
    Could be used for *both* forward and reverse chain in bidirectional Monte Carlo
    (default: forward chain with linear schedule).

    Args:
        model (vae.VAE): VAE model
        loader (iterator): iterator that returns pairs, with first component being `x`,
            second would be `z` or label (will not be used)
        mode (string): indicate forward/backward chain; must be either `forward` or 
            'backward' schedule (list or 1D np.ndarray): temperature schedule,
            i.e. `p(z)p(x|z)^t`; foward chain has increasing values, whereas
            backward has decreasing values
        n_sample (int): number of importance samples (i.e. number of parallel chains 
            for each datapoint)

    Returns:
        A list where each element is a torch.autograd.Variable that contains the 
        log importance weights for a single batch of data
    """

    assert mode == 'forward' or mode == 'backward', 'Should have forward/backward mode'

    def log_f_i(z, data, t, log_likelihood_fn=log_bernoulli):
        """Unnormalized density for intermediate distribution `f_i`:
            f_i = p(z)^(1-t) p(x,z)^(t) = p(z) p(x|z)^t
        =>  log f_i = log p(z) + t * log p(x|z)
        """
        zeros = Variable(torch.zeros(B, z_size).type(mdtype))
        log_prior = log_normal(z, zeros, zeros)
        log_likelihood = log_likelihood_fn(model.decode(z), data)

        return log_prior + log_likelihood.mul_(t)

    model.eval()

    # shorter aliases
    z_size = model.z_size
    mdtype = model.dtype

    _time = time.time()
    logws = []  # for output

    print ('In %s mode' % mode)

    for i, (batch, post_z) in enumerate(loader):

        B = batch.size(0) * n_sample
        batch = Variable(batch.type(mdtype))
        batch = safe_repeat(batch, n_sample)

        # batch of step sizes, one for each chain
        epsilon = Variable(torch.ones(B).type(model.dtype)).mul_(0.01)
        # accept/reject history for tuning step size
        accept_hist = Variable(torch.zeros(B).type(model.dtype))
        # record log importance weight; volatile=True reduces memory greatly
        logw = Variable(torch.zeros(B).type(mdtype), volatile=True)

        # initial sample of z
        if mode == 'forward':
            current_z = Variable(torch.randn(B, z_size).type(mdtype), requires_grad=True)
        else:
            current_z = Variable(safe_repeat(post_z, n_sample).type(mdtype), requires_grad=True)

        for j, (t0, t1) in tqdm(enumerate(zip(schedule[:-1], schedule[1:]), 1), position=0, leave=True):
            # update log importance weight
            log_int_1 = log_f_i(current_z, batch, t0)
            log_int_2 = log_f_i(current_z, batch, t1)
            logw.add_(log_int_2 - log_int_1)

            # resample speed
            current_v = Variable(torch.randn(current_z.size()).type(mdtype))

            def U(z):
                return -log_f_i(z, batch, t1)

            def grad_U(z):
                # grad w.r.t. outputs; mandatory in this case
                grad_outputs = torch.ones(B).type(mdtype)
                # torch.autograd.grad default returns volatile
                grad = torchgrad(U(z), z, grad_outputs=grad_outputs)[0]
                # avoid humongous gradients
                grad = torch.clamp(grad, -10000, 10000)
                # needs variable wrapper to make differentiable
                grad = Variable(grad.data, requires_grad=True)
                return grad

            def normalized_kinetic(v):
                zeros = Variable(torch.zeros(B, z_size).type(mdtype))
                # this is superior to the unnormalized version
                return -log_normal(v, zeros, zeros)

            z, v = hmc_trajectory(current_z, current_v, U, grad_U, epsilon)

            # accept-reject step
            current_z, epsilon, accept_hist = accept_reject(
                current_z, current_v,
                z, v,
                epsilon,
                accept_hist, j,
                U, K=normalized_kinetic
            )

        # IWAE lower bound
        logw = log_mean_exp(logw.view(n_sample, -1).transpose(0, 1))
        if mode == 'backward':
            logw = -logw
        logws.append(logw.data)

        print ('Time elapse %.4f, last batch stats %.4f' % \
            (time.time()-_time, logw.mean().cpu().data.numpy()))

        _time = time.time()
        sys.stdout.flush()  # for debugging

    return logws

## hmc trajectory

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
import os
import math
import torch
import numpy as np

import torch
from torch.autograd import Variable


def hmc_trajectory(current_z, current_v, U, grad_U, epsilon, L=10):
    """This version of HMC follows https://arxiv.org/pdf/1206.1901.pdf.

    Args:
        U: function to compute potential energy/minus log-density
        grad_U: function to compute gradients w.r.t. U
        epsilon: (adaptive) step size
        L: number of leap-frog steps
        current_z: current position
    """

    # as of `torch-0.3.0.post4`, there still is no proper scalar support
    assert isinstance(epsilon, Variable)

    eps = epsilon.view(-1, 1)
    z = current_z
    v = current_v - grad_U(z).mul(eps).mul_(.5)

    for i in range(1, L+1):
        z = z + v.mul(eps)
        if i < L:
            v = v - grad_U(z).mul(eps)

    v = v - grad_U(z).mul(eps).mul_(.5)
    v = -v  # this is not needed; only here to conform to the math

    return z.detach(), v.detach()


def accept_reject(current_z, current_v, 
                  z, v, 
                  epsilon, 
                  accept_hist, hist_len, 
                  U, K=lambda v: torch.sum(v * v, 1)):
    """Accept/reject based on Hamiltonians for current and propose.

    Args:
        current_z: position BEFORE leap-frog steps
        current_v: speed BEFORE leap-frog steps
        z: position AFTER leap-frog steps
        v: speed AFTER leap-frog steps
        epsilon: step size of leap-frog.
                (This is only needed for adaptive update)
        U: function to compute potential energy (MINUS log-density)
        K: function to compute kinetic energy (default: kinetic energy in physics w/ mass=1)
    """

    mdtype = type(current_z.data)

    current_Hamil = K(current_v) + U(current_z)
    propose_Hamil = K(v) + U(z)

    prob = torch.exp(current_Hamil - propose_Hamil)
    uniform_sample = torch.rand(prob.size())
    uniform_sample = Variable(uniform_sample.type(mdtype))
    accept = (prob > uniform_sample).type(mdtype)
    z = z.mul(accept.view(-1, 1)) + current_z.mul(1. - accept.view(-1, 1))

    accept_hist = accept_hist.add(accept)
    criteria = (accept_hist / hist_len > 0.65).type(mdtype)
    adapt = 1.02 * criteria + 0.98 * (1. - criteria)
    epsilon = epsilon.mul(adapt).clamp(1e-4, .5)

    # clear previous history & save memory, similar to detach
    z = Variable(z.data, requires_grad=True)
    epsilon = Variable(epsilon.data)
    accept_hist = Variable(accept_hist.data)

    return z, epsilon, accept_hist


## VAE model

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import sys
import argparse

import torch
import torch.utils.data
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from torch.autograd import Variable
from torch.autograd import grad as torchgrad


class VAE(nn.Module):
    """Generic VAE for MNIST and Fashion datasets."""
    def __init__(self, hps):
        super(VAE, self).__init__()

        self.z_size = hps.z_size
        self.has_flow = hps.has_flow
        self.use_cuda = hps.cuda
        self.act_func = hps.act_func
        self.n_flows = hps.n_flows
        self.hamiltonian_flow = hps.hamiltonian_flow

        self._init_layers(wide_encoder=hps.wide_encoder)

        if self.use_cuda:
            self.cuda()
            self.dtype = torch.cuda.FloatTensor
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            self.dtype = torch.FloatTensor

    def _init_layers(self, wide_encoder=False):
        h_s = 500 if wide_encoder else 200

        self.fc1 = nn.Linear(784, h_s)  # assume flattened
        self.fc2 = nn.Linear(h_s, h_s)
        self.fc3 = nn.Linear(h_s, self.z_size*2)

        self.fc4 = nn.Linear(self.z_size, 200)
        self.fc5 = nn.Linear(200, 200)
        self.fc6 = nn.Linear(200, 784)

        self.x_info_layer = nn.Linear(200, self.z_size)

        if self.has_flow:
            self.q_dist = Flow(self, n_flows=self.n_flows)
            if self.use_cuda:
                self.q_dist.cuda()

    def sample(self, mu, logvar, grad_fn=lambda x: 1, x_info=None):
        eps = Variable(torch.FloatTensor(mu.size()).normal_().type(self.dtype))
        z = eps.mul(logvar.mul(0.5).exp_()).add_(mu)
        logqz = log_normal(z, mu, logvar)

        if self.has_flow:
            z, logprob = self.q_dist.forward(z, grad_fn, x_info)
            logqz += logprob

        zeros = Variable(torch.zeros(z.size()).type(self.dtype))
        logpz = log_normal(z, zeros, zeros)

        return z, logpz, logqz

    def encode(self, net):
        net = self.act_func(self.fc1(net))
        net = self.act_func(self.fc2(net))
        x_info = self.act_func(self.x_info_layer(net))
        net = self.fc3(net)

        mean, logvar = net[:, :self.z_size], net[:, self.z_size:]

        return mean, logvar, x_info

    def decode(self, net):
        net = self.act_func(self.fc4(net))
        net = self.act_func(self.fc5(net))
        logit = self.fc6(net)

        return logit

    def forward(self, x, k=1, warmup_const=1.):
        x = x.repeat(k, 1)
        mu, logvar, x_info = self.encode(x)

        # posterior-aware inference
        def U(z):
            logpx = log_bernoulli(self.decode(z), x)
            logpz = log_normal(z)
            return -logpx - logpz  # energy as -log p(x, z)

        def grad_U(z):
            grad_outputs = torch.ones(z.size(0)).type(self.dtype)
            grad = torchgrad(U(z), z, grad_outputs=grad_outputs, create_graph=True)[0]
            # gradient clipping avoid numerical issue
            norm = torch.sqrt(torch.norm(grad, p=2, dim=1))
            # neither grad clip methods consistently outperforms the other
            grad = grad / norm.view(-1, 1)
            # grad = torch.clamp(grad, -10000, 10000)
            return grad.detach()

        if self.hamiltonian_flow:
            z, logpz, logqz = self.sample(mu, logvar, grad_fn=grad_U, x_info=x_info)
        else:
            z, logpz, logqz = self.sample(mu, logvar, x_info=x_info)

        logit = self.decode(z)
        logpx = log_bernoulli(logit, x)
        elbo = logpx + logpz - warmup_const * logqz  # custom warmup

        # need correction for Tensor.repeat
        elbo = log_mean_exp(elbo.view(k, -1).transpose(0, 1))
        elbo = torch.mean(elbo)

        logpx = torch.mean(logpx)
        logpz = torch.mean(logpz)
        logqz = torch.mean(logqz)

        return elbo, logpx, logpz, logqz


## Loaders file

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


def get_Larochelle_MNIST_loader(batch_size=100, partial=False, num=1000):

    if partial:
        train_loader = Larochelle_MNIST(part='partial_train', batch_size=batch_size, partial=num)
        test_loader = Larochelle_MNIST(part='partial_test')
    else:
        train_loader = Larochelle_MNIST(part='train', batch_size=batch_size)
        test_loader = Larochelle_MNIST(part='test', batch_size=batch_size)
    
    return train_loader, test_loader


def get_loaders(dataset='mnist', evaluate=False, batch_size=100):
    if dataset == 'mnist':
        train_loader, test_loader = get_Larochelle_MNIST_loader(
            batch_size=batch_size,
            partial=evaluate, num=1000
        )
    return train_loader, test_loader


def get_model(dataset, hps):
    if dataset == 'mnist':
        model = VAE(hps)
    return model

## HParams

In [0]:
class HParams(object):

    def __init__(self, **kwargs):
        self._items = {}
        for k, v in kwargs.items():
            self._set(k, v)

    def _set(self, k, v):
        self._items[k] = v
        setattr(self, k, v)

    def parse(self, str_value):
        hps = HParams(**self._items)
        for entry in str_value.strip().split(","):
            entry = entry.strip()
            if not entry:
                continue
            key, sep, value = entry.partition("=")
            if not sep:
                raise ValueError("Unable to parse: %s" % entry)
            default_value = hps._items[key]
            if isinstance(default_value, bool):
                hps._set(key, value.lower() == "true")
            elif isinstance(default_value, int):
                hps._set(key, int(value))
            elif isinstance(default_value, float):
                hps._set(key, float(value))
            else:
                hps._set(key, value)
        return hps

def get_default_hparams():
    return HParams(
        z_size=50,
        act_func=F.elu,
        has_flow=False,
        large_encoder=False,
        wide_encoder=False,
        cuda=True,
    )

## Local FFG

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import sys
from tqdm import tqdm
import argparse
import numpy as np

import torch
import torch.utils.data
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


parser = argparse.ArgumentParser(description='local_factorized_gaussian')
# action configuration flags
parser.add_argument('--no-cuda', '-nc', action='store_true')
parser.add_argument('--debug', action='store_true', help='debug mode')

# model configuration flags
parser.add_argument('--z-size', '-zs', type=int, default=50)
parser.add_argument('--batch-size', '-bs', type=int, default=100)
parser.add_argument('--eval-path', '-ep', type=str, default='model.pth',
                    help='path to load evaluation ckpt (default: model.pth)')
parser.add_argument('--dataset', '-d', type=str, default='mnist',
                    choices=['mnist', 'fashion', 'cifar'], 
                    help='dataset to train and evaluate on (default: mnist)')
parser.add_argument('--has-flow', '-hf', action='store_true', help='inference uses FLOW')
parser.add_argument('--n-flows', '-nf', type=int, default=2, help='number of flows')
parser.add_argument('--wide-encoder', '-we', action='store_true',
                    help='use wider layer (more hidden units for FC, more channels for CIFAR)')


def get_default_hparams():
    return HParams(
        z_size=args.z_size,
        act_func=F.elu,
        has_flow=args.has_flow,
        n_flows=args.n_flows,
        wide_encoder=args.wide_encoder,
        cuda=args.cuda,
        hamiltonian_flow=False
    )


def optimize_local_gaussian(
    log_likelihood,
    model,
    data_var,
    k=100,
    check_every=100,
    sentinel_thres=10,
    debug=False
):
    """data_var should be (cuda) variable."""

    B = data_var.size()[0]
    z_size = model.z_size

    data_var = safe_repeat(data_var, k)
    zeros = Variable(torch.zeros(B*k, z_size).type(model.dtype))
    mean = Variable(torch.zeros(B*k, z_size).type(model.dtype), requires_grad=True)
    logvar = Variable(torch.zeros(B*k, z_size).type(model.dtype), requires_grad=True)

    optimizer = optim.Adam([mean, logvar], lr=1e-3)
    best_avg, sentinel, prev_seq = 999999, 0, []

    # perform local opt
    time_ = time.time()
    for epoch in range(1, 999999):

        eps = Variable(torch.FloatTensor(mean.size()).normal_().type(model.dtype))
        z = eps.mul(logvar.mul(0.5).exp_()).add_(mean)
        x_logits = model.decode(z)

        logpz = log_normal(z, zeros, zeros)
        logqz = log_normal(z, mean, logvar)
        logpx = log_likelihood(x_logits, data_var)

        optimizer.zero_grad()
        loss = -torch.mean(logpx + logpz - logqz)
        loss_np = loss.data.cpu().numpy()
        loss.backward()
        optimizer.step()

        prev_seq.append(loss_np)
        if epoch % check_every == 0:
            last_avg = np.mean(prev_seq)
            if debug:  # debugging helper
                sys.stderr.write(
                    'Epoch %d, time elapse %.4f, last avg %.4f, prev best %.4f\n' % \
                    (epoch, time.time()-time_, -last_avg, -best_avg)
                )
            if last_avg < best_avg:
                sentinel, best_avg = 0, last_avg
            else:
                sentinel += 1
            if sentinel > sentinel_thres:
                break
            prev_seq = []
            time_ = time.time()

    # evaluation
    eps = Variable(torch.FloatTensor(B*k, z_size).normal_().type(model.dtype))
    z = eps.mul(logvar.mul(0.5).exp_()).add_(mean)

    logpz = log_normal(z, zeros, zeros)
    logqz = log_normal(z, mean, logvar)
    logpx = log_likelihood(model.decode(z), data_var)
    elbo = logpx + logpz - logqz

    vae_elbo = torch.mean(elbo)
    iwae_elbo = torch.mean(log_mean_exp(elbo.view(k, -1).transpose(0, 1)))

    return vae_elbo.data[0], iwae_elbo.data[0]


def main_ffg(arg_string):
    global args
    args = parser.parse_args(arg_string.split())
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    train_loader, test_loader = get_loaders(
        dataset=args.dataset,
        evaluate=True, batch_size=args.batch_size
    )
    model = get_model(args.dataset, get_default_hparams())
    model.load_state_dict(torch.load(args.eval_path)['state_dict'])
    model.eval()

    vae_record, iwae_record = [], []
    time_ = time.time()
    for i, (batch, _) in tqdm(enumerate(train_loader)):
        batch = Variable(batch.type(model.dtype))
        elbo, iwae = optimize_local_gaussian(log_bernoulli, model, batch, debug=args.debug)
        vae_record.append(elbo)
        iwae_record.append(iwae)
        print ('Local opt w/ ffg, batch %d, time elapse %.4f, ELBO %.4f, IWAE %.4f' % \
            (i+1, time.time()-time_, elbo, iwae))
        print ('mean of ELBO so far %.4f, mean of IWAE so far %.4f' % \
            (np.nanmean(vae_record), np.nanmean(iwae_record)))
        time_ = time.time()

    print ('Finishing...')
    print ('Average ELBO %.4f, IWAE %.4f' % (np.nanmean(vae_record), np.nanmean(iwae_record)))


# Candidate 1039199: What I coded and adapted

In [0]:
parser = argparse.ArgumentParser(description='VAE')
# action configuration flags
parser.add_argument('--train', '-t', action='store_true')
parser.add_argument('--load-path', '-lp', type=str, default='NA',
                    help='path to load checkpoint to retrain')
parser.add_argument('--load-epoch', '-le', type=int, default=0,
                    help='epoch number to start recording when retraining')
parser.add_argument('--display-epoch', '-de', type=int, default=10,
                    help='print status every so many epochs (default: 10)')
parser.add_argument('--eval-iwae', '-ei', action='store_true')
parser.add_argument('--eval-ais', '-ea', action='store_true')
parser.add_argument('--n-iwae', '-ni', type=int, default=5000,
                    help='number of samples for IWAE evaluation (default: 5000)')
parser.add_argument('--n-ais-iwae', '-nai', type=int, default=100,
                    help='number of IMPORTANCE samples for AIS evaluation (default: 100). \
                          This is different from MC samples.')
parser.add_argument('--n-ais-dist', '-nad', type=int, default=10000,
                    help='number of distributions for AIS evaluation (default: 10000)')
parser.add_argument('--ais-schedule', type=str, default='linear', help='schedule for AIS')

parser.add_argument('--no-cuda', '-nc', action='store_true', help='force not use CUDA')
parser.add_argument('--visdom', '-v', action='store_true', help='visualize samples')
parser.add_argument('--port', '-p', type=int, default=8097, help='port for visdom')
parser.add_argument('--save-visdom', default='test', help='visdom save path')
parser.add_argument('--encoder-more', action='store_true', help='train the encoder more (5 vs 1)')
parser.add_argument('--early-stopping', '-es', action='store_true', help='apply early stopping')
parser.add_argument('--epochs', '-e', type=int, default=3280,
                    help='total num of epochs for training (default: 3280)')
parser.add_argument('--lr-schedule', '-lrs', action='store_true',
                    help='apply learning rate schedule')

# model configuration flags
parser.add_argument('--z-size', '-zs', type=int, default=50,
                    help='dimensionality of latent code (default: 50)')
parser.add_argument('--batch-size', '-bs', type=int, default=100,
                    help='batch size (default: 100)')
parser.add_argument('--save-name', '-sn', type=str, default='model.pth',
                    help='name to save trained ckpt (default: model.pth)')
parser.add_argument('--eval-path', '-ep', type=str, default='model.pth',
                    help='path to load evaluation ckpt (default: model.pth)')
parser.add_argument('--dataset', '-d', type=str, default='mnist',
                    choices=['mnist', 'fashion', 'cifar'],
                    help='dataset to train and evaluate on (default: mnist)')
parser.add_argument('--wide-encoder', '-we', action='store_true',
                    help='use wider layer (more hidden units for FC, more channels for CIFAR)')
parser.add_argument('--has-flow', '-hf', action='store_true',
                    help='use flow for training and eval')
parser.add_argument('--hamiltonian-flow', '-hamil-f', action='store_true')
parser.add_argument('--n-flows', '-nf', type=int, default=2, help='number of flows')
parser.add_argument('--warmup', '-w', action='store_true',
                    help='apply warmup during training')

# 1039199: I added these lines
parser.add_argument('--optimiser', type=str, default='Adam', 
                    choices=['Adam', 'AdaDelta', 'SGD'],
                    help='Choose an optimiser for the VAE')
parser.add_argument('--init_lr', type=float, help='Initial learning rate ot the optimiser')



def get_default_hparams():
    return HParams(
        z_size=args.z_size,
        act_func=F.elu,
        has_flow=args.has_flow,
        hamiltonian_flow=args.hamiltonian_flow,
        n_flows=args.n_flows,
        wide_encoder=args.wide_encoder,
        cuda=args.cuda,
    )


In [0]:
def load_checkpoint(model, optimizer, filename):
    # Note: Input model & optimizer should be pre-defined.  This routine only updates their states.
    start_epoch = 0
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))

    return model, optimizer, start_epoch

In [0]:
 # Function to binarise the arrays in place
def static_binarise(d):
  ids = d < 0.5
  d[ids] = 0.
  d[~ids] = 1.
  return(d)

In [0]:
CUDA_LAUNCH_BLOCKING=1

def train(model, train_loader, test_loader, 
          optimiser_str, init_lr, # 1039199: I added these arguments
          k_train=1,  # num iwae sample for training
          k_eval=1,  # num iwae sample for eval
          epochs=3280, display_epoch=10, lr_schedule=True, 
          warmup=True, warmup_thres=None,
          encoder_more=False,
          checkpoints=None, early_stopping=True, 
          save=True, save_path='checkpoints/',
          patience=10):
  
  print('Training')

  if args.load_path != 'NA':
    f = args.load_path
    model.load_state_dict(torch.load(f)['state_dict'])

  # default warmup schedule
  if warmup_thres is None:
    if 'cifar' in save_path:
      warmup_thres = 50.
    elif 'mnist' in save_path or 'fashion' in save_path:
      warmup_thres = 400.

  if checkpoints is None:  # save a checkpoint every display_epoch
    checkpoints = [1] + list(range(0, 3280, display_epoch))[1:] + [3280]

  time_ = time.time()

  if optimizer_str == 'Adam':
    current_lr = init_lr
    optimizer = optim.Adam(model.parameters(), lr=current_lr)
  elif optimizer_str == 'AdaDelta':
    current_lr = init_lr
    optimizer = optim.Adadelta(model.parameters(), lr=current_lr, 
                                rho=0.9, eps=1e-06, weight_decay=0)
  elif optimizer_str == 'SGD':
    current_lr = init_lr
    optimizer = optim.SGD(model.parameters(), lr=current_lr,
                          momentum=0.9, dampening=1e-3)


  num_worse = 0  # compare against `patience` for early-stopping
  prev_valid_err = None
  best_valid_err = None

  # 1039199: I fixed the lr_scheduler
  # 1039199: I fixed the load_checkpoints
  power = 0
  epoch_elapsed = 0

  for epoch in tqdm(range(1, epochs+1), position=0, leave=True):
    warmup_const = min(1., epoch / warmup_thres) if warmup else 1.
    # lr schedule from IWAE: https://arxiv.org/pdf/1509.00519.pdf        
    if lr_schedule:
      if epoch_elapsed >= 3 ** power:
        # 1039199: I slowed the reduction of the learning rate
        if epoch_elapsed != 0:
          current_lr *= 10. ** (-1./20.)
        power += 1
        # 1039199: Correct way to do lr decay; also possible w/ `torch.optim.lr_scheduler`
        for param_group in optimizer.param_groups:
          param_group['lr'] = current_lr
      epoch_elapsed += 1

    if epoch in checkpoints:
      model, optimizer, checkpoint_exists = load_checkpoint(model, optimizer, ('%s%d_%s' % (save_path, epoch, args.save_name)))
    
    if not checkpoint_exists:
      model.train()  # crucial for BN to work properly
      for _, (batch, _) in enumerate(train_loader):
        batch = Variable(batch)
        if args.cuda:
          batch = batch.cuda()

        # train the encoder more
        if encoder_more:
          model.freeze_decoder()
          for _ in range(10):
            optimizer.zero_grad()
            elbo, _, _, _ = model.forward(batch, k_train, warmup_const)
            loss = -elbo
            loss.backward()
            optimizer.step()
          model.unfreeze_decoder()

        optimizer.zero_grad()
        elbo, _, _, _ = model.forward(batch, k_train, warmup_const)
        loss = -elbo
        loss.backward()
        optimizer.step()

      if epoch % display_epoch == 0:
        model.eval()  # crucial for BN to work properly

        train_logpx, test_logpx = [], []
        train_logpz, test_logpz = [], []
        train_logqz, test_logqz = [], []
        train_stats, test_stats = [], []
        for _, (batch, _) in enumerate(train_loader):
          batch = Variable(batch)
          if args.cuda:
            batch = batch.cuda()
          elbo, logpx, logpz, logqz = model(batch, k=1)
          train_stats.append(elbo.data[0])
          train_logpx.append(logpx.data[0])
          train_logpz.append(logpz.data[0])
          train_logqz.append(logqz.data[0])

        for _, (batch, _) in enumerate(test_loader):
          batch = Variable(batch)
          if args.cuda:
            batch = batch.cuda()
          # early stopping with iwae bound
          elbo, logpx, logpz, logqz = model(batch, k=k_eval)
          test_stats.append(elbo.data[0])
          test_logpx.append(logpx.data[0])
          test_logpz.append(logpz.data[0])
          test_logqz.append(logqz.data[0])
        print (
            'Train Epoch: [{}/{}]'.format(epoch, epochs),
            'Train set ELBO {:.4f}'.format(np.mean(np.asarray(train_stats))),
            'Test/Validation set IWAE {:.4f}'.format(np.mean(np.asarray(test_stats))),
            'Time: {:.2f}'.format(time.time()-time_),
        )
        time_ = time.time()

        if early_stopping:
          curr_valid_err = np.mean(test_stats)

          if prev_valid_err is None:  # don't have history yet
            prev_valid_err = curr_valid_err
          elif curr_valid_err >= prev_valid_err:  # performance improved
            prev_valid_err = curr_valid_err
            num_worse = 0
          else:
            num_worse += 1

          if num_worse >= patience:
            print("Stopped early")
            break

      if save and (epoch in checkpoints):
        if not checkpoint_exists:
          state = {'epoch': epoch + args.load_epoch, 'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),}
          torch.save(state, '%s%d_%s' % (save_path, epoch + args.load_epoch, args.save_name))

In [0]:
CUDA_LAUNCH_BLOCKING=1

def run(arg_string):
    global args
    args = parser.parse_args(arg_string.split()) #parser.parse_args()
    # args.eval_path = '/content/drive/My Drive/ATiML' + args.eval_path
    # Sanity check to know which arguments are being parsed
    print(args)
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    print(args.cuda)
    train_loader, test_loader = get_loaders(
        dataset=args.dataset,
        evaluate=args.eval_iwae or args.eval_ais, # HERE
        batch_size=args.batch_size
    )
    model = get_model(args.dataset, get_default_hparams())

    if args.train:
        save_path = 'expansion2/%s/lr_%d/bs_%s/' % (
                        args.optimiser,
                        args.init_lr,
                        args.batch_size
                    )
        if not os.path.exists(save_path):
          os.makedirs(save_path)

        train_AdaBound(
            model, train_loader, test_loader,
            display_epoch=args.display_epoch, epochs=args.epochs,
            lr_schedule=args.lr_schedule,
            warmup=args.warmup,
            early_stopping=args.early_stopping,
            encoder_more=args.encoder_more,
            save=True, save_path=save_path
        )

    if args.visdom:
        vis = visdom.Visdom(env=args.save, port=args.port)
        model.load_state_dict(torch.load(args.eval_path)['state_dict'])

        # plot original images
        batch, _ = train_loader.next()
        images = list(batch.numpy())
        win_samples = vis.images(images, 10, 2, opts={'caption': 'original images'}, win=None)

        # plot reconstructions
        batch = Variable(batch.type(model.dtype))
        reconstruction = model.reconstruct_img(batch)
        images = list(reconstruction.data.cpu().numpy())
        win_samples = vis.images(images, 10, 2, opts={'caption': 'reconstruction'}, win=None)

In [0]:
def eval_mod(arg_string):
    global args
    args = parser.parse_args(arg_string.split()) #parser.parse_args()
    # args.eval_path = '/content/drive/My Drive/ATiML' + args.eval_path
    # Sanity check to know which arguments are being parsed
    print(args)
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    print(args.cuda)
    train_loader, test_loader = get_loaders(
        dataset=args.dataset,
        evaluate=args.eval_iwae or args.eval_ais, # HERE
        batch_size=args.batch_size
    )
    model = get_model(args.dataset, get_default_hparams())


    if args.eval_iwae:
        # VAE bounds computed w/ 100 MC samples to reduce variance
        train_res, test_res = [], []
        for _ in range(100):
            test_iwae(model, train_loader, k=1, f=args.eval_path)
            test_iwae(model, test_loader, k=1, f=args.eval_path)
            #print("about to append")
            train_res.append(train_res)
            #print("appended train_res")
            test_res.append(test_res)
            #print("appended test_res")

        print("exited for loop")
        print("length of train_res = {}".format(len(train_res[0])))
        train_mean = np.mean(train_res)
        print("finished calculating mean")
        print ('Training set VAE ELBO w/ 100 MC samples: %.4f' % train_mean)
        print ('Test set VAE ELBO w/ 100 MC samples: %.4f' % np.mean(test_res))

        # IWAE bounds
        test_iwae(model, train_loader, k=args.n_iwae, f=args.eval_path)
        test_iwae(model, test_loader, k=args.n_iwae, f=args.eval_path)

    if args.eval_ais:
      print("Start evaluating")
      model.load_state_dict(torch.load(args.eval_path)['state_dict'])
      schedule_fn = linear_schedule if args.ais_schedule == 'linear' else sigmoidial_schedule
      schedule = schedule_fn(args.n_ais_dist)
      ais_trajectory(
          model, train_loader,
          mode='forward', schedule=schedule, n_sample=args.n_ais_iwae
      )

# Experiments

In [0]:
# To activate GPU
use_cuda = True
print(torch.cuda.is_available())

In [0]:
# SGD
run("--train --dataset mnist --lr-schedule --warmup --early-stopping --optimiser SGD --init_lr 0.001")

In [0]:
# AdaDelta lr=1e-3
run("--train --dataset mnist --lr-schedule --warmup --early-stopping --optimiser AdaDelta --init_lr 0.001")

In [0]:
eval_mod("--eval-ais  --dataset mnist --eval-path ./expansion2/SGD/lr_1e-3/bs100_3280_model.pth")

In [0]:
eval_mod("--eval-ais  --dataset mnist --eval-path ./expansion2/AdaDelta/lr_1e-3/3280_model.pth")

In [0]:
main_ffg("--dataset mnist --eval-path ./expansion2/SGD/lr_1e-3/bs100_3280_model.pth --debug")

In [0]:
main_ffg("--dataset mnist --eval-path ./expansion2/AdaDelta/lr_1e-3/3280_model.pth --debug")