In [None]:
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
from pathlib import Path
import pynapple as nap

from spatial_manifolds.data.binning import get_bin_config
from spatial_manifolds.data.loading import load_session
from spatial_manifolds.behaviour_plots import *

import warnings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2


In [None]:
mouse = 0
day = 0
session_type = 'VR'
storage = Path('/Users/harryclark/Downloads/COHORT12_nwb/')

alpha = 0.001
n_jobs = 8
n_shuffles = 100
seed = 1
sorter = 'kilosort4'
class Args:
    def __init__(self,mouse,day,session_type,storage,alpha,n_jobs,n_shuffles,seed,sorter):
        self.mouse = mouse
        self.day = day
        self.session_type = session_type
        self.storage = storage
        self.alpha = alpha
        self.n_jobs = n_jobs
        self.n_shuffles = n_shuffles
        self.seed = seed
        self.sorter = sorter
args = Args(mouse,day,session_type,storage,alpha,n_jobs,n_shuffles,seed,sorter)

In [None]:
args.session_type = 'VR'

# Load session
sessions = {}
for session_path in sorted(
    list(args.storage.glob(f'*/*/{args.session_type}/*{args.session_type}_beh.nwb'))
):  
    print(session_path)
    args.mouse = int(session_path.parent.parent.parent.name[1:])
    args.day = int(session_path.parent.parent.name[1:])
    mouse_day = f'M{args.mouse}D{args.day}'
    sessions[mouse_day] = (
        *load_session(args),
    )


In [None]:
mouse_day_array = []
b_correct = []
nb_correct = []
has_ephys = []
for Mouse_Day in sessions.keys():
    print(Mouse_Day)
    beh = sessions[Mouse_Day][0]
    trials = beh['trials']
    if (('nb' in np.array(trials['type'])) and ('b' in np.array(trials['type']))):
        b_correct.append(100*len(trials[(trials['type'] == 'b')&(trials['performance'] == 'hit')]) /len(trials[(trials['type'] == 'b')]))
        nb_correct.append(100*len(trials[(trials['type'] == 'nb')&(trials['performance'] == 'hit')]) /len(trials[(trials['type'] == 'nb')]))
        mouse_day_array.append(Mouse_Day)
        if (len(sessions[Mouse_Day])) == 2:
            has_ephys.append(False)
        else:
            has_ephys.append(True)   
b_correct = np.array(b_correct)
nb_correct = np.array(nb_correct)
mouse_day_array = np.array(mouse_day_array)
has_ephys = np.array(has_ephys)

In [None]:
plt.figure(figsize=(3, 2.5))
both_mask = (b_correct > 66.6) & (nb_correct > 66.6)
none_mask = (b_correct < 33.3) & (nb_correct < 33.3)
nb_mask = (nb_correct > 66.6) & (b_correct < 33.3)
b_mask = (nb_correct < 66.6) & (b_correct > 66.6)
mid_mask = (b_correct < 66.6) & (nb_correct < 66.6)
has_ephys_mask = has_ephys
has_no_ephys_mask = ~has_ephys

plt.scatter(b_correct[both_mask & has_no_ephys_mask], nb_correct[both_mask & has_no_ephys_mask], color='green', alpha=0.5, marker='o')
plt.scatter(b_correct[b_mask & has_no_ephys_mask], nb_correct[b_mask & has_no_ephys_mask], color='tab:blue', alpha=0.5, marker='o')
plt.scatter(b_correct[nb_mask & has_no_ephys_mask], nb_correct[nb_mask & has_no_ephys_mask], color='tab:blue', alpha=0.5, marker='o')
plt.scatter(b_correct[mid_mask & has_no_ephys_mask], nb_correct[mid_mask & has_no_ephys_mask], color='tab:grey', alpha=0.5, marker='o')
plt.scatter(b_correct[none_mask & has_no_ephys_mask], nb_correct[none_mask & has_no_ephys_mask], color='black', alpha=0.5, marker='o')

plt.scatter(b_correct[both_mask & has_ephys_mask], nb_correct[both_mask & has_ephys_mask], color='green', alpha=0.5, marker='^')
plt.scatter(b_correct[b_mask & has_ephys_mask], nb_correct[b_mask & has_ephys_mask], color='tab:blue', alpha=0.5, marker='^')
plt.scatter(b_correct[nb_mask & has_ephys_mask], nb_correct[nb_mask & has_ephys_mask], color='tab:blue', alpha=0.5, marker='^')
plt.scatter(b_correct[mid_mask & has_ephys_mask], nb_correct[mid_mask & has_ephys_mask], color='tab:grey', alpha=0.5, marker='^')
plt.scatter(b_correct[none_mask & has_ephys_mask], nb_correct[none_mask & has_ephys_mask], color='black', alpha=0.5, marker='^')

plt.ylim(0, 100)
plt.xlim(0, 100)
plt.xticks([0,33.3, 66.6, 100])
plt.yticks([0,33.3, 66.6, 100])
plt.xlabel('Cued hits (%)')
plt.ylabel('Uncued hits (%)')
plt.tight_layout()
plt.show()

print(f'both trial type sessions {len(mouse_day_array[both_mask])}, {mouse_day_array[both_mask]}')
print(f'cued trial type sessions{len(mouse_day_array[b_mask])}, {mouse_day_array[b_mask]}')
print(f'uncued trial type sessions{len(mouse_day_array[nb_mask])}, {mouse_day_array[nb_mask]}')
print(f'mid sessions {len(mouse_day_array[mid_mask])}, {mouse_day_array[mid_mask]}')
print(f'bad trial type sessions {len(mouse_day_array[none_mask])}, {mouse_day_array[none_mask]}')


