In [None]:
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
from sklearn.cluster import KMeans
from scipy import stats
from pathlib import Path
import pynapple as nap
from scipy.ndimage import gaussian_filter, rotate

from spatial_manifolds.data.binning import get_bin_config
from spatial_manifolds.data.loading import load_session

import warnings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

save_path = '/Users/harryclark/Documents/figs/Ianfigs'

In [None]:
mouse = 0
day = 0
session_type = 'MCVR'
sorter = 'kilosort4'
storage = Path(f'/Users/harryclark/Downloads/COHORT12/')

alpha = 0.001
n_jobs = 8
n_shuffles = 100
seed = 1

class Args:
    def __init__(self,mouse,day,session_type,sorter,storage,alpha,n_jobs,n_shuffles,seed):
        self.mouse = mouse
        self.day = day
        self.session_type = session_type
        self.sorter = sorter
        self.storage = storage
        self.alpha = alpha
        self.n_jobs = n_jobs
        self.n_shuffles = n_shuffles
        self.seed = seed
args = Args(mouse,day,session_type,sorter,storage,alpha,n_jobs,n_shuffles,seed)

if session_type == 'VR':
    tl = 200
elif session_type == 'MCVR':
    tl = 230


In [None]:
def plot_speeds(session, title):
    fig, ax = plt.subplots(
        1, 1, layout='constrained', figsize=(2, 2)
    )
    for tt, tc, ttc in zip(['nb', 'nb'], ['rz1', 'rz2'], ['black', '#6a95bf']):
        trials = session['trials'][(session['trials']['trial_type'] == tt) &
                                   (session['trials']['trial_context'] == tc) ]
        print(f'n trials for speed {len(trials)}')
        tc  = nap.compute_1d_tuning_curves_continuous(
            session['S'],
            session['P'],
            nb_bins=50,
            minmax=[0,230],
            ep=trials)[0]
        ax.plot(tc.index, tc.values, color=ttc)
                
    ax.set_xlim(0,230)
    ax.axvspan(
        90,110,
        alpha=0.2,
        zorder=-10,
        edgecolor='none',
        facecolor='orange',
    )
    ax.axvspan(
        120,140,
        alpha=0.2,
        zorder=-10,
        edgecolor='none',
        facecolor='teal',
    )
    plt.savefig(f'{save_path}/{title}_speeds.pdf', dpi=300)
    plt.show()

def plot_stops(session, title):
    trial_numbers = np.array(session['trial_number'])
    position = np.array(session['P'])
    trial_types = np.array(session['trial_type'])
    speed = np.array(session['S'])
    stop_mask = speed<3

    fig, ax = plt.subplots(
        1, 1, layout='constrained', figsize=(2, 2)
    )
    for tt, ttc in zip([1,3], ['black', '#6a95bf']):
        tt_mask = trial_types == tt
        ax.scatter(position[(stop_mask & tt_mask)], 
                   trial_numbers[(stop_mask & tt_mask)],alpha=0.025, s=15, c=ttc)
    ax.set_xlim(0,230)
    ax.set_ylim(np.min(trial_numbers),np.max(trial_numbers))
    ax.axvspan(
        90,110,
        alpha=0.2,
        zorder=-10,
        edgecolor='none',
        facecolor='orange',
    )
    ax.axvspan(
        120,140,
        alpha=0.2,
        zorder=-10,
        edgecolor='none',
        facecolor='teal',
    )
    plt.savefig(f'{save_path}/{title}.pdf', dpi=300)
    plt.show()

def extract_first_stops(session, tt):
    trial_numbers = np.array('trial_number')
    position = np.array(session['P'])
    trial_types = np.array(session['trial_type'])
    speed = np.array(session['S'])
    stop_mask = speed<3
    tt_mask = trial_types == tt
    track_mask = position>30

    first_stops = []
    for tn in np.unique(trial_numbers):
        tn_mask = trial_numbers==tn
        trial_stop_locations = position[(stop_mask & tt_mask & tn_mask & track_mask)]
        if len(trial_stop_locations)>0:
            first_stop = trial_stop_locations[0]
        else:
            first_stop = np.nan
        first_stops.append(first_stop)
    first_stops=np.array(first_stops)
    return first_stops


