In [None]:
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
from sklearn.cluster import KMeans
from scipy import stats
from pathlib import Path
import pynapple as nap

from spatial_manifolds.data.binning import get_bin_config
from spatial_manifolds.data.loading import load_session

import warnings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2


In [None]:
mouse = 29
day = 0
session_type = 'MCVR'
sorter = 'kilosort4'
storage = Path('/Volumes/cmvm/sbms/groups/CDBS_SIDB_storage/NolanLab/ActiveProjects/Wolf/COHORT12/')
#storage = Path('/Users/harryclark/Downloads/nwb_data/')

alpha = 0.001
n_jobs = 8
n_shuffles = 100
seed = 1

class Args:
    def __init__(self,mouse,day,session_type,sorter,storage,alpha,n_jobs,n_shuffles,seed):
        self.mouse = mouse
        self.day = day
        self.session_type = session_type
        self.sorter = sorter
        self.storage = storage
        self.alpha = alpha
        self.n_jobs = n_jobs
        self.n_shuffles = n_shuffles
        self.seed = seed
args = Args(mouse,day,session_type,sorter,storage,alpha,n_jobs,n_shuffles,seed)

if session_type == 'VR':
    tl = 200
elif session_type == 'MCVR':
    tl = 230


In [None]:

def plot_stops(session):
    trial_numbers = np.array(session['trial_number'])
    position = np.array(session['P'])
    trial_types = np.array(session['trial_type'])
    speed = np.array(session['S'])
    stop_mask = speed<3

    fig, ax = plt.subplots(
        1, 1, layout='constrained', figsize=(2, 2)
    )
    for tt, ttc in zip([1,3], ['red', 'blue']):
        tt_mask = trial_types == tt
        ax.scatter(position[(stop_mask & tt_mask)], 
                   trial_numbers[(stop_mask & tt_mask)],alpha=0.05, s=20, c=ttc)
    ax.set_xlim(0,230)
    ax.set_ylim(np.min(trial_numbers),np.max(trial_numbers))
    ax.axvspan(
        90,110,
        alpha=0.2,
        zorder=-10,
        edgecolor='none',
        facecolor='orange',
    )
    ax.axvspan(
        120,140,
        alpha=0.2,
        zorder=-10,
        edgecolor='none',
        facecolor='teal',
    )
    plt.show()

def extract_first_stops(session, tt):
    trial_numbers = np.array('trial_number')
    position = np.array(session['P'])
    trial_types = np.array(session['trial_type'])
    speed = np.array(session['S'])
    stop_mask = speed<3
    tt_mask = trial_types == tt
    track_mask = position>30

    first_stops = []
    for tn in np.unique(trial_numbers):
        tn_mask = trial_numbers==tn
        trial_stop_locations = position[(stop_mask & tt_mask & tn_mask & track_mask)]
        if len(trial_stop_locations)>0:
            first_stop = trial_stop_locations[0]
        else:
            first_stop = np.nan
        first_stops.append(first_stop)
    first_stops=np.array(first_stops)
    return first_stops


In [None]:
# All mice!
# Load session
sessions = {}
for session_path in sorted(
    list(args.storage.glob(f'*/*/MCVR/*MCVR.nwb'))
):
    args.mouse = int(session_path.parent.parent.parent.name[1:])
    args.day = int(session_path.parent.parent.name[1:])
    mouse_day = f'M{args.mouse}D{args.day}'
    sessions[mouse_day] = (
        *load_session(args),
        pd.read_parquet(
            session_path.parent / 'tuning_scores' / 'ramp_class.parquet'
        ),
    )
bin_config = get_bin_config(args.session_type)['P']

In [None]:
# Plot
trial_types = [1,3]
ramp_classes = [('+','+'), 
                ('+','-'),
                ('+','/')]

fig, axs = plt.subplots(
    1, len(ramp_classes), layout='constrained', figsize=(8, 2)
)

