----
Overview of notebook:
----
Set up plotting variables

Section 1: 
    * Beh Plots

Section 2: 
    * Plots for Parametric Modulation analyses

Section 3:
    * Plots for Decoding analyses

Section 4:
    * Plots for Network Hubs

In [None]:
# import packages we will need
import pandas as pd
import numpy as np
import numpy.matlib as mb
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
import pickle
import os
import sys
import re
import glob as glob
import argparse
import nilearn
from nilearn import datasets, plotting, surface
from nilearn.image import resample_img
from nilearn import masking
from nilearn import image
from nilearn.input_data import NiftiMasker
from nilearn.input_data import NiftiLabelsMasker
from nilearn.glm import threshold_stats_img
from nilearn.datasets import load_fsaverage_data
from nilearn.plotting import plot_img_on_surf
from nilearn.plotting import plot_surf_contours
import nibabel as nib

def create_nii(stats_mat, cortical_mask):
    cortical_masker = NiftiMasker(cortical_mask)
    cortical_masker.fit()
    stat_nii = cortical_masker.inverse_transform(stats_mat)
    return stat_nii

# define function to save out plots of model outputs
def gen_model_plots(c_tState, c_tTask, c_tProp, c_state, c_task, c_color, plot_filename, tResp, t_belief_alpha, c_belief_alpha):
    err_colors = ['grey','orange','red','white']
    err_ys = [-0.1, -0.125, -0.15, -.2]
    # make MODEL ESTIMATE plots
    plt.rcParams["figure.figsize"] = (50,20)
    fig, axs = plt.subplots(3,1) # rows, columns
    
    axs[0].plot(c_tState, 'black', alpha=t_belief_alpha)
    axs[0].plot(c_state, 'blue', alpha=c_belief_alpha)
    for x_coord in range(len(c_tState)):
        axs[0].plot(x_coord, err_ys[tResp[x_coord]], color=err_colors[tResp[x_coord]], marker='o', alpha=0.25)
    axs[0].set_ylabel("STATE", rotation=0, fontsize=25, labelpad=20)
    
    axs[1].plot(c_tTask, 'black', alpha=t_belief_alpha)
    axs[1].plot(c_task, 'green', alpha=c_belief_alpha)
    for x_coord in range(len(c_tTask)):
        axs[1].plot(x_coord, err_ys[tResp[x_coord]], color=err_colors[tResp[x_coord]], marker='o', alpha=0.25)
    axs[1].set_ylabel("TASK", rotation=0, fontsize=25, labelpad=20)
    
    axs[2].plot(c_tProp, 'black', alpha=t_belief_alpha)
    axs[2].plot(c_color, 'gold', alpha=c_belief_alpha)
    for x_coord in range(len(c_tProp)):
        axs[2].plot(x_coord, err_ys[tResp[x_coord]], color=err_colors[tResp[x_coord]], marker='o', alpha=0.25)
    axs[2].set_ylabel("COLOR", rotation=0, fontsize=25, labelpad=20)
    
    plt.show()
    plt.savefig(plot_filename)
    plt.close()

# -------------------------------------------------
# -- project folder directory
dataset_dir = "/mnt/nfs/lss/lss_kahwang_hpc/data/FPNHIU/"

# -- roi mask directory
mask_dir = "/mnt/nfs/lss/lss_kahwang_hpc/ROIs/"
cortical_mask = nib.load(mask_dir + "CorticalMask_RSA_task-Quantum.nii.gz")
cortical_mask_data = cortical_mask.get_fdata()

# -- csv directory
data_dir = "/mnt/nfs/lss/lss_kahwang_hpc/data/FPNHIU/CSVs/"

# -- model data directory
model_data_output = "/mnt/nfs/lss/lss_kahwang_hpc/data/FPNHIU/model_data/"

# -- set sns theme
sns.set_theme(context='poster', style='whitegrid', )

Section 1

# ------------------------------------------------------------------------------------------
# - - - - - - - - - - - - - - - - -
# - -   Behavioral Plots   - -
# - - - - - - - - - - - - - - - - -

In [None]:
""" Load CSV files and create a master file with all subjects """
# Load ouptut csv files for the two tasks and put each in their own respective master files (all subjects in long format)
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# get list of files for each task
pilots_to_exclude=["10190","10162","10260","10261","10245"]
subjs_to_exclude = ["10296", "10319", "10275", "10321", "10218", "10318", "10322", "10118", "10282", "10351", "10393", "10358"]
DM_list = glob.glob(os.path.join(model_data_output,("*_dataframe_3dDeconvolve_pSpP.csv"))) #withModelEstimates.csv")))  dataframe_model-pSpP
print(DM_list)
DM_df_list = []
DM_usable_list = []
for cur_DM in sorted(DM_list):
    # load pSpP file
    temp_df = pd.read_csv(cur_DM)
    temp_df["block"]=temp_df["block"]+1 #fix python index
    temp_df = temp_df[temp_df['block']<6]
    #print(len(temp_df))
    
    # add subject information
    sid=re.search("[0-9]{5}", cur_DM)
    if sid:
        temp_df["Participant_ID"] = sid.group(0)
    
    # -- load control model 1
    with open(os.path.join(model_data_output,("sub-"+sid.group(0)+"_pSdP_belief-state.p")), 'rb') as handle:
        pSdP_sE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+sid.group(0)+"_pSdP_belief-color.p")), 'rb') as handle:
        pSdP_cE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+sid.group(0)+"_pSdP_belief-task.p")), 'rb') as handle:
        pSdP_tE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+sid.group(0)+"_pSdP_belief-task.p")), 'rb') as handle:
        pSdP_entropy = pickle.load(handle)
    temp_df["pSdP_sE"] = pSdP_sE_belief[:200,0]
    temp_df["pSdP_tE"] = pSdP_tE_belief[:200]
    temp_df["pSdP_cE"] = pSdP_cE_belief[:200]
    temp_df["pSdP_entropy"] = pSdP_entropy[:200]
    
    # -- load control model 2
    with open(os.path.join(model_data_output,("sub-"+sid.group(0)+"_dSpP_belief-state.p")), 'rb') as handle:
        dSpP_sE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+sid.group(0)+"_dSpP_belief-color.p")), 'rb') as handle:
        dSpP_cE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+sid.group(0)+"_dSpP_belief-task.p")), 'rb') as handle:
        dSpP_tE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+sid.group(0)+"_dSpP_belief-task.p")), 'rb') as handle:
        dSpP_entropy = pickle.load(handle)
    temp_df["dSpP_sE"] = dSpP_sE_belief[:200,0]
    temp_df["dSpP_tE"] = dSpP_tE_belief[:200]
    temp_df["dSpP_cE"] = dSpP_cE_belief[:200]
    temp_df["dSpP_entropy"] = dSpP_entropy[:200]
    
    # -- load control model 3
    with open(os.path.join(model_data_output,("sub-"+sid.group(0)+"_dSdP_belief-state.p")), 'rb') as handle:
        dSdP_sE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+sid.group(0)+"_dSdP_belief-color.p")), 'rb') as handle:
        dSdP_cE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+sid.group(0)+"_dSdP_belief-task.p")), 'rb') as handle:
        dSdP_tE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+sid.group(0)+"_dSdP_belief-task.p")), 'rb') as handle:
        dSdP_entropy = pickle.load(handle)
    temp_df["dSdP_sE"] = dSdP_sE_belief[:200,1]
    temp_df["dSdP_tE"] = dSdP_tE_belief[:200]
    temp_df["dSdP_cE"] = dSdP_cE_belief[:200]
    temp_df["dSdP_entropy"] = dSdP_entropy[:200]
    
    # -- load control model 4
    with open(os.path.join(model_data_output,("sub-"+sid.group(0)+"_pSjP_belief-state.p")), 'rb') as handle:
        pSjP_sE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+sid.group(0)+"_pSjP_belief-color.p")), 'rb') as handle:
        pSjP_cE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+sid.group(0)+"_pSjP_belief-task.p")), 'rb') as handle:
        pSjP_tE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+sid.group(0)+"_pSjP_belief-task.p")), 'rb') as handle:
        pSjP_entropy = pickle.load(handle)
    temp_df["pSjP_sE"] = pSjP_sE_belief[:200,0] # temp_df5["pSjP_sE"]
    temp_df["pSjP_tE"] = pSjP_tE_belief[:200] # temp_df5["pSjP_tE"]
    temp_df["pSjP_cE"] = pSjP_cE_belief[:200] # temp_df5["pSjP_cE"]
    temp_df["pSjP_entropy"] = pSjP_entropy[:200] # temp_df5["pSjP_entropy"]
        
    # load current subject input DATA files
    c_tState = np.load(os.path.join(model_data_output,("sub-"+sid.group(0)+"_tState.npy")))
    c_tProp = np.load(os.path.join(model_data_output,("sub-"+sid.group(0)+"_tProp.npy")))
    c_tResp = np.load(os.path.join(model_data_output,("sub-"+sid.group(0)+"_tResp.npy")))
    c_tTask = np.load(os.path.join(model_data_output,("sub-"+sid.group(0)+"_tTask.npy")))
    temp_df["tState"]=c_tState[:200]
    temp_df["tProp"]=c_tProp[:200]
    temp_df["tResp"]=c_tResp[:200]
    temp_df["tTask"]=c_tTask[:200]
    
    # add session information
    # sses=re.search("session-[0-2]{3}", cur_DM)
    # if sses:
    #     temp_df["session"] = int(sses.group(0)[-3:])
    # add to df list
    if sid.group(0) not in pilots_to_exclude:
        DM_df_list.append(temp_df)
        if sid.group(0) not in subjs_to_exclude:
            DM_usable_list.append(temp_df)
