In [None]:
import mne
import numpy as np
import os
import datetime
from mne.io.edf.edf import RawEDF

In [None]:
TUH_EEG_SEIZURE_CORPUS = "/Users/jannis/Git/tuh-eeg-seizure-detection/data/raw"
OUTPUT = "/Users/jannis/Git/tuh-eeg-seizure-detection/data/preprocessed"

SAMPLING_FREQUENCY = 250
WINDOW_LENGTH = 1
CONFIGURATIONS = ["01_tcp_ar"]
CHANNELS = ["EEG FP1-REF", "EEG FP2-REF", "EEG F7-REF", "EEG F3-REF", "EEG F4-REF", "EEG F8-REF", "EEG T3-REF", "EEG C3-REF", "EEG C4-REF", "EEG T4-REF", "EEG T5-REF", "EEG P3-REF", "EEG P4-REF", "EEG T6-REF", "EEG O1-REF", "EEG O2-REF", "EEG CZ-REF", "EEG A1-REF", "EEG A2-REF"]

In [None]:
class Recording:
    path: str
    
    def __init__(self, path: str):
        self.path = path
    
    def load(self, channels: [str] = CHANNELS):
        raw = mne.io.read_raw_edf(self.path, preload=True)
        raw.pick(channels)
        
        annotations = self._parse_annotations()
        raw.set_annotations(annotations)
        
        return raw
            
    def _parse_annotations(self):
        annotation_file = self.path.replace(".edf", ".csv_bi")
        annotations: [str] = []
        
        with open(annotation_file, "r") as file:
            for line in file:
                annotations.append(line)
        
        annotations = annotations[6:]
        
        onset = []
        duration = []
        description = []
        
        for annotation in annotations:
            parts = annotation.split(",")
            start = float(parts[1])
            end = float(parts[2])
            label = parts[3]
            
            onset.append(start)
            duration.append(end - start)
            description.append(label)
            
        return mne.Annotations(onset, duration, description)
        
class Session:
    session_id: str
    recordings: [Recording]
    configuration: str
    
    def __init__(self, session_id: str, configuration: str):
        self.session_id = session_id
        self.configuration = configuration
        self.recordings = []
        
    def add_recording(self, recording: Recording):
        self.recordings.append(recording)
        
class Patient:
    patient_id: str
    sessions: [Session]
    
    def __init__(self, patient_id: str):
        self.patient_id = patient_id
        self.sessions = []
        
    def add_session(self, session: Session):
        self.sessions.append(session)
        
class Dataset:
    patients: [Patient]
    
    def __init__(self):
        self.patients = []
        
    def add_patient(self, patient: Patient):
        self.patients.append(patient)

In [None]:
def load_dataset(set_name: str, configurations: [str] = CONFIGURATIONS):
    path = f"{TUH_EEG_SEIZURE_CORPUS}/edf/{set_name}"
    dataset = Dataset()
    
    for patient_id in os.listdir(path):
        patient_path = f"{path}/{patient_id}"
        if not os.path.isdir(patient_path):
            continue
        
        patient = Patient(patient_id)
        
        for session_id in os.listdir(patient_path):
            session_path = f"{patient_path}/{session_id}"
            if not os.path.isdir(session_path):
                continue
            
            session = Session(session_id, set_name)
            
            for configuration in os.listdir(session_path):
                configuration_path = f"{session_path}/{configuration}"
                if not os.path.isdir(configuration_path) or configuration not in configurations:
                    continue
                    
                session.configuration = configuration
                
                for recording_id in os.listdir(configuration_path):
                    recording_path = f"{configuration_path}/{recording_id}"
                    if not os.path.isfile(recording_path) or not recording_path.endswith(".edf"):
                        continue
                    
                    recording = Recording(recording_path)
                    session.add_recording(recording) 
            
            if len(session.recordings) > 0:
                patient.add_session(session)
            
        if len(patient.sessions) > 0:
            dataset.add_patient(patient)
    
    return dataset

In [None]:
dev_dataset = load_dataset("dev")
eval_dataset = load_dataset("eval")
train_dataset = load_dataset("train")
print(f"Dev set: {len(dev_dataset.patients)}")
print(f"Eval set: {len(eval_dataset.patients)}")
print(f"Train set: {len(train_dataset.patients)}")