for class_idx, ramp_class in enumerate(ramp_classes):
    for trial_type, tt_c in zip([1,3], ['red', 'blue']):
        tcs_zscored = []
        for day_idx, (
            day,
            (session, session_path, clusters, ramp_table),
        ) in enumerate(sessions.items()):
            
            subset_ids1 = ramp_table[
                (ramp_table['group'] == 'rz1_b')
                & (ramp_table['sign'] == ramp_class[0])
                & (ramp_table['region'] == 'outbound')
            ]['cluster_id'].values

            subset_ids2 = ramp_table[
                (ramp_table['group'] == 'rz1_b')
                & (ramp_table['sign'] == ramp_class[1])
                & (ramp_table['region'] == 'homebound')
            ]['cluster_id'].values

            subset_ids = np.intersect1d(subset_ids1, subset_ids2)

            trials = session['trials'][session['trials']['trial_type'] == float(trial_type)]

            tcs  = nap.compute_1d_tuning_curves(
                clusters,
                session['P'],
                nb_bins=bin_config['num_bins'],
                minmax=bin_config['bounds'],
                ep=session['moving'].intersect(trials),
            )
            zscored = (tcs[subset_ids] - tcs[subset_ids].mean(axis=0)) / tcs[
            subset_ids].std(axis=0)

            tcs_zscored.extend(zscored.to_numpy().T.tolist())

        # z score and plot here axs[0, class_idx]
        tcs_zscored = np.array(tcs_zscored)

        print(f'rampclass {ramp_class} , len(zscored)={len(tcs_zscored)}')

        if len(tcs_zscored)>0:
            mean = np.nanmean(tcs_zscored, axis=0)
            sem = stats.sem(tcs_zscored, axis=0, nan_policy='omit')
            axs[class_idx].plot(tcs.index, mean, color=tt_c)
            axs[class_idx].fill_between(
                tcs.index, mean - sem, mean + sem, alpha=0.2, color=tt_c
                )
        axs[class_idx].text(
            1.0,
            0.05,
            f'N={len(tcs_zscored)}',
            fontsize=8,
            ha='right',
            transform=axs[class_idx].transAxes,
        )
        axs[class_idx].axvspan(
            90,
            110,
            alpha=0.2,
            zorder=-10,
            edgecolor='none',
            facecolor='orange',
            )
        axs[class_idx].axvspan(
            120,
            140,
            alpha=0.2,
            zorder=-10,
            edgecolor='none',
            facecolor='teal',
            )
        axs[class_idx].axvspan(
            0,
            30,
            alpha=0.2,
            zorder=-10,
            edgecolor='none',
            facecolor='grey',
            )
        axs[class_idx].axvspan(
            200,
            230,
            alpha=0.2,
            zorder=-10,
            edgecolor='none',
            facecolor='grey',
            )
        axs[class_idx].set_xlim(0,230)
plt.show()


In [None]:
# All mice!
# Load session

fig, axs = plt.subplots(
    1, 6, layout='constrained', figsize=(10, 2)
)

