In [1]:
import h5py
import numpy as np
import scipy.signal as signal
import matplotlib.pyplot as plt
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from PyQt5 import QtWidgets, QtCore
from scipy.signal import savgol_filter
from scipy.signal import spectrogram
# Open the v7.3+ .mat file using h5py

# This opens the .mat file (which is an HDF5 file under the hood) in read mode ('r').
# The with statement ensures the file is properly closed after use, even if an error occurs
with h5py.File(r'C:\Users\congy\Desktop\Gritton_Lab\Katamine_Data\622\downsampled_LFP.mat', 'r') as f:
    # Print all top-level keys (variable names)
    print(list(f.keys()))
    
    # [:] reads all data from that variable into memory as a NumPy array.
    data = f['data_combine_Low_downsample'][:]  # Read and load into memory
    print(data.shape)
    print(data[:,0].shape)

['data_combine_Low_downsample']
(4968346, 32)
(4968346,)


In [2]:
# designed to remove a narrow band of frequencies,
# such as 60 Hz line noise (common in the U.S. due to electrical power systems).
def create_notch_filter(fs, f0, Q):
    
    """
    Create a digital notch filter to remove 60Hz noise
    
    Parameters:
    fs : float
        Sampling frequency in Hz
    f0 : float
        Notch frequency (default 60Hz)
    Q : float
        Quality factor: Controls how narrow or wide the notch is
            Higher Q → narrower notch (removes only very close to 60 Hz).
            Lower Q → wider notch (removes more surrounding frequencies).
        
    Returns:
    b, a : ndarray
        Numerator and denominator polynomials of the IIR filter
    """
    
    ### Normalize notch frequency by Nyquist frequency
    
    # w0 is the normalized notch frequency
    # fs / 2 is the Nyquist frequency (half of the sampling rate).
    # The Nyquist frequency is the maximum frequency that can be accurately captured when sampling a continuous signal
    # Aliasing is a phenomenon in digital signal processing where high-frequency signals are misrepresented as lower frequencies after sampling.
    # It happens when a signal is sampled too slowly, i.e., below the Nyquist rate
    w0 = f0/(fs/2)
    
    ### Create the notch filter
    
    # Digital filter design functions (like scipy.signal.iirnotch) expect normalized frequencies — not raw Hz.
    # iirnotch creates a 2nd-order Infinite Impulse Response (IIR) notch filter.
    # b: numerator of the filter's transfer function.
    # a: denominator
    b, a = signal.iirnotch(w0, Q)
    
    return b, a

In [3]:
def lowpass_filter(data, fs, cutoff=40, order=6):
    # order: determines how steep the filter is (higher = sharper transition)
    # A low-pass filter passes (keeps) frequencies below the cutoff.
    # It attenuates (reduces) frequencies above the cutoff.
    
    nyquist = 0.5 * fs
    normal_cutoff = cutoff / nyquist
    b, a = signal.butter(order, normal_cutoff, btype='low', analog=False)
    return signal.filtfilt(b, a, data)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PyQt5 import QtWidgets, QtCore
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from scipy.signal import savgol_filter

# At the start and end of the signal, there's not enough data to fully center a window around those time points.

# Spectrogram truncates edges:
# With nperseg=1024, the spectrogram window slides over the signal.
# At the start and end, it needs full windows to compute FFT, so it cannot compute values for the very first and last portions of the signal.
# As a result, the returned t_spec is shorter than the full duration of your signa

# changing nperseg to a smaller number can reduce or even eliminate the blank areas on both ends because:
# The larger the nperseg, the more data the spectrogram function needs for each FFT window, which shortens the usable time span of the output.
# Smaller nperseg means shorter FFT windows, which gives more time bins and better temporal coverage (i.e., fewer or no missing edges), but at the cost of lower frequency resolution



