In [309]:
import os
import sys
import scipy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

import import_ipynb
from model import Model, ModelLocBias
from model import fast_optimize

In [310]:
FIG_FOLDER = 'fig'
SOURCE_FOLDER = os.path.join('data', 'source')
BACKUP_FOLDER = os.path.join('data', 'backup')
print(f"The source folder is: {os.path.abspath(SOURCE_FOLDER)}")
print(f"The figure folder is: {os.path.abspath(FIG_FOLDER)}")
print(f"The backup folder is: {os.path.abspath(BACKUP_FOLDER)}")

# Create folders
for f in SOURCE_FOLDER, FIG_FOLDER, BACKUP_FOLDER:
    os.makedirs(f, exist_ok=True)

The source folder is: /Users/aureliennioche/Documents/PythonProjects/ProspecTonk/data/source
The figure folder is: /Users/aureliennioche/Documents/PythonProjects/ProspecTonk/fig
The backup folder is: /Users/aureliennioche/Documents/PythonProjects/ProspecTonk/data/backup


In [311]:
DATASET = "B"
MODEL = ModelLocBias
OPTIMIZE = fast_optimize

# Import data

In [312]:
df_bhv = pd.read_csv(os.path.join(BACKUP_FOLDER, f"df_bhv{DATASET}.csv"))
df_bhv.date = pd.to_datetime(df_bhv.date)
df_bhv

Unnamed: 0,monkey,date,c,p0,x0,p1,x1,time_response,left_X,left_Y,...,is_same_x,is_best_left,is_best_right,pair_id,is_control,is_risky,is_neither_risky_nor_control,is_reversed,choose_risky,choose_best
0,Ola,2020-06-25,0,0.75,2,0.75,3,528,1195.0,131.0,...,False,False,True,0,True,False,False,True,False,False
1,Ola,2020-06-25,0,0.50,3,0.25,3,506,1195.0,131.0,...,True,True,False,1,True,False,False,False,False,True
2,Ola,2020-06-25,0,0.75,2,0.75,-2,394,469.0,131.0,...,False,True,False,2,True,False,False,False,False,True
3,Ola,2020-06-25,0,0.25,-2,0.25,-3,396,469.0,131.0,...,False,True,False,3,True,False,False,False,False,True
4,Ola,2020-06-25,0,0.75,-1,0.75,-3,329,469.0,131.0,...,False,True,False,4,True,False,False,False,False,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
137292,Alv,2020-10-25,0,0.25,-2,1.00,-1,1074,1195.0,131.0,...,False,False,False,83,False,True,False,False,True,False
137293,Alv,2020-10-25,0,1.00,-2,1.00,-3,1995,469.0,131.0,...,False,True,False,30,True,False,False,False,False,True
137294,Alv,2020-10-25,1,0.25,3,0.75,2,843,1195.0,131.0,...,False,False,False,9,False,True,False,False,False,False
137295,Alv,2020-10-25,0,1.00,1,0.25,1,703,469.0,131.0,...,True,True,False,17,True,False,False,False,False,True


In [313]:
df_bhv["right_XY"] = df_bhv["right_X"].astype(int).astype(str) + '_' + df_bhv["right_Y"].astype(int).astype(str)
print(df_bhv.right_XY.unique())

['1195_517' '469_517' '1195_412']


In [314]:
df_bhv["left_XY"] = df_bhv["left_X"].astype(int).astype(str) + '_' + df_bhv["left_Y"].astype(int).astype(str)
print(df_bhv.left_XY.unique())

['1195_131' '469_131' '469_412']


In [315]:
df_bhv["position"] = df_bhv["left_XY"] + '_vs_' + df_bhv["right_XY"]
print(df_bhv.position.unique())

['1195_131_vs_1195_517' '469_131_vs_469_517' '469_412_vs_1195_412']


In [316]:
df_bhv.groupby(by="position").count()