for m_idx, mouse in enumerate([22,25,26,27,28,29]):
    sessions = {}
    for session_path in sorted(
        list(args.storage.glob(f'M{mouse}/*/MCVR/*MCVR.nwb'))
    ):
        args.mouse = int(session_path.parent.parent.parent.name[1:])
        args.day = int(session_path.parent.parent.name[1:])
        mouse_day = f'M{args.mouse}D{args.day}'
        sessions[mouse_day] = (
            *load_session(args),
            pd.read_parquet(
                session_path.parent / 'tuning_scores' / 'ramp_class.parquet'
            ),
        )
    bin_config = get_bin_config(args.session_type)['P']

    # Plot
    trial_types = [1,3]
    for trial_type, tt_c in zip([1,3], ['red', 'blue']):
        tcs_zscored = []
        first_stops = []
        for day_idx, (
            day,
            (session, session_path, clusters, ramp_table),
        ) in enumerate(sessions.items()):
            
            subset_ids = ramp_table[
                (ramp_table['group'] == 'rz1_b')
                & (ramp_table['sign'] == '+')
                & (ramp_table['region'] == 'outbound')
            ]['cluster_id'].values

            trials = session['trials'][session['trials']['trial_type'] == float(trial_type)]

            tcs  = nap.compute_1d_tuning_curves(
                clusters,
                session['P'],
                nb_bins=bin_config['num_bins'],
                minmax=bin_config['bounds'],
                ep=session['moving'].intersect(trials),
            )
            zscored = (tcs[subset_ids] - tcs[subset_ids].mean(axis=0)) / tcs[
            subset_ids].std(axis=0)

            tcs_zscored.extend(zscored.to_numpy().T.tolist())
            first_stops.extend(extract_first_stops(session, trial_type).tolist())

        # z score and plot here axs[0, class_idx]
        tcs_zscored = np.array(tcs_zscored)
        first_stops = np.array(first_stops)

        print(f'mouse {mouse}, len(zscored)={len(tcs_zscored)}')

        if len(tcs_zscored)>0:
            mean = np.nanmean(tcs_zscored, axis=0)
            sem = stats.sem(tcs_zscored, axis=0, nan_policy='omit')
            axs[m_idx].plot(tcs.index, mean, color=tt_c)
            axs[m_idx].fill_between(
                tcs.index, mean - sem, mean + sem, alpha=0.2, color=tt_c
                )
            axs[m_idx].axvline(x=np.nanmean(first_stops), color=tt_c)

        axs[m_idx].text(
            1.0,
            0.05,
            f'N={len(tcs_zscored)}',
            fontsize=8,
            ha='right',
            transform=axs[m_idx].transAxes,
        )
        axs[m_idx].axvspan(
            90,
            110,
            alpha=0.2,
            zorder=-10,
            edgecolor='none',
            facecolor='orange',
            )
        axs[m_idx].axvspan(
            120,
            140,
            alpha=0.2,
            zorder=-10,
            edgecolor='none',
            facecolor='teal',
            )
        axs[m_idx].axvspan(
            0,
            30,
            alpha=0.2,
            zorder=-10,
            edgecolor='none',
            facecolor='grey',
            )
        axs[m_idx].axvspan(
            200,
            230,
            alpha=0.2,
            zorder=-10,
            edgecolor='none',
            facecolor='grey',
            )
        axs[m_idx].set_xlim(0,230)
plt.show()


In [None]:
# All mice!
# Load session
sessions = {}
for m_idx, mouse in enumerate([22,25,26,27,28,29]):
    for session_path in sorted(
        list(args.storage.glob(f'M{mouse}/*/MCVR/*MCVR.nwb'))
    ):
        args.mouse = int(session_path.parent.parent.parent.name[1:])
        args.day = int(session_path.parent.parent.name[1:])
        mouse_day = f'M{args.mouse}D{args.day}'
        sessions[mouse_day] = (
            *load_session(args),
            pd.read_parquet(
                session_path.parent / 'tuning_scores' / 'ramp_class.parquet'
            ),
        )
        print(f'mouse_day, {mouse_day}')
        plot_stops(sessions[mouse_day][0])

bin_config = get_bin_config(args.session_type)['P']

