### Setup

In [15]:
import pandas as pd
import numpy as np
from pathlib import Path

from one.api import ONE
from iblatlas.atlas import AllenAtlas
from iblatlas.regions import BrainRegions
from ibllib.atlas.plots import plot_swanson_vector
#from brainwidemap.manifold.state_space_bwm import (plot_traj_and_dist,
#                                                   plot_all)

from statsmodels.stats.multitest import multipletests

from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import to_hex
from matplotlib import gridspec
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import dataframe_image as dfi
from matplotlib.gridspec import GridSpec

from PIL import Image


#sns.set(font_scale=1.5)
#sns.set_style('ticks')

ba = AllenAtlas()
br = BrainRegions()
one = ONE(base_url='https://openalyx.internationalbrainlab.org',
          password='international', silent=True)
          
# get pooled results here
meta_pth = Path(one.cache_dir, 'meta')
meta_pth.mkdir(parents=True, exist_ok=True)          

pth_res = Path(one.cache_dir, 'manifold', 'res') 
pth_res.mkdir(parents=True, exist_ok=True)
pth_avg = Path(one.cache_dir, 'manifold', 'avgs') 
pth_avg.mkdir(parents=True, exist_ok=True)


sigl = 0.05

variables = ['duringstim', 'duringchoice', 'duringfback', 'intertrial']

meta_splits = { # 0:correct trials; 1: incorrect trials
    'duringstim_all_0': ['block_duringstim_r_choice_r_f1', 'block_duringstim_l_choice_l_f1',
                         'durings_srcrbl_slclbl', 'durings_srcrbr_slclbr',
                         'durings_srcrbl_slclbr','durings_slclbl_srcrbr'],
    'duringstim_subset': ['durings_srcrbr_slclbr', 'durings_srcrbl_slclbl',
                         ],
    'duringstim_all': ['block_duringstim_r_choice_r_f1', 'block_duringstim_l_choice_l_f1',
                         'durings_srcrbl_slclbl', 'durings_srcrbr_slclbr',
                         'durings_srcrbl_slclbr','durings_slclbl_srcrbr',
                         'block_duringstim_r_choice_l_f2', 'block_duringstim_l_choice_r_f2', 
                         'durings_slcrbl_srclbl','durings_slcrbr_srclbr',
                         'durings_slcrbr_srclbl', 'durings_slcrbl_srclbr'],
    'duringstim_all_1': ['block_duringstim_r_choice_l_f2', 'block_duringstim_l_choice_r_f2', 
                         'durings_slcrbl_srclbl','durings_slcrbr_srclbr',
                         'durings_slcrbr_srclbl', 'durings_slcrbl_srclbr'],

    #'duringstim_old': ['block_stim_l_all', 'block_stim_r_all', 
    #               'stim_block_l', 'stim_block_r',
    #               'block_concordant', 'block_discordant', 'concordant'],
    'duringstim': ['block_duringstim_r_choice_r_f1', 'block_duringstim_l_choice_l_f1',
                   'block_duringstim_l_choice_r_f2', 'block_duringstim_r_choice_l_f2'],
    'act_duringstim': ['act_block_duringstim_r_choice_r_f1', 'act_block_duringstim_l_choice_l_f1',
                   'act_block_duringstim_l_choice_r_f2', 'act_block_duringstim_r_choice_l_f2'],
    'duringchoice': ['block_stim_r_duringchoice_r_f1', 'block_stim_l_duringchoice_l_f1',
                     'block_stim_l_duringchoice_r_f2', 'block_stim_r_duringchoice_l_f2'],
    'act_duringchoice': ['act_block_stim_r_duringchoice_r_f1', 'act_block_stim_l_duringchoice_l_f1',
                     'act_block_stim_l_duringchoice_r_f2', 'act_block_stim_r_duringchoice_l_f2'],
    #'duringchoice_old': ['block_choice_l', 'block_choice_r',
    #                 'choice_block_l', 'choice_block_r',
    #                 'block_concordant_duringchoice', 'block_discordant_duringchoice',
    #                'concordant_duringchoice'],
    'duringchoice_all_0': ['block_stim_r_duringchoice_r_f1', 'block_stim_l_duringchoice_l_f1',
                         'duringc_srcrbl_slclbl', 'duringc_srcrbr_slclbr',
                         'duringc_srcrbl_slclbr','duringc_slclbl_srcrbr'],
    'duringchoice_subset': ['duringc_srcrbl_slclbl', 'duringc_srcrbr_slclbr',
                         ],
    'duringchoice_all': ['block_stim_r_duringchoice_r_f1', 'block_stim_l_duringchoice_l_f1',
                         'duringc_srcrbl_slclbl', 'duringc_srcrbr_slclbr',
                         'duringc_srcrbl_slclbr','duringc_slclbl_srcrbr',
                         'block_stim_r_duringchoice_l_f2', 'block_stim_l_duringchoice_r_f2', 
                         'duringc_slcrbl_srclbl','duringc_slcrbr_srclbr',
                         'duringc_slcrbr_srclbl', 'duringc_slcrbl_srclbr'],
    'duringchoice_all_1': ['block_stim_r_duringchoice_l_f2', 'block_stim_l_duringchoice_r_f2', 
                         'duringc_slcrbl_srclbl','duringc_slcrbr_srclbr',
                         'duringc_slcrbr_srclbl', 'duringc_slcrbl_srclbr'],
    'intertrial': ['block_stim_r_choice_r_f1', 'block_stim_l_choice_l_f1',
                  'block_stim_l_choice_r_f2', 'block_stim_r_choice_l_f2'],
    'act_intertrial': ['act_block_stim_r_choice_r_f1', 'act_block_stim_l_choice_l_f1',
                  'act_block_stim_l_choice_r_f2', 'act_block_stim_r_choice_l_f2'],
    'intertrial_all_0': ['block_stim_r_choice_r_f1', 'block_stim_l_choice_l_f1',
                         'srcrbl_slclbl', 'srcrbr_slclbr',
                         'srcrbl_slclbr','slclbl_srcrbr'],
    'intertrial_all': ['block_stim_r_choice_r_f1', 'block_stim_l_choice_l_f1',
                         'srcrbl_slclbl', 'srcrbr_slclbr',
                         'srcrbl_slclbr','slclbl_srcrbr',
                         'block_stim_r_choice_l_f2', 'block_stim_l_choice_r_f2', 
                         'slcrbl_srclbl','slcrbr_srclbr',
                         'slcrbr_srclbl', 'slcrbl_srclbr'],
    'intertrial_all_1': ['block_stim_r_choice_l_f2', 'block_stim_l_choice_r_f2', 
                         'slcrbl_srclbl','slcrbr_srclbr',
                         'slcrbr_srclbl', 'slcrbl_srclbr'],
    'sc_duringstim': ['stim', 'choice_duringstim'], # sc splits: always stim first, choice last
    'sc_duringchoice': ['stim_duringchoice', 'choice'],
    'sc_duringstim1': ['stim_choice_l', 'stim_choice_r', 
                       'choice_duringstim_l', 'choice_duringstim_r'],
    'sc_duringchoice1': ['stim_duringchoice_l', 'stim_duringchoice_r', 
                         'choice_stim_l', 'choice_stim_r'],
    'sc_duringfback': ['stim_duringfback', 'choice_duringfback'],
    'sc_intertrial': ['stim_intertrial', 'choice_intertrial']
}


# [pre_time, post_time]
pre_post = {}
for meta_split in meta_splits:
    if 'durings' in meta_split:
        pre_post[meta_split] = [0,0.15]
    elif 'duringc' in meta_split:
        pre_post[meta_split] = [-0.2,0.35]
    elif 'intertrial' in meta_split:
        pre_post[meta_split] = [0.4,-0.1]
for split in ['block_duringstim_l_choice_l_f1', 'block_duringstim_r_choice_r_f1',
              'block_stim_l_duringchoice_l_f1', 'block_stim_r_duringchoice_r_f1',
              'block_stim_l_choice_l_f1', 'block_stim_r_choice_r_f1',
              'block_duringstim_l_choice_r_f2', 'block_duringstim_r_choice_l_f2',
              'block_stim_l_duringchoice_r_f2', 'block_stim_r_duringchoice_l_f2',
              'block_stim_l_choice_r_f2', 'block_stim_r_choice_l_f2'
             ]:
    if 'durings' in split:
        pre_post[split] = [0,0.15]
    elif 'duringc' in split:
        pre_post[split] = [-0.2,0.35]
    else:
        pre_post[split] = [0.4,-0.1]
for split in ['block_stim_l_duringchoice_l_f1_long', 
              'block_stim_r_duringchoice_r_f1_long']:
    pre_post[split] = [-0.2,0.45]


def swanson_to_beryl_hex(beryl_acronym,br):
    beryl_id = br.id[br.acronym==beryl_acronym]
    rgb = br.get(ids=beryl_id)['rgb'][0].astype(int)
    return '#' + rgb_to_hex((rgb[0],rgb[1],rgb[2]))

def beryl_to_cosmos(beryl_acronym,br):
    beryl_id = br.id[br.acronym==beryl_acronym]
    return br.get(ids=br.remap(beryl_id, source_map='Beryl', 
                  target_map='Cosmos'))['acronym'][0]

def rgb_to_hex(rgb):
    return '%02x%02x%02x' % rgb


def get_name(brainregion):
    regid = br.id[np.argwhere(br.acronym == brainregion)][0, 0]
    return br.name[np.argwhere(br.id == regid)[0, 0]]


def get_cmap_(meta_split):
    '''
    for each split, get a colormap defined by Yanliang,
    updated by Chris
    '''
    dc = {}
    if 'int_mov' in meta_split or 'move_shape' in meta_split:
        base_colors = ['#ffffb3', '#ffed6f', 
                          '#feda7e', '#feb23f', '#d55607']
        base_cmap = LinearSegmentedColormap.from_list("orange_yellow", base_colors)
        dc[meta_split] = base_cmap(np.linspace(0, 1, 256))
        dc[meta_split][0] = [int("57", 16)/255, int("C1", 16)/255, int("EB", 16)/255, 1.0]  # #57C1EB = blue
        # dc[meta_split][-1] = [int("d5", 16)/255, int("56", 16)/255, int("07", 16)/255, 1.0]  # #d55607 = red
    elif 'stim_' in meta_split:
        dc[meta_split] = ["#EAF4B3","#D5E1A0", "#A3C968",
                          "#86AF40", "#517146","#33492E"]
    elif 'choice_' in meta_split:
        dc[meta_split] = ["#F8E4AA","#F9D766","#E8AC22",
                          "#DA4727","#96371D"]
    elif 'sc' in meta_split:
        # dc[meta_split] = ["#57C1EB", "#C7F9CC", "#FAD1E6", "#F49AC2", "#D36C9B"] # blue-pink
        dc[meta_split] = ['#57C1EB', '#ffffb3', '#ffed6f', 
                          '#feda7e', '#feb23f']
    else:
        dc[meta_split] = ["#D0CDE4","#998DC3","#6159A6",
                          "#42328E", "#262054"]


    return LinearSegmentedColormap.from_list("mycmap", dc[meta_split])


