Import

In [10]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt
from ipywidgets import interact, FloatSlider, IntSlider


Band pass for cleaning the noise

In [11]:
from scipy.signal import butter, filtfilt


def band_pass_filter(signal, lowcut=None, highcut=None, sampling_rate=30000):
    """
    Apply a band-pass, low-pass, or high-pass filter to the signal.

    Parameters:
        signal (np.ndarray): The input signal.
        lowcut (float or None): The lower cutoff frequency (Hz) or None for no lower bound.
        highcut (float or None): The upper cutoff frequency (Hz) or None for no upper bound.
        sampling_rate (float): The sampling rate of the signal (Hz).

    Returns:
        np.ndarray: The filtered signal.
    """
    nyquist = 0.5 * sampling_rate

    if lowcut is None and highcut is None:
        raise ValueError("At least one of lowcut or highcut must be specified.")

    if lowcut is None:  # Low-pass filter
        high = highcut / nyquist
        b, a = butter(2, high, btype="low")
    elif highcut is None:  # High-pass filter
        low = lowcut / nyquist
        b, a = butter(2, low, btype="high")
    else:  # Band-pass filter
        low = lowcut / nyquist
        high = highcut / nyquist
        b, a = butter(2, [low, high], btype="band")

    return filtfilt(b, a, signal)


Identify spikes

In [12]:
def detect_spike_starts(signal, lower_threshold, upper_threshold, min_distance=30):
    """
    Detect spike start indices in a signal based on amplitude thresholds.

    Parameters:
        signal (np.ndarray): The filtered signal.
        lower_threshold (float): Minimum amplitude threshold for detecting spikes.
        upper_threshold (float): Maximum amplitude threshold for detecting spikes.
        min_distance (int): Minimum distance between consecutive spikes (in samples).

    Returns:
        np.ndarray: Indices of detected spike starts.
    """
    spikes = np.where((signal < lower_threshold) | (signal > upper_threshold))[0]
    if len(spikes) == 0:
        return np.array([])

    selected_spikes = [spikes[0]]
    for spike in spikes[1:]:
        if spike - selected_spikes[-1] >= min_distance:
            selected_spikes.append(spike)

    return np.array(selected_spikes)


Function for Spike Classification

In [13]:
def classify_spikes(signal, spike_starts, window=50, lower_threshold=-5, upper_threshold=5):
    """
    Classify spikes into unipolar or bipolar types.

    Parameters:
        signal (np.ndarray): The filtered signal.
        spike_starts (np.ndarray): Indices of detected spike starts.
        window (int): Number of samples around the spike start to analyze.
        lower_threshold (float): Lower amplitude threshold.
        upper_threshold (float): Upper amplitude threshold.

    Returns:
        dict: Dictionary with spike indices and their types ('unipolar' or 'bipolar').
    """
    classifications = {}
    for spike in spike_starts:
        spike_window = signal[spike : spike + window]
        min_value = np.min(spike_window)
        max_value = np.max(spike_window)
        if min_value < lower_threshold and max_value > upper_threshold:
            classifications[spike] = "bipolar"
        elif min_value < lower_threshold or max_value > upper_threshold:
            classifications[spike] = "unipolar"
        else:
            classifications[spike] = "none"
    return classifications


Graph of the spikes (Plotting)

In [14]:
def plot_spikes(signal, spike_starts, classifications, window=50):
    """
    Plot spikes and classify them with markers.

    Parameters:
        signal (np.ndarray): The filtered signal.
        spike_starts (np.ndarray): Indices of detected spike starts.
        classifications (dict): Dictionary of spike indices and their types.
        window (int): Window size around each spike start for analysis.
    """
    plt.figure(figsize=(15, 6))
    plt.plot(signal, label="Filtered Signal", alpha=0.7)

    # Separate bipolar and unipolar spikes for legend clarity
    bipolar_spikes = [spike for spike, spike_type in classifications.items() if spike_type == "bipolar"]
    unipolar_spikes = [spike for spike, spike_type in classifications.items() if spike_type == "unipolar"]

    # Add markers for bipolar spikes
    plt.scatter(bipolar_spikes, signal[bipolar_spikes], color="red", label="Bipolar Spikes", alpha=0.8)

    # Add markers for unipolar spikes
    plt.scatter(unipolar_spikes, signal[unipolar_spikes], color="yellow", label="Unipolar Spikes", alpha=0.8)

    plt.title("Signal with Detected Spikes")
    plt.xlabel("Sample Index")
    plt.ylabel("Amplitude (µV)")
    plt.legend(loc="upper right")
    plt.grid()
    plt.show()


