In [50]:
from BanditAnalysisFunctions import *
import model_policies as models
import resample_and_model_reps as reps
import cprobs_functions as cprobs
from sklearn.model_selection import train_test_split
import plot_models_v_mouse as bp
from scipy.stats import sem
pd.options.display.max_columns = 50
%matplotlib qt

In [4]:
gronckle = pd.read_csv('sessdf1.csv')
toothless = pd.read_csv('sessdf2.csv')
brat = pd.read_csv('sessdf3.csv')
grump = pd.read_csv('sessdf4.csv')
aggro = pd.read_csv('aggro.csv')
auroma = pd.read_csv('auroma.csv')
buffalord = pd.read_csv('buffalord.csv')
bewilderbeast = pd.read_csv('bewilderbeast.csv')

In [44]:
def convert_to_bool(str):
    '''
    
    input - str of length n
    output - bool (1/0) number representing the upper/lowercase pattern of input str
    '''
    bool_string = ''
    for e in str:
        if e.isupper() == True:
            bool_string += '1'
        else: 
            bool_string += '0'

    return bool_string

In [17]:
def plot_scatter(df_mouse, df_model):
    
    sns.set(style='ticks', font_scale=1.6, rc={'axes.labelsize':18, 'axes.titlesize':18})   
    sns.set_palette('dark')
    
    plt.figure(figsize=(4,4))
    plt.subplot(111, aspect='equal')
    plt.scatter(df_mouse.pswitch, df_model.pswitch, alpha=0.6, edgecolor=None, linewidth=0)
    plt.plot([0, 1], [0, 1], ':k')
    
    plt.xlabel('P(switch)$_{rat}$')
    plt.ylabel('P(switch)')
    plt.xticks(np.arange(0, 1.1, 0.5))
    plt.yticks(np.arange(0, 1.1, 0.5))
    
    plt.tight_layout()
    sns.despine()

In [57]:
def sort_cprobs(conditional_probs, sorted_histories):
    
    '''
    sort conditional probs by reference order for history sequences to use for plotting/comparison
    
    INPUTS:
        - conditional_probs (pandas DataFrame): from calc_conditional_probs
        - sorted_histories (list): ordered history sequences from reference conditional_probs dataframe
    OUTPUTS:
        - (pandas DataFrame): conditional_probs sorted by reference history order
    '''
    
    from pandas.api.types import CategoricalDtype
    
    cat_history_order = CategoricalDtype(sorted_histories, ordered=True) # make reference history ordinal
    
    conditional_probs['history'] = conditional_probs['history'].astype(cat_history_order) # apply reference ordinal values to new df
    
    return conditional_probs.sort_values('history') # sort by reference ordinal values for history

