- TODO: Parallelize sharp wave property computation

# Imports and definitions

In [9]:
%load_ext autoreload
%autoreload 2
from IPython.core.debugger import set_trace

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
%matplotlib widget
import matplotlib.pyplot as plt

In [11]:
import numpy as np
import pandas as pd
import xarray as xr
import json
from datetime import datetime
from ast import literal_eval

In [12]:
from ecephys_analyses.data import paths, channel_groups
from ecephys.sglx_utils import load_timeseries, load_multifile_timeseries
from ecephys.signal.csd import get_kcsd
from ecephys.signal.sharp_wave_ripples import detect_sharp_waves_by_value, detect_sharp_waves_by_zscore, get_durations, get_midpoints, get_sink_amplitudes, get_sink_integrals
from ecephys.utils import load_df_h5, store_df_h5, zscore_to_value
import ecephys.plot as eplt

## Get SPW detection parameters from early recovery sleep

In [13]:
def get_spw_detection_parameters(subject, nfiles, detection_threshold_zscore=2.5, boundary_threshold_zscore=1):
    sr_chans = channel_groups.stratum_radiatum_140um_to_200um[subject]
    hpc_chans = channel_groups.hippocampus[subject]
    
    bin_paths = paths.get_sglx_style_datapaths(subject=subject, experiment="sleep-homeostasis", condition="recovery-sleep", ext="lf.bin")[:nfiles]
    params_path = paths.get_datapath(subject=subject, experiment="sleep-homeostasis", file="sharp_wave_detection_params.json")
    
    if nfiles > 1:
        da = load_multifile_timeseries(bin_paths, hpc_chans, contiguous=True)
    else:
        da = load_timeseries(bin_paths[0], hpc_chans)
    (time, hpc_lfps, fs) = (da.time.values, da.values, da.fs)
    
    gdx = intersite_distance = 0.020
    k = get_kcsd(
        hpc_lfps, intersite_distance=intersite_distance, gdx=gdx, do_lcurve=True
    )
    
    hpc_csd = k.values("CSD")
    sr_csd = hpc_csd[np.isin(hpc_chans, sr_chans)]
    combined_csd = np.sum(-sr_csd.T, axis=1)

    detection_threshold = zscore_to_value(combined_csd, detection_threshold_zscore)
    boundary_threshold = zscore_to_value(combined_csd, boundary_threshold_zscore)
    
    metadata = dict(
        csd_chans=hpc_chans.tolist(),
        detection_chans=sr_chans,
        electrode_positions=k.ele_pos.tolist(),
        intersite_distance=intersite_distance,
        gdx=k.gdx,
        lambd=k.lambd,
        R=k.R,
        detect_states=["all"],
        detection_threshold_zscore=detection_threshold_zscore,
        boundary_threshold_zscore=boundary_threshold_zscore,
        detection_threshold=detection_threshold,
        boundary_threshold=boundary_threshold,
        minimum_duration=0.005,
        params_source_files=[str(path) for path in bin_paths] 
    )
    
    params_path.parent.mkdir(parents=True, exist_ok=True)
    with open(params_path, "x") as params_file:
        json.dump(metadata, params_file, indent=4)

## Detect SPWS, accounting for drift

In [14]:
def get_epoch_spws(hpc_csd, params, epoch_start, epoch_end, sr_chans):
    if not sr_chans:
        return pd.DataFrame()
    
    sr_csd = hpc_csd.sel(time=slice(epoch_start, epoch_end), channel=sr_chans)
    
    spws = detect_sharp_waves_by_value(
        sr_csd.time.values,
        sr_csd.values,
        params["detection_threshold"],
        params["boundary_threshold"],
        params["minimum_duration"],
    )

    if not spws.empty:
        spws["duration"] = get_durations(spws)
        spws["midpoint"] = get_midpoints(spws)
        spws["sink_amplitude"] = get_sink_amplitudes(spws, sr_csd.time.values, sr_csd.values) * (
            1e-6
        )  # Scale to mA/mm
        spws["sink_integral"] = (
            get_sink_integrals(spws, sr_csd.time.values, sr_csd.fs, sr_csd.values) * (1e-6) * (1e3)
        )  # Scale to mA * ms

        spws["sr_chans"] = [sr_chans] * len(spws)
    
    return spws

