In [6]:
%load_ext autoreload

%reload_ext autoreload

# MODULE IMPORTS ----

# warning settings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Data management
import pandas as pd
import numpy as np
import pickle

# Plotting
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns

# Stats functionality
from statsmodels.distributions.empirical_distribution import ECDF

# HDDM
import hddm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
# Metadata
nmcmc = 300
burn = 100
model = 'levy'

In [19]:
# Load one of the datasets shipping with HDDM
cav_data = hddm.load_csv(hddm.__path__[0] + '/examples/cavanagh_theta_nn.csv')

In [20]:
cav_data

Unnamed: 0,subj_idx,stim,rt,response,theta,dbs,conf
0,0,LL,1.210,1.0,0.656275,1,HC
1,0,WL,1.630,1.0,-0.327889,1,LC
2,0,WW,1.030,1.0,-0.480285,1,HC
3,0,WL,2.770,1.0,1.927427,1,LC
4,0,WW,1.140,0.0,-0.213236,1,HC
...,...,...,...,...,...,...,...
3983,13,LL,1.450,0.0,-1.237166,0,HC
3984,13,WL,0.711,1.0,-0.377450,0,LC
3985,13,WL,0.784,1.0,-0.694194,0,LC
3986,13,LL,2.350,0.0,-0.546536,0,HC


In [26]:
reg_descrs = [{'model': 'v ~ 1 + theta', 'link_func': lambda x: x}]

In [27]:
hddmnn_model_cav = hddm.HDDMnnRegressor(cav_data,
                                        reg_descrs,
                                        model = model,
                                        informative = False,
                                        include = hddm.simulators.model_config[model]['hddm_include'],
                                        p_outlier = 0.05,
                                        is_group_model = False,
                                        depends_on = {'v': 'stim'},
                                        )

Setting priors uninformative (LANs only work with uninformative priors for now)
Includes supplied:  ['z', 'alpha']
Reg Model:
{'outcome': 'v', 'model': ' 1 + theta', 'params': ['v_Intercept', 'v_theta'], 'link_func': <function <lambda> at 0x13f4cda70>}
Uses Identity Link


In [28]:
hddmnn_model_cav.sample(nmcmc, burn = burn, dbname = 'data/test/test_db.db', db = 'pickle')

 [-----------------100%-----------------] 301 of 300 complete in 63.8 sec

<pymc.MCMC.MCMC at 0x13f4b7bd0>

In [29]:
hddmnn_model_cav.print_stats()

                       mean         std       2.5q       25q         50q        75q      97.5q       mc err
a                   1.28717   0.0342684    1.22682   1.25991     1.28531    1.31617    1.34779   0.00332518
z                  0.526181  0.00992487   0.506236  0.520062     0.52562   0.532233    0.54732  0.000892705
alpha                1.6231   0.0781073    1.49116   1.55237     1.62809    1.68958    1.75149   0.00757602
t                  0.445397  0.00743025   0.433506  0.439854    0.444862   0.450702   0.459837  0.000685366
v_Intercept(LL)    0.235146   0.0480488   0.150995  0.196873    0.236946    0.27317   0.323212    0.0038841
v_Intercept(WL)    0.807615   0.0428057   0.734275  0.777274    0.806756   0.833351   0.913376   0.00378458
v_Intercept(WW)    0.164434   0.0496131  0.0558535  0.131526     0.16567   0.196162   0.268726   0.00386317
v_theta         -0.00220532   0.0212122 -0.0443581  -0.01588 -0.00312366  0.0109834  0.0419554   0.00142847
DIC: 12125.495212
deviance: 

In [30]:
hddmnn_model_cav.save('data/test/test.pickle')

In [31]:
model_loaded = hddm.load('data/test/test.pickle')

Reg Model:
{'outcome': 'v', 'model': ' 1 + theta', 'params': ['v_Intercept', 'v_theta'], 'link_func': <function <lambda> at 0x13f266cb0>}
Uses Identity Link


In [32]:
model_loaded.get_traces()

Unnamed: 0,a,z_trans,alpha,t,v_Intercept(LL),v_Intercept(WL),v_Intercept(WW),v_theta
0,1.265049,0.211672,1.562756,0.440244,0.136568,0.752086,0.134633,-0.020984
1,1.269996,0.199928,1.594422,0.447853,0.166228,0.771398,0.134447,0.023638
2,1.269258,0.212592,1.539039,0.451826,0.252434,0.836718,0.188656,0.031301
3,1.283988,0.201851,1.581850,0.448658,0.233061,0.772118,0.192688,-0.018283
4,1.273850,0.199400,1.593444,0.449779,0.175591,0.772763,0.129692,-0.016577
...,...,...,...,...,...,...,...,...
195,1.220842,0.216124,1.483616,0.453027,0.167301,0.753453,0.202706,-0.010353
196,1.220319,0.193029,1.488009,0.457440,0.192744,0.753681,0.157244,0.046480
197,1.246261,0.164901,1.504067,0.459534,0.235708,0.758586,0.260084,0.030466
198,1.268438,0.139141,1.525533,0.446940,0.235962,0.788010,0.096457,-0.065089


In [None]:
#     for reg_model in reg_models:
#         tmp_reg_model = {}
        
#         if reg_model['link_func'] == 'id':
#             tmp_reg_model['link_func'] = id_link
#         elif reg_model['link_func'] == 'vlink':
#             tmp_reg_model['link_func'] = v_link_func
        
#         tmp_reg_model['model'] = reg_model['model']
#         new_reg_descrs.append(deepcopy(tmp_reg_model))
        