#Required Terminal Commands

First, run this code to mount your drive to your runtime:

In [0]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Now, assuming the shared folder is present in your drive, this step should work. If it doesn't, make sure your ATiML folder isn't nested in another folder.

In [0]:
%cd /content/drive/My\ Drive/ATiML
!ls

/content/drive/My Drive/ATiML
 checkpoints   expansion2	       MNIST_notused	   Results.gsheet
 datasets      Experiment_3	       old_checkpoints
 Exp4	       Extension_Planar_Flow   Plan_V2.gdoc
 Exp5	      'Final Code'	      'Poster Plan.gdoc'


All of the following commands are just to get the right python libraries in your runtime - I'm not yet sure if you close the window and re-open whether you need to run these again

In [0]:
!python --version

Python 3.6.9


In [0]:
!pip3 install tqdm



In [0]:
!pip3 install http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl



In [0]:
!pip3 install torchvision==0.2



In [0]:
!pip3 install visdom



#imports

In [0]:
#@title Default title text
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import sys
import os
import math
import argparse
from tqdm import tqdm
import numpy as np
# i've commented out this line since it just seems to change the total number of array elements which trigger summarization, and it causes an error.
# np.set_printoptions(threshold=np.nan)
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms, utils
import visdom

#math_ops file

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import numpy as np
import numpy.linalg as linalg

import torch
from torch.autograd import Variable
import torch.nn.functional as F


def log_normal(x, mean=None, logvar=None,repeat=False): #[1040336] Added the ability to run this on tensors [P,B,Z] to interface with my memory management tweaks
    """Implementation WITHOUT constant, since the constants in p(z) 
    and q(z|x) cancels out.
    Args:
        x: [B,Z]
        mean,logvar: [B,Z]

    Returns:
        output: [B]
    """
    if mean is None:
        mean = Variable(torch.zeros(x.size()).type(type(x.data)))
    if logvar is None:
        logvar = Variable(torch.zeros(x.size()).type(type(x.data)))
    if repeat:
      D = x.size()[2]
      if torch.cuda.is_available():
          term1 = D * torch.log(torch.cuda.FloatTensor([2.*math.pi])) #[1]
      else:
          term1 = D * torch.log(torch.FloatTensor([2.*math.pi])) #[1]


      return -.5 * (Variable(term1) + logvar.sum(1) + ((x - mean).pow(2)/torch.exp(logvar)).sum(2))

    return -0.5 * (logvar.sum(1) + ((x - mean).pow(2) / torch.exp(logvar)).sum(1))


def log_normal_full_cov(x, mean, L):
    """Log density of full covariance multivariate Gaussian.
    Note: results are off by the constant log(), since this 
    quantity cancels out in p(z) and q(z|x)."""

    def batch_diag(M):
        diag = [t.diag() for t in torch.functional.unbind(M)]
        diag = torch.functional.stack(diag)
        return diag

    def batch_inverse(M, damp=False, eps=1e-6):
        damp_matrix = Variable(torch.eye(M[0].size(0)).type(M.data.type())).mul_(eps)
        inverse = []
        for t in torch.functional.unbind(M):
            # damping to ensure invertible due to float inaccuracy
            # this problem is very UNLIKELY when using double
            m = t if not damp else t + damp_matrix
            inverse.append(m.inverse())
        inverse = torch.functional.stack(inverse)
        return inverse

    L_diag = batch_diag(L)
    term1 = -torch.log(L_diag).sum(1)

    L_inverse = batch_inverse(L)
    scaled_diff = L_inverse.matmul((x - mean).unsqueeze(2)).squeeze()
    term2 = -0.5 * (scaled_diff ** 2).sum(1)

    return term1 + term2


def log_bernoulli(logit, target, repeat = False): #[1040336] Added the ability to run this on tensors [P,B,X] to interface with my memory management tweaks
    """
    Args:
        logit:  [B, X, ?, ?]
        target: [B, X, ?, ?]
    
    Returns:
        output:      [B]
    """
    if repeat:
      return -(torch.clamp(logit, min=0)- logit * target
             + torch.log(1. + torch.exp(-torch.abs(logit)))).sum(2) #sum over dimensions
    loss = -F.relu(logit) + torch.mul(target, logit) - torch.log(1. + torch.exp( -logit.abs() ))
    while len(loss.size()) > 1:
        loss = loss.sum(-1)
    return loss

def log_bernoulli_cont(logit, target, thresh = 1e-4, repeat = False): #[1040336] Wrote this function, which computed continuous bernoulli likelihood from decoder outputs
    ones = Variable(torch.ones(logit.size()).cuda())
    logprob = logit-F.relu(logit) - torch.log(ones +torch.exp(-logit.abs()))
    loginvprob = logprob -logit
    prob = torch.exp(logprob)
    logbernoulli = target * logprob + (1.-target)*loginvprob #unnormalized probability, same as regular
    #print('Unnormalised:',logbernoulli)
    mask = torch.abs(prob - 0.5).ge(thresh).float().cuda()
    
    safe_mask = mask*prob+(ones-mask)*ones*0.75
    critical_mask = mask*thresh+(ones-mask)*prob
    safe_values =  mask*torch.log( (torch.log(ones - safe_mask) - torch.log(safe_mask)).div(ones - (ones * 2) * safe_mask) )
    #print('Computed far values',far_values)
    taylor_values = (ones-mask)*(torch.log(Variable(torch.FloatTensor([2.]).cuda())) + torch.log(1. + torch.pow( 1. - 2. * critical_mask, 2)/3. ))
    #print('Calculated close values',close_values)
    loss = logbernoulli +safe_values + taylor_values
    if repeat == False:
      while len(loss.size()) > 1:
          loss = loss.sum(-1)

      return loss
    else:
      return loss.sum(2)
def mean_cont_bern(prob,thresh = 1e-5):    #[1040336] Wrote this function, which computes mean of continuous bernoulli distribution. Unlike Binary mean is not just prob.
    ones = Variable(torch.ones(prob.size()).cuda())
    mask = torch.abs(prob - 0.5).ge(thresh).float().cuda()
    safe_mask = mask*prob+(ones-mask)*0.75
    critical = mask*thresh+(ones-mask)*prob
    safe_values =  mask*(torch.div(safe_mask,2*safe_mask-ones)+torch.div(ones, torch.log(ones-safe_mask)-torch.log(safe_mask)))
    critical_values = (ones-mask)*0.5
    #print('Calculated close values',close_values)
    return safe_values + critical_values
def mean_squared_error(prediction, target):
    prediction, target = flatten(prediction), flatten(target)
    diff = prediction - target

    return -torch.sum(torch.mul(diff, diff), 1)


def discretized_logistic(mu, logs, x):
    """Probability mass follow discretized logistic. 
    https://arxiv.org/pdf/1606.04934.pdf. Assuming pixel values scaled to be
    within [0,1]. Follows implementation from OpenAI.
    """
    sigmoid = torch.nn.Sigmoid()

    s = torch.exp(logs).unsqueeze(-1).unsqueeze(-1)
    logp = torch.log(sigmoid((x + 1./256. - mu) / s) - sigmoid((x - mu) / s) + 1e-7)

    return logp.sum(-1).sum(-1).sum(-1)


def flatten(x):
    return x.view(x.size(0), -1)


def log_mean_exp(x):
    max_, _ = torch.max(x, 1, keepdim=True)
    return torch.log(torch.mean(torch.exp(x - max_), 1)) + torch.squeeze(max_)


def numpy_nan_guard(arr):
    return np.all(arr == arr)


def safe_repeat(x, n):
    return x.repeat(n, *[1 for _ in range(len(x.size()) - 1)])


def sigmoidial_schedule(T, delta=4):
    """From section 6 of BDMC paper."""

    def sigmoid(x):
        return np.exp(x) / (1. + np.exp(x))

    def beta_tilde(t):
        return sigmoid(delta * (2.*t / T - 1.))

    def beta(t):
        return (beta_tilde(t) - beta_tilde(1)) / (beta_tilde(T) - beta_tilde(1))

    return [beta(t) for t in range(1, T+1)]


def linear_schedule(T):
    return np.linspace(0., 1., T)


#hparams file

In [0]:
class HParams(object):

    def __init__(self, **kwargs):
        self._items = {}
        for k, v in kwargs.items():
            self._set(k, v)

    def _set(self, k, v):
        self._items[k] = v
        setattr(self, k, v)

    def parse(self, str_value):
        hps = HParams(**self._items)
        for entry in str_value.strip().split(","):
            entry = entry.strip()
            if not entry:
                continue
            key, sep, value = entry.partition("=")
            if not sep:
                raise ValueError("Unable to parse: %s" % entry)
            default_value = hps._items[key]
            if isinstance(default_value, bool):
                hps._set(key, value.lower() == "true")
            elif isinstance(default_value, int):
                hps._set(key, int(value))
            elif isinstance(default_value, float):
                hps._set(key, float(value))
            else:
                hps._set(key, value)
        return hps

def get_default_hparams():
    return HParams(
        z_size=50,
        act_func=F.elu,
        has_flow=False,
        large_encoder=False,
        wide_encoder=False,
        cuda=True,
        decode_dist = log_bernoulli,
    )

