In [10]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from flax import linen as nn
from flax import optim

import time
import sys
import os
import gzip
import math
import argparse
import numpy as np
import matplotlib.pyplot as plt

In [11]:
def log_normal(x, mean=None, logvar=None,repeat=False):
    if mean is None:
        mean = jnp.zeros(jnp.shape(x), dtype = x.dtype)
    if logvar is None:
        logvar = jnp.zeros(jnp.shape(x), dtype = x.dtype)
    if repeat:
        D = jnp.shape(x)[2]
        term1 = D * jnp.log(jnp.array([2.*math.pi]))
        return -.5 * (term1 + logvar.sum(1) + ((x - mean).pow(2)/jnp.exp(logvar)).sum(2))

    return -0.5 * (logvar.sum(1) + ((x - mean).pow(2) / jnp.exp(logvar)).sum(1))

def log_bernoulli(logit, target, repeat = False):
    if repeat:
        return -(jnp.clip(logit, min=0) - logit * target
             + jnp.log(1. + jnp.exp(-jnp.abs(logit)))).sum(2) #sum over dimensions
    
    loss = -nn.relu(logit) + jnp.multiply(target, logit) - jnp.log(1. + jnp.exp(-jnp.abs(logit)))
    while len(loss.size()) > 1:
        loss = loss.sum(-1)
        
    return loss

# VAE

In [12]:
class Encoder(nn.Module):
    def setup(self, act_func, h_s, z_size):
        self.fc1 = nn.Dense(h_s)
        self.fc2 = nn.Dense(h_s)
        self.fc3 = nn.Dense(z_size*2)
        self.x_info_layer = nn.Dense(z_size)
        self.act_func = act_func
    
    def __call__(self, x):
        x = self.act_func(self.fc1(x))
        x = self.act_func(self.fc2(x))
        x_info = self.act_func(self.x_info_layer(x))
        x = self.fc3(x)
        
        mean, logvar = x[:, :self.z_size], x[:, self.z_size:]
        return mean, logvar, x_info

class Decoder(nn.Module):
    def setup(self, act_func):
        self.fc4 = nn.Dense(200)
        self.fc5 = nn.Dense(200)
        self.fc6 = nn.Dense(784)
        self.act_func = act_func
    
    def __call__(self, z, act_func):
        z = act_func(self.fc4(z))
        z = act_func(self.fc5(z))
        z = self.fc6(z)
        return z
    
class VAE(nn.Module):
    def setup(self, hps):
        self.has_flow = hps.has_flow
#         self.use_cuda = hps.cuda
        self.hamiltonian_flow = hps.hamiltonian_flow

        self.encode = Encoder(hps.act_func, hps.h_s, hps.z_size)
        self.decode = Decoder(hps.act_func)
        
        if hps.has_flow:
            self.q_dist = Flow(self, n_flows=hps.n_flows)
#             if self.use_cuda:
#                 self.q_dist.cuda()

    def __call__(self, x, k=1, warmup_const=1.):
        x = x.repeat(k, 1)
        mean, logvar, x_info = self.encode(x)
        
        if self.hamiltonian_flow:
            z, logpz, logqz = self.sample(mean, logvar, grad_fn=grad_U, x_info=x_info)
        else:
            z, logpz, logqz = self.sample(mean, logvar, x_info=x_info)

        logit = self.decode(z)
        logpx = log_bernoulli(logit, x)
        elbo = logpx + logpz - warmup_const * logqz 

        # need correction for Tensor.repeat
        elbo = log_mean_exp(elbo.view(k, -1).transpose(0, 1))
        elbo = jnp.mean(elbo)

        logpx = jnp.mean(logpx)
        logpz = jnp.mean(logpz)
        logqz = jnp.mean(logqz)

        return elbo, logpx, logpz, logqz
    
    def U(z):
        logpx = log_bernoulli(self.decode(z), x)
        logpz = log_normal(z)
        return -logpx - logpz  # energy as -log p(x, z)

    # If hamiltonian flow
    
    #FIXX
    def grad_U(z):
        grad_outputs = jnp.ones(z.size(0))
        grad = torchgrad(U(z), z, grad_outputs=grad_outputs, create_graph=True)[0]
        norm = torch.sqrt(torch.norm(grad, p=2, dim=1))
        grad = grad / norm.view(-1, 1)
        # grad = torch.clamp(grad, -10000, 10000)
        return grad.detach()
    
    def sample(self, mean, logvar, grad_fn=lambda x: 1, x_info=None):
        eps = random.normal(rng, mu.shape)
        z = jnp.exp(0.5 * logvar)*eps + mean
        logqz = log_normal(z, mean, logvar)
        
        if self.has_flow:
            z, logprob = self.q_dist.forward(z, grad_fn, x_info)
            logqz += logprob

        zeros = jnp.zeros(z.size())
        logpz = log_normal(z, zeros, zeros)

        return z, logpz, logqz

    def model():
        return VAE()


# HParams

In [13]:
class HParams(object):

    def __init__(self, **kwargs):
        self._items = {}
        for k, v in kwargs.items():
            self._set(k, v)

    def _set(self, k, v):
        self._items[k] = v
        setattr(self, k, v)

    def parse(self, str_value):
        hps = HParams(**self._items)
        for entry in str_value.strip().split(","):
            entry = entry.strip()
            if not entry:
                continue
            key, sep, value = entry.partition("=")
            if not sep:
                raise ValueError("Unable to parse: %s" % entry)
            default_value = hps._items[key]
            if isinstance(default_value, bool):
                hps._set(key, value.lower() == "true")
            elif isinstance(default_value, int):
                hps._set(key, int(value))
            elif isinstance(default_value, float):
                hps._set(key, float(value))
            else:
                hps._set(key, value)
            return hps

def get_default_hparams():
    return HParams(
        z_size=50,
        act_func=F.elu,
        has_flow=False,
        large_encoder=False,
        wide_encoder=False,
        cuda=True,
        decode_dist = log_bernoulli,
    )

# Utils

In [14]:
def load_mnist(path, kind='train'):
    """Load MNIST data from `path`"""

    labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz' % kind)
    images_path = os.path.join(path, '%s-images-idx3-ubyte.gz' % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)

    return images, labels