In [None]:
import numpy as np
import pandas as pd
import pymc3 as pm
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_style('ticks')
sns.set_context('paper')

import glambox as gb

%load_ext autoreload
%autoreload 2

In [None]:
## Load Krajbich & Rangel (2011) dataset estimates from Thomas, Molter et al. (2019)
estimates = pd.read_csv('resources/individual_estimates_sec_nhb2019.csv')
kr2011 = estimates.loc[estimates['dataset'] == 'krajbich2011']
kr2011.head()

In [None]:
## Simulate a clinical dataset, with small groups and few trials
## Each individual's parameter set is sampled from the Krajbich & Rangel (2011) estimates, so that they are realistic
## Except for the gaze bias parameter gamma, which we explicitly vary between the groups.

np.random.seed(1751)

groups = ['group1', 'group2', 'group3']

# Sample sizes
N = dict(group1=5,
         group2=10,
         group3=15)

# Mean gaze bias parameters for each group
gamma_mu = dict(group1=0.7,     # weak bias
                group2=0.1,     # medium-strong bias
                group3=-0.5)    # very strong bias with leak

# Sample parameter sets from KR2011 estimates
group_idx = {group: np.random.choice(kr2011.index, size=N[group], replace=False)
             for group in groups}
v = {group: kr2011.loc[group_idx[group], 'v'].values
     for group in groups}
s = {group: kr2011.loc[group_idx[group], 's'].values
     for group in groups}
tau = {group: kr2011.loc[group_idx[group], 'tau'].values
       for group in groups}
# Draw normally distributed gaze bias parameters (truncated to be ≤ 1)
gamma = dict(group1=np.clip(np.random.normal(loc=gamma_mu['group1'], scale=0.3, size=N['group1']), a_min=None, a_max=1),
             group2=np.clip(np.random.normal(loc=gamma_mu['group2'], scale=0.2, size=N['group2']), a_min=None, a_max=1),
             group3=np.clip(np.random.normal(loc=gamma_mu['group3'], scale=0.3, size=N['group3']), a_min=None, a_max=1))

# Set the number of trials and items in the task
n_trials = 50
n_items  = 3

# Simulate the data using GLAM
glam = gb.GLAM()
for g, group in enumerate(groups):
    glam.simulate_group(kind='individual',
                        n_individuals=N[group],
                        n_trials=n_trials,
                        n_items=n_items,
                        parameters=dict(v=v[group],
                                        gamma=gamma[group],
                                        s=s[group],
                                        tau=tau[group],
                                        t0=np.zeros(N[group])),
                        label=group,
                        seed=g)

data = glam.data.copy()
data.rename({'condition': 'group'}, axis=1, inplace=True)
data.to_csv('examples/example_2/data/data.csv', index=False)
data.head(3)

In [None]:
## Plot data generating parameters
from string import ascii_uppercase

fontsize = 7
n_bins = 10

fig, axs = plt.subplots(4, 4, figsize=gb.plots.cm2inch(18, 9), sharex='col', dpi=330)

for p, (parameter, parameter_name, bins) in enumerate(
    zip([v, s, gamma, tau],
        [r'$v$', r'$\sigma$', r'$\gamma$', r'$\tau$'],
        [np.linspace(0, 2, n_bins + 1),
         np.linspace(0, 0.5, n_bins + 1),
         np.linspace(-2, 1, n_bins + 1),
         np.linspace(0, 2, n_bins + 1)])):
    # Pooled
    axs[0, p].hist(np.concatenate([parameter[group] for group in groups]),
                   bins=bins,
                   color='black',
                   alpha=0.5)
    axs[0, p].set_ylim(0, 20)
    axs[0, p].set_yticks([0, 10, 20])
    axs[0, p].set_title(parameter_name, fontsize=fontsize, fontweight='bold')
    if p == 0:
        axs[0, p].set_ylabel(r'$\bf{Pooled}$' + '\n\nFreq.', fontsize=fontsize)

    # Groups
    for g, group in enumerate(groups):
        axs[g + 1, p].hist(parameter[group],
                           bins=bins,
                           color='C{}'.format(g),
                           alpha=0.7)
        axs[g + 1, p].set_ylim(0, 1)
        axs[g + 1, p].set_yticks([0, 5, 10])
        if p == 0:
            axs[g + 1, p].set_ylabel(r'$\bf{Group}$ ' + r'$\bf{{{g}}}$'.format(g=(g + 1)) + '\n\nFreq.',
                                     fontsize=fontsize)

    axs[g + 1, p].set_xlabel(parameter_name, fontsize=fontsize)
    