DM_df = pd.concat(DM_df_list, ignore_index=True) # merge dfs in list
DM_df_usable = pd.concat(DM_usable_list, ignore_index=True)
# remove no response rows (do not consider trials where no response was made)
#DM_df = DM_df[DM_df['RT']>0]
print("master data frames generated")
print(len(DM_df.Participant_ID.unique()), "subjects completed the Quantum task")
print(len(DM_df_usable.Participant_ID.unique()), "subjects are usable")

DM_df_usable.to_csv( os.path.join(model_data_output,("MASTER_dataframe__date-20250823.csv")) )

df = DM_df_usable[DM_df_usable['RT']>.199]
df = df[df['RT']<3.5]
df

""" Brief description of computational models """

"""
MPE MODEL
maximum posterior estimator of joint distribution of state and hyperparameters

input:
tState: trial wise task-set (0: color 0 = face, color 1 = scene; 1: color 0 = scene, color 1 = face)
tProp: trial-wise proportion of dots (0 - 1, for the first color)
tResp: trial-wise response: 0 = correct, 1 = correct task wrong answer, 2 = wrong task

output:
jd: joint distribution of parameters at the end of experiment
sE: trial-wise estimate of state, encoding probability of state 0
tE: trial-wise estimate of task, encoding probability of task 0
mDist: marginal distribution for the 4 thetas,
rRange: values of the thetas"""

"""
need to re-format data for computational model ... MODIFIED 4/21/2024 to flip tTask
tState
  tState = np.where(cur_df['state']==-1, 1, 0)  # recode state (1 -> 0 AND -1 -> 1 ... this flip matches model better)
tTask
  tTask   # face=0, scene=1
tProp
  tProp = np.array(cur_df['amb_r']).flatten()  # set as proportion of red
tResp
  0=correct , 1=right task but wrong answer , 2=wrong task
"""

"""
CONTROL MODEL(S)
maximum posterior estimator of joint distribution of state and hyperparameters

input:
tState: trial wise task-set (0: color 0 = face, color 1 = scene; 1: color 0 = scene, color 1 = face)
tProp: trial-wise proportion of dots (0 - 1, for the first color)
tResp: trial-wise response: 0 = correct, 1 = correct task wrong answer, 2 = wrong task

output:
jd: joint distribution of parameters at the end of experiment
sE: trial-wise estimate of state, encoding probability of state 0
tE: trial-wise estimate of task, encoding probability of task 0
mDist: marginal distribution for the 4 thetas,
rRange: values of the thetas"""

In [None]:
master_df_list = []
theta_df = {'sub':[], 'lfsp':[], 'lfip':[], 'fter':[], 'ster':[], 'diff_at':[]}
for subj_opt in df['Participant_ID'].unique():
    # ----- load input data for this subject
    print("\nloading dataframe for subject",str(subj_opt))
    cur_df = pd.read_csv(os.path.join(model_data_output, 'sub-'+subj_opt+'_dataframe_3dDeconvolve_pSpP.csv'))
    theta_df['sub'].append(subj_opt)
    
    # ----- load mpe model
    with open(os.path.join(model_data_output,("sub-"+subj_opt+"_pState_pPercept_outputs.p")), 'rb') as handle:
        mpe_model_dict=pickle.load(handle)
    s_jd = np.sum(mpe_model_dict['jd'], axis=0)
    s_pRange = mpe_model_dict['pRange']
    s_theta = [s_pRange['lfsp'][np.argmax(np.sum(s_jd,axis=(1,2,3,4)))], 
                s_pRange['lfip'][np.argmax(np.sum(s_jd,axis=(0,2,3,4)))], 
                s_pRange['fter'][np.argmax(np.sum(s_jd,axis=(0,1,3,4)))], 
                s_pRange['ster'][np.argmax(np.sum(s_jd,axis=(0,1,2,4)))], 
                s_pRange['diff_at'][np.argmax(np.sum(s_jd,axis=(0,1,2,3,)))]]
    theta_df['lfsp'].append(s_theta[0])
    theta_df['lfip'].append(s_theta[1])
    theta_df['fter'].append(s_theta[2])
    theta_df['ster'].append(s_theta[3])
    theta_df['diff_at'].append(s_theta[4])
    # load beliefs
    with open(os.path.join(model_data_output,("sub-"+subj_opt+"_pSpP_belief-state.p")), 'rb') as handle:
        sE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+subj_opt+"_pSpP_belief-task.p")), 'rb') as handle:
        tE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+subj_opt+"_pSpP_belief-color.p")), 'rb') as handle:
        cE_belief = pickle.load(handle)
    
    # ----- load control model 1 
    with open(os.path.join(model_data_output,("sub-"+subj_opt+"_dState_pPercept_outputs.p")), 'rb') as handle:
        dSpP_model_dict=pickle.load(handle)
    c1_jd = np.sum(dSpP_model_dict['jd'], axis=0)
    c1_pRange = dSpP_model_dict['pRange']
    c1_theta = [c1_pRange['lfsp'][np.argmax(np.sum(c1_jd,axis=(1,2,3,4)))], 
                c1_pRange['lfip'][np.argmax(np.sum(c1_jd,axis=(0,2,3,4)))], 
                c1_pRange['fter'][np.argmax(np.sum(c1_jd,axis=(0,1,3,4)))], 
                c1_pRange['ster'][np.argmax(np.sum(c1_jd,axis=(0,1,2,4)))], 
                c1_pRange['diff_at'][np.argmax(np.sum(c1_jd,axis=(0,1,2,3,)))]]
    # load beliefs
    with open(os.path.join(model_data_output,("sub-"+subj_opt+"_dSpP_belief-state.p")), 'rb') as handle:
        dSpP_sE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+subj_opt+"_dSpP_belief-task.p")), 'rb') as handle:
        dSpP_tE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+subj_opt+"_dSpP_belief-color.p")), 'rb') as handle:
        dSpP_cE_belief = pickle.load(handle)
    # add control models to master data frame
    cur_df['dSpP_sE'] = dSpP_sE_belief[:,0]
    cur_df['dSpP_tE'] = dSpP_tE_belief[:]
    cur_df['dSpP_cE'] = dSpP_cE_belief[:]
    
    # ----- load control model 2 
    with open(os.path.join(model_data_output,("sub-"+subj_opt+"_pState_dPercept_outputs.p")), 'rb') as handle:
        pSdP_model_dict=pickle.load(handle)
    c2_jd = np.sum(pSdP_model_dict['jd'], axis=0)
    c2_pRange = pSdP_model_dict['pRange']
    c2_theta = [c2_pRange['lfsp'][np.argmax(np.sum(c2_jd,axis=(1,2,3,4)))], 
                c2_pRange['lfip'][np.argmax(np.sum(c2_jd,axis=(0,2,3,4)))], 
                c2_pRange['fter'][np.argmax(np.sum(c2_jd,axis=(0,1,3,4)))], 
                c2_pRange['ster'][np.argmax(np.sum(c2_jd,axis=(0,1,2,4)))], 
                c2_pRange['diff_at'][np.argmax(np.sum(c2_jd,axis=(0,1,2,3,)))]]
    # load beliefs
    with open(os.path.join(model_data_output,("sub-"+subj_opt+"_pSdP_belief-state.p")), 'rb') as handle:
        pSdP_sE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+subj_opt+"_pSdP_belief-task.p")), 'rb') as handle:
        pSdP_tE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+subj_opt+"_pSdP_belief-color.p")), 'rb') as handle:
        pSdP_cE_belief = pickle.load(handle)
    # add control models to master data frame
    cur_df['pSdP_sE'] = pSdP_sE_belief[:,0]
    cur_df['pSdP_tE'] = pSdP_tE_belief[:]
    cur_df['pSdP_cE'] = pSdP_cE_belief[:]
    
    # ----- load control model 3 
    with open(os.path.join(model_data_output,("sub-"+subj_opt+"_dState_dPercept_outputs.p")), 'rb') as handle:
        dSdP_model_dict=pickle.load(handle)
    c3_jd = np.sum(dSdP_model_dict['jd'], axis=0)
    c3_pRange = dSdP_model_dict['pRange']
    c3_theta = [c3_pRange['lfsp'][np.argmax(np.sum(c3_jd,axis=(1,2,3,4)))], 
                c3_pRange['lfip'][np.argmax(np.sum(c3_jd,axis=(0,2,3,4)))], 
                c3_pRange['fter'][np.argmax(np.sum(c3_jd,axis=(0,1,3,4)))], 
                c3_pRange['ster'][np.argmax(np.sum(c3_jd,axis=(0,1,2,4)))], 
                c3_pRange['diff_at'][np.argmax(np.sum(c3_jd,axis=(0,1,2,3,)))]]
    # load beliefs
    with open(os.path.join(model_data_output,("sub-"+subj_opt+"_dSdP_belief-state.p")), 'rb') as handle:
        dSdP_sE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+subj_opt+"_dSdP_belief-task.p")), 'rb') as handle:
        dSdP_tE_belief = pickle.load(handle)
    with open(os.path.join(model_data_output,("sub-"+subj_opt+"_dSdP_belief-color.p")), 'rb') as handle:
        dSdP_cE_belief = pickle.load(handle)
    # add control models to master data frame
    cur_df['dSdP_sE'] = dSdP_sE_belief[:,0]
    cur_df['dSdP_tE'] = dSdP_tE_belief[:]
    cur_df['dSdP_cE'] = dSdP_cE_belief[:]
    
    
    # add control models to master data frame
    cur_df['sub'] = subj_opt
    master_df_list.append(cur_df)
    

    # -- theta parameters for mpe model
    pRange_keys = ['lfsp', 'lfip', 'fter', 'ster', 'diff_at']
    pRange_titles = {'lfsp':'slope param', 'lfip':'intercept param', 'fter':'face error param', 'ster':'scene error param', 'diff_at':'diffusion param'}
    x_ranges = [[1.5, 5], [-3.0, 3.0], [0.0, 0.35], [0.0, 0.35], [0.0, 0.9]]
    pRange_colors = ['deepskyblue','steelblue','orange','green','red'] # set colors so I know what is what
    pRange = mpe_model_dict['pRange']
    mDist = mpe_model_dict['mDist']
    plt.rcParams["figure.figsize"] = (40,10)
    fig_thetas, ax_thetas = plt.subplots(1,5)
    for ind, cur_key in enumerate(pRange_keys):
        ax_thetas[ind].set_xlim(x_ranges[ind][0],x_ranges[ind][1])
        ax_thetas[ind].set_yticklabels([])
        ax_thetas[ind].plot(pRange[cur_key], mDist[cur_key], color=pRange_colors[ind])
        ax_thetas[ind].set_title(pRange_titles[cur_key])
    plt.show()
    plt.savefig(os.path.join(model_data_output,("sub-"+subj_opt+"_model-pSpP__thetas.png")))
    plt.close()

