In [6]:
import os
import glob
import numpy as np
import mne
import xml.etree.ElementTree as ET
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, cohen_kappa_score
import yasa
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm  # Progress bar
import warnings  # Suppress warnings
from datetime import datetime
import logging
import sys
from contextlib import redirect_stdout, redirect_stderr
import io

# -------------------------------
# Configuration and Setup
# -------------------------------
CURRENT_DATE = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
CURRENT_USER = "ZiadATAhmed"
DATASET_DIR = "dataset/"

# Create logs directory if it doesn't exist
os.makedirs('logs', exist_ok=True)
log_filename = os.path.join('logs', f'sleep_analysis_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s | %(levelname)s | %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    handlers=[
        logging.FileHandler(log_filename),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore")

# Set publication-ready matplotlib settings
plt.rcParams.update({
    "font.size": 14,
    "figure.figsize": (18, 10),
    "axes.titlesize": 16,
    "axes.labelsize": 16,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "legend.fontsize": 14,
    "figure.titlesize": 18,
    "lines.linewidth": 2
})

# For five-class sleep staging, we use these stage names.
STAGE_NAMES = ["WAKE", "N1", "N2", "N3", "REM"]
# The five-class mapping corresponds to integer labels [0, 1, 2, 3, 4].

# -------------------------------
# Utility Functions
# -------------------------------
def suppress_output(func):
    """Decorator to suppress stdout and stderr output."""
    def wrapper(*args, **kwargs):
        with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
            return func(*args, **kwargs)
    return wrapper

def parse_nsrr_xml(xml_file):
    """
    Parses an NSRR XML file (expected to end with "-nsrr.xml") to extract sleep stage annotations.
    Assumes each "ScoredEvent" with EventType "Stages|Stages" contains:
      - EventConcept: where the last element after splitting by "|" is the stage (numeric)
      - Duration: duration in seconds (must be a multiple of 30)
    Returns:
      A numpy array of raw sleep stage values repeated for each 30-second epoch.
    """
    try:
        tree = ET.parse(xml_file)
        root = tree.getroot()
    except Exception as e:
        raise ValueError(f"Error parsing XML file {xml_file}: {e}")

    hypnogram = []
    for event in root.findall(".//ScoredEvent"):
        event_type = event.find("EventType")
        if event_type is None or event_type.text != "Stages|Stages":
            continue
        duration_elem = event.find("Duration")
        concept_elem = event.find("EventConcept")
        if duration_elem is None or concept_elem is None:
            continue
        stage_str = concept_elem.text.split("|")[-1]
        hypnogram.append((stage_str, duration_elem.text))
    
    stages = []
    for stage_str, duration_str in hypnogram:
        try:
            stage = int(stage_str)
        except ValueError:
            continue  # Skip non-numeric annotations
        duration = float(duration_str)
        if duration % 30 != 0:
            raise ValueError("Annotation duration is not a multiple of 30 seconds.")
        epochs_duration = int(duration) // 30
        stages.extend([stage] * epochs_duration)
    
    if not stages:
        raise ValueError("No valid sleep stage annotations were found in the XML file.")
    
    return np.array(stages)

def map_stage(stage):
    """
    Maps the raw stage integer from the XML to a sleep stage label for five classes.
    Mapping:
      - 0 -> WAKE
      - 1 -> N1
      - 2 -> N2
      - 3 or 4 -> N3   (N3 and N4 merged)
      - 5 -> REM
    Any other stage is mapped to "UNS".
    """
    if stage == 0:
        return "WAKE"
    elif stage == 1:
        return "N1"
    elif stage == 2:
        return "N2"
    elif stage in (3, 4):
        return "N3"
    elif stage == 5:
        return "REM"
    else:
        return "UNS"

# -------------------------------
# Core Analysis Class
# -------------------------------
class SleepAnalysis:
    def __init__(self, dataset_dir):
        self.dataset_dir = dataset_dir
        self.metrics_dict = {}
        self.overall_true = []
        self.overall_pred = []
        logger.info(f"Analysis initialized by {CURRENT_USER} at {CURRENT_DATE}")

    @suppress_output
    def preprocess_eeg(self, raw):
        """Resample and filter the EEG data."""
        raw.resample(100, npad="auto")
        raw.filter(0.3, 45, fir_design="firwin", verbose=False)
        return raw

    def evaluate_sleep_staging(self, edf_file, xml_file):
        """
        Evaluates sleep staging on one EDF file paired with its NSRR XML annotation.
        Steps:
          1. Read and preprocess the EDF file.
          2. Parse XML to extract true sleep stage annotations.
          3. Map raw annotations to stage labels using `map_stage`.
          4. Convert these to integer labels via YASA's Hypnogram.
          5. Perform automatic sleep staging using YASA.
          6. Compute evaluation metrics.
        Returns a dictionary with accuracy, confusion matrix, classification report,
        Cohen's kappa, true labels, predicted labels, and raw EEG.
        """
        logger.info(f"Processing EDF file: {edf_file}")
        try:
            raw = mne.io.read_raw_edf(edf_file, preload=True, verbose=False)
        except Exception as e:
            logger.error(f"Failed to read EDF file {edf_file}: {e}")
            raise

        raw = self.preprocess_eeg(raw)

        # Parse XML and map annotations
        try:
            hypnogram = parse_nsrr_xml(xml_file)
        except Exception as e:
            logger.error(f"Error parsing XML {xml_file}: {e}")
            raise

        # Map raw stages to five-class labels
        hypno_labels = [map_stage(stage) for stage in hypnogram]
        # Convert string labels to integer labels using YASA's Hypnogram.
        # YASA will map: "WAKE"->0, "N1"->1, "N2"->2, "N3"->3, "REM"->4.
        try:
            true_hyp = yasa.Hypnogram(hypno_labels, freq="30s")
            true_int = true_hyp.as_int()
        except Exception as e:
            logger.error(f"Error converting hypnogram to integer labels: {e}")
            raise

        # Automatic sleep staging using YASA
        try:
            sls = yasa.SleepStaging(raw, eeg_name="EEG", eog_name="EOG(L)", emg_name="EMG")
            hypno_pred = sls.predict()
            pred_int = yasa.hypno_str_to_int(hypno_pred)
        except Exception as e:
            logger.error(f"Error during sleep staging prediction: {e}")
            raise

        # Ensure both arrays have the same length.
        if len(true_int) != len(pred_int):
            err_msg = "Mismatch in the number of epochs between true annotations and predictions."
            logger.error(err_msg)
            raise ValueError(err_msg)

        # Use consistent label order [0, 1, 2, 3, 4] for confusion matrix and reporting.
        accuracy = 100 * accuracy_score(true_int, pred_int)
        cm = confusion_matrix(true_int, pred_int, labels=[0, 1, 2, 3, 4])
        class_report = classification_report(true_int, pred_int, target_names=STAGE_NAMES)
        kappa = cohen_kappa_score(true_int, pred_int)

        logger.info(f"File processed. Accuracy: {accuracy:.2f}%, Cohen's Kappa: {kappa:.2f}")
        metrics = {
            "accuracy": accuracy,
            "confusion_matrix": cm,
            "classification_report": class_report,
            "cohen_kappa": kappa,
            "true_int": true_int,
            "pred_int": pred_int,
            "raw": raw
        }
        return metrics

    def process_dataset(self):
        """
        Processes all EDF/NSRR XML pairs in the dataset directory.
        Aggregates per-sample metrics and accumulates overall true and predicted labels.
        """
        edf_files = glob.glob(os.path.join(self.dataset_dir, "**", "*.edf"), recursive=True)
        if not edf_files:
            logger.error("No EDF files found in the dataset directory.")
            return

        for edf_file in tqdm(edf_files, desc="Processing files"):
            base = os.path.splitext(edf_file)[0]
            xml_file = base + "-nsrr.xml"
            sample_name = os.path.basename(base)
            if not os.path.exists(xml_file):
                logger.warning(f"XML file not found for {edf_file}. Skipping.")
                continue
            try:
                metrics = self.evaluate_sleep_staging(edf_file, xml_file)
                self.metrics_dict[sample_name] = metrics
                self.overall_true.extend(metrics["true_int"])
                self.overall_pred.extend(metrics["pred_int"])
                logger.info(f"Successfully processed sample: {sample_name}")
            except Exception as e:
                logger.error(f"Error processing {edf_file}: {e}")
                continue

    def plot_results(self):
        """
        Generates plots including:
          - A grouped bar chart of per-sample accuracy and Cohen's kappa.
          - An aggregated confusion matrix heatmap.
          - A text panel for the overall classification report.
          - Individual hypnogram comparisons.
        Saves plots to an output directory.
        """
        if not self.metrics_dict:
            logger.error("No metrics available for plotting.")
            return

        output_dir = f"results_{datetime.now().strftime('%Y%m%d')}"
        os.makedirs(output_dir, exist_ok=True)

        sample_names = list(self.metrics_dict.keys())
        x_values = np.arange(1, len(sample_names) + 1)
        accuracies = [self.metrics_dict[s]["accuracy"] for s in sample_names]
        kappas = [self.metrics_dict[s]["cohen_kappa"] * 100 for s in sample_names]

        # Aggregate confusion matrix across samples using consistent labels [0,1,2,3,4]
        agg_cm = None
        for s in sample_names:
            if agg_cm is None:
                agg_cm = self.metrics_dict[s]["confusion_matrix"]
            else:
                agg_cm += self.metrics_dict[s]["confusion_matrix"]

        overall_class_report = classification_report(self.overall_true, self.overall_pred, target_names=STAGE_NAMES)

        # Create combined figure with three panels
        fig = plt.figure(constrained_layout=True, figsize=(18, 12))
        gs = fig.add_gridspec(2, 2)

        # Panel 1: Grouped Bar Chart (per-sample metrics)
        ax1 = fig.add_subplot(gs[0, 0])
        width = 0.25
        ax1.bar(x_values - width/2, accuracies, width, label='Accuracy')
        ax1.bar(x_values + width/2, kappas, width, label="Cohen's Kappa")
        ax1.set_ylabel("Percentage", fontsize=14)
        ax1.set_title("Per-Sample Metrics", fontsize=16)
        ax1.set_xticks(x_values)
        ax1.set_xticklabels([s[-4:] for s in sample_names], rotation=45, ha='center', fontsize=10)
        ax1.legend(fontsize=12, loc='upper center', bbox_to_anchor=(0.5, -0.10), ncol=2)
        ax1.grid(True, linestyle="--", alpha=0.5)

        # Panel 2: Aggregated Confusion Matrix Heatmap
        ax2 = fig.add_subplot(gs[0, 1])
        sns.heatmap(agg_cm, annot=True, fmt="d", cmap="Blues", 
                    xticklabels=STAGE_NAMES, yticklabels=STAGE_NAMES, ax=ax2)
        ax2.set_title("Aggregated Confusion Matrix", fontsize=16)
        ax2.set_xlabel("Predicted Stage", fontsize=14)
        ax2.set_ylabel("True Stage", fontsize=14)

        # Panel 3: Overall Classification Report (text)
        ax3 = fig.add_subplot(gs[1, :])
        ax3.axis("off")
        ax3.set_title("Overall Classification Report", fontsize=16, pad=30, loc='left')
        ax3.text(0, 0.5, overall_class_report, fontsize=14, family="monospace")
        plt.suptitle("Sleep Staging Evaluation Summary", fontsize=20)
        combined_fig_path = os.path.join(output_dir, "combined_analysis.png")
        plt.savefig(combined_fig_path, bbox_inches='tight', dpi=300)
        plt.close()
        logger.info(f"Combined analysis figure saved to {combined_fig_path}")

        # Save individual hypnogram comparisons
        for sample_name, metrics in self.metrics_dict.items():
            plt.figure(figsize=(15, 5))
            time_axis = np.arange(len(metrics["true_int"])) * 30 / 3600  # Convert epochs to hours
            plt.plot(time_axis, metrics["true_int"], label="True", alpha=0.7)
            plt.plot(time_axis, metrics["pred_int"], label="Predicted", alpha=0.7, linestyle="--")
            plt.title(f"Hypnogram Comparison - {sample_name}")
            plt.xlabel("Time (hours)")
            plt.ylabel("Sleep Stage")
            plt.yticks(ticks=[0, 1, 2, 3, 4], labels=STAGE_NAMES)
            plt.legend()
            plt.grid(True)
            hypnogram_path = os.path.join(output_dir, f"hypnogram_{sample_name}.png")
            plt.tight_layout()
            plt.savefig(hypnogram_path, dpi=300)
            plt.close()
            logger.info(f"Hypnogram for sample {sample_name} saved to {hypnogram_path}")

    def generate_report(self):
        """
        Generates a text report summarizing overall and per-sample metrics.
        """
        report_path = f"sleep_analysis_report_{datetime.now().strftime('%Y%m%d')}.txt"
        overall_acc = accuracy_score(self.overall_true, self.overall_pred) * 100
        overall_kappa = cohen_kappa_score(self.overall_true, self.overall_pred)
        with open(report_path, 'w') as f:
            f.write("Sleep Analysis Report\n")
            f.write(f"Generated by: {CURRENT_USER}\n")
            f.write(f"Date: {CURRENT_DATE}\n")
            f.write("=" * 50 + "\n\n")
            f.write("Overall Metrics:\n")
            f.write(f"Accuracy: {overall_acc:.2f}%\n")
            f.write(f"Cohen's Kappa: {overall_kappa:.2f}\n\n")
            f.write("Per-sample Metrics:\n")
            for sample_name, metrics in self.metrics_dict.items():
                f.write(f"\nSample: {sample_name}\n")
                f.write(f"Accuracy: {metrics['accuracy']:.2f}%\n")
                f.write(f"Cohen's Kappa: {metrics['cohen_kappa']:.2f}\n")
                f.write("Classification Report:\n")
                f.write(metrics["classification_report"])
                f.write("\n" + "-" * 50 + "\n")
        logger.info(f"Report generated: {report_path}")

# -------------------------------
# Main Execution
# -------------------------------
def main():
    try:
        analysis = SleepAnalysis(DATASET_DIR)
        logger.info("Starting dataset processing")
        analysis.process_dataset()
        logger.info("Dataset processing completed")
        logger.info("Generating plots and report")
        analysis.plot_results()
        analysis.generate_report()
        logger.info("Analysis completed successfully")
    except Exception as e:
        logger.error(f"Analysis failed: {e}")
        raise

if __name__ == "__main__":
    main()


Processing files: 100%|██████████| 50/50 [09:08<00:00, 10.98s/it]
