In [7]:
import os
import glob
import numpy as np
import mne
import xml.etree.ElementTree as ET
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, cohen_kappa_score
import yasa
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
from datetime import datetime
import logging
import sys
from contextlib import redirect_stdout, redirect_stderr
import io

# Configuration and Setup
CURRENT_DATE = "2025-02-03 12:01:57"
CURRENT_USER = "ZiadATAhmed"

# Create logs directory if it doesn't exist
os.makedirs('logs', exist_ok=True)

# Set up logging
log_filename = os.path.join('logs', f'sleep_analysis_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s | %(levelname)s | %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    handlers=[
        logging.FileHandler(log_filename),
        logging.StreamHandler(open(os.devnull, 'w'))  # Redirect stdout to null
    ]
)
logger = logging.getLogger(__name__)

# Suppress warnings
warnings.filterwarnings("ignore")

# Define stage names and labels for all possible stages
STAGE_NAMES = ["WAKE", "N1", "N2", "N3", "REM", "UNKNOWN"]
STAGE_LABELS = [0, 1, 2, 3, 5, 6]  # Include all possible stage values

def suppress_output(func):
    """Decorator to suppress stdout and stderr output"""
    def wrapper(*args, **kwargs):
        with redirect_stdout(io.StringIO()), redirect_stderr(io.StringIO()):
            result = func(*args, **kwargs)
        return result
    return wrapper