master_df = pd.concat(master_df_list, ignore_index=True) # merge dfs in list

# # ----- save out master data frames
master_df.to_csv(os.path.join(model_data_output, 'sub-'+subj_opt+'_master_dataframe_allmodels.csv'))
theta_df = pd.DataFrame(theta_df)
theta_df.to_csv(os.path.join(model_data_output, 'sub-'+subj_opt+'_master_dataframe_thetas.csv'))

In [None]:
# ---- create model plots
for subj_opt in df['Participant_ID'].unique():
    # ----- load input data for this subject
    print("\nloading files for subject",str(subj_opt))
    # load current subject input DATA files
    c_tState = np.load(os.path.join(model_data_output,("sub-"+subj_opt+"_tState.npy")))
    c_tProp = np.load(os.path.join(model_data_output,("sub-"+subj_opt+"_tProp.npy")))
    c_tResp = np.load(os.path.join(model_data_output,("sub-"+subj_opt+"_tResp.npy")))
    c_tTask = np.load(os.path.join(model_data_output,("sub-"+subj_opt+"_tTask.npy")))
    cur_df = pd.read_csv(os.path.join(model_data_output, 'sub-'+subj_opt+'_dataframe_3dDeconvolve_pSpP.csv'))
    # also set up binary color variable
    c_tProp_bin = c_tProp.round()
    
    # ----- make plots
    sE_belief = cur_df['pSpP_sE']
    tE_belief = cur_df['pSpP_tE']
    cE_belief = cur_df['pSpP_cE']
    
    t_belief_alpha = 0.5
    c_belief_alpha = 0.5
    # -- model estimates
    c_sE = 1 - sE_belief[:200] # [:,1] = probability state 1
    c_tE = 1 - tE_belief[:200] # 1 - tE ... because tE = probability task 0 ... we want probability task 1
    c_cE = cE_belief[:200]
    gen_model_plots(c_tState[:200], c_tTask[:200], c_tProp[:200], c_sE[:200], c_tE[:200], c_cE[:200], os.path.join(model_data_output,(subj_opt+"_param-ModelEstimates_model-pSpP.png")), c_tResp[:200], t_belief_alpha, c_belief_alpha)
    
    t_belief_alpha = 0.15
    c_belief_alpha = 0.5
    # -- prediction error
    c_sE = np.abs(c_tState[:200] - (1 - sE_belief[:200])) # [:,1] = probability state 1
    c_tE = np.abs(c_tTask[:200] - (1 - tE_belief[:200]))  # 1 - tE ... because tE = probability task 0 ... we want probability task 1
    c_cE = np.abs(c_tProp_bin[:200] - cE_belief[:200])
    gen_model_plots(c_tState[:200], c_tTask[:200], c_tProp[:200], c_sE[:200], c_tE[:200], c_cE[:200], os.path.join(model_data_output,(subj_opt+"_param-PredictionError_model-pSpP.png")), c_tResp[:200], t_belief_alpha, c_belief_alpha)

Code below this sets up and then generates the PREDICTION ACCURACY and BIC plots
----

In [None]:
"""
CODE FOR CALCULATING MODEL FIT
...
"""

master_df = DM_df_usable # all trials
master_df_reduced = df # trials with  200ms <= RTs <= 3500ms

df_to_use = master_df

# get the actual task performed 
taskPerformed = []
for trl, resp in enumerate(df_to_use["tResp"]):
    if resp==2:
        taskPerformed.append(1-df_to_use["tTask"][trl])
    else:
        taskPerformed.append(df_to_use["tTask"][trl])
df_to_use["taskPerformed"] = taskPerformed

# get subject level prediction accuracy
subList = []
pSpP = []
dSpP = []
pSdP = []
dSdP = []
pSjP = []
for subj_opt in df_to_use["Participant_ID"].unique():
    subList.append(subj_opt)
    sub_df = df_to_use[df_to_use["Participant_ID"]==subj_opt]
    pSpP.append(sum(sub_df["taskPerformed"]==sub_df["pSpP_tE"].round()))
    dSpP.append(sum(sub_df["taskPerformed"]==sub_df["dSpP_tE"].round()))
    pSdP.append(sum(sub_df["taskPerformed"]==sub_df["pSdP_tE"].round()))
    dSdP.append(sum(sub_df["taskPerformed"]==sub_df["dSdP_tE"].round()))
    pSjP.append(sum(sub_df["taskPerformed"]==sub_df["pSjP_tE"].round()))
    
    print("pSpP prediction err for subject ", subj_opt, " is ", sum(sub_df["taskPerformed"]==sub_df["pSpP_tE"].round())) #, "  ", sum(taskPerformed==pSpP_tE_v2_binary), "  ", sum(taskPerformed==pSpP_tE_v3_binary), "  ", sum(taskPerformed==pSpP_tE_v4_binary))
    print("dSpP prediction err for subject ", subj_opt, " is ", sum(sub_df["taskPerformed"]==sub_df["dSpP_tE"].round()))
    print("pSdP prediction err for subject ", subj_opt, " is ", sum(sub_df["taskPerformed"]==sub_df["pSdP_tE"].round()))
    print("dSdP prediction err for subject ", subj_opt, " is ", sum(sub_df["taskPerformed"]==sub_df["dSdP_tE"].round()))
    print("pSjP prediction err for subject ", subj_opt, " is ", sum(sub_df["taskPerformed"]==sub_df["pSjP_tE"].round()))

tmp_df = pd.DataFrame({'sub':subList, 
                       'pSpP':pSpP, 
                       'dSpP':dSpP, 
                       'pSdP':pSdP, 
                       'dSdP':dSdP,
                       'pSjP':pSjP})
tmp_df.to_csv(os.path.join(model_data_output, "prediction_error_by_model__date-20250823.csv"))

In [None]:
# -- plot of prediction accuracy
plot_df = tmp_df.melt("sub")
plot_df["value"] = plot_df["value"]/200
plot_df.columns = ["sub", "Models", "Prediction Error"]
# -- prediction accuracy plots
sns.set(rc={'figure.figsize':(8,6)}, font_scale=1.75)
ax = plt.figure(figsize=(8,12))
#ax = sns.boxplot(data=plot_df, x="variable", y="value", ) 
#ax = sns.violinplot(data=plot_df, x="variable", y="value", inner="point")
#sns.catplot(data=plot_df, x="Models", y="Prediction Error", kind="violin", inner="quart", ylim=(0,1)) #, ax = ax) #, inner=None)
#sns.swarmplot(data=plot_df, x="Models", y="Prediction Error", color='w', size=3, facet_kws={ylim:(0,1)})

g = sns.catplot(data=plot_df, x="Models", y="Prediction Error", kind="violin", height=6, aspect=1.5, inner="quart", palette="tab10", saturation=0.5)
g = sns.swarmplot(data=plot_df, x="Models", y="Prediction Error", dodge=True, color='w', size=4)
g.set(ylim=(0,0.75))

In [None]:
bic_df = pd.read_csv(os.path.join(model_data_output,("38subjs__logit_PE-binary_BIC_cross-valid.csv"))) # 38subjs__logit_PE-continuous_BIC_cross-valid.csv
# -- plot of BIC
plot_df = bic_df.melt("Unnamed: 0")
#plot_df["value"] = plot_df["value"]/200
plot_df.columns = ["sub", "Models", "BIC"]
# -- BIC plots
sns.set(rc={'figure.figsize':(8,6)}, font_scale=1.75)
ax = plt.figure(figsize=(8,12))
#ax = sns.boxplot(data=plot_df, x="variable", y="value", ) 
#ax = sns.violinplot(data=plot_df, x="variable", y="value", inner="point")
#sns.catplot(data=plot_df, x="Models", y="Prediction Error", kind="violin", inner="quart", ylim=(0,1)) #, ax = ax) #, inner=None)
#sns.swarmplot(data=plot_df, x="Models", y="Prediction Error", color='w', size=3, facet_kws={ylim:(0,1)})

