# **Import**

In [None]:
import os
import mne
import glob
import math
import random
import ntpath
import shutil
import dhedfreader
import numpy as np
import pandas as pd
from tqdm import tqdm
from datetime import datetime
from mne.io import read_raw_edf

mne.set_log_level('ERROR')

# **Preparation**

In this cell, we extract the EEG signals from the PhysioNet dataset and split them by subjects.

In [None]:
def process_and_convert_edf(
    psg_files: list[str], 
    annotation_files: list[str], 
    output_dir: str, 
    selected_channel: str, 
    wake_edge_minutes: int
) -> None:
    """
    Processes EDF sleep data and converts it to NPZ format for classification.

    Parameters:
    - psg_files (list[str]): List of PSG (Polysomnography) EDF file paths.
    - annotation_files (list[str]): Corresponding list of annotation EDF file paths.
    - output_dir (str): Directory where processed NPZ files will be saved.
    - selected_channel (str): EEG channel to extract data from.
    - wake_edge_minutes (int): Number of minutes of wake data to include at the edges.

    Returns:
    - None: Saves the processed EEG and labels as NPZ files in the output directory.
    """

    EPOCH_SEC_SIZE = 30  # Standard epoch size in seconds
    STAGE_LABELS = {
        "W": 0, "N1": 1, "N2": 2, "N3": 3, "REM": 4, "UNKNOWN": 5
    }
    ANNOTATION_TO_LABEL = {
        "Sleep stage W": 0,
        "Sleep stage 1": 1,
        "Sleep stage 2": 2,
        "Sleep stage 3": 3,
        "Sleep stage 4": 3,  # Merged with stage 3
        "Sleep stage R": 4,
        "Sleep stage ?": 5,
        "Movement time": 5
    }

    for i in tqdm(range(len(psg_files)), desc="Processing files"):
        try:
            # Load PSG signal
            raw = read_raw_edf(psg_files[i], preload=True, stim_channel=None)
            sampling_rate = raw.info['sfreq']
            raw_data = raw.to_data_frame(scalings=100.0)[selected_channel].to_frame()
            raw_data.set_index(np.arange(len(raw_data)), inplace=True)

            # Read header information from PSG and annotation files
            with open(psg_files[i], 'r', encoding='iso-8859-1') as f:
                reader = dhedfreader.BaseEDFReader(f)
                reader.read_header()
                psg_header = reader.header

            with open(annotation_files[i], 'r', encoding='iso-8859-1') as f:
                reader = dhedfreader.BaseEDFReader(f)
                reader.read_header()
                annotation_header = reader.header
                _, _, annotations = list(zip(*reader.records()))

            # Convert timestamps to datetime objects
            psg_start_time = datetime.strptime(psg_header['date_time'], "%Y-%m-%d %H:%M:%S")
            annotation_start_time = datetime.strptime(annotation_header['date_time'], "%Y-%m-%d %H:%M:%S")

            if psg_start_time != annotation_start_time:
                raise ValueError("Mismatch in PSG and annotation start times.")

            # Process annotations
            labels, valid_indices, remove_indices = [], [], []
            for entry in annotations[0]:
                onset_sec, duration_sec, annotation_chars = entry
                annotation_text = "".join(annotation_chars)

                if annotation_text not in ANNOTATION_TO_LABEL:
                    continue

                label = ANNOTATION_TO_LABEL[annotation_text]

                if label != STAGE_LABELS["UNKNOWN"]:
                    if duration_sec % EPOCH_SEC_SIZE != 0:
                        raise ValueError("Non-multiples of 30-second epoch found in annotation duration.")

                    num_epochs = int(duration_sec / EPOCH_SEC_SIZE)
                    labels.append(np.full(num_epochs, label, dtype=int))
                    index_range = int(onset_sec * sampling_rate) + np.arange(duration_sec * sampling_rate, dtype=int)
                    valid_indices.append(index_range)
                else:
                    index_range = int(onset_sec * sampling_rate) + np.arange(duration_sec * sampling_rate, dtype=int)
                    remove_indices.append(index_range)

            # Flatten labels and indices
            labels = np.hstack(labels)
            valid_indices = np.hstack(valid_indices)

            # Remove invalid indices
            if remove_indices:
                remove_indices = np.hstack(remove_indices)
                selected_indices = np.setdiff1d(np.arange(len(raw_data)), remove_indices)
            else:
                selected_indices = np.arange(len(raw_data))

            # Ensure valid indices align with available data
            selected_indices = np.intersect1d(selected_indices, valid_indices)

            # Trim excess labels if needed
            if len(valid_indices) > len(selected_indices):
                extra_indices = np.setdiff1d(valid_indices, selected_indices)
                if np.all(extra_indices > selected_indices[-1]):  # Ensure trimming occurs at the end
                    trim_count = len(selected_indices) % int(EPOCH_SEC_SIZE * sampling_rate)
                    trim_label_count = int(math.ceil(trim_count / (EPOCH_SEC_SIZE * sampling_rate)))
                    selected_indices = selected_indices[:-trim_count]
                    labels = labels[:-trim_label_count]

            # Extract raw EEG data
            processed_data = raw_data.values[selected_indices]

            if len(processed_data) % (EPOCH_SEC_SIZE * sampling_rate) != 0:
                raise ValueError("Mismatch in EEG data length and epoch size.")

            num_epochs = len(processed_data) // (EPOCH_SEC_SIZE * sampling_rate)

            # Reshape EEG data into epochs
            if num_epochs > 0:
                x_data = np.asarray(np.split(processed_data, num_epochs)).astype(np.float32)[:, :, 0]
                y_labels = labels.astype(np.int32)

                assert len(x_data) == len(y_labels)

                # Extract wake edges
                non_wake_indices = np.where(y_labels != STAGE_LABELS["W"])[0]
                start_idx = max(non_wake_indices[0] - (wake_edge_minutes * 2), 0)
                end_idx = min(non_wake_indices[-1] + (wake_edge_minutes * 2), len(y_labels) - 1)
                final_indices = np.arange(start_idx, end_idx + 1)

                x_data = x_data[final_indices]
                y_labels = y_labels[final_indices]

                # Save processed data
                output_filename = ntpath.basename(psg_files[i]).replace("-PSG.edf", ".npz")
                save_dict = {
                    "x": x_data,
                    "y": y_labels,
                    "fs": sampling_rate,
                    "ch_label": selected_channel,
                    "header_raw": psg_header,
                    "header_annotation": annotation_header
                }
                np.savez(os.path.join(output_dir, output_filename), **save_dict)

        except Exception as e:
            print(f"Error processing file {psg_files[i]}: {e}")


