In [None]:
import numpy as np
import scipy.stats as stats
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt
import os
import random

import matplotlib
import matplotlib.font_manager as fm

matplotlib.rcParams['font.size'] = 8
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
fm.FontProperties().set_family('arial')

%load_ext autoreload
%autoreload 2
%matplotlib inline
# %matplotlib widget
# %matplotlib notebook

In [None]:
session_table_path=r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\CO decoding results\session_table_v0.268.csv"
# session_table_path="/Users/ethan.mcbride/Data/DR/session_table_v0.265.csv"
session_table=pl.read_csv(session_table_path)

dr_session_list=(
    session_table.filter(
    pl.col('project')=="DynamicRouting",
    pl.col('is_production'),
    pl.col('is_annotated'),
    pl.col('issues')=="",
    pl.col('is_engaged'),
    pl.col('is_good_behavior').eq(True),
    )['session_id'].to_list()
    )

In [None]:
#load results from parquet files
savepath = r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\CO decoding results\compare-context-lick-stimulus-predict-proba-2025-09-22\separate_aud_vis_stim_trials"
results_dfs={}
for filename in os.listdir(savepath):
    if filename.endswith("_decoding_predict_proba_table.parquet"):
        key = filename.replace("_decoding_predict_proba_table.parquet", "")
        df = pd.read_parquet(os.path.join(savepath, filename))
        results_dfs[key] = df

In [None]:
all_performance=pl.scan_parquet('s3://aind-scratch-data/dynamic-routing/cache/nwb_components/v0.0.268/consolidated/performance.parquet').collect()#.to_pandas()
all_trials=pl.scan_parquet('s3://aind-scratch-data/dynamic-routing/cache/nwb_components/v0.0.268/consolidated/trials.parquet').collect()#.to_pandas()

In [None]:
###update the code below to match subsets of trials, i.e. only aud stim trials

In [None]:
#load specific parquet file
results_dfs={}
loadpath=r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\CO decoding results\compare-context-lick-stimulus-predict-proba-2025-09-22"
filename="response_vis_stim_10units_decoding_predict_proba_table.parquet"

key=filename.replace("_decoding_predict_proba_table.parquet", "")
results_dfs[key] = pd.read_parquet(os.path.join(loadpath, filename))

In [None]:
#select 2 structures - OPTIMIZED VERSION
# sel_structure_1='ORBl'
# sel_structure_2='ACAd'
# sel_session='664851_2023-11-13'

sel_key='context_no_baseline_subtract_vis_stim_10units'
# sel_key='context_500ms_bins'

correction = 'flip_aud' # 'flip_aud', 'subtract_mean', 'none'
excl_instruction_trials = True

# Preprocess data once
results_dfs[sel_key]['bin_center'] = np.round(results_dfs[sel_key]['bin_center'], 3)
dr_session_set = set(dr_session_list)  # Convert to set for O(1) lookup

# Pre-filter all_trials if needed
if excl_instruction_trials:
    all_trials_filtered = all_trials.to_pandas().query('~is_instruction')
else:
    all_trials_filtered = all_trials.to_pandas()

# Pre-compute choice predict proba logic
def compute_choice_predict_proba(df):
    """Vectorized computation of choice predict proba"""
    choice_proba = np.full(len(df), np.nan)
    # False alarms
    fa_mask = ((df['is_response'] == True) & 
               (((df['stim_name'] == 'vis1') & (df['rewarded_modality'] == 'aud')) |
                ((df['stim_name'] == 'sound1') & (df['rewarded_modality'] == 'vis'))))
    choice_proba[fa_mask] = -1
    
    # Correct rejects  
    cr_mask = ((df['is_response'] == False) & 
               (((df['stim_name'] == 'vis1') & (df['rewarded_modality'] == 'aud')) |
                ((df['stim_name'] == 'sound1') & (df['rewarded_modality'] == 'vis'))))
    choice_proba[cr_mask] = 1
    return choice_proba

# Initialize results storage
results_data = []

