In [1]:
import numpy as np
import xarray as xr
import pandas as pd
import nlopt
import seaborn as sns
from statsrat import perform_oat, oat_grid, make_sim_data, learn_plot
from statsrat.expr.predef import kitten
import statsrat.bayes_regr as br
from statsrat.rw.predef import drva
from statsrat.rw.fbase import elem, elem_intercept
from scipy import stats
from plotnine import theme, scale_x_continuous

In [2]:
bdrva = br.model(name = 'linear ARD derived attention', fbase = elem, link = br.link.linear, tausq_inv_dist = br.tausq_inv_dist.ard_drv_atn)
n = 10
max_time = 10
print(bdrva.pars)
par_val = [-0.5, 2, 1.5, 5] # alpha0 = 2 + 1 = 3; beta0 = -(-0.5) = 0.5
#algorithm = nlopt.GN_DIRECT_L
#algorithm = nlopt.GN_AGS
algorithm = nlopt.GN_ORIG_DIRECT
#algorithm = nlopt.GD_STOGO
run_oat = True

                        min   max  default
prior_tausq_inv_hpar0 -10.0   0.0     -2.0
prior_tausq_inv_hpar1   1.0  11.0      3.0
u_var                   0.0  10.0      0.1
resp_scale              0.0  10.0      1.0


In [3]:
print('original derived attention: value effect')

trials = kitten.value_sal.make_trials()
ds = drva.simulate(trials = trials, par_val = [0.1, 0.99, 5])
print(kitten.value_sal.oats['value'].mean_resp(ds))

index = ds.stage_name.isin(['value', 'transfer'])
learn_plot(ds, 'atn', {'t': index}, text_size = 12).save(filename = 'original_value_atn.png')
learn_plot(ds, 'w', {'u_name': ['cat1', 'cat2'], 't': index}, text_size = 12).save(filename = 'original_value_w.png')

original derived attention: value effect
  schedule      trial_name u_name  mean_resp
0   design  a.y -> nothing   cat3   0.984278
1   design  a.y -> nothing   cat4   0.015722
2   design  b.x -> nothing   cat3   0.015722
3   design  b.x -> nothing   cat4   0.984278




In [5]:
print('Bayesian derived attention: learned predictiveness')

if run_oat:
    oat_result = perform_oat(bdrva, kitten.lrn_pred, minimize = False, n = n, max_time = max_time, algorithm = algorithm)
    print(oat_result[0])
    print(oat_result[1])
    print()

trials = kitten.lrn_pred.make_trials()
ds = bdrva.simulate(trials = trials, par_val = par_val)
print(kitten.lrn_pred.oats['rel_irl'].mean_resp(ds))

index = ds.stage_name.isin(['relevance', 'transfer'])
learn_plot(ds, 'mean_tausq', {'u_name': ['cat1'], 't': index}, text_size = 12).save('lrn_pred_tausq.png')
foo = learn_plot(ds, 'mean_w', {'u_name': ['cat1', 'cat2'], 't': index}, text_size = 12)
foo += scale_x_continuous(name = 'stage', breaks = [0, 40], labels = ['rel', 'transfer'])
foo.save(filename = 'lrn_pred_w.png')

Bayesian derived attention: learned predictiveness
  schedule      trial_name u_name  mean_resp
0   design  a.y -> nothing   cat3   0.569777
1   design  a.y -> nothing   cat4   0.430223
2   design  b.x -> nothing   cat3   0.430001
3   design  b.x -> nothing   cat4   0.569999


Adding another scale for 'x',
which will replace the existing scale.



In [5]:
print('Bayesian derived attention: blocking')

if run_oat:
    oat_result = perform_oat(bdrva, kitten.blk_inatn, oat = 'blocking', minimize = False, n = n, max_time = max_time, algorithm = algorithm)
    print('OAT for blocking')
    print(oat_result[0])
    print(oat_result[1])

    oat_result = perform_oat(bdrva, kitten.blk_inatn, oat = 'inattention', minimize = False, n = n, max_time = max_time, algorithm = algorithm)
    print('OAT for inattention')
    print(oat_result[0])
    print(oat_result[1])
    print()