fig, ax = plt.subplots(
    1, 1, layout='constrained', figsize=(3, 3)
)
# Plot
trial_types = [1,3]
for trial_type, tt_c in zip([1,3], ['red', 'blue']):
    tcs_zscored = []
    first_stops = []
    for day_idx, (
        day,
        (session, session_path, clusters, ramp_table),
    ) in enumerate(sessions.items()):
        
        subset_ids = ramp_table[
            (ramp_table['group'] == 'rz1_b')
            & (ramp_table['sign'] == '+')
            & (ramp_table['region'] == 'outbound')
        ]['cluster_id'].values

        trials = session['trials'][session['trials']['trial_type'] == float(trial_type)]

        tcs = nap.compute_1d_tuning_curves(
            clusters,
            session['P'],
            nb_bins=bin_config['num_bins'],
            minmax=bin_config['bounds'],
            ep=session['moving'].intersect(trials),
        )
        zscored = (tcs[subset_ids] - tcs[subset_ids].mean(axis=0)) / tcs[
        subset_ids].std(axis=0)

        tcs_zscored.extend(zscored.to_numpy().T.tolist())
        first_stops.extend(extract_first_stops(session, trial_type).tolist())

    # z score and plot here axs[0, class_idx]
    tcs_zscored = np.array(tcs_zscored)
    first_stops = np.array(first_stops)

    if len(tcs_zscored)>0:
        mean = np.nanmean(tcs_zscored, axis=0)
        sem = stats.sem(tcs_zscored, axis=0, nan_policy='omit')
        ax.plot(tcs.index, mean, color=tt_c)
        ax.fill_between(
            tcs.index, mean - sem, mean + sem, alpha=0.2, color=tt_c
            )
        ax.axvline(x=np.nanmean(first_stops), color=tt_c)

ax.text(
    1.0,
    0.05,
    f'N={len(tcs_zscored)}',
    fontsize=8,
    ha='right',
    transform=ax.transAxes,
)
ax.axvspan(
    90,
    110,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='orange',
    )
ax.axvspan(
    120,
    140,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='teal',
    )
ax.axvspan(
    0,
    30,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='grey',
    )
ax.axvspan(
    200,
    230,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='grey',
    )
ax.set_xlim(0,230)
plt.show()


In [None]:
mouse_day

In [None]:
mouse_day in good_sesh

In [None]:
# All mice!
good_sesh = ["M22D43", "M22D45", "M22D46", 
             "M25D26", "M25D29", "M25D30", 
             "M26D24", 
             "M28D29", 
             "M29D24", "M29D26", "M29D29"]

# Load session
sessions = {}
for m_idx, mouse in enumerate([22,25,26,27,28,29]):
    for session_path in sorted(
        list(args.storage.glob(f'M{mouse}/*/MCVR/*MCVR.nwb'))
    ):
        args.mouse = int(session_path.parent.parent.parent.name[1:])
        args.day = int(session_path.parent.parent.name[1:])
        mouse_day = f'M{args.mouse}D{args.day}'

        care = True
        if mouse_day in good_sesh:
            sessions[mouse_day] = (
                *load_session(args),
                pd.read_parquet(
                    session_path.parent / 'tuning_scores' / 'ramp_class.parquet'
                ),
            )
            print(f'mouse_day, {mouse_day}')
            plot_stops(sessions[mouse_day][0])

bin_config = get_bin_config(args.session_type)['P']

fig, ax = plt.subplots(
    1, 1, layout='constrained', figsize=(3, 3)
)
# Plot
trial_types = [1,3]
for trial_type, tt_c in zip([1,3], ['red', 'blue']):
    tcs_zscored = []
    first_stops = []
    for day_idx, (
        day,
        (session, session_path, clusters, ramp_table),
    ) in enumerate(sessions.items()):
        
        subset_ids = ramp_table[
            (ramp_table['group'] == 'rz1_b')
            & (ramp_table['sign'] == '+')
            & (ramp_table['region'] == 'outbound')
        ]['cluster_id'].values

        trials = session['trials'][session['trials']['trial_type'] == float(trial_type)]

        tcs = nap.compute_1d_tuning_curves(
            clusters,
            session['P'],
            nb_bins=bin_config['num_bins'],
            minmax=bin_config['bounds'],
            ep=session['moving'].intersect(trials),
        )
        zscored = (tcs[subset_ids] - tcs[subset_ids].mean(axis=0)) / tcs[
        subset_ids].std(axis=0)

        tcs_zscored.extend(zscored.to_numpy().T.tolist())
        first_stops.extend(extract_first_stops(session, trial_type).tolist())

    # z score and plot here axs[0, class_idx]
    tcs_zscored = np.array(tcs_zscored)
    first_stops = np.array(first_stops)

    if len(tcs_zscored)>0:
        mean = np.nanmean(tcs_zscored, axis=0)
        sem = stats.sem(tcs_zscored, axis=0, nan_policy='omit')
        ax.plot(tcs.index, mean, color=tt_c)
        ax.fill_between(
            tcs.index, mean - sem, mean + sem, alpha=0.2, color=tt_c
            )
        ax.axvline(x=np.nanmean(first_stops), color=tt_c)