# This defines a new class DataPlotter that inherits from QWidget, which is a basic GUI window in PyQt5
class DataPlotter(QtWidgets.QWidget):
    def __init__(self, raw_data, filtered_data, sampling_rate):
        # super().__init__() calls the constructor (__init__) of the parent class, which is QtWidgets.QWidget in this case.
        # DataPlotter is a subclass of QtWidgets.QWidget, which means it inherits GUI behavior from QWidget.
        super().__init__()

        self.raw_data = raw_data
        self.smoothed_data = filtered_data
        self.sampling_rate = sampling_rate
        self.window_size_sec = 1  # default window size in seconds
        
        # This calls a method that sets up the graphical interface (buttons, canvas, inputs, etc.)
        self.initUI()

    def initUI(self):
        layout = QtWidgets.QVBoxLayout(self)

        # Matplotlib Figure
        #self.figure, self.ax = plt.subplots()
        self.figure, (self.ax, self.ax2) = plt.subplots(2, 1, sharex=True, figsize=(5, 6))
        self.canvas = FigureCanvas(self.figure)
        layout.addWidget(self.canvas)

        # Channel selection
        # QtWidgets.QComboBox() creates a dropdown (combo box) for selecting the channel to plot.
        self.channel_selector = QtWidgets.QComboBox()
        # self.data.shape[1] gives the number of channels in your data (columns of the data array).
        for i in range(self.raw_data.shape[1]):
            self.channel_selector.addItem(f"Channel {i+1}")
        self.channel_selector.currentIndexChanged.connect(self.update_plot)
        layout.addWidget(self.channel_selector)

        # Start Time and Window Size Input
        time_layout = QtWidgets.QHBoxLayout()

        self.start_time_input = QtWidgets.QLineEdit()
        self.start_time_input.setPlaceholderText("Start time (s)")
        time_layout.addWidget(QtWidgets.QLabel("Start Time (s):"))
        time_layout.addWidget(self.start_time_input)

        self.window_size_input = QtWidgets.QLineEdit()
        self.window_size_input.setPlaceholderText("Infinity (s)")
        time_layout.addWidget(QtWidgets.QLabel("Stop Time (s):"))
        time_layout.addWidget(self.window_size_input)

        layout.addLayout(time_layout)
        
        
        # Slider for moving through time
        self.slider = QtWidgets.QSlider(QtCore.Qt.Horizontal)
        self.slider.setMinimum(0)
        self.slider.setMaximum(self.raw_data.shape[0] - 1)
        self.slider.setValue(0)
        self.slider.valueChanged.connect(self.slide_plot)
        layout.addWidget(self.slider)
        

        # Plot Button
        self.plot_button = QtWidgets.QPushButton("Plot")
        self.plot_button.clicked.connect(self.update_plot)
        layout.addWidget(self.plot_button)

        self.setLayout(layout)


        
        # Set default values
        self.start_time_input.setText('0')
        self.window_size_input.setText(str(self.window_size_sec))

        self.update_plot()

    def update_plot(self):
        """Update plot based on start time and window size entered by user."""
        try:
            start_time = float(self.start_time_input.text())
            window_size_sec = float(self.window_size_input.text())
        except ValueError:
            print("Invalid input! Please enter valid numbers for start time and window size.")
            return

        # Convert times to indices
        start_idx = int(start_time * self.sampling_rate)
        #stop_idx = int((start_time + window_size_sec) * self.sampling_rate)
        stop_idx = int((window_size_sec) * self.sampling_rate)
        
        window_len = stop_idx - start_idx
        nperseg = min(512, window_len)
        noverlap = int(nperseg * 0.9)
        
        

        if start_idx < 0 or stop_idx > self.raw_data.shape[0] or start_idx >= stop_idx:
            print("Invalid time range. Please check your start time and window size.")
            return

        channel_idx = self.channel_selector.currentIndex()

        # Clear and replot
        self.ax.clear()
        time_axis = np.arange(start_idx, stop_idx) / self.sampling_rate
        
        
        self.ax.plot(time_axis, self.raw_data[start_idx:stop_idx, channel_idx], label='Original')
        self.ax.plot(time_axis, self.smoothed_data[start_idx:stop_idx, channel_idx], label='Filtered')
        

        self.ax.set_xlabel('Time (s)')
        self.ax.set_ylabel('Amplitude')
        self.ax.set_title(f'Channel {channel_idx + 1}: {start_time:.2f}s - {window_size_sec:.2f}s window')
        self.ax.legend()
        
        
        
        
        # Clear second axis
        self.ax2.clear()
        time_axis = np.arange(start_idx, stop_idx) / self.sampling_rate
        
        signal = self.smoothed_data[start_idx:stop_idx, channel_idx]
        
        



        # 1024    950
        f, t_spec, Sxx = spectrogram(signal, self.sampling_rate, nperseg=nperseg, noverlap=noverlap)

        Sxx_db = 10 * np.log10(Sxx+1e-20)
        # np.percentile(Sxx_db, 2) returns the value below which 2% of the data falls (i.e. very dark/low values).
        vmin = np.percentile(Sxx_db, 2)
        # np.percentile(Sxx_db, 98) returns the value below which 98% of the data falls (i.e. near the brightest values).
        vmax = np.percentile(Sxx_db, 98)

        plt.pcolormesh(t_spec, f, Sxx_db, shading='gouraud', vmin=vmin, vmax=vmax)


        
        self.ax2.set_xlim(start_idx / self.sampling_rate, stop_idx / self.sampling_rate)
        #self.ax2.set_xlim(t_spec_shifted[0], t_spec_shifted[-1])

        self.ax2.set_ylim(0, 100)  # Limit to 0–100 Hz
        self.ax2.set_xlabel('Time (s)')
        self.ax2.set_ylabel('Frequency (Hz)')
        self.ax2.set_title(f'Spectrogram - Channel {channel_idx + 1}')

        
        self.canvas.draw()



    def slide_plot(self):
        start_idx = self.slider.value()

        # Get text safely from the input box
        input_text = self.window_size_input.text()

        try:
            window_size_sec = float(input_text) if input_text else self.window_size_sec
        except ValueError:
            QtWidgets.QMessageBox.warning(self, "Input Error", "Please enter a valid number for window size.")
            return

        window_size_samples = int(window_size_sec * self.sampling_rate)
        end_idx = start_idx + window_size_samples
        channel_idx = self.channel_selector.currentIndex()

        # Ensure end index doesn't go out of bounds
        end_idx = min(end_idx, self.raw_data.shape[0])
        
        
        window_len = end_idx - start_idx
        nperseg = min(512, window_len)
        noverlap = int(nperseg * 0.9)

        self.ax.clear()
        time_axis = np.arange(start_idx, end_idx) / self.sampling_rate


        self.ax.plot(time_axis, self.raw_data[start_idx:end_idx, channel_idx], label='Original')
        self.ax.plot(time_axis, self.smoothed_data[start_idx:end_idx, channel_idx], label='Filtered')
        
        self.ax.set_xlabel('Time (s)')
        self.ax.set_ylabel('Amplitude')
        self.ax.set_title(f'Channel {channel_idx + 1}')
        
                # Clear second axis
        self.ax2.clear()
        time_axis = np.arange(start_idx, end_idx) / self.sampling_rate
        
        signal = self.smoothed_data[start_idx:end_idx, channel_idx]
        
