- TODO: Parallelize sharp wave property computation

# Imports and notebook definition

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib widget
import matplotlib.pyplot as plt

In [3]:
import numpy as np
import pandas as pd
import json

In [4]:
from ecephys.data import paths, channel_groups
from ecephys.sglx_utils import load_timeseries
from ecephys.signal.csd import get_kcsd
from ecephys.signal.sharp_wave_ripples import detect_sharp_waves_by_value, detect_sharp_waves_by_zscore, get_durations, get_midpoints, get_sink_amplitudes, get_sink_integrals
from ecephys.utils import load_df_h5, store_df_h5, zscore_to_value
import ecephys.plot as eplt

In [7]:
def run_spw_detection_pipeline(subject, condition, use_existing_params):
    sr_chans = channel_groups.stratum_radiatum_140um_to_200um[subject]
    hpc_chans = channel_groups.hippocampus[subject]
    bin_path = paths.get_datapath(subject=subject, condition=condition, data="lf.bin")
    params_path = paths.get_datapath(
        subject=subject, condition="REC-30+18", data="sharp_wave_detection_params.json"
    )
    (time, hpc_lfps, fs) = load_timeseries(
        bin_path, hpc_chans, start_time=None, end_time=None
    )

    if use_existing_params:
        with open(params_path) as params_file:
            params = json.load(params_file)

    if use_existing_params:
        intersite_distance = params["intersite_distance"]
        k = get_kcsd(
            hpc_lfps,
            intersite_distance=params["intersite_distance"],
            gdx=params["gdx"],
            lambd=params["lambd"],
            R_init=params["R"],
            do_lcurve=False,
        )
    else:
        gdx = intersite_distance = 0.020
        k = get_kcsd(
            hpc_lfps, intersite_distance=intersite_distance, gdx=gdx, do_lcurve=True
        )

    hpc_csd = k.values("CSD")
    sr_csd = hpc_csd[np.isin(hpc_chans, sr_chans)]

    if use_existing_params:
        spws = detect_sharp_waves_by_value(
            time,
            sr_csd,
            params["detection_threshold"],
            params["boundary_threshold"],
            params["minimum_duration"],
        )
    else:
        spws = detect_sharp_waves_by_zscore(time, sr_csd)

    spws["duration"] = get_durations(spws)
    spws["midpoint"] = get_midpoints(spws)
    spws["sink_amplitude"] = get_sink_amplitudes(spws, time, sr_csd) * (
        1e-6
    )  # Scale to mA/mm
    spws["sink_integral"] = (
        get_sink_integrals(spws, time, fs, sr_csd) * (1e-6) * (1e3)
    )  # Scale to mA * ms

    metadata = dict(
        csd_chans=hpc_chans,
        detection_chans=sr_chans,
        electrode_positions=k.ele_pos,
        intersite_distance=intersite_distance,
        gdx=k.gdx,
        lambd=k.lambd,
        R=k.R,
        detect_states=["all"],
    )
    metadata.update(spws.attrs)
    spws_path = paths.get_datapath(
        subject=subject, condition=condition, data="sharp_waves.h5"
    )
    store_df_h5(spws_path, spws, **metadata)

    if not use_existing_params:
        metadata.update({"params_source_file": str(bin_path)})
        metadata["csd_chans"] = metadata["csd_chans"].tolist()
        metadata["electrode_positions"] = metadata["electrode_positions"].tolist()
        with open(params_path, "x") as params_file:
            json.dump(metadata, params_file, indent=4)

# Run automated pipeline

### Segundo

In [16]:
run_spw_detection_pipeline(subject="Segundo", condition="REC-0+2", use_existing_params=False)

nChan: 385, nFileSamp: 18000001
Performing L-curve parameter estimation...
No lambda given, using defaults
min lambda 1e-11
max lambda 0.0133
min lambda 1e-11
max lambda 0.0133
l-curve (all lambda):  0.23
Best lambda and R =  0.00013089207687291224 ,  0.23


In [17]:
run_spw_detection_pipeline(subject="Segundo", condition="REC-2-0", use_existing_params=True)

nChan: 385, nFileSamp: 18000001


In [27]:
run_spw_detection_pipeline(subject="Segundo", condition="REC-4-2", use_existing_params=True)

nChan: 385, nFileSamp: 18000001


In [28]:
run_spw_detection_pipeline(subject="Segundo", condition="REC-6-4", use_existing_params=True)

nChan: 385, nFileSamp: 18000001


In [18]:
run_spw_detection_pipeline(subject="Segundo", condition="BSL-0+2", use_existing_params=True)

nChan: 385, nFileSamp: 18000000


### Valentino

In [19]:
run_spw_detection_pipeline(subject="Valentino", condition="REC-0+2", use_existing_params=False)

nChan: 385, nFileSamp: 18000000
Performing L-curve parameter estimation...
No lambda given, using defaults
min lambda 1e-12
max lambda 0.0119
min lambda 1e-12
max lambda 0.0119
l-curve (all lambda):  0.23
Best lambda and R =  0.00105886675336374 ,  0.23


In [20]:
run_spw_detection_pipeline(subject="Valentino", condition="REC-2-0", use_existing_params=True)

nChan: 385, nFileSamp: 18000000


In [25]:
run_spw_detection_pipeline(subject="Valentino", condition="REC-4-2", use_existing_params=True)

nChan: 385, nFileSamp: 18000000


In [26]:
run_spw_detection_pipeline(subject="Valentino", condition="REC-6-4", use_existing_params=True)

nChan: 385, nFileSamp: 18000000


