In [None]:
import numpy as np
import ipywidgets as widgets
import matplotlib.pyplot as plt
import os

def MapLFPs(path, nch, dtype=np.int16 ,order='F'):
        '''Returns a 2D numpy <memmap>-object to a binary file, which is indexable as [channel, sample].

        INPUT:
        - [path]:              <str> containing full path to binary-file
        - [nch]:               <int> number of channels in binary file
        - [dtype]=np.int16:    <numpy-type> of binary data points'''

        ## Calculate the total number of data points in the provided binary-file
        size = os.path.getsize(path)
        size = int(size/np.dtype(dtype).itemsize)

        ## Create and return the 2D memory-map object
        memMap = np.memmap(path, mode='r', dtype=dtype, order=order, shape=(nch, int(size/nch)))

        return memMap  

In [4]:
HPC_PATH = '/Volumes/behrens/mohamady_el-gaby/Taskspace_abstraction/Data/cohort4/me10/Ephys/HPC/2021-10-26_13-17-19/Record Node 122/experiment1/recording1/continuous/Rhythm_FPGA-121.0/continuous.dat'

In [5]:
HPC_DAT = MapLFPs(HPC_PATH, 64)

In [6]:
HPC_DAT.shape

(64, 72432128)

In [7]:
import numpy as np
from scipy.signal import butter, filtfilt, decimate, hilbert, resample_poly

# Define filtering utility
def butter_bandpass(lowcut, highcut, fs, order=4):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return b, a

def apply_filter(data, lowcut, highcut, fs, order=4):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    return filtfilt(b, a, data, axis=1)

# Downsample function
def downsample(data, original_fs, target_fs):
    downsample_factor = original_fs // target_fs
    if original_fs % target_fs != 0:
        raise ValueError("Original sampling rate must be an integer multiple of target sampling rate.")
    return resample_poly(data, up=1, down=downsample_factor, axis=1)

# Detect sharp-wave ripples
def detect_swr(data, ripple_fs, ripple_band=(150, 250), zscore_threshold=3):
    """
    Detects sharp-wave ripples in filtered LFP data.
    
    Parameters:
        data (np.ndarray): LFP data (channels x timepoints) at ripple sampling rate.
        ripple_fs (float): Sampling rate of the input data (Hz).
        ripple_band (tuple): Ripple frequency band (Hz).
        zscore_threshold (float): Threshold for SWR detection in z-score units.
    
    Returns:
        List of detected SWRs per channel: [(start_idx, end_idx), ...]
    """
    swr_events = []
    for ch in range(data.shape[0]):
        filtered = apply_filter(data[ch:ch+1, :], ripple_band[0], ripple_band[1], ripple_fs)
        envelope = np.abs(hilbert(filtered))  # Compute the analytic signal envelope
        zscored_envelope = (envelope - np.mean(envelope)) / np.std(envelope)
        above_thresh = np.where(zscored_envelope > zscore_threshold)[1]  # Get indices

        # Group into events
        events = []
        if len(above_thresh) > 0:
            event_start = above_thresh[0]
            for i in range(1, len(above_thresh)):
                if above_thresh[i] != above_thresh[i - 1] + 1:
                    event_end = above_thresh[i - 1]
                    events.append((event_start, event_end))
                    event_start = above_thresh[i]
            events.append((event_start, above_thresh[-1]))
        
        swr_events.append(events)
    return swr_events

# Main processing pipeline
def process_lfp(lfp_path, nch, original_fs=30000, target_fs=1000, ripple_band=(150, 250)):
    # Open memmap and reshape
    data = MapLFPs(lfp_path, nch)
    print(f"Original data shape: {data.shape}")

    # Downsample
    print("Downsampling...")
    downsampled_data = downsample(data, original_fs, target_fs)
    print(f"Downsampled data shape: {downsampled_data.shape}")

    # Detect SWRs
    print("Detecting sharp-wave ripples...")
    swr_events = detect_swr(downsampled_data, target_fs, ripple_band)

    return downsampled_data, swr_events

# Usage example
# lfp_path = HPC_PATH  # Replace with your memmap file
# shape = HPC_DAT.shape  # Adjust 'n' to match your total time points
# downsampled_data, swr_events = process_lfp(lfp_path, shape)

# # Output SWR events for the first channel
# print(f"SWR events for channel 1: {swr_events[0]}")


In [None]:
import matplotlib.pyplot as plt

def plot_lfp_with_swr(data, swr_events, channel, fs, start_time=None, end_time=None):
    """
    Plot LFP data for a single channel with SWR events highlighted.
    
    Parameters:
        data (np.ndarray): LFP data (channels x timepoints).
        swr_events (list): List of SWR event tuples for each channel.
        channel (int): Channel index to plot (0-based).
        fs (float): Sampling rate of the data (Hz).
        start_time (float): Start time of the plot (seconds).
        end_time (float): End time of the plot (seconds).
    """
    # Time vector
    num_samples = data.shape[1]
    time_vector = np.arange(num_samples) / fs  # in seconds

    # Get time-limited data
    if start_time is not None and end_time is not None:
        start_idx = int(start_time * fs)
        end_idx = int(end_time * fs)
        time_vector = time_vector[start_idx:end_idx]
        lfp_trace = data[channel, start_idx:end_idx]
        channel_swr_events = [(start, end) for start, end in swr_events[channel]
                              if start >= start_idx and end <= end_idx]
    else:
        lfp_trace = data[channel]
        channel_swr_events = swr_events[channel]

    # Plot the LFP trace
    plt.figure(figsize=(12, 6))
    plt.plot(time_vector, lfp_trace, color='black', label="LFP Trace")
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude (µV)")
    plt.title(f"LFP Trace with SWRs (Channel {channel + 1})")

    # Add orange shaded bars for SWR events
    for start, end in channel_swr_events:
        plt.axvspan(start / fs, end / fs, color='orange', alpha=0.4, label="SWR" if start == channel_swr_events[0][0] else "")

    plt.legend(loc='upper right')
    plt.tight_layout()
    plt.show()

# Main processing pipeline with visualization
def process_and_plot_lfp(lfp_path, nch, channel_to_plot, start_time=None, end_time=None,
                         original_fs=30000, target_fs=1000, ripple_band=(150, 250)):
    # Open memmap and reshape
    data = MapLFPs(lfp_path, nch)
    print(f"Original data shape: {data.shape}")

    # Downsample
    print("Downsampling...")
    downsampled_data = downsample(data, original_fs, target_fs)
    print(f"Downsampled data shape: {downsampled_data.shape}")

    # Detect SWRs
    print("Detecting sharp-wave ripples...")
    swr_events = detect_swr(downsampled_data, target_fs, ripple_band)

    # # Plot the LFP with SWRs
    # print(f"Plotting channel {channel_to_plot + 1}...")
    for i in range(nch):
        plot_lfp_with_swr(downsampled_data, swr_events, channel=i,
                        fs=target_fs, start_time=start_time, end_time=end_time)

# Usage example

# Usage example
lfp_path = HPC_PATH  # Replace with your memmap file
shape = 64  # number of channels
# channel_to_plot = 0  # Plot channel 1 (0-indexed)
start_time = 10  # Start time in seconds for the plot
end_time = 12  # End time in seconds for the plot
channel_to_plot=1
process_and_plot_lfp(lfp_path, shape, channel_to_plot, start_time, end_time)


Original data shape: (64, 72432128)
Downsampling...
