## Imports and notebook definition

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib widget
import matplotlib.pyplot as plt

In [3]:
from pathlib import Path
import numpy as np
import pandas as pd
from neurodsp.filt import filter_signal
from ecephys.signal.csd import kcsd_npix

In [95]:
from ecephys.data import paths, channel_groups
from ecephys.sglx_utils import load_timeseries
from ecephys.helpers.utils import store_df_h5
from ecephys.scoring import load_visbrain_hypnogram
import ecephys.plot as eplt

In [274]:
condition = "pre-SR"
subject = "Doppio"

## Load the data

In [275]:
sr_chans = channel_groups.stratum_radiatum_140um_to_200um[subject]
so_chans = channel_groups.stratum_oriens_100um[subject]
hpc_chans = channel_groups.hippocampus[subject]
bin_path = Path(paths.lfp_bin[condition][subject])

In [25]:
(time, sr_lfps, fs) = load_timeseries(bin_path, sr_chans, start_time=None, end_time=None)

nChan: 385, nFileSamp: 18000019


In [26]:
(time, so_lfps, fs) = load_timeseries(bin_path, so_chans, start_time=None, end_time=None)

nChan: 385, nFileSamp: 18000019


In [276]:
(time, hpc_lfps, fs) = load_timeseries(bin_path, hpc_chans, start_time=None, end_time=None)

nChan: 385, nFileSamp: 18000019


## Explore LFPs

In [13]:
plot_timeseries_interactive(time, sr_lfps, chan_labels=sr_chans)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

