- TODO: Parallelize sharp wave property computation

# Imports and definitions

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib widget
import matplotlib.pyplot as plt

In [3]:
import numpy as np
import pandas as pd
import xarray as xr
import json
from datetime import datetime
from ast import literal_eval

In [4]:
from ecephys_analyses.data import paths, channel_groups
from ecephys.sglx_utils import load_timeseries, load_multifile_timeseries
from ecephys.signal.csd import get_kcsd
from ecephys.signal.sharp_wave_ripples import detect_sharp_waves_by_value, detect_sharp_waves_by_zscore, get_durations, get_midpoints, get_sink_amplitudes, get_sink_integrals
from ecephys.utils import load_df_h5, store_df_h5, zscore_to_value
import ecephys.plot as eplt

## Get SPW detection parameters from early recovery sleep

In [5]:
def get_spw_detection_parameters(subject, detection_threshold_zscore=2.5, boundary_threshold_zscore=1):
    sr_chans = channel_groups.stratum_radiatum_140um_to_200um[subject]
    hpc_chans = channel_groups.hippocampus[subject]
    
    bin_paths = paths.get_sglx_style_datapaths(subject=subject, condition="recovery-sleep-2h", ext="lf.bin")
    params_path = paths.get_datapath(subject=subject, condition="sleep-homeostasis", file="sharp_wave_detection_params.json")
    
    if len(bin_paths) > 1:
        (time, hpc_lfps, fs) = load_multifile_timeseries(bin_paths, hpc_chans)
    else:
        (time, hpc_lfps, fs) = load_timeseries(bin_paths[0], hpc_chans)
    
    gdx = intersite_distance = 0.020
    k = get_kcsd(
        hpc_lfps, intersite_distance=intersite_distance, gdx=gdx, do_lcurve=True
    )
    
    hpc_csd = k.values("CSD")
    sr_csd = hpc_csd[np.isin(hpc_chans, sr_chans)]
    combined_csd = np.sum(-sr_csd.T, axis=1)

    detection_threshold = zscore_to_value(combined_csd, detection_threshold_zscore)
    boundary_threshold = zscore_to_value(combined_csd, boundary_threshold_zscore)
    
    metadata = dict(
        csd_chans=hpc_chans.tolist(),
        detection_chans=sr_chans,
        electrode_positions=k.ele_pos.tolist(),
        intersite_distance=intersite_distance,
        gdx=k.gdx,
        lambd=k.lambd,
        R=k.R,
        detect_states=["all"],
        detection_threshold_zscore=detection_threshold_zscore,
        boundary_threshold_zscore=boundary_threshold_zscore,
        detection_threshold=detection_threshold,
        boundary_threshold=boundary_threshold,
        minimum_duration=0.005,
        params_source_files=[str(path) for path in bin_paths] 
    )
    
    params_path.parent.mkdir(parents=True, exist_ok=True)
    with open(params_path, "x") as params_file:
        json.dump(metadata, params_file, indent=4)

## Detect SPWS, accounting for drift

In [5]:
def get_epoch_spws(hpc_csd, params, epoch_start, epoch_end, sr_chans):
    sr_csd = hpc_csd.sel(time=slice(epoch_start, epoch_end), channel=sr_chans)
    
    spws = detect_sharp_waves_by_value(
        sr_csd.time.values,
        sr_csd.values,
        params["detection_threshold"],
        params["boundary_threshold"],
        params["minimum_duration"],
    )

    if not spws.empty:
        spws["duration"] = get_durations(spws)
        spws["midpoint"] = get_midpoints(spws)
        spws["sink_amplitude"] = get_sink_amplitudes(spws, sr_csd.time.values, sr_csd.values) * (
            1e-6
        )  # Scale to mA/mm
        spws["sink_integral"] = (
            get_sink_integrals(spws, sr_csd.time.values, sr_csd.fs, sr_csd.values) * (1e-6) * (1e3)
        )  # Scale to mA * ms

        spws["sr_chans"] = [sr_chans] * len(spws)
    
    return spws