In [None]:
# All mice!
# Load session
sessions = {}
for session_path in sorted(
    list(args.storage.glob(f'*/*/MCVR/*MCVR*.nwb'))
):
    args.mouse = int(session_path.parent.parent.parent.name[1:])
    args.day = int(session_path.parent.parent.name[1:])
    mouse_day = f'M{args.mouse}D{args.day}'
    sessions[mouse_day] = (
        *load_session(args),
        pd.read_parquet(
            session_path.parent / 'tuning_scores' / 'ramp_class.parquet'
        ),
    )
bin_config = get_bin_config(args.session_type)['P']

In [None]:
# All mice!
good_sesh = ["M22D43", "M22D45", "M22D46", 
             "M25D26", "M25D29", "M25D30", 
             "M26D24", 
             "M28D29", 
             "M29D24", "M29D26", "M29D29"]

# Load session
sessions = {}
for m_idx, mouse in enumerate([22,25,26,27,28,29]):
    for session_path in sorted(
        list(args.storage.glob(f'M{mouse}/*/MCVR/*MCVR*.nwb'))
    ):
        args.mouse = int(session_path.parent.parent.parent.name[1:])
        args.day = int(session_path.parent.parent.name[1:])
        mouse_day = f'M{args.mouse}D{args.day}'

        care = True
        if mouse_day in good_sesh:
            sessions[mouse_day] = (
                *load_session(args),
                pd.read_parquet(
                    session_path.parent / 'tuning_scores' / 'ramp_class.parquet'
                ),
            )
            print(f'mouse_day, {mouse_day}')
            plot_stops(sessions[mouse_day][0], title=mouse_day)
            plot_speeds(sessions[mouse_day][0], title=mouse_day)


bin_config = get_bin_config(args.session_type)['P']
# Plot
trial_types = [1,3]
ramp_classes = [('+',' '),
                ('+','+'), 
                ('+','-'),
                ('+','/'),
                ('-',' '),
                ('-','+'), 
                ('-','-'),
                ('-','/')]

fig, axs = plt.subplots(
    2, 4, layout='constrained', figsize=(8, 3.5)
)

