In [153]:
import jax.numpy as jnp
from jax import random
import jax

In [2]:
key = random.PRNGKey(42)

In [89]:
n_dims = 2

u = random.normal(key, (1, n_dims))
w = random.normal(key, (1, n_dims))
b = random.normal(key, (1, ))

z = random.normal(key, (10, n_dims))

# without checking for inverse

In [132]:
a = jnp.tanh(jnp.dot(z, w.T) + b)
f_z = z + a * u

In [133]:
psi = (1 - a**2) * w

In [134]:
log_det = jnp.log(jnp.abs(1 + jnp.dot(psi, u.T)))

In [135]:
log_det

DeviceArray([[7.0586407e-01],
             [1.7356923e+00],
             [6.4933568e-01],
             [1.7187123e-03],
             [1.2462626e+00],
             [1.5734894e+00],
             [5.4141873e-01],
             [1.7079745e-01],
             [1.6818756e+00],
             [1.2972219e+00]], dtype=float32)

In [106]:
f_z + log_det

DeviceArray([[-1.9248432 ,  1.3059537 ],
             [ 2.0724218 ,  0.8170076 ],
             [-1.6387651 ,  3.1203785 ],
             [ 4.163963  , -0.81497204],
             [-0.7484112 ,  1.6867993 ],
             [ 0.27362704,  2.0167742 ],
             [ 2.9737387 , -0.97599727],
             [ 2.9625885 , -1.7842555 ],
             [ 0.8964971 ,  2.6955755 ],
             [-0.58546245,  1.8676493 ]], dtype=float32)

In [107]:
u.shape

(1, 2)

In [108]:
w.shape

(1, 2)

In [111]:
jnp.dot(w, u.T) > 1

DeviceArray([[ True]], dtype=bool)

# Putting in check:

In [114]:
m = lambda x: -1 + jnp.log1p(jnp.exp(x))

In [115]:
wu = jnp.dot(w, u.T)

In [125]:
u_hat = u + (m(wu) - wu) * w / jnp.dot(w, w.T)

In [126]:
u_hat

DeviceArray([[-1.7323706 ,  0.37109286]], dtype=float32)

In [127]:
u

DeviceArray([[-2.169826  ,  0.46480063]], dtype=float32)

In [129]:
def planar_flow(u, w, b, z):
    # making sure its invertible
    wu = jnp.dot(w, u.T)
    u_hat = u + (-1 + jnp.log1p(jnp.exp(wu)) - wu) * w / jnp.dot(w, w.T)
    
    # transforming
    a = jnp.tanh(jnp.dot(z, w.T) + b)
    f_z = z + a * u_hat
    psi = (1 - a**2) * w
    log_det = jnp.log(jnp.abs(1 + jnp.dot(psi, u_hat.T)))
    return f_z, log_det

In [130]:
planar_flow(u, w, b, z)

(DeviceArray([[-2.2414646 ,  0.51670957],
              [ 0.2378994 , -0.89751416],
              [-1.8933396 ,  2.3864808 ],
              [ 3.7248652 , -0.72299933],
              [-1.6863042 ,  0.37448058],
              [-1.093035  ,  0.39898   ],
              [ 2.0280344 , -1.4308136 ],
              [ 2.3626885 , -1.8631344 ],
              [-0.6393704 ,  0.98242337],
              [-1.5859914 ,  0.5068725 ]], dtype=float32),
 DeviceArray([[5.9819186e-01],
              [1.5540872e+00],
              [5.4808933e-01],
              [1.3724680e-03],
              [1.0912092e+00],
              [1.3993610e+00],
              [4.5336533e-01],
              [1.3863160e-01],
              [1.5026240e+00],
              [1.1387993e+00]], dtype=float32))

Looks decent - now let's look at doing it in batches using vmap

# batching

In [144]:
n_dims = 2
n_batch = 5
n_samples = 10

u = random.normal(key, (n_batch, 1, n_dims))
w = random.normal(key, (n_batch, 1, n_dims))
b = random.normal(key, (n_batch, 1, ))

x = random.normal(key, (n_batch, n_samples, n_dims))

In [148]:
z = []
log_det = []
for idx in jnp.arange(u.shape[0]):
    z_i, log_det_i = planar_flow(u[idx], w[idx], b[idx], x[idx])
    z.append(z_i)
    log_det.append(log_det_i)
z = jnp.stack(z, axis=0)
log_det = jnp.stack(log_det, axis=0)

In [149]:
z.shape

(5, 10, 2)

In [150]:
log_det.shape

(5, 10, 1)

Now lets try it with vmap

In [163]:
z_vmap, log_det_vmap = jax.vmap(planar_flow)(u, w, b, x)

In [164]:
jnp.allclose(z_vmap, z)

DeviceArray(True, dtype=bool)