#hmc file

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
import os
import math
import torch
import numpy as np

import torch
from torch.autograd import Variable


def hmc_trajectory(current_z, current_v, U, grad_U, epsilon, L=10):
    """This version of HMC follows https://arxiv.org/pdf/1206.1901.pdf.

    Args:
        U: function to compute potential energy/minus log-density
        grad_U: function to compute gradients w.r.t. U
        epsilon: (adaptive) step size
        L: number of leap-frog steps
        current_z: current position
    """

    # as of `torch-0.3.0.post4`, there still is no proper scalar support
    assert isinstance(epsilon, Variable)

    eps = epsilon.view(-1, 1)
    z = current_z
    v = current_v - grad_U(z).mul(eps).mul_(.5)

    for i in range(1, L+1):
        z = z + v.mul(eps)
        if i < L:
            v = v - grad_U(z).mul(eps)

    v = v - grad_U(z).mul(eps).mul_(.5)
    v = -v  # this is not needed; only here to conform to the math

    return z.detach(), v.detach()


def accept_reject(current_z, current_v, 
                  z, v, 
                  epsilon, 
                  accept_hist, hist_len, 
                  U, K=lambda v: torch.sum(v * v, 1)):
    """Accept/reject based on Hamiltonians for current and propose.

    Args:
        current_z: position BEFORE leap-frog steps
        current_v: speed BEFORE leap-frog steps
        z: position AFTER leap-frog steps
        v: speed AFTER leap-frog steps
        epsilon: step size of leap-frog.
                (This is only needed for adaptive update)
        U: function to compute potential energy (MINUS log-density)
        K: function to compute kinetic energy (default: kinetic energy in physics w/ mass=1)
    """

    mdtype = type(current_z.data)

    current_Hamil = K(current_v) + U(current_z)
    propose_Hamil = K(v) + U(z)

    prob = torch.exp(current_Hamil - propose_Hamil)
    uniform_sample = torch.rand(prob.size())
    uniform_sample = Variable(uniform_sample.type(mdtype))
    accept = (prob > uniform_sample).type(mdtype)
    z = z.mul(accept.view(-1, 1)) + current_z.mul(1. - accept.view(-1, 1))

    accept_hist = accept_hist.add(accept)
    criteria = (accept_hist / hist_len > 0.65).type(mdtype)
    adapt = 1.02 * criteria + 0.98 * (1. - criteria)
    epsilon = epsilon.mul(adapt).clamp(1e-4, .5)

    # clear previous history & save memory, similar to detach
    z = Variable(z.data, requires_grad=True)
    epsilon = Variable(epsilon.data)
    accept_hist = Variable(accept_hist.data)

    return z, epsilon, accept_hist


#ais file

In [0]:
from torch.autograd import Variable
from torch.autograd import grad as torchgrad


def ais_trajectory(
    model,
    loader,
    mode='forward',
    schedule=np.linspace(0., 1., 500),
    n_sample=100,
    log_likelihood_fn = log_bernoulli #[1040336] Enabled using arbitrary log-likelihood functions in order to use continous bernoulli
):
    """Compute annealed importance sampling trajectories for a batch of data. 
    Could be used for *both* forward and reverse chain in bidirectional Monte Carlo
    (default: forward chain with linear schedule).

    Args:
        model (vae.VAE): VAE model
        loader (iterator): iterator that returns pairs, with first component being `x`,
            second would be `z` or label (will not be used)
        mode (string): indicate forward/backward chain; must be either `forward` or 
            'backward' schedule (list or 1D np.ndarray): temperature schedule,
            i.e. `p(z)p(x|z)^t`; foward chain has increasing values, whereas
            backward has decreasing values
        n_sample (int): number of importance samples (i.e. number of parallel chains 
            for each datapoint)

    Returns:
        A list where each element is a torch.autograd.Variable that contains the 
        log importance weights for a single batch of data
    """

    assert mode == 'forward' or mode == 'backward', 'Should have forward/backward mode'

    def log_f_i(z, data, t, log_likelihood_fn=log_likelihood_fn):
        """Unnormalized density for intermediate distribution `f_i`:
            f_i = p(z)^(1-t) p(x,z)^(t) = p(z) p(x|z)^t
        =>  log f_i = log p(z) + t * log p(x|z)
        """
        zeros = Variable(torch.zeros(B, z_size).type(mdtype))
        log_prior = log_normal(z, zeros, zeros)
        log_likelihood = log_likelihood_fn(model.decode(z), data)

        return log_prior + log_likelihood.mul_(t)

    model.eval()

    # shorter aliases
    z_size = model.z_size
    mdtype = model.dtype

    _time = time.time()
    logws = []  # for output

    print ('In %s mode' % mode)

    for i, (batch, post_z) in enumerate(loader):

        B = batch.size(0) * n_sample
        batch = Variable(batch.type(mdtype))
        batch = safe_repeat(batch, n_sample)

        # batch of step sizes, one for each chain
        epsilon = Variable(torch.ones(B).type(model.dtype)).mul_(0.01)
        # accept/reject history for tuning step size
        accept_hist = Variable(torch.zeros(B).type(model.dtype))
        # record log importance weight; volatile=True reduces memory greatly
        logw = Variable(torch.zeros(B).type(mdtype), volatile=True)

        # initial sample of z
        if mode == 'forward':
            current_z = Variable(torch.randn(B, z_size).type(mdtype), requires_grad=True)
        else:
            current_z = Variable(safe_repeat(post_z, n_sample).type(mdtype), requires_grad=True)

        for j, (t0, t1) in tqdm(enumerate(zip(schedule[:-1], schedule[1:]), 1), position=0, leave=True):
            # update log importance weight
            log_int_1 = log_f_i(current_z, batch, t0)
            log_int_2 = log_f_i(current_z, batch, t1)
            logw.add_(log_int_2 - log_int_1)

            # resample speed
            current_v = Variable(torch.randn(current_z.size()).type(mdtype))

            def U(z):
                return -log_f_i(z, batch, t1)

            def grad_U(z):
                # grad w.r.t. outputs; mandatory in this case
                grad_outputs = torch.ones(B).type(mdtype)
                # torch.autograd.grad default returns volatile
                grad = torchgrad(U(z), z, grad_outputs=grad_outputs)[0]
                # avoid humongous gradients
                grad = torch.clamp(grad, -10000, 10000)
                # needs variable wrapper to make differentiable
                grad = Variable(grad.data, requires_grad=True)
                return grad

            def normalized_kinetic(v):
                zeros = Variable(torch.zeros(B, z_size).type(mdtype))
                # this is superior to the unnormalized version
                return -log_normal(v, zeros, zeros)

            z, v = hmc_trajectory(current_z, current_v, U, grad_U, epsilon)

            # accept-reject step
            current_z, epsilon, accept_hist = accept_reject(
                current_z, current_v,
                z, v,
                epsilon,
                accept_hist, j,
                U, K=normalized_kinetic
            )

        # IWAE lower bound
        logw = log_mean_exp(logw.view(n_sample, -1).transpose(0, 1))
        if mode == 'backward':
            logw = -logw
        logws.append(logw.data)

        print ('Time elapse %.4f, last batch stats %.4f' % \
            (time.time()-_time, logw.mean().cpu().data.numpy()))

        _time = time.time()
        sys.stdout.flush()  # for debugging

    return logws

#approx_posts file

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.utils.data
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