In [None]:
# combine all datasets
dataset = Dataset()
dataset.patients = dev_dataset.patients + eval_dataset.patients + train_dataset.patients
print(f"Combined set: {len(dataset.patients)}")

In [None]:
# find patient with id aaaaamnk
patient = None
for p in dataset.patients:
    if p.patient_id == "aaaaamnk":
        patient = p
        break

print(f"Patient {patient.patient_id} has {len(patient.sessions)} sessions")
sess = patient.sessions[0]
rec = sess.recordings[0]
raw = mne.io.read_raw_edf(rec.path, preload=True)
raw.plot()

In [None]:
# get 10 patients, skip first 10 
patients = dataset.patients[100:120]

for patient in patients:
    for session in patient.sessions:
        for recording in session.recordings:
            recording_name = recording.path.split("/")[-1].replace(".edf", "")
            recording_name = recording_name.split("_")[-1]
            
            raw = recording.load(CHANNELS)
            raw.resample(SAMPLING_FREQUENCY)
            
            seizure_samples = []
            non_seizure_samples = []
            
            for i in range(len(raw.annotations)):
                onset = raw.annotations[i]["onset"]
                duration = raw.annotations[i]["duration"]
                description = raw.annotations[i]["description"]
                
                # prevent overlong 
                if onset + duration > raw.times[-1]:
                    duration = raw.times[-1] - onset
                
                sample = raw.copy().crop(onset, onset + duration)
                
                if description == "bckg":
                    non_seizure_samples.append(sample)
                else:
                    seizure_samples.append(sample)
                    
            if len(seizure_samples) == 0 and len(non_seizure_samples) == 0:
                continue
            
            os.makedirs(f"{OUTPUT}/{patient.patient_id}/seizures", exist_ok=True)
            os.makedirs(f"{OUTPUT}/{patient.patient_id}/non_seizures", exist_ok=True)
            
            for sample in seizure_samples:
                num_of_epochs = int(sample.times[-1] / WINDOW_LENGTH)
                for i in range(num_of_epochs):
                    epoch_path = f"{OUTPUT}/{patient.patient_id}/seizures/{session.session_id}_{recording_name}_{i}_raw.fif"
                   
                    # check if epoch already exists
                    if os.path.isfile(epoch_path):
                        continue
                    
                    start = i * WINDOW_LENGTH
                    end = start + WINDOW_LENGTH
                    epoch = sample.copy().crop(start, end)
                    
                    # sometimes measurement date is not set correctly -> we don't need it, so just set it to now
                    epoch.set_meas_date(datetime.datetime.now(datetime.UTC))
                    
                    epoch.save(epoch_path, overwrite=False)
                    
            for sample in non_seizure_samples:
                num_of_epochs = int(sample.times[-1] / WINDOW_LENGTH)
                for i in range(num_of_epochs):
                    epoch_path = f"{OUTPUT}/{patient.patient_id}/non_seizures/{session.session_id}_{recording_name}_{i}_raw.fif"
                    
                    # check if epoch already exists
                    if os.path.isfile(epoch_path):
                        continue
                    
                    start = i * WINDOW_LENGTH
                    end = start + WINDOW_LENGTH
                    epoch = sample.copy().crop(start, end)
                    
                    # sometimes measurement date is not set correctly -> we don't need it, so just set it to now
                    epoch.set_meas_date(datetime.datetime.now(datetime.UTC))
                    
                    epoch.save(epoch_path, overwrite=False)

In [None]:
p = "/Users/jannis/Git/tuh-eeg-seizure-detection/data/preprocessed/aaaaaiij/seizures"
seizures = []
for file in os.listdir(p):
    file_path = os.path.join(p, file)
    
    if not os.path.isfile(file_path) or not file_path.endswith(".fif"):
        continue
        
    seizures.append(file_path)

raw = mne.io.read_raw_fif(seizures[2], preload=True)
raw.plot()

non_seizures = []
p = "/Users/jannis/Git/tuh-eeg-seizure-detection/data/preprocessed/aaaaamnk/non_seizures"
for file in os.listdir(p):
    file_path = os.path.join(p, file)
    
    if not os.path.isfile(file_path) or not file_path.endswith(".fif"):
        continue
        
    non_seizures.append(file_path)
    
raw = mne.io.read_raw_fif(non_seizures[2], preload=True)
raw.plot()