In [None]:
run_align = {
    'intertrial': ['block_stim_r_choice_r_f1', 'block_stim_l_choice_l_f1', 
                   'block_stim_l_choice_r_f2', 'block_stim_r_choice_l_f2'
                   ],
    'intertrial0': ['block_only'],
    'block_duringstim': ['block_duringstim_r_choice_r_f1', 'block_duringstim_l_choice_l_f1', 
                     'block_duringstim_l_choice_r_f2', 'block_duringstim_r_choice_l_f2'
                     ],
    'block_duringchoice': ['block_stim_r_duringchoice_r_f1', 'block_stim_l_duringchoice_l_f1', 
                            'block_stim_l_duringchoice_r_f2', 'block_stim_r_duringchoice_l_f2'
                            ],
    'intertrial1': ['block_stim_r_choice_r_f1', 'block_stim_l_choice_l_f1', 
                   ],
    'block_duringstim1': ['block_duringstim_r_choice_r_f1', 'block_duringstim_l_choice_l_f1', 
                     ],
    'block_duringchoice1': ['block_stim_r_duringchoice_r_f1', 'block_stim_l_duringchoice_l_f1', 
                            ],
    'act_intertrial': ['act_block_stim_r_choice_r_f1', 'act_block_stim_l_choice_l_f1', 
                   'act_block_stim_l_choice_r_f2', 'act_block_stim_r_choice_l_f2'
                   ],
    'act_intertrial0': ['act_block_only'],
    'act_block_duringstim': ['act_block_duringstim_r_choice_r_f1', 'act_block_duringstim_l_choice_l_f1', 
                     'act_block_duringstim_l_choice_r_f2', 'act_block_duringstim_r_choice_l_f2'
                     ],
    'act_block_duringchoice': ['act_block_stim_r_duringchoice_r_f1', 'act_block_stim_l_duringchoice_l_f1', 
                            'act_block_stim_l_duringchoice_r_f2', 'act_block_stim_r_duringchoice_l_f2'
                            ],
    'stim_duringstim0': ['stim_choice_r_block_r', 'stim_choice_l_block_l', 
             'stim_choice_r_block_l', 'stim_choice_l_block_r'],
    'choice_duringchoice0': ['choice_stim_r_block_r', 'choice_stim_l_block_l', 
               'choice_stim_r_block_l', 'choice_stim_l_block_r'],
    'stim_duringchoice0': ['stim_duringchoice_r_block_r', 
                          'stim_duringchoice_l_block_l', 
                          'stim_duringchoice_r_block_l', 
                          'stim_duringchoice_l_block_r'],
    'choice_duringstim0': ['choice_duringstim_r_block_r', 
                          'choice_duringstim_l_block_l', 
                          'choice_duringstim_r_block_l', 
                          'choice_duringstim_l_block_r'],
    'act_stim_duringstim0': ['stim_choice_r_block_r_act', 'stim_choice_l_block_l_act', 
             'stim_choice_r_block_l_act', 'stim_choice_l_block_r_act'],
    'act_choice_duringchoice0': ['choice_stim_r_block_r_act', 'choice_stim_l_block_l_act', 
               'choice_stim_r_block_l_act', 'choice_stim_l_block_r_act'],
    'act_stim_duringchoice0': ['stim_duringchoice_r_block_r_act', 
                          'stim_duringchoice_l_block_l_act', 
                          'stim_duringchoice_r_block_l_act', 
                          'stim_duringchoice_l_block_r_act'],
    'act_choice_duringstim0': ['choice_duringstim_r_block_r_act', 
                          'choice_duringstim_l_block_l_act', 
                          'choice_duringstim_r_block_l_act', 
                          'choice_duringstim_l_block_r_act'],
    'stim_duringstim': ['stim_choice_l', 'stim_choice_r'], 
    'choice_duringchoice': ['choice_stim_l', 'choice_stim_r'],
    'choice_duringstim': ['choice_duringstim_l', 'choice_duringstim_r'],
    'stim_duringchoice': ['stim_duringchoice_l', 'stim_duringchoice_r'],
    'stim_duringstim1': ['stim_block_l', 'stim_block_r'],
    'act_stim_duringstim1': ['stim_block_l_act', 'stim_block_r_act'],
    # 'stim_duringstim1': ['stim_choice_r_block_r_short', 'stim_choice_l_block_l_short', 
    #                      'stim_choice_r_block_l_short', 'stim_choice_l_block_r_short'], 
}


In [None]:
meta_splits = {
    'duringstim': ['block_duringstim_r_choice_r_f1', 'block_duringstim_l_choice_l_f1',
                   'block_duringstim_l_choice_r_f2', 'block_duringstim_r_choice_l_f2'],    
    'duringchoice': ['block_stim_r_duringchoice_r_f1', 'block_stim_l_duringchoice_l_f1',
                     'block_stim_l_duringchoice_r_f2', 'block_stim_r_duringchoice_l_f2'],
    'intertrial': ['block_stim_r_choice_r_f1', 'block_stim_l_choice_l_f1',
                  'block_stim_l_choice_r_f2', 'block_stim_r_choice_l_f2'],
}

values = [1.0, 0.25, 0.125, 0.0625, 0.0]

expanded_meta_splits={}
for key, entries in meta_splits.items():
    for value in values:
        new_key = f"{key}_{value}"
        expanded_meta_splits[new_key] = [f"{entry}_{value}" for entry in entries]

pre_post = {}
for meta_split in expanded_meta_splits:
    if 'durings' in meta_split:
        pre_post[meta_split] = [0,0.15]
    elif 'duringc' in meta_split:
        pre_post[meta_split] = [-0.2,0.35]
    elif 'intertrial' in meta_split:
        pre_post[meta_split] = [0.4,-0.1]


### plot sc region histogram

In [None]:
sc_threshold=0.6
res = get_sc_table(times, ptype, alpha, sc_threshold=sc_threshold)

In [None]:
sc_threshold=0.6

sc_duringstim = np.array(res['sc_duringstim'])
sc_duringchoice = np.array(res['sc_duringchoice'])
# sc_duringstim_mv = np.array(res['sc_duringstim_move_shape'].fillna(0))
sc_duringchoice_mv = np.array(res['sc_duringchoice_move_shape'].fillna(0))

# Fill NaNs only in the binary mv variables
# sc_duringstim_mv = np.nan_to_num(sc_duringstim_mv, nan=0)
sc_duringchoice_mv = np.nan_to_num(sc_duringchoice_mv, nan=0)

sc_duringstim_mv = np.zeros_like(sc_duringstim, dtype=float)

# Build DataFrames
df_stim = pd.DataFrame({
    "sc_duringstim_mv": sc_duringstim_mv,
    "sc_duringstim": sc_duringstim
})
df_choice = pd.DataFrame({
    "sc_duringchoice_mv": sc_duringchoice_mv,
    "sc_duringchoice": sc_duringchoice
})

# Highlight subsets
move_choice = df_choice[res['sc_duringchoice_int_mov']==1]
stim_choice = df_choice[res['sc_duringchoice_int_mov']==0]
int_choice = df_choice[res['sc_duringchoice_int_mov']==0.5]
move_stim = df_stim[res['sc_duringstim_int_mov']==1]
stim_stim = df_stim[res['sc_duringstim_int_mov']==0]
int_stim = df_stim[res['sc_duringstim_int_mov']==0.5]

fig, axes = plt.subplots(1, 2, figsize=(4, 2), dpi=120, gridspec_kw={'width_ratios': [1, 2]}, sharey=True)

sns.stripplot(
    x="sc_duringchoice_mv", y="sc_duringchoice", data=int_choice,
    color='#feda7e', jitter=0.25, ax=axes[1], size=3
)

sns.stripplot(
    x="sc_duringstim_mv", y="sc_duringstim", data=int_stim,
    color='#feda7e', jitter=0.25, ax=axes[0], size=3
)

sns.stripplot(
    x="sc_duringstim_mv", y="sc_duringstim", data=move_stim,
    color='#d55607', jitter=0.25, ax=axes[0], size=3
)

sns.stripplot(
    x="sc_duringchoice_mv", y="sc_duringchoice", data=move_choice,
    color='#d55607', jitter=0.25, ax=axes[1], size=3
)

sns.stripplot(
    x="sc_duringstim_mv", y="sc_duringstim", data=stim_stim,
    color='#57C1EB', jitter=0.25, ax=axes[0], size=3
)

sns.stripplot(
    x="sc_duringchoice_mv", y="sc_duringchoice", data=stim_choice,
    color='#57C1EB', jitter=0.25, ax=axes[1], size=3
)

axes[0].set_title("s_a")
axes[1].set_title("m_a")

for ax in axes:
    # ax.axhline(sc_threshold, color="black", linestyle="--", linewidth=1)
    ax.set_ylabel(r'$\sum$ choice', fontsize=10)
    ax.tick_params(labelsize=9)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_facecolor('none')

axes[0].set_xlabel('')
axes[0].set_xticks([])
axes[0].set_xticklabels([])

axes[1].set_xlabel('mv-init ramp', fontsize=10)

save_dir = '/Users/ariliu/Desktop/ibl-figures'
fig.savefig(f'{save_dir}/sc_strip.pdf', transparent=True)

In [None]:
sc_duringstim = np.array(res['sc_duringstim'])
sc_duringchoice = np.array(res['sc_duringchoice'])
sc_duringstim_mv = np.array(res['sc_duringstim_move_shape'].fillna(0))
sc_duringchoice_mv = np.array(res['sc_duringchoice_move_shape'].fillna(0))


fig, axs = plt.subplots(1,2, sharey=True, figsize=(4,2),dpi=120)

axs[0].hist(sc_duringstim)
axs[1].hist(sc_duringchoice)
for ax in axs:
    ax.set_xlabel(r'$\sum$ choice', fontsize=10)
    ax.tick_params(labelsize=9)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

axs[0].set_title('during stim')
axs[1].set_title('during choice')
axs[0].set(ylabel='num regions')
fig.tight_layout

save_dir = '/Users/ariliu/Desktop/ibl-figures'
fig.savefig(f'{save_dir}/sc_hist.pdf')

# decoding & projection of trajectories

In [None]:
def plot_proj_along_direction(reg, splits, colors, direction, pc=False,
                              method='manifold', control=False):
    
    # load in decoding result to get choice/stim vector
    pth = Path(one.cache_dir, 'decoding')
    if method=='manifold':
        pth = Path(one.cache_dir, 'manifold', 'traj_new')
        r = np.load(Path(pth, f'traj_{reg}_{direction}.npy'), allow_pickle=True)
        r0 = r[:,0,:]
        r1 = r[:,1,:]

        d = (r0-r1)**2
        d1 = np.sum(d, axis=0)
        maxd = max(d1)
        maxidx = list(d1).index(maxd)
        c_vec1 = d[:,maxidx]
    else:
        if method=='decoding':
            x = np.load(Path(pth, f'{direction}_{reg}_original.npy'), allow_pickle=True).flatten()[0]
        elif method=='decoding_pc':
            x = np.load(Path(pth, f'{direction}_{reg}.npy'), allow_pickle=True).flatten()[0]
        else:
            print('what method to define proj direction?')
            return

        if control:
            c_vec1 = x['avg_peth']
        else:
            max_acc = max(x['accuracy'])
            idx = x['accuracy'].index(max_acc)
            c_vec = x['readout'][idx]
            c_vec1 = np.concatenate(c_vec)


    labels = ['clbl', 'clbr', 'crbl', 'crbr']

    i = 0
    for split in splits:
        # load in trajectory data
        if 'pc' in method:
            pth_r = Path(one.cache_dir, 'manifold', reg)
        else:
            pth_r = Path(one.cache_dir, 'manifold', 'traj_new')
        r = np.load(Path(pth_r, f'traj_{reg}_{split}.npy'), allow_pickle=True)
        #print(r.shape)
        ntimes = r.shape[2]
        r0 = r[:,0,:] #condition L trajectory
        r1 = r[:,1,:] #condition R trajectory
    
        if pc:
            # pca to transform data onto same space
            a = np.concatenate([r0,r1], axis=1)
            a = a.transpose() # shape of ntimes*2, ncells
            a = np.append(a, c_vec, axis=0) # append c_vec together here for pca
            
            from sklearn.decomposition import PCA
            pca = PCA(n_components=None)
            a_pc = pca.fit_transform(a)
            var = pca.explained_variance_ratio_
            r0 = a_pc[:ntimes] #condition L trajectory in pc space
            r1 = a_pc[ntimes:ntimes*2] #condition R trajectory in pc spacec
            c_vec1 = a_pc[-1]
        else:
            r0 = r0.transpose()
            r1 = r1.transpose()
    
        bl = np.dot(r0,c_vec1)
        br = np.dot(r1,c_vec1)
        xx = np.linspace(-pre_post[split][0], 
                         pre_post[split][1], 
                         len(bl))
        plt.plot(xx, bl, c=colors[i])
        plt.plot(xx, br, c=colors[i+1])
        i+=2
    
    plt.legend(labels)
    plt.title(reg)
    plt.show()

In [None]:
reg = 'GRN'
splits = ['block_duringstim_r_choice_r_f1', 'block_duringstim_l_choice_l_f1',
          'block_stim_r_duringchoice_r_f1', 'block_stim_r_duringchoice_r_f1',
          'block_stim_r_choice_r_f1', 'block_stim_l_choice_l_f1'
         ]
for split in splits:
            #pth_r = Path(one.cache_dir, 'manifold', reg)
            pth_r = Path(one.cache_dir, 'manifold', 'traj_new')
            r = np.load(Path(pth_r, f'traj_{reg}_{split}.npy'), allow_pickle=True)
            r0 = r[:,0,:]
            r1 = r[:,1,:]
            print(split, r0.shape, r1.shape)

In [None]:
reg = 'GRN'
splits = ['block_duringstim_l_choice_l_f1', 'block_duringstim_r_choice_r_f1',
          'block_stim_l_duringchoice_l_f1', 'block_stim_r_duringchoice_r_f1',
          #'block_stim_l_duringchoice_l_f1_long', 'block_stim_r_duringchoice_r_f1_long',
          #'block_stim_l_choice_l_f1', 'block_stim_r_choice_r_f1'
          #'block_duringstim_l_choice_r_f2', 'block_duringstim_r_choice_l_f2',
          #'block_stim_l_duringchoice_r_f2', 'block_stim_r_duringchoice_l_f2',
         ]
colors = ['b', 'orange', 'c', 'r', 'b', 'orange', 'c', 'r',
           'b', 'orange', 'c', 'r'
          ]

method = 'decoding'
for direction in ['stim', 'choice']:
    plot_proj_along_direction(reg, splits, colors, direction, pc=False, 
                              method=method)
    plot_proj_along_direction(reg, splits, colors, direction, pc=True, 
                              method=method)