In [165]:
jnp.allclose(log_det, log_det_vmap)

DeviceArray(True, dtype=bool)

Okay that works, great. Now let's try and get it working from a matrix with multiple layers etc.

# batching, hyper network

In [227]:
n_dims = 1
n_batch = 5
n_samples = 10
n_layers = 15

params = random.normal(key, (n_batch, (2 * n_dims + 1) * n_layers))
z = random.normal(key, (n_batch, n_samples, n_dims))

In [228]:
z.shape

(5, 10, 1)

In [229]:
u, w, b = jnp.split(params, [n_layers * n_dims, 2 * n_layers * n_dims],axis=1)
u = jnp.swapaxes(u.reshape(n_batch, n_layers, 1, n_dims), 1, 0) # (n_layers, n_batch, 1, n_dim)
w = jnp.swapaxes(w.reshape(n_batch, n_layers, 1, n_dims), 1, 0) # (n_layers, n_batch, 1, n_dim)
b = jnp.swapaxes(b.reshape(n_batch, n_layers, 1, ), 1, 0) # (n_layers, n_batch, 1, )

In [230]:
log_jacob = jnp.zeros((n_batch, n_samples, 1))
for idx in jnp.arange(n_layers):
    z, log_jacob_i = jax.vmap(planar_flow)(u[idx], w[idx], b[idx], z)
    log_jacob += log_jacob_i

In [231]:
z.shape

(5, 10, 1)

In [232]:
log_jacob.shape

(5, 10, 1)

In [233]:
from jax.scipy.stats import norm

In [236]:
jnp.sum(norm.logpdf(z) + log_jacob)

DeviceArray(-1088.2534, dtype=float32)

In [None]:
norm.logpdf()

In [239]:
z.size

50

In [240]:
from jax.lax import scan

In [241]:
log_jacob = jnp.zeros((n_batch, n_samples, 1))
for idx in jnp.arange(n_layers):
    z, log_jacob_i = jax.vmap(planar_flow)(u[idx], w[idx], b[idx], z)
    log_jacob += log_jacob_i

In [287]:
n_dims = 1
n_batch = 5
n_samples = 10
n_layers = 15

params = random.normal(key, (n_batch, (2 * n_dims + 1) * n_layers))
x = random.normal(key, (n_batch, n_samples, n_dims))

In [272]:
u, w, b = jnp.split(params, [n_layers * n_dims, 2 * n_layers * n_dims],axis=1)
u = jnp.swapaxes(u.reshape(n_batch, n_layers, 1, n_dims), 1, 0) # (n_layers, n_batch, 1, n_dim)
w = jnp.swapaxes(w.reshape(n_batch, n_layers, 1, n_dims), 1, 0) # (n_layers, n_batch, 1, n_dim)
b = jnp.swapaxes(b.reshape(n_batch, n_layers, 1, ), 1, 0) # (n_layers, n_batch, 1, )

In [273]:
z, log_jac = scan(lambda z, params: jax.vmap(planar_flow)(*params, z), x, (u, w, b))

In [274]:
z.shape

(5, 10, 1)

In [275]:
log_jac.shape

(15, 5, 10, 1)

In [256]:
log_jac = jnp.sum(log_jac, axis=0)

In [257]:
jnp.sum(norm.logpdf(z) + log_jacob)

DeviceArray(-1037.6753, dtype=float32)

What if we do the vmap over the scan?

In [278]:
u, w, b = jnp.split(params, [n_layers * n_dims, 2 * n_layers * n_dims],axis=1)
u = u.reshape(n_batch, n_layers, 1, n_dims)
w = w.reshape(n_batch, n_layers, 1, n_dims)
b = b.reshape(n_batch, n_layers, 1, )

In [294]:
z, log_jac = scan(planar_flow, x[0], (u[0], w[0], b[0]));

In [295]:
z.shape

(10, 1)

In [286]:
log_jac.shape

(15, 10, 1)

In [301]:
z, log_jac = jax.vmap(lambda batch_params, batch_x: scan(planar_flow, batch_x, batch_params))((u, w, b), x)

In [302]:
z.shape

(5, 10, 1)

In [303]:
log_jac.shape

(5, 15, 10, 1)

In [306]:
jnp.sum(log_jac, axis=-3).shape

(5, 10, 1)

In [293]:
def planar_flow(z, params):
    u, w, b = params
    # making sure its invertible
    wu = jnp.dot(w, u.T)
    u_hat = u + (-1 + jnp.log1p(jnp.exp(wu)) - wu) * w / jnp.dot(w, w.T)
    
    # transforming
    a = jnp.tanh(jnp.dot(z, w.T) + b)
    f_z = z + a * u_hat
    psi = (1 - a**2) * w
    log_det = jnp.log(jnp.abs(1 + jnp.dot(psi, u_hat.T)))
    return f_z, log_det