g = sns.catplot(data=plot_df, x="Models", y="BIC", kind="violin", height=6, aspect=1.5, inner="quart", palette="tab10", saturation=0.75)
g = sns.swarmplot(data=plot_df, x="Models", y="BIC", dodge=True, color='w', size=4)
#g.set(ylim=(0,0.75))

In [None]:
# ---- subject by subject breakdown of BIC scores
plot_df.BIC.round(2)
plot_df['Models'] = pd.Categorical(plot_df['Models'], categories=["pSpP", "dSpP", "pSdP", "dSdP", "pSjP"], ordered=True)
plot_df = plot_df.sort_values('Models')

heatmap_data = plot_df.pivot_table(index='sub', columns='Models', values='BIC')

plt.figure(figsize=(60, 40))
sns.heatmap(heatmap_data, annot=True, fmt='g', cmap='viridis', vmin=80, vmax=250)
plt.title('Heatmap of BIC scores by subject')
plt.show()
#Explanation:
#df.pivot_table(index='Category_A', columns='Category_B', values='Value'):
#This crucial step transforms your "long-form" data into a "wide-form" matrix suitable for a heatmap. Category_A becomes the index (rows), Category_B becomes the columns, and the Value column provides the aggregated data for each cell.


Set up fMRI plot variables
----

In [None]:
# --------------------------------------------------------------------------------------------------------------------------
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# - - - -    Set up Some Variables for fMRI Plots   - - - -
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
#Load cortical surface mesh
fsaverage = datasets.fetch_surf_fsaverage()
#load atlas
atlas_data = datasets.fetch_atlas_schaefer_2018(n_rois=400, yeo_networks=17, resolution_mm=2) # fetch_atlas_yeo_2011()
atlas = atlas_data.maps
atlas_labels = atlas_data.labels
networks = []
for ii in range(len(atlas_labels)-1):
    networks.append(atlas_labels[ii+1].split("_")[2])

#project atlas onto fsaverage
fsaverage = datasets.fetch_surf_fsaverage('fsaverage')
textures=[]
textures.append(surface.vol_to_surf(atlas, fsaverage['pial_left'], inner_mesh=fsaverage['white_left'], interpolation='nearest', n_samples=1, radius=0.0))
textures.append(surface.vol_to_surf(atlas, fsaverage['pial_left'], inner_mesh=fsaverage['white_left'], interpolation='nearest', n_samples=1, radius=0.0))
textures.append(surface.vol_to_surf(atlas, fsaverage['pial_right'], inner_mesh=fsaverage['white_right'], interpolation='nearest', n_samples=1, radius=0.0))
textures.append(surface.vol_to_surf(atlas, fsaverage['pial_right'], inner_mesh=fsaverage['white_right'], interpolation='nearest', n_samples=1, radius=0.0))

# -- set up rois as 17 networks (remove specific roi info
n_color_dict = {'ContA':'indianred',
                'ContB':'firebrick',
                'ContC':'maroon',
                'DefaultA':'royalblue',
                'DefaultB':'mediumblue',
                'DefaultC':'navy',
                'DorsAttnA':'forestgreen',
                'DorsAttnB':'darkgreen',
                'LimbicA':'lightyellow',
                'LimbicB':'beige',
                'SalVentAttnA':'darkorchid',
                'SalVentAttnB':'mediumorchid',
                'SomMotA':'saddlebrown',
                'SomMotB':'sienna',
                'TempPar':'peachpuff',
                'VisCent':'lightgrey',
                'VisPeri':'darkgrey'}
n_textures=[]
n_textures.append(np.zeros(len(textures[0])))
n_textures.append(np.zeros(len(textures[1])))
n_textures.append(np.zeros(len(textures[2])))
n_textures.append(np.zeros(len(textures[3])))
custom_color_list = []
for n_idx, c_network in enumerate(sorted(list(set(networks)))):
    custom_color_list.append(n_color_dict[c_network])
    matching_indices = [index for index, value in enumerate(networks) if value == c_network]
    for idx, ctext in enumerate(textures):
        n_textures[idx][np.isin(ctext,matching_indices)] = n_idx+1.0
# -- set up custom color map for networks
discrete_cmap = mcolors.ListedColormap(custom_color_list)

# load the Schaefer 2018 atlas (400 ROIs)
atlas_img = nib.load(atlas)
roi_labels = atlas_data.labels  # list of 400 ROI labels
roi_indices = np.arange(1, 401)  # assuming atlas labels are 1-indexed
# Use NiftiLabelsMasker with the Schaefer atlas to extract ROI betas.
masker = NiftiLabelsMasker(labels_img=atlas_img, standardize=False)
resampled_atlas_img = resample_img(atlas_img, target_affine=cortical_mask.affine, target_shape=cortical_mask_data.shape, interpolation='nearest', force_resample=True)

Section 2

# ------------------------------------------------------------------------------------------
# - - - - - - - - - - - - - - - - -
# - -   Parametric Modulation   - -
# - - - - - - - - - - - - - - - - -

In [None]:
PM_path = os.path.join(dataset_dir, "3dMEMA")

# - - - - lists for loading my own data
mask_list = ["cue_zEntropy_masked.BRIK",
             "feedback_zTaskPE_masked.BRIK",
             "feedback_zStatePE_masked.BRIK",
             "feedback_zColorPE_masked.BRIK",
             "cue_stateD_masked.BRIK",
             "cue_zState_masked.BRIK"]
model_list = ["cue__zEntropy_SPMGmodel_stats_REML__tval.nii.gz",
              "feedback__zTaskPE_SPMGmodel_stats_REML__tval.nii.gz",
              "feedback__zStatePE_SPMGmodel_stats_REML__tval.nii.gz",
              "feedback__zColorPE_SPMGmodel_stats_REML__tval.nii.gz",
              "cue__StateD_SPMGmodel_stats_REML__tval.nii.gz",
              "cue__zState_SPMGmodel_stats_REML__tval.nii.gz"]
vmin_list = [-6, -6, -6, -6, -6, -6]
vmax_list = [6, 6, 6, 6, 6, 6]

In [None]:
# - - - - - - - - -
# plot all 17 networks together
# - - - - - - - - -
for m_idx, c_model in enumerate(model_list):
    # -- load current stat image sig clusters mask
    brik_mask = nib.load(os.path.join(PM_path, "nii3D", mask_list[m_idx])) # load sig rois for state or color
    brik_data = brik_mask.get_fdata()
    roi_data = np.where(np.squeeze(brik_data)>1,1,0) # binarize just in case its not already    
    roi_data = np.where(np.squeeze(cortical_mask_data)>0,roi_data,0) # apply cortical mask too
    mask_img = nib.Nifti1Image(roi_data, brik_mask.affine, brik_mask.header)
    
    stat_img0 = nib.load(os.path.join(PM_path, "nii3D", c_model))
    stat_masked = nilearn.masking.apply_mask(stat_img0, mask_img)
    stat_img = masking.unmask(stat_masked, mask_img)
    print("cur stat: ", os.path.join(PM_path, "nii3D", c_model))
    
    stat_img_list = []
    stat_img_list.append(surface.vol_to_surf(stat_img, fsaverage['pial_left']))
    stat_img_list.append(surface.vol_to_surf(stat_img, fsaverage['pial_left'], inner_mesh=fsaverage['white_left']))
    stat_img_list.append(surface.vol_to_surf(stat_img, fsaverage['pial_right'], inner_mesh=fsaverage['white_right']))
    stat_img_list.append(surface.vol_to_surf(stat_img, fsaverage['pial_right']))
    
    # ------------------ plot sagital surface plots w/ 17 networks ------------------ #
    figures, axes = plt.subplots(figsize=(25,25),nrows=1,ncols=4,subplot_kw={'projection': '3d'},sharex=True,sharey=True)
    figures.suptitle(c_network, fontsize=16)
    
    # plot atlas 
    plotting.plot_surf_roi(fsaverage.pial_left, colorbar=False, roi_map=n_textures[0], hemi='left', view='lateral',
                        darkness=0.3, cmap=discrete_cmap, bg_on_data=True, bg_map=fsaverage.sulc_left, axes=axes[0])
    plotting.plot_surf_roi(fsaverage.pial_left, colorbar=False, roi_map=n_textures[1], hemi='left', view='medial',
                        darkness=0.3, cmap=discrete_cmap, bg_on_data=True, bg_map=fsaverage.sulc_left, axes=axes[1])
    plotting.plot_surf_roi(fsaverage.pial_right, colorbar=False, roi_map=n_textures[2], hemi='right', view='medial',
                        darkness=0.3, cmap=discrete_cmap, bg_on_data=True, bg_map=fsaverage.sulc_right, axes=axes[2])
    plotting.plot_surf_roi(fsaverage.pial_right, colorbar=False, roi_map=n_textures[3], hemi='right', view='lateral',
                        darkness=0.3, cmap=discrete_cmap, bg_on_data=True, bg_map=fsaverage.sulc_right, axes=axes[3])
    
    # #plot maps
    # plotting.plot_surf_stat_map(fsaverage.pial_left, colorbar=False, stat_map=stat_img_list[0], bg_map=fsaverage.sulc_left, hemi='left', view='lateral',
    #                         darkness=0.1, alpha=0.005, axes=axes[0], threshold=2.985, vmin=-10, vmax=10)
    # plotting.plot_surf_stat_map(fsaverage.pial_left, colorbar=False, stat_map=stat_img_list[1], bg_map=fsaverage.sulc_left, hemi='left', view='medial',
    #                         darkness=0.1, alpha=0.005, axes=axes[1], threshold=2.985, vmin=-10, vmax=10)
    # plotting.plot_surf_stat_map(fsaverage.pial_right, colorbar=False, stat_map=stat_img_list[2], bg_map=fsaverage.sulc_right, hemi='right', view='medial',
    #                         darkness=0.1, alpha=0.005, axes=axes[2], threshold=2.985, vmin=-10, vmax=10)
    # plotting.plot_surf_stat_map(fsaverage.pial_right, colorbar=False, stat_map=stat_img_list[3], bg_map=fsaverage.sulc_right, hemi='right', view='lateral',
    #                         darkness=0.1, alpha=0.005, axes=axes[3], threshold=2.985, vmin=-10, vmax=10)
    
    figures.subplots_adjust(wspace=0.01,hspace=0.00)
    figures.savefig(os.path.join(PM_path,"nilearn_plots", ("17networks_"+c_model[:-34]+'.png')), bbox_inches='tight')
    plt.show() # plt.close('all')
    "c"+2