class Flow(nn.Module):
    """A combination of R-NVP and auxiliary variables."""

    def __init__(self, model, n_flows=2):
        super(Flow, self).__init__()
        self.z_size = model.z_size
        self.n_flows = n_flows
        self._construct_weights()

    def forward(self, z, grad_fn=lambda x: 1, x_info=None):
        return self._sample(z, grad_fn, x_info)

    def _norm_flow(self, params, z, v, grad_fn, x_info):
        h = F.elu(params[0][0](torch.cat((z, x_info), dim=1)))
        mu = params[0][1](h)
        logit = params[0][2](h)
        sig = F.sigmoid(logit)

        # old CIFAR used the one below
        # v = v * sig + mu * grad_fn(z)

        # the more efficient one uses the one below
        v = v * sig - F.elu(mu) * grad_fn(z)
        logdet_v = torch.sum(logit - F.softplus(logit), 1)

        h = F.elu(params[1][0](torch.cat((v, x_info), dim=1)))
        mu = params[1][1](h)
        logit = params[1][2](h)
        sig = F.sigmoid(logit)

        z = z * sig + mu
        logdet_z = torch.sum(logit - F.softplus(logit), 1)
        logdet = logdet_v + logdet_z

        return z, v, logdet

    def _sample(self, z0, grad_fn, x_info, k=1): #[1040336] added the ability to get k samples with the same mean and var
        x_info = x_info.repeat(k,1)
        B = z0.size(0)
        z_size = self.z_size
        act_func = F.elu
        qv_weights, rv_weights, params = self.qv_weights, self.rv_weights, self.params

        out = torch.cat((z0, x_info), dim=1)
        for i in range(len(qv_weights)-1):
            out = act_func(qv_weights[i](out))
        out = qv_weights[-1](out)
        mean_v0, logvar_v0 = out[:, :z_size], out[:, z_size:]

        eps = Variable(torch.randn(B, z_size).type( type(out.data) ))
        v0 = eps.mul(logvar_v0.mul(0.5).exp_()) + mean_v0
        logqv0 = log_normal(v0, mean_v0, logvar_v0)

        zT, vT = z0, v0
        logdetsum = 0.
        for i in range(self.n_flows):
            zT, vT, logdet = self._norm_flow(params[i], zT, vT, grad_fn, x_info)
            logdetsum += logdet

        # reverse model, r(vT|x,zT)
        out = torch.cat((zT, x_info), dim=1)
        for i in range(len(rv_weights)-1):
            out = act_func(rv_weights[i](out))
        out = rv_weights[-1](out)
        mean_vT, logvar_vT = out[:, :z_size], out[:, z_size:]
        logrvT = log_normal(vT, mean_vT, logvar_vT)

        assert logqv0.size() == (B,)
        assert logdetsum.size() == (B,)
        assert logrvT.size() == (B,)

        logprob = logqv0 - logdetsum - logrvT

        return zT, logprob

    def _construct_weights(self):
        z_size = self.z_size
        n_flows = self.n_flows
        h_s = 200

        qv_arch = rv_arch = [z_size*2, h_s, h_s, z_size*2]
        qv_weights, rv_weights = [], []

        # q(v|x,z)
        id = 0
        for ins, outs in zip(qv_arch[:-1], qv_arch[1:]):
            cur_layer = nn.Linear(ins, outs)
            qv_weights.append(cur_layer)
            self.add_module('qz%d' % id, cur_layer)
            id += 1

        # r(v|x,z)
        id = 0
        for ins, outs in zip(rv_arch[:-1], rv_arch[1:]):
            cur_layer = nn.Linear(ins, outs)
            rv_weights.append(cur_layer)
            self.add_module('rv%d' % id, cur_layer)
            id += 1

        # nf
        params = []
        for i in range(n_flows):
            layer_grid = [
                [nn.Linear(z_size*2, h_s),
                 nn.Linear(h_s, z_size),
                 nn.Linear(h_s, z_size)],
                [nn.Linear(z_size*2, h_s),
                 nn.Linear(h_s, z_size),
                 nn.Linear(h_s, z_size)],
            ]

            params.append(layer_grid)

            id = 0
            for layer_list in layer_grid:
                for layer in layer_list:
                    self.add_module('flow%d_layer%d' % (i, id), layer)
                    id += 1

        self.qv_weights = qv_weights
        self.rv_weights = rv_weights
        self.params = params

        self.sanity_check_param = self.params[0][0][0]._parameters['weight']


#VAE and CVAE files

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import sys
import argparse

import torch
import torch.utils.data
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from torch.autograd import Variable
from torch.autograd import grad as torchgrad


class VAE(nn.Module):
    """Generic VAE for MNIST and Fashion datasets."""
    def __init__(self, hps):
        super(VAE, self).__init__()

        self.z_size = hps.z_size
        self.has_flow = hps.has_flow
        self.use_cuda = hps.cuda
        self.act_func = hps.act_func
        self.n_flows = hps.n_flows
        self.hamiltonian_flow = hps.hamiltonian_flow
        self.decode_dist = hps.decode_dist #[1040336] added the ability to use arbitrary log-likelihoods, in order to use continuous bernoulli

        self._init_layers(wide_encoder=hps.wide_encoder)

        if self.use_cuda:
            self.cuda()
            self.dtype = torch.cuda.FloatTensor
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            self.dtype = torch.FloatTensor

    def _init_layers(self, wide_encoder=False):
        h_s = 500 if wide_encoder else 200

        self.fc1 = nn.Linear(784, h_s)  # assume flattened
        self.fc2 = nn.Linear(h_s, h_s)
        self.fc3 = nn.Linear(h_s, self.z_size*2)

        self.fc4 = nn.Linear(self.z_size, 200)
        self.fc5 = nn.Linear(200, 200)
        self.fc6 = nn.Linear(200, 784)

        self.x_info_layer = nn.Linear(200, self.z_size)

        if self.has_flow:
            self.q_dist = Flow(self, n_flows=self.n_flows)
            if self.use_cuda:
                self.q_dist.cuda()

    def sample(self, mu, logvar, grad_fn=lambda x: 1, x_info=None, k=1): #[1040336] added the ability to get k samples with the same mean and var
        if k>1:
          eps = Variable(torch.FloatTensor(k,mu.size()[0],mu.size()[1]).normal_().type(self.dtype))
          z = eps.mul(logvar.mul(0.5).exp_()).add_(mu)
          logqz = log_normal(z, mu, logvar,repeat = (k>1))
        else:
          eps = Variable(torch.FloatTensor(mu.size()).normal_().type(self.dtype))
          z = eps.mul(logvar.mul(0.5).exp_()).add_(mu)
          logqz = log_normal(z, mu, logvar)
          
        
          

        if self.has_flow:
            z, logprob = self.q_dist.forward(z, grad_fn, x_info)
            logqz += logprob

        zeros = Variable(torch.zeros(mu.size()).type(self.dtype))
        logpz = log_normal(z, zeros, zeros,repeat = (k>1))

        return z, logpz, logqz

    def encode(self, net):
        net = self.act_func(self.fc1(net))
        net = self.act_func(self.fc2(net))
        x_info = self.act_func(self.x_info_layer(net))
        net = self.fc3(net)

        mean, logvar = net[:, :self.z_size], net[:, self.z_size:]

        return mean, logvar, x_info

    def decode(self, net):
        net = self.act_func(self.fc4(net))
        net = self.act_func(self.fc5(net))
        logit = self.fc6(net)

        return logit

    def forward(self, x, k=1, warmup_const=1.): #[1040336] Modified to save memory by not copying the datameans and variance k times but rather generating k samples from them later.  
        mu, logvar, x_info = self.encode(x)
        # posterior-aware inference
        def U(z):
            logpx = self.decoude_dist(self.decode(z), x)
            logpz = log_normal(z)
            return -logpx - logpz  # energy as -log p(x, z)

        def grad_U(z):
            grad_outputs = torch.ones(z.size(0)).type(self.dtype)
            grad = torchgrad(U(z), z, grad_outputs=grad_outputs, create_graph=True)[0]
            # gradient clipping avoid numerical issue
            norm = torch.sqrt(torch.norm(grad, p=2, dim=1))
            # neither grad clip methods consistently outperforms the other
            grad = grad / norm.view(-1, 1)
            # grad = torch.clamp(grad, -10000, 10000)
            return grad.detach()

        if self.hamiltonian_flow:
            z, logpz, logqz = self.sample(mu, logvar, grad_fn=grad_U, x_info=x_info, k=k)
        else:
            z, logpz, logqz = self.sample(mu, logvar, x_info=x_info,k=k)

        logit = self.decode(z)
        logpx = self.decode_dist(logit, x, repeat = (k>1))
        elbo = logpx + logpz - warmup_const * logqz  # custom warmup
        # need correction for Tensor.repeat
        #elbo = log_mean_exp(elbo.view(k, -1).transpose(0, 1))
        #elbo = torch.mean(elbo)
        if k>1:
            max_ = torch.max(elbo, 0)[0] #[B]
            elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B]
        elbo = torch.mean(elbo)
        logpx = torch.mean(logpx)
        logpz = torch.mean(logpz)
        logqz = torch.mean(logqz)

        return elbo, logpx, logpz, logqz


In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import sys
import argparse

import torch
import torch.utils.data
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from torch.autograd import Variable
from torch.autograd import grad as torchgrad