Unnamed: 0_level_0,monkey,date,c,p0,x0,p1,x1,time_response,left_X,left_Y,...,is_best_right,pair_id,is_control,is_risky,is_neither_risky_nor_control,is_reversed,choose_risky,choose_best,right_XY,left_XY
position,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1195_131_vs_1195_517,67224,67224,67224,67224,67224,67224,67224,67224,67224,67224,...,67224,67224,67224,67224,67224,67224,67224,67224,67224,67224
469_131_vs_469_517,67018,67018,67018,67018,67018,67018,67018,67018,67018,67018,...,67018,67018,67018,67018,67018,67018,67018,67018,67018,67018
469_412_vs_1195_412,3055,3055,3055,3055,3055,3055,3055,3055,3055,3055,...,3055,3055,3055,3055,3055,3055,3055,3055,3055,3055


# At once

In [319]:
for pos, df_bhv_pos in df_bhv.groupby(by="position"): 
    
    cond = "gain", "loss"
    monkeys = df_bhv.monkey.unique()

    df_fit_overall = pd.DataFrame()

    for i_m, m in tqdm(enumerate(monkeys), file=sys.stdout, total=len(monkeys)):

        for cd in cond:

            # Select the data
            df_m = df_bhv_pos[(df_bhv_pos.monkey == m) 
                              & (df_bhv_pos.is_risky == True) 
                              & (df_bhv_pos[f"is_{cd}"] == True)]

            # Get the dates
            dates = df_m.date.unique()
            if len(dates) < 1:
                continue

            # Get the number of parameters
            n_param = len(MODEL.param_labels)

            # Optimize
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                best_param, best_value = OPTIMIZE(model=MODEL, data=df_m)

            # Backup
            df_fit_m = pd.DataFrame(best_param.reshape(1, -1), columns=MODEL.param_labels)
            df_fit_m["date_begin"] = dates[0]
            df_fit_m["date_end"] = dates[-1]
            df_fit_m["monkey"] = m
            df_fit_m["condition"] = cd
            df_fit_m["n"] = len(df_m)
            df_fit_m["loss"] = -best_value / len(df_m)
            df_fit_overall = pd.concat((df_fit_overall, df_fit_m))

    df_fit_overall.to_csv(os.path.join(BACKUP_FOLDER, f"df_fit_overall{DATASET}_{MODEL.__name__}_{pos}.csv"))

100%|██████████| 15/15 [00:02<00:00,  6.70it/s]
100%|██████████| 15/15 [00:02<00:00,  7.00it/s]
100%|██████████| 15/15 [00:00<00:00, 307.06it/s]


### Load the results

In [320]:
positions = ['1195_131_vs_1195_517', '469_131_vs_469_517', '469_412_vs_1195_412']

df_fit_overall = pd.read_csv(os.path.join(BACKUP_FOLDER, f"df_fit_overall{DATASET}_{MODEL.__name__}_{positions[0]}.csv"))
for c in ("date_begin", "date_end"):
    df_fit_overall[c] = pd.to_datetime(df_fit_overall[c])
df_fit_overall.drop(df_fit_overall.filter(regex="Unname"),axis=1, inplace=True)
df_fit_overall

Unnamed: 0,distortion,precision,risk_aversion,loc_bias,date_begin,date_end,monkey,condition,n,loss
0,3.231039,0.84358,0.685802,-4.177722,2020-06-25,2020-10-25,Ola,gain,410,-0.052239
1,0.25,0.499682,-0.495693,-2.046484,2020-06-25,2020-10-25,Ola,loss,407,-0.099468
2,0.546107,0.422205,0.453638,1.243981,2020-06-25,2020-10-25,Abr,gain,944,-0.222489
3,0.837528,1.481038,-1.758502,2.215694,2020-06-25,2020-10-25,Abr,loss,895,-0.467859
4,1.96946,5.545476,-0.374935,0.421067,2020-06-25,2020-10-25,Nem,gain,402,-0.685007
5,3.470907,5.693589,-0.365398,1.803686,2020-06-25,2020-10-25,Nem,loss,411,-0.663135
6,1.563822,0.528037,-0.088305,0.964425,2020-06-25,2020-10-25,Alv,gain,917,-0.395378
7,0.910996,0.346152,-0.37693,0.37319,2020-06-25,2020-10-25,Alv,loss,933,-0.512958
8,0.25,0.891291,0.75,3.774533,2020-06-25,2020-10-25,Ner,gain,251,-0.080205
9,4.0,1.844709,0.171274,5.534946,2020-06-25,2020-10-25,Ner,loss,219,-0.19458
