In [1]:
import numpy as np
import scipy.stats as stats
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt
import os
import random

import matplotlib
import matplotlib.font_manager as fm

matplotlib.rcParams['font.size'] = 8
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
fm.FontProperties().set_family('arial')

%load_ext autoreload
%autoreload 2
%matplotlib inline
# %matplotlib widget
# %matplotlib notebook

In [2]:
session_table_path=r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\CO decoding results\session_table_v0.272.parquet"
# session_table_path="/Users/ethan.mcbride/Data/DR/session_table_v0.265.csv"
session_table=pl.read_parquet(session_table_path)

dr_session_list=(
    session_table.filter(
    pl.col('project')=="DynamicRouting",
    pl.col('is_production'),
    pl.col('is_annotated'),
    pl.col('issues')==[],
    pl.col('is_engaged'),
    pl.col('is_good_behavior').eq(True),
    )['session_id'].to_list()
    )

In [None]:
#load specific parquet file
results_dfs={}
loadpath=r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\CO decoding results\compare-context-lick-stimulus-predict-proba-2025-09-22\separate_aud_vis_stim_trials"
# filename="context_vis_stim_25ms_bins_w_repeats_decoding_predict_proba_table.parquet"
# filename="context_all_trials_decoding_predict_proba_table.parquet"
# filename="response_all_trials_decoding_predict_proba_table.parquet"
# filename="context_baseline_subtract_vis_stim_10units_decoding_predict_proba_table.parquet"

key=filename.replace("_decoding_predict_proba_table.parquet", "")
results_dfs[key] = pd.read_parquet(os.path.join(loadpath, filename))

In [None]:
all_performance=pl.scan_parquet('s3://aind-scratch-data/dynamic-routing/cache/nwb_components/v0.0.268/consolidated/performance.parquet').collect()#.to_pandas()
all_trials=pl.scan_parquet('s3://aind-scratch-data/dynamic-routing/cache/nwb_components/v0.0.268/consolidated/trials.parquet').collect()#.to_pandas()

In [None]:
sel_session='742903_2024-10-22'
sel_structure='ACAd'
sel_time_aligned_to='stim_start_time'

session_results_df=results_dfs[key].query(
    'session_id==@sel_session and structure==@sel_structure and time_aligned_to==@sel_time_aligned_to'
    ).sort_values(['structure','bin_center'])
# session_results_df

In [None]:
# first, get all vis1 is_response is_vis_rewarded trials
sel_stim_name='vis1'
sel_is_vis_rewarded=True
sel_is_response=True

trial_index=session_results_df['trial_index'].iloc[0]
stim_name=session_results_df['stim_name'].iloc[0]
is_vis_rewarded=session_results_df['is_vis_rewarded'].iloc[0]
is_response=session_results_df['is_response'].iloc[0]

sel_trials=(
    (stim_name==sel_stim_name) &
    (is_vis_rewarded==sel_is_vis_rewarded) &
    (is_response==sel_is_response)
)

sel_trial_indices=trial_index[sel_trials]

predict_proba_stack=np.vstack(session_results_df['predict_proba'].values).T
predict_proba_stack=predict_proba_stack[sel_trials,:]

mean_predict_proba=np.nanmean(predict_proba_stack,axis=0)

residual_predict_proba=predict_proba_stack - mean_predict_proba

In [None]:
fig,ax=plt.subplots(1,1,figsize=(6,4))
im=ax.imshow(predict_proba_stack,aspect='auto',cmap='bwr',vmin=0.2,vmax=0.8)
ax.set_title('raw predict proba')

fig,ax=plt.subplots(1,1,figsize=(6,4))
im=ax.imshow(residual_predict_proba,aspect='auto',cmap='bwr',vmin=-0.3,vmax=0.3)
ax.set_title('residual predict proba')

In [None]:
sel_session='742903_2024-10-22'
# sel_struct_1='ACAd'
# sel_struct_2='MOs'
sel_time_aligned_to='stim_start_time'

example_plot_area='ACAd'

residual_predict_proba_by_structure={
    'structure':[],
    'flattened_residual_predict_proba':[],
    'raw_predict_proba':[]
}