class CVAE(nn.Module):
    """Convolutional VAE for CIFAR."""
    def __init__(self, hps):
        super(CVAE, self).__init__()

        self.z_size = hps.z_size
        self.has_flow = hps.has_flow
        self.hamiltonian_flow = hps.hamiltonian_flow
        self.n_flows = hps.n_flows
        self.use_cuda = hps.cuda
        self.act_func = hps.act_func

        self._init_layers(wide_encoder=hps.wide_encoder)

        if self.use_cuda:
            self.cuda()
            self.dtype = torch.cuda.FloatTensor
        else:
            self.dtype = torch.FloatTensor

    def _init_layers(self, wide_encoder=False):

        if wide_encoder:
            init_channel = 128
        else:
            init_channel = 64

        # encoder
        self.conv1 = nn.Conv2d(3, init_channel, 4, 2)
        self.conv2 = nn.Conv2d(init_channel, init_channel*2, 4, 2)
        self.conv3 = nn.Conv2d(init_channel*2, init_channel*4, 4, 2)
        self.fc_enc = nn.Linear(init_channel*4*2*2, self.z_size*2)

        self.bn_enc1 = nn.BatchNorm2d(init_channel)
        self.bn_enc2 = nn.BatchNorm2d(init_channel*2)
        self.bn_enc3 = nn.BatchNorm2d(init_channel*4)

        self.x_info_layer = nn.Linear(init_channel*4*2*2, self.z_size)

        # decoder
        self.fc_dec = nn.Linear(self.z_size, 256*2*2)
        self.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2)
        self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2)

        self.bn_dec1 = nn.BatchNorm2d(128)
        self.bn_dec2 = nn.BatchNorm2d(64)

        self.decoder_layers = []
        self.decoder_layers.append(self.deconv1)
        self.decoder_layers.append(self.deconv2)
        self.decoder_layers.append(self.deconv3)
        self.decoder_layers.append(self.fc_dec)
        self.decoder_layers.append(self.bn_dec1)
        self.decoder_layers.append(self.bn_dec2)

        if self.has_flow:
            self.q_dist = Flow(self, n_flows=self.n_flows)
            if self.use_cuda:
                self.q_dist.cuda()

    def encode(self, net):

        net = self.act_func(self.bn_enc1(self.conv1(net)))
        net = self.act_func(self.bn_enc2(self.conv2(net)))
        net = self.act_func(self.bn_enc3(self.conv3(net)))
        net = net.view(net.size(0), -1)
        x_info = self.act_func(self.x_info_layer(net))
        net = self.fc_enc(net)
        mean, logvar = net[:, :self.z_size], net[:, self.z_size:]

        return mean, logvar, x_info

    def decode(self, net):

        net = self.act_func(self.fc_dec(net))
        net = net.view(net.size(0), -1, 2, 2)
        net = self.act_func(self.bn_dec1(self.deconv1(net)))
        net = self.act_func(self.bn_dec2(self.deconv2(net)))
        logit = self.deconv3(net)

        return logit

    def sample(self, mu, logvar, grad_fn=lambda x: 1, x_info=None):
        # grad_fn default is identity, i.e. don't use grad info
        eps = Variable(torch.randn(mu.size()).type(self.dtype))
        z = eps.mul(logvar.mul(0.5).exp()).add(mu)
        logqz = log_normal(z, mu, logvar)

        if self.has_flow:
            z, logprob = self.q_dist.forward(z, grad_fn, x_info)
            logqz += logprob

        zeros = Variable(torch.zeros(z.size()).type(self.dtype))
        logpz = log_normal(z, zeros, zeros)

        return z, logpz, logqz

    def forward(self, x, k=1, warmup_const=1.):

        x = x.repeat(k, 1, 1, 1)  # for computing iwae bound
        mu, logvar, x_info = self.encode(x)

        # posterior-aware inference
        def U(z):
            logpx = self.decode_dist(self.decode(z), x)
            logpz = log_normal(z)
            return -logpx - logpz  # energy as -log p(x, z)

        def grad_U(z):
            grad_outputs = torch.ones(z.size(0)).type(self.dtype)
            grad = torchgrad(U(z), z, grad_outputs=grad_outputs, create_graph=True)[0]
            # gradient clipping by norm avoid numerical issue
            norm = torch.sqrt(torch.norm(grad, p=2, dim=1))
            grad = grad / norm.view(-1, 1)
            return grad.detach()

        if self.hamiltonian_flow:
            z, logpz, logqz = self.sample(mu, logvar, grad_fn=grad_U, x_info=x_info)
        else:
            z, logpz, logqz = self.sample(mu, logvar, x_info=x_info)

        logit = self.decode(z)
        logpx = self.decode_dist(logit, x)
        elbo = logpx + logpz - warmup_const * logqz  # custom warmup
        # correction for Tensor.repeat
        elbo = log_mean_exp(elbo.view(k, -1).transpose(0, 1))
        elbo = torch.mean(elbo)

        logpx = torch.mean(logpx)
        logpz = torch.mean(logpz)
        logqz = torch.mean(logqz)

        return elbo, logpx, logpz, logqz

    def reconstruct_img(self, x):

        # for visualization
        mu, logvar, x_info = self.encode(x)
        z, logpz, logqz = self.sample(mu, logvar)
        logit = self.decode(z)
        x_hat = torch.sigmoid(logit)

        return x_hat

    def freeze_decoder(self):
        # freeze so that decoder is not optimized
        for layer in self.decoder_layers:
            for param_name in layer._parameters:
                layer._parameters[param_name].requires_grad = False

    def unfreeze_decoder(self):
        # unfreeze so that decoder is optimized
        for layer in self.decoder_layers:
            for param_name in layer._parameters:
                layer._parameters[param_name].requires_grad = True



# Utils files

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import gzip
import numpy as np


def load_mnist(path, kind='train'):
    """Load MNIST data from `path`"""

    labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz' % kind)
    images_path = os.path.join(path, '%s-images-idx3-ubyte.gz' % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)

    return images, labels

#loader file

In [0]:
 # Function to binarise the arrays in place
def static_binarise(d):
  ids = d < 0.5
  d[ids] = 0.
  d[~ids] = 1.
  return(d)

In [0]:
import numpy as np
from scipy import io
import sys
import os
import time

import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.autograd import Variable
import collections
import pickle
from torch.autograd import Variable


class CIFAR10:

    def __init__(self,
                 part='train',
                 batch_size=128,
                 partial=None,
                 binarize=True,
                 valid_size=0.1,
                 num_workers=4,
                 pin_memory=False):

        transform_list = [transforms.ToTensor()]
        if binarize:
            transform_list.append(lambda x: x >= 0.5)
            transform_list.append(lambda x: x.float())

        data_transform = transforms.Compose(transform_list)
        train_set = datasets.CIFAR10('./datasets', train=True, download=True, transform=data_transform)
        valid_set = datasets.CIFAR10('./datasets', train=True, download=True, transform=data_transform)
        test_set  = datasets.CIFAR10('./datasets', train=False, download=True, transform=data_transform)

        num_train = len(train_set)
        indices = list(range(num_train))
        split = int(np.floor(valid_size * num_train))
        train_idx, valid_idx = indices[split:], indices[:split]

        self.loader = {
            'train': DataLoader(train_set,
                        batch_size=batch_size, sampler=SubsetRandomSampler(train_idx),
                        num_workers=num_workers, pin_memory=pin_memory, shuffle=False),
            'valid': DataLoader(valid_set,
                        batch_size=batch_size, sampler=SubsetRandomSampler(valid_idx),
                        num_workers=num_workers, pin_memory=pin_memory, shuffle=False),
            'test':  DataLoader(test_set, batch_size=batch_size,
                        num_workers=num_workers, pin_memory=pin_memory, shuffle=False)
        }[part]

        self.size = len(self.loader) if partial is None else partial // batch_size
        self._iter = iter(self.loader)
        self.batch_size = batch_size
        self.p = 0

    def __iter__(self):
        self.p = 0
        self._iter = iter(self.loader)
        return self

    def __next__(self):
        self.p += 1
        if self.p > self.size:
            raise StopIteration
        return next(self._iter)

    # due to inconsistency between py2 and py3
    def next(self):
        return self.__next__()

# This is a binarized version of MNIST
class Larochelle_MNIST:

    def __init__(self, part='train', batch_size=128, partial=1000):
        with open('datasets/mnist_non_binarised.pkl', 'rb') as f:
            # This is checking if you are using a version of Python < 3.0
            if sys.version_info[0] < 3:
                mnist = pickle.load(f)
            else:
                mnist = pickle.load(f, encoding='latin1')
            train = np.concatenate((mnist[0][0], mnist[1][0]))
            # clunky but this is how we'll turn on the binarising functionality for the moment
            # just change the below bool to true if you want to binarise the mnist file
            binarise = False
            if binarise:
              self.data = {
                  'train': static_binarise(train),
                  'test': static_binarise(mnist[2][0]),
                  'partial_train': static_binarise(mnist[0][0][:partial]),
                  'partial_test': static_binarise(mnist[2][0][:partial]),
              }[part]
            else:
              self.data = {
                  'train': train,
                  'test': mnist[2][0],
                  'partial_train': mnist[0][0][:partial],
                  'partial_test': mnist[2][0][:partial],
              }[part]
        self.size = self.data.shape[0]
        self.batch_size = batch_size
        self._construct()

    def __iter__(self):
        return iter(self.batch_list)

    def _construct(self):
        self.batch_list = []
        for i in range(self.size // self.batch_size):
            batch = self.data[self.batch_size*i:self.batch_size*(i+1)]
            batch = torch.from_numpy(batch)
            # placeholder for second entry
            self.batch_list.append((batch, None))


class Binarized_Omniglot:

    def __init__(self, part='train', batch_size=128, partial=1000):
        omni_raw = io.loadmat('datasets/chardata.mat')
        reshape_data = lambda d: d.reshape(
            (-1, 28, 28)).reshape((-1, 28*28), order='fortran')

        def static_binarize(d):
            # The mask is applied element-wise to the tensor
            ids = d < 0.5
            d[ids] = 0.
            d[~ids] = 1.

        train_data = reshape_data(omni_raw['data'].T.astype('float32'))
        test_data = reshape_data(omni_raw['testdata'].T.astype('float32'))
        static_binarize(train_data)
        static_binarize(test_data)

        assert train_data.shape == (24345, 784)
        assert test_data.shape == (8070, 784)

        self.data = {
            'train': train_data,
            'test':  test_data,
            'partial_train': train_data[:partial],
            'partial_test': test_data[:partial],
        }[part]
        self.size = self.data.shape[0]
        self.batch_size = batch_size
        self._construct()

    def __iter__(self):
        return iter(self.batch_list)

    def _construct(self):
        self.batch_list = []
        for i in range(self.size // self.batch_size):
            batch = self.data[self.batch_size*i:self.batch_size*(i+1)]
            batch = torch.from_numpy(batch)
            self.batch_list.append((batch, None))


class Binarized_Fashion:

    def __init__(self, part='train', batch_size=128, partial=1000):

        # I copied load_mnist in the colab, we don't need the line below anymore
        # from utils.mnist_reader import load_mnist
        train_raw, _ = load_mnist('datasets/fashion', kind='train')
        test_raw, _ = load_mnist('datasets/fashion', kind='t10k')

        grey_scale = lambda x: np.float32(x / 255.)

        def static_binarize(d):
            ids = d < 0.5
            d[ids] = 0.
            d[~ids] = 1.

        train_data = grey_scale(train_raw)
        test_data = grey_scale(test_raw)

        static_binarize(train_data)
        static_binarize(test_data)

        assert train_data.shape == (60000, 784)
        assert test_data.shape == (10000, 784)

        self.data = {
            'train': train_data[:55000],
            'valid': train_data[55000:],
            'test':  test_data,
            'partial_train': train_data[:partial],
            'partial_test': test_data[:partial],
        }[part]
        self.size = self.data.shape[0]
        self.batch_size = batch_size
        self._construct()

    def __iter__(self):
        return iter(self.batch_list)

    def _construct(self):
        self.batch_list = []
        for i in range(self.size // self.batch_size):
            batch = self.data[self.batch_size*i:self.batch_size*(i+1)]
            batch = torch.from_numpy(batch)
            self.batch_list.append((batch, None))


def get_default_mnist_loader():

    kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./datasets', train=True, download=True,
                       transform=transforms.ToTensor()),
        batch_size=128, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./datasets', train=False,
                       transform=transforms.ToTensor()),
        batch_size=100, shuffle=True, **kwargs)

    return train_loader, test_loader