# Main processing loop
for sel_session in results_dfs[sel_key]['session_id'].unique():
    if sel_session not in dr_session_set:
        continue
        
    # Get session data once
    session_df = results_dfs[sel_key].query('session_id == @sel_session').copy()
    session_trials = all_trials_filtered.query('session_id == @sel_session').copy()
    
    if session_trials.empty:
        continue
    
    # Compute choice predict proba once
    session_trials['choice_predict_proba'] = compute_choice_predict_proba(session_trials)
    
    # Create lookup dictionaries for faster access
    session_df_grouped = session_df.groupby(['bin_center', 'structure'])['predict_proba'].first()

    trial_indices = session_trials['trial_index'].values

    #assume all structures have the same trial indices
    predict_proba_trial_indices = session_df['trial_index'].iloc[0]

    trial_intersection = np.intersect1d(trial_indices, predict_proba_trial_indices)

    temp_predict_proba_index=[]
    for x in predict_proba_trial_indices:
        if x in trial_intersection:
            temp_predict_proba_index.append(True)
        else:
            temp_predict_proba_index.append(False)
    temp_predict_proba_index=np.array(temp_predict_proba_index)

    session_trials = session_trials[session_trials['trial_index'].isin(trial_intersection)]

    structure_list = np.concatenate([session_df['structure'].unique(), np.array(['choice'])])
    bin_centers = session_df['bin_center'].unique()
    
    # Process all structures and bin centers for this session
    session_predict_proba_cols = {}
    
    for bin_center in bin_centers:
        # Add choice column for each bin center
        session_predict_proba_cols[f'choice_predict_proba_{bin_center}'] = session_trials['choice_predict_proba'].values
        
        # Process structures for this bin center
        bin_structures = session_df.query('bin_center == @bin_center')['structure'].unique()
        
        for structure in bin_structures:
            try:
                temp_predict_proba = session_df_grouped.loc[(bin_center, structure)]

                temp_predict_proba = temp_predict_proba[temp_predict_proba_index]
                
                # if len(temp_predict_proba) > trial_indices.max():
                #     temp_predict_proba = temp_predict_proba[trial_indices]
                # else:
                #     print(f'warning! predict_proba length too short for {structure}, session {sel_session}, bin {bin_center}')
                #     continue

                # if len(temp_predict_proba) != len(session_trials):
                #     print(f'warning! predict_proba length mismatch for {structure}, session {sel_session}, bin {bin_center}')
                #     continue
                
                # Vectorized correction computation
                if correction == 'flip_aud':
                    # Create correction mask
                    aud_mask = session_trials['rewarded_modality'] == 'aud'
                    corrected_proba = temp_predict_proba.copy()
                    corrected_proba[aud_mask] = 1 - corrected_proba[aud_mask]
                elif correction == 'subtract_mean':
                    # Group by block and subtract mean
                    corrected_parts = []
                    for block_idx in session_trials['block_index'].unique():
                        block_mask = session_trials['block_index'] == block_idx
                        block_proba = temp_predict_proba[block_mask]
                        corrected_parts.append(block_proba - np.nanmean(block_proba))
                    corrected_proba = np.concatenate(corrected_parts)
                else:  # correction == 'none'
                    corrected_proba = temp_predict_proba.copy()
                
                session_predict_proba_cols[f'{structure}_predict_proba_{bin_center}'] = corrected_proba
                
            except KeyError:
                continue
    
    # Add all computed columns to session_trials at once
    for col_name, col_data in session_predict_proba_cols.items():
        session_trials[col_name] = col_data
    
    # Compute correlations for all structure pairs and bin centers
    for bin_center in bin_centers:
        for structure_1 in structure_list:
            for structure_2 in ['choice']:  # structure_list:
                if structure_1 == structure_2:
                    continue
                
                col1 = f'{structure_1}_predict_proba_{bin_center}'
                col2 = f'{structure_2}_predict_proba_{bin_center}'
                
                if col1 not in session_trials.columns or col2 not in session_trials.columns:
                    continue
                
                result1 = session_trials[col1].values
                result2 = session_trials[col2].values
                
                if len(result1) != len(result2):
                    continue
                
                # Vectorized NaN handling
                valid_mask = ~(np.isnan(result1) | np.isnan(result2))
                
                if np.sum(valid_mask) < 3:  # Need at least 3 points for correlation
                    continue
                
                try:
                    r, p = stats.pearsonr(result1[valid_mask], result2[valid_mask])
                    
                    results_data.append({
                        'structure_1': structure_1,
                        'structure_2': structure_2,
                        'session_id': sel_session,
                        'results_key': sel_key,
                        'bin_center': bin_center,
                        'pearson_r': r,
                        'p_value': p
                    })
                except:
                    continue

# Create DataFrame from results
predict_proba_corr_by_bin_center_df = pd.DataFrame(results_data)

In [None]:
savepath=r'\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\CO decoding results\predict_proba_corr_across_bin_centers_2025-10-01'
predict_proba_corr_by_bin_center_df.to_parquet(os.path.join(savepath, "predict_proba_corr_across_bin_centers_"+sel_key+"_choice_corrected.parquet"))

In [None]:
#compare correlation values across different structure pairs
structure1='MOs'
structure2='choice'

sel_key='context_no_baseline_subtract_vis_stim_10units'

corr_by_session=[]

pair_df=predict_proba_corr_by_bin_center_df.query('structure_1==@structure1 and structure_2==@structure2 and results_key==@sel_key')