In [63]:
def group_by_reward(df_mouse_symm, symm_cprobs_model, cprobs_str, cprobs_ustr, cprobs_array_str, cprobs_array_ustr,  sort_order, comb_list = ['111', '110', '101', '100', '000', '001', '010', '011']):
    '''
    '''
    sns.set_palette('deep')

    cprobs_str_df = pd.DataFrame(data = {'history':sort_order, 'pswitch': cprobs_str})
    cprobs_ustr_df = pd.DataFrame(data = {'history':sort_order, 'pswitch': cprobs_ustr})     

    fig, ax = plt.subplots(2, 4, layout = 'tight', figsize = (15, 5))
    fig.suptitle('model vs rat')

    cprobs_str_df_bool = pd.DataFrame(data = {'history_bool': cprobs_str_df['history'].map(lambda x: convert_to_bool(x)), 'history':cprobs_str_df['history'] , 'pswitch':cprobs_str_df['pswitch']})     
    cprobs_ustr_df_bool = pd.DataFrame(data = {'history_bool': cprobs_ustr_df['history'].map(lambda x: convert_to_bool(x)),'history': cprobs_ustr_df['history'] ,  'pswitch':cprobs_ustr_df['pswitch']})  
    loc = [(0, 0), (0, 1), (0, 2), (0,3), (1, 0), (1, 1), (1, 2), (1,3)] #list of (row, column) tuples
    i = 0
    
    for comb in comb_list:
        row, column = loc[i][0], loc[i][1]
        comb_str = cprobs_str_df_bool[cprobs_str_df_bool['history_bool']==comb].index.tolist()
        comb_ustr = cprobs_str_df_bool[cprobs_ustr_df_bool['history_bool']==comb].index.tolist()
        # print(cprobs_str_list[1][comb_str])

        err_str = df_mouse_symm.pswitch_err[comb_str]
        err_ustr = symm_cprobs_model.pswitch_err[comb_str]

        df_str = sort_cprobs(cprobs_str_df_bool, ['AAA', 'ABB', 'ABA', 'AAB', 'AAa', 'ABb', 'ABa', 'AAb', 'AaA', 'AbB', 'AbA', 'AaB', 'Aaa', 'Abb', 'Aba', 'Aab', 'aaa', 'abb', 'aba', 'aab', 'aaA', 'abB', 'abA', 'aaB', 'aAa', 'aBb', 'aBa', 'aAb', 'aAA', 'aBB', 'aBA', 'aAB'])
        df_ustr = sort_cprobs(cprobs_ustr_df_bool, ['AAA', 'ABB', 'ABA', 'AAB', 'AAa', 'ABb', 'ABa', 'AAb', 'AaA', 'AbB', 'AbA', 'AaB', 'Aaa', 'Abb', 'Aba', 'Aab', 'aaa', 'abb', 'aba', 'aab', 'aaA', 'abB', 'abA', 'aaB', 'aAa', 'aBb', 'aBa', 'aAb', 'aAA', 'aBB', 'aBA', 'aAB']) 

        ax[row][column].bar(x = (df_str[(df_str['history_bool'])== comb]).history, height = df_str[(df_str['history_bool'])==comb].pswitch, color = 'k', alpha = 0.4, yerr = err_str)
        ax[row][column].bar(x = (df_ustr[(df_ustr['history_bool'])== comb]).history, height = df_ustr[(df_ustr['history_bool'])==comb].pswitch, color = sns.color_palette()[0], alpha = 0.8, yerr = err_ustr)
        ax[row][column].set_ylim(bottom = 0)
        ax[row][column].title.set_text(comb_list[i])
        i+=1
    
        



In [5]:
def add_cols(df, rat):
    '''
    modify df to make it compatible with celiaberon logreg code
    add columns: blockTrial, blockLength, Target 
    modify column names
    input: df, rat(str, eg. R1)
    output: modified df
    '''                                                                                                                                                                                                                                 
    df = df.rename(columns = {'trial#':'Trial', 'port':'Decision', 'reward': 'Reward', 'session#':'Session'})
    df['Decision'] = df['Decision'].replace({1:0, 2:1})
    switches = abs(df['Decision'].shift(-1, ) - df['Decision'])
    switches.iloc[-1] = 0
    df.insert(4, 'Switch', switches )
    #### insert blockTrials:
    sessions = pd.unique(df['Session'])
    df.insert(2, 'blockTrial', np.zeros(len(df)))
    
    for i, session in enumerate(sessions):
        if i > 0:
            blockTrial = df[df['Session'] == session]['Trial'].sub((np.ones(len(df[df['Session'] == session]))*max(df[df['Session'] == sessions[i-1]]['Trial'])))
        elif i == 0:
            blockTrial = df[df['Session']==session]['Trial']
        df.loc[df['Session'] == session, ['blockTrial']] = blockTrial
    
    #### insert blockLength
    df.insert(3, 'blockLength', np.zeros(len(df)))
    
    for session in sessions:
        blockLength = np.ones(len(df[df['Session']==session]))*max(df[df['Session']==session]['blockTrial'])
        df.loc[df['Session']==session, ['blockLength']] = blockLength
    
    #### insert Target
    df.insert(4, 'Target', np.zeros(len(df)))
    for session in sessions:
        if np.mean(df[df['Session']== session]['rewprobfull1']) > np.mean(df[df['Session']== session]['rewprobfull2']):
            target = 1
        else:
            target = 2
        df.loc[df['Session']==session, ['Target']] = target
    df['Session'] = df.Session.map(lambda x: rat + '_' + f'{x}')
    df['Rat'] = rat
    return df