def get_cifar10_loader(batch_size=100, partial=False, num=1000):

    if partial:
        train_loader = CIFAR10(part='train', batch_size=batch_size, partial=num)
        test_loader  = CIFAR10(part='test', batch_size=4)
    else:
        train_loader = CIFAR10(part='train', batch_size=batch_size)
        test_loader = CIFAR10(part='valid', batch_size=4)  # really validation set

    return train_loader, test_loader


def get_Larochelle_MNIST_loader(batch_size=100, partial=False, num=1000):

    if partial:
        train_loader = Larochelle_MNIST(part='partial_train', batch_size=batch_size, partial=num)
        test_loader = Larochelle_MNIST(part='partial_test')
    else:
        train_loader = Larochelle_MNIST(part='train', batch_size=batch_size)
        test_loader = Larochelle_MNIST(part='test', batch_size=batch_size)
    
    return train_loader, test_loader


def get_omniglot_loader(batch_size=100, partial=False, num=1000):

    if partial:
        train_loader = Binarized_Omniglot(part='partial_train', batch_size=batch_size, partial=num)
        test_loader = Binarized_Omniglot(part='partial_test')
    else:
        train_loader = Binarized_Omniglot(part='train', batch_size=batch_size)
        test_loader = Binarized_Omniglot(part='valid', batch_size=10)

    return train_loader, test_loader


def get_fashion_loader(batch_size=100, partial=False, num=1000):

    if partial:
        train_loader = Binarized_Fashion(part='partial_train', batch_size=batch_size, partial=num)
        test_loader = Binarized_Fashion(part='partial_test', batch_size=10)
    else:
        train_loader = Binarized_Fashion(part='train', batch_size=batch_size)
        test_loader = Binarized_Fashion(part='valid', batch_size=10)
    
    return train_loader, test_loader


# if __name__ == '__main__':
#     # sanity checking
#     train_loader, test_loader = get_cifar10_loader()
#     train_loader, test_loader = get_default_mnist_loader()
#     for i, (batch, _) in enumerate(train_loader):
#         batch = Variable(batch)
#         print (i)

#     for i, (batch, _) in enumerate(train_loader):
#         batch = Variable(batch)
#         print (i)


#helper file

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


def get_loaders(dataset='mnist', evaluate=False, batch_size=100):
    if dataset == 'mnist':
        train_loader, test_loader = get_Larochelle_MNIST_loader(
            batch_size=batch_size,
            partial=evaluate, num=1000
        )
    elif dataset == 'fashion':
        train_loader, test_loader = get_fashion_loader(
            batch_size=batch_size,
            partial=evaluate, num=1000
        )
    elif dataset == 'cifar':
        train_loader, test_loader = get_cifar10_loader(
            batch_size=batch_size,
            partial=evaluate, num=100
        )

    return train_loader, test_loader


def get_model(dataset, hps):
    if dataset == 'mnist' or dataset == 'fashion':
        model = VAE(hps)
    elif dataset == 'cifar':  # convolutional VAE for CIFAR
        model = CVAE(hps)

    return model


## local_ffg.py

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import sys
from tqdm import tqdm
import argparse
import numpy as np

import torch
import torch.utils.data
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


parser_ffg = argparse.ArgumentParser(description='local_factorized_gaussian')
# action configuration flags
parser_ffg.add_argument('--no-cuda', '-nc', action='store_true')
parser_ffg.add_argument('--debug', action='store_true', help='debug mode')

# model configuration flags
parser_ffg.add_argument('--cont-bernoulli','-cs',action ='store_true') #[1040336] Added the ability to use arbitray log-likelihoods in order to use continous bernoulli 
parser_ffg.add_argument('--z-size', '-zs', type=int, default=50)
parser_ffg.add_argument('--batch-size', '-bs', type=int, default=100)
parser_ffg.add_argument('--eval-path', '-ep', type=str, default='model.pth',
                    help='path to load evaluation ckpt (default: model.pth)')
parser_ffg.add_argument('--dataset', '-d', type=str, default='mnist',
                    choices=['mnist', 'fashion', 'cifar'], 
                    help='dataset to train and evaluate on (default: mnist)')
parser_ffg.add_argument('--has-flow', '-hf', action='store_true', help='inference uses FLOW')
parser_ffg.add_argument('--n-flows', '-nf', type=int, default=2, help='number of flows')
parser_ffg.add_argument('--wide-encoder', '-we', action='store_true',
                    help='use wider layer (more hidden units for FC, more channels for CIFAR)')


# args.cuda = not args.no_cuda and torch.cuda.is_available()


def get_default_hparams_local_ffg():
    return HParams(
        z_size=args.z_size,
        act_func=F.elu,
        has_flow=args.has_flow,
        n_flows=args.n_flows,
        wide_encoder=args.wide_encoder,
        cuda=args.cuda,
        hamiltonian_flow=False,
        decode_dist = args.decode_dist #[1040336] Added the ability to use arbitray log-likelihoods in order to use continous bernoulli
    )


def optimize_local_gaussian(
    log_likelihood,
    model,
    data_var,
    k=100,
    check_every=100,
    sentinel_thres=10,
    debug=True
):
    """data_var should be (cuda) variable."""

    B = data_var.size()[0]
    z_size = model.z_size

    data_var = safe_repeat(data_var, k)
    zeros = Variable(torch.zeros(B*k, z_size).type(model.dtype))
    mean = Variable(torch.zeros(B*k, z_size).type(model.dtype), requires_grad=True)
    logvar = Variable(torch.zeros(B*k, z_size).type(model.dtype), requires_grad=True)

    optimizer = optim.Adam([mean, logvar], lr=1e-3)
    best_avg, sentinel, prev_seq = 999999, 0, []

    # perform local opt
    time_ = time.time()
    for epoch in range(1, 999999):

        eps = Variable(torch.FloatTensor(mean.size()).normal_().type(model.dtype))
        z = eps.mul(logvar.mul(0.5).exp_()).add_(mean)
        x_logits = model.decode(z)

        logpz = log_normal(z, zeros, zeros)
        logqz = log_normal(z, mean, logvar)
        logpx = log_likelihood(x_logits, data_var)

        optimizer.zero_grad()
        loss = -torch.mean(logpx + logpz - logqz)
        loss_np = loss.data.cpu().numpy()
        loss.backward()
        optimizer.step()

        prev_seq.append(loss_np)
        if epoch % check_every == 0:
            last_avg = np.mean(prev_seq)
            if debug:  # debugging helper
                sys.stderr.write(
                    'Epoch %d, time elapse %.4f, last avg %.4f, prev best %.4f\n' % \
                    (epoch, time.time()-time_, -last_avg, -best_avg)
                )
            if last_avg < best_avg:
                sentinel, best_avg = 0, last_avg
            else:
                sentinel += 1
            if sentinel > sentinel_thres:
                break
            prev_seq = []
            time_ = time.time()

    # evaluation
    eps = Variable(torch.FloatTensor(B*k, z_size).normal_().type(model.dtype))
    z = eps.mul(logvar.mul(0.5).exp_()).add_(mean)

    logpz = log_normal(z, zeros, zeros)
    logqz = log_normal(z, mean, logvar)
    logpx = log_likelihood(model.decode(z), data_var)
    elbo = logpx + logpz - logqz

    vae_elbo = torch.mean(elbo)
    iwae_elbo = torch.mean(log_mean_exp(elbo.view(k, -1).transpose(0, 1)))

    return vae_elbo.data[0], iwae_elbo.data[0]