print(f'both trial type sessions with ephys {len(mouse_day_array[both_mask & has_ephys_mask])}, {mouse_day_array[both_mask & has_ephys_mask]}')
print(f'cued trial type sessions with ephys {len(mouse_day_array[b_mask & has_ephys_mask])}, {mouse_day_array[b_mask & has_ephys_mask]}')
print(f'uncued trial type sessions with ephys{len(mouse_day_array[nb_mask & has_ephys_mask])}, {mouse_day_array[nb_mask & has_ephys_mask]}')
print(f'mid sessionswith ephys {len(mouse_day_array[mid_mask & has_ephys_mask])}, {mouse_day_array[mid_mask & has_ephys_mask]}')
print(f'bad trial type sessionswith ephys {len(mouse_day_array[none_mask & has_ephys_mask])}, {mouse_day_array[none_mask & has_ephys_mask]}')

In [None]:
# make a function that computes streaks of hits in the b_trials performance
def compute_streaks(beh, context='rz1', type='b', sign='engaged', streak_cap=5):
    trials = beh['trials']
    trials_in_group = trials[(trials['type'] == type) 
                           & (trials['context'] == context)]

    streaks = []
    current_streak = 0
    for i, trial in enumerate(trials):
        tn = trial['number'][0]
        if tn in trials_in_group['number'].values:
            if sign == 'engaged':
                if trial['performance'][0] == 'hit':
                    current_streak += 1
                else:
                    current_streak = 0
            elif sign == 'not_engaged':
                if trial['performance'][0] != 'hit':
                    current_streak += 1
                else:
                    current_streak = 0
            else:
                raise ValueError("sign must be either 'engaged' or 'not_engaged'")
        if current_streak >= streak_cap:
            current_streak = streak_cap
        streaks.append(current_streak)
    return np.array(streaks)

In [None]:
    
def plot_stops_with_streaks(beh, title=None, savepath=None, tl=None, sort=False, return_fig=True, sign='engaged', streak_cap=5):
    trial_numbers = np.array(beh['trial_number'])
    position = np.array(beh['P'])
    trial_types = np.array(beh['trial_type'])
    speed = np.array(beh['S'])
    stop_mask = speed<3
    b_streaks = compute_streaks(beh, context='rz1', type='b', sign=sign, streak_cap=streak_cap)
    nb_streaks = compute_streaks(beh, context='rz1', type='nb', sign=sign, streak_cap=streak_cap)

    fig, ax = plt.subplots(ncols=4, nrows=1, figsize=(3, 3), width_ratios=[1,0.05,0.4, 0.4], sharey=True)
    for i, ti in enumerate(beh['trials'].index):
        tn_mask = trial_numbers==beh['trials'][ti]['number'].iloc[0]
        group = (beh['trials'][ti]['context'].iloc[0],
                    beh['trials'][ti]['type'].iloc[0],
                    beh['trials'][ti]['performance'].iloc[0])
        stops = position[(stop_mask & tn_mask)]
        ax[0].scatter(stops, i*np.ones(len(stops)), alpha=0.025, s=3, c=get_color_for_group(group), rasterized=True)
        ax[1].scatter(1, i, c=get_color_for_group(group), marker='s', rasterized=True)

    if sign == 'engaged':
        ax[2].barh(beh['trials'].index, b_streaks, color=get_color_for_group(('rz1', 'b', 'hit')), alpha=0.5)
        ax[3].barh(beh['trials'].index, nb_streaks, color=get_color_for_group(('rz1', 'nb', 'hit')), alpha=0.5)
    elif sign == 'not_engaged':
        ax[2].barh(beh['trials'].index, nb_streaks, color=get_color_for_group(('rz1', 'b', 'run')), alpha=0.5)
        ax[3].barh(beh['trials'].index, nb_streaks, color=get_color_for_group(('rz1', 'nb', 'run')), alpha=0.5)
    ax[2].set_xlim(0,streak_cap)
    ax[3].set_xlim(0,streak_cap)
    ax[1].axis('off')
    ax[0].set_xlabel('Pos (cm)')
    ax[0].set_xlim(0,tl)
    ax[0].set_ylim(0,len(beh['trials']))
    ax[0].invert_yaxis()
    ax[0].axvspan(
        90,110,
        alpha=0.2,
        zorder=-10,
        edgecolor='none',
        facecolor='grey',
    )
    return 


In [None]:
beh = sessions['M26D13'][0]
plot_stops_with_streaks(beh, tl=200, sign='engaged', return_fig=True,
           savepath=f'/Users/harryclark/Documents/figs/toroidal/M{mouse}D{day}_stops_streaks_engaged.pdf', streak_cap=5)

In [None]:
for Mouse_Day in sessions.keys():
    print(Mouse_Day)
    beh = sessions[Mouse_Day][0]
    Mouse, day = Mouse_Day[1:].split('D')
    if int(day) > 10:
       plot_stops_with_streaks(beh, tl=200, sign='engaged', return_fig=True,
                               savepath=f'/Users/harryclark/Documents/figs/toroidal/{Mouse}D{day}_stops_streaks_engaged.pdf')
       plt.show()