In [6]:
def get_file_spws(bin_path, sr_chans_path, spw_path, params_path, hpc_chans):
    hpc_lfps = load_timeseries(bin_path, hpc_chans)

    with open(params_path) as params_file:
        params = json.load(params_file)

    intersite_distance = params["intersite_distance"]
    k = get_kcsd(
        hpc_lfps.values,
        intersite_distance=params["intersite_distance"],
        gdx=params["gdx"],
        lambd=params["lambd"],
        R_init=params["R"],
        do_lcurve=False,
    )
    
    hpc_csd = xr.DataArray(
        k.values("CSD"),
        dims=("channel", "time"),
        coords={"channel": hpc_lfps.channel.values, "time": hpc_lfps.time.values},
        attrs={'units': "nA/mm", 'fs': hpc_lfps.fs}
    ) 

    spws_by_epoch = list()
    sr_chans_df = pd.read_csv(sr_chans_path, converters={"sr_chans": literal_eval})
    sr_chans_df.sr_chans = sr_chans_df.sr_chans.apply(list)
    for epoch in sr_chans_df.itertuples():
        spws_by_epoch.append(get_epoch_spws(hpc_csd, params, epoch.start_time, epoch.end_time, epoch.sr_chans))
    
    spws = pd.concat(spws_by_epoch)
    metadata = dict(
        csd_chans=hpc_chans,
        electrode_positions=k.ele_pos,
        intersite_distance=intersite_distance,
        gdx=k.gdx,
        lambd=k.lambd,
        R=k.R,
        detect_states=["all"],
        file_start=hpc_lfps.fileCreateTime,
    )
    metadata.update(spws.attrs)

    store_df_h5(spw_path, spws, **metadata)

In [7]:
def get_condition_spws(subject, condition):
    hpc_chans = channel_groups.hippocampus[subject]
    
    bin_paths = paths.get_sglx_style_datapaths(subject=subject, condition=condition, ext="lf.bin")
    sr_chans_paths = paths.get_sglx_style_datapaths(subject=subject, condition=condition, ext="sr_chans.csv")
    spw_paths = paths.get_sglx_style_datapaths(subject=subject, condition=condition, ext="spws.h5")
    params_path = paths.get_datapath(subject=subject, condition="sleep-homeostasis", file="sharp_wave_detection_params.json")

    for bin_path, sr_chans_path, spw_path in zip(bin_paths, sr_chans_paths, spw_paths):
        get_file_spws(bin_path, sr_chans_path, spw_path, params_path, hpc_chans)
        current_time = datetime.now().strftime("%H:%M:%S")
        print(f"{current_time}: Finished {str(bin_path)}")

# Run automated pipeline

### Segundo

In [16]:
get_spw_detection_parameters(subject="Segundo", detection_threshold_zscore=2.5, boundary_threshold_zscore=1)

nChan: 385, nFileSamp: 18000001
Performing L-curve parameter estimation...
No lambda given, using defaults
min lambda 1e-11
max lambda 0.0133
min lambda 1e-11
max lambda 0.0133
l-curve (all lambda):  0.23
Best lambda and R =  0.00013089207687291224 ,  0.23


In [8]:
get_condition_spws(subject="Segundo", condition="all")

nChan: 385, nFileSamp: 18000000


your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->block1_values] [items->Index(['sr_chans'], dtype='object')]

  # Remove the CWD from sys.path while we load stuff.


14:51:36: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX2-Segundo/1-21-2020/1-21-2020_g0/1-21-2020_g0_imec0/1-21-2020_g0_t8.imec0.lf.bin
nChan: 385, nFileSamp: 18000000
15:04:08: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX2-Segundo/1-21-2020/1-21-2020_g0/1-21-2020_g0_imec0/1-21-2020_g0_t9.imec0.lf.bin
nChan: 385, nFileSamp: 18000000
15:17:25: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX2-Segundo/1-21-2020/1-21-2020_g0/1-21-2020_g0_imec0/1-21-2020_g0_t10.imec0.lf.bin
nChan: 385, nFileSamp: 18000000
15:28:50: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX2-Segundo/1-21-2020/1-21-2020_g0/1-21-2020_g0_imec0/1-21-2020_g0_t11.imec0.lf.bin
nChan: 385, nFileSamp: 18000000
15:35:35: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX2-Segundo/1-21-2020/1-21-2020_g0/1-21-2020_g0_imec0/1-21-2020_g0_t12.imec0.lf.bin
nChan: 385, nFileSamp: 18000000
15:48:19: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX2-Segundo/1-21-2020/1-21-2020_g0/1-21-2020

### Valentino

In [19]:
get_spw_detection_parameters(subject="Valentino", detection_threshold_zscore=2.5, boundary_threshold_zscore=1)

nChan: 385, nFileSamp: 18000000
Performing L-curve parameter estimation...
No lambda given, using defaults
min lambda 1e-12
max lambda 0.0119
min lambda 1e-12
max lambda 0.0119
l-curve (all lambda):  0.23
Best lambda and R =  0.00105886675336374 ,  0.23


