In [1]:
import sys
sys.path.insert(0, '/Users/richardgrumitt/Documents/blackjax/')
import blackjax
import arviz as az
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

import scipy
from scipy.stats import multivariate_normal as n_mvn
import corner
import torch
import warnings
import pickle

import jax
import jax.numpy as jnp

import getdist
from getdist import plots, MCSamples

seed=1234
np.random.seed(seed)
rng_key = jax.random.PRNGKey(seed)



# German Credit

In [2]:
data = np.genfromtxt("./fiducial_samples/german.data-numeric")
x = data[:, :-1]
y = (data[:, -1] - 1).astype(np.int32)

x_min = np.min(x, 0, keepdims=True)
x_max = np.max(x, 0, keepdims=True)

x /= (x_max - x_min)
x = 2.0 * x - 1.0

x = np.concatenate([x, np.ones([x.shape[0], 1])], -1)

x = jnp.asarray(x, dtype=jnp.float32)
y = jnp.asarray(y, dtype=jnp.float32)

In [13]:
n = x.shape[0]
d = x.shape[1]
n_dim_data = d

def gamma_log_prob(x, alpha=0.5, beta=0.5):
    return (alpha - 1) * jnp.log(x) - beta * x

def log_like(params, data=y, covariates=x):
    
    tau = jnp.exp(params[..., 0])
    lam = jnp.exp(params[..., 1:d+1])
    beta = params[..., d+1:]
    
    f = jnp.dot(covariates, (tau * beta * lam).T)
    
    return jnp.sum(jax.scipy.stats.bernoulli.logpmf(k=data, p=jax.scipy.special.expit(f))).squeeze()

def log_prior(params):
    
    tau = jnp.exp(params[..., 0])
    lam = jnp.exp(params[..., 1:d+1])
    beta = params[..., d+1:]
    logp_tau = jnp.log(tau) + jax.scipy.stats.gamma.logpdf(x=tau, a=0.5, scale=1.0/0.5)
    logp_lam = jnp.sum(jnp.log(lam) + jax.scipy.stats.gamma.logpdf(x=lam, a=0.5, scale=1.0/0.5))
    logp_beta = jax.scipy.stats.multivariate_normal.logpdf(beta, mean=jnp.zeros(d), cov=jnp.eye(d))
    
    return (logp_tau + logp_lam + logp_beta).squeeze()
    
def prior_rvs(key, N):
    
    tau = 0.5 * jax.random.gamma(key, a=0.5, shape=(N,))
    lam = 0.5 * jax.random.gamma(key, a=0.5, shape=(N, d))
    beta = jax.random.normal(key, shape=(N, d))
    
    return jnp.concatenate([jnp.log(tau).reshape(N, 1), jnp.log(lam), beta], axis=1)


vlog_prior = jax.vmap(log_prior)
vlog_like = jax.vmap(log_like)
target_logp = lambda x: log_like(x) + log_prior(x)

initial_position = prior_rvs(rng_key, 1)

print(vlog_like(initial_position))
#initial_state = nuts.init(initial_position)

[-629.21674]


In [None]:
warmup = blackjax.window_adaptation(
    blackjax.nuts,
    target_logp,
    1000,
)
state, kernel, _ = warmup.run(
    rng_key,
    initial_position,
)

In [None]:
def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

In [None]:
samples = inference_loop(rng_key, kernel, state, 1_000)