In [1]:
from utils.data import load_data
from stimuli.motion_structure import MotionStructure
from analysis.kalman_filter import apply_filters_on_trial
from analysis.model import fit
import numpy as np
import pandas as pd
import pickle
structures={
    'IND': MotionStructure(0, 2),
    'GLO': MotionStructure(1, 1/4),
    'CLU': MotionStructure(0, 1/4),
    'SDH': MotionStructure(0.77, 1/4)
}
s = ['IND', 'GLO', 'CLU_012', 'CLU_120', 'CLU_201', 'SDH_012', 'SDH_120', 'SDH_201']
df_sichao = pd.read_csv(f'../data/pilot_sichao_glo=0.77_v=1.5_0.csv')[['ground_truth', 'Sichao_choice']]
df_johannes = pd.read_csv(f'../data/pilot_johannes_glo=0.77_v=1.5_0.csv')[['ground_truth', 'Johannes_choice']]
data_sichao = load_data('../data/sichao_0729.dat')
data_johannes = load_data('../data/johannes_0729.dat')
x, t = [], []
for trial in data_sichao:
    x.append(trial['φ'][3:])
    t.append(trial['t'][3:])
x_sichao, t_sichao = np.array(x), np.array(t)
x, t = [], []
for trial in data_johannes:
    x.append(trial['φ'][3:])
    t.append(trial['t'][3:])
x_johannes, t_johannes = np.array(x), np.array(t)
rng = np.random.RandomState()
dx = rng.normal(size=x_sichao.shape)

In [2]:
def loss(x, t, σ_R, df_target):
    ℓ = []
    for ii in range(x.shape[0]):
        ℓ.append(np.array(apply_filters_on_trial(x[ii], t[ii], σ_R, structures=structures)))
    df = pd.DataFrame(ℓ, columns=s)
    df['target'] = df_target
    mask = np.column_stack([df['target'] == structure for structure in ['IND', 'GLO'] + ['CLU'] * 3 + ['SDH'] * 3])
    res = fit(df, mask, 'target', np.array([0., 0., 0.]), disp=False)
    # print(res)
    return res.fun

In [3]:
for σ_R in np.arange(0, 5, 0.5):
    print(loss(x_sichao, t_sichao, σ_R, df_sichao['Sichao_choice']))

186.75397313497197
149.09450946010816
139.0977458396738
135.61225615180717
136.2264677977992
138.38059472842986


KeyboardInterrupt: 

In [4]:
for σ_R in np.arange(1.6, 2.4, 0.1):
    print(loss(x_sichao, t_sichao, σ_R, df_sichao['Sichao_choice']))

135.53256839890525
135.57583239777424
135.7170711121753
135.93818787408367
136.22646780132217
136.57238130248376
136.96786716835058


KeyboardInterrupt: 

In [5]:
for σ_R in np.arange(0, 5, 0.5):
    print(loss(x_johannes, t_johannes, σ_R, df_johannes['Johannes_choice']))

203.11549293210172
157.29692151457888
148.10431689524512
148.19266644658052
150.2803423849211
153.53773999161604
157.6460378788276
161.45051530525507
164.76032344431994
167.83193690549842
