# Libraries

In [None]:
import pyedflib
import plotly.express as px
from pathlib import Path
import os
import numpy as np
from scipy.signal import butter, filtfilt

# EDF class

In [None]:
class EDFFile:
    def __init__(self, file: Path):
        self.file_path = file.absolute().as_posix()
        self.file_name = file.name
        self.file_handle = None
        self.signal_headers = None
        self.sample_frequency = None
        self.raw_signals = []
        self.filtered_signals = []
        self.annotations = []
        
    def load_raw_signals(self):
        if self.file_handle is None:
            self.file_handle = pyedflib.EdfReader(self.file_path)
            
        num_signals = self.file_handle.signals_in_file
        
        for i in range(num_signals):
            signal_buffer = self.file_handle.readSignal(i)
            self.raw_signals.append(signal_buffer)
    
    def load_signal_headers(self):
        if self.file_handle is None:
            self.file_handle = pyedflib.EdfReader(self.file_path)

        self.signal_headers = self.file_handle.getSignalHeaders()

    def load_annotations(self):
        if self.file_handle is None:
            self.file_handle = pyedflib.EdfReader(self.file_path)
        
        self.annotations = self.file_handle.readAnnotations()

    def load_sample_frequency(self):
        if self.file_handle is None:
            self.file_handle = pyedflib.EdfReader(self.file_path)
        
        self.sample_frequency = self.file_handle.getSampleFrequencies()[0]
    
    # def filter_signals(self, payload:dict, signal_labels=['BI', 'Atmung/Respiration']):
    #     signals = filter(lambda l : l[0]["label"].split()[0] in signal_labels, 
    #                      zip(self.signal_headers, self.raw_signals))
        
    #     filtered = [payload["function"](signal, headers["sample_rate"], payload["cutoff"], payload["order"], payload["f_type"])
    #             for headers, signal in signals]  # Make it kwargs
        
    #     self.filtered_signals = filtered
        
    #     return filtered
        
    def close(self):
        if self.file_handle is not None:
            self.file_handle.close()
            self.file_handle = None

In [None]:
def read_files_from_dir(directory: Path):
    extensions = ["edf", "bdf"]

    return [EDFFile(Path(file)) for file in os.scandir(directory) 
            if file.is_file() and file.name.endswith(tuple(extensions))]

# Filter design

In [None]:
def apply_butterworth_lowpass(signal, cutoff, fs, order=3):
    nyquist = 0.5 * fs
    normalized_cutoff = cutoff / nyquist
    b, a = butter(order, normalized_cutoff, btype='low', analog=False)
    filtered_signal = filtfilt(b, a, signal)
    
    return filtered_signal

def apply_notch_filter(signal, notch_freq, fs):
    nyquist = 0.5 * fs
    notch_normalized = notch_freq / nyquist
    b, a = butter(2, [notch_normalized - 0.02, notch_normalized + 0.02], btype='bandstop', analog=False)
    filtered_signal = filtfilt(b, a, signal)
    
    return filtered_signal

def apply_butterworth_filters(signal, cutoff, fs, order=3, notch_frequencies=None):
    # Apply low-pass filter
    filtered_signal = apply_butterworth_lowpass(signal, cutoff, fs, order)

    # Apply notch filters
    if notch_frequencies is not None:
        for notch_freq in notch_frequencies:
            filtered_signal = apply_notch_filter(filtered_signal, notch_freq, fs)

    return filtered_signal

In [None]:
directory = Path("data/edf/")
files = read_files_from_dir(directory)

In [None]:
# Dict to specify which functions to apply to the signals
#payload={"function": butterworth_filter, "order": 3, "cutoff": 15, "f_type": "low"}

for signal_file in files:
    signal_file.load_raw_signals()
    signal_file.load_signal_headers()
    signal_file.load_annotations()
    signal_file.load_sample_frequency()
    #signal_file.filter_signals(payload)

# Visualization

In [None]:
px.line(files[-2].raw_signals[0])

In [None]:
px.line(apply_butterworth_filters(files[-2].raw_signals[0], 15, files[-2].sample_frequency, 3, [50, 100, 150]))

# Testing zone

## Save EDF file file.name + _filtered.edf

In [None]:
from pyedflib import highlevel

# write an edf file
signals = []
channel_names = []

for channel_name, signal, fs in self.filtered_signals:
    channel_names.append(channel_name)
    signals.append(signal)

signal_headers = highlevel.make_signal_headers(channel_names, sample_frequency=fs)

header = highlevel.make_header(patientname='patient', gender='Female')

if highlevel.write_edf(f'{self.file_name}_filtered.edf', signals, signal_headers, header):
    print("The file has been successfully saved.")
else:
    print("Something went wrong with the saving process.")

In [None]:
 modified_data = [signal * modification_factor for signal in signal_data]

new_edf_file = pyedflib.EdfWriter(output_file_path, num_signals)
for i in range(num_signals):
    new_edf_file.setSignalHeader(i, label=signal_labels[i], samplefrequency=sample_frequency[i])

for i in range(num_signals):
    new_edf_file.writeSignal(i, modified_data[i])

edf_file.close()
new_edf_file.close()

In [None]:
t = EDFFile(f'filtered/{self.file_name}.edf')

In [None]:
for file in files:
    file.close()