HBox(children=(FloatSlider(value=1.0, description='Secs', max=4.0, min=0.25, step=0.25), BoundedFloatText(valu…

Output()

In [12]:
lfps = (sr_lfps.T - so_lfps.T).T

In [14]:
plot_timeseries_interactive(time, lfps, chan_labels=sr_chans)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

HBox(children=(FloatSlider(value=1.0, description='Secs', max=4.0, min=0.25, step=0.25), BoundedFloatText(valu…

Output()

## Filter signal

In [73]:
f_range = (2, 35)
filtered_sr_lfps = filter_signal(sr_lfps.T, fs, 'bandpass', f_range)

In [74]:
plot_timeseries_interactive(time, filtered_sr_lfps.T, chan_labels=sr_chans)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

HBox(children=(FloatSlider(value=1.0, description='Secs', max=4.0, min=0.25, step=0.25), BoundedFloatText(valu…

Output()

## Detection type I: Thresholding LFP negativity by Z-score

In [94]:
from ecephys.signal.ripples import _threshold_by_zscore
from ripple_detection.core import get_envelope

filtered_lfps = -filtered_sr_lfps.T
filtered_lfps[filtered_lfps < 0] = 0
not_null = np.all(pd.notnull(filtered_lfps), axis=1)
filtered_lfps, _time = (filtered_lfps[not_null], time[not_null])
combined_filtered_lfps = np.sum(filtered_lfps, axis=1)

candidate_spw_times = _threshold_by_zscore(
    combined_filtered_lfps,
    _time,
    minimum_duration=0.005,
    detection_zscore_threshold=3,
    boundary_zscore_threshold=1,
)

index = pd.Index(np.arange(len(candidate_spw_times)) + 1, name="spw_number")
spw_times = pd.DataFrame(candidate_spw_times, columns=["start_time", "end_time"], index=index)

In [95]:
from ecephys.signal.ripples import compute_ripple_features

spw_times = compute_ripple_features(_time, filtered_lfps, spw_times, fs, 'Kay', smoothing_sigma=0.0)

In [96]:
spw_times

Unnamed: 0_level_0,start_time,end_time,duration,center_time,nadir_time,envelope_integral,envelope_peak,mean_rms,summed_rms,max_rms,mean_amplitude,max_amplitude
spw_number,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
1,47.021553,47.048353,0.0268,47.034953,47.021553,166371.545583,2652.641939,584.854823,7603.112701,742.045104,526.215469,715.590225
2,73.973126,74.013126,0.0400,73.993126,74.013126,269764.957808,3110.843671,636.294686,8271.830917,727.651448,639.632100,740.460225
3,87.201112,87.301112,0.1000,87.251112,87.301112,758746.360762,3490.477184,710.592883,9237.707485,875.846920,716.534301,834.688976
4,87.999512,88.027911,0.0284,88.013711,87.999512,200044.981153,3034.285481,649.571194,8444.425524,839.666150,701.843824,756.933179
5,90.935109,90.964709,0.0296,90.949909,90.935109,181114.458519,2608.839470,580.888174,7551.546256,684.333483,439.888779,523.817146
...,...,...,...,...,...,...,...,...,...,...,...,...
6631,7197.235162,7197.277562,0.0424,7197.256362,7197.235162,277196.371691,2831.149172,605.984285,7877.795705,770.190965,525.649443,720.915793
6632,7198.137562,7198.175562,0.0380,7198.156562,7198.175562,227876.350072,2609.146032,528.896318,6875.652135,911.672230,406.677231,789.213268
6633,7198.634761,7198.716761,0.0820,7198.675761,7198.634761,958930.913475,4978.687382,1065.828176,13855.766290,1297.673928,1220.742873,1406.807306
6634,7199.073961,7199.127161,0.0532,7199.100561,7199.127161,404512.767081,3689.438886,710.093005,9231.209062,886.255519,720.032279,876.217046


In [97]:
from ecephys.plot import plot_ripple

from ipywidgets import (
    fixed,
    interact,
)

_, ax = plt.subplots(3, 1, figsize=(18, 6))
_ = interact(
    plot_ripple,
    time=fixed(_time),
    lfps=fixed(sr_lfps[not_null]),
    filtered_lfps=fixed(-filtered_lfps),
    fs=fixed(fs),
    ripple_times=fixed(spw_times),
    window_length=(0.25, 2, 0.25),
    ripple_number=(1, len(spw_times), 1),
    ax=fixed(ax),
)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

interactive(children=(FloatSlider(value=1.0, description='window_length', max=2.0, min=0.25, step=0.25), IntSl…

## Detection type II: CSD

In [277]:
from kcsd import KCSD1D

n_hpc_chans = len(hpc_chans)
intersite_distance = 0.020
interestimate_distance = intersite_distance
ele_pos = np.linspace(0., (n_hpc_chans - 1) * intersite_distance, n_hpc_chans).reshape(n_hpc_chans, 1)

In [278]:
pots = hpc_lfps.T
k = KCSD1D(ele_pos, pots, gdx=interestimate_distance)
est_csd = k.values('CSD')

In [279]:
k.L_curve()

No lambda given, using defaults
min lambda 1e-12
max lambda 0.0126
min lambda 1e-12
max lambda 0.0126
l-curve (all lambda):  0.23
Best lambda and R =  0.0012266776157016984 ,  0.23


In [280]:
est_csd_val = k.values('CSD')

### Optionally detrend

In [54]:
detrended_hpc_lfps = detrend(hpc_lfps)

In [55]:
l = KCSD1D(ele_pos, detrended_hpc_lfps.T, gdx=0.02)
est_csd_detrend = l.values('CSD')

In [56]:
l.L_curve()

No lambda given, using defaults
min lambda 1e-12
max lambda 0.0126
min lambda 1e-12
max lambda 0.0126
l-curve (all lambda):  0.23
Best lambda and R =  0.0003822395851068327 ,  0.23


In [57]:
est_csd_val_detrend = l.values('CSD')

### Detect

In [281]:
from ecephys.signal.ripples import threshold_by_zscore

detection_threshold=2.5
boundary_threshold=1
minimum_duration=0.005

def detect_spw(time, csd): 
    csd = -csd.T
    #filtered_lfps[filtered_lfps < 0] = 0
    #not_null = np.all(pd.notnull(filtered_lfps), axis=1)
    #filtered_lfps, _time = (filtered_lfps[not_null], time[not_null])
    combined_csd = np.sum(csd, axis=1)

    candidate_spw_times = threshold_by_zscore(
        combined_csd,
        time,
        minimum_duration=minimum_duration,
        detection_zscore_threshold=detection_threshold,
        boundary_zscore_threshold=boundary_threshold,
    )

    index = pd.Index(np.arange(len(candidate_spw_times)) + 1, name="spw_number")
    spws = pd.DataFrame(candidate_spw_times, columns=["start_time", "end_time"], index=index)
    
    return spws

In [282]:
print(sr_chans)
sr_csd = est_csd_val[np.isin(hpc_chans, sr_chans)]
spws = detect_spw(time, sr_csd)

[146, 149, 150, 153]


### Compute SPW properties

In [283]:
from ecephys.signal.ripples import get_durations, get_midpoints, get_sink_amplitudes, get_sink_integrals

In [284]:
spws["duration"] = get_durations(spws)
spws["midpoint"] = get_midpoints(spws)

In [285]:
spws["sink_amplitude"] = get_sink_amplitudes(spws, time, sr_csd) * (1e-6) # Scale to mA/mm

In [286]:
spws["sink_integral"] = get_sink_integrals(spws, time, fs, sr_csd) * (1e-6) * (1e3) # Scale to mA * ms

In [287]:
spws

Unnamed: 0_level_0,start_time,end_time,duration,midpoint,sink_amplitude,sink_integral
spw_number,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1,1.353199,1.434799,0.0816,1.393999,-0.024114,-1.279671
2,1.755998,1.787198,0.0312,1.771598,-0.020968,-0.439843
3,4.450796,4.565595,0.1148,4.508195,-0.039624,-2.053179
4,4.989595,5.067195,0.0776,5.028395,-0.057252,-2.303828
5,5.196395,5.242795,0.0464,5.219595,-0.025813,-0.767067
...,...,...,...,...,...,...
3825,7176.205184,7176.333583,0.1284,7176.269384,-0.029055,-2.053046
3826,7176.904383,7176.951583,0.0472,7176.927983,-0.046288,-1.248569
3827,7184.205176,7184.233976,0.0288,7184.219576,-0.020775,-0.399027
3828,7184.333975,7184.463975,0.1300,7184.398975,-0.041863,-3.597945


### Export results

In [288]:
metadata = dict(
    csd_chans=hpc_chans,
    detection_chans=sr_chans,
    electrode_positions=ele_pos,
    intersite_distance=intersite_distance,
    gdx=interestimate_distance,
    lambd = k.lambd,
    R = k.R,
    detect_states=["Wake", "N1", "N2", "REM"],
    detection_zscore_threshold=detection_threshold,
    boundary_zscore_threshold=boundary_threshold,
    minimum_duration=minimum_duration,
)
spws_path = Path(paths.spws[condition][subject])
store_df_h5(spws_path, spws, **metadata)

In [289]:
sr_chans

[146, 149, 150, 153]