In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
from datetime import datetime
from joblib import Parallel, delayed
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import itertools
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import ecephys.plot
from ecephys_analyses.data import paths
from ecephys.scoring.hypnogram import load_visbrain_hypnogram, load_consecutive_visbrain_hypnograms
from ecephys.sglx_utils import get_sf

In [None]:
plt.rcParams['figure.figsize'] = [15, 6]
plt.rcParams['figure.dpi'] = 200 # 200 e.g. is really fine, but slower

SMALL_SIZE = 12
MEDIUM_SIZE = 15
BIGGER_SIZE = 18

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure t

In [None]:
state_colors = ecephys.plot.state_colors
state_colors.update({
    'Wake': 'lightgreen',
    'N1': 'cornflowerblue',
    'N2': 'royalblue',
    'REM': 'lightcoral',
    'Trans': 'black',
})

In [None]:
def plot_drift_traces(subject, condition, sorting_condition, hyp_root, output_dir=None):
    print(f'drift traces, {subject} {condition} {sorting_condition}')
    hyp=load_hypno(subject, condition, hyp_root)
    drift = np.load(
        paths.get_datapath(
            subject,
            condition,
            sorting_condition,
        )/'drift.npy'
    )
    times = np.linspace(hyp.start_time.min(), hyp.end_time.max(), num=drift.shape[0])

    fig, ax = plt.subplots()
    n = drift.shape[1]
    colors = plt.cm.plasma(np.linspace(0,0.75,n))

    for i in range(n):
        plt.plot(times, drift[:, i], color=colors[i], alpha=0.7, linewidth=1)
    ecephys.plot.plot_hypnogram_overlay(hyp, state_colors=state_colors, ax=ax)
    plt.ylabel('Drift (um)')
    plt.xlabel('Time (s)')
    hypcols = {state: state_colors[state] for state in hyp.state.unique()}
    plt.title(f'Drift trace: {subject}, {condition}, {sorting_condition}.\n N={len(times)} batches.\n{hypcols}')
    
    if output_dir is not None:
        Path(output_dir).mkdir(exist_ok=True)
        filename = f'drift_trace_{subject}_{condition}_{sorting_condition}.png'
        print(f'save {filename}')
        fig.savefig(Path(output_dir)/filename, bbox_inches='tight')

    return fig, ax


def plot_drift_map(subject, condition, sorting_condition, hyp_root, output_dir=None):
    print(f'drift map, {subject} {condition} {sorting_condition}')
    hyp=load_hypno(subject, condition, hyp_root)
    
    st =  np.load(
        paths.get_datapath(
            subject,
            condition,
            sorting_condition,
        )/'pre_correction_sorting.npy'
    )
    sf = get_sf(
        paths.get_sglx_style_datapaths(
            subject=subject,
            condition=condition,
            ext='ap.bin',
        )[0]
    )

    # Figure
    fig, ax = plt.subplots()
    st_shift = st[:,1]
    for j in range(100):
        ix = np.where(st[:, 2]==j)[0]
        plt.scatter(
            st[ix, 0]/sf,
            st_shift[ix],
            s=0.05,
            c=max(0, 1-j/40) * np.ones((len(ix),)),
            cmap='gray', vmin=0,  vmax=1,
        )
    ecephys.plot.plot_hypnogram_overlay(hyp, state_colors=state_colors, ax=ax, alpha=0.2)
    plt.ylabel('Spike position (um)')
    plt.xlabel('Time (s)')
    hypcols = {state: state_colors[state] for state in hyp.state.unique()}
    plt.title(f'Drift map: {subject}, {condition}, {sorting_condition}.\n{hypcols}')

    if output_dir is not None:
        Path(output_dir).mkdir(exist_ok=True)
        filename = f'drift_map_{subject}_{condition}_{sorting_condition}.png'
        fig.savefig(Path(output_dir)/filename, bbox_inches='tight')

    
    return fig, ax

    
def load_hypno(subject, condition, root):
    hypno_paths = paths.get_sglx_style_datapaths(
        subject=subject,
        condition=condition,
        ext='hypnogram.txt',
        data_root=root,
    )
    return load_consecutive_visbrain_hypnograms(hypno_paths)

In [None]:
DEPTHS = [0, 7680]

def get_state(hyp, t):
    return hyp[(hyp['start_time'] <= t) & (hyp['end_time'] >= t)]['state'].iloc[0]