In [24]:
run_spw_detection_pipeline(subject="Valentino", condition="BSL-0+2", use_existing_params=True)

nChan: 385, nFileSamp: 18000000


### Doppio

In [None]:
run_spw_detection_pipeline(subject="Doppio", condition="REC-0+2", use_existing_params=False)

In [None]:
run_spw_detection_pipeline(subject="Doppio", condition="REC-2-0", use_existing_params=True)

In [None]:
run_spw_detection_pipeline(subject="Doppio", condition="REC-4-2", use_existing_params=True)

In [None]:
run_spw_detection_pipeline(subject="Doppio", condition="BSL-6-4", use_existing_params=True)

In [None]:
run_spw_detection_pipeline(subject="Doppio", condition="BSL-0+2", use_existing_params=True)

### Alessandro

In [24]:
run_spw_detection_pipeline(subject="Alessandro", condition="REC-0+1", use_existing_params=False)

nChan: 385, nFileSamp: 9000052
Performing L-curve parameter estimation...
No lambda given, using defaults
min lambda 1e-12
max lambda 0.0126
min lambda 1e-12
max lambda 0.0126
l-curve (all lambda):  0.23
Best lambda and R =  0.00011910798611831335 ,  0.23


In [9]:
run_spw_detection_pipeline(subject="Alessandro", condition="REC-1-0", use_existing_params=True)

nChan: 385, nFileSamp: 9000052


In [10]:
run_spw_detection_pipeline(subject="Alessandro", condition="BSL-0+1", use_existing_params=True)

nChan: 385, nFileSamp: 9000052


### Eugene

In [12]:
run_spw_detection_pipeline(subject="Eugene", condition="REC-0+2", use_existing_params=False)

nChan: 385, nFileSamp: 9000026
Performing L-curve parameter estimation...
No lambda given, using defaults
min lambda 1e-11
max lambda 0.0140
min lambda 1e-11
max lambda 0.0140
l-curve (all lambda):  0.23
Best lambda and R =  0.001404472887665455 ,  0.23


In [14]:
run_spw_detection_pipeline(subject="Eugene", condition="BSL-0+2", use_existing_params=True)

nChan: 385, nFileSamp: 9000025


# Run pipeline piecemeal

In [7]:
SUBJECT = "Doppio"
CONDITION = "BSL-0+2"

## Load the data

In [8]:
sr_chans = channel_groups.stratum_radiatum_140um_to_200um[SUBJECT]
hpc_chans = channel_groups.hippocampus[SUBJECT]
bin_path = paths.get_datapath(subject=SUBJECT, condition=CONDITION, data="lf.bin")
params_path = paths.get_datapath(subject=SUBJECT, condition="REC-30+18", data="sharp_wave_detection_params.json")

In [9]:
(time, hpc_lfps, fs) = load_timeseries(bin_path, hpc_chans)

nChan: 385, nFileSamp: 18000019


### Explore LFPs (optional)

In [None]:
eplt.lfp_explorer(time, hpc_lfps, chan_labels=hpc_chans)

## Detect sharp waves

### If we need to determine detection parameters

In [8]:
intersite_distance = 0.020
k = get_kcsd(hpc_lfps, intersite_distance=intersite_distance, gdx=0.020, do_lcurve=True)
hpc_csd = k.values('CSD')
sr_csd = hpc_csd[np.isin(hpc_chans, sr_chans)]

Performing L-curve parameter estimation...
No lambda given, using defaults
min lambda 1e-12
max lambda 0.0126
min lambda 1e-12
max lambda 0.0126
l-curve (all lambda):  0.23
Best lambda and R =  0.0003822395851068327 ,  0.23


In [32]:
spws = detect_sharp_waves_by_zscore(time, sr_csd)

### If we are using detection parameters obtained elsewhere

In [10]:
with open(params_path) as params_file:
    params = json.load(params_file)
intersite_distance = params['intersite_distance']

In [11]:
k = get_kcsd(hpc_lfps, intersite_distance=params['intersite_distance'], gdx=params['gdx'], lambd=params['lambd'], R_init=params['R'], do_lcurve=False)
hpc_csd = k.values('CSD')
sr_csd = hpc_csd[np.isin(hpc_chans, sr_chans)]

In [12]:
spws = detect_sharp_waves_by_value(time, sr_csd, params['detection_threshold'], params['boundary_threshold'], params['minimum_duration'])

## Compute SPW properties

In [13]:
spws["duration"] = get_durations(spws)
spws["midpoint"] = get_midpoints(spws)

In [14]:
spws["sink_amplitude"] = get_sink_amplitudes(spws, time, sr_csd) * (1e-6) # Scale to mA/mm

In [15]:
spws["sink_integral"] = get_sink_integrals(spws, time, fs, sr_csd) * (1e-6) * (1e3) # Scale to mA * ms

## Export results

In [16]:
metadata = dict(
    csd_chans=hpc_chans,
    detection_chans=sr_chans,
    electrode_positions=k.ele_pos, 
    intersite_distance=intersite_distance,
    gdx=k.gdx,
    lambd = k.lambd,
    R = k.R,
    detect_states=["all"],
)
metadata.update(spws.attrs)
spws_path = paths.get_datapath(subject=SUBJECT, condition=CONDITION, data="sharp_waves.h5")
store_df_h5(spws_path, spws, **metadata)

## If necessary, create params file

In [85]:
metadata.update({'params_source_file': str(bin_path)})
metadata['csd_chans'] = metadata['csd_chans'].tolist()
metadata['electrode_positions'] = metadata['electrode_positions'].tolist()
with open(params_path, 'x') as params_file:
    json.dump(metadata, params_file, indent=4)