# Load package

In [1]:
%matplotlib inline

# General packages for system, time, etc
import os, time, csv, sys
import datetime
from datetime import date
import glob

# scitnific computing and plotting
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

# HDDM related packages
import pymc as pm
import hddm
import kabuki
import arviz as az
print("The current HDDM version is: ", hddm.__version__)
print("The current kabuki version is: ", kabuki.__version__)
print("The current PyMC version is: ", pm.__version__)
print("The current ArviZ version is: ", az.__version__)

# parallel processing related
from p_tqdm import p_map
from functools import partial

from sklearn.metrics import r2_score

The current HDDM version is:  0.9.8RC
The current kabuki version is:  0.6.5RC3
The current PyMC version is:  2.3.8
The current ArviZ version is:  0.14.0


# Load data

In [2]:
cpath=os.getcwd()
lname=cpath+'/ior_stroop_alldata.csv'
data=pd.read_csv(lname)
data['rt']=data['rt']/1000
data.rename(columns={"sub_idx": "subj_idx"}, inplace=True)

# Run model 3

In [3]:
from IOR_stroop_model_utils import m3_id

In [4]:
nsample=10000
burns=2000
thins=2
chains = 4

In [None]:
%%time

file_names = glob.glob("m3_id_tmp" + "_chain_*[!db]", recursive=False)

if file_names:
    file_names = sorted(file_names, key=lambda x: x[-1]) # sort filenames by chain ID
    m1res = []
    for fname in file_names:
        print('current loading: ', fname, '\n')
        m1res.append(hddm.load(fname))
else:
    m1res = p_map(partial(m3_id, df=data, samples=nsample, burn=burns, thin=thins,save_name="m3_id_tmp"), range(chains))

# 2.Load Model

In [6]:
file_names = glob.glob("m3_id_tmp" + "_chain_*[!db]", recursive=False)
m3=[]
for f in file_names:
    m3.append(hddm.load(f))

# 3.Model convergence

#### R-hat 指标检查模型是否拟合好了，所有参数的R-hat<1.01 则表明模型拟合好了

In [None]:
from kabuki.analyze import gelman_rubin
gelman_rubin(m3)

In [8]:
np.max(list(gelman_rubin(m3).values()))

1.2455632429796069

#### combine these three models to get a better approximation of the posterior distribution.

In [9]:
# Combine the models we ran to test for convergence.
m = kabuki.utils.concat_models(m3)

### visual trace

In [None]:
m.plot_posteriors(save=True)

In [10]:
m.dic

-20765.090045300985