for class_idx, ramp_class in enumerate(ramp_classes):
    axi = int(class_idx//4)
    axj = int(class_idx%4)


    for trial_type, trial_context, tt_c in zip(['nb','nb'], ['rz1', 'rz2'], ['black', '#6a95bf']):
        tcs_zscored = []
        for day_idx, (
            day,
            (session, session_path, clusters, ramp_table),
        ) in enumerate(sessions.items()):
            
            subset_ids1 = ramp_table[
                (ramp_table['group'] == 'rz1_nb')
                & (ramp_table['sign'] == ramp_class[0])
                & (ramp_table['region'] == 'outbound')
            ]['cluster_id'].values

            if ramp_class[1] != ' ':
                subset_ids2 = ramp_table[
                    (ramp_table['group'] == 'rz1_nb')
                    & (ramp_table['sign'] == ramp_class[1])
                    & (ramp_table['region'] == 'homebound')
                ]['cluster_id'].values
                subset_ids = np.intersect1d(subset_ids1, subset_ids2)
            else:
                subset_ids = subset_ids1

            trials = session['trials'][(session['trials']['trial_performance'] == 'run') &
                                       (session['trials']['trial_type'] == trial_type) &
                                       (session['trials']['trial_context'] == trial_context) ]
            print(f'n trials for condition = {len(trials)}')
            tcs  = nap.compute_1d_tuning_curves(
                clusters,
                session['P'],
                nb_bins=bin_config['num_bins'],
                minmax=bin_config['bounds'],
                ep=session['moving'].intersect(trials),
            )
            zscored = (tcs[subset_ids] - tcs[subset_ids].mean(axis=0)) / tcs[
            subset_ids].std(axis=0)

            tcs_zscored.extend(zscored.to_numpy().T.tolist())

        # z score and plot here axs[0, class_idx]
        tcs_zscored = np.array(tcs_zscored)

        print(f'rampclass {ramp_class} , len(zscored)={len(tcs_zscored)}')

        if len(tcs_zscored)>0:
            mean = np.nanmean(tcs_zscored, axis=0)
            sem = stats.sem(tcs_zscored, axis=0, nan_policy='omit')
            axs[axi, axj].plot(tcs.index, mean, color=tt_c)
            axs[axi, axj].fill_between(
                tcs.index, mean - sem, mean + sem, alpha=0.2, color=tt_c
                )
        axs[axi, axj].text(
            1.0,
            0.05,
            f'N={len(tcs_zscored)}',
            fontsize=8,
            ha='right',
            transform=axs[axi, axj].transAxes,
        )
        axs[axi, axj].axvspan(
            90,
            110,
            alpha=0.2,
            zorder=-10,
            edgecolor='none',
            facecolor='orange',
            )
        axs[axi, axj].axvspan(
            120,
            140,
            alpha=0.2,
            zorder=-10,
            edgecolor='none',
            facecolor='teal',
            )
        axs[axi, axj].axvspan(
            0,
            30,
            alpha=0.2,
            zorder=-10,
            edgecolor='none',
            facecolor='grey',
            )
        axs[axi, axj].axvspan(
            200,
            230,
            alpha=0.2,
            zorder=-10,
            edgecolor='none',
            facecolor='grey',
            )
        axs[axi, axj].set_xlim(0,230)
plt.savefig(f'{save_path}/rz1_nb_ramps_zscored_good_sessions_run_trials.pdf', dpi=300)
plt.show()


In [None]:
    
def plot_example(tc, bin_centres, title, save_path=''):
    fig, ax = plt.subplots(
        1, 1, layout='constrained', figsize=(2, 2)
    )
    for tt, tt_c in zip([1,3], ['black', '#6a95bf']):
        tt_tc = tc[f'{tt}']
        mean = np.nanmean(tt_tc, axis=0)
        sem = stats.sem(tt_tc, axis=0, nan_policy='omit')
        ax.plot(bin_centres, mean, color=tt_c)
        ax.fill_between(
            bin_centres, mean - sem, mean + sem, alpha=0.2, color=tt_c
            )
    ax.set_title(
        f'{title}',
        fontsize=8,
    )
    ax.set_xlim(0,230)
    ax.set_xlabel('position (cm)')
    ax.set_ylabel('firing rate (Hz)')

    ax.axvspan(
        90,110,
        alpha=0.2,
        zorder=-10,
        edgecolor='none',
        facecolor='orange',
    )
    ax.axvspan(
        120,140,
        alpha=0.2,
        zorder=-10,
        edgecolor='none',
        facecolor='teal',
    )
    ax.axvspan(
    0,
    30,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='grey',
    )
    ax.axvspan(
    200,
    230,
    alpha=0.2,
    zorder=-10,
    edgecolor='none',
    facecolor='grey',
    )
    plt.savefig(f'{save_path}/{title}.pdf', dpi=300)
    plt.show()



In [None]:

bin_config = get_bin_config(args.session_type)['P']
# Plot
trial_types = [1,3]
ramp_classes = [('+',' '),
                ('+','+'), 
                ('+','-'),
                ('+','/'),
                ('-',' '),
                ('-','+'), 
                ('-','-'),
                ('-','/')]

examplar_cell_ids = ['M25D29MCVR112',
                     'M25D29MCVR287',
                     'M25D29MCVR265',
                     'M26D24MCVR18',
                     'M26D24MCVR27',
                     'M29D26MCVR379',
                     'M29D26MCVR414',
                     'M29D24MCVR130',
                     'M29D24MCVR129',
                     'M28D29MCVR342',
                     'M28D29MCVR340',
]

for class_idx, ramp_class in enumerate(ramp_classes):
    examplar_tcs = {}
    for unique_idx in examplar_cell_ids:
        examplar_tcs[unique_idx] = {}

    print(ramp_class)

    ncols = 10
    nrows = int(np.ceil(840/ncols))
    fig, axs = plt.subplots(ncols=ncols, nrows=nrows, 
                           figsize=(15, 4*(np.ceil(840/6)/4)), squeeze=False)
    
    for trial_type, tt_c in zip([1,3], ['black', '#6a95bf']):
        tcs = {}

        for day_idx, (
            day,
            (session, session_path, clusters, ramp_table),
        ) in enumerate(sessions.items()):
            
            subset_ids1 = ramp_table[
                (ramp_table['group'] == 'rz1_nb')
                & (ramp_table['sign'] == ramp_class[0])
                & (ramp_table['region'] == 'outbound')
            ]['cluster_id'].values

            if ramp_class[1] != ' ':
                subset_ids2 = ramp_table[
                    (ramp_table['group'] == 'rz1_nb')
                    & (ramp_table['sign'] == ramp_class[1])
                    & (ramp_table['region'] == 'homebound')
                ]['cluster_id'].values
                subset_ids = np.intersect1d(subset_ids1, subset_ids2)
            else:
                subset_ids = subset_ids1

            trials = session['trials'][session['trials']['trial_type'] == float(trial_type)]

            n_bins = bin_config['num_bins']
            bounds = bin_config['bounds']
            sigma = bin_config['smooth_sigma']

            for index in subset_ids:
                unique_idx = f'{session.name}{index}'
                this_neuron = clusters[clusters.index == index]
                trial_tcs = []
                for tn in trials.trial_number:
                    trial = session['trials'][session['trials']['trial_number'] == tn]
                    tc = nap.compute_1d_tuning_curves(this_neuron, 
                                                    session["P"], 
                                                    nb_bins=n_bins, 
                                                    minmax=[bounds[0], bounds[1]],
                                                    ep=session["moving"].intersect(trial))[index]
                    bin_centres = tc.index
                    tc = np.nan_to_num(tc)
                    trial_tcs.append(tc)
                trial_tcs = np.array(trial_tcs)
                trial_tcs_flat = trial_tcs.flatten()
                trial_tcs_flat = gaussian_filter(np.nan_to_num(trial_tcs_flat).astype(np.float64), sigma=sigma)
                trial_tcs = trial_tcs_flat.reshape(trial_tcs.shape)
                tcs[unique_idx] = trial_tcs
                
                # save the examplar cells
                if unique_idx in examplar_cell_ids:
                    examplar_tcs[unique_idx][f'{trial_type}'] = trial_tcs
            
        # plot all of the cells
        print(f'for ramp class{ramp_class}, there is {len(tcs)} found')
        for tc_idx, unique_idx in enumerate(tcs.keys()):
            tc = tcs[unique_idx]
            axi = int(tc_idx//ncols)
            axj = int(tc_idx%ncols)
        
            mean = np.nanmean(tc, axis=0)
            sem = stats.sem(tc, axis=0, nan_policy='omit')
            axs[axi, axj].plot(bin_centres, mean, color=tt_c)
            axs[axi, axj].fill_between(
                bin_centres, mean - sem, mean + sem, alpha=0.2, color=tt_c
                )
            axs[axi, axj].text(
                1.0,
                -0.05,
                f'{unique_idx}',
                fontsize=8,
                ha='right',
                transform=axs[axi, axj].transAxes,
            )
            axs[axi, axj].axvspan(
                90,
                110,
                alpha=0.2,
                zorder=-10,
                edgecolor='none',
                facecolor='orange',
                )
            axs[axi, axj].axvspan(
                120,
                140,
                alpha=0.2,
                zorder=-10,
                edgecolor='none',
                facecolor='teal',
                )
            axs[axi, axj].axvspan(
                0,
                30,
                alpha=0.2,
                zorder=-10,
                edgecolor='none',
                facecolor='grey',
                )
            axs[axi, axj].axvspan(
                200,
                230,
                alpha=0.2,
                zorder=-10,
                edgecolor='none',
                facecolor='grey',
                )
            axs[axi, axj].set_xlim(0,230)
            axs[axi, axj].xaxis.set_visible(False)
            axs[axi, axj].yaxis.set_visible(False)
            axs[axi, axj].spines['top'].set_visible(False)
            axs[axi, axj].spines['right'].set_visible(False)
            axs[axi, axj].spines['bottom'].set_visible(False)
            axs[axi, axj].spines['left'].set_visible(False)
    plt.show()

    # plot example cells
    for unique_idx in examplar_tcs.keys():
        try:
            tc = examplar_tcs[unique_idx]
            plot_example(tc, bin_centres, title=unique_idx, save_path=save_path)
        except:
            continue
