In [62]:
import os
import time

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import scipy

import jax
from jax import vmap
import jax.numpy as jnp
import jax.random as random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, init_to_feasible, init_to_median, init_to_sample, \
    init_to_uniform, init_to_value

from pymc3.gp.util import plot_gp_dist

from tqdm import tqdm

In [45]:
# Don't sting my eyes
%config InlineBackend.figure_format = 'retina'
sns.set()

# Utility models

In [46]:
def u_pow(x, theta): return x**theta

In [47]:
def u_exp(x, theta): return 1 - np.exp(-theta*x)

In [48]:
def u_lin(x, theta=None): return x

# Generate choice data

In [49]:
def generate_data(u, seed=123, max_x=1, n=100, tau=3.333, theta=(0.5, )):
    
    np.random.seed(seed)
    
    data = pd.DataFrame(np.random.uniform(0, 1, size=(n*10, 4)), columns=["p0", "x0", "p1", "x1"])
    for i in range(2):
        data[f"x{i}"] = data[f"x{i}"].values * max_x
    data = data[~((data.p0 >= data.p1) & (data.x0 >= data.x1))]
    data = data[~((data.p1 >= data.p0) & (data.x1 >= data.x0))]
    data = data.sample(n=n, replace=False)
    
    p0 = data.p0.values
    p1 = data.p1.values
    x0 = data.x0.values
    x1 = data.x1.values
    
    seu0 = p0 * u(x0, theta)
    seu1 = p1 * u(x1, theta)

    diff_eu = seu1 - seu0

    p_choice_1 = scipy.special.expit(tau*diff_eu)
    choices = np.zeros(n, dtype=int)
    choices[:] = p_choice_1 > np.random.random(size=n)
    data['choices'] = choices

    return data

In [50]:
data = generate_data(u=u_pow, n=1000, tau=3.333, theta=(0.5, ))
data

Unnamed: 0,p0,x0,p1,x1,choices
9002,0.696389,0.007647,0.617248,0.282997,1
7863,0.953285,0.292898,0.380948,0.552913,1
7455,0.060678,0.327938,0.685627,0.158464,1
3414,0.138550,0.276750,0.788651,0.125342,0
5796,0.116159,0.356232,0.732859,0.310642,0
...,...,...,...,...,...
2857,0.928569,0.159319,0.006656,0.571188,0
5417,0.621396,0.103035,0.601608,0.871169,1
1854,0.218694,0.560962,0.389658,0.026879,1
8834,0.980679,0.073375,0.877985,0.777087,1


# Likelihood whole model given M, $\theta_M$

In [51]:
def softplus(x):
    return np.log(1 + np.exp(x))

def objective(param, data, u_m):
    
    param = softplus(param)  # All parameters all supposed to be R+
    
    tau = param[0]
    theta = param[1]
    
    p0 = data.p0.values
    p1 = data.p1.values
    x0 = data.x0.values
    x1 = data.x1.values
    y = data.choices.values 
    
    seu0 = p0 * u_m(x0, theta)
    seu1 = p1 * u_m(x1, theta)

    diff_eu = seu1 - seu0

    p_choice_1 = scipy.special.expit(tau*diff_eu) # p choose 1
    p_choice_y = p_choice_1**y * (1-p_choice_1)**(1-y)
    return - np.log(p_choice_y).sum()

In [52]:
def optimize(data, u_m=u_pow, x0=None):
    if x0 is None:
        x0 = (0.0, 0.0) # Assume two parameters
    opt = scipy.optimize.minimize(objective, x0=x0, args=(data, u_m))
    return softplus(opt.x)

In [53]:
optimize(data, u_m=u_pow)

array([3.28496491, 0.53506426])

# Measure discrepancy

In [100]:
u_data = u_pow
theta_data = 0.5
tau_data = 3.333
u_model = u_pow
theta_model = 0.5
tau_model = 3.333

In [101]:
data = generate_data(u=u_data, tau=tau_data, theta=theta_data, n=200, seed=123)