In [15]:
def get_file_spws(bin_path, sr_chans_path, spw_path, params_path, hpc_chans):
    hpc_lfps = load_timeseries(bin_path, hpc_chans)

    with open(params_path) as params_file:
        params = json.load(params_file)

    intersite_distance = params["intersite_distance"]
    k = get_kcsd(
        hpc_lfps.values,
        intersite_distance=params["intersite_distance"],
        gdx=params["gdx"],
        lambd=params["lambd"],
        R_init=params["R"],
        do_lcurve=False,
    )
    
    hpc_csd = xr.DataArray(
        k.values("CSD"),
        dims=("channel", "time"),
        coords={"channel": hpc_lfps.channel.values, "time": hpc_lfps.time.values},
        attrs={'units': "nA/mm", 'fs': hpc_lfps.fs}
    ) 

    spws_by_epoch = list()
    sr_chans_df = pd.read_csv(sr_chans_path)
    sr_chans_df.sr_chans = sr_chans_df.sr_chans.apply(lambda x: [] if pd.isnull(x) else list(literal_eval(x)))
    for epoch in sr_chans_df.itertuples():
        spws_by_epoch.append(get_epoch_spws(hpc_csd, params, epoch.start_time, epoch.end_time, epoch.sr_chans))
    
    spws = pd.concat(spws_by_epoch)
    metadata = dict(
        csd_chans=hpc_chans,
        electrode_positions=k.ele_pos,
        intersite_distance=intersite_distance,
        gdx=k.gdx,
        lambd=k.lambd,
        R=k.R,
        detect_states=["all"],
        file_start=hpc_lfps.fileCreateTime,
    )
    metadata.update(spws.attrs)

    store_df_h5(spw_path, spws, **metadata)

In [16]:
def get_experiment_spws(subject, experiment):
    hpc_chans = channel_groups.hippocampus[subject]
    
    bin_paths = paths.get_sglx_style_datapaths(subject=subject, experiment=experiment, condition="all", ext="lf.bin")
    sr_chans_paths = paths.get_sglx_style_datapaths(subject=subject, experiment=experiment, condition="all", ext="sr_chans.csv")
    spw_paths = paths.get_sglx_style_datapaths(subject=subject, experiment=experiment, condition="all", ext="spws.h5")
    params_path = paths.get_datapath(file="sharp_wave_detection_params.json", subject=subject, experiment=experiment)

    for bin_path, sr_chans_path, spw_path in zip(bin_paths, sr_chans_paths, spw_paths):
        get_file_spws(bin_path, sr_chans_path, spw_path, params_path, hpc_chans)
        current_time = datetime.now().strftime("%H:%M:%S")
        print(f"{current_time}: Finished {str(bin_path)}")

# Run automated pipeline

In [17]:
get_experiment_spws(subject="Doppio", experiment="atropine")

nChan: 385, nFileSamp: 1802362


your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->block1_values] [items->Index(['sr_chans'], dtype='object')]

  # Remove the CWD from sys.path while we load stuff.


18:12:25: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX4-Doppio/3-26-2020/3-26-2020_g0/3-26-2020_g0_imec0/3-26-2020_g0_t0.imec0.lf.bin
nChan: 385, nFileSamp: 4977007
18:14:41: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX4-Doppio/3-26-2020/3-26-2020_2_g0/3-26-2020_2_g0_imec0/3-26-2020_2_g0_t0.imec0.lf.bin
nChan: 385, nFileSamp: 18000039
18:28:06: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX4-Doppio/3-26-2020/3-26-2020_3_g0/3-26-2020_3_g0_imec0/3-26-2020_3_g0_t0.imec0.lf.bin
nChan: 385, nFileSamp: 8553813
18:32:02: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX4-Doppio/3-26-2020/3-26-2020_3_g0/3-26-2020_3_g0_imec0/3-26-2020_3_g0_t1.imec0.lf.bin