def main_ffg(arg_string):
    global args
    args = parser_ffg.parse_args(arg_string.split())   #parser_ffg.parse_args()
    args.decode_dist = log_bernoulli
    if args.cont_bernoulli:
        args.decode_dist = log_bernoulli_cont
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    print(args.cuda)
    train_loader, test_loader = get_loaders(
        dataset=args.dataset,
        evaluate=True, batch_size=args.batch_size
    )
    print('Done: train_loader, test_loader')
    model = get_model(args.dataset, get_default_hparams_local_ffg())
    print('Done: model; get_model')
    model.load_state_dict(torch.load(args.eval_path)['state_dict'])
    print('Done: model.load_state_dict')
    model.eval()
    print('Done: model.eval')

    vae_record, iwae_record = [], []
    time_ = time.time()
    print('Done: time_ initalise')
    for i, (batch, _) in tqdm(enumerate(train_loader)):
        batch = Variable(batch.type(model.dtype))
        print('Done: batch')
        elbo, iwae = optimize_local_gaussian(self.decode_dist, model, batch, debug=args.debug)
        print('Done: elbo, iwae')
        vae_record.append(elbo)
        print('Done: vae_record append elbo')
        iwae_record.append(iwae)
        print('Done: iwae_record append iwae')
        print ('Local opt w/ ffg, batch %d, time elapse %.4f, ELBO %.4f, IWAE %.4f' % \
            (i+1, time.time()-time_, elbo, iwae))
        print ('mean of ELBO so far %.4f, mean of IWAE so far %.4f' % \
            (np.nanmean(vae_record), np.nanmean(iwae_record)))
        time_ = time.time()
        print('Done: time_ loop')

    print ('Finishing...')
    print ('Average ELBO %.4f, IWAE %.4f' % (np.nanmean(vae_record), np.nanmean(iwae_record)))

## local_flow.py

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import sys
from tqdm import tqdm
import argparse
import numpy as np

import torch
import torch.utils.data
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


#   from loader import get_Larochelle_MNIST_loader, get_fashion_loader, get_cifar10_loader
#   from vae import VAE
#   from cvae import CVAE


parser_flow = argparse.ArgumentParser(description='local_expressive')
# action configuration flags
parser_flow.add_argument('--no-cuda', '-nc', action='store_true')
parser_flow.add_argument('--debug', action='store_true', help='debug mode')

# model configuration flags
parser_flow.add_argument('--z-size', '-zs', type=int, default=50)
parser_flow.add_argument('--batch-size', '-bs', type=int, default=100)
parser_flow.add_argument('--eval-path', '-ep', type=str, default='model.pth',
                    help='path to load evaluation ckpt (default: model.pth)')
parser_flow.add_argument('--dataset', '-d', type=str, default='mnist',
                    choices=['mnist', 'fashion', 'cifar'], 
                    help='dataset to train and evaluate on (default: mnist)')
parser_flow.add_argument('--has-flow', '-hf', action='store_true', help='inference uses FLOW')
parser_flow.add_argument('--n-flows', '-nf', type=int, default=2, help='number of flows')
parser_flow.add_argument('--wide-encoder', '-we', action='store_true',
                    help='use wider layer (more hidden units for FC, more channels for CIFAR)')
parser_flow.add_argument('--cont-bernoulli','-cs',action ='store_true') #[1040336] Added the ability to use arbitray log-likelihoods in order to use continous bernoulli
#   args = parser_flow.parse_args()
#   args.cuda = not args.no_cuda and torch.cuda.is_available()


def get_default_hparams_local_flow():
    return HParams(
        z_size=args.z_size,
        act_func=F.elu,
        has_flow=args.has_flow,
        n_flows=args.n_flows,
        wide_encoder=args.wide_encoder,
        cuda=args.cuda,
        hamiltonian_flow=False,
        decode_dist=args.decode_dist #[1040336] Added the ability to use arbitray log-likelihoods in order to use continous bernoulli
    )


def optimize_local_expressive(
    log_likelihood,
    model,
    data_var,
    k=100,
    check_every=100,
    sentinel_thres=10,
    n_flows=2,
    debug=False
):
    """data_var should be (cuda) variable."""

    def log_joint(x_logits, x, z):
        """log p(x,z)"""
        zeros = Variable(torch.zeros(z.size()).type(model.dtype))
        logpz = log_normal(z, zeros, zeros)
        logpx = log_likelihood(x_logits, x)

        return logpx + logpz

    def norm_flow(params, z, v):

        h = F.tanh(params[0][0](z))
        mew_ = params[0][1](h)
        logit_ = params[0][2](h)
        sig_ = F.sigmoid(logit_)

        v = v*sig_ + mew_
        # numerically stable: log (sigmoid(logit)) = logit - softplus(logit)
        logdet_v = torch.sum(logit_ - F.softplus(logit_), 1)

        h = F.tanh(params[1][0](v))
        mew_ = params[1][1](h)
        logit_ = params[1][2](h)
        sig_ = F.sigmoid(logit_)

        z = z*sig_ + mew_
        logdet_z = torch.sum(logit_ - F.softplus(logit_), 1)

        logdet = logdet_v + logdet_z

        return z, v, logdet

    def sample(mean_v0, logvar_v0):

        B = mean_v0.size()[0]
        eps = Variable(torch.FloatTensor(B, z_size).normal_().type(model.dtype))
        v0 = eps.mul(logvar_v0.mul(0.5).exp_()) + mean_v0
        logqv0 = log_normal(v0, mean_v0, logvar_v0)

        out = v0
        for i in range(len(qz_weights)-1):
            out = act_func(qz_weights[i](out))
        out = qz_weights[-1](out)
        mean_z0, logvar_z0 = out[:, :z_size], out[:, z_size:]

        eps = Variable(torch.FloatTensor(B, z_size).normal_().type(model.dtype))
        z0 = eps.mul(logvar_z0.mul(0.5).exp_()) + mean_z0
        logqz0 = log_normal(z0, mean_z0, logvar_z0)

        zT, vT = z0, v0
        logdetsum = 0.
        for i in range(n_flows):
            zT, vT, logdet = norm_flow(params[i], zT, vT)
            logdetsum += logdet

        # reverse model, r(vT|x,zT)
        out = zT
        for i in range(len(rv_weights)-1):
            out = act_func(rv_weights[i](out))
        out = rv_weights[-1](out)
        mean_vT, logvar_vT = out[:, :z_size], out[:, z_size:]
        logrvT = log_normal(vT, mean_vT, logvar_vT)

        logq = logqz0 + logqv0 - logdetsum - logrvT

        return zT, logq

    def get_params():

        all_params = []

        mean_v = Variable(torch.zeros(B*k, z_size).type(model.dtype), requires_grad=True)
        logvar_v = Variable(torch.zeros(B*k, z_size).type(model.dtype), requires_grad=True)

        all_params.append(mean_v)
        all_params.append(logvar_v)

        qz_weights = []  # q(z|x,v)
        for ins, outs in zip(qz_arch[:-1], qz_arch[1:]):
            cur_layer = nn.Linear(ins, outs)
            if args.cuda:
                cur_layer.cuda()
            qz_weights.append(cur_layer)
            all_params.append(cur_layer.weight)

        rv_weights = []  # r(v|x,z)
        for ins, outs in zip(rv_arch[:-1], rv_arch[1:]):
            cur_layer = nn.Linear(ins, outs)
            if args.cuda:
                cur_layer.cuda()
            rv_weights.append(cur_layer)
            all_params.append(cur_layer.weight)

        params = []
        for i in range(n_flows):
            layers = [
                [nn.Linear(z_size, h_s),
                 nn.Linear(h_s, z_size),
                 nn.Linear(h_s, z_size)],
                [nn.Linear(z_size, h_s),
                 nn.Linear(h_s, z_size),
                 nn.Linear(h_s, z_size)],
            ]

            params.append(layers)

            for sublist in layers:
                for item in sublist:
                    all_params.append(item.weight)
                    if args.cuda:
                        item.cuda()

        return (mean_v, logvar_v), all_params, params, qz_weights, rv_weights

    # the real shit
    B = data_var.size(0)
    z_size = args.z_size
    qz_arch = rv_arch = [args.z_size, 200, 200, args.z_size*2]
    h_s = 200
    act_func = F.elu

    data_var = safe_repeat(data_var, k)
    (mean_v, logvar_v), all_params, params, qz_weights, rv_weights = get_params()

    # tile input for IS
    optimizer = optim.Adam(all_params, lr=1e-3)
    best_avg, sentinel, prev_seq = 999999, 0, []

    # perform local opt
    time_ = time.time()
    for epoch in range(1, 999999):
        z, logqz = sample(mean_v, logvar_v)
        x_logits = model.decode(z)
        logpxz = log_joint(x_logits, data_var, z)

        optimizer.zero_grad()
        loss = -torch.mean(logpxz - logqz)
        loss_np = loss.data.cpu().numpy()
        loss.backward()
        optimizer.step()

        prev_seq.append(loss_np)
        if epoch % check_every == 0:
            last_avg = np.mean(prev_seq)
            if debug:  # debugging helper
                sys.stderr.write(
                    'Epoch %d, time elapse %.4f, last avg %.4f, prev best %.4f\n' % \
                    (epoch, time.time()-time_, -last_avg, -best_avg)
                )
            if last_avg < best_avg:
                sentinel, best_avg = 0, last_avg
            else:
                sentinel += 1
            if sentinel > sentinel_thres:
                break

            prev_seq = []
            time_ = time.time()

    # evaluation
    z, logqz = sample(mean_v, logvar_v)
    x_logits = model.decode(z)
    logpxz = log_joint(x_logits, data_var, z)
    elbo = logpxz - logqz

    vae_elbo = torch.mean(elbo)
    iwae_elbo = torch.mean(log_mean_exp(elbo.view(k, -1).transpose(0, 1)))

    return vae_elbo.data[0], iwae_elbo.data[0]


