In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [3]:
from kcsd import KCSD1D

In [43]:
from ipywidgets import interact, fixed

In [44]:
from ecephys.graham import channel_groups, paths
from ecephys.helpers.utils import load_df_h5
from ecephys.sglx_utils import load_timeseries
from ecephys.signal.ripples import apply_ripple_filter
from ecephys.signal.timefrequency import get_perievent_cwtm

In [63]:
condition = "SR"
subject = "Eugene"

ripple_path = Path(paths.ripples[condition][subject])
with pd.HDFStore(ripple_path) as store:
    ripples, metadata = load_df_h5(store)

In [69]:
@interact(ripples=fixed(ripples), metadata=fixed(metadata), subject=fixed(subject), condition=fixed(condition), window_length=fixed(1), ripple_number=(1, len(ripples)), ax=fixed(None))
def ripple_explorer(ripples, metadata, subject, condition, window_length=1, ripple_number=1, ax=None,):

    if ax is None:
        _, axes = plt.subplots(4, 1, figsize=(30, 12))
    else:
        [ax.cla() for ax in axes]

    ripple = ripples.loc[ripple_number]
    window_start_time = ripple.center_time - window_length / 2
    window_end_time = ripple.center_time + window_length / 2
    
    all_chans = channel_groups.full[subject]
    (time, sig, fs) = load_timeseries(Path(paths.lfp_bin[condition][subject]), all_chans, start_time=window_start_time, end_time=window_end_time)
    
    idx_detection_chans = np.isin(all_chans, metadata["chans"])
    detection_sig = sig[:, idx_detection_chans]
    filtered_detection_sig = apply_ripple_filter(detection_sig, fs)

    axes[0].plot(time, filtered_detection_sig, linewidth=1)

    offset_lfps = detection_sig - np.full(detection_sig.shape, np.arange(detection_sig.shape[1]) * 300)
    axes[1].plot(
        time,
        offset_lfps,
        color="black",
        linewidth=0.5,
    )
    axes[1].set_xlim(axes[0].get_xlim())

    # Compute CWTM
    freq = np.linspace(1, 300, 300)
    cwtm = get_perievent_cwtm(detection_sig, fs, freq)
    cwtm = cwtm / (1 / freq)[:, None]
    axes[2].pcolormesh(time, freq, cwtm, cmap="viridis", shading="gouraud")
    axes[2].set_xlim(axes[0].get_xlim())
    axes[2].axhline(150, color="k", alpha=0.5, linestyle="--")
    axes[2].axhline(250, color="k", alpha=0.5, linestyle="--")
        
    # Compute CSD
    n_chans = len(all_chans)
    intersite_distance = 0.020
    ele_pos = np.linspace(0., (n_chans - 1) * intersite_distance, n_chans).reshape(n_chans, 1)    
    k = KCSD1D(ele_pos, sig.T)
    est_csd = k.values('CSD')
    
    # Find and select hippocampal sources
    idx_hpc_chans = np.isin(all_chans, channel_groups.hippocampus[subject])
    hpc_ele_pos = ele_pos[idx_hpc_chans]
    idx_hpc_src = np.logical_and(k.estm_x >= np.min(hpc_ele_pos), k.estm_x <= np.max(hpc_ele_pos))

    
    # Plot CSD
    axes[3].pcolormesh(time, k.estm_x[idx_hpc_src], est_csd[idx_hpc_src, :], shading="gouraud")
    axes[3].set_xlim(axes[0].get_xlim())    
    axes[3].set_xlabel("Time [sec]")
    axes[3].set_ylabel("Depth (mm)")
    
    # Plot lines to show area used for ripple detection
    detection_ele_pos = ele_pos[idx_detection_chans]
    axes[3].axhline(np.min(detection_ele_pos), alpha=0.5, color='k', linestyle=":")
    axes[3].axhline(np.max(detection_ele_pos), alpha=0.5, color='k', linestyle=":")

    for ripple in ripples.itertuples():
        if (ripple.start_time >= window_start_time) and (
            ripple.end_time <= window_end_time
        ):
            axes[0].axvspan(
                ripple.start_time, ripple.end_time, alpha=0.3, color="red", zorder=1000
            )
            axes[1].axvspan(
                ripple.start_time, ripple.end_time, alpha=0.3, color="red", zorder=1000
            )

interactive(children=(IntSlider(value=1, description='ripple_number', max=518, min=1), Output()), _dom_classes…