for ax, letter in zip(axs.ravel(), ascii_uppercase):
    # Show ticklabels, despite sharey=True
    ax.xaxis.set_tick_params(labelbottom=True)
    ax.yaxis.set_tick_params(labelbottom=True)
    ax.tick_params(axis='both', which='major', labelsize=fontsize)

sns.despine()
    
fig.tight_layout(pad=1.02)

plt.savefig('examples/example_2/figures/generating_parameters.png', dpi=330)

In [None]:
## Check if the generating parameters, due to sampling, differ between groups, other than the gaze bias
from scipy.stats import ttest_ind

parameters = ['v', 'gamma', 's', 'tau']
comparisons=[('group1', 'group2'),
             ('group1', 'group3'),
             ('group2', 'group3')]

for p, parameter in zip([v, gamma, s, tau], ['v', 'gamma', 's', 'tau']):
    print(parameter)
    for comparison in comparisons:
        g1, g2 = comparison
        print('{} vs {}\t'.format(g1, g2), ttest_ind(p[g1], p[g2]))

In [None]:
data.groupby('group').apply(gb.analysis.aggregate_group_level_data, n_items=n_items)

In [None]:
## Plot aggregate data across groups and for each group separately, to visualize behavioural differences between groups
gb.plots.plot_aggregate(bar_data=data,
                        line_data=[data.loc[data['group'] == group] for group in groups],
                        line_labels=['Group 1', 'Group 2', 'Group 3'],
                        value_bins=7, gaze_bins=7,
                        limits=dict(rt=(0, 5)));
sns.despine()
plt.savefig('examples/example_2/figures/aggregate_data.png', dp=330)

In [None]:
hglam = gb.GLAM(data=data)
hglam.make_model(kind='hierarchical',
                 depends_on=dict(v='group',
                                 gamma='group',
                                 s='group',
                                 tau='group'))

In [None]:
np.random.seed(1958)
hglam.fit(method='MCMC',
          draws=10000, tune=90000,
          random_seed=1958)

In [None]:
hglam.estimates.to_csv('examples/example_2/results/estimates.csv', index=False)

In [None]:
variables = [v for v in hglam.trace[0].varnames
             if not v.endswith('__')
             and not v in ['b', 'p_error', 't0']]

pm.summary(hglam.trace[0], var_names=variables).head()

In [None]:
variables = [v for v in variables
             if v.endswith('_mu')
             or v.endswith('_sd')]
gb.plots.traceplot(hglam.trace[0],
          varnames=variables,
          ref_val=dict(v_group1_mu=v['group1'].mean(),
                       v_group2_mu=v['group2'].mean(),
                       v_group3_mu=v['group3'].mean(),
                       gamma_group3_mu=gamma['group3'].mean(),
                       gamma_group1_mu=gamma['group1'].mean(),
                       gamma_group2_mu=gamma['group2'].mean(),
                       s_group3_mu=s['group3'].mean(),
                       s_group1_mu=s['group1'].mean(),
                       s_group2_mu=s['group2'].mean(),
                       tau_group3_mu=tau['group3'].mean(),
                       tau_group1_mu=tau['group1'].mean(),
                       tau_group2_mu=tau['group2'].mean()))
plt.savefig('examples/example_2/figures/traceplot.png', dpi=330)
sns.despine()

In [None]:
parameters = ['v', 'gamma', 's', 'tau']
comparisons=[('group1', 'group2'),
             ('group1', 'group3'),
             ('group2', 'group3')]

fig, axs = gb.plots.plot_node_hierarchical(hglam,
                                           parameters=parameters,
                                           comparisons=comparisons);

plt.savefig('examples/example_2/figures/node_comparison.png', dpi=330)