In [2]:
import xml.etree.ElementTree as ET
import os
from collections import Counter, defaultdict
import pandas as pd

# Processing PSG XML Files for Arousal Analysis

This notebook processes PSG (Polysomnography) XML files to extract arousal events and calculate arousal burden metrics across different sleep stages.

## Overview

The analysis pipeline consists of four main components:

### 1. Data Extraction
- **XML Parsing**: Read PSG XML files using ElementTree parser
- **Event Extraction**: Extract arousal events with timestamps, durations, and event labels
- **Sleep Stage Mapping**: Extract sleep stage sequences for each epoch (default 30-second epochs)
- **Stage Code Mapping**: Map sleep stage names (N1, N2, N3, REM, Wake, Unscorable) to numeric codes
- **Error Handling**: Handle missing or malformed data gracefully with error logging

### 2. Data Validation and Filtering
- **Duration Filtering**: Remove arousal events with duration ≤ 0 seconds
- **Sleep-Time Focus**: Filter arousals to only include those occurring during sleep stages (N1, N2, N3, REM)
- **Wake Exclusion**: Exclude arousals during Wake (0) and Unscorable (9) epochs
- **Data Validation**: Check for missing required fields (event name, start time, duration)
- **Overlap Handling**: Merge overlapping arousal events into singular events to avoid double-counting

### 3. Arousal Burden Calculation
- **Proportional Allocation**: Use proportional allocation to determine arousal time within each sleep stage
- **Multi-Epoch Spanning**: For arousals spanning multiple epochs, calculate overlap with each stage epoch
- **Interval Merging**: Merge overlapping arousal intervals to prevent double-counting
- **Sleep Time Calculation**: Calculate total sleep time per stage (epoch_length × number_of_epochs_in_stage)
- **Burden Formula**: Compute arousal burden as: `(total_arousal_duration_in_stage / total_stage_duration) × 100`
- **Edge Case Handling**: Handle cases where sleep stages have zero duration (return None for burden)

### 4. Statistical Analysis and Export
- **Batch Processing**: Process all XML files in the `/psg/` folder
- **Summary Generation**: Generate per-file summary with total sleep time, total arousal duration, and stage-specific burdens
- **CSV Export**: Export results to CSV file for further analysis
- **Statistical Overview**: Display sample results and descriptive statistics across all processed files

## Key Metrics Calculated
- **Total Arousal Duration**: Sum of all arousal event durations during sleep (minutes)
- **Total Sleep Time**: Total time spent in sleep stages N1, N2, N3, and REM (minutes)
- **Overall Arousal Burden**: Percentage of sleep time occupied by arousal events
- **Stage-Specific Arousal Burden**: Arousal burden calculated separately for N1, N2, N3, and REM stages

In [10]:
def extract_arousal_events(xml_file):
    """
    Extract arousal events and sleep stages from PSG XML file.
    
    Returns:
        events (list): List of event dictionaries with event_label, start_time_sec, duration_sec
        sleep_stages (list): List of sleep stage codes for each epoch
        epoch_length (int): Length of each epoch in seconds
    """
    try:
        tree = ET.parse(xml_file)
        root = tree.getroot()
        
        # Extract epoch length
        epoch_length_elem = root.find('EpochLength')
        epoch_length = int(epoch_length_elem.text) if epoch_length_elem is not None else 30
        
        # Extract events
        events = []
        scored_events = root.find('ScoredEvents')
        if scored_events is not None:
            for event in scored_events.findall('ScoredEvent'):
                name_elem = event.find('Name')
                start_elem = event.find('Start')
                duration_elem = event.find('Duration')
                
                if name_elem is not None and start_elem is not None and duration_elem is not None:
                    duration = float(duration_elem.text)
                    # Convert negative durations to positive
                    if duration < 0:
                        duration = abs(duration)
                    # Skip events with zero duration
                    if duration > 0:
                        events.append({
                            'event_label': name_elem.text,
                            'start_time_sec': float(start_elem.text),
                            'duration_sec': duration
                        })
        
        # Extract sleep stages
        sleep_stages = []
        sleep_stages_elem = root.find('SleepStages')
        if sleep_stages_elem is not None:
            for stage in sleep_stages_elem.findall('SleepStage'):
                stage_text = stage.text
                # Map sleep stage names to codes
                if stage_text in ['1', 'N1']:
                    sleep_stages.append(1)
                elif stage_text in ['2', 'N2']:
                    sleep_stages.append(2)
                elif stage_text in ['3', 'N3']:
                    sleep_stages.append(3)
                elif stage_text in ['5', 'REM', 'R']:
                    sleep_stages.append(5)
                elif stage_text in ['0', 'W', 'Wake']:
                    sleep_stages.append(0)
                elif stage_text in ['9', 'Unscorable']:
                    sleep_stages.append(9)
                else:
                    sleep_stages.append(int(stage_text) if stage_text.isdigit() else 0)
        
        return events, sleep_stages, epoch_length
        
    except Exception as e:
        print(f"Error processing {xml_file}: {e}")
        return [], [], 30