def plot_drift_fingerprints(subject, condition, sorting_condition, hyp_root, N=25, output_dir=None, depths_interval=None):
    print(f'drift fingerprints, {subject} {condition} {sorting_condition}')
    
    hyp=load_hypno(subject, condition, hyp_root)
    F = np.load(
        paths.get_datapath(
            subject,
            condition,
            sorting_condition,
        )/'F.npy'
    )[::-1, :, :]
    F0 = np.load(
        paths.get_datapath(
            subject,
            condition,
            sorting_condition,
        )/'F0.npy'
    )[::-1, :]
    drift = np.load(
        paths.get_datapath(
            subject,
            condition,
            sorting_condition,
        )/'drift.npy'
    )
    times = np.linspace(hyp.start_time.min(), hyp.end_time.max(), num=drift.shape[0])
    depths = np.linspace(DEPTHS[0], DEPTHS[1], num=F.shape[0])
    
    if depths_interval is not None:
        print(f"select depths at {depths_interval}")
        depths_i = np.array([
            i for i, d in enumerate(depths)
            if d>=depths_interval[0] and d <= depths_interval[1]
        ])
        F = F[depths_i, :, :]
        F0 = F0[depths_i, :]
        depths = depths[depths_i]
    
    step = int(len(times)/N)

    F_sub = F[:,:,::step]
    times_sub = times[::step]
    drift_sub = drift[::step]

    n = len(times_sub)

#     plt.rcParams['figure.figsize'] = (10,20)
    fig = plt.Figure(figsize=(30,20))
    grid = fig.add_gridspec(3, n+1, wspace=0.0, hspace=0.0, height_ratios=[15,1,5])

    ax = plt.subplot(grid[0,0])
    plt.imshow(F0, aspect='auto', extent=[0, 20, depths.min(), depths.max()])
    ax.title.set_text('F0 (target)')

    for i in range(n):
        t = times_sub[i]

        ax = plt.subplot(grid[0,i+1])
        plt.imshow(F_sub[:,:,i], aspect='auto')
        ax.get_yaxis().set_visible(False)
        ax.set_title(f't={t}', rotation=90)

        ax = plt.subplot(grid[1, i+1])
        state = get_state(hyp, t)
        ax.set_facecolor(state_colors[state])

    ax = plt.subplot(grid[2, 1:])
    for i in range(drift_sub.shape[1]):
        plt.plot(times_sub, drift_sub[:, i], alpha=0.7, linewidth=1)
    
    hypcols = {state: state_colors[state] for state in hyp.state.unique()}
    plt.suptitle(f'Fingerprints: {subject}, {condition}, {sorting_condition}.\n{hypcols}')
    
    if output_dir is not None:
        Path(output_dir).mkdir(exist_ok=True)
        filename = f'drift_fingerprint_{subject}_{condition}_{sorting_condition}.png'
        print(f'save {filename}')
        fig.savefig(Path(output_dir)/filename)
    
    return fig


def plot_state_average_fingerprints(
    subject, condition, sorting_condition, hyp_root, N=25, output_dir=None,
    T_max=float('inf'),
):
    print(f'drift average fingerprint by state, {subject} {condition} {sorting_condition}')
    
    hyp=load_hypno(subject, condition, hyp_root)
    F = np.load(
        paths.get_datapath(
            subject,
            condition,
            sorting_condition,
        )/'F.npy'
    )
    F0 = np.load(
        paths.get_datapath(
            subject,
            condition,
            sorting_condition,
        )/'F0.npy'
    )
    times = np.linspace(hyp.start_time.min(), hyp.end_time.max(), num=F.shape[2])
    depths = np.linspace(DEPTHS[0], DEPTHS[1], num=F.shape[0])
    
    # Subset of depths
    if depths_interval is not None:
        print(f"select depths at {depths_interval}")
        depths_i = np.array([
            i for i, d in enumerate(depths)
            if d>=depths_interval[0] and d <= depths_interval[1]
        ])
        F = F[depths_i, :, :]
        F0 = F0[depths_i, :]
        depths = depths[depths_i]

    # Subset of times
    print(f'times: {times.min()}, {times.max()}, Tmax={T_max}')
    idx = np.array([i for i, t in enumerate(times) if t <= T_max])
    F = F[:,:,idx]
    times = times[idx]
    

    fig, axes = plt.subplots(1, nstates+1)

    ax = plt.axes(axes[0])
    plt.imshow(F0, aspect='auto', extent=[0, 20, depths.min(), depths.max()])
    ax.title.set_text('F0 (target)')


    for i, state in enumerate(hyp.state.unique()):
        state_i = np.array([i for i, t in enumerate(times) if get_state(hyp, t) == state])
        print(f'{state}: {len(state_i)} batches')

        ax = plt.axes(axes[i+1])
        Fstate = np.mean(F[:,:,state_i.astype(int)], axis=2)
        plt.imshow(Fstate, aspect='auto')
        ax.title.set_text(state)
        ax.get_yaxis().set_visible(False)

    plt.tight_layout()

    hypcols = {state: state_colors[state] for state in hyp.state.unique()}
    plt.suptitle(f'Fingerprint averages: {subject}, {condition}, {sorting_condition}.\n{hypcols}')
    

    if output_dir is not None:
        Path(output_dir).mkdir(exist_ok=True)
        filename = f'drift_state_average_fingerprint_{subject}_{condition}_{sorting_condition}.png'
        print(f'save {filename}')
        fig.savefig(Path(output_dir)/filename, bbox_inches='tight')

    return fig

