In [1]:
%load_ext autoreload
%autoreload 2

# Imports and definitions

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
import xarray as xr
from scipy.stats import zscore

In [3]:
from ripple_detection.core import gaussian_smooth, segment_boolean_series

In [4]:
from ecephys_analyses.data import paths
from ecephys.utils import load_df_h5, add_attrs, get_disjoint_interval_intersections, get_interval_complements
from ecephys.scoring import get_separated_wake_hypnogram, write_visbrain_hypnogram

In [5]:
SMOOTHING_SIGMA = 4
TD_THRESHOLD_ZSCORE = 1
EMG_THRESHOLD_ZSCORE = 0
MINIMUM_QWK_DURATION = 10

In [6]:
def my_gaussian_smooth(da, smoothing_sigma=4):
    estimated_fs = 1 / np.diff(da.time.values).mean()
    return gaussian_smooth(da, smoothing_sigma, estimated_fs)

In [7]:
def below_threshold_intervals(sig, threshold, time):
    _is_below_threshold = pd.Series(sig < threshold, index=time)
    return segment_boolean_series(_is_below_threshold, minimum_duration=0)

def above_threshold_intervals(sig, threshold, time):
    _is_below_threshold = pd.Series(sig > threshold, index=time)
    return segment_boolean_series(_is_below_threshold, minimum_duration=0)

In [13]:
def run_classification_pipeline_on_file(bandpower_path, emg_path, hypnogram_path):
    lfp_bandpower = xr.open_dataset(bandpower_path)
    emg = xr.open_dataarray(emg_path)
    
    td_smooth = my_gaussian_smooth(lfp_bandpower.sr_theta / lfp_bandpower.sr_delta, smoothing_sigma=SMOOTHING_SIGMA)
    tds_threshold = td_smooth[zscore(np.log(td_smooth)) < TD_THRESHOLD_ZSCORE].max()
    tds_below_threshold_intervals = below_threshold_intervals(td_smooth, tds_threshold, lfp_bandpower.time.values)
    
    emg_smooth = my_gaussian_smooth(emg, smoothing_sigma=SMOOTHING_SIGMA)
    emgs_threshold = emg_smooth[zscore(emg_smooth) < EMG_THRESHOLD_ZSCORE].max()
    emgs_below_threshold_intervals = below_threshold_intervals(emg_smooth, emgs_threshold, emg.time.values)
    
    recording_length = np.max([lfp_bandpower.time.values.max(), emg.time.values.max()])
    qwk_intervals = np.asarray(get_disjoint_interval_intersections(emgs_below_threshold_intervals, tds_below_threshold_intervals))
    qwk_durations = np.asarray([end - start for start, end in qwk_intervals])
    qwk_intervals = qwk_intervals[qwk_durations > MINIMUM_QWK_DURATION]
    awk_intervals = get_interval_complements(qwk_intervals, 0, recording_length)
    
    hypnogram = get_separated_wake_hypnogram(qwk_intervals, awk_intervals)
    write_visbrain_hypnogram(hypnogram, hypnogram_path)
    
    lfp_bandpower.close()
    emg.close()

In [14]:
def get_wake_classification(subject):    
    bandpower_paths = paths.get_sglx_style_datapaths(subject=subject, experiment="novel_objects_deprivation", condition="extended-wake", ext="pow.nc")
    emg_paths = paths.get_sglx_style_datapaths(subject=subject, experiment="novel_objects_deprivation", condition="extended-wake", ext="emg.nc")
    hypnogram_paths = paths.get_sglx_style_datapaths(subject=subject, experiment="novel_objects_deprivation", condition="extended-wake", ext="hypnogram.txt")

    for bandpower_path, emg_path, hypnogram_path in zip(bandpower_paths, emg_paths, hypnogram_paths):
        run_classification_pipeline_on_file(bandpower_path, emg_path, hypnogram_path)
        current_time = datetime.now().strftime("%H:%M:%S")
        print(f"{current_time}: Finished {str(bandpower_path.stem)}")

# Run automated pipeline

In [15]:
get_wake_classification("Allan")

14:24:15: Finished 3-2-2021_A_g0_t0.imec1.pow
14:24:16: Finished 3-2-2021_B_g0_t0.imec1.pow
14:24:17: Finished 3-2-2021_B_g0_t1.imec1.pow
14:24:19: Finished 3-2-2021_C_g0_t0.imec1.pow
14:24:20: Finished 3-2-2021_D_g0_t0.imec1.pow
14:24:21: Finished 3-2-2021_D_g0_t1.imec1.pow
14:24:22: Finished 3-2-2021_D_g0_t2.imec1.pow
14:24:23: Finished 3-2-2021_E_g0_t0.imec1.pow


# Run pipeline piecemeal

In [121]:
%matplotlib widget
import matplotlib.pyplot as plt
from neurodsp.plts.time_series import plot_time_series
import ecephys.plot as eplt

In [122]:
SUBJECT = "Segundo"
CONDITION = "REC-6-4"

In [124]:
lfp_bandpower = xr.open_dataset(paths.get_datapath_from_csv(subject=SUBJECT, condition=CONDITION, data="lfp_bandpower.nc"))
emg = xr.open_dataset(paths.get_datapath_from_csv(subject=SUBJECT, condition=CONDITION, data="emg.nc")).emg

In [110]:
td_smooth = my_gaussian_smooth(lfp_bandpower.sr_theta / lfp_bandpower.sr_delta, smoothing_sigma=SMOOTHING_SIGMA)
tds_threshold = td_smooth[zscore(np.log(td_smooth)) < TD_THRESHOLD_ZSCORE].max()
tds_below_threshold_intervals = below_threshold_intervals(td_smooth, tds_threshold, lfp_bandpower.time.values)

In [111]:
emg_smooth = my_gaussian_smooth(emg, smoothing_sigma=SMOOTHING_SIGMA)
emgs_threshold = emg_smooth[zscore(emg_smooth) < EMG_THRESHOLD_ZSCORE].max()
emgs_below_threshold_intervals = below_threshold_intervals(emg_smooth, emgs_threshold, emg.time.values)

In [112]:
recording_length = np.max([lfp_bandpower.time.values.max(), emg.time.values.max()])
qwk_intervals = np.asarray(get_disjoint_interval_intersections(emgs_below_threshold_intervals, tds_below_threshold_intervals))
qwk_durations = np.asarray([end - start for start, end in qwk_intervals])
qwk_intervals = qwk_intervals[qwk_durations > MINIMUM_QWK_DURATION]
awk_intervals = get_interval_complements(qwk_intervals, 0, recording_length)

In [113]:
hypnogram = get_separated_wake_hypnogram(qwk_intervals, awk_intervals)

In [102]:
write_visbrain_hypnogram(hypnogram, paths.get_datapath(subject=SUBJECT, condition=CONDITION, data="hypnogram.txt"))

In [114]:
fig, axes = plt.subplots(2, 1, figsize=(20, 4))
plot_time_series(lfp_bandpower.time.values, td_smooth, ax=axes[0], title="Theta / Delta ratio, smoothed.", ylabel="Ratio")
plot_time_series(emg.time.values, emg_smooth, ax=axes[1], title="Buzsaki-style EMG, smoothed", ylabel="Corr")
for i in range(0, len(axes)):
    eplt.plot_hypnogram_overlay(hypnogram, ax=axes[i])
    axes[i].set_xlim(axes[0].get_xlim())

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [115]:
lfp_bandpower.close()
emg.close()