def main_flow(arg_string):
    global args
    args = parser_flow.parse_args(arg_string.split())   #parser_flow.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.decode_dist = log_bernoulli
    if args.cont_bernoulli:
        args.decode_dist = log_bernoulli_cont
    print(args.cuda)
    train_loader, test_loader = get_loaders(
        dataset=args.dataset,
        evaluate=True, batch_size=1)
    model = get_model(args.dataset, get_default_hparams_local_flow())
    model.load_state_dict(torch.load(args.eval_path)['state_dict'])
    model.eval()

    vae_record, iwae_record = [], []
    time_ = time.time()
    for i, (batch, _) in tqdm(enumerate(train_loader)):
        batch = Variable(batch.type(model.dtype))
        elbo, iwae = optimize_local_expressive(
            args.decode_dist,
            model,
            batch,
            n_flows=args.n_flows, debug=args.debug
        )
        vae_record.append(elbo)
        iwae_record.append(iwae)
        print ('Local opt w/ flow, batch %d, time elapse %.4f, ELBO %.4f, IWAE %.4f' % \
            (i+1, time.time()-time_, elbo, iwae))
        print ('mean of ELBO so far %.4f, mean of IWAE so far %.4f' % \
            (np.nanmean(vae_record), np.nanmean(iwae_record)))
        time_ = time.time()

    print ('Finishing...')
    print ('Average ELBO %.4f, IWAE %.4f' % (np.nanmean(vae_record), np.nanmean(iwae_record)))

#Params File


In [0]:
parser = argparse.ArgumentParser(description='VAE')
# action configuration flags
parser.add_argument('--train', '-t', action='store_true')
parser.add_argument('--load-path', '-lp', type=str, default='NA',
                    help='path to load checkpoint to retrain')
parser.add_argument('--load-epoch', '-le', type=int, default=0,
                    help='epoch number to start recording when retraining')
parser.add_argument('--display-epoch', '-de', type=int, default=10,
                    help='print status every so many epochs (default: 10)')
parser.add_argument('--eval-iwae', '-ei', action='store_true')
parser.add_argument('--eval-ais', '-ea', action='store_true')
parser.add_argument('--n-iwae', '-ni', type=int, default=5000,
                    help='number of samples for IWAE evaluation (default: 5000)')
parser.add_argument('--n-ais-iwae', '-nai', type=int, default=100,
                    help='number of IMPORTANCE samples for AIS evaluation (default: 100). \
                          This is different from MC samples.')
parser.add_argument('--n-ais-dist', '-nad', type=int, default=10000,
                    help='number of distributions for AIS evaluation (default: 10000)')
parser.add_argument('--ais-schedule', type=str, default='linear', help='schedule for AIS')

parser.add_argument('--no-cuda', '-nc', action='store_true', help='force not use CUDA')
parser.add_argument('--visdom', '-v', action='store_true', help='visualize samples')
parser.add_argument('--port', '-p', type=int, default=8097, help='port for visdom')
parser.add_argument('--save-visdom', default='test', help='visdom save path')
parser.add_argument('--encoder-more', action='store_true', help='train the encoder more (5 vs 1)')
parser.add_argument('--early-stopping', '-es', action='store_true', help='apply early stopping')
parser.add_argument('--epochs', '-e', type=int, default=3280,
                    help='total num of epochs for training (default: 3280)')
parser.add_argument('--lr-schedule', '-lrs', action='store_true',
                    help='apply learning rate schedule')

# model configuration flags
parser.add_argument('--z-size', '-zs', type=int, default=50,
                    help='dimensionality of latent code (default: 50)')
parser.add_argument('--batch-size', '-bs', type=int, default=100,
                    help='batch size (default: 100)')
parser.add_argument('--save-name', '-sn', type=str, default='model.pth',
                    help='name to save trained ckpt (default: model.pth)')
parser.add_argument('--eval-path', '-ep', type=str, default='model.pth',
                    help='path to load evaluation ckpt (default: model.pth)')
parser.add_argument('--dataset', '-d', type=str, default='mnist',
                    choices=['mnist', 'fashion', 'cifar'],
                    help='dataset to train and evaluate on (default: mnist)')
parser.add_argument('--wide-encoder', '-we', action='store_true',
                    help='use wider layer (more hidden units for FC, more channels for CIFAR)')
parser.add_argument('--has-flow', '-hf', action='store_true',
                    help='use flow for training and eval')
parser.add_argument('--adadelta', '-ad', action='store_true', #[1040336] Did my own version of changed optimiser due to misscommunication on who does what
                    help='use AdaDelta optimizer')
parser.add_argument('--sgd', '-sgd', action='store_true')#[1040336] Did my own version of changed optimiser due to misscommunication on who does what
parser.add_argument('--hamiltonian-flow', '-hamil-f', action='store_true')
parser.add_argument('--n-flows', '-nf', type=int, default=2, help='number of flows')
parser.add_argument('--warmup', '-w', action='store_true',
                    help='apply warmup during training')
parser.add_argument('--cont-bernoulli','-cs',action ='store_true') #[1040336] Added the ability to use arbitray log-likelihoods in order to use continous bernoulli
parser.add_argument('--get-samples','-gs',action ='store_true') #[1040336] Added the ability samples images from the model
parser.add_argument('--samples-name','-sna',type = str, default = 'Samples')


def get_default_hparams():
    return HParams(
        z_size=args.z_size,
        act_func=F.elu,
        has_flow=args.has_flow,
        hamiltonian_flow=args.hamiltonian_flow,
        n_flows=args.n_flows,
        wide_encoder=args.wide_encoder,
        cuda=args.cuda,
        decode_dist=args.decode_dist #[1040336] Added the ability to use arbitray log-likelihoods in order to use continous bernoulli
    )