In [None]:
def plot_avg_proj_along_dir(regions, timespan, colors, direction='choice', pc=False,
                            original=True, control=False, test=False):
    pth = Path(one.cache_dir, 'decoding')
    labels = ['clbl', 'crbl', 'clbr', 'crbr']
    if timespan == 'duringchoice':
        splits = ['block_stim_l_duringchoice_l_f1', 'block_stim_r_duringchoice_r_f1']
    elif timespan == 'duringstim':
        splits = ['block_duringstim_l_choice_l_f1', 'block_duringstim_r_choice_r_f1']
    elif timespan == 'duringchoice1':
        splits = ['block_stim_r_duringchoice_l_f2', 'block_stim_l_duringchoice_r_f2']
    elif timespan == 'duringstim1':
        splits = ['block_duringstim_r_choice_l_f2', 'block_duringstim_l_choice_r_f2']
    elif timespan == 'intertrial':
        splits = ['block_stim_l_choice_l_f1', 'block_stim_r_choice_r_f1']
    elif timespan == 'intertrial1':
        splits = ['block_stim_r_choice_l_f2', 'block_stim_l_choice_r_f2']
    else:
        print('timespan?')
        return
    
    i = 0
    for reg in regions:
        # load choice decoding result to get choice direction
        if original:
            x = np.load(Path(pth, f'{direction}_{reg}_original.npy'), allow_pickle=True).flatten()[0]
        else:
            x = np.load(Path(pth, f'{direction}_{reg}.npy'), allow_pickle=True).flatten()[0]
        max_acc = max(x['accuracy'])
        idx = x['accuracy'].index(max_acc)
        c_vec = x['readout'][idx]
        c_vec = np.concatenate(c_vec)
        if control==True:
            c_vec = x['avg_peth']
        if test:
            c_vec = np.ones(len(x['avg_peth']))

        print(reg)
        
        datal, datar = [], []
        for split in splits:        
            
            # load manifold trajectories
            pth_r = Path(one.cache_dir, 'manifold', reg)
            r = np.load(Path(pth_r, f'traj_{reg}_{split}.npy'), allow_pickle=True)
            r0 = r[:,0,:]
            r1 = r[:,1,:]
            print(r0.shape, np.sum(r0), np.sum(r1))
            a = np.concatenate([r0,r1], axis=1)
            a = a.transpose() # shape of ntimes*2, ncells
            ntimes = r0.shape[1]
    
            if pc:
                # pca to transform manifold data onto same space as choice decoding direction
                ndim = min(len(c_vec), a.shape[0], a.shape[1])
                if len(c_vec) > ndim: 
                    # use only the first ndim # of pc to define choice direction
                    c_vec = c_vec[:ndim]
                from sklearn.decomposition import PCA
                pca = PCA(n_components=ndim)
                a_pc = pca.fit_transform(a)
                var = pca.explained_variance_ratio_
                print('var_exp', sum(var))
                r0 = a_pc[:ntimes] # shape of ntimes, npcs
                r1 = a_pc[ntimes:]
            
            bl = np.dot(r0,c_vec)
            br = np.dot(r1,c_vec)
            #bl = np.mean(bl[48:])
            #br = np.mean(br[48:])
            bl = np.mean(bl)
            br = np.mean(br)
            datal.append(bl)
            datar.append(br)
            print('bl', bl, 'br', br)
        
        xx = [0,1,2,3]
        plt.scatter(xx, np.concatenate([datal, datar]), c=colors[i], s=150)
        i+=1
    plt.xticks(xx, labels)
    plt.legend(regions)
    plt.title(timespan)

In [None]:
def plot_slope_proj_along_dir(regions, timespan, colors, direction='choice', pc=False,
                              original=True, control=False, test=False):
    pth = Path(one.cache_dir, 'decoding')
    labels = ['clbl', 'crbl', 'clbr', 'crbr']
    if timespan == 'duringchoice':
        splits = ['block_stim_l_duringchoice_l_f1', 'block_stim_r_duringchoice_r_f1']
    elif timespan == 'duringstim':
        splits = ['block_duringstim_l_choice_l_f1', 'block_duringstim_r_choice_r_f1']
    elif timespan == 'duringchoice1':
        splits = ['block_stim_r_duringchoice_l_f2', 'block_stim_l_duringchoice_r_f2']
    elif timespan == 'duringstim1':
        splits = ['block_duringstim_r_choice_l_f2', 'block_duringstim_l_choice_r_f2']
    elif timespan == 'intertrial':
        splits = ['block_stim_l_choice_l_f1', 'block_stim_r_choice_r_f1']
    elif timespan == 'intertrial1':
        splits = ['block_stim_r_choice_l_f2', 'block_stim_l_choice_r_f2']
    else:
        print('timespan?')
        return
    
    i = 0
    for reg in regions:
        # load choice decoding result to get choice direction
        if original:
            x = np.load(Path(pth, f'{direction}_{reg}_original.npy'), allow_pickle=True).flatten()[0]
        else:
            x = np.load(Path(pth, f'{direction}_{reg}.npy'), allow_pickle=True).flatten()[0]
        max_acc = max(x['accuracy'])
        idx = x['accuracy'].index(max_acc)
        c_vec = x['readout'][idx]
        c_vec = np.concatenate(c_vec)
        if control==True:
            c_vec = x['avg_peth']
        if test:
            c_vec = np.ones(len(x['avg_peth']))

        print(reg)
        
        datal, datar = [], []
        for split in splits:        
            
            # load manifold trajectories
            pth_r = Path(one.cache_dir, 'manifold', reg)
            r = np.load(Path(pth_r, f'traj_{reg}_{split}.npy'), allow_pickle=True)
            r0 = r[:,0,:]
            r1 = r[:,1,:]
            print(r0.shape, np.sum(r0), np.sum(r1))
            
            if pc:
                # pca to transform manifold data onto same space as choice decoding direction
                ndim = min(len(c_vec), a.shape[0], a.shape[1])
                if len(c_vec) > ndim: 
                    # use only the first ndim # of pc to define choice direction
                    c_vec = c_vec[:ndim]
                from sklearn.decomposition import PCA
                pca = PCA(n_components=ndim)
                a_pc = pca.fit_transform(a)
                var = pca.explained_variance_ratio_
                print('var_exp', sum(var))
                r0 = a_pc[:ntimes] # shape of ntimes, npcs
                r1 = a_pc[ntimes:]

            
            bl = np.dot(r0,c_vec)
            br = np.dot(r1,c_vec)
            
            if datatype == 'start':
                datal.append(bl[0])
                datar.append(br[0])
            elif datatype == 'end':
                datal.append(bl[len(bl)])
                datar.append(br[len(br)])
            elif datatype == 'slope':
                bl_slope, br_slope = [], []
                for j in range(len(bl)-1):
                    bl_slope.append(bl[j+1]-bl[j])
                    br_slope.append(br[j+1]-br[j])
                bl = np.mean(bl_slope)
                br = np.mean(br_slope)
                datal.append(bl)
                datar.append(br)
            print('bl', bl, 'br', br)
        
        xx = [0,1,2,3]
        plt.scatter(xx, np.concatenate([datal, datar]), c=colors[i], s=150)
        i+=1
    plt.xticks(xx, labels)
    plt.legend(regions)
    plt.title(timespan)

why these values all symmetric?????
look at traj along proj dir for incorrect trials

In [None]:
#regions = ['MRN', 'IRN', 'MOs']
regions = ['GRN']#, 'MOp']#, 'PAG']
timespan = 'duringstim'
colors = ['b', 'orange', 'c']
direction = 'choice'
control=False
test=False

plot_avg_proj_along_dir(regions, timespan, colors, direction, control, test)

In [None]:
regions = ['GRN', 'IRN', 'MOs', 'MRN', 'MOp', 'CENT3', 'SIM', 
           'IP', 'RSPagl', 'PL', 'AIv', 'PAG', 'CENT2', 
           'CP', 'ENTl', 'SUB', 'ZI', 'ANcr1', 'ACAd', 'APN']

pth = Path(one.cache_dir, 'decoding')
for region in regions:
    x = np.load(Path(pth, f'choice_{region}.npy'), allow_pickle=True).flatten()[0]
    print(region, 'mean_acc:', np.mean(x['accuracy']), 'p_val:', x['p_value'],
         'var:', x['var_exp'])

# Manifold Distance Analysis

## line plots

In [None]:
# show choice/stim diff in different block conditions
from matplotlib.ticker import MaxNLocator

reg = 'MOs'
splits = ['durings_srcrbl_slclbl', 'durings_srcrbr_slclbr'] #'srcrbl_slclbl', 'srcrbr_slclbr'
if 'duringc' in splits[0]:
    times = np.linspace(-0.15, 0, 72)
    time='duringchoice'
elif 'durings' in splits[0]:
    times = np.linspace(0, 0.15, 72)
    time='duringstim'
else:
    times = np.linspace(-0.4, -0.1, 144)
    time='intertrial'

fig = plt.figure(figsize=(4,3), dpi=150)

for split in splits:
    try:
        r = np.load(Path(pth_res,f'd_with_controls_{split}.npy'), allow_pickle=True).flatten()[0][reg]
    except BaseException:
        print("error:", split)
        continue
    if 'bl' in split:
        r[0]=-1*r[0]
        block='L'
    else:
        block='R'

    plt.plot(times, r[0], label=f'block{block}')

plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().xaxis.set_major_locator(MaxNLocator(4))

plt.savefig(f'/Users/ariliu/Desktop/{reg}_{time}_dist_trtl.pdf')

In [None]:
# OLD VERS - compare distance with controls
reg = 'IRN'
#splits = meta_splits['duringstim']
splits = ['block_stim_r_duringchoice_r_f1'] #'act_block_stim_r_choice_r_f1'
if 'duringchoice' in splits[0]:
    times = np.linspace(-0.15, 0, 72)
elif 'duringstim' in splits[0]:
    times = np.linspace(0, 0.15, 72)
else:
    times = np.linspace(-0.4, -0.1, 144)

    
for split in splits:
    try:
        r = np.load(Path(pth_res,f'd_with_controls_{split}.npy'), allow_pickle=True).flatten()[0][reg]
    except BaseException:
        print("error:", split)
        continue
        
    fig, axs = plt.subplots(1,2, sharey=True, figsize=(7,4), dpi=250, 
                            gridspec_kw={'width_ratios': [6, 1]})
    controls = []
    for i in range(20):
        if i == 0:
            continue
        color = 'gray'
        axs[0].plot(times, r[i], c = color, alpha=0.2)
        controls.append(r[i])
        
    axs[0].plot(times, r[0], c = 'b')
    # axs[0].set_title(reg)
    # axs[0].set(xlabel='Time(s)')
    # axs[0].set(ylabel='Euclidean Distance')
    axs[1].hist([np.max(r[k]) for k in r if k != 0], density=True, bins=20, 
                color = 'silver', orientation='horizontal')
    axs[1].axhline(y=np.max(r[0]), c='b')
    # axs[1].set(xlabel='Density')
    
    # Calculate Significant Fraction
    p = []
    for i in range(len(r[0])): # counting time bins
        a = 0
        for j in range(len(r)-1): # counting number of trajectories
            a+=int(r[0][i]<r[j+1][i])
        p.append(a/len(r))
    p = np.array(p)
    sig_frac = sum(p < 0.05)/len(p)
    print(split, np.min(p), 'sig_frac:', sig_frac)

    # Print p value
    d = np.load(Path(pth_res,f'{split}.npy'), allow_pickle=True).flatten()[0][reg]
    p_val_at_max = d['p_euc_c1']
    axs[0].text(0.55, 0.97, f'p_val {p_val_at_max:.3f}', transform=axs[0].transAxes,
            color='red', fontsize=14, ha='left', va='top')
    
    for ax in axs:
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_xticklabels([])
        ax.set_facecolor('none')
        ax.tick_params(labelsize=15)
    axs[0].set_xticks(np.linspace(times[0], times[-1], 4))
    axs[1].spines['bottom'].set_visible(False)
    axs[1].tick_params(axis='y', left=False, labelleft=False)
    axs[1].tick_params(axis='x', bottom=False, labelbottom=False)

    fig.tight_layout()
    save_dir = '/Users/ariliu/Desktop/ibl-figures'
    fig.savefig(f'{save_dir}/{reg}_{split}_dist.pdf', 
                transparent=True)


