In [None]:
import math
import os

import arviz as az
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
from IPython.display import Image, set_matplotlib_formats
from matplotlib.patches import Ellipse, transforms

import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.special import expit

import numpy as onp
import numpyro as numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import effective_sample_size, print_summary
from numpyro.infer import MCMC, NUTS, Predictive

az.style.use("arviz-darkgrid")
numpyro.set_platform("cpu")
numpyro.set_host_device_count(4)
a = 3.5  # average morning wait time
b = -1  # average difference afternoon wait time
sigma_a = 1  # std dev in intercepts
sigma_b = 0.5  # std dev in slopes
rho = -0.7  # correlation between intercepts and slopes
Mu = jnp.array([a, b])
cov_ab = sigma_a * sigma_b * rho
Sigma = jnp.array([[sigma_a**2, cov_ab], [cov_ab, sigma_b**2]])

sigmas = jnp.array([sigma_a, sigma_b])  # standard deviations
Rho = jnp.array([[1, rho], [rho, 1]])  # correlation matrix

# now matrix multiply to get covariance matrix
Sigma = jnp.diag(sigmas) @ Rho @ jnp.diag(sigmas)
N_cafes = 20
seed = random.PRNGKey(5)  # used to replicate example
vary_effects = dist.MultivariateNormal(Mu, Sigma).sample(seed, (N_cafes,))
a_cafe = vary_effects[:, 0]
b_cafe = vary_effects[:, 1]
seed = random.PRNGKey(22)
N_visits = 10
afternoon = jnp.tile(jnp.arange(2), N_visits * N_cafes // 2)
cafe_id = jnp.repeat(jnp.arange(N_cafes), N_visits)
mu = a_cafe[cafe_id] + b_cafe[cafe_id] * afternoon
sigma = 0.5  # std dev within cafes
wait = dist.Normal(mu, sigma).sample(seed)
d = pd.DataFrame(dict(cafe=cafe_id, afternoon=afternoon, wait=wait))

import pathlib
import os
compiler_path = '/home/sosa/BI/linux-arm64-stanc'

plat = "linux"
plat_rename = {"darwin": "mac", "win32": "windows", "linux": "linux"}
if plat not in plat_rename.keys():
    raise OSError("OS {} is not supported".format(plat))
plat = plat_rename[plat]

os.chmod(compiler_path, 755)

from stan2tfp import Stan2tfp

stan_code = """ 
data{
    vector[200] wait;
    array[200] int afternoon;
    array[200] int cafe;
}
parameters{
    vector[20] b_cafe;
    vector[20] a_cafe;
    real a;
    real b;
    vector<lower=0>[2] sigma_cafe;
    real<lower=0> sigma;
    corr_matrix[2] Rho;
}
model{
    vector[200] mu;
    Rho ~ lkj_corr( 2 );
    sigma ~ exponential( 1 );
    sigma_cafe ~ exponential( 1 );
    b ~ normal( -1 , 0.5 );    
    a ~ normal( 5 , 2 );
    {
        array[20] vector[2] YY;
        vector[2] MU;
        MU = [ a , b ]';
        for ( j in 1:20 ) YY[j] = [ a_cafe[j] , b_cafe[j] ]';
        YY ~ multi_normal( MU , quad_form_diag(Rho , sigma_cafe) );
    }
    for ( i in 1:200 ) {
        mu[i] = a_cafe[cafe[i]] + b_cafe[cafe[i]] * afternoon[i];        
    }
    
    wait ~ normal( mu , sigma );

}
"""

data = {
    'wait' : d['wait'].values.astype(float),
    'afternoon' : d['afternoon'].values.astype(int),
    'cafe' : d['cafe'].values.astype(int)+1,
}

model = Stan2tfp(stan_model_code=stan_code, data_dict=data, compiler_path = '/home/sosa/BI/linux-arm64-stanc')