In [None]:
# - - - - - - - - -
# bar plots of number of clusters in each network
# - - - - - - - - -
# mask_list = ["feedback_zTaskPE_masked.BRIK",
#              "feedback_zStatePE_masked.BRIK",
#              "feedback_zColorPE_masked.BRIK",
#              "cue_stateD_masked.BRIK",
#              "cue_zState_masked.BRIK",
#              "cue_zEntropy_masked.BRIK"]
# loop through each sig roi mask
for m_idx, c_mask in enumerate(mask_list):
    if c_mask=="feedback_zTaskPE_masked.BRIK":
        sns.set(rc={'figure.figsize':(5.5, 3.25)}) # width=10 inches, height=6 inches
    elif c_mask=="feedback_zStatePE_masked.BRIK":
        sns.set(rc={'figure.figsize':(5.5, 3.25)}) # width=10 inches, height=6 inches
    elif c_mask=="cue_stateD_masked.BRIK":
        sns.set(rc={'figure.figsize':(5.5, 3.25)}) # width=10 inches, height=6 inches
    elif c_mask=="feedback_zColorPE_masked.BRIK":
        sns.set(rc={'figure.figsize':(5.5, 9)})
    else:
        sns.set(rc={'figure.figsize':(5.5, 3.75)}) # width=10 inches, height=6 inches
    print("working with ", c_mask)
    # -- set up dict to save rois per network for bar plots
    results_dict = {'ContA':0, 'ContB':0, 'ContC':0,
                'DefaultA':0, 'DefaultB':0, 'DefaultC':0,
                'DorsAttnA':0, 'DorsAttnB':0,
                'LimbicA':0, 'LimbicB':0,
                'SalVentAttnA':0, 'SalVentAttnB':0,
                'SomMotA':0, 'SomMotB':0,
                'TempPar':0,
                'VisCent':0, 'VisPeri':0}
    total_voxel_count = 0
    # -- load current stat image sig clusters mask
    stat_img0 = nib.load(os.path.join(PM_path, "nii3D", model_list[m_idx]))
    brik_mask = nib.load(os.path.join(PM_path, "nii3D", c_mask)) # load sig rois for state or color
    brik_data = brik_mask.get_fdata()
    print("number of rois in current mask: ", int(brik_data.max()))
    for roi in range(int(brik_data.max())):
        # -- set sig roi cluster mask
        sig_clust_data = np.where(np.squeeze(brik_data)==(roi+1),1,0) # binarize current roi
        sig_clust_data = np.where(np.squeeze(cortical_mask_data)>0,sig_clust_data,0) # apply cortical mask too
        sig_clust_img = nib.Nifti1Image(sig_clust_data, brik_mask.affine, brik_mask.header)
        
        try:
            # mask stats data to see if cluster is pos or neg
            stat_masked = nilearn.masking.apply_mask(stat_img0, sig_clust_img)
            if stat_masked.mean() < 0:
                continue # skip this roi if it is negative
            # -- mask atals data so we can see which networks this cluster overlaps with
            atlas_masked = nilearn.masking.apply_mask(resampled_atlas_img, sig_clust_img)
            atlas_masked = atlas_masked[atlas_masked>0] # only use voxels with atlas labels so it all sums to 1
        except:
            print("error probably due to no data after masking.. just move on and skip this roi")
            continue
        
        #num_voxels_w_roi_labels = atlas_masked.shape 
        # -- okay pull networks that match rois with data
        c_networks_list = [networks[(i-1)] for i in atlas_masked.astype(int)]
        for c_net in c_networks_list:
            results_dict[c_net] += 1 # add 1 to show 1 voxel overlapped with this sig cluster
            total_voxel_count += 1
    
    for net_key in results_dict.keys():
        results_dict[net_key] = results_dict[net_key] / total_voxel_count
    
    results_df = pd.DataFrame(results_dict, index=[0])
    melt_df = pd.melt(results_df)
    #print(melt_df)
    sns.barplot(x='variable', y='value', data=melt_df, palette=custom_color_list, edgecolor='black', linewidth=1.5)
    plt.title(c_mask[:-11] + "\n")
    plt.xlabel("Network")
    plt.xticks(rotation=45, ha='right')
    plt.ylabel("Proportion of sig. voxels\noverlapping with network")
    if c_mask=="feedback_zTaskPE_masked.BRIK":
        plt.ylim((0,0.3))
    elif c_mask=="feedback_zStatePE_masked.BRIK":
        plt.ylim((0,0.3))
    elif c_mask=="feedback_zColorPE_masked.BRIK":
        plt.ylim((0,0.99))
    elif c_mask=="cue_stateD_masked.BRIK":
        plt.ylim((0,0.3))
    else:
        plt.ylim((0,0.45))
    
    plt.savefig(os.path.join(PM_path,"nilearn_plots",("17networks__bar_plot__"+c_mask[:-11]+".png")), bbox_inches='tight')
    
    plt.show()