In [347]:
def plot_regional_distance(reg, split, time, ptype='p_mean_c', alpha=0.05, plot_p_per_time=True):
    if 'duringchoice' in time:
        times = np.linspace(-0.15, 0, 72)
    elif 'duringstim1' in time:
        times = np.linspace(0, 0.08, 42)
    elif 'duringstim' in time:
        times = np.linspace(0, 0.15, 72)
    else:
        times = np.linspace(-0.4, -0.1, 144)

    if 'combined' in split:
        r = np.load(Path(pth_res, f'{split}.npy'), allow_pickle=True).flatten()[0][reg]
        r = np.concatenate([r[0].reshape(1, -1), r[1]], axis=0)
        split_name = split.split('regde_', 1)[1]
        d = np.load(Path(pth_res, f'combined_{split_name}.npy'), allow_pickle=True).flatten()[0][reg]
    else:
        r = np.load(Path(pth_res, f'{split}_regde.npy'), allow_pickle=True).flatten()[0][reg]
        d = np.load(Path(pth_res, f'{split}.npy'), allow_pickle=True).flatten()[0][reg]

    fig, axs = plt.subplots(1, 2, sharey=True, figsize=(6, 4), dpi=250,
                            gridspec_kw={'width_ratios': [6, 1]})
    for i in range(1, 40):
        if 'duringstim' in time:
            axs[0].plot(times[:5], r[i][:5], c='#5f7ea3', alpha=0.5, linewidth=0.5)
            axs[0].plot(times[4:], r[i][4:], c='gray', alpha=0.2, linewidth=0.5)
        else:
            axs[0].plot(times, r[i], c='gray', alpha=0.2, linewidth=0.5)

    # axs[0].plot(times, d['d_euc'], c='gold', linewidth=1)
    if 'duringstim' in time:
        axs[0].plot(times[:5], r[0][:5], c='blue', linewidth=1)
        axs[0].plot(times[4:], r[0][4:], c='black', linewidth=1)
    else:
        axs[0].plot(times, r[0], c='black', linewidth=1)


    p_per_time = np.mean(r >= r[0], axis=0)
    # # Perform Bonferroni FDR correction
    # corrected_p_values = multipletests(p_per_time, alpha=0.05, method='fdr_bh')[1]
    # p_per_time = corrected_p_values
    
    # Calculate p_val for the mean of the first 5 datapoints
    mean_first5 = np.mean(r[:, :5], axis=1)
    p_val_first5 = np.mean(mean_first5 >= mean_first5[0])

    p_val = d[ptype]
    # if ptype == 'p_amp':
    #     amplitude = np.max(r, axis=1) - np.min(r, axis=1)
    #     p_val = np.mean(amplitude >= amplitude[0])
    # elif ptype == 'p_mean':
    #     p_val = np.mean(np.mean(r, axis=1) >= np.mean(r[0]))
    # elif ptype == 'p_max':
    #     p_val = np.mean(np.max(r, axis=1) >= np.max(r[0]))
    # else:
    #     raise ValueError(f"Invalid ptype: {ptype}")

    if plot_p_per_time:
        # ax2 = axs[0].twinx()
        # ax2.plot(times, p_per_time, color='blue', linestyle='--', linewidth=1, label='p per time')
        # ax2.set_ylim([0, 1])
        # ax2.set_ylabel('p', fontsize=10, color='blue')
        # ax2.tick_params(axis='y', labelcolor='blue', labelsize=8)

        sig_mask = p_per_time <= alpha
        axs[0].scatter(times[sig_mask], np.full(np.sum(sig_mask), axs[0].get_ylim()[0]),
                       marker='v', color='blue', s=20, zorder=5)

    if 'p_mean' in ptype:
        axs[1].hist(np.mean(r[1:], axis=1), density=True, bins=20,
                    color='silver', orientation='horizontal')
        axs[1].axhline(y=np.mean(r[0]), c='black')
        if 'duringstim' in time:
            axs[1].hist(np.mean(r[1:, :5], axis=1), density=True, bins=20,
                        color='#5f7ea3', orientation='horizontal', alpha=0.5)
            axs[1].axhline(y=np.mean(r[0, :5]), c='blue')
    elif 'p_amp' in ptype:
        amplitude = np.max(r, axis=1) - np.min(r, axis=1)
        axs[1].hist(amplitude[1:], density=True, bins=20,
                    color='silver', orientation='horizontal')
        axs[1].axhline(y=amplitude[0], c='black')
    elif 'p_max' in ptype:
        axs[1].hist(np.max(r[1:], axis=1), density=True, bins=20,
                    color='silver', orientation='horizontal')
        axs[1].axhline(y=np.max(r[0]), c='black')

    axs[0].text(0.2, 0.97, f'p_val {p_val:.4f}', transform=axs[0].transAxes,
                color='red' if p_val <= alpha else 'black', fontsize=20, ha='left', va='top')
    if 'duringstim' in time:
        axs[0].text(0.45, 0.15, f'p_val_offset {p_val_first5:.3f}', transform=axs[0].transAxes,
                    color='red' if p_val_first5 <= alpha else 'blue', fontsize=16, ha='left', va='top')

    for ax in axs:
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_facecolor('none')
        ax.tick_params(labelsize=15)
    axs[0].spines['left'].set_visible(False)
    axs[0].tick_params(axis='y', left=False)
    if 'duringstim1' in time:
        axs[0].set_xticks(np.linspace(times[0], times[-1], 3))
    else:
        axs[0].set_xticks(np.linspace(times[0], times[-1], 4))
    axs[1].spines['bottom'].set_visible(False)
    axs[1].tick_params(axis='y', left=False, labelleft=False)
    axs[1].tick_params(axis='x', bottom=False, labelbottom=False)

    fig.tight_layout()
    save_dir = '/Users/ariliu/Desktop/ibl-figures'
    if 'combined' in split:
        fig.savefig(f'{save_dir}/{reg}_{time}_{ptype}_dist.pdf', 
                    transparent=True)
    else:
        fig.savefig(f'{save_dir}/{reg}_{split}_{ptype}_dist.pdf', 
                transparent=True)


In [None]:
def plot_regional_distance_comparison(reg, times):
    x_times = np.linspace(0, 0.15, 72)

    fig, axs = plt.subplots(figsize=(6, 4), dpi=150)
    for time in times:
        splits = run_align[time]
        name = "_".join(splits)
        r = np.load(Path(pth_res, f'combined_regde_{name}.npy'), allow_pickle=True).flatten()[0][reg]
        r = np.concatenate([r[0].reshape(1, -1), r[1]], axis=0)
        # d = np.load(Path(pth_res, f'combined_{name}.npy'), allow_pickle=True).flatten()[0][reg]
        # print(time, 'amp_slope', d['amp_slope'], 'amp_loc', d['amp_loc'])

        axs.plot(x_times, r[0], linewidth=1, label=time)
        axs.scatter(x_times, r[0], color='black', s=10)  # Add dots for each datapoint

    axs.spines['top'].set_visible(False)
    axs.spines['right'].set_visible(False)
    axs.set_xticklabels([])
    # ax.set_yticklabels([])
    axs.set_facecolor('none')
    axs.tick_params(labelsize=15)
    # axs[0].spines['left'].set_visible(False)
    # axs[0].tick_params(axis='y', left=False)

    axs.set_xticks(np.linspace(x_times[0], x_times[-1], 4))
    axs.legend(fontsize=15, frameon=False)
    axs.set_title(reg, fontsize=15)

    fig.tight_layout()
    save_dir = '/Users/ariliu/Desktop/ibl-figures'
    fig.savefig(f'{save_dir}/{reg}_{times}_comparison.pdf', 
                transparent=True)
    plt.close()


In [None]:
reg = 'PRNr'
ptype = 'p_mean_c'

# timeframes = ['block_duringstim', 'choice_duringstim0', 'block_duringchoice', 'choice_duringchoice0']
# timeframes = ['stim_duringstim0', 'choice_duringstim0', 'stim_duringchoice0', 'choice_duringchoice0']
timeframes=['choice_duringstim0']
for timeframe in timeframes:    
    splits = run_align[timeframe]
    if len(splits) == 1:
        combined_name = splits[0]
    else:
        combined_name = 'combined_regde_'+"_".join(splits)

    # for split in splits:
    #     print(split)
    #     plot_regional_distance(reg, split, timeframe, ptype=ptype, plot_p_per_time=True)
    plot_regional_distance(reg, combined_name, timeframe, ptype=ptype, alpha=0.05, plot_p_per_time=True)

# plot_regional_distance(reg, 'block_only', timeframe, ptype=ptype, plot_p_per_time=True)

In [None]:
# reg='SIM'
ptype = 'p_mean_c'

timeframe='stim_duringchoice0'
splits = run_align[timeframe]
combined_name = 'combined_regde_'+"_".join(splits)

# regs = ['GRN', 'MRN', 'IRN']
regs = ['MOs', 'SIM', 'IP', 'RN', 'APN', 'SCm', 
            #    'PF'
               ]

# regs = ['GRN', 'GPi', 'SNr']
# regs = ['SIM', 
#             #    'DEC', 
#                'CLA', 
#                'PCG', 
#             #    'SSs', 
#                'ACAd', 'MOs', 'RN', 'SMT']

for reg in regs:
    plot_regional_distance(reg, combined_name, timeframe, ptype=ptype, plot_p_per_time=True)


In [None]:
def plot_average_distance_over_regions(regs, timewindow, split=None, alpha=0.05, 
                                       ptype='p_mean_c', plot_p_per_time=True):
    if 'duringchoice' in timewindow:
        times = np.linspace(-0.15, 0, 72)
    elif 'duringstim' in timewindow:
        times = np.linspace(0, 0.15, 72)
    else:
        times = np.linspace(-0.4, -0.1, 144)

    if split is None:
        splits = run_align[timewindow]
        split = 'combined_regde_'+"_".join(splits)

    all_r = []
    all_d = []

    if 'GRN' in regs:
        c = 'tomato'
    elif 'MOs' in regs:
        c = 'gold'
    else:
        c = 'blue'

    for reg in regs:
        if 'combined' in split:
            r = np.load(Path(pth_res, f'{split}.npy'), allow_pickle=True).flatten()[0][reg]
            r = np.concatenate([r[0].reshape(1, -1), r[1]], axis=0)
            split_name = split.split('regde_', 1)[1]
            d = np.load(Path(pth_res, f'combined_{split_name}.npy'), allow_pickle=True).flatten()[0][reg]
        else:
            r = np.load(Path(pth_res, f'{split}_regde.npy'), allow_pickle=True).flatten()[0][reg]
            d = np.load(Path(pth_res, f'combined_{split}.npy'), allow_pickle=True).flatten()[0][reg]

        all_r.append(r)
        all_d.append(d)

    # average across regions (shape: [n_samples, time])
    r_avg = np.mean(np.stack(all_r), axis=0)

    fig, axs = plt.subplots(1, 2, sharey=True, figsize=(7, 4), dpi=250,
                            gridspec_kw={'width_ratios': [6, 1]})
    for i in range(1, min(40, r_avg.shape[0])):
        axs[0].plot(times, r_avg[i], c='gray', alpha=0.2, linewidth=0.5)

    axs[0].plot(times, r_avg[0], c=c, linewidth=1)

    # p-value curve (per time) and aggregate
    p_per_time = np.mean(r_avg >= r_avg[0], axis=0)

    # Calculate p_val for the mean of the first 5 datapoints
    mean_first5 = np.mean(r[:, :5], axis=1)
    p_val_first5 = np.mean(mean_first5 >= mean_first5[0])

    if ptype.endswith('_c'):
        p_val = np.mean([d[ptype] for d in all_d])
    elif ptype == 'p_mean':
        p_val = np.mean(np.mean(r_avg[1:], axis=1) >= np.mean(r_avg[0]))
    elif ptype == 'p_amp':
        amp = np.max(r_avg, axis=1) - np.min(r_avg, axis=1)
        p_val = np.mean(amp[1:] >= amp[0])
    elif ptype == 'p_max':
        p_val = np.mean(np.max(r_avg[1:], axis=1) >= np.max(r_avg[0]))
    else:
        raise ValueError(f"Unsupported ptype: {ptype}")

    if plot_p_per_time:
        ax2 = axs[0].twinx()
        ax2.plot(times, p_per_time, color='blue', linestyle='--', linewidth=1, label='p per time')
        ax2.set_ylim([0, 1])
        ax2.set_ylabel('p', fontsize=10, color='blue')
        ax2.tick_params(axis='y', labelcolor='blue', labelsize=8)

        sig_mask = p_per_time <= alpha
        axs[0].scatter(times[sig_mask], np.full(np.sum(sig_mask), axs[0].get_ylim()[0]),
                       marker='v', color='blue', s=20, zorder=5)

    if 'p_mean' in ptype:
        axs[1].hist(np.mean(r_avg[1:], axis=1), density=True, bins=20,
                    color='silver', orientation='horizontal')
        axs[1].axhline(y=np.mean(r_avg[0]), c='black')
    elif 'p_amp' in ptype:
        amp = np.max(r_avg, axis=1) - np.min(r_avg, axis=1)
        axs[1].hist(amp[1:], density=True, bins=20,
                    color='silver', orientation='horizontal')
        axs[1].axhline(y=amp[0], c='black')
    elif 'p_max' in ptype:
        axs[1].hist(np.max(r_avg[1:], axis=1), density=True, bins=20,
                    color='silver', orientation='horizontal')
        axs[1].axhline(y=np.max(r_avg[0]), c='black')

    axs[0].text(0.2, 0.97, f'p_val {p_val:.3f}', transform=axs[0].transAxes,
                color='red' if p_val <= alpha else 'black', fontsize=18, ha='left', va='top')
    axs[0].text(0.2, 0.85, f'p_val_offset {p_val_first5:.3f}', transform=axs[0].transAxes,
                color='red' if p_val_first5 <= alpha else 'black', fontsize=15, ha='left', va='top')

    for ax in axs:
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_facecolor('none')
        ax.tick_params(labelsize=15)
    axs[0].spines['left'].set_visible(False)
    axs[0].tick_params(axis='y', left=False)

    axs[0].set_xticks(np.linspace(times[0], times[-1], 4))
    axs[1].spines['bottom'].set_visible(False)
    axs[1].tick_params(axis='y', left=False, labelleft=False)
    axs[1].tick_params(axis='x', bottom=False, labelbottom=False)

    fig.tight_layout()
    save_dir = '/Users/ariliu/Desktop/ibl-figures'
    fig.savefig(f'{save_dir}/{"_".join(regs)}_{split}_{ptype}_dist_avg.pdf', transparent=True)


