In [1]:
# warning settings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Data management
import pandas as pd
import numpy as np
import pickle

# Plotting
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
from scipy import stats
# HDDM
import hddm
from scipy.stats import norm
from sklearn.neighbors import KernelDensity
from scipy.stats import gaussian_kde


df_orig = pd.read_csv('/data/victoria/HDDM_code/full_brain_behav.csv')
df = df_orig
df.rt = df.rt/1000
df = hddm.utils.flip_errors(df)
df = df.dropna()
late_learn = df[df['run'].isin([5, 6])].copy()

In [7]:
roi_ls = [1,2,51,52,53,14,15,16,17,18,20,23,24,25,65,67,68,70,74,77,31,32,33,34,35,38,39,40,41,42,43,44,48,80,82,83,88,90,91,92,93,94] # only DMN regions this time
# roi_ls = [39]

In [None]:
for i in range(len(roi_ls)):
    roi = roi_ls[i]
    m_reg = hddm.HDDMRegressor(late_learn, [f"v ~ roi{roi} * C(type,Treatment('prototype'))", \
                                        f"a ~ roi{roi} * C(type,Treatment('prototype'))",\
                                        f"t ~ roi{roi}* C(type,Treatment('prototype'))"],
                                    p_outlier=0.05)  
                                    
    m_reg.find_starting_values()
    m_reg.sample(200, burn=100)  

In [14]:
df = pd.DataFrame()
for i in range(len(roi_ls)):
    roi = roi_ls[i]
    x_min = min(late_learn[f'roi{roi}'])
    x_max= max(late_learn[f'roi{roi}'])
    x = np.linspace(x_min, x_max, 20)
    x_2d = x[:, np.newaxis]

    df[f'roi{roi}'] = x
    


In [17]:
df.to_csv('roi_betas.csv')

In [6]:
# plot drift rate actual values
v_Intercept = m_reg.nodes_db.node["v_Intercept"]
v_roi = m_reg.nodes_db.node[f"v_roi{roi}"]

v_exception = m_reg.nodes_db.node["v_C(type, Treatment('prototype'))[T.exception]"]
v_rf = m_reg.nodes_db.node["v_C(type, Treatment('prototype'))[T.rule follower]"]
v_roi_exception = m_reg.nodes_db.node[f"v_roi{roi}:C(type, Treatment('prototype'))[T.exception]"]
v_roi_rulefollower = m_reg.nodes_db.node[f"v_roi{roi}:C(type, Treatment('prototype'))[T.rule follower]"]

# putting it together 
v_Prototype =  v_Intercept.trace() + x_2d * v_roi.trace()
v_RuleFollower = v_rf.trace() + v_Intercept.trace() + x_2d * (v_roi.trace() + v_roi_rulefollower.trace())
v_Exception= v_Intercept.trace() + v_exception.trace() + x_2d * (v_roi.trace() + v_roi_exception.trace())
v_Prototype = v_Prototype.T
v_RuleFollower = v_RuleFollower.T
v_Exception = v_Exception.T

medians_p = np.median(v_Prototype, axis=0)
medians_rf = np.median(v_RuleFollower, axis=0)
medians_e = np.median(v_Exception, axis=0)

# Calculate the confidence interval for each column
# You can specify the desired confidence level, e.g., 95%
CI_P = np.percentile(v_Prototype, [2.5, 97.5], axis=0)
CI_rf = np.percentile(v_RuleFollower, [2.5, 97.5], axis=0)
CI_e = np.percentile(v_Exception, [2.5, 97.5], axis=0)
# Assuming you have medians and confidence intervals for the three datasets

In [7]:
trace_dict = {'v_proto': v_Prototype, 'v_excep': v_Exception, 'v_rf': v_RuleFollower}

In [27]:
# Stack the matrices vertically
stacked_matrix = np.vstack([matrix for matrix in trace_dict.values()])

# Create a DataFrame from the stacked matrix
df = pd.DataFrame(stacked_matrix)

# Save the DataFrame to a CSV file
df.to_csv('output.csv', index=False, header=False)  # Set header=False to omit column headers

In [31]:
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,-0.251125,-0.223345,-0.195566,-0.167787,-0.140008,-0.112228,-0.084449,-0.056670,-0.028891,-0.001111,0.026668,0.054447,0.082226,0.110006,0.137785,0.165564,0.193343,0.221123,0.248902,0.276681
1,-0.319451,-0.273488,-0.227525,-0.181562,-0.135600,-0.089637,-0.043674,0.002289,0.048252,0.094215,0.140178,0.186141,0.232104,0.278067,0.324030,0.369993,0.415956,0.461919,0.507882,0.553845
2,-0.438277,-0.388411,-0.338546,-0.288680,-0.238815,-0.188949,-0.139084,-0.089218,-0.039353,0.010513,0.060378,0.110244,0.160109,0.209975,0.259840,0.309706,0.359571,0.409437,0.459302,0.509168
3,-0.406466,-0.362360,-0.318255,-0.274149,-0.230044,-0.185939,-0.141833,-0.097728,-0.053623,-0.009517,0.034588,0.078693,0.122799,0.166904,0.211010,0.255115,0.299220,0.343326,0.387431,0.431536
4,-0.286809,-0.249465,-0.212122,-0.174778,-0.137435,-0.100092,-0.062748,-0.025405,0.011939,0.049282,0.086625,0.123969,0.161312,0.198656,0.235999,0.273342,0.310686,0.348029,0.385373,0.422716
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
295,-0.596801,-0.530469,-0.464137,-0.397805,-0.331474,-0.265142,-0.198810,-0.132478,-0.066146,0.000186,0.066518,0.132850,0.199182,0.265514,0.331846,0.398178,0.464510,0.530842,0.597174,0.663506
296,-0.404140,-0.364342,-0.324544,-0.284746,-0.244948,-0.205149,-0.165351,-0.125553,-0.085755,-0.045957,-0.006158,0.033640,0.073438,0.113236,0.153034,0.192833,0.232631,0.272429,0.312227,0.352025
297,-0.848879,-0.759027,-0.669175,-0.579323,-0.489472,-0.399620,-0.309768,-0.219916,-0.130065,-0.040213,0.049639,0.139491,0.229342,0.319194,0.409046,0.498898,0.588750,0.678601,0.768453,0.858305
298,-0.814649,-0.727297,-0.639946,-0.552594,-0.465243,-0.377891,-0.290539,-0.203188,-0.115836,-0.028485,0.058867,0.146219,0.233570,0.320922,0.408273,0.495625,0.582977,0.670328,0.757680,0.845031


In [32]:
v_Prototype

array([[-0.25112467, -0.22334543, -0.19556618, ...,  0.2211225 ,
         0.24890175,  0.27668099],
       [-0.3194514 , -0.27348844, -0.22752547, ...,  0.46191907,
         0.50788204,  0.55384501],
       [-0.43827666, -0.38841116, -0.33854566, ...,  0.40943685,
         0.45930235,  0.50916785],
       ...,
       [-0.43423367, -0.38243088, -0.33062809, ...,  0.4464137 ,
         0.49821648,  0.55001927],
       [-0.21358958, -0.18649851, -0.15940743, ...,  0.24695869,
         0.27404977,  0.30114084],
       [-0.24463366, -0.21553413, -0.18643459, ...,  0.25005849,
         0.27915803,  0.30825757]])