for sel_structure in results_dfs[key].query('session_id==@sel_session')['structure'].unique():
    session_results_df=results_dfs[key].query(
        'session_id==@sel_session and structure==@sel_structure and time_aligned_to==@sel_time_aligned_to'
        ).sort_values(['structure','bin_center'])

    sel_stim_name='vis1'
    sel_is_vis_rewarded=True
    sel_is_response=True

    trial_index=session_results_df['trial_index'].iloc[0]
    stim_name=session_results_df['stim_name'].iloc[0]
    is_vis_rewarded=session_results_df['is_vis_rewarded'].iloc[0]
    is_response=session_results_df['is_response'].iloc[0]

    sel_trials=(
        (stim_name==sel_stim_name) &
        (is_vis_rewarded==sel_is_vis_rewarded) &
        (is_response==sel_is_response)
    )

    sel_trial_indices=trial_index[sel_trials]

    predict_proba_stack=np.vstack(session_results_df['predict_proba'].values).T
    predict_proba_stack=predict_proba_stack[sel_trials,:]

    mean_predict_proba=np.nanmean(predict_proba_stack,axis=0)

    residual_predict_proba=predict_proba_stack - mean_predict_proba

    flattened_residuals=residual_predict_proba.flatten()

    residual_predict_proba_by_structure['structure'].append(sel_structure)
    residual_predict_proba_by_structure['flattened_residual_predict_proba'].append(flattened_residuals)
    residual_predict_proba_by_structure['raw_predict_proba'].append(predict_proba_stack.flatten())

    if sel_structure==example_plot_area:
        fig,ax=plt.subplots(1,1,figsize=(6,4))
        ax.plot(session_results_df['bin_center'],predict_proba_stack.T, color='lightgray', alpha=0.25)
        ax.plot(session_results_df['bin_center'],mean_predict_proba, color='black', linewidth=2)
        ax.set_title(f'Predict proba variation\nStructure: {sel_structure}')

In [None]:
residual_predict_proba_by_structure_df=pd.DataFrame(residual_predict_proba_by_structure)

In [None]:
residual_predict_proba_by_structure_df

In [None]:
fig,ax=plt.subplots(1,1,figsize=(6,4))
im=ax.imshow(np.vstack(residual_predict_proba_by_structure['raw_predict_proba']),aspect='auto',cmap='bwr',vmin=0,vmax=1.0, interpolation='none')
ax.set_yticks(range(len(residual_predict_proba_by_structure['structure'])))
ax.set_yticklabels(residual_predict_proba_by_structure['structure'])
ax.set_xlabel('time bins')
ax.set_title('Raw predict proba by structure')
fig.colorbar(im, ax=ax, label='Predict proba')
fig.tight_layout()

fig,ax=plt.subplots(1,1,figsize=(6,4))
im=ax.imshow(np.vstack(residual_predict_proba_by_structure['flattened_residual_predict_proba']),aspect='auto',cmap='bwr',vmin=-0.5,vmax=0.5, interpolation='none')
ax.set_yticks(range(len(residual_predict_proba_by_structure['structure'])))
ax.set_yticklabels(residual_predict_proba_by_structure['structure'])
ax.set_xlabel('time bins')
ax.set_title('Residual predict proba by structure')
fig.colorbar(im, ax=ax, label='Residual predict proba')
fig.tight_layout()

In [None]:
#get correlation between structures

corr_matrix=np.full((len(residual_predict_proba_by_structure_df),len(residual_predict_proba_by_structure_df)),np.nan)

corr_matrix_raw=np.full((len(residual_predict_proba_by_structure_df),len(residual_predict_proba_by_structure_df)),np.nan)