In [102]:
p0 = data.p0.values
p1 = data.p1.values
x0 = data.x0.values
x1 = data.x1.values
y = data.choices.values

x = np.hstack((x0, x1))
p = np.hstack((p0, p1))

# compute M(x)
X = jnp.array(x)
uX = jnp.array(u_model(x, theta_model))
P = jnp.array(p)
y = jnp.array(y)

n = y.shape[0]

In [103]:
optimize(data, u_m=u_pow)

array([3.97130767, 0.48028541])

In [104]:
def mean_function(x):
    return u_model(x, theta_model)

def kernel(X, Xs, var, length):
    deltaX = jnp.power((X[:, None] - Xs) / length, 2.0)
    k = var * jnp.exp(-0.5 * deltaX)
    return k


def compute_f(var, length, eta, mu, X, jitter=1.0e-6):
    N = X.shape[0]
    K = kernel(X, X, var, length) + jitter * jnp.eye(N)
    L = jnp.linalg.cholesky(K)
    return jnp.matmul(L, eta) + mu


def model(X, y, P):
    
    N = X.shape[0]
    
    # set uninformative log-normal priors on our three kernel hyperparameters
    # var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
    # length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
    
    var = numpyro.sample('kernel_var', dist.HalfCauchy(5))
    length = numpyro.sample('kernel_length', dist.InverseGamma(2, 2))
    
    # beta = numpyro.sample('beta', dist.Normal(np.zeros(N), 1))
    eta = numpyro.sample('eta', dist.Normal(np.zeros(N), 1))
    
    mu = mean_function(X)
    f = compute_f(var=var, length=length, mu=mu, eta=eta, X=X)
    
    est_eu =  P * f    
    diff_eu_hat = est_eu[n:] - est_eu[:n]
    
    numpyro.sample("obs", dist.Bernoulli(logits=tau_model*diff_eu_hat), obs=y)

In [None]:
# Set random seed for reproducibility.
rng_key = random.PRNGKey(0)

nuts = MCMC(NUTS(model, target_accept_prob=0.9, max_tree_depth=10),
            num_samples=1000, num_warmup=1000)
nuts.run(rng_key, X=X, y=y, P=P)

warmup:   3%|▎         | 62/2000 [00:11<08:02,  4.01it/s, 2 steps of size 3.18e-08. acc. prob=0.76]   

In [None]:
def gp_predict(X, samples, i, Xnew=None, jitter=1.0e-6):
    
    var=samples['kernel_var'][i]
    length=samples['kernel_length'][i]
    eta=samples['eta'][i] 
    mu = mean_function(X)
    
    N = X.shape[0]
    K = kernel(X, X, var, length) + jitter*jnp.eye(N)
    L = jnp.linalg.cholesky(K)
    f = jnp.matmul(L, eta) + mu
    
    if Xnew is not None:
        M = Xnew.shape[0]
        k_pp = kernel(Xnew, Xnew, var, length) + jitter*jnp.eye(M)
        k_pX = kernel(Xnew, X, var, length)
        k_XX = kernel(X, X, var, length) + jitter*jnp.eye(N)
        K_xx_inv = jnp.linalg.inv(k_XX)
        K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
        # sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.)) * jax.random.normal(rng_key, X_test.shape[:1])
        f_predict = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, f))
    else:
        f_predict = f
    
    return f_predict


In [None]:
nuts_samples = nuts.get_samples()

Xtest = jnp.array(np.linspace(0, 1, 100))

mu = mean_function(Xtest)
n_samples = nuts_samples['kernel_var'].shape[0]
f_predict = np.zeros((n_samples, Xtest.shape[0]))

for i in tqdm(range(n_samples)):
    f_predict[i] = gp_predict(X=X, Xnew=Xtest, samples=nuts_samples, i=i)

In [None]:
fig, ax = plt.subplots()
plot_gp_dist(ax, f_predict, Xtest)