In [35]:
get_condition_spws(subject="Valentino", condition="all")

nChan: 385, nFileSamp: 18000000


your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->block1_values] [items->Index(['sr_chans'], dtype='object')]

  # Remove the CWD from sys.path while we load stuff.


19:14:40: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX3-Valentino/2-19-2020/2-19-2020_g1/2-19-2020_g1_imec0/2-19-2020_g1_t0.imec0.lf.bin
nChan: 385, nFileSamp: 18000000
19:18:55: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX3-Valentino/2-19-2020/2-19-2020_g1/2-19-2020_g1_imec0/2-19-2020_g1_t1.imec0.lf.bin
nChan: 385, nFileSamp: 18000000
19:25:07: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX3-Valentino/2-19-2020/2-19-2020_g1/2-19-2020_g1_imec0/2-19-2020_g1_t2.imec0.lf.bin
nChan: 385, nFileSamp: 18000000
19:31:21: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX3-Valentino/2-19-2020/2-19-2020_g1/2-19-2020_g1_imec0/2-19-2020_g1_t3.imec0.lf.bin
nChan: 385, nFileSamp: 18000000
19:36:24: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX3-Valentino/2-19-2020/2-19-2020_g1/2-19-2020_g1_imec0/2-19-2020_g1_t4.imec0.lf.bin
nChan: 385, nFileSamp: 18000000
19:43:01: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX3-Valentino/2-19-2020/2-19-2020_g1/

### Doppio

In [None]:
get_spw_detection_parameters(subject="Doppio", detection_threshold_zscore=2.5, boundary_threshold_zscore=1)

### Alessandro

In [18]:
get_spw_detection_parameters(subject="Alessandro", detection_threshold_zscore=2.5, boundary_threshold_zscore=1)

nChan: 385, nFileSamp: 9000052
nChan: 385, nFileSamp: 9000052
You are loading multifile SGLX data without xarray.
 Are you sure you want to do this? Please see documentation.
Performing L-curve parameter estimation...
No lambda given, using defaults
min lambda 1e-12
max lambda 0.0126
min lambda 1e-12
max lambda 0.0126
l-curve (all lambda):  0.23
Best lambda and R =  3.711471263028718e-05 ,  0.23


In [9]:
get_condition_spws(subject="Alessandro", condition="all")

nChan: 385, nFileSamp: 9000051


your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->block1_values] [items->Index(['sr_chans'], dtype='object')]

  # Remove the CWD from sys.path while we load stuff.


19:09:59: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX5-Alessandro/8-24-2020/8-24-2020_g0/8-24-2020_g0_imec0/8-24-2020_g0_t0.imec0.lf.bin
nChan: 385, nFileSamp: 9000052
19:13:28: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX5-Alessandro/8-24-2020/8-24-2020_g0/8-24-2020_g0_imec0/8-24-2020_g0_t1.imec0.lf.bin
nChan: 385, nFileSamp: 9000052
19:16:05: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX5-Alessandro/8-24-2020/8-24-2020_g0/8-24-2020_g0_imec0/8-24-2020_g0_t2.imec0.lf.bin
nChan: 385, nFileSamp: 9000052
19:19:17: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX5-Alessandro/8-24-2020/8-24-2020_g0/8-24-2020_g0_imec0/8-24-2020_g0_t3.imec0.lf.bin
nChan: 385, nFileSamp: 9000051
19:21:33: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX5-Alessandro/8-24-2020/8-24-2020_g0/8-24-2020_g0_imec0/8-24-2020_g0_t4.imec0.lf.bin
nChan: 385, nFileSamp: 9000051
19:24:45: Finished /Volumes/neuropixel_archive/Data/chronic/CNPIX5-Alessandro/8-24-2020/8-24-2020_g0

ValueError: Wrong number of items passed 2, placement implies 1

### Eugene

In [34]:
get_spw_detection_parameters(subject="Eugene", detection_threshold_zscore=2.5, boundary_threshold_zscore=1)

nChan: 385, nFileSamp: 9000026
nChan: 385, nFileSamp: 9000026
You are loading multifile SGLX data without xarray.
 Are you sure you want to do this? Please see documentation.
Performing L-curve parameter estimation...
No lambda given, using defaults
min lambda 1e-11
max lambda 0.0140
min lambda 1e-11
max lambda 0.0140
l-curve (all lambda):  0.23
Best lambda and R =  0.001404472887665455 ,  0.23


# Run pipeline piecemeal

In [10]:
subject = "Alessandro"
condition = "all"

## Load the data