##target


In [7]:
sessdf_all = pd.concat([add_cols(gronckle, 'R1'), add_cols(toothless, 'R2'), add_cols(brat, 'R3'), add_cols(grump, 'R4'), add_cols(buffalord, 'R5'), add_cols(bewilderbeast, 'R6'), add_cols(auroma, 'R7'), add_cols(aggro, 'R8') ])
sessdf_str = pd.concat([add_cols(gronckle, 'R1'), add_cols(toothless, 'R2'), add_cols(buffalord, 'R5'), add_cols(bewilderbeast, 'R6')])
sessdf_ustr = pd.concat([add_cols(brat, 'R3'), add_cols(grump, 'R4'), add_cols(auroma, 'R7'), add_cols(aggro, 'R8') ])

In [93]:
sessdf_all

Unnamed: 0.1,Unnamed: 0,Trial,blockTrial,blockLength,Target,trialstart,Decision,Switch,Reward,trialend,Session,eptime,task,rewprobfull1,rewprobfull2,rw,animal,datetime,Rat
0,0,0,0.0,302.0,2.0,8226,1.0,0.0,0,8726,R1_1,1695837108,12,30,70,70,Gronckle,2023-09-27 17:51:48,R1
1,1,1,1.0,302.0,2.0,11117,1.0,0.0,1,11617,R1_1,1695837111,12,30,70,70,Gronckle,2023-09-27 17:51:51,R1
2,2,2,2.0,302.0,2.0,24965,1.0,0.0,1,25465,R1_1,1695837124,12,30,70,70,Gronckle,2023-09-27 17:52:04,R1
3,3,3,3.0,302.0,2.0,30058,1.0,0.0,0,30558,R1_1,1695837130,12,30,70,70,Gronckle,2023-09-27 17:52:10,R1
4,4,4,4.0,302.0,2.0,31770,1.0,0.0,1,32270,R1_1,1695837131,12,30,70,70,Gronckle,2023-09-27 17:52:11,R1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
76104,76104,76104,269.0,273.0,2.0,1162679,1.0,0.0,1,1163022,R4_167,1702264773,13,80,90,90,Grump,2023-12-11 03:19:33,R4
76105,76105,76105,270.0,273.0,2.0,1164634,1.0,0.0,1,1164880,R4_167,1702264775,13,80,90,90,Grump,2023-12-11 03:19:35,R4
76106,76106,76106,271.0,273.0,2.0,1166300,1.0,0.0,1,1166421,R4_167,1702264777,13,80,90,90,Grump,2023-12-11 03:19:37,R4
76107,76107,76107,272.0,273.0,2.0,1167738,1.0,0.0,1,1167836,R4_167,1702264779,13,80,90,90,Grump,2023-12-11 03:19:39,R4


In [12]:
data = sessdf_all

# data = pd.read_csv(os.path.join('mouse_data.csv'))

data = data.groupby('Session').filter(lambda x: len(x['Trial'])>50) # make sure there are no empty sessions
seq_nback = 3
train_prop = 0.7
seed = np.random.randint(1000)


data = cprobs.add_history_cols(data, seq_nback) # set history labels up front

train_session_ids, test_session_ids = train_test_split(data.Session.unique(), 
                                                       train_size=train_prop, random_state=seed) # split full df for train/test

data['block_pos_rev'] = data['blockTrial'] - data['blockLength'] # reverse block position from transition
data['model']='rat'
data['highPort'] = data['Decision']==data['Target'] # boolean, chose higher probability port