In [None]:
def plot_group_comparison_over_regions(regs_int, regs_choice, timeframe, alpha=0.05, ptype='p_mean_c',
                                       label_A='integrator', label_B='choice'):

    splits = run_align[timeframe]
    split = 'combined_'+"_".join(splits)

    if 'duringchoice' in split:
        times = np.linspace(-0.15, 0, 72)
    elif 'duringstim' in split:
        times = np.linspace(0, 0.15, 72)
    else:
        times = np.linspace(-0.4, -0.1, 144)

    def load_group(regs):
        all_r = []
        res = manifold_to_csv(split, alpha, ptype)
        for reg in regs:
            sig = res[res['region']==reg]['significant'].iloc[0]
            if sig==1:
                r = np.load(Path(pth_res, f'{split}.npy'), allow_pickle=True).flatten()[0][reg]
                print(reg)
                all_r.append(r['d_euc'])
        return np.mean(np.stack(all_r), axis=0)  # shape: (n_samples, time)

    print('int regs')
    r_int = load_group(regs_int)
    print('move regs')
    r_choice = load_group(regs_choice)

    fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=250)

    ax.plot(times, r_int, color='gold', linewidth=2, label=label_A)
    ax.plot(times, r_choice, color='tomato', linewidth=2, label=label_B)

    ax.set_xticks(np.linspace(times[0], times[-1], 4))
    # ax.set_xlabel('Time (s)')
    # ax.set_ylabel('Euclidean Distance')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_facecolor('none')
    ax.tick_params(labelsize=12)
    ax.set_yticks([])
    ax.set_xticks([])
    # ax.legend(frameon=False, fontsize=10, loc='upper left')

    fig.tight_layout()
    save_dir = '/Users/ariliu/Desktop/ibl-figures'
    fig.savefig(f'{save_dir}/compare_int_choice_{timeframe}.pdf', transparent=True)


In [None]:
times = ['stim_duringstim0', 'choice_duringstim0', 'stim_duringchoice0', 'choice_duringchoice0']
ptype = 'p_mean_c'
alpha = 0.05
sc_threshold = 0.6

res = get_sc_table(times, ptype, alpha, combined_p=True, 
                   sc_threshold=sc_threshold)
regs = res['region']

for timeframe in ['duringstim', 'duringchoice']:

    # regs_int = list(regs[res[f'sc_{timeframe}_int_mov']==0.5])
    # regs_move = list(regs[res[f'sc_{timeframe}_int_mov']==1])
    regs_move = list(set(move_regs_stim) | set(move_regs_choice))
    regs_int = list(set(int_regs_stim) | set(int_regs_choice))

    # print(regs_move)

    plot_group_comparison_over_regions(regs_int, regs_move, f'act_block_{timeframe}')

In [359]:
# regs_stim = ['ZI', 'VISp', 'LP']
# regs_int = ['SIM', 'CUL4 5', 'BMA', 'NTS', 'PRNc', 'PARN', 'SUV', 'MV', 'PRNr',
#         'PAR', 'MOs', 'APN', 'SCm', 'PPN']
# regs_move = [
#              'CENT2', 'IP', 'CENT3', 'CP', 'GRN', 'V', 'IRN', 'RN', 'MRN',
#              'GPi'
#              ]

# get stim, move, int regs by running plot_combined_onetype
# move_regs_stim, move_regs_choice, int_regs_stim, int_regs_choice, stim_regs

# timeframe='act_block_duringstim'
timeframe='choice_duringchoice0'
# timeframe='act_block_duringchoice'
ptype='p_mean_c'

# timeframe='intertrial'

regs = int_regs_stim
splits = run_align[timeframe]
combined_name = 'combined_regde_'+"_".join(splits)

# plot_average_distance_over_regions(regs, timeframe, plot_p_per_time=False)

for reg in regs:
    plot_regional_distance(reg, combined_name, timeframe, ptype=ptype, alpha=0.05, plot_p_per_time=True)
    plt.close()

In [None]:
def get_line_plots(meta_split, regs, color, p_type=None, restr=True):
    '''
    plot average all detailed splits' trajectories during a timeframe
    restr: restricted to only correct trials
    '''
    
    splits = meta_splits[meta_split]
    if restr:
        splits = splits[:2]
    d = {}
    for split in splits:
        d[split] = 0
        for reg in regs:
            r = np.load(Path(pth_res,f'{split}.npy'),
                        allow_pickle=True).flat[0][reg]
            if p_type==None:
                d[split] += r['d_euc']
            else:
                r['significant'] = r[p_type]<sigl
                d[split] += r['d_euc'] * r['significant']

    d = pd.DataFrame(data=d)
    d['sum'] = d[splits].apply(np.sum, axis=1)

    xx = np.linspace(-pre_post[meta_split][0], 
                                  pre_post[meta_split][1], 
                                  len(d['sum']))
    yy = d['sum']/len(regs)
    plt.plot(xx,yy,c=color)
    
    
def get_multiple_line_plots(metasplits, reg, labels, p_type=None):
    """
    plot all detailed splits' trajectories during a timeframe
    average over correct & incorrect trials (probably not useful)
    """

    col = ['b', 'r', 'y', 'g', 'c', 'orange']
    for meta_split in metasplits:
        splits = meta_splits[meta_split]
        d = {}
        for split in splits:
            r = np.load(Path(pth_res,f'{split}.npy'),
                        allow_pickle=True).flat[0][reg]
            if p_type==None:
                d[split] = r['d_euc']
            else:
                r['significant'] = r[p_type]<sigl
                d[split] = r['d_euc'] * r['significant']

        xx = np.linspace(-pre_post[meta_split][0], 
                         pre_post[meta_split][1], 
                         len(d[split]))
        for i in range(0,6):
            yy = (d[splits[i]] + d[splits[i+6]]) / 2
            plt.plot(xx,yy, c = col[i])
        
    plt.legend(labels)
    plt.title(reg)

    
def get_multiple_line_plots_0(metasplits, reg, labels, p_type=None):
    """
    plot detailed splits' trajectories during a timeframe
    restricted to correct/incorrect subset of trials only
    """
    
    #col = sns.color_palette('flare', 6)
    col = ['#e3685c', '#e98d6b', '#b13c6c', '#8f3371', '#d14a61', '#6c2b6d']
    for meta_split in metasplits:
        splits = meta_splits[meta_split]
        d = {}
        i = 0
        for split in splits:
            r = np.load(Path(pth_res,f'{split}.npy'),
                        allow_pickle=True).flat[0][reg]
            if p_type==None:
                d[split] = r['d_euc']
            else:
                r['significant'] = r[p_type]<sigl
                d[split] = r['d_euc'] * r['significant']

            xx = np.linspace(-pre_post[meta_split][0], 
                         pre_post[meta_split][1], 
                         len(d[split]))
            yy = d[split]
            plt.plot(xx,yy, c = col[i])
            i+=1
        
    #plt.legend(labels)
    plt.xlabel('Time(s)')
    plt.ylabel('Euclidean Distance')
    plt.title(reg)
    xticks = [0,0.05,0.1,0.15,0.2,0.25,0.3,0.35]
    ticklabels = ['0','0.05','0.1','0.15','-0.15','-0.1','-0.05','0']
    plt.xticks(xticks, ticklabels)
    plt.savefig('/Users/ariliu/Desktop/'+reg+'_lineplots.pdf', dpi=300)

In [None]:
p_type = None #'p_euc_c1' 'p_euc'
reg = 'GRN'

labels = ['block difference, choice R', 'block difference, choice L',
          'choice difference, block L', 'choice difference, block R',
          'discordant, different block & choice', 'concordant, different block & choice'
         ]

#metasplits = ['intertrial_all_0', 'duringstim_all_0', 'duringchoice_all_0']
metasplits = ['duringstim_all_0', 'duringchoice_all_0']
#metasplits = ['duringstim_subset', 'duringchoice_subset']
get_multiple_line_plots_0(metasplits, reg, labels, p_type)

#metasplits = ['duringstim_all', 'duringchoice_all']
#metasplits = ['intertrial_all', 'duringstim_all', 'duringchoice_all']
#get_multiple_line_plots(metasplits, reg, labels, p_type)

In [None]:
nrand = 2000

p_type = 'p_euc_c1' #None #'p_euc'
regs_c = ['GRN','PAG','AIv','MOp','RSPagl', 'CENT3'] #'AIv','MOp','RSPagl'
regs_s = ['IRN','MOs','MRN','CENT3']
regs = regs_c
colors = {
    'c':'tomato',
    's':'gold'
}
color = colors['c']

meta_split = 'duringchoice'
splits = meta_splits[meta_split]
splits = splits[:2]


In [None]:
d = {}
n=0
for split in splits:
    for reg in regs:
        r = np.load(Path(pth_res,f'{split}.npy'),
                    allow_pickle=True).flat[0][reg]
        regde = np.load(Path(pth_res,f'{split}_regde.npy'),
                    allow_pickle=True).flat[0][reg]
        r['significant'] = r[p_type]<sigl
        n+=1 if r['significant'] else 0
        if 'true' not in d:
            d['true'] = np.zeros_like(regde[0], dtype=float) 
        d['true'] += regde[0] * r['significant']
        for i in range(len(regde[1:])):
            key = f'ctrl_{i}'
            if key not in d:
                d[key] = np.zeros_like(regde[i+1], dtype=float)
            d[key] += regde[i+1] * r['significant']
            
d['true'] = d['true']/n
for i in range(nrand):
    d[f'ctrl_{i}'] = d[f'ctrl_{i}']/n

xx = np.linspace(-pre_post[meta_split][0], 
                                pre_post[meta_split][1], 
                                len(d['true']))
yy = d['true']
plt.plot(xx,yy,c=color)
for i in range(nrand):
    plt.plot(xx, d[f'ctrl_{i}']/n, c='gray', alpha=0.5)


In [None]:
### plot avg block L-R distance for stim integrators vs choice generators

p_type = 'p_euc_c1' #None #'p_euc'
regs_c = ['GRN','PAG','AIv','MOp','RSPagl', 'CENT3'] #'AIv','MOp','RSPagl'
regs_s = ['IRN','MOs','MRN','CENT3']
regs = {
    'c':regs_c,
    's':regs_s
}
colors = {
    'c':'tomato',
    's':'gold'
}


for regtype in ['c','s']:
    for meta_split in ['duringstim', 'duringchoice']: #'intertrial'
        get_line_plots(meta_split, regs[regtype], colors[regtype], p_type)
        
plt.xlabel('Time(s)')
plt.ylabel('Average Block L-R Euclidean Distance')
xticks = [0,0.05,0.1,0.15,0.2,0.25,0.3,0.35]
ticklabels = ['0','0.05','0.1','0.15','-0.15','-0.1','-0.05','0']
plt.xticks(xticks, ticklabels)
plt.savefig('/Users/ariliu/Desktop/avgd.pdf',dpi=300)

In [None]:
def get_sc_line_plots(sc_splits, reg, stype=None, p_type=None):
    
    for meta_split in sc_splits:
        splits = meta_splits[meta_split]
        r0 = np.load(Path(pth_res,f'{splits[0]}.npy'),
                    allow_pickle=True).flat[0][reg]
        r1 = np.load(Path(pth_res,f'{splits[1]}.npy'),
                    allow_pickle=True).flat[0][reg]

        if p_type==None:
                d = r1['d_euc'] / (r0['d_euc'] + r1['d_euc'])
        else:
                r0['significant'] = r0[p_type]<sigl
                r1['significant'] = r1[p_type]<sigl
                d = r1['d_euc'] * r1['significant'] / (
                    r0['d_euc'] * r0['significant'] + r1['d_euc'] * r1['significant'])

        xx = np.linspace(-pre_post[meta_split][0], 
                                  pre_post[meta_split][1], 
                                  len(d))
        yy = d
        plt.plot(xx,yy)

In [None]:
sc_splits = ['sc_duringstim', 'sc_duringchoice']
get_sc_line_plots(sc_splits, 'MV', 'p_euc_c1')

## process data p values