ax.text(
    1.0,
    0.05,
    f'N={len(tcs_zscored)}',
    fontsize=8,
    ha='right',
    transform=ax.transAxes,
)
ax.axvspan(
    90,
    110,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='orange',
    )
ax.axvspan(
    120,
    140,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='teal',
    )
ax.axvspan(
    0,
    30,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='grey',
    )
ax.axvspan(
    200,
    230,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='grey',
    )
ax.set_xlim(0,230)
plt.show()


In [None]:
# All mice!
# Load session

fig, axs = plt.subplots(
    1, 6, layout='constrained', figsize=(10, 2)
)

for m_idx, mouse in enumerate([22,25,26,27,28,29]):
    sessions = {}
    for session_path in sorted(
        list(args.storage.glob(f'M{mouse}/*/MCVR/*MCVR.nwb'))
    ):
        args.mouse = int(session_path.parent.parent.parent.name[1:])
        args.day = int(session_path.parent.parent.name[1:])
        mouse_day = f'M{args.mouse}D{args.day}'
        sessions[mouse_day] = (
            *load_session(args),
            pd.read_parquet(
                session_path.parent / 'tuning_scores' / 'ramp_class.parquet'
            ),
        )
    bin_config = get_bin_config(args.session_type)['P']

    # Plot
    trial_types = [1,3]
    for trial_type, tt_c in zip([1,3], ['red', 'blue']):
        tcs_zscored = []
        first_stops = []
        for day_idx, (
            day,
            (session, session_path, clusters, ramp_table),
        ) in enumerate(sessions.items()):
            
            subset_ids = ramp_table[
                (ramp_table['group'] == 'rz1_b')
                & (ramp_table['sign'] == '+')
                & (ramp_table['region'] == 'outbound')
            ]['cluster_id'].values

            trials = session['trials'][session['trials']['trial_type'] == float(trial_type)]

            tcs  = nap.compute_1d_tuning_curves(
                clusters,
                session['P'],
                nb_bins=bin_config['num_bins'],
                minmax=bin_config['bounds'],
                ep=session['moving'].intersect(trials),
            )
            zscored = (tcs[subset_ids] - tcs[subset_ids].mean(axis=0)) / tcs[
            subset_ids].std(axis=0)

            tcs_zscored.extend(zscored.to_numpy().T.tolist())
            first_stops.extend(extract_first_stops(session, trial_type).tolist())

        # z score and plot here axs[0, class_idx]
        tcs_zscored = np.array(tcs_zscored)
        first_stops = np.array(first_stops)

        print(f'mouse {mouse}, len(zscored)={len(tcs_zscored)}')

        if len(tcs_zscored)>0:
            mean = np.nanmean(tcs_zscored, axis=0)
            sem = stats.sem(tcs_zscored, axis=0, nan_policy='omit')
            axs[m_idx].plot(tcs.index, mean, color=tt_c)
            axs[m_idx].fill_between(
                tcs.index, mean - sem, mean + sem, alpha=0.2, color=tt_c
                )
            axs[m_idx].axvline(x=np.nanmean(first_stops), color=tt_c)

        axs[m_idx].text(
            1.0,
            0.05,
            f'N={len(tcs_zscored)}',
            fontsize=8,
            ha='right',
            transform=axs[m_idx].transAxes,
        )
        axs[m_idx].axvspan(
            90,
            110,
            alpha=0.2,
            zorder=-10,
            edgecolor='none',
            facecolor='orange',
            )
        axs[m_idx].axvspan(
            120,
            140,
            alpha=0.2,
            zorder=-10,
            edgecolor='none',
            facecolor='teal',
            )
        axs[m_idx].axvspan(
            0,
            30,
            alpha=0.2,
            zorder=-10,
            edgecolor='none',
            facecolor='grey',
            )
        axs[m_idx].axvspan(
            200,
            230,
            alpha=0.2,
            zorder=-10,
            edgecolor='none',
            facecolor='grey',
            )
        axs[m_idx].set_xlim(0,230)