for s1,struct1 in enumerate(residual_predict_proba_by_structure_df['structure']):
    for s2,struct2 in enumerate(residual_predict_proba_by_structure_df['structure']):
        if struct1 != struct2:
            vec1=residual_predict_proba_by_structure_df.loc[
                residual_predict_proba_by_structure_df['structure']==struct1,
                'flattened_residual_predict_proba'
            ].values[0]
            vec2=residual_predict_proba_by_structure_df.loc[
                residual_predict_proba_by_structure_df['structure']==struct2,
                'flattened_residual_predict_proba'
            ].values[0]
            corr_coef=np.corrcoef(vec1,vec2)[0,1]
            corr_matrix[s1,s2]=corr_coef

            #also compute for raw predict proba
            raw_vec1=residual_predict_proba_by_structure_df.loc[
                residual_predict_proba_by_structure_df['structure']==struct1,
                'raw_predict_proba'
            ].values[0]
            raw_vec2=residual_predict_proba_by_structure_df.loc[
                residual_predict_proba_by_structure_df['structure']==struct2,
                'raw_predict_proba'
            ].values[0]
            raw_corr_coef=np.corrcoef(raw_vec1,raw_vec2)[0,1]
            corr_matrix_raw[s1,s2]=raw_corr_coef


In [None]:
#plot heatmaps of correlation matrices
fig,ax=plt.subplots(1,1,figsize=(8,7))
im=ax.imshow(corr_matrix_raw,aspect='auto',cmap='bwr',vmin=-0.5,vmax=0.5, interpolation='none')
ax.set_xticks(range(len(residual_predict_proba_by_structure['structure'])))
ax.set_xticklabels(residual_predict_proba_by_structure['structure'], rotation=90)
ax.set_yticks(range(len(residual_predict_proba_by_structure['structure'])))
ax.set_yticklabels(residual_predict_proba_by_structure['structure'])
fig.suptitle(f'Raw predict proba correlation matrix\nSession: {sel_session}')
fig.colorbar(im, ax=ax, label='Correlation coefficient')
fig.tight_layout()

fig,ax=plt.subplots(1,1,figsize=(8,7))
im=ax.imshow(corr_matrix,aspect='auto',cmap='bwr',vmin=-0.5,vmax=0.5, interpolation='none')
ax.set_xticks(range(len(residual_predict_proba_by_structure['structure'])))
ax.set_xticklabels(residual_predict_proba_by_structure['structure'], rotation=90)
ax.set_yticks(range(len(residual_predict_proba_by_structure['structure'])))
ax.set_yticklabels(residual_predict_proba_by_structure['structure'])
fig.suptitle(f'Residual predict proba correlation matrix\nSession: {sel_session}')
fig.colorbar(im, ax=ax, label='Correlation coefficient')
fig.tight_layout()

fig,ax=plt.subplots(1,1,figsize=(8,7))
im=ax.imshow(np.abs(corr_matrix)-np.abs(corr_matrix_raw),aspect='auto',cmap='bwr',vmin=-0.3,vmax=0.3, interpolation='none')
ax.set_xticks(range(len(residual_predict_proba_by_structure['structure'])))
ax.set_xticklabels(residual_predict_proba_by_structure['structure'], rotation=90)
ax.set_yticks(range(len(residual_predict_proba_by_structure['structure'])))
ax.set_yticklabels(residual_predict_proba_by_structure['structure'])
fig.suptitle(f'Residual minus raw predict proba correlation matrix\nSession: {sel_session}')
fig.colorbar(im, ax=ax, label='Correlation coefficient')
fig.tight_layout()


In [None]:
#run cross-correlation on each pair of structures
sel_structure_1='ORBvl' #structure 1 shifts
sel_structure_2='MOs'
plot_col='flattened_residual_predict_proba'
# plot_col='raw_predict_proba'
lag_limit=25 #seconds

struct1_values=residual_predict_proba_by_structure_df.query('structure==@sel_structure_1')[plot_col].values[0]

struct2_values=residual_predict_proba_by_structure_df.query('structure==@sel_structure_2')[plot_col].values[0]

