In [1]:
import pandas as pd
import numpy as np
import scipy as sp
from scipy.special import expit

import pymc3 as pm
import arviz as az
import cmdstanpy

import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
%config InlineBackend.figure_format = 'retina'
az.style.use("arviz-darkgrid")
rng = np.random.default_rng(1234)

In [3]:
db = pd.read_csv('data/sim.csv')


In [4]:
numsubjs = 10

α_true = np.random.beta(4,   3, numsubjs)
β_true = np.random.normal(0, 1, numsubjs)

values    = np.tile(np.array(db.value),     numsubjs)
risk      = np.tile(np.array(db.risk),      numsubjs)
ambiguity = np.tile(np.array(db.ambiguity), numsubjs)

In [5]:
refValue       = 5
refProbability = 1
refAmbiguity   = 0

refProbabilities = np.tile(refProbability, len(values))
refValues        = np.tile(refValue,       len(values))
refAmbiguities   = np.tile(refAmbiguity,   len(values))

In [6]:
riskTol = np.repeat(α_true, len(risk) / numsubjs)
ambTol  = np.repeat(β_true, len(ambiguity) / numsubjs)

In [7]:
uRef = refValues ** riskTol
uLotto = values ** riskTol * (risk - ambTol * ambiguity / 2)
p = sp.special.expit(uLotto - uRef)

choices = np.random.binomial(1, p, len(p))

In [8]:
n_trials = np.arange(len(choices))

In [9]:
sub_idx = np.arange(numsubjs)
sub_idx = np.repeat(sub_idx, 84)
ID = sub_idx+1

In [10]:
AmbiguityModel = cmdstanpy.CmdStanModel(stan_file='/home/nachshon/Documents/Projects/Aging/Aging/RiskandAmbiguity/models/AmbiguityModel.stan')

INFO:cmdstanpy:found newer exe file, not recompiling


In [11]:
standata_ambiguity = {
    'N' : len(refProbabilities),
    'choice' : choices,
    'refProbabilities' : refProbabilities,
    'refAmbiguities' : refAmbiguities,
    'refValues' : refValues,
    'lotteryProbabilities' : risk,
    'lotteryAmbiguities' : ambiguity,
    'ID' : ID,
    'lotteryValues' : values,
    'n_sub': numsubjs
}

In [12]:
fit_ambiguity_model = AmbiguityModel.sample(
  data = standata_ambiguity,
  chains = 4,
  iter_warmup = 1000,
  iter_sampling = 1000,
  adapt_delta = .9,
  inits = 0.2,
)

INFO:cmdstanpy:CmdStan start procesing


chain 1 |          | 00:00 Status

chain 2 |          | 00:00 Status

chain 3 |          | 00:00 Status

chain 4 |          | 00:00 Status

                                                                                                                                                                                                                                                                                                                                

INFO:cmdstanpy:CmdStan done processing.





In [13]:
cmdstanpy_data_amb = az.from_cmdstanpy(posterior=fit_ambiguity_model,
                                       posterior_predictive="y_hat",
                                       log_likelihood="log_lik")

In [14]:
AmbiguityModel_p = cmdstanpy.CmdStanModel(stan_file='/home/nachshon/Documents/Projects/Aging/Aging/RiskandAmbiguity/models/AmbiguityModel_priors.stan')

INFO:cmdstanpy:found newer exe file, not recompiling


In [15]:
fit_ambiguity_model_info_prior = AmbiguityModel_p.sample(
  data = standata_ambiguity,
  chains = 4,
  iter_warmup = 1000,
  iter_sampling = 1000,
  adapt_delta = .9,
  inits = 0.2,
)

INFO:cmdstanpy:CmdStan start procesing


chain 1 |          | 00:00 Status

chain 2 |          | 00:00 Status

chain 3 |          | 00:00 Status

chain 4 |          | 00:00 Status

                                                                                                                                                                                                                                                                                                                                