plt.show()


In [None]:
# All mice!
good_sesh = ["M22D43", "M22D45", "M22D46", 
             "M25D26", "M25D29", "M25D30", 
             "M26D24", 
             "M28D29", 
             "M29D24", "M29D26", "M29D29"]

# Load session
sessions = {}
for m_idx, mouse in enumerate([22,25,26,27,28,29]):
    for session_path in sorted(
        list(args.storage.glob(f'M{mouse}/*/MCVR/*MCVR.nwb'))
    ):
        args.mouse = int(session_path.parent.parent.parent.name[1:])
        args.day = int(session_path.parent.parent.name[1:])
        mouse_day = f'M{args.mouse}D{args.day}'

        care = True
        if mouse_day in good_sesh:
            sessions[mouse_day] = (
                *load_session(args),
                pd.read_parquet(
                    session_path.parent / 'tuning_scores' / 'ramp_class.parquet'
                ),
            )
            print(f'mouse_day, {mouse_day}')
            plot_stops(sessions[mouse_day][0])

bin_config = get_bin_config(args.session_type)['P']

fig, ax = plt.subplots(
    1, 1, layout='constrained', figsize=(3, 3)
)
# Plot
trial_types = [1,3]
for trial_type, tt_c in zip([1,3], ['red', 'blue']):
    tcs_zscored = []
    first_stops = []
    for day_idx, (
        day,
        (session, session_path, clusters, ramp_table),
    ) in enumerate(sessions.items()):
        
        subset_ids1 = ramp_table[
            (ramp_table['group'] == 'rz1_nb')
            & (ramp_table['sign'] == '+')
            & (ramp_table['region'] == 'outbound')
        ]['cluster_id'].values

        subset_ids2 = ramp_table[
            (ramp_table['group'] == 'rz2_nb')
            & (ramp_table['sign'] == '+')
            & (ramp_table['region'] == 'outbound')
        ]['cluster_id'].values

        subset_ids = np.intersect1d(subset_ids1, subset_ids2)

        trials = session['trials'][session['trials']['trial_type'] == float(trial_type)]

        tcs = nap.compute_1d_tuning_curves(
            clusters,
            session['P'],
            nb_bins=bin_config['num_bins'],
            minmax=bin_config['bounds'],
            ep=session['moving'].intersect(trials),
        )
        zscored = (tcs[subset_ids] - tcs[subset_ids].mean(axis=0)) / tcs[
        subset_ids].std(axis=0)

        tcs_zscored.extend(zscored.to_numpy().T.tolist())
        first_stops.extend(extract_first_stops(session, trial_type).tolist())

    # z score and plot here axs[0, class_idx]
    tcs_zscored = np.array(tcs_zscored)
    first_stops = np.array(first_stops)

    if len(tcs_zscored)>0:
        mean = np.nanmean(tcs_zscored, axis=0)
        sem = stats.sem(tcs_zscored, axis=0, nan_policy='omit')
        ax.plot(tcs.index, mean, color=tt_c)
        ax.fill_between(
            tcs.index, mean - sem, mean + sem, alpha=0.2, color=tt_c
            )
        ax.axvline(x=np.nanmean(first_stops), color=tt_c)