Procceing and activation of functions

In [27]:
# Step 1: Upoad the data
signal = np.load(r"C:\test1\matan_bootcamp_python\bic13-ch259.npy")

# Step 2: Trim the signal to 30,000 samples (1 second)
signal = signal[:30000]
#step 3-5: interactive filtering
filtered_signal = band_pass_filter(signal, lowcut=300, highcut=3000)


def interactive_filter(lowcut, highcut):
    print(f"Filtering with lowcut={lowcut}, highcut={highcut}")
    filtered_signal = band_pass_filter(signal, lowcut=lowcut, highcut=highcut)

    # Display the filtered signal
    plt.figure(figsize=(15, 6))
    plt.plot(signal[:30000], label="Original Signal", alpha=0.5)
    plt.plot(filtered_signal[:30000], label="Filtered Signal", alpha=0.7)
    plt.title(f"Filtered Signal ({lowcut} - {highcut} Hz)")
    plt.xlabel("Sample Index")
    plt.ylabel("Amplitude")
    plt.legend()
    plt.grid()
    plt.show()

# Interactive widgets for filtering
interact(
    interactive_filter,
    lowcut=FloatSlider(min=0, max=2000, step=50, value=300),
    highcut=FloatSlider(min=2000, max=10000, step=50, value=3000),
)


interactive(children=(FloatSlider(value=300.0, description='lowcut', max=2000.0, step=50.0), FloatSlider(value…

<function __main__.interactive_filter(lowcut, highcut)>

In [29]:
# step 6-7: discover and classify of spokes
def interactive_spike_detection(filtered_signal, lower_threshold=-5, upper_threshold=5, min_distance=30):
    print(f"Detecting spikes with thresholds {lower_threshold}, {upper_threshold} and minimum distance {min_distance}")
    
      # Ensure the filtered signal is available

    if filtered_signal is None:
        print("Please filter the signal first.")
        return

    # Detect spikes
    spike_starts = detect_spike_starts(
        signal=filtered_signal,
        lower_threshold=lower_threshold,
        upper_threshold=upper_threshold,
        min_distance=min_distance,
    )

    # Classify spikes
    classifications = classify_spikes(filtered_signal, spike_starts)

    # Plot spikes
    plot_spikes(filtered_signal, spike_starts, classifications)

    # Print spike counts
    num_bipolar = len([s for s, t in classifications.items() if t == "bipolar"])
    num_unipolar = len([s for s, t in classifications.items() if t == "unipolar"])
    print(f"Number of Bipolar Spikes: {num_bipolar}")
    print(f"Number of Unipolar Spikes: {num_unipolar}")


# Interactive widgets for spike detection
interact(
    lambda lower_threshold, upper_threshold, min_distance: interactive_spike_detection(
        filtered_signal, lower_threshold, upper_threshold, min_distance
    ),
    lower_threshold=FloatSlider(min=-10, max=0, step=0.5, value=-5),
    upper_threshold=FloatSlider(min=0, max=10, step=0.5, value=5),
    min_distance=IntSlider(min=10, max=100, step=10, value=30),
)

interactive(children=(FloatSlider(value=-5.0, description='lower_threshold', max=0.0, min=-10.0, step=0.5), Fl…

<function __main__.<lambda>(lower_threshold, upper_threshold, min_distance)>