In [None]:
vox_thresh = 2.026 
for m_idx, c_model in enumerate(model_list):
    # -- load current stat image sig clusters mask
    brik_mask = nib.load(os.path.join(PM_path, "nii3D", mask_list[m_idx])) # load sig rois for state or color
    brik_data = brik_mask.get_fdata()
    roi_data = np.where(np.squeeze(brik_data)>0,1,0) # binarize just in case its not already    
    roi_data = np.where(np.squeeze(cortical_mask_data)>0,roi_data,0) # apply cortical mask too
    mask_img = nib.Nifti1Image(roi_data, brik_mask.affine, brik_mask.header)
    
    stat_img0 = nib.load(os.path.join(PM_path, "nii3D", c_model))
    stat_masked = nilearn.masking.apply_mask(stat_img0, mask_img)
    stat_img = masking.unmask(stat_masked, mask_img)
    print("cur stat: ", os.path.join(PM_path, "nii3D", c_model))
    
    # if c_model == "cue__StateD_SPMGmodel_stats_REML__tval":
    #     cut_cords_list = [[-15], [-6, -3], [6, 15], [30, 42, 48]] #left-medial, right-medial, left-lateral, right-lateral
    # else:
    #     cut_cords_list = [[-15], [-6, -3], [3, 12, 18], [39, 48, 54]]
    cut_cords_list = [[-16, -11, -9, -2], [1, 6, 17], [20, 23, 32], [42, 52, 64]]

    # # ------------------ plot axial slices ------------------ #
    L_sag_cuts = plotting.plot_stat_map(stat_map_img=stat_img, threshold=vox_thresh, cmap=plt.cm.RdBu_r, display_mode='z', cut_coords=cut_cords_list[0], vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
    L_sag_cuts.savefig(os.path.join(PM_path,"nilearn_plots", (c_model[:-34]+'__Inferior_Ax-1.png')))
    plt.show() # plt.close('all')
    R_sag_cuts = plotting.plot_stat_map(stat_map_img=stat_img, threshold=vox_thresh, cmap=plt.cm.RdBu_r, display_mode='z', cut_coords=cut_cords_list[1], vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
    R_sag_cuts.savefig(os.path.join(PM_path,"nilearn_plots", (c_model[:-34]+'__Inferior_Ax-2.png')))
    plt.show() # plt.close('all')
    
    L_sag_cuts = plotting.plot_stat_map(stat_map_img=stat_img, threshold=vox_thresh, cmap=plt.cm.RdBu_r, display_mode='z', cut_coords=cut_cords_list[2], vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
    L_sag_cuts.savefig(os.path.join(PM_path,"nilearn_plots", (c_model[:-34]+'__Superior_Ax-1.png')))
    plt.show() # plt.close('all')
    R_sag_cuts = plotting.plot_stat_map(stat_map_img=stat_img, threshold=vox_thresh, cmap=plt.cm.RdBu_r, display_mode='z', cut_coords=cut_cords_list[3], vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
    R_sag_cuts.savefig(os.path.join(PM_path,"nilearn_plots", (c_model[:-34]+'__Superior_Ax-2.png')))
    plt.show() # plt.close('all')

    # figures, axes = plt.subplots(figsize=(25,25),nrows=1,ncols=4,subplot_kw={'projection': '3d'},sharex=True,sharey=True)

    # ------------------ plot sagital surface plots ------------------ #
    stat_img_list = []
    stat_img_list.append(surface.vol_to_surf(stat_img, fsaverage['pial_left'], inner_mesh=fsaverage['white_left']))
    stat_img_list.append(surface.vol_to_surf(stat_img, fsaverage['pial_left'], inner_mesh=fsaverage['white_left']))
    stat_img_list.append(surface.vol_to_surf(stat_img, fsaverage['pial_right'], inner_mesh=fsaverage['white_right']))
    stat_img_list.append(surface.vol_to_surf(stat_img, fsaverage['pial_right'], inner_mesh=fsaverage['white_right']))
    
    figures, axes = plt.subplots(figsize=(25,25),nrows=1,ncols=4,subplot_kw={'projection': '3d'},sharex=True,sharey=True)
    
    #plot maps
    plotting.plot_surf_stat_map(fsaverage.pial_left, colorbar=False, stat_map=stat_img_list[0], bg_map=fsaverage.sulc_left, hemi='left', view='lateral',
                            darkness=0.95, alpha=0.71, axes=axes[0], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
    plotting.plot_surf_stat_map(fsaverage.pial_left, colorbar=False, stat_map=stat_img_list[1], bg_map=fsaverage.sulc_left, hemi='left', view='medial',
                            darkness=0.95, alpha=0.71, axes=axes[1], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
    plotting.plot_surf_stat_map(fsaverage.pial_right, colorbar=False, stat_map=stat_img_list[2], bg_map=fsaverage.sulc_right, hemi='right', view='medial',
                            darkness=0.95, alpha=0.71, axes=axes[2], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
    plotting.plot_surf_stat_map(fsaverage.pial_right, colorbar=False, stat_map=stat_img_list[3], bg_map=fsaverage.sulc_right, hemi='right', view='lateral',
                            darkness=0.95, alpha=0.71, axes=axes[3], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
    
    figures.subplots_adjust(wspace=0.01,hspace=0.00)
    figures.savefig(os.path.join(PM_path,"nilearn_plots", (c_model[:-34]+'.png')), bbox_inches='tight')
    plt.show() # plt.close('all')
    
    # ------------------ plot sagital surface plots w/ 17 networks ------------------ #
    #figures, axes = plt.subplots(figsize=(15,20),nrows=1,ncols=4,subplot_kw={'projection': '3d'},sharex=True,sharey=True)
    
    # # plot atlas 
    # plotting.plot_surf_roi(fsaverage.pial_left, roi_map=n_textures[0], hemi='left', view='lateral',
    #                     darkness=0.3, cmap=discrete_cmap, bg_on_data=True, bg_map=fsaverage.sulc_left, axes=axes[0])
    # plotting.plot_surf_roi(fsaverage.pial_left, roi_map=n_textures[1], hemi='left', view='medial',
    #                     darkness=0.3, cmap=discrete_cmap, bg_on_data=True, bg_map=fsaverage.sulc_left, axes=axes[1])
    # plotting.plot_surf_roi(fsaverage.pial_right, roi_map=n_textures[2], hemi='right', view='medial',
    #                     darkness=0.3, cmap=discrete_cmap, bg_on_data=True, bg_map=fsaverage.sulc_right, axes=axes[2])
    # plotting.plot_surf_roi(fsaverage.pial_right, roi_map=n_textures[3], hemi='right', view='lateral',
    #                     darkness=0.3, cmap=discrete_cmap, bg_on_data=True, bg_map=fsaverage.sulc_right, axes=axes[3])
    
    # #plot maps
    # plotting.plot_surf_stat_map(fsaverage.pial_left, colorbar=False, stat_map=stat_img_list[0], bg_map=fsaverage.sulc_left, hemi='left', view='lateral',
    #                         darkness=0.1, alpha=0.005, axes=axes[0], threshold=2.985, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
    # plotting.plot_surf_stat_map(fsaverage.pial_left, colorbar=False, stat_map=stat_img_list[1], bg_map=fsaverage.sulc_left, hemi='left', view='medial',
    #                         darkness=0.1, alpha=0.005, axes=axes[1], threshold=2.985, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
    # plotting.plot_surf_stat_map(fsaverage.pial_right, colorbar=False, stat_map=stat_img_list[2], bg_map=fsaverage.sulc_right, hemi='right', view='medial',
    #                         darkness=0.1, alpha=0.005, axes=axes[2], threshold=2.985, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
    # plotting.plot_surf_stat_map(fsaverage.pial_right, colorbar=False, stat_map=stat_img_list[3], bg_map=fsaverage.sulc_right, hemi='right', view='lateral',
    #                         darkness=0.1, alpha=0.005, axes=axes[3], threshold=2.985, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
    
    # #plot maps
    # plotting.plot_surf_stat_map(fsaverage.pial_left, colorbar=False, stat_map=stat_img_list[0], bg_map=fsaverage.sulc_left, hemi='left', view='lateral',
    #                         darkness=0.35, axes=axes[0], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
    # plotting.plot_surf_stat_map(fsaverage.pial_left, colorbar=False, stat_map=stat_img_list[1], bg_map=fsaverage.sulc_left, hemi='left', view='medial',
    #                         darkness=0.35, axes=axes[1], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
    # plotting.plot_surf_stat_map(fsaverage.pial_right, colorbar=False, stat_map=stat_img_list[2], bg_map=fsaverage.sulc_right, hemi='right', view='medial',
    #                         darkness=0.35, axes=axes[2], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
    # plotting.plot_surf_stat_map(fsaverage.pial_right, colorbar=False, stat_map=stat_img_list[3], bg_map=fsaverage.sulc_right, hemi='right', view='lateral',
    #                         darkness=0.35, axes=axes[3], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
    
    # plot_surf_contours(roi_map=n_textures[0], surf_mesh=fsaverage.pial_left, cmap=discrete_cmap, hemi='left', view='lateral', axes=axes[0])
    # plot_surf_contours(roi_map=n_textures[1], surf_mesh=fsaverage.pial_left, cmap=discrete_cmap, hemi='left', view='lateral', axes=axes[1])
    # plot_surf_contours(roi_map=n_textures[2], surf_mesh=fsaverage.pial_left, cmap=discrete_cmap, hemi='left', view='lateral', axes=axes[2])
    # plot_surf_contours(roi_map=n_textures[3], surf_mesh=fsaverage.pial_left, cmap=discrete_cmap, hemi='left', view='lateral', axes=axes[3])
    
    figures.subplots_adjust(wspace=0.01,hspace=0.00)
    figures.savefig(os.path.join(PM_path,"nilearn_plots", (c_model[:-34]+'_plus17networks.png')), bbox_inches='tight')
    plt.show() # plt.close('all')

Section 3

# -----------------------------------------------------------------------------------------------------

# - - - - - - - - - - - - - - - - - - -
# - - -   Probabalistic Decoding  - - -
# - - - - - - - - - - - - - - - - - - -

In [None]:
PM_path = os.path.join(dataset_dir, "Decoding", "GroupStats")
# ---- set up plot variables
model_list = ["jointP", "color", "state", "task"]
vmin_list = [-6, -6, -6, -6]
vmax_list = [6, 6, 6, 6]

custom_color_list_stats = custom_color_list.copy()
custom_color_list_stats.append('gold')
# -- set up custom color map for networks
discrete_cmap_stats = mcolors.ListedColormap(custom_color_list_stats)

In [None]:
vox_thresh = 2.026 
for epoch in ["cue","probe"]:
    for m_idx, c_model in enumerate(model_list):
        # -- load current stat image sig clusters mask
        brik_mask = nib.load(os.path.join(PM_path, (epoch+"_"+c_model+"_masked.BRIK"))) # load sig rois for state or color
        brik_data = brik_mask.get_fdata()
        if c_model=="jointP":
            roi_data = np.where(np.squeeze(brik_data)>0,1,0) # binarize just in case its not already
        elif c_model=="task":
            roi_data = np.where(np.squeeze(brik_data)>0,1,0) # binarize just in case its not already 
        else:
            roi_data = np.where(np.squeeze(brik_data)>1,1,0)
        roi_data = np.where(np.squeeze(cortical_mask_data)>0,roi_data,0) # apply cortical mask too
        mask_img = nib.Nifti1Image(roi_data, brik_mask.affine, brik_mask.header)
        
        stat_img0 = nib.load(os.path.join(PM_path, ("GroupAnalysis_38subjs__"+c_model+"_"+epoch+"__r__tval.nii")))
        stat_masked = nilearn.masking.apply_mask(stat_img0, mask_img)
        stat_img = masking.unmask(stat_masked, mask_img)
        print("cur stat: ", c_model)
        
        # if c_model == "state":
        #     cut_cords_list = [[-30, -16], [-12, -8, -2], [12, 20], [50, 56]] #left-medial, right-medial, left-lateral, right-lateral
        # elif c_model == "color":
        #     if epoch=="cue":
        #         cut_cords_list = [[-8], [-8], [30, 33, 37], [45, 54, 64]]
        #     else:
        #         cut_cords_list = [[-6], [-6], [39], [52, 59]]
        # elif c_model=="task":
        #     cut_cords_list = [[-16, -9], [-1], [7, 16, 21, 30], [35, 47, 54, 64]]
        # elif c_model=="jointP":
        #     cut_cords_list = [[-16, -12, -9], [-1], [7, 16, 21, 30], [35, 47, 54, 64]]
        cut_cords_list = [[-16, -11, -9, -2], [1, 6, 17], [20, 23, 32], [42, 52, 64]]

        # ------------------ plot axial slices ------------------ #
        L_sag_cuts = plotting.plot_stat_map(stat_map_img=stat_img, threshold=vox_thresh, cmap=plt.cm.RdBu_r, display_mode='z', cut_coords=cut_cords_list[0], vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
        L_sag_cuts.savefig(os.path.join(PM_path,"nilearn_plots", (c_model+"_"+epoch+'__Inferior_Ax-1.png')))
        plt.show() # plt.close('all')
        R_sag_cuts = plotting.plot_stat_map(stat_map_img=stat_img, threshold=vox_thresh, cmap=plt.cm.RdBu_r, display_mode='z', cut_coords=cut_cords_list[1], vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
        R_sag_cuts.savefig(os.path.join(PM_path,"nilearn_plots", (c_model+"_"+epoch+'__Inferior_Ax-2.png')))
        plt.show() # plt.close('all')
        
        L_sag_cuts = plotting.plot_stat_map(stat_map_img=stat_img, threshold=vox_thresh, cmap=plt.cm.RdBu_r, display_mode='z', cut_coords=cut_cords_list[2], vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
        L_sag_cuts.savefig(os.path.join(PM_path,"nilearn_plots", (c_model+"_"+epoch+'__Superior_Ax-1.png')))
        plt.show() # plt.close('all')
        R_sag_cuts = plotting.plot_stat_map(stat_map_img=stat_img, threshold=vox_thresh, cmap=plt.cm.RdBu_r, display_mode='z', cut_coords=cut_cords_list[3], vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
        R_sag_cuts.savefig(os.path.join(PM_path,"nilearn_plots", (c_model+"_"+epoch+'__Superior_Ax-2.png')))
        plt.show() # plt.close('all')
        
        
        # ------------------ plot sagital surface plots ------------------ #
        stat_img_list = []
        stat_img_list.append(surface.vol_to_surf(stat_img, fsaverage['pial_left'], inner_mesh=fsaverage['white_left']))
        stat_img_list.append(surface.vol_to_surf(stat_img, fsaverage['pial_left'], inner_mesh=fsaverage['white_left']))
        stat_img_list.append(surface.vol_to_surf(stat_img, fsaverage['pial_right'], inner_mesh=fsaverage['white_right']))
        stat_img_list.append(surface.vol_to_surf(stat_img, fsaverage['pial_right'], inner_mesh=fsaverage['white_right']))
        
        # n_textures_D = np.copy(n_textures)
        # n_textures_D[0][stat_img_list[0]>0] = 18
        # n_textures_D[1][stat_img_list[1]>0] = 18
        # n_textures_D[2][stat_img_list[2]>0] = 18
        # n_textures_D[3][stat_img_list[3]>0] = 18
        
        
        figures, axes = plt.subplots(figsize=(15,15),nrows=1,ncols=4,subplot_kw={'projection': '3d'},sharex=True,sharey=True)
            
        #plot maps
        plotting.plot_surf_stat_map(fsaverage.pial_left, colorbar=False, stat_map=stat_img_list[0], bg_map=fsaverage.sulc_left, hemi='left', view='lateral',
                                darkness=0.95, alpha=0.71, axes=axes[0], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
        plotting.plot_surf_stat_map(fsaverage.pial_left, colorbar=False, stat_map=stat_img_list[1], bg_map=fsaverage.sulc_left, hemi='left', view='medial',
                                darkness=0.95, alpha=0.71, axes=axes[1], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
        plotting.plot_surf_stat_map(fsaverage.pial_right, colorbar=False, stat_map=stat_img_list[2], bg_map=fsaverage.sulc_right, hemi='right', view='medial',
                                darkness=0.95, alpha=0.71, axes=axes[2], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
        plotting.plot_surf_stat_map(fsaverage.pial_right, colorbar=False, stat_map=stat_img_list[3], bg_map=fsaverage.sulc_right, hemi='right', view='lateral',
                                darkness=0.95, alpha=0.71, axes=axes[3], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
        
        figures.subplots_adjust(wspace=0.01,hspace=0.00)
        figures.savefig(os.path.join(PM_path, "nilearn_plots", (c_model+"_"+epoch+'.png')), bbox_inches='tight')
        plt.show()
        #plt.close('all')
        
        # # # ------------------ plot sagital surface plots w/ 17 networks ------------------ #
        # # figures, axes = plt.subplots(figsize=(15,15),nrows=1,ncols=4,subplot_kw={'projection': '3d'},sharex=True,sharey=True)
        
        # # # # plot atlas 
        # # # plotting.plot_surf_roi(fsaverage.pial_left, roi_map=n_textures[0], hemi='left', view='lateral', colorbar=False, 
        # # #                     darkness=0.3, cmap=discrete_cmap_stats, bg_on_data=True, bg_map=fsaverage.sulc_left, axes=axes[0])
        # # # plotting.plot_surf_roi(fsaverage.pial_left, roi_map=n_textures[1], hemi='left', view='medial', colorbar=False, 
        # # #                     darkness=0.3, cmap=discrete_cmap_stats, bg_on_data=True, bg_map=fsaverage.sulc_left, axes=axes[1])
        # # # plotting.plot_surf_roi(fsaverage.pial_right, roi_map=n_textures[2], hemi='right', view='medial', colorbar=False, 
        # # #                     darkness=0.3, cmap=discrete_cmap_stats, bg_on_data=True, bg_map=fsaverage.sulc_right, axes=axes[2])
        # # # plotting.plot_surf_roi(fsaverage.pial_right, roi_map=n_textures[3], hemi='right', view='lateral', colorbar=False, 
        # # #                     darkness=0.3, cmap=discrete_cmap_stats, bg_on_data=True, bg_map=fsaverage.sulc_right, axes=axes[3])
        
        # # #plot maps
        # # plotting.plot_surf_stat_map(fsaverage.pial_left, colorbar=False, stat_map=stat_img_list[0], bg_map=fsaverage.sulc_left, hemi='left', view='lateral',
        # #                         darkness=0.35, axes=axes[0], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
        # # plotting.plot_surf_stat_map(fsaverage.pial_left, colorbar=False, stat_map=stat_img_list[1], bg_map=fsaverage.sulc_left, hemi='left', view='medial',
        # #                         darkness=0.35, axes=axes[1], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
        # # plotting.plot_surf_stat_map(fsaverage.pial_right, colorbar=False, stat_map=stat_img_list[2], bg_map=fsaverage.sulc_right, hemi='right', view='medial',
        # #                         darkness=0.35, axes=axes[2], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
        # # plotting.plot_surf_stat_map(fsaverage.pial_right, colorbar=False, stat_map=stat_img_list[3], bg_map=fsaverage.sulc_right, hemi='right', view='lateral',
        # #                         darkness=0.35, axes=axes[3], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
        
        # # plot_surf_contours(roi_map=n_textures[0], surf_mesh=fsaverage.pial_left, cmap=discrete_cmap, hemi='left', view='lateral', axes=axes[0])
        # # plot_surf_contours(roi_map=n_textures[1], surf_mesh=fsaverage.pial_left, cmap=discrete_cmap, hemi='left', view='lateral', axes=axes[1])
        # # plot_surf_contours(roi_map=n_textures[2], surf_mesh=fsaverage.pial_left, cmap=discrete_cmap, hemi='left', view='lateral', axes=axes[2])
        # # plot_surf_contours(roi_map=n_textures[3], surf_mesh=fsaverage.pial_left, cmap=discrete_cmap, hemi='left', view='lateral', axes=axes[3])
        
        # # # figures.subplots_adjust(wspace=0.01,hspace=0.00)
        # # # figures.savefig(os.path.join(PM_path, "nilearn_plots", (c_model+"_"+epoch+'_plus17networks.png')), bbox_inches='tight')
        # # plt.show()
        # # #plt.close('all')
        

In [None]:
# - - - - - - - - -
# bar plots of number of clusters in each network
# - - - - - - - - -
# loop through each sig roi mask
for epoch in ["cue","probe"]:
    for m_idx, c_mask in enumerate(model_list):
        if c_mask=="jointP":
            sns.set(rc={'figure.figsize':(5.5, 3.75)}) # width=10 inches, height=6 inches
        elif c_mask=="task":
            sns.set(rc={'figure.figsize':(5.5, 3.75)}) # width=10 inches, height=6 inches
        else:
            sns.set(rc={'figure.figsize':(5.5, 2.5)}) # width=10 inches, height=6 inches
        print("working with ", c_mask)
        # -- set up dict to save rois per network for bar plots
        results_dict = {'ContA':0, 'ContB':0, 'ContC':0,
                    'DefaultA':0, 'DefaultB':0, 'DefaultC':0,
                    'DorsAttnA':0, 'DorsAttnB':0,
                    'LimbicA':0, 'LimbicB':0,
                    'SalVentAttnA':0, 'SalVentAttnB':0,
                    'SomMotA':0, 'SomMotB':0,
                    'TempPar':0,
                    'VisCent':0, 'VisPeri':0}
        total_voxel_count = 0
        
        # -- load current stat image sig clusters mask
        stat_img0 = nib.load(os.path.join(PM_path, ("GroupAnalysis_38subjs__"+c_mask+"_"+epoch+"__r__tval.nii")))
        brik_mask = nib.load(os.path.join(PM_path, (epoch+"_"+c_mask+"_masked.BRIK"))) # load sig rois for state or color
        brik_data = brik_mask.get_fdata()
        for roi in range(int(brik_data.max())):
            # if roi==0:
            #     if c_mask!="jointP":
            #         continue
            #     elif c_mask!="task":
            #         continue # skip first roi for color, state, and task
            sig_clust_data = np.where(np.squeeze(brik_data)==(roi+1),1,0) # binarize current roi
            sig_clust_data = np.where(np.squeeze(cortical_mask_data)>0,sig_clust_data,0) # apply cortical mask too
            sig_clust_img = nib.Nifti1Image(sig_clust_data, brik_mask.affine, brik_mask.header)
            
            try:
                # mask stats data to see if cluster is pos or neg
                stat_masked = nilearn.masking.apply_mask(stat_img0, sig_clust_img)
                if stat_masked.mean() < 0:
                    continue # skip this roi if it is negative
                atlas_masked = nilearn.masking.apply_mask(resampled_atlas_img, sig_clust_img)
                atlas_masked = atlas_masked[atlas_masked>0] # only use voxels with atlas labels so it all sums to 1
            except:
                print("error probably due to no data after masking.. just move on and skip this roi")
                continue
            
            #num_voxels_w_roi_labels = atlas_masked.shape 
            # -- okay pull networks that match rois with data
            c_networks_list = [networks[(i-1)] for i in atlas_masked.astype(int)]
            for c_net in c_networks_list:
                results_dict[c_net] += 1 # add 1 to show 1 voxel overlapped with this sig cluster
                total_voxel_count += 1
            
        for net_key in results_dict.keys():
            results_dict[net_key] = results_dict[net_key] / total_voxel_count
        
        results_df = pd.DataFrame(results_dict, index=[0])
        melt_df = pd.melt(results_df)
        #print(melt_df)
        sns.barplot(x='variable', y='value', data=melt_df, palette=custom_color_list, edgecolor='black', linewidth=1.5)
        plt.title((epoch + " " + c_mask) + "\n")
        plt.xlabel("Network")
        plt.xticks(rotation=45, ha='right')
        plt.ylabel("Proportion of sig. voxels\noverlapping with network")
        if c_mask=="jointP":
            plt.ylim((0,0.45))
        elif c_mask=="task":
            plt.ylim((0,0.45))
        else:
            plt.ylim((0,0.25))
        
        plt.savefig(os.path.join(PM_path, "nilearn_plots", ("17networks__bar_plot__"+c_mask+"_"+epoch+".png")), bbox_inches='tight')
        
        plt.show()

Section 4

# -----------------------------------------------------------------------------------------------------

# - - - - - - - - - - - - - - - - - - -
# - - -   Network Hubs  - - -
# - - - - - - - - - - - - - - - - - - -

In [None]:
PM_path = os.path.join(dataset_dir, "Hubs")
# ---- set up plot variables
c_model = "Voxelwise_4mm_MGH_PC"
m_idx=0
vmin_list = [0]
vmax_list = [1]

In [None]:
vox_thresh = 0.65

# -- load current stat image
#mask_img = cortical_mask
stat_img = nib.load(os.path.join(PM_path, (c_model+".nii")))
#stat_masked = nilearn.masking.apply_mask(stat_img0, mask_img)
#stat_img = masking.unmask(stat_masked, mask_img)
print("cur stat: ", c_model)

cut_cords_list = [[-16, -11, -9, -2], [1, 6, 17], [20, 23, 32], [42, 52, 64]]

# ------------------ plot axial slices ------------------ #
L_sag_cuts = plotting.plot_stat_map(stat_map_img=stat_img, threshold=vox_thresh, cmap=plt.cm.RdBu_r, display_mode='z', cut_coords=cut_cords_list[0], vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
L_sag_cuts.savefig(os.path.join(PM_path,(c_model+'__Inferior_Ax-1.png')))
plt.show() # plt.close('all')
R_sag_cuts = plotting.plot_stat_map(stat_map_img=stat_img, threshold=vox_thresh, cmap=plt.cm.RdBu_r, display_mode='z', cut_coords=cut_cords_list[1], vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
R_sag_cuts.savefig(os.path.join(PM_path,(c_model+'__Inferior_Ax-2.png')))
plt.show() # plt.close('all')

L_sag_cuts = plotting.plot_stat_map(stat_map_img=stat_img, threshold=vox_thresh, cmap=plt.cm.RdBu_r, display_mode='z', cut_coords=cut_cords_list[2], vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
L_sag_cuts.savefig(os.path.join(PM_path,(c_model+'__Superior_Ax-1.png')))
plt.show() # plt.close('all')
R_sag_cuts = plotting.plot_stat_map(stat_map_img=stat_img, threshold=vox_thresh, cmap=plt.cm.RdBu_r, display_mode='z', cut_coords=cut_cords_list[3], vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
R_sag_cuts.savefig(os.path.join(PM_path, (c_model+'__Superior_Ax-2.png')))
plt.show() # plt.close('all')


# # ------------------ plot sagital surface plots ------------------ #
# stat_img_list = []
# stat_img_list.append(surface.vol_to_surf(stat_img, fsaverage['pial_left'], inner_mesh=fsaverage['white_left']))
# stat_img_list.append(surface.vol_to_surf(stat_img, fsaverage['pial_left'], inner_mesh=fsaverage['white_left']))
# stat_img_list.append(surface.vol_to_surf(stat_img, fsaverage['pial_right'], inner_mesh=fsaverage['white_right']))
# stat_img_list.append(surface.vol_to_surf(stat_img, fsaverage['pial_right'], inner_mesh=fsaverage['white_right']))

# figures, axes = plt.subplots(figsize=(15,15),nrows=1,ncols=4,subplot_kw={'projection': '3d'},sharex=True,sharey=True)
    
# #plot maps
# plotting.plot_surf_stat_map(fsaverage.pial_left, colorbar=False, stat_map=stat_img_list[0], bg_map=fsaverage.sulc_left, hemi='left', view='lateral',
#                         darkness=0.95, alpha=0.71, axes=axes[0], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
# plotting.plot_surf_stat_map(fsaverage.pial_left, colorbar=False, stat_map=stat_img_list[1], bg_map=fsaverage.sulc_left, hemi='left', view='medial',
#                         darkness=0.95, alpha=0.71, axes=axes[1], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
# plotting.plot_surf_stat_map(fsaverage.pial_right, colorbar=False, stat_map=stat_img_list[2], bg_map=fsaverage.sulc_right, hemi='right', view='medial',
#                         darkness=0.95, alpha=0.71, axes=axes[2], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])
# plotting.plot_surf_stat_map(fsaverage.pial_right, colorbar=False, stat_map=stat_img_list[3], bg_map=fsaverage.sulc_right, hemi='right', view='lateral',
#                         darkness=0.95, alpha=0.71, axes=axes[3], threshold=vox_thresh, vmin=vmin_list[m_idx], vmax=vmax_list[m_idx])

# figures.subplots_adjust(wspace=0.01,hspace=0.00)
# figures.savefig(os.path.join(PM_path, (c_model+'.png')), bbox_inches='tight')
# plt.show()
# #plt.close('all')


In [None]:
# - - - - - - - - -
# bar plots of number of clusters in each network
# - - - - - - - - -
# loop through each sig roi mask
sns.set(rc={'figure.figsize':(5.5, 3.75)}) # width=10 inches, height=6 inches

print("working with ", c_model)
# -- set up dict to save rois per network for bar plots
results_dict = {'ContA':0, 'ContB':0, 'ContC':0,
            'DefaultA':0, 'DefaultB':0, 'DefaultC':0,
            'DorsAttnA':0, 'DorsAttnB':0,
            'LimbicA':0, 'LimbicB':0,
            'SalVentAttnA':0, 'SalVentAttnB':0,
            'SomMotA':0, 'SomMotB':0,
            'TempPar':0,
            'VisCent':0, 'VisPeri':0}
total_voxel_count = 0

stat_img = nib.load(os.path.join(PM_path, (c_model+".nii")))
stat_data = stat_img.get_fdata()
sig_clust_data = np.where(np.squeeze(stat_data)>0.65,1,0) # apply cortical mask too
sig_clust_img = nib.Nifti1Image(sig_clust_data, stat_img.affine, stat_img.header)

# Use NiftiLabelsMasker with the Schaefer atlas to extract ROI betas.
masker = NiftiLabelsMasker(labels_img=atlas_img, standardize=False)
atlas_4mm_img = resample_img(atlas_img, target_affine=stat_img.affine, target_shape=stat_img.shape, interpolation='nearest', force_resample=True)

try:
    atlas_masked = nilearn.masking.apply_mask(atlas_4mm_img, sig_clust_img)
    atlas_masked = atlas_masked[atlas_masked>0] # only use voxels with atlas labels so it all sums to 1
except:
    print("error probably due to no data after masking.. just move on and skip this roi")
    
# -- okay pull networks that match rois with data
c_networks_list = [networks[(i-1)] for i in atlas_masked.astype(int)]
for c_net in c_networks_list:
    results_dict[c_net] += 1 # add 1 to show 1 voxel overlapped with this sig cluster
    total_voxel_count += 1
    
for net_key in results_dict.keys():
    results_dict[net_key] = results_dict[net_key] / total_voxel_count

results_df = pd.DataFrame(results_dict, index=[0])
melt_df = pd.melt(results_df)
#print(melt_df)
sns.barplot(x='variable', y='value', data=melt_df, palette=custom_color_list, edgecolor='black', linewidth=1.5)
plt.title((c_model) + "\n")
plt.xlabel("Network")
plt.xticks(rotation=45, ha='right')
plt.ylabel("Proportion of sig. voxels\noverlapping with network")
plt.ylim((0,0.45))

plt.savefig(os.path.join(PM_path, ("17networks__bar_plot__"+c_model+".png")), bbox_inches='tight')