train_features, _, _ = reps.pull_sample_dataset(train_session_ids, data)
test_features, _, block_pos_core = reps.pull_sample_dataset(test_session_ids, data)
bpos_mouse = bp.get_block_position_summaries(block_pos_core)
bpos_mouse['condition'] = 'rat'

cprobs

In [13]:
df_mouse_symm_reference = cprobs.calc_conditional_probs(data, symm=True, action=['Switch']).sort_values('pswitch')
df_mouse_symm = cprobs.calc_conditional_probs(block_pos_core, symm=True, action=['Switch', 'Decision'])
df_mouse_symm = cprobs.sort_cprobs(df_mouse_symm, df_mouse_symm_reference.history.values)
bp.plot_sequences(df_mouse_symm, alpha=0.5) 

In [14]:
L1 = 1 # choice history
L2 =  5 # choice * reward history
L3 = 0
memories = [L1, L3, L2, 1]
# print(memories)
lr = models.fit_logreg_policy(train_features, memories) # refit model with reduced histories, training set
model_probs = models.compute_logreg_probs(test_features, lr_args=[lr, memories])
# model_probs = models.fit_logreg_policy(test_features, memories )

In [34]:
pd.options.display.max_rows = 1000
model_choices, model_switches = models.model_to_policy(model_probs, test_features, policy='softmax', e = 0.2, temp = 1.25)
block_pos_model = reps.reconstruct_block_pos(block_pos_core, model_choices, model_switches)
bpos_model = bp.get_block_position_summaries(block_pos_model)
# print(bpos_model)
bpos_model['condition'] = 'model' # label model predictions as such
condition_series = pd.DataFrame({'condition':['model']*len(bpos_model)}, index = bpos_model.index)
# print(condition_series)
bpos_model = pd.concat([bpos_model, condition_series])
bpos_model_v_mouse = pd.concat((bpos_mouse, bpos_model)) # agg df with model predictions and mouse data

# print(bpos_model_v_mouse.head(50))

color_dict = {'rat': 'gray', 'model': sns.color_palette()[0]}#plot_config['model_seq_col']}
bp.plot_by_block_position(bpos_model_v_mouse, subset='condition', color_dict = color_dict)

symm_cprobs_model = cprobs.calc_conditional_probs(block_pos_model, symm=True, action=['Switch'])
symm_cprobs_model = cprobs.sort_cprobs(symm_cprobs_model, df_mouse_symm.history.values)
bp.plot_sequences(df_mouse_symm, overlay=symm_cprobs_model, main_label='rat', overlay_label='model')

bp.plot_scatter(df_mouse_symm, symm_cprobs_model)

In [99]:
raw_cm, norm_cm = bp.calc_confusion_matrix(df_mouse_symm, 'pswitch' ,df_model = symm_cprobs_model  )
plt.figure()
sns.heatmap(norm_cm, annot = True, fmt = 'g')
plt.title('')


Text(0.5, 1.0, '')

In [101]:
from sklearn.metrics import RocCurveDisplay
RocCurveDisplay.from_predictions(block_pos_core.Switch, block_pos_model.Switch)

<sklearn.metrics._plot.roc_curve.RocCurveDisplay at 0x23b5ac6b290>

In [37]:
# norm_cm = np.array([[0.95, 0.05], [0.76, 0.24]])

In [25]:
cm = norm_cm
recall_per_class = np.diag(cm)/np.sum(cm)
balanced_accuracy = np.mean(recall_per_class)
balanced_accuracy

0.2634791028514596

In [26]:
cm = norm_cm
precision = cm[0][0]/(cm[0][0]+cm[1][0])
recall = cm[0][0]/(cm[0][0]+cm[0][1])
f1 = 2*(precision*recall/(precision+recall))

f1

0.6721864315029584

In [84]:
coefs_grump = lr.coef_