class SleepAnalysis:
    def __init__(self, dataset_dir):
        self.dataset_dir = dataset_dir
        self.metrics_dict = {}
        self.overall_true = []
        self.overall_pred = []
        logger.info(f"Analysis initialized by {CURRENT_USER} at {CURRENT_DATE}")

    def parse_nsrr_xml(self, xml_file):
        """Parse NSRR XML file for sleep stage annotations."""
        try:
            tree = ET.parse(xml_file)
            root = tree.getroot()
            
            stages = []
            for event in root.findall(".//ScoredEvent"):
                if event.find("EventType").text == "Stages|Stages":
                    stage = int(event.find("EventConcept").text.split("|")[-1])
                    duration = float(event.find("Duration").text)
                    epochs = int(duration // 30)
                    stages.extend([stage] * epochs)
                    
            return np.array(stages)
        except Exception as e:
            logger.error(f"Error parsing XML file {xml_file}: {str(e)}")
            raise

    @suppress_output
    def preprocess_eeg(self, raw):
        """Preprocess EEG data."""
        raw.resample(100, verbose=False)
        raw.filter(0.3, 45, verbose=False)
        return raw

    def detect_artifacts(self, raw, threshold=100e-6):
        """Detect artifacts in EEG data."""
        data = raw.get_data()
        artifacts = np.abs(data) > threshold
        artifact_percentage = np.mean(artifacts) * 100
        logger.info(f"Artifact percentage: {artifact_percentage:.2f}%")
        return artifacts

    @suppress_output
    def evaluate_sleep_staging(self, edf_file, xml_file):
        """Evaluate sleep staging for a single recording."""
        try:
            # Load and preprocess EEG
            raw = mne.io.read_raw_edf(edf_file, preload=True, verbose=False)
            raw = self.preprocess_eeg(raw)

            # Get true stages
            true_stages = self.parse_nsrr_xml(xml_file)
            
            # Perform automatic sleep staging
            sls = yasa.SleepStaging(raw, eeg_name="EEG", eog_name="EOG(L)", emg_name="EMG")
            pred_stages = sls.predict()
            pred_int = yasa.hypno_str_to_int(pred_stages)

            # Calculate metrics using all possible labels
            metrics = {
                "accuracy": accuracy_score(true_stages, pred_int) * 100,
                "confusion_matrix": confusion_matrix(true_stages, pred_int, 
                                                  labels=STAGE_LABELS),
                "classification_report": classification_report(true_stages, pred_int, 
                                                          labels=STAGE_LABELS,
                                                          target_names=STAGE_NAMES,
                                                          zero_division=0),
                "cohen_kappa": cohen_kappa_score(true_stages, pred_int),
                "true_stages": true_stages,
                "pred_stages": pred_int,
                "raw": raw
            }
            
            return metrics
        
        except Exception as e:
            logger.error(f"Error processing {edf_file}: {str(e)}")
            raise

    def process_dataset(self):
        """Process all files in the dataset."""
        edf_files = glob.glob(os.path.join(self.dataset_dir, "*.edf"))
        
        for edf_file in tqdm(edf_files, desc="Processing files", file=sys.stdout):
            try:
                xml_file = edf_file.replace(".edf", "-nsrr.xml")
                if not os.path.exists(xml_file):
                    logger.warning(f"No XML file found for {edf_file}")
                    continue
                    
                metrics = self.evaluate_sleep_staging(edf_file, xml_file)
                sample_name = os.path.basename(edf_file).replace(".edf", "")
                self.metrics_dict[sample_name] = metrics
                
                self.overall_true.extend(metrics["true_stages"])
                self.overall_pred.extend(metrics["pred_stages"])
                
                logger.info(f"Successfully processed {sample_name}")
                
            except Exception as e:
                logger.error(f"Error processing {edf_file}: {str(e)}")
                continue

    def plot_results(self):
        """Generate and save all plots."""
        if not self.metrics_dict:
            logger.error("No data available for plotting")
            return

        # Create output directory for plots
        output_dir = f"results_{CURRENT_DATE.split()[0]}"
        os.makedirs(output_dir, exist_ok=True)

        # Create a single figure with three panels
        fig = plt.figure(figsize=(20, 15))
        gs = fig.add_gridspec(2, 2)

        # Panel 1: Performance Metrics Bar Chart
        ax1 = fig.add_subplot(gs[0, 0])
        sample_names = list(self.metrics_dict.keys())
        x_values = np.arange(len(sample_names))
        accuracies = [m["accuracy"] for m in self.metrics_dict.values()]
        kappas = [m["cohen_kappa"] * 100 for m in self.metrics_dict.values()]
        
        width = 0.35
        ax1.bar(x_values - width/2, accuracies, width, label='Accuracy')
        ax1.bar(x_values + width/2, kappas, width, label="Cohen's Kappa")
        ax1.set_ylabel("Percentage", fontsize=12)
        ax1.set_title("Per-Sample Performance Metrics", fontsize=14)
        ax1.set_xticks(x_values)
        ax1.set_xticklabels([s[-4:] for s in sample_names], rotation=45)
        ax1.legend()
        ax1.grid(True, linestyle='--', alpha=0.7)

        # Panel 2: Aggregated Confusion Matrix
        ax2 = fig.add_subplot(gs[0, 1])
        agg_cm = sum([m["confusion_matrix"] for m in self.metrics_dict.values()])
        sns.heatmap(agg_cm, annot=True, fmt="d", cmap="Blues",
                    xticklabels=STAGE_NAMES, yticklabels=STAGE_NAMES, ax=ax2)
        ax2.set_title("Aggregated Confusion Matrix", fontsize=14)
        ax2.set_xlabel("Predicted Stage", fontsize=12)
        ax2.set_ylabel("True Stage", fontsize=12)

        # Panel 3: Classification Report (as text)
        ax3 = fig.add_subplot(gs[1, :])
        ax3.axis('off')
        overall_report = classification_report(self.overall_true, self.overall_pred,
                                            labels=STAGE_LABELS,
                                            target_names=STAGE_NAMES,
                                            zero_division=0)
        ax3.text(0.1, 0.1, f"Overall Classification Report:\n\n{overall_report}",
                 fontfamily='monospace', fontsize=10)

        # Adjust layout and save
        plt.suptitle(f"Sleep Staging Analysis Summary - {CURRENT_DATE}", fontsize=16)
        plt.tight_layout()
        
        # Save the combined figure
        plt.savefig(os.path.join(output_dir, "combined_analysis.png"), 
                    bbox_inches='tight', dpi=300)
        plt.close()

        # Also save individual hypnograms
        for sample_name, metrics in self.metrics_dict.items():
            plt.figure(figsize=(15, 5))
            time_axis = np.arange(len(metrics["true_stages"])) * 30 / 3600  # Convert to hours
            
            plt.plot(time_axis, metrics["true_stages"], label="True", alpha=0.7)
            plt.plot(time_axis, metrics["pred_stages"], label="Predicted", alpha=0.7)
            
            plt.title(f"Hypnogram Comparison - {sample_name}")
            plt.xlabel("Time (hours)")
            plt.ylabel("Sleep Stage")
            plt.yticks(range(len(STAGE_NAMES)), STAGE_NAMES)
            plt.legend()
            plt.grid(True)
            plt.tight_layout()
            
            plt.savefig(os.path.join(output_dir, f"hypnogram_{sample_name}.png"))
            plt.close()

        logger.info(f"Plots saved to {output_dir}")

    def generate_report(self):
        """Generate a comprehensive report."""
        report_path = f"sleep_analysis_report_{CURRENT_DATE.split()[0]}.txt"
        
        with open(report_path, 'w') as f:
            f.write(f"Sleep Analysis Report\n")
            f.write(f"Generated by: {CURRENT_USER}\n")
            f.write(f"Date: {CURRENT_DATE}\n")
            f.write("="*50 + "\n\n")
            
            # Overall metrics
            overall_acc = accuracy_score(self.overall_true, self.overall_pred) * 100
            overall_kappa = cohen_kappa_score(self.overall_true, self.overall_pred)
            
            f.write(f"Overall Metrics:\n")
            f.write(f"Accuracy: {overall_acc:.2f}%\n")
            f.write(f"Cohen's Kappa: {overall_kappa:.2f}\n\n")
            
            # Per-sample metrics
            f.write("Per-sample Metrics:\n")
            for sample_name, metrics in self.metrics_dict.items():
                f.write(f"\n{sample_name}:\n")
                f.write(f"Accuracy: {metrics['accuracy']:.2f}%\n")
                f.write(f"Cohen's Kappa: {metrics['cohen_kappa']:.2f}\n")
                f.write("\nClassification Report:\n")
                f.write(metrics['classification_report'])
                f.write("\n" + "-"*50 + "\n")

        logger.info(f"Report generated: {report_path}")

def main():
    try:
        # Initialize analysis
        dataset_dir = "dataset/"
        analysis = SleepAnalysis(dataset_dir)
        
        # Process dataset
        logger.info("Starting dataset processing")
        analysis.process_dataset()
        
        # Generate results
        logger.info("Generating plots and reports")
        analysis.plot_results()
        analysis.generate_report()
        
        logger.info("Analysis completed successfully")
        
    except Exception as e:
        logger.error(f"Analysis failed: {str(e)}")
        raise

if __name__ == "__main__":
    main()

Processing files: 100%|██████████| 50/50 [09:03<00:00, 10.86s/it]