ax.text(
    1.0,
    0.05,
    f'N={len(tcs_zscored)}',
    fontsize=8,
    ha='right',
    transform=ax.transAxes,
)
ax.axvspan(
    90,
    110,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='orange',
    )
ax.axvspan(
    120,
    140,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='teal',
    )
ax.axvspan(
    0,
    30,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='grey',
    )
ax.axvspan(
    200,
    230,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='grey',
    )
ax.set_xlim(0,230)
plt.show()

In [None]:
# Load session
sessions = {}
for m_idx, mouse in enumerate([22,25,26,27,28,29]):
    for session_path in sorted(
        list(args.storage.glob(f'M{mouse}/*/MCVR/*MCVR.nwb'))
    ):
        args.mouse = int(session_path.parent.parent.parent.name[1:])
        args.day = int(session_path.parent.parent.name[1:])
        mouse_day = f'M{args.mouse}D{args.day}'

        sessions[mouse_day] = (
            *load_session(args),
            pd.read_parquet(
                session_path.parent / 'tuning_scores' / 'ramp_class.parquet'
            ),
        )
        print(f'mouse_day, {mouse_day}')
        plot_stops(sessions[mouse_day][0])

bin_config = get_bin_config(args.session_type)['P']

fig, ax = plt.subplots(
    1, 1, layout='constrained', figsize=(3, 3)
)
# Plot
trial_types = [1,3]
for trial_type, tt_c in zip([1,3], ['red', 'blue']):
    tcs_zscored = []
    first_stops = []
    for day_idx, (
        day,
        (session, session_path, clusters, ramp_table),
    ) in enumerate(sessions.items()):
        
        subset_ids1 = ramp_table[
            (ramp_table['group'] == 'rz1_nb')
            & (ramp_table['sign'] == '+')
            & (ramp_table['region'] == 'outbound')
        ]['cluster_id'].values

        subset_ids2 = ramp_table[
            (ramp_table['group'] == 'rz2_nb')
            & (ramp_table['sign'] == '+')
            & (ramp_table['region'] == 'outbound')
        ]['cluster_id'].values

        subset_ids = np.intersect1d(subset_ids1, subset_ids2)

        trials = session['trials'][session['trials']['trial_type'] == float(trial_type)]

        tcs = nap.compute_1d_tuning_curves(
            clusters,
            session['P'],
            nb_bins=bin_config['num_bins'],
            minmax=bin_config['bounds'],
            ep=session['moving'].intersect(trials),
        )
        zscored = (tcs[subset_ids] - tcs[subset_ids].mean(axis=0)) / tcs[
        subset_ids].std(axis=0)

        tcs_zscored.extend(zscored.to_numpy().T.tolist())
        first_stops.extend(extract_first_stops(session, trial_type).tolist())

    # z score and plot here axs[0, class_idx]
    tcs_zscored = np.array(tcs_zscored)
    first_stops = np.array(first_stops)

    if len(tcs_zscored)>0:
        mean = np.nanmean(tcs_zscored, axis=0)
        sem = stats.sem(tcs_zscored, axis=0, nan_policy='omit')
        ax.plot(tcs.index, mean, color=tt_c)
        ax.fill_between(
            tcs.index, mean - sem, mean + sem, alpha=0.2, color=tt_c
            )
        ax.axvline(x=np.nanmean(first_stops), color=tt_c)

ax.text(
    1.0,
    0.05,
    f'N={len(tcs_zscored)}',
    fontsize=8,
    ha='right',
    transform=ax.transAxes,
)
ax.axvspan(
    90,
    110,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='orange',
    )
ax.axvspan(
    120,
    140,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='teal',
    )
ax.axvspan(
    0,
    30,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='grey',
    )
ax.axvspan(
    200,
    230,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='grey',
    )
ax.set_xlim(0,230)
plt.show()