def train(
    model,
    train_loader,
    test_loader,
    k_train=1,  # num iwae sample for training
    k_eval=1,  # num iwae sample for eval
    epochs=3280,
    display_epoch=10,
    adadelta = False,
    sgd = False,
    lr_schedule=True,
    warmup=True,
    warmup_thres=None,
    encoder_more=False,
    checkpoints=None,
    early_stopping=True,
    save=True,
    save_path='checkpoints/',
    patience=10  # for early-stopping
):
    print('Training')

    if args.load_path != 'NA':
        f = args.load_path
        model.load_state_dict(torch.load(f)['state_dict'])

    # default warmup schedule
    if warmup_thres is None:
        if 'cifar' in save_path:
            warmup_thres = 50.
        elif 'mnist' in save_path or 'fashion' in save_path:
            warmup_thres = 400.

    if checkpoints is None:  # save a checkpoint every display_epoch
        checkpoints = [1] + list(range(0, 3280, display_epoch))[1:] + [3280]

    time_ = time.time()

    if lr_schedule:
        current_lr = 1e-3
        pow = 0
        epoch_elapsed = 0
        # pth default: beta_1 = .9, beta_2 = .999, eps = 1e-8
        optimizer = optim.Adam(model.parameters(), lr=current_lr, eps=1e-4) 
    elif adadelta:
        optimizer = optim.Adadelta(model.parameters(), lr = 0.9)
    elif sgd:
        optimizer = optim.SGD(model.parameters(), lr =4e-3)
    else:
        optimizer = optim.Adam(model.parameters(), lr=1e-4, eps=1e-4)
      

    num_worse = 0  # compare against `patience` for early-stopping
    prev_valid_err = None

    for epoch in tqdm(range(1, epochs+1), position=0, leave=True):
        warmup_const = min(1., epoch / warmup_thres) if warmup else 1.
        # lr schedule from IWAE: https://arxiv.org/pdf/1509.00519.pdf
        if lr_schedule:
            if epoch_elapsed >= 3 ** pow:
                current_lr *= 10. ** (-1. / 7.)
                pow += 1
                epoch_elapsed = 0
                # correct way to do lr decay; also possible w/ `torch.optim.lr_scheduler`
                for param_group in optimizer.param_groups:
                    param_group['lr'] = current_lr
            epoch_elapsed += 1

        above_checkpoint_exists = os.path.exists('%s%d_%s' % (save_path, checkpoints[0], args.save_name))
        # print("epoch = {}, display_epoch = {}, load_epoch = {}".format(epoch, display_epoch, args.load_epoch))
        # print(int(epoch / display_epoch + display_epoch + args.load_epoch))
        
        if above_checkpoint_exists:
          if epoch == checkpoints[0]:
            # load that checkpoint
            model.load_state_dict(torch.load('%s%d_%s' % (save_path, checkpoints[0], args.save_name))['state_dict'])
        else:
          model.train()  # crucial for BN to work properly
          for _, (batch, _) in enumerate(train_loader):
              batch = Variable(batch)
              if args.cuda:
                  batch = batch.cuda()

              # train the encoder more
              if encoder_more:
                  model.freeze_decoder()
                  for _ in range(10):
                      optimizer.zero_grad()
                      elbo, _, _, _ = model.forward(batch, k_train, warmup_const)
                      loss = -elbo
                      loss.backward()
                      optimizer.step()
                  model.unfreeze_decoder()

              optimizer.zero_grad()
              elbo, _, _, _ = model.forward(batch, k_train, warmup_const)
              loss = -elbo
              loss.backward()
              optimizer.step()

        if epoch % display_epoch == 0:
            assert(int(epoch / display_epoch) * 10 == checkpoints[0])
            model.eval()  # crucial for BN to work properly

            train_logpx, test_logpx = [], []
            train_logpz, test_logpz = [], []
            train_logqz, test_logqz = [], []
            train_stats, test_stats = [], []
            for _, (batch, _) in enumerate(train_loader):
                batch = Variable(batch)
                if args.cuda:
                    batch = batch.cuda()
                elbo, logpx, logpz, logqz = model(batch, k=1)
                train_stats.append(elbo.data[0])
                train_logpx.append(logpx.data[0])
                train_logpz.append(logpz.data[0])
                train_logqz.append(logqz.data[0])

            for _, (batch, _) in enumerate(test_loader):
                batch = Variable(batch)
                if args.cuda:
                    batch = batch.cuda()
                # early stopping with iwae bound
                elbo, logpx, logpz, logqz = model(batch, k=k_eval)
                test_stats.append(elbo.data[0])
                test_logpx.append(logpx.data[0])
                test_logpz.append(logpz.data[0])
                test_logqz.append(logqz.data[0])
            print (
                'Train Epoch: [{}/{}]'.format(epoch, epochs),
                'Train set ELBO {:.4f}'.format(np.mean(np.asarray(train_stats))),
                'Test/Validation set IWAE {:.4f}'.format(np.mean(np.asarray(test_stats))),
                'Time: {:.2f}'.format(time.time()-time_),
            )
            time_ = time.time()

            if early_stopping:
                curr_valid_err = np.mean(test_stats)

                if prev_valid_err is None:  # don't have history yet
                    prev_valid_err = curr_valid_err
                elif curr_valid_err >= prev_valid_err:  # performance improved
                    prev_valid_err = curr_valid_err
                    num_worse = 0
                else:
                    num_worse += 1

                if num_worse >= patience:
                    print("Stopped early")
                    break

        if save and epoch in checkpoints:
          assert(epoch + args.load_epoch == checkpoints[0])
          if not above_checkpoint_exists:
            torch.save({
                'epoch': epochs + args.load_epoch,
                'state_dict': model.state_dict(),
            }, '%s%d_%s' % (save_path, epoch + args.load_epoch, args.save_name))
          checkpoints = checkpoints[1:]


def test_iwae(  
    model,
    loader,
    k=5000,
    f='model.pth',
    print_res=True
): 
    print('Testing with %d importance samples' % k)
    #model.load_state_dict(torch.load(f)['state_dict'])
    #model.eval()
    time_ = time.time()
    elbos = []
    for i, (batch, _) in enumerate(loader):
        batch = Variable(batch)
        if args.cuda:
            batch = batch.cuda()
        elbo, logpx, logpz, logqz = model(batch, k=k)
        elbos.append(elbo.data[0])

    mean_ = np.mean(elbos)
    if print_res:
        print(mean_, 'T:', time.time()-time_)
    return mean_


def run(arg_string):
    global args
    args = parser.parse_args(arg_string.split()) #parser.parse_args()
    # args.eval_path = '/content/drive/My Drive/ATiML' + args.eval_path
    # Sanity check to know which arguments are being parsed
    print(args)
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    print(args.cuda)
    args.decode_dist = log_bernoulli #[1040336]
    if args.cont_bernoulli:
        args.decode_dist = log_bernoulli_cont
    train_loader, test_loader = get_loaders(
        dataset=args.dataset,
        evaluate=args.eval_iwae or args.eval_ais, # HERE
        batch_size=args.batch_size
    )
    model = get_model(args.dataset, get_default_hparams())

    if args.train:
        save_path = 'checkpoints/%s/%s/%s%s%s/' % ( #[1040336] Added proper naming to alternate models
                        args.dataset,
                        'warmup' if args.warmup else 'no_warmup',
                        'wide_' if args.wide_encoder else '',
                        'hamiltonian_flow' if args.hamiltonian_flow else
                            'flow' if args.has_flow else 'ffg',
                        '_adadelta' if args.adadelta else '_sgd' if args.sgd else '_contbern' if args.cont_bernoulli else ''
                    )
        if not os.path.exists(save_path):
          os.makedirs(save_path)

        train(
            model, train_loader, test_loader,
            display_epoch=args.display_epoch, epochs=args.epochs,
            lr_schedule=args.lr_schedule,
            adadelta = args.adadelta, #[1040336]
            sgd = args.sgd, #[1040336]
            warmup=args.warmup,
            early_stopping=args.early_stopping,
            encoder_more=args.encoder_more,
            save=True, save_path=save_path
        )

    if args.visdom:
        vis = visdom.Visdom(env=args.save, port=args.port)
        model.load_state_dict(torch.load(args.eval_path)['state_dict'])

        # plot original images
        batch, _ = train_loader.next()
        images = list(batch.numpy())
        win_samples = vis.images(images, 10, 2, opts={'caption': 'original images'}, win=None)

        # plot reconstructions
        batch = Variable(batch.type(model.dtype))
        reconstruction = model.reconstruct_img(batch)
        images = list(reconstruction.data.cpu().numpy())
        win_samples = vis.images(images, 10, 2, opts={'caption': 'reconstruction'}, win=None)

    if args.eval_iwae: #[1040336] Fixed crash from appending list to itself, improved downstream memory management to get it to run on colab
        # VAE bounds computed w/ 100 MC samples to reduce variance
        model.load_state_dict(torch.load(args.eval_path)['state_dict'])
        model.eval()
        train_res, test_res = [], []
        for _ in range(1):
            train_val = test_iwae(model, train_loader, k=1, f=args.eval_path)
            test_val = test_iwae(model, test_loader, k=1, f=args.eval_path)
            #print("about to append")
            train_res.append(train_val)
            #print("appended train_res")
            test_res.append(test_val)
            #print("appended test_res")

        #print("exited for loop")
        #print("length of train_res = {}".format(len(train_res)))
        train_mean = np.mean(train_res)
        #print("finished calculating mean")
        print ('Training set VAE ELBO w/ 100 MC samples: %.4f' % train_mean)
        print ('Test set VAE ELBO w/ 100 MC samples: %.4f' % np.mean(test_res))

        # IWAE bounds
        print('Computing training set IWAE:')
        test_iwae(model, train_loader, k=args.n_iwae, f=args.eval_path)
        print('Computing test set IWAE:')
        test_iwae(model, test_loader, k=args.n_iwae, f=args.eval_path)

    if args.eval_ais:
        model.load_state_dict(torch.load(args.eval_path)['state_dict'])
        schedule_fn = linear_schedule if args.ais_schedule == 'linear' else sigmoidial_schedule
        schedule = schedule_fn(args.n_ais_dist)
        logws = ais_trajectory(
            model, train_loader,
            mode='forward', schedule=schedule, n_sample=args.n_ais_iwae, log_likelihood_fn = args.decode_dist
        )
        print('Average importance weight: {}'.format(np.mean(logws).mean())) #[1040336] This print was missing
    if args.get_samples: #[1040336] Added this function to print mnist images
        n = 10
        model.load_state_dict(torch.load(args.eval_path)['state_dict'])
        z = Variable(torch.randn(n,args.z_size))
        logits = model.decode(z)
        prob = torch.sigmoid(logits)
        if args.cont_bernoulli:
            prob = mean_cont_bern(prob)
        images = prob.view(n, 1, 28, 28).detach()
        image = utils.make_grid( images.data, n)
        npimg = image.cpu().numpy() # BlacK background
        plt.figure(figsize = (10, 10))
        plt.title(args.samples_name, fontsize = 20)
        plt.axis('off')
        plt.imshow(np.transpose(npimg, (1, 2, 0)), 
                  interpolation = 'nearest')