In [11]:
def compute_amp_slope(timeframe, n=10):
    '''
    for stim/choice splits, locate the peak of the amplitude and fit the slope of the last n points
    used later for region type classification
    n: number of last points to fit the slope
    '''

    splits = run_align[timeframe]
    if len(splits) == 1:
        combined_name = splits[0]
        combined_regde_name = f'{combined_name}_regde'
    else:
        combined_name = 'combined_'+"_".join(splits)
        combined_regde_name = 'combined_regde_'+"_".join(splits)

    # run for combined results
    d = np.load(Path(pth_res, f'{combined_name}.npy'), 
                    allow_pickle=True).flat[0]  
    regs = [x for x in d]
    for reg in regs:
        r = np.load(Path(pth_res, f'{combined_regde_name}.npy'), allow_pickle=True).flatten()[0][reg][0]
        slope = np.polyfit(np.linspace(0, 0.15, len(r)), r, 1)[0]
        d[reg]['amp_slope'] = slope
        
        slope_last = np.polyfit(np.arange(n), r[-n:], 1)[0]
        d[reg]['slope_last'] = slope_last

        amp_loc = np.argmax(r)
        d[reg]['amp_loc'] = amp_loc

        slope_last_5 = np.polyfit(np.arange(5), r[-5:], 1)[0]
        d[reg]['slope_last_5'] = slope_last_5

        
    np.save(Path(pth_res, f'{combined_name}.npy'), d, allow_pickle=True)


In [19]:
# for timeframe in run_align:
for timeframe in ['act_stim_duringstim1', 'act_choice_duringstim0', 'act_stim_duringstim0', 
                  'act_stim_duringchoice0', 'act_choice_duringchoice0']:
    compute_amp_slope(timeframe, n=10)

FileNotFoundError: [Errno 2] No such file or directory: '/Users/ariliu/Downloads/ONE/openalyx.internationalbrainlab.org/manifold/res/combined_act_stim_block_l_act_stim_block_r.npy'

In [12]:
def fdr_combined(timeframe, ptype='p_euc', sigl=0.05):
    
    '''
    FDR correction, based on regions (same as in bwm analysis)
    results saved as '{ptype}_c'
    '''
    splits = run_align[timeframe]
    if len(splits) == 1:
        combined_name = splits[0]
    else:
        combined_name = 'combined_'+"_".join(splits)

    # run correction for combined results
    d = np.load(Path(pth_res, f'{combined_name}.npy'), 
                    allow_pickle=True).flat[0]
    regs = [x for x in d]
    pvals = [d[x][ptype] for x in d]
    _, pvals_c, _, _ = multipletests(pvals, sigl, method='fdr_bh')

    for i in range(len(regs)):
        d[regs[i]][f'{ptype}_c'] = pvals_c[i]

    np.save(Path(pth_res, f'{combined_name}.npy'), d, allow_pickle=True)

    # run for each split
    if len(splits) > 1:
        for split in splits:
            d = np.load(Path(pth_res, f'{split}.npy'), 
                        allow_pickle=True).flat[0]
            regs = [x for x in d]
            pvals = [d[x][ptype] for x in d]
            _, pvals_c, _, _ = multipletests(pvals, sigl, method='fdr_bh')
        
            for i in range(len(regs)):
                d[regs[i]][f'{ptype}_c'] = pvals_c[i]
        
            np.save(Path(pth_res, f'{split}.npy'), d, allow_pickle=True)


def compute_p_value(timeframe, ptype='p_mean'):
    splits = run_align[timeframe]
    if len(splits) == 1:
        combined_name = splits[0]
        combined_regde_name = f'{combined_name}_regde'
    else:
        combined_name = 'combined_'+"_".join(splits)
        combined_regde_name = 'combined_regde_'+"_".join(splits)

    # run for combined results
    d = np.load(Path(pth_res, f'{combined_name}.npy'), 
                    allow_pickle=True).flat[0]  
    regs = [x for x in d]
    for reg in regs:
        r = np.load(Path(pth_res, f'{combined_regde_name}.npy'), allow_pickle=True).flatten()[0][reg]
        if len(splits) > 1: # for combined splits the control curves are all stored in r[1]
            r = np.concatenate([r[0].reshape(1, -1), r[1]], axis=0)
        if ptype == 'p_amp':
            amplitude = np.max(r, axis=1) - np.min(r, axis=1)
            p_val = np.mean(amplitude >= amplitude[0])
        elif ptype == 'p_mean':
            p_val = np.mean(np.mean(r, axis=1) >= np.mean(r[0]))
        elif ptype == 'p_max':
            p_val = np.mean(np.max(r, axis=1) >= np.max(r[0]))
        else:
            raise ValueError(f"Invalid ptype: {ptype}")
        d[reg][f'{ptype}'] = p_val
        
    np.save(Path(pth_res, f'{combined_name}.npy'), d, allow_pickle=True)

    # run for each split
    if len(splits) > 1: 
        for split in splits:
            d = np.load(Path(pth_res, f'{split}.npy'), 
                        allow_pickle=True).flat[0]
            regs = [x for x in d]
            for reg in regs:
                r = np.load(Path(pth_res, f'{split}_regde.npy'), allow_pickle=True).flat[0][reg]
                if ptype == 'p_amp':
                    amplitude = np.max(r, axis=1) - np.min(r, axis=1)
                    p_val = np.mean(amplitude >= amplitude[0])
                elif ptype == 'p_mean':
                    p_val = np.mean(np.mean(r, axis=1) >= np.mean(r[0]))
                elif ptype == 'p_max':
                    p_val = np.mean(np.max(r, axis=1) >= np.max(r[0]))
                else:
                    raise ValueError(f"Invalid ptype: {ptype}")
                d[reg][f'{ptype}'] = p_val
            
            np.save(Path(pth_res, f'{split}.npy'), d, allow_pickle=True)
            


In [None]:
# times = [
#     'stim_duringstim0', 'choice_duringstim0', 
#     'stim_duringchoice0', 'choice_duringchoice0'
# ]
# times = ['stim_duringstim', 'choice_duringstim', 'stim_duringchoice', 'choice_duringchoice']
# times = ['block_duringstim', 'block_duringchoice', 'intertrial']
times = ['intertrial0_act', 'intertrial_act', 'block_duringstim_act', 'block_duringchoice_act']

for ptype in ['p_mean', 'p_amp', 'p_max']:
    for timeframe in times:
        compute_p_value(timeframe, ptype=ptype)
        fdr_combined(timeframe, ptype=ptype)

In [None]:
timeframe = 'choice_duringstim'
splits = run_align[timeframe]
combined_name = 'combined_'+"_".join(splits)

d = np.load(Path(pth_res, f'{combined_name}.npy'), 
                    allow_pickle=True).flat[0]
for reg in ['GRN', 'VISp']:
    print(reg)
    print(d[reg]['p_mean'])
    print(d[reg]['p_mean_c'])
    print(d[reg]['p_amp'])
    print(d[reg]['p_amp_c'])
    print(d[reg]['p_max'])
    print(d[reg]['p_max_c'])
    # break


In [13]:
def manifold_to_csv(split, sigl, p_type):

    '''
    reformat results for table
    '''
    
    # mapping = 'Beryl'

    columns = ['region', #'name', 
               p_type, 'amp_euc', 'lat_euc', 
               'amp_slope', 'slope_last', 'amp_loc', 'slope_last_5',
               #'amp_euc_can','lat_euc_can', 'amp_eucn_can', 'lat_eucn_can'
              ]
               
    d = np.load(Path(pth_res, f'{split}.npy'), 
                    allow_pickle=True).flat[0]
    
    # use a sample to align the regions if sample file exists (easier to align here than later when plotting tables!!)
    sample_path = Path(pth_res, f'intertrial.csv')
    if not sample_path.exists():
        regs = [x for x in d]
    else:
        sample = pd.read_csv(sample_path)
        regs = sample.region

    r = []   
    for reg in regs:
        if reg not in d:
            r.append([reg, None, None, None])
            continue
        r.append([reg, d[reg][p_type],
                    d[reg]['amp_euc'], d[reg]['lat_euc'],
                    d[reg]['amp_slope'], d[reg]['slope_last'], 
                    d[reg]['amp_loc'], d[reg]['slope_last_5']
                    ])
    
    df  = pd.DataFrame(data=r, columns=columns)
    
    df['significant'] = (df[p_type] <= sigl).astype(int)
    df.to_csv(Path(pth_res, f'{split}.csv'), index=False) 
    return df

def manifold_to_csv_old(meta_split, sigl, p_type):

    '''
    reformat results for table
    '''
                   
    splits = meta_splits[meta_split]
    # sample = pd.read_pickle('~/Downloads/stim.pkl')
    sample = pd.read_csv(Path(pth_res, f'intertrial.csv'))

    for split in splits:
        r = []
        d = np.load(Path(pth_res,f'{split}.npy'),
                    allow_pickle=True).flat[0] 
        
        for reg in sample.region:
            if reg not in d:
                r.append([reg, None, None, None])
                continue        
        #for reg in d:
            r.append([reg, d[reg][p_type],
                      d[reg]['amp_euc'], d[reg]['lat_euc'],
                     ])
        
        df  = pd.DataFrame(data=r,
                           columns=['region',
                                    f'p_{split}', f'amp_{split}',
                                    f'lat_{split}'])
        
        df[f'{split}_significant'] = df[f'p_{split}']<=sigl
        df.to_csv(Path(pth_res, f'{split}.csv'), index=False)


## plot tables

In [3]:
from plot_tables import get_cmap_

In [4]:
def plot_table_with_styles(df, beryl_palette, colormap_lookup, out_path):

    fig, ax = plt.subplots()
    # n_rows, n_cols = df.shape
    # fig, ax = plt.subplots(figsize=(0.4 * n_cols, 0.27 * n_rows))
    ax.axis('off')

    table = ax.table(
        cellText=df.values,
        colLabels=df.columns,
        cellLoc='center',
        loc='center'
    )

    table.auto_set_font_size(False)
    table.set_fontsize(7)
    table.scale(1.2, 1.3)

    for (row, col), cell in table.get_celld().items():
        is_header = row == 0
        col_name = df.columns[col] if not is_header else None

        # header row
        if is_header:
            cell.set_text_props(weight='bold', fontsize=7)
            cell.set_facecolor('none')          # transparent background
            cell.get_text().set_rotation(270)
            cell.set_linewidth(0)              # remove border line
            cell.get_text().set_verticalalignment('bottom')
            cell.get_text().set_horizontalalignment('center')
            cell.set_width(0.07)  # match body cell width
            cell.set_height(cell.get_height())
            cell.get_text().set_fontsize(6)

        else:
            val = df.iloc[row - 1, col]

            if col_name == 'region':
                cell.set_facecolor(beryl_palette.get(val, '#ffffff'))
                cell.get_text().set_fontsize(10)
                cell.get_text().set_weight('bold')
                cell.set_width(0.18)

            # elif col_name == 'cosmos':
            #     cell.set_facecolor('#ffffff')
            #     cell.get_text().set_fontsize(1)  # hide text
            #     cell.set_width(0.15)

            elif isinstance(val, (float, int)):
                cell.set_width(0.12)
                if 'sc' in col_name:
                    if np.isnan(val):
                        cell.set_facecolor('#f2f2f2')
                    else:
                        rgb = colormap_lookup.get(col_name, lambda x: (1, 1, 1))(val)
                        cell.set_facecolor(to_hex(rgb))
                else:
                    if val == 0:
                        cell.set_facecolor('#f2f2f2')
                    elif np.isnan(val):
                        cell.set_facecolor('#f2f2f2')
                    else:
                        # rgb = cmap(val)
                        rgb = colormap_lookup.get(col_name, lambda x: (1, 1, 1))(val)
                        cell.set_facecolor(to_hex(rgb))
                cell.get_text().set_text('')  # completely remove number
            else:
                cell.set_facecolor('#ffffff')
                cell.get_text().set_fontsize(6)

            cell.set_height(cell.get_height())
            cell.set_linewidth(0.5)
            cell.set_edgecolor('white')

    plt.savefig(out_path, bbox_inches='tight', dpi=350, transparent=True)
    plt.close()


def plot_table(times, alpha=0.05, ptype='p_euc_c', datatype='true_block'):
    table = {}
    for timeframe in times:
        splits = run_align[timeframe]
        if len(splits) == 1:
            split_name = splits[0]
        else:
            split_name = 'combined_'+"_".join(splits)
        res = manifold_to_csv(split_name, alpha, ptype)
        min_val = res['amp_euc'].min()
        max_val = res['amp_euc'].max()
        res['amp_euc'] = (res['amp_euc'] - min_val) / (max_val - min_val) + 1e-4
        res['amp_euc'] *= res['significant']
        res = res.fillna(0)
        table[timeframe] = res['amp_euc']
        
    table = pd.DataFrame(data=table)
    table['region'] = res.region
    table['beryl_hex'] = res.region.apply(swanson_to_beryl_hex, args=[br])
    beryl_palette = dict(zip(table['region'], table['beryl_hex']))
    table['sum'] = table[times].sum(axis=1)
    table['cosmos'] = table['region'].apply(lambda r: beryl_to_cosmos(r, br))

    # Load or compute region order
    ordering_path = Path(meta_pth, 'region_order.txt')
    if ordering_path.exists():
        with open(ordering_path) as f:
            region_order = [line.strip() for line in f]
    else:
        table = table.sort_values(['cosmos', 'sum'], ascending=[True, False])
        region_order = table['region'].tolist()
        with open(ordering_path, 'w') as f:
            f.writelines(r + '\n' for r in region_order)

    table['region'] = pd.Categorical(table['region'], categories=region_order, ordered=True)
    table = table.sort_values('region')

    # Drop non-display columns
    df_to_plot = table.drop(columns=['beryl_hex', 'sum', 'cosmos']).reset_index(drop=True)
    cols = df_to_plot.columns.tolist()
    cols = ['region'] + [c for c in cols if c != 'region']
    df_to_plot = df_to_plot[cols]

    colormap_lookup = {timeframe: get_cmap_(timeframe) for timeframe in times}
    plot_table_with_styles(
        df=df_to_plot,
        colormap_lookup=colormap_lookup,
        beryl_palette=beryl_palette,
        out_path=Path(meta_pth, f'table_{datatype}_alltimes_{ptype}_{alpha}.png')
    )