INFO:cmdstanpy:CmdStan done processing.





In [16]:
cmdstanpy_data_amb_info_prior = az.from_cmdstanpy(posterior=fit_ambiguity_model_info_prior,
                                                  posterior_predictive="y_hat",
                                                  log_likelihood="log_lik")

In [17]:
choices = np.array(choices)

In [18]:
with pm.Model() as RiskAmb:
    # hyper
    rMu = pm.Normal('rMu', 0,1)
    rSig = pm.Exponential('rSig', 1)
    aMu = pm.Normal('aMu', 0,1)
    aSig = pm.Exponential('aSig', 1)
    
    nMu = pm.Normal('nMu', 0,1)
    nSig = pm.Exponential('nSig', 1)
    
    
    α = pm.Lognormal('α', rMu, rSig, shape = numsubjs)
    β = pm.Normal('β', aMu, aSig, shape = numsubjs)
    γ = pm.Lognormal('γ',nMu , nSig, shape = numsubjs)
    # Priors for unknown model parameters
    
      
    # Expected value of outcome
    svLotto = (values ** α[sub_idx]) * (risk-(β[sub_idx] * (ambiguity/2)))
    svRef = 5 ** α[sub_idx]
    p = (svLotto - svRef)/γ[sub_idx]
    mu = pm.invlogit(p)
       
    # Likelihood (sampling distribution) of observations
    #Y_obs = pm.Normal("Y_obs", mu=mu, sigma=sigma, observed=db.choice)
    choice = pm.Binomial('choice',1, mu, observed=choices)
    trace2 = pm.sample(2000, return_inferencedata=True, target_accept=0.95)

Auto-assigning NUTS sampler...
INFO:pymc3:Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
INFO:pymc3:Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
INFO:pymc3:Multiprocess sampling (4 chains in 4 jobs)
NUTS: [γ, β, α, nSig, nMu, aSig, aMu, rSig, rMu]
INFO:pymc3:NUTS: [γ, β, α, nSig, nMu, aSig, aMu, rSig, rMu]


Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 66 seconds.
INFO:pymc3:Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 66 seconds.
There were 30 divergences after tuning. Increase `target_accept` or reparameterize.
ERROR:pymc3:There were 30 divergences after tuning. Increase `target_accept` or reparameterize.
There were 54 divergences after tuning. Increase `target_accept` or reparameterize.
ERROR:pymc3:There were 54 divergences after tuning. Increase `target_accept` or reparameterize.
There were 53 divergences after tuning. Increase `target_accept` or reparameterize.
ERROR:pymc3:There were 53 divergences after tuning. Increase `target_accept` or reparameterize.
There were 31 divergences after tuning. Increase `target_accept` or reparameterize.
ERROR:pymc3:There were 31 divergences after tuning. Increase `target_accept` or reparameterize.
The number of effective samples is smaller than 10% for some 

In [19]:
with pm.Model() as RiskAmb:
    # hyper
    a = pm.Poisson('a', 1)
    b = pm.Poisson('b', 1)
    
    aMu = pm.Normal('aMu', 0,1)
    aSig = pm.Exponential('aSig', 1)
    
    nMu = pm.Normal('nMu', 0,1)
    nSig = pm.Exponential('nSig', 1)
    
    
    α = pm.Beta('α', a, b, shape = numsubjs)
    β = pm.Normal('β', aMu, aSig, shape = numsubjs)
    γ = pm.Lognormal('γ',nMu , nSig, shape = numsubjs)
    # Priors for unknown model parameters
    
    riskTol = α * 2
    

    # Expected value of outcome
    svLotto = (values ** riskTol[sub_idx]) * (risk-(β[sub_idx] * (ambiguity/2)))
    svRef = 5 ** riskTol[sub_idx]
    p = (svLotto - svRef)/γ[sub_idx]
    mu = pm.invlogit(p)
       
    # Likelihood (sampling distribution) of observations
    #Y_obs = pm.Normal("Y_obs", mu=mu, sigma=sigma, observed=db.choice)
    choice = pm.Binomial('choice',1, mu, observed=choices)
    trace3 = pm.sample(2000, return_inferencedata=True, nuts={'target_accept':0.95})