In [86]:
coefs_list = [coefs_gronckle, coefs_toothless, coefs_brat, coefs_grump]

In [58]:
fig, ax = plt.subplots()
colors = ['g', 'g', 'r', 'r']
c = 0
for coefs in coefs_list:
    print(c)
    ax.plot(np.arange(1, 7, 1), coefs[0][1:], color = colors[c])
    c+=1
ax.plot(np.arange(1, 7, 1), np.zeros(6), linestyle = '--', color = 'k', alpha = 0.5)
sns.despine()

NameError: name 'coefs_list' is not defined

In [19]:
sort_order = pd.read_csv('cprobs sort order.csv')
sort_order = list(sort_order['0'].values)

In [20]:
def group_choices_v_model(cprobs_str, cprobs_model, cprobs_str_err, cprobs_model_err, sort_order, comb_list = ['aaa', 'aab', 'aba', 'abb']):
    '''
    group conditinal probabilities by choice history
    cprobs_str, cprobs_model : averaged dataframes for str and unstr environment. Insert empty dataframe to exclude.
    comb_list: list of choice history combinaitons
    sort_order: sorting for history combinations, list
    '''
    
    
    sns.set_palette('deep')
    
    cprobs_str_list = cprobs_str
    cprobs_ustr_list = cprobs_model
    
    cprobs_str_df = pd.DataFrame(data = {'history':sort_order, 'pswitch': cprobs_str})
    cprobs_ustr_df = pd.DataFrame(data = {'history':sort_order, 'pswitch': cprobs_model})
    
    cprobs_str_lower = cprobs_str_df['history'].map(lambda x : x.lower())
    cprobs_ustr_lower = cprobs_ustr_df['history'].map(lambda x : x.lower())


    fig, ax = plt.subplots(2, 2, figsize = (12, 6), layout = 'tight')
    fig.suptitle('Model vs Rat')
    loc = [(0, 0), (0, 1), (1, 0), (1,1)] #list of (row, column) tuples
    i = 0
    for comb in comb_list:
        row, column = loc[i][0], loc[i][1] 
        #store indices of each combination:
        comb_str = cprobs_str_lower[cprobs_str_lower== comb].index.tolist()
        comb_ustr = cprobs_ustr_lower[cprobs_ustr_lower== comb].index.tolist()
        #calculate standard error
        err_str = cprobs_str_err.iloc[comb_str]
        err_ustr = cprobs_model_err.iloc[comb_ustr]
        #plot
        df_str = cprobs_str_df
        df_ustr = cprobs_ustr_df

        
        ax[row][column].bar(x =df_str[(df_str['history'].map(lambda x : x.lower())) == comb].history , height = df_str[(df_str['history'].map(lambda x : x.lower())) == comb].pswitch, color = sns.color_palette()[0], alpha = 1, yerr = err_str)
        ax[row][column].bar(x = df_ustr[(df_ustr['history'].map(lambda x : x.lower())) == comb].history, height = df_ustr[(df_ustr['history'].map(lambda x : x.lower())) == comb].pswitch, color = 'k', alpha = 0.4, yerr = err_ustr)
        ax[row][column].set_ylim(bottom = 0)
        ax[row][column].title.set_text(comb_list[i].upper())

        i+=1
    
    return

In [65]:

group_choices_v_model(df_mouse_symm.pswitch.values, symm_cprobs_model.pswitch.values, (df_mouse_symm.pswitch_err), (symm_cprobs_model.pswitch_err), sort_order, comb_list = ['aaa', 'aab', 'aba', 'abb'])


In [46]:
cprobs_str_list = list(df_mouse_symm.pswitch.values)
cprobs_ustr_list = list(df_mouse_symm.pswitch.values)

In [64]:

group_by_reward(df_mouse_symm, symm_cprobs_model, df_mouse_symm.pswitch, symm_cprobs_model.pswitch, df_mouse_symm.pswitch.values, symm_cprobs_model.pswitch.values, sort_order)