In [338]:
# alpha = 0.05
ptype = 'p_mean_c'

datatype = 'true_block'
times = ['block_duringchoice', 'block_duringstim', 'intertrial0']
# datatype = 'act_block'
# times = ['act_block_duringchoice', 'act_block_duringstim', 'act_intertrial0']

for alpha in [0.01, 0.05]:
    table = plot_table(times, ptype=ptype, alpha=alpha, datatype=datatype)

# times = ['stim_duringstim0', 'choice_duringstim0', 'stim_duringchoice0', 'choice_duringchoice0']
# table = plot_table_combined(times, ptype=ptype, datatype='stimchoice0', alpha=alpha)

# times = ['stim_duringstim', 'choice_duringstim', 'stim_duringchoice', 'choice_duringchoice']
# table = plot_table_combined(times, ptype=ptype, datatype='stimchoice', alpha=alpha)

In [16]:
# times = ['stim_duringstim0', 'choice_duringstim0', 'stim_duringchoice0', 'choice_duringchoice0']
# times = ['stim_duringstim', 'choice_duringstim', 'stim_duringchoice', 'choice_duringchoice']
times = ['stim_duringstim0', 'choice_duringstim0', 'stim_duringchoice0', 'choice_duringchoice0']
ptype = 'p_mean_c'
metric = 'int_mov'
# metric = 'move_shape'
sc_threshold = 0.6

plot_sc_table(times, ptype, alpha=0.05, metric=metric, sc_threshold=sc_threshold)

# for time in times:
#     res = manifold_to_csv(time, 0.05, ptype)

In [5]:
def get_sc_table(times, ptype, alpha=0.05, n=20, combined_p=True, slope_threshold=0.05, 
                 sc_threshold=0.6, amp_loc_threshold=69):
    
    # # Plot comparison table
    sc_splits = {'sc_duringchoice': [time for time in times if 'duringchoice' in time and time.startswith('stim')] + 
                                  [time for time in times if 'duringchoice' in time and time.startswith('choice')],
                 'sc_duringstim': [time for time in times if 'duringstim' in time and time.startswith('stim')] + 
                                  [time for time in times if 'duringstim' in time and time.startswith('choice')]}

    tables, res = {}, {}

    if combined_p:
        for time in times:
            splits = run_align[time]
            split_name = 'combined_'+"_".join(splits)

            compute_amp_slope(time, n)
            results = manifold_to_csv(split_name, alpha, ptype)
            # min_val = results['amp_euc'].min()
            # max_val = results['amp_euc'].max()
            # results['amp_euc'] = (results['amp_euc'] - min_val) / (max_val - min_val) + 1e-4

            results['amp_euc'] *= results['significant']
            results = results.fillna(0)
            tables[time] = results['amp_euc']
            tables[f'{time}_amp_slope'] = results['amp_slope']
            tables[f'{time}_slope_last'] = results['slope_last']
            tables[f'{time}_amp_loc'] = results['amp_loc']
            tables[f'{time}_slope_last_5'] = results['slope_last_5']

        # add in short splits for stim
        time = 'stim_duringstim1'
        splits = run_align[time]
        split_name = 'combined_'+"_".join(splits)
        results = manifold_to_csv(split_name, alpha, ptype)
        tables[time] = results['significant']

    else:
        for time in times:
            splits = run_align[time]
            for split in splits:
                results = manifold_to_csv(split, alpha, ptype)
                results['amp_euc'] *= results['significant']
                # min_val = results['amp_euc'].min()
                # max_val = results['amp_euc'].max()
                # results['amp_euc'] = (results['amp_euc'] - min_val) / (max_val - min_val) + 1e-4
                results = results.fillna(0)
                if time not in tables:
                    tables[time] = results['amp_euc']
                else:   
                    tables[time] += results['amp_euc']
        
    tables['region'] = results['region']

    # identify regions with move_init ramp shape based on choice diff curve shape, movement aligned
    meta_split = 'sc_duringchoice'
    splits = sc_splits[meta_split]
    res[f'{meta_split}_amp_loc'] = tables[f'{splits[1]}_amp_loc'] > amp_loc_threshold
    res[f'{meta_split}_slope_last'] = tables[f'{splits[1]}_slope_last'] > slope_threshold
    res[f'{meta_split}_slope_last_5'] = tables[f'{splits[1]}_slope_last_5'] > 0

    # res[f'{meta_split}_move_shape'] = np.full(len(res[f'{meta_split}_amp_loc']), np.nan)  # Initialize with NaN
    move_shape = (res[f'{meta_split}_slope_last']
                    & res[f'{meta_split}_slope_last_5']
                    & res[f'{meta_split}_amp_loc'])
    # res[f'{meta_split}_move_shape'][move_shape] = 1

    for meta_split in sc_splits:                
        # Calculate choice-stim metric, within [0,1] to be plotted
        splits = sc_splits[meta_split]
        res[meta_split] = tables[splits[1]]/(tables[splits[0]] + tables[splits[1]])

        res[f'{meta_split}_move_init'] = (move_shape & res['sc_duringchoice'] > sc_threshold).astype(int)
        
        res[f'{meta_split}_integrator'] = ((res[meta_split] > 0 
                                            & res['sc_duringchoice_slope_last']).astype(int) - res[f'{meta_split}_move_init'])
        # res[f'{meta_split}_integrator'] = (res[meta_split] > 0 & res['sc_duringchoice_slope_last']).astype(int)
        
        res[f'{meta_split}_int_mov'] = np.full(len(res[meta_split]), np.nan)  # Initialize with NaN
        res[f'{meta_split}_int_mov'][res[f'{meta_split}_move_init'] == 1] = 1
        res[f'{meta_split}_int_mov'][res[f'{meta_split}_integrator'] == 1] = 0.5
        # res[f'{meta_split}_int_mov'][res[meta_split] == 0] = 0

    # add in short splits results for stim
    mask = np.isnan(res['sc_duringstim']) & (tables['stim_duringstim1'] == 1)
    res['sc_duringstim'][mask] = 0
    res['sc_duringstim_int_mov'][res['sc_duringstim']==0] = 0
    
    res['region'] = results['region']
    res = pd.DataFrame(data=res)

    return res

In [6]:
def plot_combined_table_summary(sc_times, timing_splits, ptype='p_mean_c', alpha=0.05, combined_p=True, 
                                sc_threshold=0.6, slope_threshold=0.05, amp_loc_threshold=69):
    sc_splits = ['sc_duringchoice_int_mov', 'sc_duringstim_int_mov']

    # Handle SC splits with combined L/R
    table = get_sc_table(sc_times, ptype, alpha=alpha, combined_p=combined_p,
                         sc_threshold=sc_threshold, slope_threshold=slope_threshold, 
                         amp_loc_threshold=amp_loc_threshold)

    # Handle timing splits
    for timing_split in timing_splits:
        if combined_p:
            splits = run_align[timing_split]
            split_name = 'combined_'+"_".join(splits)
            res = manifold_to_csv(split_name, alpha, ptype)
            min_val = res['amp_euc'].min()
            max_val = res['amp_euc'].max()
            res['amp_euc'] = (res['amp_euc'] - min_val) / (max_val - min_val) + 1e-4
            res['amp_euc'] *= res['significant']
            res = res.fillna(0)
            table[timing_split] = res['amp_euc']
        else:
            for split in run_align[timing_split]:
                res = manifold_to_csv(split, alpha, ptype)
                min_val = res['amp_euc'].min()
                max_val = res['amp_euc'].max()
                res['amp_euc'] = (res['amp_euc'] - min_val) / (max_val - min_val) + 1e-4
                res['amp_euc'] *= res['significant']
                res = res.fillna(0)
                if timing_split not in table:
                    table[timing_split] = res['amp_euc']
                else:
                    table[timing_split] += res['amp_euc']

    # Create DataFrame
    df = pd.DataFrame(table)
    df['beryl_hex'] = df['region'].apply(swanson_to_beryl_hex, args=[br])
    beryl_palette = dict(zip(df['region'], df['beryl_hex']))
    df['cosmos'] = df['region'].apply(lambda r: beryl_to_cosmos(r, br))
    df['sum'] = df[sc_splits + timing_splits].sum(axis=1, skipna=True)

    # Region ordering
    ordering_path = Path(meta_pth, 'region_order.txt')
    if ordering_path.exists():
        with open(ordering_path) as f:
            region_order = [line.strip() for line in f]
    else:
        df_sorted = df.sort_values(['cosmos', 'sum'], ascending=[True, False])
        region_order = df_sorted['region'].tolist()
        with open(ordering_path, 'w') as f:
            f.writelines(r + '\n' for r in region_order)

    df['region'] = pd.Categorical(df['region'], categories=region_order, ordered=True)
    df = df.sort_values('region')
    column_names = df.columns.difference(['region']).tolist()

    # Prepare and plot
    choice_time = [time for time in timing_splits if 'duringchoice' in time]
    stim_time = [time for time in timing_splits if 'duringstim' in time]
    display_cols = ['region'] + ['sc_duringchoice_int_mov'] + choice_time + ['sc_duringstim_int_mov'] + stim_time
    df_to_plot = df[display_cols].reset_index(drop=True)

    colormap_lookup = {name: get_cmap_(name) for name in column_names}

    if 'act' in timing_splits[0]:
        block_type = 'act_block'
    else:
        block_type = 'true_block'

    if 'stim_duringstim0' in sc_times:
        out_path = Path(meta_pth, f'table_{block_type}_combined_summary0_{ptype}_combinedp{combined_p}_{alpha}.png')
    elif 'stim_duringstim1' in sc_times:
        out_path = Path(meta_pth, f'table_{block_type}_combined_summary1_{ptype}_combinedp{combined_p}_{alpha}.png')
    else: 
        out_path = Path(meta_pth, f'table_{block_type}_combined_summary_{ptype}_combinedp{combined_p}_{alpha}.png')
    plot_table_with_styles(
        df=df_to_plot,
        beryl_palette=beryl_palette,
        colormap_lookup=colormap_lookup,
        out_path=out_path
    )


In [7]:
def plot_sc_table(times, ptype, metric='int_mov', alpha=0.05, slope_threshold=0.05, 
                 sc_threshold=0.7, amp_loc_threshold=69):
    '''
    metric: 'int_mov' (region category: integrator, movement, stim) or 'move_shape' or 'sc'
    '''

    if metric == 'int_mov': 
        sc_splits = ['sc_duringchoice_int_mov', 'sc_duringstim_int_mov']
    elif metric == 'move_shape':
        sc_splits = ['sc_duringchoice_move_shape']
    elif metric == 'sc':
        sc_splits = ['sc_duringchoice', 'sc_duringstim']
    else: 
        raise ValueError(f"Invalid metric: {metric}")
    
    if 'choice_duringstim0' in times:
        datatype = 'stimchoice0'
    else:
        datatype = 'stimchoice'
    datatype = f'{datatype}_{metric}'

    res = get_sc_table(times, ptype, alpha, sc_threshold=sc_threshold, 
                       slope_threshold=slope_threshold, amp_loc_threshold=amp_loc_threshold)
    
    # Add hex values for Beryl regions
    res['beryl_hex'] = res.region.apply(swanson_to_beryl_hex,args=[br])    
    beryl_palette = dict(zip(res.region, res.beryl_hex))
        
    # Order columns according to panels in Figure
    names = ['region'] #, 'region_color']
    for split in sc_splits:
        names.append(split)
    res = res[names]

    # Sum values in each row to use for sorting
    res['sum']  = res[names[2:]].apply(np.sum,axis=1)
    res['cosmos'] = res['region'].apply(lambda r: beryl_to_cosmos(r, br))
    
    # Load or compute region order
    ordering_path = Path(meta_pth, 'region_order.txt')
    if ordering_path.exists():
        with open(ordering_path) as f:
            region_order = [line.strip() for line in f]
    else:
        res = res.sort_values(['cosmos', 'sum'], ascending=[True, False])
        region_order = res['region'].tolist()
        with open(ordering_path, 'w') as f:
            f.writelines(r + '\n' for r in region_order)

    res['region'] = pd.Categorical(res['region'], categories=region_order, ordered=True)
    res = res.sort_values('region')
    
    df_to_plot = res.drop(columns=['cosmos', 'sum']).reset_index(drop=True)

    # Ensure region is first column
    cols = df_to_plot.columns.tolist()
    cols = ['region'] + [c for c in cols if c != 'region']
    df_to_plot = df_to_plot[cols]

    # Build column-specific colormap dictionary
    colormap_lookup = {col: get_cmap_(col) for col in df_to_plot.columns if col != 'region'}

    # Export using correct filename
    outname = f'table_{datatype}_{ptype}.png'
    plot_table_with_styles(
        df=df_to_plot,
        beryl_palette=beryl_palette,
        colormap_lookup=colormap_lookup,
        out_path=Path(meta_pth, outname)
    )