#         pad = 500  # Half of nperseg
#         signal = np.pad(signal, (pad, pad), mode='reflect')
        
        f, t_spec, Sxx = spectrogram(signal, self.sampling_rate, nperseg=nperseg, noverlap=noverlap)
        
        t_spec_shifted = t_spec + (start_idx / self.sampling_rate)
        
        # Subtract the padding time to align the actual signal window back
        #t_spec_shifted = t_spec + (start_idx / self.sampling_rate) - (pad / self.sampling_rate)

        Sxx_db = 10 * np.log10(Sxx+1e-20)
        # np.percentile(Sxx_db, 2) returns the value below which 2% of the data falls (i.e. very dark/low values).
        vmin = np.percentile(Sxx_db, 2)
        # np.percentile(Sxx_db, 98) returns the value below which 98% of the data falls (i.e. near the brightest values).
        vmax = np.percentile(Sxx_db, 98)

        plt.pcolormesh(t_spec_shifted, f, Sxx_db, shading='gouraud', vmin=vmin, vmax=vmax)

        #quad=self.ax2.pcolormesh(t_spec_shifted, f, 10 * np.log10(Sxx + 1e-10), shading='gouraud')

        #quad.set_clim(-160, -80)
        
        self.ax2.set_xlim(start_idx / self.sampling_rate, end_idx / self.sampling_rate)
        
        #self.ax2.set_xlim(t_spec_shifted[0], t_spec_shifted[-1])

        self.ax2.set_ylim(0, 100)  # Limit to 0–100 Hz
        
        
        


        
        
        
        self.ax2.set_xlabel('Time (s)')
        self.ax2.set_ylabel('Frequency (Hz)')
        self.ax2.set_title(f'Spectrogram - Channel {channel_idx + 1}')

        
        self.canvas.draw()
        



if __name__ == '__main__':
    import sys


    sampling_rate = 24414.0625 / 20  # 1220.703125 Hz
    b, a = create_notch_filter(sampling_rate, f0=60.0, Q=30)
    
    filtered_data = np.zeros_like(data)
    smoothed_data = np.zeros_like(data)
    # number of channels 
    for ch in range(data.shape[1]):
        # Select all rows (i.e., all time samples) from column ch (i.e., channel ch) of the data array
        notch_applied = signal.lfilter(b, a, data[:, ch])
        filtered_data[:, ch] = lowpass_filter(notch_applied, sampling_rate)
        
        # Savitzky–Golay filter
        # Savitzky–Golay smoothing is ideal when you want to reduce noise but still retain the original shape of the signal
        # Fits a low-degree polynomial to a sliding window of data points using least squares,
        # and replaces the center point with the predicted value from that polynomial.
        smoothed_data[:, ch] = savgol_filter(filtered_data[:, ch], window_length=31, polyorder=3)
        
 
    
    app = QtWidgets.QApplication(sys.argv)
    #window = DataPlotter(data, sampling_rate, b, a)
    
    #window = DataPlotter(filtered_data, sampling_rate, b, a)
    window = DataPlotter(data, smoothed_data, sampling_rate)
    window.setWindowTitle('Data Plotter')
    window.resize(800, 600)
    window.show()
    sys.exit(app.exec_())