In [None]:
output_dir = './plots_drift'

# All

In [None]:
conds = [
#     ('Doppio', 'drift_test_01', 'ks2_5_raw_df', '/Volumes/neuropixel/Data/CNPIX4-Doppio'),
#     ('Doppio', 'drift_test_01', 'ks2_5_raw_8s-batches_minFR=0', '/Volumes/neuropixel/Data/CNPIX4-Doppio'),
#     ('Doppio', 'drift_test_01', 'ks2_5_raw_rigid', '/Volumes/neuropixel/Data/CNPIX4-Doppio'),
#     ('Doppio', 'drift_test_01', 'ks2_5_raw_4s-batches', '/Volumes/neuropixel/Data/CNPIX4-Doppio'),
#     ('Doppio', 'drift_test_01', 'ks2_5_raw_8s-batches', '/Volumes/neuropixel/Data/CNPIX4-Doppio'),
#     ('Doppio', 'drift_test_01', 'ks2_5_raw_df', '/Volumes/neuropixel/Data/CNPIX4-Doppio'),
#     ('Doppio', 'drift_test_02', 'ks2_5_raw_df', '/Volumes/neuropixel/Data/CNPIX4-Doppio'),
#     ('Valentino', 'drift_test_01', 'ks2_5_raw_df', '/Volumes/neuropixel/Data/CNPIX3-Valentino'),
#     ('Valentino', 'drift_test_02', 'ks2_5_raw_df', '/Volumes/neuropixel/Data/CNPIX3-Valentino'),
#     ('Valentino', 'baseline_6h', 'ks2_5_raw_df', '/Volumes/neuropixel/Data/CNPIX3-Valentino'),
#     ('Valentino', 'baseline_12h', 'ks2_5_raw_df', '/Volumes/neuropixel/Data/CNPIX3-Valentino'),
#     ('Valentino', 'baseline_12h', 'ks2_5_raw_niter=100_nmaxshift_x2', '/Volumes/neuropixel/Data/CNPIX3-Valentino'),
    ('Valentino', 'baseline_12h', 'ks2_5_raw_niter=100_nmaxshift_x2_8s-batches', '/Volumes/neuropixel/Data/CNPIX3-Valentino'),
]

In [None]:
for subject, condition, sorting_condition, hyp_root in conds:
    plot_drift_traces(
        subject,
        condition,
        sorting_condition,
        hyp_root,
        output_dir=output_dir
    )



In [None]:
for subject, condition, sorting_condition, hyp_root in conds:
    plot_drift_map(
        subject,
        condition,
        sorting_condition,
        hyp_root,
        output_dir=output_dir
    )



In [None]:
# depths_interval = [5328, 6218]
depths_interval = None

In [None]:
for subject, condition, sorting_condition, hyp_root in conds:
    plot_drift_fingerprints(
        subject,
        condition,
        sorting_condition,
        hyp_root,
        output_dir=output_dir,
        depths_interval=depths_interval,
    )



In [None]:
for subject, condition, sorting_condition, hyp_root in conds:
    plot_drift_fingerprints(
        subject,
        condition,
        sorting_condition,
        hyp_root,
        output_dir=output_dir,
        depths_interval=depths_interval,
    )



In [None]:
for subject, condition, sorting_condition, hyp_root in conds:
    plot_drift_fingerprints(
        subject,
        condition,
        sorting_condition,
        hyp_root,
        output_dir=output_dir
    )



In [None]:
for subject, condition, sorting_condition, hyp_root in conds:
    plot_drift_fingerprints(
        subject,
        condition,
        sorting_condition,
        hyp_root,
        output_dir=output_dir
    )



In [None]:
for subject, condition, sorting_condition, hyp_root in conds:
    plot_state_average_fingerprints(
        subject,
        condition,
        sorting_condition,
        hyp_root,
        output_dir=output_dir
    )



In [None]:
for subject, condition, sorting_condition, hyp_root in conds:
    plot_state_average_fingerprints(
        subject,
        condition,
        sorting_condition,
        hyp_root,
        output_dir=output_dir,
        T_max=20000,
    )



# Doppio tmp

In [None]:
subject = 'Doppio'
condition = 'drift_test_01'
sorting_condition = 'test'
hyp_root = '/Volumes/neuropixel/Data/CNPIX4-Doppio'

In [None]:
plot_drift_map(subject, condition, sorting_condition, hyp_root)

In [None]:
plot_drift_traces(subject, condition, sorting_condition, hyp_root)

In [None]:
plot_drift_fingerprints(subject, condition, sorting_condition, hyp_root)

In [None]:
plot_state_average_fingerprints(subject, condition, sorting_condition, hyp_root)