Multiprocess sampling (4 chains in 4 jobs)
INFO:pymc3:Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
INFO:pymc3:CompoundStep
>CompoundStep
INFO:pymc3:>CompoundStep
>>Metropolis: [b]
INFO:pymc3:>>Metropolis: [b]
>>Metropolis: [a]
INFO:pymc3:>>Metropolis: [a]
>NUTS: [γ, β, α, nSig, nMu, aSig, aMu]
INFO:pymc3:>NUTS: [γ, β, α, nSig, nMu, aSig, aMu]


Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 71 seconds.
INFO:pymc3:Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 71 seconds.
There were 33 divergences after tuning. Increase `target_accept` or reparameterize.
ERROR:pymc3:There were 33 divergences after tuning. Increase `target_accept` or reparameterize.
There were 44 divergences after tuning. Increase `target_accept` or reparameterize.
ERROR:pymc3:There were 44 divergences after tuning. Increase `target_accept` or reparameterize.
There were 54 divergences after tuning. Increase `target_accept` or reparameterize.
ERROR:pymc3:There were 54 divergences after tuning. Increase `target_accept` or reparameterize.
There were 99 divergences after tuning. Increase `target_accept` or reparameterize.
ERROR:pymc3:There were 99 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0

In [20]:
az.summary(cmdstanpy_data_amb, var_names=['riskTol','ambTol'])

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
riskTol[0],0.392,0.116,0.231,0.62,0.027,0.02,20.0,102.0,1.13
riskTol[1],0.748,0.084,0.622,0.931,0.01,0.007,78.0,1630.0,1.05
riskTol[2],0.623,0.132,0.418,0.86,0.031,0.022,19.0,96.0,1.15
riskTol[3],0.73,0.081,0.566,0.884,0.002,0.001,1805.0,1950.0,1.19
riskTol[4],0.532,0.093,0.365,0.711,0.011,0.008,79.0,1717.0,1.04
riskTol[5],0.788,0.086,0.647,0.975,0.003,0.003,722.0,1856.0,1.07
riskTol[6],1.129,0.095,0.961,1.314,0.012,0.009,63.0,1016.0,1.04
riskTol[7],0.823,0.077,0.666,0.953,0.011,0.008,49.0,1965.0,1.06
riskTol[8],0.678,0.087,0.516,0.85,0.002,0.002,1203.0,1681.0,1.06
riskTol[9],0.419,0.104,0.223,0.615,0.014,0.011,59.0,1338.0,1.05


In [21]:
az.summary(cmdstanpy_data_amb_info_prior, var_names=['riskTol','ambTol'])

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
riskTol[0],0.302,0.133,0.057,0.536,0.003,0.002,1800.0,1461.0,1.0
riskTol[1],0.791,0.099,0.601,0.978,0.002,0.001,2443.0,1828.0,1.0
riskTol[2],0.666,0.171,0.322,1.001,0.005,0.003,1350.0,866.0,1.0
riskTol[3],0.758,0.106,0.558,0.955,0.002,0.002,2752.0,2265.0,1.0
riskTol[4],0.489,0.114,0.261,0.691,0.003,0.002,1787.0,2160.0,1.0
riskTol[5],0.79,0.094,0.614,0.967,0.002,0.001,2069.0,1995.0,1.0
riskTol[6],1.196,0.109,1.002,1.411,0.002,0.002,2168.0,2210.0,1.0
riskTol[7],0.828,0.085,0.677,0.996,0.002,0.001,3033.0,2710.0,1.0
riskTol[8],0.694,0.104,0.5,0.896,0.002,0.002,2443.0,2137.0,1.0
riskTol[9],0.329,0.126,0.089,0.563,0.003,0.002,1401.0,1297.0,1.0


