In [1]:
import nest_asyncio
import stan
import numpy as np
import pandas as pd
import os
from os import path
import glob
from RLmodels import QLearningModel
# from RLmodels import qLearningModel_5params_simNoPlot
from RLmodels import RestlessBanditDecoupled
from RLmodels import QLearningModelSim, myPairPlot, getSessionFitParams
import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az
from scipy.stats import spearmanr
from scipy.stats import pearsonr
import re
from scipy.stats import norm
from scipy.stats import halfcauchy
from scipy.stats import cauchy
from sklearn.linear_model import LinearRegression
import pickle
import json
nest_asyncio.apply()
from aind_dynamic_foraging_basic_analysis.lick_analysis import plot_lick_analysis, load_data, load_nwb, cal_metrics, plot_met
from aind_dynamic_foraging_basic_analysis import plot_foraging_session
from beh_utils import*
import statsmodels.api as sm
import json
import ast

In [2]:
# load curated session data
animalID = '717121'
animal_dir = f'/root/capsule/scratch/{animalID}'
ani_session_file = f'/root/capsule/scratch/{animalID}/{animalID}_session_data.csv'

if not os.path.exists(ani_session_file):
    print(f'File {ani_session_file} does not exist, session is not curated yet')
else:
    ani_session_data = pd.read_csv(ani_session_file)
    print(f'{len(ani_session_data)} sessions are curated for animal {animalID}')

10 sessions are curated for animal 717121


In [3]:
# session data load
data_dir = '/root/capsule/data/foraging_nwb_bonsai'
choices = []
outcomes = []
sessionLens = []
sessNum = len(ani_session_data)
for sessionInd in range(len(ani_session_data)):
    print('Extracting session data for session ', ani_session_data['session_id'][sessionInd])
    nwb_file = os.path.join(data_dir, ani_session_data['session_id'][sessionInd] + '.nwb')
    nwb = load_nwb(nwb_file)
    trial_count = len(nwb.trials.to_dataframe())
    curr_cut = ast.literal_eval(ani_session_data['session_cut'][sessionInd])
    # convert curr_cut to start and end removal
    curr_cut[1] = trial_count - curr_cut[1]
    choice_tbl = makeSessionDF(nwb, curr_cut)

    choices.append(list(choice_tbl['choices'].values))
    outcomes.append(list(choice_tbl['outcomes'].values)) 
    sessionLens.append(len(choice_tbl))

Extracting session data for session  717121_2024-05-30_13-24-19
Extracting session data for session  717121_2024-06-03_11-12-58
Extracting session data for session  717121_2024-06-04_12-06-13
Extracting session data for session  717121_2024-06-05_12-05-27
Extracting session data for session  717121_2024-06-06_11-36-13
Extracting session data for session  717121_2024-06-07_11-41-07
Extracting session data for session  717121_2024-06-10_11-37-52
Extracting session data for session  717121_2024-06-14_10-23-49
Extracting session data for session  717121_2024-06-15_10-00-58
Extracting session data for session  717121_2024-06-16_11-45-02


In [4]:
# make data for stan
maxLen = max(sessionLens)
allChoiceArray = choices
allOutcomeArray = outcomes
allChoiceArray = np.array([np.pad(choiceSimCurr, (0, maxLen-sessionLenCurr), mode='constant') for choiceSimCurr, sessionLenCurr in zip(choices, sessionLens)]).astype(int)
allOutcomeArray = np.array([np.pad(outcomeSimCurr, (0, maxLen-sessionLenCurr), mode='constant') for outcomeSimCurr, sessionLenCurr in zip(outcomes, sessionLens)]).astype(int)
sim_data = {"N": sessNum,
            "T": maxLen,
            "Tsesh": sessionLens,
            "choice": allChoiceArray,
            "outcome": allOutcomeArray}

In [5]:
# Read the Stan model from a file
model = '/code/stan_qLearning_5params.stan'
with open(model, "r") as file:
    model_code = file.read()

In [6]:
# fitting
posterior = stan.build(model_code, data=sim_data)
fit = posterior.sample(num_chains=16, num_samples=5000, num_warmup=2500)

Building...



Building: found in cache, done.Messages from stanc:
    control flow statement depends on parameter(s): aF_pr, aN_pr, aP_pr,
    mu_p, sigma.
    control flow statement depends on parameter(s): aF_pr, aN_pr, aP_pr,
    mu_p, sigma.
    20 suggests there may be parameters that are not unit scale; consider
    rescaling with a multiplier (see manual section 22.12).
Sampling:   0%
Sampling:   0% (1/120000)
Sampling:   0% (2/120000)
Sampling:   0% (3/120000)
Sampling:   0% (4/120000)
Sampling:   0% (5/120000)
Sampling:   0% (6/120000)
Sampling:   0% (7/120000)
Sampling:   0% (8/120000)
Sampling:   0% (9/120000)
Sampling:   0% (10/120000)
Sampling:   0% (11/120000)
Sampling:   0% (12/120000)
Sampling:   0% (13/120000)
Sampling:   0% (14/120000)
Sampling:   0% (15/120000)
Sampling:   0% (16/120000)
Sampling:   0% (115/120000)
Sampling:   0% (214/120000)
Sampling:   0% (313/120000)
Sampling:   0% (412/120000)
Sampling:   0% (511/120000)
Sampling:   1% (610/120000)
Sampling:   1% (709/120000)

In [7]:
# summarize
summaryMean = az.summary(fit, stat_focus = 'mean')
summaryMedian = az.summary(fit, stat_focus = 'median')
summary = pd.merge(summaryMean, summaryMedian, left_index=True, right_index=True)

In [None]:
# save
# make session-wise dataframe
# save 
paramNames = ['aN', 'aP', 'aF', 'beta', 'bias']
saveDir = path.expanduser('~/capsule/scratch/'+animalID+'/stan_qLearning_5params')
os.makedirs(saveDir, exist_ok=True)

paramsFit = getSessionFitParams(summary, paramNames, focus = 'mean')

summary.to_csv(saveDir+'/summary.csv', index=True)
# paramsSim.to_csv(saveDir+'/paramsSim.csv')  
paramsFit.to_csv(saveDir+'/paramsFit.csv')
ani_session_data.to_csv(saveDir+'/ani_session_data.csv', index=False)

samples = dict(fit)
with open(saveDir+'/samples', 'wb') as pickle_file:
    pickle.dump(samples, pickle_file)