def convert_edf_to_npz(
    data_path: str, 
    output_path: str, 
    channel: str = 'EEG Fpz-Cz', 
    split_ratios: tuple[float, float, float] = (0.8, 0.1, 0.1), 
    w_edge: int = 30
) -> None:
    """
    Converts EDF files into NPZ format and splits them into training, validation, and test sets.

    Parameters:
    - data_path (str): Path to the directory containing raw EDF files.
    - output_path (str): Path where the processed NPZ files will be stored.
    - channel (str): EEG channel to extract data from. Default is 'EEG Fpz-Cz'.
    - split_ratios (tuple[float, float, float]): Ratios for train, test, and validation splits.
    - w_edge (int): Edge window parameter for processing wake stages.

    Returns:
    - None: Saves processed files in the specified output directory.
    """

    subdirs = ['train', 'test', 'validation']

    # Validate split ratios
    if not np.isclose(sum(split_ratios), 1.0):
        raise ValueError("Split ratios must sum to 1.0 (e.g., (0.8, 0.1, 0.1)).")

    # Recreate the output directory
    if os.path.exists(output_path):
        shutil.rmtree(output_path)
    os.makedirs(output_path)

    # Create subdirectories for splits
    for subdir in subdirs:
        os.makedirs(os.path.join(output_path, subdir))

    # Fetch EDF files
    psg_files = sorted(glob.glob(os.path.join(data_path, "*PSG.edf")))
    annotation_files = sorted(glob.glob(os.path.join(data_path, "*Hypnogram.edf")))

    if not psg_files or not annotation_files:
        raise FileNotFoundError("No matching PSG or Hypnogram EDF files found in the specified directory.")

    if len(psg_files) != len(annotation_files):
        raise ValueError("Mismatch in the number of PSG and Hypnogram files.")

    # Convert lists to numpy arrays for easy indexing
    psg_files = np.array(psg_files)
    annotation_files = np.array(annotation_files)

    total_files = len(psg_files)
    train_ratio, test_ratio, val_ratio = split_ratios

    # Compute indices for splitting
    train_end = int(train_ratio * total_files)
    test_end = train_end + int(test_ratio * total_files)

    # Split data
    train_psg, train_ann = psg_files[:train_end], annotation_files[:train_end]
    test_psg, test_ann = psg_files[train_end:test_end], annotation_files[train_end:test_end]
    val_psg, val_ann = psg_files[test_end:], annotation_files[test_end:]

    # Process each split
    for subset, psg, ann in zip(subdirs, [train_psg, test_psg, val_psg], [train_ann, test_ann, val_ann]):
        print(f"Processing {subset} dataset ...")
        process_and_convert_edf(psg, ann, os.path.join(output_path, subset), channel, w_edge)