for sel_session in pair_df['session_id'].unique():
    temp_session_df=pair_df.query('session_id==@sel_session').sort_values('bin_center')
    if temp_session_df.shape[0]==0:
        continue
    corr_by_session.append(temp_session_df['pearson_r'].values)

corr_by_session=np.vstack(corr_by_session)

bins=np.round(pair_df['bin_center'].unique(),3)+0.025

fig,ax=plt.subplots(1,1,figsize=(5,3))
ax.axhline(0,color='k',linewidth=0.5,linestyle='--')
ax.axvline(0,color='k',linewidth=0.5,linestyle='--')
ax.plot(bins, corr_by_session.T, color='gray', alpha=0.3)
ax.plot(bins, np.nanmean(corr_by_session,axis=0),'k')
ax.set_xlabel('bin center (s)')
ax.set_ylabel('Pearson r')
ax.set_title(f"{structure1} vs {structure2}; {sel_key}\n(n={corr_by_session.shape[0]} sessions)")

In [None]:
#plot list of areas against single area

# sel_structure='SCm'

# struct_list=['SCm','MRN','CP','MOs','AId','ACAd','FRP','ORBl','PL','ILA','SSp','MOp','VISp','CA1','MD','RT','choice',]
# struct_list=['SCm','MRN','CP','MOs','AId','ACAd','FRP','ORBl','PL','choice',]
# struct_list=['SCm','MRN','CP','MOs','MOp','SSp','AId','ACAd','FRP','ORBl','VTA','GPe','RT','MD','choice',]
# struct_list=['VISp','AUDp','LGd','MGd','MD','RT','VAL']
# struct_list=['choice']
# struct_list=['SCm','MRN','CP','MOs','AId','ACAd','ORBl','VTA','GPe','RT','MD','choice',]


# struct_list=['ACAd','MOs','PL','ORBl','FRP','AId','MOp','SSp','AUDp','VISp'] #ctx
# struct_list=['SCm','SCs','MRN','PAG','APN','CP','GPe','VTA','SNr'] #mb
# struct_list=['MD','RT','VAL','VPL','ZI','VL','POL','LGd','MGd'] #thal

structure_sets={
    'cortex': ['ACAd','MOs','PL','ORBl','FRP','AId','MOp','SSp','AUDp','VISp'],
    'mb_bg': ['SCm','SCs','MRN','PAG','APN','CP','GPe','VTA','SNr'],
    'thalamus': ['MD','RT','VAL','VPL','ZI','VL','POL','VPM','LGd','MGd'],
}


# savepath=r"C:\Users\ethan.mcbride\OneDrive - Allen Institute\quick figures\2025-10-01-decoding_latency"

color_list=plt.cm.tab10.colors

sel_key='context_no_baseline_subtract_vis_stim_10units'

for structure_set_name, struct_list in structure_sets.items():

    for sel_structure in struct_list:
        sel_structure='choice'
        n_list=[]

        fig,ax=plt.subplots(1,1,figsize=(5.5,3))
        ax.axhline(0,color='k',linewidth=0.5,linestyle='--')
        ax.axvline(0,color='k',linewidth=0.5,linestyle='--')
        for ss,sel_structure_2 in enumerate(struct_list):
            if sel_structure_2 == sel_structure:
                continue
            pair_df=predict_proba_corr_by_bin_center_df.query('structure_2==@sel_structure and structure_1==@sel_structure_2 and results_key==@sel_key')

            corr_by_session=[]

            for sel_session in pair_df['session_id'].unique():
                temp_session_df=pair_df.query('session_id==@sel_session').sort_values('bin_center')
                if temp_session_df.shape[0]==0:
                    continue
                corr_by_session.append(temp_session_df['pearson_r'].values)

            if len(corr_by_session)<=1:
                continue

            corr_by_session=np.vstack(corr_by_session)

            bins=np.round(pair_df['bin_center'].unique(),3)+0.025

            if ss<len(color_list):
                line_style='-'
            if ss>len(color_list)-1:
                ss=ss-len(color_list)
                line_style='--'

            ax.plot(bins, np.nanmean(corr_by_session,axis=0), label=sel_structure_2+f' ({corr_by_session.shape[0]})', color=color_list[ss], linestyle=line_style)
        ax.set_xlabel('bin center (s)')
        ax.set_ylabel('Pearson r')
        # ax.set_ylim([0.1,0.9])
        # ax.set_xlim([-0.05,0.25])
        ax.set_title(f"{sel_structure} vs other areas; {sel_key}")
        ax.legend(loc='upper right', bbox_to_anchor=(1.35, 1))
        fig.tight_layout()
        fig.savefig(os.path.join(savepath, f"{sel_structure}_vs_other_areas_{sel_key}_{structure_set_name}.png"), dpi=300)
        plt.close(fig)
        break