def filter_arousals_during_sleep(events_df, sleep_stages, epoch_length):
    """
    Filter arousal events to only include those that occur during sleep stages.
    
    Args:
        events_df (DataFrame): DataFrame with arousal events
        sleep_stages (list): List of sleep stage codes for each epoch
        epoch_length (int): Length of each epoch in seconds
    
    Returns:
        DataFrame: Filtered arousal events that occur during sleep
    """
    if events_df.empty or not sleep_stages:
        return events_df
    
    # Sleep stage codes: 1 = N1, 2 = N2, 3 = N3, 5 = REM (exclude 0 = Wake, 9 = Unscorable)
    sleep_stage_codes = [1, 2, 3, 5]
    
    # Filter arousals that occur during sleep epochs
    sleep_arousals = []
    
    for _, event in events_df.iterrows():
        start_time = event['start_time_sec']
        duration = event['duration_sec']
        end_time = start_time + duration
        
        # Find which epochs this event spans
        start_epoch = int(start_time // epoch_length)
        end_epoch = int(end_time // epoch_length)
        
        # Check if any part of the arousal occurs during sleep
        occurs_during_sleep = False
        for epoch_idx in range(start_epoch, min(end_epoch + 1, len(sleep_stages))):
            if epoch_idx < len(sleep_stages) and sleep_stages[epoch_idx] in sleep_stage_codes:
                occurs_during_sleep = True
                break
        
        if occurs_during_sleep:
            sleep_arousals.append(event)
        else:
            print(f"Skipping arousal at {event['start_time_sec']}s (duration {event['duration_sec']}s) - occurs during wakefulness.")
    
    return pd.DataFrame(sleep_arousals) if sleep_arousals else pd.DataFrame(columns=events_df.columns)

def calculate_total_arousal_time(arousals_df):
    """
    Calculate total arousal time by merging overlapping arousals into singular arousals.
    For any arousals that overlap, merge them into a single arousal spanning from the earliest start to the latest end.
    
    Args:
        arousals_df (DataFrame): DataFrame with arousal events containing start_time_sec and duration_sec
    
    Returns:
        float: Total arousal time in seconds (with overlapping arousals merged)
    """
    if arousals_df.empty:
        return 0.0
    
    # Create list of time intervals (start, end) for each arousal
    intervals = []
    for _, arousal in arousals_df.iterrows():
        start = arousal['start_time_sec']
        end = start + arousal['duration_sec']
        intervals.append((start, end))
    
    # Sort intervals by start time
    intervals.sort(key=lambda x: x[0])
    
    # Merge overlapping intervals
    merged_intervals = []
    for start, end in intervals:
        if not merged_intervals or merged_intervals[-1][1] < start:
            # No overlap with previous interval - add as new interval
            merged_intervals.append((start, end))
        else:
            # Overlap detected - merge by extending the previous interval to the latest end time
            merged_intervals[-1] = (merged_intervals[-1][0], max(merged_intervals[-1][1], end))
    
    # Calculate total time from merged intervals
    total_time = sum(end - start for start, end in merged_intervals)
    
    return total_time

def calculate_stage_specific_arousal_time(arousals_df, sleep_stages, epoch_length, target_stage):
    """
    Calculate arousal time that occurs specifically within a given sleep stage,
    proportionally allocating arousal duration based on overlap with stage epochs.
    
    Args:
        arousals_df (DataFrame): DataFrame with arousal events
        sleep_stages (list): List of sleep stage codes for each epoch
        epoch_length (int): Length of each epoch in seconds
        target_stage (int): Sleep stage code to calculate arousal time for
    
    Returns:
        float: Total arousal time in seconds that occurs within the target stage
    """
    if arousals_df.empty or not sleep_stages:
        return 0.0
    
    # Create list of time intervals for arousals that overlap with target stage epochs
    stage_arousal_intervals = []
    
    for _, arousal in arousals_df.iterrows():
        arousal_start = arousal['start_time_sec']
        arousal_end = arousal_start + arousal['duration_sec']
        
        # Find which epochs this arousal spans
        start_epoch = int(arousal_start // epoch_length)
        end_epoch = int(arousal_end // epoch_length)
        
        # For each epoch the arousal spans, calculate overlap with target stage
        for epoch_idx in range(start_epoch, min(end_epoch + 1, len(sleep_stages))):
            if epoch_idx < len(sleep_stages) and sleep_stages[epoch_idx] == target_stage:
                # Calculate the time boundaries of this epoch
                epoch_start = epoch_idx * epoch_length
                epoch_end = (epoch_idx + 1) * epoch_length
                
                # Calculate the overlap between arousal and this epoch
                overlap_start = max(arousal_start, epoch_start)
                overlap_end = min(arousal_end, epoch_end)
                
                if overlap_start < overlap_end:
                    stage_arousal_intervals.append((overlap_start, overlap_end))
    
    if not stage_arousal_intervals:
        return 0.0
    
    # Sort intervals by start time
    stage_arousal_intervals.sort(key=lambda x: x[0])
    
    # Merge overlapping intervals to avoid double-counting
    merged_intervals = []
    for start, end in stage_arousal_intervals:
        if not merged_intervals or merged_intervals[-1][1] < start:
            merged_intervals.append((start, end))
        else:
            merged_intervals[-1] = (merged_intervals[-1][0], max(merged_intervals[-1][1], end))
    
    # Calculate total time from merged intervals
    total_time = sum(end - start for start, end in merged_intervals)
    
    return total_time

# Pipeline: Process all PSG XML files in /psg/ folder and compute arousal metrics for each, outputting to CSV

psg_folder = "psg"
files = ["/wynton/group/andrews/data/MrOS/mros-sof_mjhe/vs/EDF/bi/bi0002.edf.XML"]
results = []

for filename in files:
    if not filename.lower().endswith(".xml"):
        continue
    xml_file = os.path.join(psg_folder, filename)
    print(f"\nProcessing: {filename}")

    # Use the provided extract_arousal_events function to get events, sleep stages, and epoch length
    events, sleep_stages, epoch_length = extract_arousal_events(xml_file)

    # Convert the list of event dicts to a DataFrame (if events were found)
    if events:
        df_events = pd.DataFrame(events)
    else:
        print("  No events extracted.")
        continue

    # Optionally, convert sleep stages to a DataFrame (one row per epoch)
    if sleep_stages:
        df_sleep_stages = pd.DataFrame({
            "epoch": range(1, len(sleep_stages) + 1),
            "sleep_stage": sleep_stages
        })
    else:
        print("  No sleep stages extracted.")
        continue

    # Check if df_events is empty or doesn't have the expected columns
    if df_events.empty:
        print("  No events found in DataFrame.")
        continue
    
    # Check if 'event_label' column exists
    if 'event_label' not in df_events.columns:
        print(f"  Warning: 'event_label' column not found. Available columns: {list(df_events.columns)}")
        continue

    # Filter for arousal events - using case-insensitive search for "arousal"
    df_arousals = df_events[df_events['event_label'].str.contains('arousal', case=False, na=False)]

    # Filter arousals to only include those during sleep stages
    df_arousals_sleep = filter_arousals_during_sleep(df_arousals, sleep_stages, epoch_length)

    # Calculate total arousal duration by merging overlapping arousals
    total_arousal_duration_sec = calculate_total_arousal_time(df_arousals_sleep)

    # Sleep stage codes: 1 = N1, 2 = N2, 3 = N3, 5 = REM
    sleep_stage_codes = [1, 2, 3, 5]
    if sleep_stages:
        total_sleep_epochs = sum(1 for s in sleep_stages if s in sleep_stage_codes)
        total_sleep_time_sec = total_sleep_epochs * epoch_length
    else:
        total_sleep_time_sec = 0

    # Calculate arousal burden: (total arousal duration [min] / total sleep time [min]) * 100
    if sleep_stages and total_sleep_time_sec > 0:
        total_arousal_duration_min = total_arousal_duration_sec / 60
        total_sleep_time_min = total_sleep_time_sec / 60
        arousal_burden_percent = (total_arousal_duration_min / total_sleep_time_min) * 100
    else:
        arousal_burden_percent = None

    # Count the number of epochs for each sleep stage (e.g., N1, N2, N3, REM, Wake)
    if sleep_stages:
        stage_labels = {
            0: "Wake",
            1: "N1",
            2: "N2",
            3: "N3",
            5: "REM"
        }
        stage_counts = Counter(sleep_stages)
    else:
        stage_counts = {}

    # Calculate arousal burden broken down by sleep stage using proportional allocation
    arousal_burden_by_stage = {}
    if sleep_stages and not df_arousals_sleep.empty:
        stage_labels_sleep = {
            1: "N1",
            2: "N2",
            3: "N3",
            5: "REM"
        }
        stage_epoch_counts = Counter(sleep_stages)
        stage_total_sleep_sec = {code: stage_epoch_counts.get(code, 0) * epoch_length for code in sleep_stage_codes}
        
        # Calculate arousal duration per stage using proportional allocation
        for code in sleep_stage_codes:
            # Calculate arousal time that occurs specifically within this stage
            stage_arousal_sec = calculate_stage_specific_arousal_time(
                df_arousals_sleep, sleep_stages, epoch_length, code
            )
            
            total_sleep_sec = stage_total_sleep_sec.get(code, 0)
            if total_sleep_sec > 0:
                arousal_burden = (stage_arousal_sec / total_sleep_sec) * 100
            else:
                arousal_burden = None
            arousal_burden_by_stage[stage_labels_sleep.get(code, f"Stage {code}")] = arousal_burden

    # Collect results for this file
    print(total_arousal_duration_sec)
    print(total_arousal_duration_sec/60)
    results.append({
        "file_name": filename,
        "total_arousal_duration_min": total_arousal_duration_sec / 60 if total_arousal_duration_sec is not None else None,
        "total_sleep_time_min": total_sleep_time_sec / 60 if total_sleep_time_sec is not None else None,
        "arousal_burden_percent": arousal_burden_percent,
        "N1_arousal_burden_percent": arousal_burden_by_stage.get("N1"),
        "N2_arousal_burden_percent": arousal_burden_by_stage.get("N2"),
        "N3_arousal_burden_percent": arousal_burden_by_stage.get("N3"),
        "REM_arousal_burden_percent": arousal_burden_by_stage.get("REM"),
    })

# Output results to CSV
df_summary = pd.DataFrame(results)
#csv_output_path = "arousal_burden_by_file.csv"
#df_summary.to_csv(csv_output_path, index=False)
#print(f"\nSummary written to {csv_output_path}")
df_summary


Processing: /wynton/group/andrews/data/MrOS/mros-sof_mjhe/vs/EDF/bi/bi0002.edf.XML
Skipping arousal at 13982.2s (duration 25.7s) - occurs during wakefulness.
Skipping arousal at 15548.1s (duration 15.9s) - occurs during wakefulness.
Skipping arousal at 17010.2s (duration 16.4s) - occurs during wakefulness.
Skipping arousal at 18211.5s (duration 27.3s) - occurs during wakefulness.
Skipping arousal at 22982.6s (duration 16.2s) - occurs during wakefulness.
Skipping arousal at 24693.4s (duration 22.9s) - occurs during wakefulness.
Skipping arousal at 25687.0s (duration 22.9s) - occurs during wakefulness.
Skipping arousal at 25890.4s (duration 14.9s) - occurs during wakefulness.
Skipping arousal at 26130.5s (duration 29.2s) - occurs during wakefulness.
Skipping arousal at 26889.9s (duration 19.9s) - occurs during wakefulness.
Skipping arousal at 27907.2s (duration 22.8s) - occurs during wakefulness.
Skipping arousal at 28023.9s (duration 21.6s) - occurs during wakefulness.
Skipping arousal

Unnamed: 0,file_name,total_arousal_duration_min,total_sleep_time_min,arousal_burden_percent,N1_arousal_burden_percent,N2_arousal_burden_percent,N3_arousal_burden_percent,REM_arousal_burden_percent
0,/wynton/group/andrews/data/MrOS/mros-sof_mjhe/...,13.035,298.0,4.374161,12.162393,3.892138,0.0,


# Sanity Checks

In [51]:
# Load the arousal burden by file CSV
burden_df = pd.read_csv('arousal_burden_by_file.csv')

# Display descriptive statistics for the arousal burden by file data
burden_df.describe()

Unnamed: 0,total_arousal_duration_min,total_sleep_time_min,arousal_burden_percent,N1_arousal_burden_percent,N2_arousal_burden_percent,N3_arousal_burden_percent,REM_arousal_burden_percent
count,2902.0,2902.0,2901.0,2869.0,2871.0,2773.0,2858.0
mean,20.669332,353.125775,5.983963,9.808612,6.078677,1.044127,4.450694
std,10.670067,70.657373,3.203529,4.258895,3.325762,1.689973,3.261034
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,13.49375,314.625,3.897832,6.911111,3.721544,0.0,2.406585
50%,18.904167,358.0,5.385253,9.619048,5.393802,0.560976,3.764273
75%,25.59875,399.0,7.418029,12.509434,7.779827,1.290598,5.744025
max,116.263333,616.5,37.408451,29.770115,22.941441,21.833333,60.579832


In [50]:
# Count how many files have arousal_burden_percent > 8.5
high_arousal_burden_count = (burden_df['arousal_burden_percent'] > 8.5).sum()
print(f"Number of files with arousal_burden_percent > 8.5: {high_arousal_burden_count}")

# Also show the percentage of total files
total_files = len(burden_df)
percentage = (high_arousal_burden_count / total_files) * 100
print(f"Percentage of files with arousal_burden_percent > 8.5: {percentage:.2f}%")


Number of files with arousal_burden_percent > 8.5: 486
Percentage of files with arousal_burden_percent > 8.5: 16.75%


In [54]:
# Filter out files with no arousal burden (0% or NaN)
filtered_burden_df = burden_df[
    (burden_df['arousal_burden_percent'] > 0) & 
    (burden_df['arousal_burden_percent'].notna())
]

print(f"Original number of files: {len(burden_df)}")
print(f"Files with arousal burden present: {len(filtered_burden_df)}")
print(f"Files excluded (0% or NaN arousal burden): {len(burden_df) - len(filtered_burden_df)}")

# Display descriptive statistics for files with arousal burden only
print(f"\nDescriptive statistics for files with arousal burden present:")
filtered_burden_df.describe()


Original number of files: 2902
Files with arousal burden present: 2871
Files excluded (0% or NaN arousal burden): 31

Descriptive statistics for files with arousal burden present:


Unnamed: 0,total_arousal_duration_min,total_sleep_time_min,arousal_burden_percent,N1_arousal_burden_percent,N2_arousal_burden_percent,N3_arousal_burden_percent,REM_arousal_burden_percent
count,2871.0,2871.0,2871.0,2869.0,2871.0,2773.0,2858.0
mean,20.892512,353.282828,6.046491,9.808612,6.078677,1.044127,4.450694
std,10.507885,70.009544,3.160959,4.258895,3.325762,1.689973,3.261034
min,0.105,51.0,0.029536,0.0,0.0,0.0,0.0
25%,13.768333,315.0,3.947046,6.911111,3.721544,0.0,2.406585
50%,19.045,358.0,5.419455,9.619048,5.393802,0.560976,3.764273
75%,25.6875,398.5,7.433269,12.509434,7.779827,1.290598,5.744025
max,116.263333,616.5,37.408451,29.770115,22.941441,21.833333,60.579832


In [55]:
# Count how many files with arousal burden present have arousal_burden_percent > 8.5
high_arousal_filtered_count = (filtered_burden_df['arousal_burden_percent'] > 8.5).sum()
print(f"Number of files with arousal burden present and arousal_burden_percent > 8.5: {high_arousal_filtered_count}")

# Show the percentage of files with arousal burden present
filtered_total_files = len(filtered_burden_df)
filtered_percentage = (high_arousal_filtered_count / filtered_total_files) * 100
print(f"Percentage of files with arousal burden present that have arousal_burden_percent > 8.5: {filtered_percentage:.2f}%")


Number of files with arousal burden present and arousal_burden_percent > 8.5: 486
Percentage of files with arousal burden present that have arousal_burden_percent > 8.5: 16.93%


## File Validation and Detailed Examination

Now let's examine individual files in detail to validate our data extraction process. This section provides functions to:

1. **Deep dive into specific XML files** - Extract and display detailed information about events, sleep stages, and calculated metrics
2. **Validate arousal detection** - Show sample arousal events and verify they're being detected correctly
3. **Check sleep stage distribution** - Ensure sleep stages are being parsed properly
4. **Verify calculations** - Confirm that arousal burden calculations match expectations

This validation step is crucial to ensure our batch processing is working correctly across all files.


In [None]:
# Function to examine individual files in detail for validation
def examine_file_details(xml_file_path, max_events_to_show=10):
    """
    Examine a single XML file in detail to validate data extraction.
    
    Args:
        xml_file_path (str): Path to the XML file
        max_events_to_show (int): Maximum number of events to display
    
    Returns:
        dict: Dictionary containing detailed information about the file
    """
    print(f"\n{'='*60}")
    print(f"EXAMINING FILE: {os.path.basename(xml_file_path)}")
    print(f"{'='*60}")
    
    # Extract data using our existing function
    events, sleep_stages, epoch_length = extract_arousal_events(xml_file_path)
    
    if events is None or sleep_stages is None:
        print("Failed to extract data from file")
        return None
    
    # Convert to DataFrames
    df_events = pd.DataFrame(events)
    df_sleep_stages = pd.DataFrame({
        "epoch": range(1, len(sleep_stages) + 1),
        "sleep_stage": sleep_stages
    })
    
    print(f"BASIC INFO:")
    print(f"   • Epoch length: {epoch_length} seconds")
    print(f"   • Total events: {len(df_events)}")
    print(f"   • Total epochs: {len(sleep_stages)}")
    
    # Show event types
    if not df_events.empty:
        event_counts = df_events['event_label'].value_counts()
        print(f"\nEVENT TYPES:")
        for event_type, count in event_counts.head(10).items():
            print(f"   • {event_type}: {count}")
        
        # Show arousal events specifically with sleep stage information
        df_arousals = df_events[df_events['event_label'].str.contains('arousal', case=False, na=False)]
        print(f"\nAROUSAL EVENTS: {len(df_arousals)} total")
        if not df_arousals.empty:
            # Add sleep stage information to arousal events
            df_arousals_with_stages = df_arousals.copy()
            stage_labels = {0: "Wake", 1: "N1", 2: "N2", 3: "N3", 5: "REM", 9: "Unscorable"}
            
            # Calculate which epoch each arousal occurs in and get the sleep stage
            df_arousals_with_stages['epoch'] = (df_arousals_with_stages['start_time_sec'] / epoch_length).astype(int)
            df_arousals_with_stages['sleep_stage_code'] = df_arousals_with_stages['epoch'].apply(
                lambda x: sleep_stages[x] if 0 <= x < len(sleep_stages) else 9
            )
            df_arousals_with_stages['sleep_stage_label'] = df_arousals_with_stages['sleep_stage_code'].map(stage_labels)
            
            print("   Sample arousal events with sleep stages:")
            display(df_arousals_with_stages[['event_label', 'start_time_sec', 'duration_sec', 'epoch', 'sleep_stage_label']].head(max_events_to_show))
    
    # Show sleep stage distribution
    stage_labels = {0: "Wake", 1: "N1", 2: "N2", 3: "N3", 5: "REM", 9: "Unscorable"}
    stage_counts = Counter(sleep_stages)
    print(f"\nSLEEP STAGE DISTRIBUTION:")
    for stage_code, count in sorted(stage_counts.items()):
        stage_name = stage_labels.get(stage_code, f"Unknown({stage_code})")
        percentage = (count / len(sleep_stages)) * 100
        print(f"   • {stage_name}: {count} epochs ({percentage:.1f}%)")
    
    # Filter arousals during sleep and show details
    df_arousals_sleep = filter_arousals_during_sleep(df_arousals, sleep_stages, epoch_length)
    total_arousal_duration_sec = calculate_total_arousal_time(df_arousals_sleep)
    print(total_arousal_duration_sec)
    sleep_stage_codes = [1, 2, 3, 5]
    total_sleep_epochs = sum(1 for s in sleep_stages if s in sleep_stage_codes)
    total_sleep_time_sec = total_sleep_epochs * epoch_length
    
    print(f"\nCALCULATED METRICS:")
    print(f"   • Arousals during sleep: {len(df_arousals_sleep)}")
    print(f"   • Total arousal duration: {total_arousal_duration_sec/60:.2f} minutes")
    print(f"   • Total sleep time: {total_sleep_time_sec/60:.2f} minutes")
    
    if total_sleep_time_sec > 0:
        arousal_burden = (total_arousal_duration_sec / total_sleep_time_sec) * 100
        print(f"   • Arousal burden: {arousal_burden:.2f}%")
    else:
        print(f"   • Arousal burden: N/A (no sleep time)")
    
    return {
        'events_df': df_events,
        'sleep_stages_df': df_sleep_stages,
        'arousals_df': df_arousals,
        'arousals_sleep_df': df_arousals_sleep,
        'epoch_length': epoch_length,
        'total_arousal_duration_sec': total_arousal_duration_sec,
        'total_sleep_time_sec': total_sleep_time_sec
    }

# Example usage: examine a few files for validation
#psg_folder = "psg"
#xml_files = [f for f in os.listdir(psg_folder) if f.lower().endswith('.xml')]
xml_files = [
    "/wynton/group/andrews/data/MrOS/mros-sof_mjhe/vs2/EDF/bi0001_20110502_EDF/bi0001_20110502.edf.XML",
    "/wynton/group/andrews/data/MrOS/mros-sof_mjhe/vs2/EDF/bi/bi0002.edf.XML",
    "/wynton/group/andrews/data/MrOS/mros-sof_mjhe/vs2/EDF/bi0003_20110925_EDF/bi0003_20110925.edf.XML",]

print("EXAMINING FILES FOR VALIDATION")
print("=" * 80)

# Examine first 3 files as examples
for i, filename in enumerate(xml_files):
    xml_file_path = os.path.join(psg_folder, filename)
    file_details = examine_file_details(xml_file_path)
    
    if i < 2:  # Add separator between files
        print("\n" + "─" * 80)

# Interactive function to examine any specific file
def examine_specific_file(filename):
    """
    Examine a specific file by name.
    
    Args:
        filename (str): Name of the XML file to examine
    """
    xml_file_path = os.path.join(psg_folder, filename)
    if os.path.exists(xml_file_path):
        return examine_file_details(xml_file_path)
    else:
        print(f"File not found: {filename}")
        print(f"Available files: {xml_files[:10]}...")  # Show first 10 files
        return None

# Example: Uncomment the line below to examine a specific file
#examine_specific_file("psg\bi0002.edf.XML")


EXAMINING FILES FOR VALIDATION

EXAMINING FILE: bi0001.edf.XML
Error processing /wynton/group/andrews/data/MrOS/mros-sof_mjhe/vs2/EDF/bi/bi0001.edf.XML: [Errno 2] No such file or directory: '/wynton/group/andrews/data/MrOS/mros-sof_mjhe/vs2/EDF/bi/bi0001.edf.XML'
BASIC INFO:
   • Epoch length: 30 seconds
   • Total events: 0
   • Total epochs: 0

SLEEP STAGE DISTRIBUTION:


UnboundLocalError: local variable 'df_arousals' referenced before assignment