convert_edf_to_npz(
    data_path="../data/original/sleep-cassette",
    output_path="../data/split",
    channel="EEG Fpz-Cz",
    split_ratios=(0.8, 0.1, 0.1),
    w_edge=30
)

# **Windowing**

In this cell, we validate the previous cell and extract 30-second signals per subject

In [None]:
def convert_to_windows(data_path: str, file_list: list[str], save_path: str) -> None:
    """
    Splits EEG recordings into smaller windows and saves them as separate NPZ files.

    Parameters:
    - data_path (str): Path to the directory containing EEG NPZ files.
    - file_list (list[str]): List of NPZ filenames to process.
    - save_path (str): Directory where the processed files will be saved.

    Returns:
    - None: Saves the extracted EEG windows as individual NPZ files.
    """

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    counter = 0

    for file_name in tqdm(file_list, desc="Processing EEG files"):
        try:
            # Load EEG sample
            sample_path = os.path.join(data_path, file_name)
            sample = np.load(sample_path)

            # Extract EEG signal and labels
            x_signals = sample['x']
            y_labels = sample['y']

            for idx, signal in enumerate(x_signals):
                output_file = os.path.join(save_path, f"{str(counter).zfill(7)}.npz")
                np.savez(output_file, x=signal, y=y_labels[idx])
                counter += 1

        except Exception as e:
            print(f"Error processing file {file_name}: {e}")


# Define dataset paths
train_data_path = "../data/split/train/"
test_data_path = "../data/split/test/"
val_data_path = "../data/split/validation/"

# Retrieve file lists
train_files = os.listdir(train_data_path)
test_files = os.listdir(test_data_path)
val_files = os.listdir(val_data_path)

# Process and save EEG windows for each dataset split
print("Preparing dataset for the training set...")
convert_to_windows(train_data_path, train_files, "../data/3000/train/")

print("Preparing dataset for the test set...")
convert_to_windows(test_data_path, test_files, "../data/3000/test/")

print("Preparing dataset for the validation set...")
convert_to_windows(val_data_path, val_files, "../data/3000/validation/")

# **Convert to CSV**

In this cell, we convert the splitted signals into csv files.

In [None]:
def convert_to_csv(data_path: str, save_path: str) -> None:
    """
    Converts EEG NPZ files into a CSV file.

    Parameters:
    - data_path (str): Path to the directory containing EEG NPZ files.
    - save_path (str): File path where the final CSV will be saved.

    Returns:
    - None: Saves the EEG data as a CSV file.
    """

    # Ensure the data directory exists
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"Data directory '{data_path}' not found.")

    # Retrieve and shuffle file list
    file_list = os.listdir(data_path)
    random.shuffle(file_list)

    array_list = []

    for file_name in tqdm(file_list, desc="Processing EEG files"):
        try:
            # Load EEG sample
            sample_path = os.path.join(data_path, file_name)
            sample = np.load(sample_path)

            # Concatenate EEG signal and label
            concatenated_data = np.append(sample['x'], sample['y'])
            array_list.append(concatenated_data)

        except Exception as e:
            print(f"Error processing file {file_name}: {e}")

    # Convert to DataFrame and save as CSV
    df = pd.DataFrame(np.array(array_list))
    df.to_csv(save_path, index=False)

    print(f"CSV file saved to {save_path}")


# Convert and save datasets as CSV
print("Preparing CSV file for the training set...")
convert_to_csv("../data/3000/train/", "../data/3000/train.csv")

print("Preparing CSV file for the test set...")
convert_to_csv("../data/3000/test/", "../data/3000/test.csv")

print("Preparing CSV file for the validation set...")
convert_to_csv("../data/3000/validation/", "../data/3000/validation.csv")