In [4]:
import numpy as np
import nlopt
from statsrat import perform_oat
from statsrat.expr.predef import pvl
from statsrat.rw.predef import basic, Kalman, hrmn
from statsrat.lc.predef import discrete_cpl

n = 2 # fixed trial sequences, so only need n = 1 FIX (don't need to average across trial sequences)
#max_time = 5
max_time = 30
algorithm = nlopt.GN_DIRECT_L
#algorithm = nlopt.GN_AGS
#algorithm = nlopt.GN_ORIG_DIRECT
#algorithm = nlopt.GD_STOGO

In [2]:
print('\n basic latent cause model')
oat_result = perform_oat(discrete_cpl, pvl.latent_inhib, n = n, max_time = max_time, algorithm = algorithm)
print(oat_result.round(2))


 basic latent cause model
['control', 'pre_exp']
{}
{'control': <xarray.Dataset>
Dimensions:     (ident: 2, t: 60, u_name: 1, x_name: 2, z_name: 15)
Coordinates:
    trial_name  (t) <U13 'cs -> us' 'cs -> us' ... 'cs -> nothing'
    stage_name  (t) <U8 'training' 'training' 'training' ... 'test' 'test'
    stage       (t) int64 0 0 0 0 0 0 0 0 0 0 0 0 0 ... 1 1 1 1 1 1 1 1 1 1 1 1
    t_name      (t) <U8 'bg' 'bg' 'bg' 'bg' ... 'bg' 'bg' 'pre_main' 'main'
    trial       (t) int64 0 0 0 0 0 0 1 1 1 1 1 1 2 ... 8 8 8 8 8 8 9 9 9 9 9 9
  * t           (t) int64 0 1 2 3 4 5 6 7 8 9 ... 50 51 52 53 54 55 56 57 58 59
  * u_name      (u_name) <U2 'us'
  * x_name      (x_name) <U3 'cs' 'ctx'
  * z_name      (z_name) <U2 '0' '1' '2' '3' '4' ... '10' '11' '12' '13' '14'
  * ident       (ident) object 'sim_0' 'sim_1'
Data variables:
    x           (ident, t, x_name) float64 0.0 1.0 0.0 1.0 ... 0.0 1.0 1.0 1.0
    u           (ident, t, u_name) float64 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
    u_

IndexError: list index out of range

In [5]:
print('\n basic Rescorla-Wagner')
oat_result = perform_oat(basic, pvl.latent_inhib, n = n, max_time = max_time, algorithm = algorithm)
print(oat_result.round(2))


 basic Rescorla-Wagner
['control', 'pre_exp']
{}
{'control': <xarray.Dataset>
Dimensions:     (f_name: 2, ident: 5, t: 60, u_name: 1, x_name: 2)
Coordinates:
    trial_name  (t) <U13 'cs -> us' 'cs -> us' ... 'cs -> nothing'
    stage_name  (t) <U8 'training' 'training' 'training' ... 'test' 'test'
    stage       (t) int64 0 0 0 0 0 0 0 0 0 0 0 0 0 ... 1 1 1 1 1 1 1 1 1 1 1 1
    t_name      (t) <U8 'bg' 'bg' 'bg' 'bg' ... 'bg' 'bg' 'pre_main' 'main'
    trial       (t) int64 0 0 0 0 0 0 1 1 1 1 1 1 2 ... 8 8 8 8 8 8 9 9 9 9 9 9
  * t           (t) int64 0 1 2 3 4 5 6 7 8 9 ... 50 51 52 53 54 55 56 57 58 59
  * u_name      (u_name) <U2 'us'
  * f_name      (f_name) <U3 'cs' 'ctx'
  * x_name      (x_name) <U3 'cs' 'ctx'
  * ident       (ident) object 'sim_0' 'sim_1' 'sim_2' 'sim_3' 'sim_4'
Data variables:
    x           (ident, t, x_name) float64 0.0 1.0 0.0 1.0 ... 0.0 1.0 1.0 1.0
    u           (ident, t, u_name) float64 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
    u_psb       (ident, 

IndexError: list index out of range

In [4]:
print('\n Kalman filter')
oat_result = perform_oat(Kalman, pvl.latent_inhib, n = n, max_time = max_time, algorithm = algorithm)
print(oat_result.round(2))


 Kalman filter
<statsrat.expr.learn.schedule object at 0x2920c0eef0>
<statsrat.expr.learn.schedule object at 0x2920c0eef0>
<statsrat.expr.learn.schedule object at 0x2920bd3358>
<statsrat.expr.learn.schedule object at 0x2920bd3358>


SystemError: <class 'dict'> returned a result with an error set

In [None]:
print('\n harmonically decreasing learning rate')
oat_result = perform_oat(hrmn, pvl.latent_inhib, n = n, max_time = max_time, algorithm = algorithm)
print(oat_result.round(2))