In [19]:
hpc_chans = channel_groups.hippocampus[subject]
bin_paths = paths.get_sglx_style_datapaths(subject=subject, condition=condition, ext="lf.bin")
sr_chans_paths = paths.get_sglx_style_datapaths(subject=subject, condition=condition, ext="sr_chans.csv")
spw_paths = paths.get_sglx_style_datapaths(subject=subject, condition=condition, ext="spws.h5")
params_path = paths.get_datapath(subject=subject, condition="sleep-homeostasis", file="sharp_wave_detection_params.json")

In [21]:
filenum = 9
bin_path = bin_paths[filenum]
sr_chans_path = sr_chans_paths[filenum]
spw_path = spw_paths[filenum]

In [25]:
hpc_lfps = load_timeseries(bin_path, hpc_chans)

nChan: 385, nFileSamp: 9000052


## Detect sharp waves

### If we need to determine detection parameters

In [8]:
intersite_distance = 0.020
k = get_kcsd(hpc_lfps, intersite_distance=intersite_distance, gdx=0.020, do_lcurve=True)
hpc_csd = k.values('CSD')
sr_csd = hpc_csd[np.isin(hpc_chans, sr_chans)]

Performing L-curve parameter estimation...
No lambda given, using defaults
min lambda 1e-12
max lambda 0.0126
min lambda 1e-12
max lambda 0.0126
l-curve (all lambda):  0.23
Best lambda and R =  0.0003822395851068327 ,  0.23


In [32]:
spws = detect_sharp_waves_by_zscore(time, sr_csd)

### If we are using detection parameters obtained elsewhere

In [26]:
with open(params_path) as params_file:
    params = json.load(params_file)

intersite_distance = params["intersite_distance"]
k = get_kcsd(
    hpc_lfps.values,
    intersite_distance=params["intersite_distance"],
    gdx=params["gdx"],
    lambd=params["lambd"],
    R_init=params["R"],
    do_lcurve=False,
)

hpc_csd = xr.DataArray(
    k.values("CSD"),
    dims=("channel", "time"),
    coords={"channel": hpc_lfps.channel.values, "time": hpc_lfps.time.values},
    attrs={'units': "nA/mm", 'fs': hpc_lfps.fs}
) 

In [57]:
spws_by_epoch = list()
sr_chans_df = pd.read_csv(sr_chans_path, converters={"sr_chans": literal_eval})
sr_chans_df.sr_chans = sr_chans_df.sr_chans.apply(list)
for epoch in sr_chans_df.itertuples():
    spws_by_epoch.append(get_epoch_spws(hpc_csd, params, epoch.start_time, epoch.end_time, epoch.sr_chans))

spws = pd.concat(spws_by_epoch)

In [49]:
epoch_num = 1
epoch_start = sr_chans_df.start_time[epoch_num]
epoch_end = sr_chans_df.end_time[epoch_num]
sr_chans = sr_chans_df.sr_chans[epoch_num]

sr_csd = hpc_csd.sel(time=slice(epoch_start, epoch_end), channel=sr_chans)

spws = detect_sharp_waves_by_value(
    sr_csd.time.values,
    sr_csd.values,
    params["detection_threshold"],
    params["boundary_threshold"],
    params["minimum_duration"],
)

if not spws.empty:
    spws["duration"] = get_durations(spws)
    spws["midpoint"] = get_midpoints(spws)
    spws["sink_amplitude"] = get_sink_amplitudes(spws, sr_csd.time.values, sr_csd.values) * (
        1e-6
    )  # Scale to mA/mm
    spws["sink_integral"] = (
        get_sink_integrals(spws, sr_csd.time.values, sr_csd.fs, sr_csd.values) * (1e-6) * (1e3)
    )  # Scale to mA * ms

    spws["sr_chans"] = [sr_chans] * len(spws)

## Export results

In [16]:
metadata = dict(
    csd_chans=hpc_chans,
    electrode_positions=k.ele_pos,
    intersite_distance=intersite_distance,
    gdx=k.gdx,
    lambd=k.lambd,
    R=k.R,
    detect_states=["all"],
    file_start=hpc_lfps.fileCreateTime,
)
metadata.update(spws.attrs)

store_df_h5(spw_path, spws, **metadata)

## If necessary, create params file

In [85]:
metadata.update({'params_source_file': str(bin_path)})
metadata['csd_chans'] = metadata['csd_chans'].tolist()
metadata['electrode_positions'] = metadata['electrode_positions'].tolist()
with open(params_path, 'x') as params_file:
    json.dump(metadata, params_file, indent=4)