trials = kitten.blk_inatn.make_trials()
ds = bdrva.simulate(trials = trials, par_val = par_val)
print('blocking')
print(kitten.blk_inatn.oats['blocking'].mean_resp(ds))
print('inattention')
print(kitten.blk_inatn.oats['inattention'].mean_resp(ds))

index = ds.stage_name.isin(['single_cue', 'double_cue'])
learn_plot(ds, 'mean_tausq', {'u_name': ['cat1'], 't': index}, text_size = 12).save(filename = 'blk_inatn_tausq.png')
learn_plot(ds, 'mean_w', {'u_name': ['cat1', 'cat2'], 't': index}, text_size = 12).save(filename = 'blk_inatn_w.png')

Bayesian derived attention: blocking
blocking
  schedule      trial_name u_name  mean_resp
0   design  e.y -> nothing   cat1   0.793112
1   design  e.y -> nothing   cat2   0.206888
2   design  g.x -> nothing   cat1   0.206886
3   design  g.x -> nothing   cat2   0.793114
inattention
  schedule      trial_name u_name  mean_resp
0   design  a.y -> nothing   cat3   0.631141
1   design  a.y -> nothing   cat4   0.368859
2   design  b.x -> nothing   cat3   0.368338
3   design  b.x -> nothing   cat4   0.631662




In [6]:
print('Bayesian derived attention: value effect')

if run_oat:
    oat_result = perform_oat(bdrva, kitten.value_sal, minimize = False, n = n, max_time = max_time, algorithm = algorithm)
    print(oat_result[0])
    print(oat_result[1])
    print()

trials = kitten.value_sal.make_trials()
ds = bdrva.simulate(trials = trials, par_val = par_val)
print(kitten.value_sal.oats['value'].mean_resp(ds))

index = ds.stage_name.isin(['value', 'transfer'])
learn_plot(ds, 'mean_tausq', {'u_name': ['cat1'], 't': index}, text_size = 12).save(filename = 'value_sal_tausq.png')
learn_plot(ds, 'mean_w', {'u_name': ['cat1', 'cat2'], 't': index}, text_size = 12).save(filename = 'value_sal_w.png')

Bayesian derived attention: value effect
  schedule      trial_name u_name  mean_resp
0   design  a.y -> nothing   cat3   0.575667
1   design  a.y -> nothing   cat4   0.424333
2   design  b.x -> nothing   cat3   0.427451
3   design  b.x -> nothing   cat4   0.572549




In [7]:
print('Bayesian derived attention: retrospective revaluation')

if run_oat:
    oat_result = perform_oat(bdrva, kitten.bkwd_blk, minimize = False, n = n, max_time = max_time, algorithm = algorithm)
    print(oat_result[0])
    print(oat_result[1])
    print()

trials = kitten.bkwd_blk.make_trials()
ds = bdrva.simulate(trials = trials, par_val = par_val)
print(kitten.bkwd_blk.oats['blocking'].mean_resp(ds))

index = ds.stage_name.isin(['single_cue', 'double_cue'])
learn_plot(ds, 'mean_tausq', {'u_name': ['cat1'], 't': index}, text_size = 12).save(filename = 'bkwd_blk_tausq.png')
learn_plot(ds, 'mean_w', {'t': index}, text_size = 12).save(filename = 'bkwd_blk_w.png')

Bayesian derived attention: retrospective revaluation
  schedule      trial_name u_name  mean_resp
0   design  e.y -> nothing   cat1   0.792506
1   design  e.y -> nothing   cat2   0.207494
2   design  g.x -> nothing   cat1   0.206906
3   design  g.x -> nothing   cat2   0.793094