corr=np.correlate(struct1_values, struct2_values, mode='same')
lag_values=np.arange(-len(struct1_values)//2, len(struct1_values)//2)*0.01  #assuming 10ms bins

lag_mask=(lag_values>=-lag_limit) & (lag_values<=lag_limit)
corr=corr[lag_mask]
lag_values=lag_values[lag_mask]

max_corr_index=np.argmax(corr)
max_corr_lag=lag_values[max_corr_index]

fig,ax=plt.subplots(1,1,figsize=(6,4))
ax.axvline(0, color='gray', linestyle='--')
ax.plot(lag_values,corr)
ax.plot(max_corr_lag, corr[max_corr_index], 'ro')
ax.set_xlim(-1,1)
ax.set_xlabel(f'lag in {sel_structure_1} relative to {sel_structure_2} (s)')

ax.set_title(f'Cross-correlation of {plot_col}\n{sel_structure_1} vs {sel_structure_2}: max at {max_corr_lag:.2f}s')


In [None]:
#create trial-shuffled controls

In [None]:
#get lag for all structure pairs

lag_limit=25 #seconds

lag_matrix=np.full((len(residual_predict_proba_by_structure_df),len(residual_predict_proba_by_structure_df)),np.nan)

lag_matrix_raw=np.full((len(residual_predict_proba_by_structure_df),len(residual_predict_proba_by_structure_df)),np.nan)

for s1,struct1 in enumerate(residual_predict_proba_by_structure_df['structure']):
    for s2,struct2 in enumerate(residual_predict_proba_by_structure_df['structure']):
        if struct1 != struct2:
            vec1=residual_predict_proba_by_structure_df.loc[
                residual_predict_proba_by_structure_df['structure']==struct1,
                'flattened_residual_predict_proba'
            ].values[0]
            vec2=residual_predict_proba_by_structure_df.loc[
                residual_predict_proba_by_structure_df['structure']==struct2,
                'flattened_residual_predict_proba'
            ].values[0]
            corr=np.correlate(vec1, vec2, mode='same')
            lag_values=np.arange(-len(vec1)//2, len(vec1)//2)*0.01  #assuming 10ms bins
            lag_mask=(lag_values>=-lag_limit) & (lag_values<=lag_limit)
            corr=corr[lag_mask]
            lag_values=lag_values[lag_mask]
            max_corr_index=np.argmax(corr)
            max_corr_lag=lag_values[max_corr_index]
            lag_matrix[s1,s2]=max_corr_lag

            #also compute for raw predict proba
            raw_vec1=residual_predict_proba_by_structure_df.loc[
                residual_predict_proba_by_structure_df['structure']==struct1,
                'raw_predict_proba'
            ].values[0]
            raw_vec2=residual_predict_proba_by_structure_df.loc[
                residual_predict_proba_by_structure_df['structure']==struct2,
                'raw_predict_proba'
            ].values[0]
            raw_corr=np.correlate(raw_vec1, raw_vec2, mode='same')
            raw_corr=raw_corr[lag_mask]
            raw_max_corr_index=np.argmax(raw_corr)
            raw_max_corr_lag=lag_values[raw_max_corr_index]
            lag_matrix_raw[s1,s2]=raw_max_corr_lag

In [None]:
#plot heatmaps of lag values

fig,ax=plt.subplots(1,1,figsize=(8,7))
im=ax.imshow(lag_matrix_raw,aspect='auto',cmap='bwr',vmin=-0.5,vmax=0.5, interpolation='none')
ax.set_xticks(range(len(residual_predict_proba_by_structure['structure'])))
ax.set_xticklabels(residual_predict_proba_by_structure['structure'], rotation=90)
ax.set_yticks(range(len(residual_predict_proba_by_structure['structure'])))
ax.set_yticklabels(residual_predict_proba_by_structure['structure'])
fig.suptitle(f'lag between raw predict proba\nSession: {sel_session}')
fig.colorbar(im, ax=ax, label='Lag (s)')
fig.tight_layout()

fig,ax=plt.subplots(1,1,figsize=(8,7))
im=ax.imshow(lag_matrix,aspect='auto',cmap='bwr',vmin=-10,vmax=10, interpolation='none')
ax.set_xticks(range(len(residual_predict_proba_by_structure['structure'])))
ax.set_xticklabels(residual_predict_proba_by_structure['structure'], rotation=90)
ax.set_yticks(range(len(residual_predict_proba_by_structure['structure'])))
ax.set_yticklabels(residual_predict_proba_by_structure['structure'])
fig.suptitle(f'lag between residual predict proba\nSession: {sel_session}')
fig.colorbar(im, ax=ax, label='Lag (s)')
fig.tight_layout()
