In [None]:
import os
import numpy as np
from sklearn.model_selection import StratifiedGroupKFold

In [None]:
class Sample:
    def __init__(self, session_id, recording_id, seizure):
        self.session_id = session_id
        self.recording_id = recording_id
        self.seizure = seizure
    
class Patient:
    def __init__(self, patient_id, samples: [Sample]):
        self.patient_id = patient_id
        self.samples = samples

In [None]:
PREPROCESSED_PATH = "/Users/jannis/Git/tuh-eeg-seizure-detection/data/preprocessed"

In [None]:
def load_preprocessed_data():
    patients = []
    for patient in os.listdir(PREPROCESSED_PATH):
        patient_path = os.path.join(PREPROCESSED_PATH, patient)
        
        if not os.path.isdir(patient_path):
            continue
            
        samples = []
        
        for non_seizure_sample in os.listdir(os.path.join(patient_path, "non_seizures")):
            non_seizure_sample_path = os.path.join(patient_path, "non_seizures", non_seizure_sample)
            
            if not os.path.isfile(non_seizure_sample_path) or not non_seizure_sample.endswith(".fif"):
                continue
                
            parts = non_seizure_sample.split("_")
            session_id = parts[0] + "_" + parts[1]
            recording_id = parts[2]
            sample = Sample(session_id, recording_id, False)
            samples.append(sample)
        
        for seizure_sample in os.listdir(os.path.join(patient_path, "seizures")):
            seizure_sample_path = os.path.join(patient_path, "seizures", seizure_sample)
            
            if not os.path.isfile(seizure_sample_path) or not seizure_sample.endswith(".fif"):
                continue
                
            parts = seizure_sample.split("_")
            session_id = parts[0] + "_" + parts[1]
            recording_id = parts[2]
            sample = Sample(session_id, recording_id, True)
            samples.append(sample)
            
        if len(samples) > 0:
            patient = Patient(patient, samples)
            patients.append(patient)
        
    
    return patients

load_preprocessed_data()

In [None]:
x = []
y = []
groups = []

# Assuming load_preprocessed_data() loads your preprocessed data
for patient in load_preprocessed_data():
    for sample in patient.samples:
        x.append(sample.session_id + "_" + sample.recording_id)
        y.append(sample.seizure)
        groups.append(patient.patient_id)

x = np.array(x)
y = np.array(y)
groups = np.array(groups)

# Set up StratifiedGroupKFold with 5 splits 
cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)

# Generate the splits
splits = list(cv.split(x, y, groups))

# Randomly choose one split for test
test_fold_idx = np.random.choice(len(splits))

# Get the train and test indices
train_idx, test_idx = splits[test_fold_idx]

# Split the data
x_train, y_train = x[train_idx], y[train_idx]
x_test, y_test = x[test_idx], y[test_idx]

# Get the number of positive/negative samples in both train and test and their ratio
unique, counts = np.unique(y_train, return_counts=True)
train_ratio = counts[1] / counts[0]
print(f"Train - Unique: {unique} Counts: {counts}")
print(f"Train ratio: {train_ratio}")

unique, counts = np.unique(y_test, return_counts=True)
test_ratio = counts[1] / counts[0]
print(f"Test - Unique: {unique} Counts: {counts}")
print(f"Test ratio: {test_ratio}")