In [None]:
# old version
sigl=0.05
# meta_split = 'intertrial_1.0' #'sc_duringstim1' 'intertrial_block_only' 'act_intertrial'
fdr = 'p_euc_c1' #'p_euc_c1' or 'p_euc' or 'p_euc_c'

for meta_split in expanded_meta_splits:
# for meta_split in ['intertrial_1.0']:
    # fdr correction on raw data, for fdr = 'p_euc_c1' correction over splits, and 'p_euc_c' correction over regions
    fdr_splits(meta_split, sigl)
    fdr_reg(meta_split, sigl)
    
    # load data to csv files
    manifold_to_csv(meta_split, sigl, fdr)
    
    res = plot_table(meta_split, 'amp', fdr)
    res.to_html(Path(meta_pth,f'table_{meta_split}.html'))

In [8]:
def plot_combined_sc_table_summary(sc_times, ptype='p_euc_c1', alpha=0.05, combined_p=True):
    timing_splits = ['block_duringchoice', 'block_duringstim']
    sc_splits = ['sc_duringchoice', 'sc_duringstim']
    # region_sets = []
    # region_map = {}

    # Handle SC splits with combined L/R
    table = get_sc_table(sc_times, ptype, alpha=alpha, combined_p=combined_p)
    # tables = {}
    # for meta_split in sc_splits:
    #     r = load_meta_results(f'{meta_split}1')
    #     splits = meta_splits[f'{meta_split}1']
    #     newsplits = meta_splits[f'{meta_split}']
    #     r = r.fillna(0)

    #     # Combine L/R
    #     r[f'amp_{newsplits[0]}'] = (
    #         r[f'amp_{splits[0]}'] * r[f'{splits[0]}_significant'] +
    #         r[f'amp_{splits[1]}'] * r[f'{splits[1]}_significant']
    #     ) / 2
    #     r[f'amp_{newsplits[1]}'] = (
    #         r[f'amp_{splits[2]}'] * r[f'{splits[2]}_significant'] +
    #         r[f'amp_{splits[3]}'] * r[f'{splits[3]}_significant']
    #     ) / 2

    #     tables[meta_split] = r
    #     splits = meta_splits[meta_split]
    #     amp0 = f'amp_{splits[0]}'
    #     amp1 = f'amp_{splits[1]}'
    #     choice_stim = r[amp1] / (r[amp0] + r[amp1])
    #     # region_map[meta_split] = r['region']        
    #     table[meta_split] = choice_stim
    #     # region_sets.append(r['region'])

    # Handle timing splits
    for timing_split in timing_splits:
        if combined_p:
            splits = run_align[timing_split]
            split_name = 'combined_'+"_".join(splits)
            res = manifold_to_csv(split_name, alpha, ptype)
            min_val = res['amp_euc'].min()
            max_val = res['amp_euc'].max()
            res['amp_euc'] = (res['amp_euc'] - min_val) / (max_val - min_val) + 1e-4
            res['amp_euc'] *= res['significant']
            res = res.fillna(0)
            # region_map[timing_split] = res['region']
            table[timing_split] = res['amp_euc']
            # region_sets.append(res['region'])
        else:
            for split in run_align[timing_split]:
                res = manifold_to_csv(split, alpha, ptype)
                min_val = res['amp_euc'].min()
                max_val = res['amp_euc'].max()
                res['amp_euc'] = (res['amp_euc'] - min_val) / (max_val - min_val) + 1e-4
                res['amp_euc'] *= res['significant']
                res = res.fillna(0)
                if timing_split not in table:
                    table[timing_split] = res['amp_euc']
                else:
                    table[timing_split] += res['amp_euc']

    # # Union of all regions
    # all_regions = pd.Index(sorted(set().union(*region_sets)))

    # # Align each column to full region list, filling with NaN
    # for k in table:
    #     col = table[k]
    #     col = pd.Series(col.values, index=region_map[k])
    #     table[k] = col.reindex(all_regions)

    # Create DataFrame
    df = pd.DataFrame(table)
    # df['region'] = all_regions
    df['beryl_hex'] = df['region'].apply(swanson_to_beryl_hex, args=[br])
    beryl_palette = dict(zip(df['region'], df['beryl_hex']))
    df['cosmos'] = df['region'].apply(lambda r: beryl_to_cosmos(r, br))
    df['sum'] = df[sc_splits + timing_splits].sum(axis=1, skipna=True)

    # Region ordering
    ordering_path = Path(meta_pth, 'region_order.txt')
    if ordering_path.exists():
        with open(ordering_path) as f:
            region_order = [line.strip() for line in f]
    else:
        df_sorted = df.sort_values(['cosmos', 'sum'], ascending=[True, False])
        region_order = df_sorted['region'].tolist()
        with open(ordering_path, 'w') as f:
            f.writelines(r + '\n' for r in region_order)

    df['region'] = pd.Categorical(df['region'], categories=region_order, ordered=True)
    df = df.sort_values('region')
    column_names = df.columns.difference(['region']).tolist()

    # Prepare and plot
    display_cols = ['region'] + ['sc_duringchoice', 'block_duringchoice',
                                 'sc_duringstim', 'block_duringstim']
    df_to_plot = df[display_cols].reset_index(drop=True)

    colormap_lookup = {name: get_cmap_(name) for name in column_names}

    if 'stim_duringstim0' in sc_times:
        out_path = Path(meta_pth, f'table_combined_sc_summary0_{ptype}_combinedp{combined_p}.png')
    else: 
        out_path = Path(meta_pth, f'table_combined_sc_summary_{ptype}_combinedp{combined_p}.png')
    plot_table_with_styles(
        df=df_to_plot,
        beryl_palette=beryl_palette,
        colormap_lookup=colormap_lookup,
        out_path=out_path
    )


In [384]:
# times = ['stim_duringstim', 'choice_duringstim', 'stim_duringchoice', 'choice_duringchoice']
times = ['stim_duringchoice0', 'choice_duringchoice0', 'stim_duringstim0', 'choice_duringstim0']

timing_splits = ['act_block_duringchoice', 'act_block_duringstim']
# timing_splits = ['block_duringchoice', 'block_duringstim']

# sigl=0.05
ptype = 'p_mean_c'
combined_p=True
sc_threshold=0.6

for sigl in [0.05, 0.01]:
    plot_combined_table_summary(times, timing_splits, ptype=ptype, alpha=sigl, 
                                combined_p=combined_p, sc_threshold=sc_threshold)
# plot_combined_sc_table_summary(times, ptype=ptype, alpha=sigl, combined_p=combined_p)

# for ptype in ['p_mean_c', 'p_amp_c', 'p_max_c']:
    # plot_combined_table_summary(times, ptype=ptype, alpha=sigl)


In [None]:
sigl = 0.05
ptype = 'p_euc_c1' #'p_euc_c1' 'p_euc'
times = ['duringstim','act_duringstim','duringchoice',
         'act_duringchoice','intertrial','act_intertrial']
blocktype = 'all' #'true' 'act' 'all'

for meta_split in times:
    if ptype == 'p_euc_c1':
        plot_tables.fdr_splits(meta_split, sigl)
    elif ptype == 'p_euc_c':
        plot_tables.fdr_reg(meta_split, sigl)

    manifold_to_csv(meta_split,sigl,ptype)
    
plot_block_alltimes(times, ptype, blocktype)

In [9]:
def plot_combined_onetype(sc_times, sc_type, timing_split, ptype='p_euc', alpha=0.05, combined_p=True,
                          sc_threshold=0.6, slope_threshold=0.05, amp_loc_threshold=69):

    '''
    sc_type: 'stim' or 'choice' or 'integrator'
    timing_split: 'block_duringstim' or 'block_duringchoice' or 'act_block_duringstim' or 'act_block_duringchoice'
    '''

    table = {}
    # region_sets = []
    # region_map = {}

    sc_split = 'sc_duringchoice_int_mov' if 'choice' in timing_split else 'sc_duringstim_int_mov'
    print(sc_split, timing_split)

    # Handle SC splits with combined L/R
    res = get_sc_table(sc_times, ptype, alpha=alpha, combined_p=combined_p, 
                       sc_threshold=sc_threshold, slope_threshold=slope_threshold, 
                       amp_loc_threshold=amp_loc_threshold)
    table[sc_split] = res[sc_split]
    table['region'] = res['region']

    # Handle timing split
    if combined_p:
        splits = run_align[timing_split]
        combined_name = 'combined_'+"_".join(splits)
        res = manifold_to_csv(combined_name, alpha, ptype)
        min_val = res['amp_euc'].min()
        max_val = res['amp_euc'].max()
        res['amp_euc'] = (res['amp_euc'] - min_val) / (max_val - min_val) + 1e-4
        res['amp_euc'] *= res['significant']
        res = res.fillna(0)
        table[timing_split] = res['amp_euc']
    else:
        for split in run_align[timing_split]:
            res = manifold_to_csv(split, alpha, ptype)
            min_val = res['amp_euc'].min()
            max_val = res['amp_euc'].max()
            res['amp_euc'] = (res['amp_euc'] - min_val) / (max_val - min_val) + 1e-4
            res['amp_euc'] *= res['significant']
            res = res.fillna(0)
            if timing_split not in table:
                table[timing_split] = res['amp_euc']
            else:
                table[timing_split] += res['amp_euc']

    # Binarize and combine timing split to sc_split
    table[timing_split] = table[timing_split].apply(lambda x: 1 if x > 0 else np.nan)
    if sc_type == 'stim':
        table[sc_split] = table[sc_split].apply(lambda x: 1 if x == 0 else np.nan)
        table['combined'] = table[sc_split] * table[timing_split]
    elif sc_type == 'choice':
        table[sc_split] = table[sc_split].apply(lambda x: 1 if x == 1 else np.nan)
        table['combined'] = table[sc_split] * table[timing_split]
    elif sc_type == 'integrator':
        table[sc_split] = table[sc_split].apply(lambda x: 1 if 0<x<1 else np.nan)
        table['combined'] = table[sc_split] * table[timing_split]
    else:
        raise ValueError(f"Invalid sc_type: {sc_type}")

    # Create DataFrame
    df = pd.DataFrame(table)
    # df['region'] = all_regions
    df['beryl_hex'] = df['region'].apply(swanson_to_beryl_hex, args=[br])
    beryl_palette = dict(zip(df['region'], df['beryl_hex']))

    # Region ordering
    ordering_path = Path(meta_pth, 'region_order.txt')
    with open(ordering_path) as f:
        region_order = [line.strip() for line in f]

    df['region'] = pd.Categorical(df['region'], categories=region_order, ordered=True)
    df = df.sort_values('region')

    # Prepare and plot
    display_cols = ['region'] + ['combined']
    df_to_plot = df[display_cols].reset_index(drop=True)

    colormap_lookup = {
        'combined': get_cmap_('intertrial')
    }
    
    regions_with_1 = df_to_plot.loc[df_to_plot['combined'] == 1, 'region'].tolist()

    out_path = Path(meta_pth, f'table_combined_{sc_type}_{timing_split}_{ptype}_combinedp{combined_p}_{alpha}.png')
    plot_table_with_styles(
        df=df_to_plot,
        beryl_palette=beryl_palette,
        colormap_lookup=colormap_lookup,
        out_path=out_path
    )

    return regions_with_1


In [None]:
# times = ['stim_duringstim0', 'choice_duringstim0', 'stim_duringchoice0', 'choice_duringchoice0']
# times = ['stim_duringstim', 'choice_duringstim', 'stim_duringchoice', 'choice_duringchoice']
times = ['stim_duringstim0', 'choice_duringstim0', 'stim_duringchoice0', 'choice_duringchoice0']

ptype = 'p_mean_c'
combined_p = True
sc_threshold = 0.6
alpha = 0.05

stim_regs = plot_combined_onetype(times, 'stim', 'act_block_duringstim', ptype, combined_p=combined_p, 
                      sc_threshold=sc_threshold, alpha=alpha)
move_regs_choice = plot_combined_onetype(times, 'choice', 'act_block_duringchoice', ptype, combined_p=combined_p, 
                      sc_threshold=sc_threshold, alpha=alpha)
int_regs_stim = plot_combined_onetype(times, 'integrator', 'act_block_duringstim', ptype, combined_p=combined_p, 
                      sc_threshold=sc_threshold, alpha=alpha)
int_regs_choice = plot_combined_onetype(times, 'integrator', 'act_block_duringchoice', ptype, combined_p=combined_p, 
                      sc_threshold=sc_threshold, alpha=alpha)
move_regs_stim = plot_combined_onetype(times, 'choice', 'act_block_duringstim', ptype, combined_p=combined_p, 
                      sc_threshold=sc_threshold, alpha=alpha)