In [22]:
az.compare({'informed':fit_ambiguity_model_info_prior, 'uninformed':fit_ambiguity_model}, ic='waic')

Unnamed: 0,rank,waic,p_waic,d_waic,weight,se,dse,warning,waic_scale
informed,0,-842.854376,0.610055,0.0,1.0,2.738401,0.0,False,log
uninformed,1,-843.293017,0.537266,0.43864,0.0,2.754678,0.276079,False,log


In [23]:
az.summary(trace2, var_names=['α','β'])

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
α[0],0.371,0.1,0.181,0.557,0.002,0.001,3128.0,2918.0,1.0
α[1],0.617,0.087,0.457,0.783,0.001,0.001,3686.0,3243.0,1.0
α[2],0.544,0.114,0.327,0.759,0.003,0.002,1951.0,1520.0,1.0
α[3],0.59,0.085,0.429,0.744,0.001,0.001,3692.0,3653.0,1.0
α[4],0.511,0.091,0.338,0.676,0.002,0.001,3122.0,3455.0,1.0
α[5],0.759,0.084,0.597,0.912,0.001,0.001,3388.0,3374.0,1.0
α[6],0.956,0.087,0.797,1.123,0.001,0.001,4104.0,4413.0,1.0
α[7],0.693,0.08,0.538,0.839,0.001,0.001,3399.0,4577.0,1.0
α[8],0.6,0.089,0.437,0.772,0.002,0.001,3257.0,3421.0,1.0
α[9],0.405,0.093,0.227,0.581,0.002,0.001,3305.0,2690.0,1.0


In [24]:
az.summary(trace3, var_names=['α','β'])

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
α[0],0.112,0.061,0.002,0.212,0.001,0.001,2003.0,2809.0,1.0
α[1],0.31,0.046,0.221,0.392,0.001,0.001,3521.0,3541.0,1.0
α[2],0.255,0.077,0.102,0.4,0.002,0.002,1131.0,1508.0,1.0
α[3],0.294,0.045,0.207,0.376,0.001,0.001,2651.0,4940.0,1.0
α[4],0.241,0.053,0.142,0.341,0.001,0.001,2806.0,2731.0,1.0
α[5],0.387,0.041,0.309,0.464,0.001,0.0,4627.0,4684.0,1.0
α[6],0.49,0.042,0.41,0.568,0.001,0.001,2809.0,4580.0,1.0
α[7],0.352,0.041,0.276,0.429,0.001,0.0,3858.0,4926.0,1.0
α[8],0.301,0.048,0.21,0.391,0.001,0.001,2710.0,3634.0,1.0
α[9],0.154,0.06,0.035,0.262,0.001,0.001,2736.0,2344.0,1.0


In [25]:
print(α_true,'\n', β_true)

[0.43088793 0.75851713 0.58670421 0.59103367 0.47015016 0.82167002
 0.88334426 0.74470406 0.5737788  0.25437644] 
 [ 1.61137062 -0.20330181 -0.04021609 -0.94167306  1.13501365  1.39153644
 -1.44795524  0.10565118 -0.03147334  1.0931832 ]


In [26]:
α_pred2 = az.summary(trace2, var_names=['α'])['mean']
α_pred3 = az.summary(trace3, var_names=['α'])['mean']
β_pred2 = az.summary(trace2, var_names=['β'])['mean']
β_pred3 = az.summary(trace3, var_names=['β'])['mean']

In [27]:
print(sum((α_pred2-α_true)**2))
print(sum((α_pred3*2-α_true)**2))

print(sum((β_pred2-β_true)**2))
print(sum((β_pred3-β_true)**2))

0.06236179600708265
0.08496627445197395
3.031885433412442
3.296619745887324
