# BiLSTM Sleep Stage Clustering - Results Analysis

This notebook provides a comprehensive analysis of BiLSTM-based sleep stage clustering results from the research project "Sleep Stages Analysis with Machine Learning - Unsupervised Approach" conducted at Institut de Neurosciences des Systèmes (INS).

## Analysis Components:
- **Data Loading & Validation**: Load model results and EEG signals with robust error handling
- **Interactive Hypnogram Visualization**: Time-series visualization of detected sleep clusters
- **Cluster Distribution Analysis**: Statistical analysis of sleep stage frequency and duration
- **Frequency Band Analysis**: Spectral power analysis across traditional EEG frequency bands
- **Transition Pattern Analysis**: Markov-like analysis of sleep stage transitions
- **Micro-arousal Detection**: Identification of brief arousal events

## Data Sources:
- **Model Results**: `bilstm_30s_4clusters.pkl` - Pre-trained BiLSTM clustering results
- **Metadata**: `bilstm_30s_4clusters_metadata.json` - Model configuration and parameters  
- **EEG Data**: `EEG_0_per_hour_2024-03-20 17_12_18.edf` - Continuous EEG recordings

## Research Context:
This analysis supports the unsupervised machine learning approach for automated sleep stage classification, contributing to precision medicine applications in neurology and sleep disorders research.

In [None]:
# ==============================================================================
# LIBRARY IMPORTS AND CONFIGURATION
# ==============================================================================

# Core scientific computing libraries
import numpy as np
import pandas as pd
import pickle
import json

# EEG signal processing
import mne

# Interactive visualization
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.colors as colors

# Signal processing and statistical analysis
from scipy import signal
from scipy.stats import mode
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.preprocessing import StandardScaler

# Configuration
import warnings
warnings.filterwarnings('ignore')  # Suppress routine warnings for cleaner output

# Set plotting parameters for consistent visualization
plt.style.use('seaborn-v0_8')
px.defaults.template = 'plotly_white'

print("✓ All required libraries imported successfully")
print("✓ Visualization settings configured")

In [None]:
# ==============================================================================
# DATA LOADING: MODEL RESULTS AND METADATA
# ==============================================================================

# Load model metadata containing training parameters and configuration
try:
    with open('results/bilstm_30s_4clusters_metadata.json', 'r') as f:
        metadata = json.load(f)
    
    print("📊 Model Metadata Loaded:")
    print("-" * 40)
    for key, value in metadata.items():
        print(f"{key:20}: {value}")
    
    # Extract key parameters for later use
    window_size = metadata.get('window_size_seconds', 30)  # Default 30s windows
    overlap = metadata.get('overlap', 0.5)  # Default 50% overlap
    
    print(f"\n✓ Key parameters extracted:")
    print(f"  Window size: {window_size}s")
    print(f"  Overlap: {overlap*100}%")
    
except FileNotFoundError:
    print("❌ Metadata file not found. Using default parameters.")
    metadata = {'window_size_seconds': 30, 'overlap': 0.5}
    window_size, overlap = 30, 0.5
except json.JSONDecodeError:
    print("❌ Error parsing metadata JSON. Using default parameters.")
    metadata = {'window_size_seconds': 30, 'overlap': 0.5}
    window_size, overlap = 30, 0.5

In [None]:
# Load BiLSTM clustering results with robust error handling
try:
    with open('results/bilstm_30s_4clusters.pkl', 'rb') as f:
        results = pickle.load(f)
    
    print("\n📦 BiLSTM Results Loaded Successfully:")
    print("-" * 40)
    for key, value in results.items():
        print(f"{key:20}: {type(value).__name__}")
        if hasattr(value, 'shape'):
            print(f"{'':<20}  Shape: {value.shape}")
        elif hasattr(value, '__len__') and not isinstance(value, str):
            print(f"{'':<20}  Length: {len(value)}")
    
    print("✓ Results data structure validated")
    
except FileNotFoundError:
    print("❌ Results file not found. Please ensure 'results/bilstm_30s_4clusters.pkl' exists.")
    raise
except Exception as e:
    print(f"❌ Error loading results: {e}")
    raise

In [None]:
# ==============================================================================
# EEG DATA LOADING WITH FALLBACK MECHANISMS
# ==============================================================================

def load_eeg_data():
    """
    Load EEG data with multiple fallback mechanisms for robust data handling.
    
    Returns:
        tuple: (raw_data, eeg_signal, sampling_frequency, time_vector)
    """
    # Priority order for EEG file locations
    eeg_paths = [
        'by captain borat/raw/EEG_0_per_hour_2024-03-20 17_12_18.edf',
        'raw data/EEG_0_per_hour_2024-03-20 17_12_18.edf',
        'EEG_0_per_hour_2024-03-20 17_12_18.edf'
    ]
    
    # Attempt to load actual EEG data
    for eeg_file in eeg_paths:
        try:
            print(f"📂 Attempting to load: {eeg_file}")
            raw = mne.io.read_raw_edf(eeg_file, preload=True, verbose=False)
            print(f"✓ Successfully loaded EEG data from: {eeg_file}")
            break
        except FileNotFoundError:
            continue
    else:
        # Create simulated EEG data if no file is found
        print("⚠️  No EEG file found. Generating simulated EEG data for demonstration...")
        raw = create_simulated_eeg_data()
        print("✓ Simulated EEG data created")
    
    # Extract signal properties
    fs = raw.info['sfreq']
    eeg_data = raw.get_data()[0]  # First channel
    time_vector = np.arange(len(eeg_data)) / fs
    
    # Display data characteristics
    print(f"\n📊 EEG Data Characteristics:")
    print(f"  Sampling frequency: {fs} Hz")
    print(f"  Channels: {len(raw.ch_names)} ({', '.join(raw.ch_names)})")
    print(f"  Duration: {time_vector[-1]:.1f}s ({time_vector[-1]/3600:.2f}h)")
    print(f"  Data points: {len(eeg_data):,}")
    print(f"  Amplitude range: [{eeg_data.min():.1f}, {eeg_data.max():.1f}] µV")
    
    return raw, eeg_data, fs, time_vector

def create_simulated_eeg_data():
    """Create realistic simulated EEG data for demonstration purposes."""
    # Simulate 24-hour recording
    duration = 86400  # seconds
    fs = 512  # Hz (from metadata)
    n_samples = int(duration * fs)
    time_vector = np.arange(n_samples) / fs
    
    # Generate multi-component EEG-like signal
    np.random.seed(42)  # Reproducible simulation
    
    # Base noise
    eeg_signal = np.random.randn(n_samples) * 10
    
    # Add physiologically relevant frequency components
    frequency_components = {
        'delta': (0.5, 2.0, 80),    # Deep sleep dominant
        'theta': (4.0, 8.0, 40),    # Light sleep, REM
        'alpha': (8.0, 13.0, 30),   # Relaxed wakefulness
        'beta': (13.0, 30.0, 20),   # Active wakefulness
    }
    
    for band, (f_low, f_high, amplitude) in frequency_components.items():
        freq = np.random.uniform(f_low, f_high)
        eeg_signal += amplitude * np.sin(2 * np.pi * freq * time_vector)
    
    # Add sleep-wake cycle modulation (circadian rhythm simulation)
    circadian_modulation = np.sin(2 * np.pi * time_vector / 86400) * 20
    eeg_signal += circadian_modulation
    
    # Create MNE Raw object
    info = mne.create_info(['EEG'], sfreq=fs, ch_types=['eeg'])
    raw = mne.io.RawArray(eeg_signal.reshape(1, -1), info)
    
    return raw

# Load EEG data using the robust loading function
raw, eeg_data, fs, time_vector = load_eeg_data()

In [None]:
# ==============================================================================
# CLUSTER LABELS EXTRACTION AND VALIDATION
# ==============================================================================

def extract_cluster_labels(results, verbose=True):
    """
    Intelligently extract cluster labels from results with comprehensive search.
    
    Args:
        results (dict): Loaded results dictionary
        verbose (bool): Print detailed search information
        
    Returns:
        numpy.ndarray: Extracted cluster labels
    """
    cluster_labels = None
    
    # Define potential keys where cluster labels might be stored
    potential_keys = [
        'cluster_labels', 'labels', 'predictions', 'y_pred', 
        'clusters', 'clustering_labels', 'cluster_assignments'
    ]
    
    if verbose:
        print("🔍 Searching for cluster labels...")
    
    # Search in main results dictionary
    for key in potential_keys:
        if key in results and results[key] is not None:
            cluster_labels = results[key]
            if verbose:
                print(f"✓ Found cluster labels in: '{key}'")
            break
    
    # Search in nested 'results' dictionary if not found
    if cluster_labels is None and 'results' in results:
        nested = results['results']
        for key in potential_keys:
            if key in nested and nested[key] is not None:
                cluster_labels = nested[key]
                if verbose:
                    print(f"✓ Found cluster labels in: 'results.{key}'")
                break
    
    # Fallback: find first suitable array-like object
    if cluster_labels is None:
        if verbose:
            print("⚠️  Cluster labels not found in expected locations. Available keys:")
            for key, value in results.items():
                print(f"  {key:20}: {type(value).__name__}")
                if hasattr(value, 'shape'):
                    print(f"{'':<20}    Shape: {value.shape}")
        
        # Try to use first 1D array as cluster labels
        for key, value in results.items():
            if hasattr(value, 'shape') and len(value.shape) == 1 and len(value) > 100:
                cluster_labels = value
                if verbose:
                    print(f"✓ Using '{key}' as cluster labels (first suitable 1D array)")
                break
    
    if cluster_labels is None:
        raise ValueError("❌ Could not locate cluster labels in results. Please verify data structure.")
    
    return np.array(cluster_labels)

def validate_cluster_labels(cluster_labels, verbose=True):
    """Validate and analyze cluster label properties."""
    unique_labels = np.unique(cluster_labels)
    label_counts = np.bincount(cluster_labels.astype(int))
    
    if verbose:
        print(f"\n📊 Cluster Labels Validation:")
        print(f"  Total windows: {len(cluster_labels):,}")
        print(f"  Unique clusters: {len(unique_labels)} {list(unique_labels)}")
        print(f"  Label distribution:")
        for i, count in enumerate(label_counts):
            if count > 0:
                percentage = count / len(cluster_labels) * 100
                print(f"    Cluster {i}: {count:,} ({percentage:.1f}%)")
    
    return unique_labels, label_counts

# Extract and validate cluster labels
cluster_labels = extract_cluster_labels(results)
unique_clusters, cluster_counts = validate_cluster_labels(cluster_labels)

# Calculate temporal alignment parameters
step_size = window_size * (1 - overlap)  # Time step between consecutive windows
n_windows = len(cluster_labels)
cluster_times = np.arange(n_windows) * step_size

print(f"\n⏱️  Temporal Alignment:")
print(f"  Window size: {window_size}s")
print(f"  Step size: {step_size}s") 
print(f"  Total coverage: {cluster_times[-1] + window_size:.1f}s ({(cluster_times[-1] + window_size)/3600:.2f}h)")
print(f"  Effective recording: {len(cluster_labels)} windows")

print("\n✓ Cluster extraction and validation completed")

In [None]:
# ==============================================================================
# INTERACTIVE HYPNOGRAM VISUALIZATION
# ==============================================================================

def create_hypnogram(cluster_labels, cluster_times, window_size, max_points=5000):
    """
    Create an interactive hypnogram showing sleep cluster evolution over time.
    
    Args:
        cluster_labels (array): Cluster assignments for each time window
        cluster_times (array): Start times for each window
        window_size (float): Duration of each window in seconds
        max_points (int): Maximum number of points to plot for performance
        
    Returns:
        plotly.graph_objects.Figure: Interactive hypnogram plot
    """
    
    # Define distinct colors for each cluster (sleep stage)
    cluster_colors = {
        0: '#1f77b4',  # Deep Blue - potentially deep sleep
        1: '#ff7f0e',  # Orange - potentially light sleep  
        2: '#2ca02c',  # Green - potentially REM sleep
        3: '#d62728',  # Red - potentially wake/arousal
        4: '#9467bd',  # Purple - additional cluster
        5: '#8c564b',  # Brown - additional cluster
    }
    
    # Downsample for visualization performance if dataset is large
    if len(cluster_labels) > max_points:
        print(f"📊 Downsampling to {max_points:,} points from {len(cluster_labels):,} for optimal visualization...")
        sample_indices = np.linspace(0, len(cluster_labels)-1, max_points, dtype=int)
        labels_plot = cluster_labels[sample_indices]
        times_plot = cluster_times[sample_indices]
    else:
        labels_plot = cluster_labels
        times_plot = cluster_times
    
    # Create interactive figure
    fig = go.Figure()
    
    # Add scatter plot for each cluster with distinct styling
    for cluster_id in np.unique(labels_plot):
        cluster_mask = labels_plot == cluster_id
        times_subset = times_plot[cluster_mask] / 3600  # Convert to hours
        labels_subset = labels_plot[cluster_mask]
        
        if len(times_subset) > 0:
            color = cluster_colors.get(cluster_id, '#17becf')
            
            fig.add_trace(go.Scatter(
                x=times_subset,
                y=labels_subset,
                mode='markers',
                marker=dict(
                    size=6,
                    color=color,
                    symbol='square',
                    opacity=0.8,
                    line=dict(width=1, color='white')
                ),
                name=f'Cluster {cluster_id}',
                text=[f'Time: {t:.2f}h<br>Cluster: {c}<br>Duration: {window_size}s' 
                      for t, c in zip(times_subset, labels_subset)],
                hovertemplate='%{text}<extra></extra>'
            ))
    
    # Add trend line to show cluster progression
    times_hours = times_plot / 3600
    fig.add_trace(go.Scatter(
        x=times_hours,
        y=labels_plot,
        mode='lines',
        line=dict(width=1.5, color='rgba(100,100,100,0.4)'),
        name='Sleep Progression',
        showlegend=True,
        hoverinfo='skip'
    ))
    
    # Configure layout for optimal viewing
    fig.update_layout(
        title={
            'text': 'Sleep Stage Hypnogram - BiLSTM Unsupervised Clustering',
            'x': 0.5,
            'font': {'size': 16}
        },
        xaxis_title='Time (hours)',
        yaxis_title='Sleep Cluster (Stage)',
        height=500,
        yaxis=dict(
            tickmode='linear', 
            tick0=0, 
            dtick=1,
            title_font_size=14
        ),
        xaxis=dict(
            title_font_size=14,
            tickformat='.1f'
        ),
        hovermode='closest',
        legend=dict(
            orientation='h', 
            yanchor='bottom', 
            y=1.02, 
            xanchor='center', 
            x=0.5
        ),
        template='plotly_white'
    )
    
    return fig

# Generate and display the hypnogram
print("🎨 Generating interactive hypnogram...")
hypnogram_fig = create_hypnogram(cluster_labels, cluster_times, window_size)
hypnogram_fig.show()
print("✓ Hypnogram visualization complete")

In [None]:
# ==============================================================================
# EEG SEGMENT VISUALIZATION
# ==============================================================================

def plot_eeg_segment(start_sec, end_sec, eeg_data, time_vector, 
                     cluster_labels, cluster_times, window_size):
    """
    Plot EEG data segment with cluster annotations for detailed analysis.
    
    Args:
        start_sec, end_sec (float): Time range to visualize in seconds
        eeg_data (array): EEG signal data
        time_vector (array): Time points for EEG data
        cluster_labels (array): Cluster assignments
        cluster_times (array): Cluster window start times
        window_size (float): Duration of each cluster window
        
    Returns:
        plotly.graph_objects.Figure: Interactive EEG segment plot
    """
    
    # Extract EEG segment
    start_idx = int(start_sec * fs)
    end_idx = int(end_sec * fs)
    segment_data = eeg_data[start_idx:end_idx]
    segment_time = time_vector[start_idx:end_idx]
    
    # Find overlapping cluster windows
    cluster_mask = (cluster_times >= start_sec - window_size) & (cluster_times <= end_sec)
    relevant_clusters = cluster_labels[cluster_mask]
    relevant_times = cluster_times[cluster_mask]
    
    # Create dual-panel figure
    fig = make_subplots(
        rows=2, cols=1, 
        shared_xaxes=True,
        subplot_titles=['EEG Signal with Cluster Annotations', 'Cluster Assignment Timeline'],
        vertical_spacing=0.12,
        row_heights=[0.75, 0.25]
    )
    
    # Plot EEG signal
    fig.add_trace(go.Scatter(
        x=segment_time,
        y=segment_data,
        mode='lines',
        name='EEG Signal',
        line=dict(width=1.2, color='#2E86C1'),
        hovertemplate='Time: %{x:.1f}s<br>Amplitude: %{y:.1f}μV<extra></extra>'
    ), row=1, col=1)
    
    # Color map for clusters
    cluster_colors = {0: '#E74C3C', 1: '#F39C12', 2: '#27AE60', 3: '#8E44AD'}
    
    # Add cluster annotations as background regions
    for time_start, cluster in zip(relevant_times, relevant_clusters):
        color = cluster_colors.get(cluster, '#BDC3C7')
        
        # Background shading on EEG plot
        fig.add_vrect(
            x0=time_start,
            x1=time_start + window_size,
            fillcolor=color,
            opacity=0.2,
            layer="below",
            line_width=0,
            row=1, col=1
        )
        
        # Cluster timeline bar
        fig.add_shape(
            type="rect",
            x0=time_start,
            x1=time_start + window_size,
            y0=cluster - 0.35,
            y1=cluster + 0.35,
            fillcolor=color,
            line=dict(width=1, color=color),
            opacity=0.8,
            row=2, col=1
        )
    
    # Configure layout
    fig.update_layout(
        title={
            'text': f'EEG Segment Analysis: {start_sec}s - {end_sec}s ({(end_sec-start_sec)/60:.1f} min)',
            'x': 0.5,
            'font': {'size': 14}
        },
        height=600,
        showlegend=False,
        hovermode='x unified'
    )
    
    # Update axes
    fig.update_xaxes(title_text='Time (seconds)', row=2, col=1)
    fig.update_yaxes(title_text='Amplitude (μV)', row=1, col=1)
    fig.update_yaxes(title_text='Cluster', row=2, col=1, 
                     tickmode='linear', tick0=0, dtick=1)
    
    return fig

# Demonstrate with first 10 minutes of recording
demo_start, demo_end = 0, 600  # 10 minutes
print(f"🔍 Analyzing EEG segment: {demo_start}s - {demo_end}s")
segment_fig = plot_eeg_segment(demo_start, demo_end, eeg_data, time_vector, 
                              cluster_labels, cluster_times, window_size)
segment_fig.show()
print("✓ Segment visualization complete")

In [None]:
# ==============================================================================
# CLUSTER DISTRIBUTION ANALYSIS
# ==============================================================================

def analyze_cluster_distribution(cluster_labels, window_size):
    """
    Comprehensive analysis of cluster distribution patterns.
    
    Args:
        cluster_labels (array): Cluster assignments for each window
        window_size (float): Duration of each window in seconds
        
    Returns:
        pandas.DataFrame: Cluster statistics summary
    """
    unique_clusters, counts = np.unique(cluster_labels, return_counts=True)
    total_windows = len(cluster_labels)
    total_duration_hrs = total_windows * window_size / 3600
    
    print("📊 CLUSTER DISTRIBUTION ANALYSIS")
    print("=" * 50)
    print(f"Recording Overview:")
    print(f"  Total windows: {total_windows:,}")
    print(f"  Total duration: {total_duration_hrs:.2f} hours")
    print(f"  Window size: {window_size}s")
    print(f"  Unique clusters: {len(unique_clusters)}")
    
    print(f"\nCluster Breakdown:")
    cluster_stats = []
    
    for cluster, count in zip(unique_clusters, counts):
        proportion = count / total_windows * 100
        duration_hours = count * window_size / 3600
        duration_minutes = duration_hours * 60
        
        print(f"  Cluster {cluster}: {count:,} windows ({proportion:5.1f}%) "
              f"→ {duration_hours:5.2f}h ({duration_minutes:6.1f}min)")
        
        cluster_stats.append({
            'cluster': cluster,
            'count': count,
            'proportion_pct': proportion,
            'duration_hours': duration_hours,
            'duration_minutes': duration_minutes
        })
    
    print(f"\nInterpretation Notes:")
    print(f"  • Higher proportion clusters may represent dominant sleep stages")
    print(f"  • Consider circadian patterns and typical sleep architecture")
    print(f"  • Cluster transitions indicate sleep stage stability")
    
    return pd.DataFrame(cluster_stats)

# Analyze cluster distribution
cluster_dist_df = analyze_cluster_distribution(cluster_labels, window_size)
print("\n✓ Cluster distribution analysis completed")

In [None]:
# ==============================================================================
# CLUSTER DISTRIBUTION VISUALIZATION
# ==============================================================================

def create_cluster_distribution_plots(cluster_dist_df, cluster_labels, cluster_times):
    """Create comprehensive cluster distribution visualizations."""
    
    # Create 2x2 subplot layout
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=[
            'Cluster Window Counts', 
            'Sleep Stage Proportions (%)', 
            'Duration Distribution (Hours)', 
            'Temporal Evolution'
        ],
        specs=[
            [{'type': 'bar'}, {'type': 'pie'}],
            [{'type': 'bar'}, {'type': 'scatter'}]
        ],
        horizontal_spacing=0.1,
        vertical_spacing=0.15
    )
    
    # 1. Bar chart of window counts
    fig.add_trace(go.Bar(
        x=cluster_dist_df['cluster'],
        y=cluster_dist_df['count'],
        name='Window Count',
        marker_color='lightblue',
        text=cluster_dist_df['count'],
        textposition='outside',
        hovertemplate='Cluster %{x}<br>Count: %{y}<extra></extra>'
    ), row=1, col=1)
    
    # 2. Pie chart of proportions  
    fig.add_trace(go.Pie(
        labels=[f'Cluster {c}' for c in cluster_dist_df['cluster']],
        values=cluster_dist_df['proportion_pct'],
        name='Proportion',
        textinfo='label+percent',
        hovertemplate='%{label}<br>%{percent}<br>%{value:.1f}%<extra></extra>'
    ), row=1, col=2)
    
    # 3. Duration in hours
    fig.add_trace(go.Bar(
        x=cluster_dist_df['cluster'],
        y=cluster_dist_df['duration_hours'],
        name='Duration (Hours)',
        marker_color='lightgreen',
        text=[f'{h:.1f}h' for h in cluster_dist_df['duration_hours']],
        textposition='outside',
        hovertemplate='Cluster %{x}<br>Duration: %{y:.2f}h<extra></extra>'
    ), row=2, col=1)
    
    # 4. Temporal evolution scatter plot (downsampled for performance)
    max_points = 2000
    if len(cluster_labels) > max_points:
        step = len(cluster_labels) // max_points
        times_plot = cluster_times[::step] / 3600  # Convert to hours
        labels_plot = cluster_labels[::step]
    else:
        times_plot = cluster_times / 3600
        labels_plot = cluster_labels
    
    fig.add_trace(go.Scatter(
        x=times_plot,
        y=labels_plot,
        mode='markers',
        marker=dict(size=3, opacity=0.6, color=labels_plot, colorscale='Viridis'),
        name='Sleep Evolution',
        hovertemplate='Time: %{x:.1f}h<br>Cluster: %{y}<extra></extra>'
    ), row=2, col=2)
    
    # Update layout and axes
    fig.update_layout(
        height=700,
        title={
            'text': 'Comprehensive Cluster Distribution Analysis',
            'x': 0.5,
            'font': {'size': 16}
        },
        showlegend=False
    )
    
    # Configure axes labels
    fig.update_xaxes(title_text='Cluster ID', row=1, col=1)
    fig.update_yaxes(title_text='Number of Windows', row=1, col=1)
    fig.update_xaxes(title_text='Cluster ID', row=2, col=1)
    fig.update_yaxes(title_text='Duration (Hours)', row=2, col=1)
    fig.update_xaxes(title_text='Recording Time (Hours)', row=2, col=2)
    fig.update_yaxes(title_text='Cluster ID', row=2, col=2)
    
    return fig

# Create and display comprehensive distribution plots
print("🎨 Creating cluster distribution visualizations...")
distribution_fig = create_cluster_distribution_plots(cluster_dist_df, cluster_labels, cluster_times)
distribution_fig.show()
print("✓ Distribution plots generated successfully")

In [None]:
# ==============================================================================
# CONTINUOUS DURATION ANALYSIS
# ==============================================================================

def analyze_continuous_durations(cluster_labels, window_size):
    """
    Analyze the continuous duration of each cluster to understand sleep stage stability.
    
    Args:
        cluster_labels (array): Cluster assignments for each window
        window_size (float): Duration of each window in seconds
        
    Returns:
        tuple: (continuous_durations dict, duration_stats DataFrame)
    """
    # Initialize storage for continuous segments
    continuous_durations = {cluster: [] for cluster in np.unique(cluster_labels)}
    
    # Track current segment
    current_cluster = cluster_labels[0]
    current_duration = 1  # Number of consecutive windows
    
    print("🔍 CONTINUOUS DURATION ANALYSIS")
    print("=" * 50)
    
    # Process each label to find continuous segments
    for i in range(1, len(cluster_labels)):
        if cluster_labels[i] == current_cluster:
            # Continue current segment
            current_duration += 1
        else:
            # End of continuous segment - record duration
            duration_seconds = current_duration * window_size
            continuous_durations[current_cluster].append(duration_seconds)
            
            # Start new segment
            current_cluster = cluster_labels[i]
            current_duration = 1
    
    # Don't forget the final segment
    duration_seconds = current_duration * window_size
    continuous_durations[current_cluster].append(duration_seconds)
    
    # Calculate comprehensive statistics
    duration_stats = []
    total_segments = sum(len(durations) for durations in continuous_durations.values())
    
    print(f"Sleep Architecture Summary:")
    print(f"  Total continuous segments: {total_segments}")
    print(f"  Average transitions per hour: {total_segments / (len(cluster_labels) * window_size / 3600):.1f}")
    
    print(f"\nCluster-Specific Duration Statistics:")
    for cluster in sorted(continuous_durations.keys()):
        durations = continuous_durations[cluster]
        
        if durations:
            # Calculate statistics in minutes for readability
            durations_min = np.array(durations) / 60
            
            stats = {
                'cluster': cluster,
                'segment_count': len(durations),
                'mean_duration_min': np.mean(durations_min),
                'median_duration_min': np.median(durations_min),
                'max_duration_min': np.max(durations_min),
                'min_duration_min': np.min(durations_min),
                'std_duration_min': np.std(durations_min),
                'total_time_min': np.sum(durations_min)
            }
            
            print(f"  Cluster {cluster}:")
            print(f"    Segments: {stats['segment_count']:3d} | "
                  f"Mean: {stats['mean_duration_min']:5.1f}min | "
                  f"Median: {stats['median_duration_min']:5.1f}min")
            print(f"    Range: {stats['min_duration_min']:5.1f}min - {stats['max_duration_min']:5.1f}min | "
                  f"Total: {stats['total_time_min']:6.1f}min")
            
            duration_stats.append(stats)
    
    duration_stats_df = pd.DataFrame(duration_stats)
    
    print(f"\n📊 Sleep Stage Stability Analysis:")
    if len(duration_stats_df) > 0:
        most_stable = duration_stats_df.loc[duration_stats_df['mean_duration_min'].idxmax()]
        longest_segment = duration_stats_df.loc[duration_stats_df['max_duration_min'].idxmax()]
        
        print(f"  Most stable cluster: {most_stable['cluster']} (avg {most_stable['mean_duration_min']:.1f}min)")
        print(f"  Longest single segment: Cluster {longest_segment['cluster']} ({longest_segment['max_duration_min']:.1f}min)")
    
    return continuous_durations, duration_stats_df

# Analyze continuous durations
continuous_durations, duration_stats_df = analyze_continuous_durations(cluster_labels, window_size)
print("✓ Continuous duration analysis completed")

In [None]:
# ==============================================================================
# CONTINUOUS DURATION VISUALIZATION
# ==============================================================================

def create_duration_analysis_plots(duration_stats_df, continuous_durations):
    """Create comprehensive visualization of continuous duration patterns."""
    
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=[
            'Average Duration by Cluster',
            'Duration Distribution Histograms', 
            'Segment Count Analysis',
            'Duration Boxplots'
        ],
        specs=[
            [{'type': 'bar'}, {'type': 'histogram'}],
            [{'type': 'bar'}, {'type': 'box'}]
        ],
        horizontal_spacing=0.1,
        vertical_spacing=0.15
    )
    
    # 1. Bar chart of mean durations
    fig.add_trace(go.Bar(
        x=duration_stats_df['cluster'],
        y=duration_stats_df['mean_duration_min'],
        name='Average Duration',
        marker_color='lightcoral',
        text=[f'{dur:.1f}min' for dur in duration_stats_df['mean_duration_min']],
        textposition='outside',
        hovertemplate='Cluster %{x}<br>Mean Duration: %{y:.1f}min<extra></extra>'
    ), row=1, col=1)
    
    # 2. Overlapping histograms of duration distributions
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
    for i, cluster in enumerate(sorted(continuous_durations.keys())):
        durations_min = np.array(continuous_durations[cluster]) / 60
        
        fig.add_trace(go.Histogram(
            x=durations_min,
            name=f'Cluster {cluster}',
            opacity=0.7,
            marker_color=colors[i % len(colors)],
            nbinsx=20,
            showlegend=True
        ), row=1, col=2)
    
    # 3. Segment count analysis
    fig.add_trace(go.Bar(
        x=duration_stats_df['cluster'],
        y=duration_stats_df['segment_count'],
        name='Segment Count',
        marker_color='lightgreen',
        text=duration_stats_df['segment_count'],
        textposition='outside',
        hovertemplate='Cluster %{x}<br>Segments: %{y}<extra></extra>'
    ), row=2, col=1)
    
    # 4. Box plots showing duration distributions
    for cluster in sorted(continuous_durations.keys()):
        durations_min = np.array(continuous_durations[cluster]) / 60
        
        fig.add_trace(go.Box(
            y=durations_min,
            name=f'Cluster {cluster}',
            boxpoints='outliers',
            showlegend=False,
            marker_color=colors[cluster % len(colors)]
        ), row=2, col=2)
    
    # Update layout
    fig.update_layout(
        height=700,
        title={
            'text': 'Sleep Stage Continuity Analysis - Duration Patterns',
            'x': 0.5,
            'font': {'size': 16}
        }
    )
    
    # Configure axes
    fig.update_xaxes(title_text='Cluster ID', row=1, col=1)
    fig.update_yaxes(title_text='Mean Duration (minutes)', row=1, col=1)
    fig.update_xaxes(title_text='Duration (minutes)', row=1, col=2)
    fig.update_yaxes(title_text='Frequency', row=1, col=2)
    fig.update_xaxes(title_text='Cluster ID', row=2, col=1)
    fig.update_yaxes(title_text='Number of Segments', row=2, col=1)
    fig.update_xaxes(title_text='Cluster ID', row=2, col=2)
    fig.update_yaxes(title_text='Duration (minutes)', row=2, col=2)
    
    # Configure histogram to be overlaid
    fig.update_layout(barmode='overlay')
    
    return fig

# Create and display duration analysis plots
print("🎨 Creating duration analysis visualizations...")
duration_fig = create_duration_analysis_plots(duration_stats_df, continuous_durations)
duration_fig.show()
print("✓ Duration analysis plots generated successfully")

In [None]:
# ==============================================================================
# SPECTRAL ANALYSIS - MULTITAPER SPECTROGRAM
# ==============================================================================

def compute_multitaper_spectrogram(data, fs, window_length=30, overlap=0.75):
    """
    Compute multitaper spectrogram for enhanced spectral resolution.
    
    Args:
        data (array): EEG time series data
        fs (float): Sampling frequency
        window_length (float): Window length in seconds
        overlap (float): Overlap ratio between windows
        
    Returns:
        tuple: (frequencies, times, power_spectral_density)
    """
    from scipy.signal import spectrogram
    
    # Calculate spectrogram parameters
    nperseg = int(window_length * fs)
    noverlap = int(nperseg * overlap)
    
    print(f"🔬 Computing Multitaper Spectrogram:")
    print(f"  Window length: {window_length}s ({nperseg} samples)")
    print(f"  Overlap: {overlap*100}% ({noverlap} samples)") 
    print(f"  Frequency resolution: {fs/nperseg:.3f} Hz")
    
    # Compute spectrogram
    f, t, Sxx = spectrogram(
        data, fs, 
        window='hann',  # Hann window for good spectral properties
        nperseg=nperseg, 
        noverlap=noverlap,
        scaling='density'
    )
    
    print(f"  Spectrogram shape: {Sxx.shape}")
    print(f"  Frequency range: {f[0]:.1f} - {f[-1]:.1f} Hz")
    print(f"  Time coverage: {t[-1]:.1f}s ({t[-1]/60:.1f}min)")
    
    return f, t, Sxx

# Analyze a representative segment (limit to first hour for computation efficiency)
analysis_duration = min(3600, len(eeg_data)/fs)  # 1 hour or available data
analysis_samples = int(analysis_duration * fs)
eeg_segment = eeg_data[:analysis_samples]

print(f"📊 Analyzing {analysis_duration/60:.1f} minutes of EEG data for spectral characteristics...")
frequencies, spec_times, power_spectral_density = compute_multitaper_spectrogram(eeg_segment, fs)

# Convert to decibels for visualization
psd_db = 10 * np.log10(power_spectral_density + 1e-12)  # Add small constant to avoid log(0)

print("✓ Spectral analysis completed")

In [None]:
# ==============================================================================
# COMBINED HYPNOGRAM AND SPECTROGRAM VISUALIZATION
# ==============================================================================

def create_hypnogram_spectrogram_plot(cluster_labels, cluster_times, window_size, 
                                      frequencies, spec_times, psd_db, 
                                      display_duration=1800):
    """
    Create integrated visualization combining sleep hypnogram with spectral analysis.
    
    Args:
        cluster_labels, cluster_times: Sleep cluster data
        window_size: Cluster window duration
        frequencies, spec_times, psd_db: Spectrogram data
        display_duration: Time range to display in seconds
        
    Returns:
        plotly.graph_objects.Figure: Combined plot
    """
    
    # Filter data to display range
    time_mask = spec_times <= display_duration
    spec_times_display = spec_times[time_mask]
    psd_display = psd_db[:, time_mask]
    
    cluster_mask = cluster_times <= display_duration
    cluster_times_display = cluster_times[cluster_mask]
    cluster_labels_display = cluster_labels[cluster_mask]
    
    # Create subplot structure
    fig = make_subplots(
        rows=2, cols=1,
        shared_xaxes=True,
        subplot_titles=[
            f'Sleep Stage Hypnogram ({display_duration/60:.0f} min)',
            'Multitaper Spectrogram (0-50 Hz)'
        ],
        vertical_spacing=0.12,
        row_heights=[0.25, 0.75]
    )
    
    # Define cluster visualization colors
    cluster_colors = {
        0: '#3498DB',  # Blue - Deep sleep
        1: '#E67E22',  # Orange - Light sleep
        2: '#27AE60',  # Green - REM sleep  
        3: '#E74C3C',  # Red - Wake/Arousal
    }
    
    # Add hypnogram as colored rectangles
    for time_start, cluster in zip(cluster_times_display, cluster_labels_display):
        color = cluster_colors.get(cluster, '#95A5A6')
        
        fig.add_shape(
            type="rect",
            x0=time_start,
            x1=time_start + window_size,
            y0=cluster - 0.35,
            y1=cluster + 0.35,
            fillcolor=color,
            line=dict(width=1, color=color),
            opacity=0.8,
            row=1, col=1
        )
    
    # Add spectrogram heatmap
    fig.add_trace(go.Heatmap(
        x=spec_times_display,
        y=frequencies,
        z=psd_display,
        colorscale='Viridis',
        colorbar=dict(
            title='Power Spectral Density (dB)',
            x=1.02,
            len=0.75,
            y=0.4
        ),
        name='Spectrogram',
        hovertemplate='Time: %{x:.1f}s<br>Frequency: %{y:.1f}Hz<br>Power: %{z:.1f}dB<extra></extra>'
    ), row=2, col=1)
    
    # Configure layout
    fig.update_layout(
        title={
            'text': 'Integrated Sleep Analysis: Hypnogram with Spectral Dynamics',
            'x': 0.5,
            'font': {'size': 16}
        },
        height=700,
        showlegend=False
    )
    
    # Configure axes
    fig.update_xaxes(title_text='Time (seconds)', row=2, col=1)
    fig.update_yaxes(
        title_text='Sleep Stage', 
        tickmode='linear', 
        tick0=0, 
        dtick=1,
        row=1, col=1
    )
    fig.update_yaxes(
        title_text='Frequency (Hz)', 
        range=[0, 50],  # Focus on sleep-relevant frequencies
        row=2, col=1
    )
    
    return fig

# Generate combined visualization
print("🎨 Creating integrated hypnogram-spectrogram visualization...")
display_time = min(1800, len(eeg_data)/fs)  # 30 minutes or available data
combined_plot = create_hypnogram_spectrogram_plot(
    cluster_labels, cluster_times, window_size,
    frequencies, spec_times, psd_db,
    display_duration=display_time
)
combined_plot.show()
print("✓ Combined visualization completed")

In [None]:
# ==============================================================================
# FREQUENCY BAND ANALYSIS
# ==============================================================================

def analyze_frequency_bands(frequencies, spec_times, power_spectral_density, 
                           cluster_labels, cluster_times, window_size):
    """
    Analyze EEG frequency band power distribution across sleep clusters.
    
    Args:
        frequencies: Frequency bins from spectrogram
        spec_times: Time points from spectrogram  
        power_spectral_density: Spectral power matrix
        cluster_labels, cluster_times: Sleep cluster data
        window_size: Cluster window duration
        
    Returns:
        pandas.DataFrame: Band power statistics by cluster
    """
    
    # Define physiologically relevant EEG frequency bands
    frequency_bands = {
        'Delta (0.5-4 Hz)': (0.5, 4.0),    # Deep sleep, slow-wave activity
        'Theta (4-8 Hz)': (4.0, 8.0),      # Light sleep, drowsiness, REM
        'Alpha (8-13 Hz)': (8.0, 13.0),    # Relaxed wakefulness, eyes closed
        'Beta (13-30 Hz)': (13.0, 30.0),   # Active wakefulness, cognitive activity
        'Gamma (30-50 Hz)': (30.0, 50.0)   # High-frequency activity, cognition
    }
    
    print("🧠 FREQUENCY BAND ANALYSIS")
    print("=" * 50)
    print(f"Analyzing {len(frequency_bands)} frequency bands across {len(np.unique(cluster_labels))} clusters")
    
    # Pre-compute frequency band masks for efficiency
    band_masks = {}
    for band_name, (f_low, f_high) in frequency_bands.items():
        mask = (frequencies >= f_low) & (frequencies <= f_high)
        band_masks[band_name] = mask
        n_bins = np.sum(mask)
        print(f"  {band_name}: {n_bins} frequency bins ({f_low}-{f_high} Hz)")
    
    # Calculate band power for each time window
    band_powers = {}
    for band_name, freq_mask in band_masks.items():
        # Average power across frequency bins within the band
        band_powers[band_name] = np.mean(power_spectral_density[freq_mask, :], axis=0)
    
    # Align clusters with spectrogram time points
    cluster_band_stats = []
    analysis_limit = min(len(cluster_labels), 500)  # Limit for computational efficiency
    
    print(f"\nProcessing {analysis_limit} cluster windows...")
    
    for i in range(0, analysis_limit, max(1, analysis_limit//100)):  # Sample for efficiency
        cluster_time = cluster_times[i]
        cluster_id = cluster_labels[i]
        
        # Find closest spectrogram time point
        time_idx = np.argmin(np.abs(spec_times - cluster_time))
        
        # Extract band powers for this time point
        for band_name in frequency_bands.keys():
            power_value = band_powers[band_name][time_idx]
            
            cluster_band_stats.append({
                'cluster': cluster_id,
                'band': band_name,
                'power': power_value,
                'log_power': np.log10(power_value + 1e-12),  # Avoid log(0)
                'time_point': cluster_time
            })
    
    # Convert to DataFrame and calculate aggregate statistics
    band_df = pd.DataFrame(cluster_band_stats)
    
    # Calculate summary statistics by cluster and band
    band_summary = band_df.groupby(['cluster', 'band']).agg({
        'power': ['mean', 'std', 'median'],
        'log_power': ['mean', 'std']
    }).round(4)
    
    band_summary.columns = ['mean_power', 'std_power', 'median_power', 'mean_log_power', 'std_log_power']
    band_summary = band_summary.reset_index()
    
    print(f"\n📊 Band Power Analysis Summary:")
    for cluster in sorted(band_summary['cluster'].unique()):
        cluster_data = band_summary[band_summary['cluster'] == cluster]
        dominant_band = cluster_data.loc[cluster_data['mean_power'].idxmax(), 'band']
        max_power = cluster_data['mean_power'].max()
        print(f"  Cluster {cluster}: Dominant band = {dominant_band} (Power: {max_power:.2e})")
    
    return band_summary, band_powers

# Perform frequency band analysis
print("🔬 Starting frequency band analysis...")
band_stats_df, band_power_dict = analyze_frequency_bands(
    frequencies, spec_times, power_spectral_density,
    cluster_labels, cluster_times, window_size
)

print("✓ Frequency band analysis completed")

In [None]:
# ==============================================================================
# FREQUENCY BAND VISUALIZATION
# ==============================================================================

def create_frequency_band_plots(band_stats_df):
    """
    Create comprehensive frequency band analysis visualizations.
    
    Args:
        band_stats_df: DataFrame with band power statistics by cluster
        
    Returns:
        plotly.graph_objects.Figure: Multi-panel frequency analysis plot
    """
    
    # Pivot data for easier visualization
    power_pivot = band_stats_df.pivot(index='cluster', columns='band', values='mean_power')
    log_power_pivot = band_stats_df.pivot(index='cluster', columns='band', values='mean_log_power')
    
    # Create comprehensive subplot layout
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=[
            'Band Power by Cluster (Linear Scale)',
            'Band Power Heatmap (Log Scale)', 
            'Relative Band Power Distribution (%)',
            'Dominant Frequency Bands'
        ],
        specs=[
            [{'type': 'bar'}, {'type': 'heatmap'}],
            [{'type': 'bar'}, {'type': 'bar'}]
        ],
        horizontal_spacing=0.12,
        vertical_spacing=0.15
    )
    
    # Color scheme for clusters
    cluster_colors = ['#3498DB', '#E67E22', '#27AE60', '#E74C3C', '#9B59B6', '#1ABC9C']
    
    # 1. Grouped bar chart of absolute band powers
    for i, cluster in enumerate(sorted(band_stats_df['cluster'].unique())):
        cluster_data = band_stats_df[band_stats_df['cluster'] == cluster]
        
        fig.add_trace(go.Bar(
            x=cluster_data['band'],
            y=cluster_data['mean_power'],
            name=f'Cluster {cluster}',
            marker_color=cluster_colors[i % len(cluster_colors)],
            opacity=0.8,
            error_y=dict(type='data', array=cluster_data['std_power'], visible=True),
            hovertemplate='Band: %{x}<br>Power: %{y:.2e}<br>Cluster: ' + str(cluster) + '<extra></extra>'
        ), row=1, col=1)
    
    # 2. Heatmap of log-scaled band powers
    fig.add_trace(go.Heatmap(
        x=log_power_pivot.columns,
        y=[f'Cluster {c}' for c in log_power_pivot.index],
        z=log_power_pivot.values,
        colorscale='Viridis',
        colorbar=dict(
            title='Log₁₀(Power)',
            x=0.48,
            y=0.8,
            len=0.35
        ),
        hovertemplate='Band: %{x}<br>Cluster: %{y}<br>Log Power: %{z:.2f}<extra></extra>'
    ), row=1, col=2)
    
    # 3. Relative band power (normalized to 100% per cluster)
    for i, cluster in enumerate(sorted(band_stats_df['cluster'].unique())):
        cluster_data = band_stats_df[band_stats_df['cluster'] == cluster]
        total_power = cluster_data['mean_power'].sum()
        relative_power = (cluster_data['mean_power'] / total_power) * 100
        
        fig.add_trace(go.Bar(
            x=cluster_data['band'],
            y=relative_power,
            name=f'Cluster {cluster} (Relative)',
            marker_color=cluster_colors[i % len(cluster_colors)],
            opacity=0.8,
            showlegend=False,
            hovertemplate='Band: %{x}<br>Relative Power: %{y:.1f}%<br>Cluster: ' + str(cluster) + '<extra></extra>'
        ), row=2, col=1)
    
    # 4. Dominant band analysis
    dominant_bands = []
    for cluster in sorted(band_stats_df['cluster'].unique()):
        cluster_data = band_stats_df[band_stats_df['cluster'] == cluster]
        dominant = cluster_data.loc[cluster_data['mean_power'].idxmax(), 'band']
        dominant_power = cluster_data['mean_power'].max()
        dominant_bands.append({'cluster': cluster, 'dominant_band': dominant, 'max_power': dominant_power})
    
    dominant_df = pd.DataFrame(dominant_bands)
    
    fig.add_trace(go.Bar(
        x=[f'Cluster {c}' for c in dominant_df['cluster']],
        y=dominant_df['max_power'],
        text=dominant_df['dominant_band'],
        textposition='outside',
        marker_color='lightsteelblue',
        name='Dominant Band Power',
        showlegend=False,
        hovertemplate='Cluster: %{x}<br>Dominant Band: %{text}<br>Power: %{y:.2e}<extra></extra>'
    ), row=2, col=2)
    
    # Configure layout
    fig.update_layout(
        height=800,
        title={
            'text': 'Comprehensive EEG Frequency Band Analysis Across Sleep Clusters',
            'x': 0.5,
            'font': {'size': 16}
        },
        showlegend=True,
        legend=dict(
            orientation='h',
            yanchor='bottom',
            y=1.02,
            xanchor='center',
            x=0.5
        )
    )
    
    # Configure axes
    fig.update_xaxes(title_text='Frequency Band', row=1, col=1, tickangle=45)
    fig.update_yaxes(title_text='Mean Power', row=1, col=1, type='log')  # Log scale for better visualization
    fig.update_xaxes(title_text='Frequency Band', row=2, col=1, tickangle=45)
    fig.update_yaxes(title_text='Relative Power (%)', row=2, col=1)
    fig.update_xaxes(title_text='Cluster', row=2, col=2)
    fig.update_yaxes(title_text='Dominant Band Power', row=2, col=2, type='log')
    
    return fig

# Create and display frequency band analysis plots
print("🎨 Creating frequency band visualizations...")
frequency_plots = create_frequency_band_plots(band_stats_df)
frequency_plots.show()
print("✓ Frequency band analysis plots completed")

In [None]:
# ==============================================================================
# COMPREHENSIVE ANALYSIS SUMMARY
# ==============================================================================

def generate_analysis_summary(cluster_dist_df, duration_stats_df, band_stats_df, 
                             cluster_labels, eeg_data, fs, continuous_durations):
    """Generate comprehensive summary of BiLSTM clustering analysis."""
    
    print("📋 COMPREHENSIVE ANALYSIS SUMMARY")
    print("=" * 60)
    
    # 1. Recording and Model Overview
    total_duration_hrs = len(eeg_data) / fs / 3600
    total_windows = len(cluster_labels)
    transitions_count = np.sum(np.diff(cluster_labels) != 0)
    stability_ratio = 1 - (transitions_count / len(cluster_labels))
    
    print(f"\n🔍 RECORDING OVERVIEW:")
    print(f"  Total duration: {total_duration_hrs:.2f} hours ({total_duration_hrs*60:.0f} minutes)")
    print(f"  Total analysis windows: {total_windows:,}")
    print(f"  Sampling frequency: {fs} Hz")
    print(f"  Sleep stage transitions: {transitions_count:,}")
    print(f"  Overall stability: {stability_ratio:.3f} (lower = more transitions)")
    
    # 2. Cluster Distribution Summary
    print(f"\n📊 CLUSTER DISTRIBUTION:")
    print("-" * 40)
    for _, row in cluster_dist_df.iterrows():
        print(f"  Cluster {row['cluster']:1.0f}: {row['count']:4.0f} windows "
              f"({row['proportion_pct']:5.1f}%) → {row['duration_hours']:5.2f}h")
    
    most_frequent_cluster = cluster_dist_df.loc[cluster_dist_df['count'].idxmax()]
    least_frequent_cluster = cluster_dist_df.loc[cluster_dist_df['count'].idxmin()]
    
    print(f"\n  Most frequent: Cluster {most_frequent_cluster['cluster']} ({most_frequent_cluster['proportion_pct']:.1f}%)")
    print(f"  Least frequent: Cluster {least_frequent_cluster['cluster']} ({least_frequent_cluster['proportion_pct']:.1f}%)")
    
    # 3. Sleep Architecture Analysis
    print(f"\n🏗️  SLEEP ARCHITECTURE (Duration Analysis):")
    print("-" * 40)
    if len(duration_stats_df) > 0:
        for _, row in duration_stats_df.iterrows():
            print(f"  Cluster {row['cluster']:1.0f}: {row['segment_count']:3.0f} segments | "
                  f"Avg: {row['mean_duration_min']:5.1f}min | "
                  f"Max: {row['max_duration_min']:5.1f}min")
        
        most_stable = duration_stats_df.loc[duration_stats_df['mean_duration_min'].idxmax()]
        longest_segment = duration_stats_df.loc[duration_stats_df['max_duration_min'].idxmax()]
        
        print(f"\n  Most stable cluster: {most_stable['cluster']} (avg {most_stable['mean_duration_min']:.1f}min per segment)")
        print(f"  Longest single segment: Cluster {longest_segment['cluster']} ({longest_segment['max_duration_min']:.1f}min)")
    
    # 4. Frequency Band Dominance
    print(f"\n🧠 FREQUENCY BAND ANALYSIS:")
    print("-" * 40)
    if len(band_stats_df) > 0:
        for cluster in sorted(band_stats_df['cluster'].unique()):
            cluster_bands = band_stats_df[band_stats_df['cluster'] == cluster]
            dominant_band = cluster_bands.loc[cluster_bands['mean_power'].idxmax(), 'band']
            max_power = cluster_bands['mean_power'].max()
            
            # Calculate relative band powers for interpretation
            total_power = cluster_bands['mean_power'].sum()
            dominant_percentage = (max_power / total_power) * 100
            
            print(f"  Cluster {cluster}: {dominant_band} dominant ({dominant_percentage:.1f}% of total power)")
    
    # 5. Clinical and Research Implications
    print(f"\n🏥 CLINICAL INTERPRETATION NOTES:")
    print("-" * 40)
    print("  • Delta dominance (0.5-4 Hz) → Likely deep sleep (N3 stage)")
    print("  • Theta dominance (4-8 Hz) → Likely light sleep (N1/N2) or REM")
    print("  • Alpha dominance (8-13 Hz) → Likely relaxed wakefulness or sleep onset")
    print("  • Beta/Gamma dominance (>13 Hz) → Likely active wakefulness or micro-arousals")
    
    print(f"\n💡 RESEARCH INSIGHTS:")
    print("-" * 40)
    print(f"  • BiLSTM identified {len(cluster_dist_df)} distinct sleep clusters")
    print(f"  • Sleep architecture shows {transitions_count} stage transitions")
    print(f"  • Average sleep segment duration: {np.mean([np.mean(durations) for durations in continuous_durations.values()])/60:.1f} minutes")
    print(f"  • Cluster stability suggests {'high' if stability_ratio > 0.8 else 'moderate' if stability_ratio > 0.6 else 'low'} sleep continuity")
    
    return {
        'total_duration_hrs': total_duration_hrs,
        'total_windows': total_windows,
        'transitions_count': transitions_count,
        'stability_ratio': stability_ratio,
        'most_frequent_cluster': most_frequent_cluster['cluster'],
        'dominant_bands': {cluster: band_stats_df[band_stats_df['cluster'] == cluster].loc[
            band_stats_df[band_stats_df['cluster'] == cluster]['mean_power'].idxmax(), 'band'] 
            for cluster in sorted(band_stats_df['cluster'].unique())} if len(band_stats_df) > 0 else {}
    }

# Generate comprehensive summary
analysis_summary = generate_analysis_summary(
    cluster_dist_df, duration_stats_df, band_stats_df, 
    cluster_labels, eeg_data, fs, continuous_durations
)

print("\n✅ Analysis summary completed")

In [None]:
# ==============================================================================
# INTERACTIVE ANALYSIS TOOLS
# ==============================================================================

def interactive_segment_analysis(start_sec, end_sec, plot_spectrogram=True):
    """
    Interactive function to analyze specific time segments in detail.
    
    Args:
        start_sec, end_sec (float): Time range to analyze in seconds
        plot_spectrogram (bool): Whether to generate spectrogram for the segment
        
    Usage:
        interactive_segment_analysis(600, 1200)  # Analyze 10-20 minute segment
    """
    
    print(f"\n🔍 DETAILED SEGMENT ANALYSIS")
    print(f"Time Range: {start_sec}s - {end_sec}s ({(end_sec-start_sec)/60:.1f} minutes)")
    print("=" * 50)
    
    # Find clusters in this segment
    segment_mask = (cluster_times >= start_sec) & (cluster_times <= end_sec)
    segment_clusters = cluster_labels[segment_mask]
    segment_times = cluster_times[segment_mask]
    
    if len(segment_clusters) > 0:
        unique, counts = np.unique(segment_clusters, return_counts=True)
        cluster_distribution = dict(zip(unique, counts))
        dominant_cluster = unique[np.argmax(counts)]
        num_transitions = np.sum(np.diff(segment_clusters) != 0)
        
        print(f"📊 Cluster Distribution: {cluster_distribution}")
        print(f"🎯 Dominant cluster: {dominant_cluster} ({counts[np.argmax(counts)]} windows)")
        print(f"🔄 Transitions: {num_transitions}")
        print(f"📈 Stability: {1 - (num_transitions / len(segment_clusters)):.3f}")
        
        # Calculate segment duration statistics
        total_segment_time = (end_sec - start_sec) / 60
        windows_per_minute = len(segment_clusters) / total_segment_time
        print(f"⏱️  Analysis windows: {len(segment_clusters)} ({windows_per_minute:.1f} per minute)")
    else:
        print("❌ No cluster data available for this time segment")
        return
    
    # Plot EEG segment with cluster annotations
    print(f"\n🎨 Generating EEG segment visualization...")
    segment_fig = plot_eeg_segment(start_sec, end_sec, eeg_data, time_vector,
                                   cluster_labels, cluster_times, window_size)
    segment_fig.show()
    
    # Optional spectrogram analysis for shorter segments
    if plot_spectrogram and (end_sec - start_sec) <= 600:  # Max 10 minutes for performance
        print(f"🔬 Computing segment spectrogram...")
        
        start_idx = int(start_sec * fs)
        end_idx = int(end_sec * fs)
        segment_data = eeg_data[start_idx:end_idx]
        
        if len(segment_data) > fs * 10:  # Minimum 10 seconds of data
            f_seg, t_seg, Sxx_seg = compute_multitaper_spectrogram(segment_data, fs, window_length=5)
            Sxx_seg_db = 10 * np.log10(Sxx_seg + 1e-12)
            
            # Create spectrogram plot
            spec_fig = go.Figure()
            spec_fig.add_trace(go.Heatmap(
                x=t_seg + start_sec,  # Adjust time to actual recording time
                y=f_seg,
                z=Sxx_seg_db,
                colorscale='Viridis',
                colorbar=dict(title='Power (dB)')
            ))
            
            spec_fig.update_layout(
                title=f'Segment Spectrogram: {start_sec}s - {end_sec}s',
                xaxis_title='Time (seconds)',
                yaxis_title='Frequency (Hz)',
                height=400
            )
            spec_fig.update_yaxes(range=[0, 50])  # Focus on sleep-relevant frequencies
            spec_fig.show()
            
            print("✅ Spectrogram analysis completed")
        else:
            print("⚠️  Segment too short for reliable spectrogram analysis")
    
    elif plot_spectrogram and (end_sec - start_sec) > 600:
        print("⚠️  Segment too long for spectrogram analysis (>10 min). Set plot_spectrogram=False for longer segments.")
    
    print("✅ Segment analysis completed\n")

# Demonstration of interactive analysis
print("🛠️  INTERACTIVE ANALYSIS TOOLS READY")
print("=" * 50)
print("Available Functions:")
print("  • interactive_segment_analysis(start_sec, end_sec, plot_spectrogram=True)")
print("    - Analyze any time segment in detail")
print("    - Includes EEG visualization and optional spectrogram")
print("    - Example: interactive_segment_analysis(600, 900)")
print("\nRecommended Analysis Segments:")
print("  • First hour: interactive_segment_analysis(0, 3600)")
print("  • Sleep onset: interactive_segment_analysis(0, 1800)  # First 30 min")
print("  • Mid-sleep: interactive_segment_analysis(14400, 16200)  # 4-4.5 hours in")
print("  • Wake period: interactive_segment_analysis(28800, 30600)  # 8-8.5 hours in")

print("\n💡 ANALYSIS RECOMMENDATIONS:")
print("1. 🔬 Cluster Interpretation:")
print("   - Compare frequency band dominance with physiological sleep stages")
print("   - Delta dominance → Deep sleep (N3)")
print("   - Theta dominance → Light sleep (N1/N2) or REM")
print("   - Alpha/Beta dominance → Wake or micro-arousals")
print("\n2. 📊 Model Validation:")
print("   - Check cluster consistency across different time periods")
print("   - Validate against expected sleep architecture patterns")
print("   - Consider individual sleep characteristics and disorders")
print("\n3. 🔍 Further Investigation:")
print("   - Analyze cluster transition patterns and timing")
print("   - Investigate circadian rhythm effects on clustering")
print("   - Compare results with other sleep analysis methods")

In [None]:
# ==============================================================================
# CLUSTER TRANSITION ANALYSIS
# ==============================================================================

def analyze_cluster_transitions(cluster_labels):
    """
    Analyze transitions between sleep clusters to understand sleep dynamics.
    
    Args:
        cluster_labels (array): Sequence of cluster assignments
        
    Returns:
        tuple: (transition_matrix, transition_probabilities, unique_clusters)
    """
    
    unique_clusters = sorted(np.unique(cluster_labels))
    n_clusters = len(unique_clusters)
    
    print("🔄 CLUSTER TRANSITION ANALYSIS")
    print("=" * 50)
    print(f"Analyzing transitions between {n_clusters} clusters...")
    
    # Initialize transition matrix
    transition_matrix = np.zeros((n_clusters, n_clusters), dtype=int)
    
    # Count transitions between consecutive time windows
    for i in range(len(cluster_labels) - 1):
        from_cluster = int(cluster_labels[i])
        to_cluster = int(cluster_labels[i + 1])
        transition_matrix[from_cluster, to_cluster] += 1
    
    # Calculate transition probabilities (normalize by row sums)
    row_sums = transition_matrix.sum(axis=1, keepdims=True)
    transition_probs = np.divide(transition_matrix, row_sums, 
                                out=np.zeros_like(transition_matrix, dtype=float), 
                                where=row_sums!=0)
    
    # Display transition matrices
    print(f"\n📊 Transition Matrix (Raw Counts):")
    trans_df = pd.DataFrame(
        transition_matrix, 
        index=[f'From Cluster {c}' for c in unique_clusters], 
        columns=[f'To Cluster {c}' for c in unique_clusters]
    )
    print(trans_df.to_string())
    
    print(f"\n📈 Transition Probabilities:")
    trans_prob_df = pd.DataFrame(
        transition_probs, 
        index=[f'From Cluster {c}' for c in unique_clusters], 
        columns=[f'To Cluster {c}' for c in unique_clusters]
    )
    print(trans_prob_df.round(3).to_string())
    
    # Calculate transition statistics
    total_transitions = np.sum(transition_matrix)
    self_transitions = np.sum(np.diag(transition_matrix))
    stability_ratio = self_transitions / total_transitions if total_transitions > 0 else 0
    
    print(f"\n📋 Transition Statistics:")
    print(f"  Total transitions: {total_transitions:,}")
    print(f"  Self-transitions (no change): {self_transitions:,} ({stability_ratio:.1%})")
    print(f"  Between-cluster transitions: {total_transitions - self_transitions:,}")
    print(f"  Overall stability: {stability_ratio:.3f}")
    
    # Identify most common transitions
    print(f"\n🔝 Most Common Transitions:")
    transition_list = []
    for i, from_cluster in enumerate(unique_clusters):
        for j, to_cluster in enumerate(unique_clusters):
            if transition_matrix[i, j] > 0:
                transition_list.append({
                    'from': from_cluster,
                    'to': to_cluster,
                    'count': transition_matrix[i, j],
                    'probability': transition_probs[i, j]
                })
    
    # Sort by count and display top transitions
    transition_list.sort(key=lambda x: x['count'], reverse=True)
    for trans in transition_list[:8]:  # Show top 8 transitions
        arrow = "→" if trans['from'] != trans['to'] else "↻"
        print(f"  Cluster {trans['from']} {arrow} Cluster {trans['to']}: "
              f"{trans['count']:,} times ({trans['probability']:.3f})")
    
    return transition_matrix, transition_probs, unique_clusters

# Perform transition analysis
transition_matrix, transition_probabilities, cluster_ids = analyze_cluster_transitions(cluster_labels)
print("✅ Transition analysis completed")

In [None]:
# ==============================================================================
# TRANSITION MATRIX VISUALIZATION
# ==============================================================================

def create_transition_matrix_plots(transition_matrix, transition_probs, cluster_ids):
    """Create comprehensive visualization of cluster transition patterns."""
    
    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=['Transition Counts (Raw)', 'Transition Probabilities'],
        specs=[[{'type': 'heatmap'}, {'type': 'heatmap'}]],
        horizontal_spacing=0.15
    )
    
    cluster_labels = [f'Cluster {c}' for c in cluster_ids]
    
    # 1. Raw transition counts heatmap
    fig.add_trace(go.Heatmap(
        x=cluster_labels,
        y=cluster_labels,
        z=transition_matrix,
        colorscale='Blues',
        text=transition_matrix.astype(int),
        texttemplate='%{text}',
        textfont={'size': 12, 'color': 'white'},
        colorbar=dict(
            title='Transition Count',
            x=0.46,
            len=0.8
        ),
        hovertemplate='From: %{y}<br>To: %{x}<br>Count: %{z}<extra></extra>'
    ), row=1, col=1)
    
    # 2. Transition probabilities heatmap
    fig.add_trace(go.Heatmap(
        x=cluster_labels,
        y=cluster_labels,
        z=transition_probabilities,
        colorscale='Reds',
        text=np.round(transition_probabilities, 3),
        texttemplate='%{text}',
        textfont={'size': 12, 'color': 'white'},
        colorbar=dict(
            title='Transition Probability',
            x=1.02,
            len=0.8
        ),
        hovertemplate='From: %{y}<br>To: %{x}<br>Probability: %{z:.3f}<extra></extra>'
    ), row=1, col=2)
    
    # Configure layout
    fig.update_layout(
        title={
            'text': 'Sleep Cluster Transition Analysis - Markov Chain Dynamics',
            'x': 0.5,
            'font': {'size': 16}
        },
        height=500,
        font=dict(size=12)
    )
    
    # Configure axes
    fig.update_xaxes(title_text='To Cluster', row=1, col=1, tickangle=45)
    fig.update_yaxes(title_text='From Cluster', row=1, col=1)
    fig.update_xaxes(title_text='To Cluster', row=1, col=2, tickangle=45)
    fig.update_yaxes(title_text='From Cluster', row=1, col=2)
    
    return fig

# Create and display transition matrix visualization
print("🎨 Creating transition matrix visualizations...")
transition_fig = create_transition_matrix_plots(transition_matrix, transition_probabilities, cluster_ids)
transition_fig.show()

# Additional transition insights
print("\n🧠 TRANSITION PATTERN INSIGHTS:")
print("=" * 50)

# Calculate self-transition (stability) for each cluster
self_stability = np.diag(transition_probabilities)
most_stable_cluster = cluster_ids[np.argmax(self_stability)]
least_stable_cluster = cluster_ids[np.argmin(self_stability)]

print(f"📊 Cluster Stability (Self-Transition Probabilities):")
for i, cluster in enumerate(cluster_ids):
    stability = self_stability[i]
    stability_desc = "High" if stability > 0.8 else "Moderate" if stability > 0.6 else "Low"
    print(f"  Cluster {cluster}: {stability:.3f} ({stability_desc} stability)")

print(f"\n🎯 Key Findings:")
print(f"  Most stable cluster: {most_stable_cluster} (probability: {self_stability[most_stable_cluster]:.3f})")
print(f"  Most dynamic cluster: {least_stable_cluster} (probability: {self_stability[least_stable_cluster]:.3f})")

# Identify most common between-cluster transitions
between_cluster_transitions = []
for i, from_cluster in enumerate(cluster_ids):
    for j, to_cluster in enumerate(cluster_ids):
        if i != j and transition_probabilities[i, j] > 0.05:  # Threshold for significant transitions
            between_cluster_transitions.append({
                'from': from_cluster,
                'to': to_cluster,
                'probability': transition_probabilities[i, j],
                'count': transition_matrix[i, j]
            })

if between_cluster_transitions:
    between_cluster_transitions.sort(key=lambda x: x['probability'], reverse=True)
    print(f"\n🔄 Significant Between-Cluster Transitions (>5% probability):")
    for trans in between_cluster_transitions[:5]:  # Top 5
        print(f"  Cluster {trans['from']} → Cluster {trans['to']}: "
              f"{trans['probability']:.3f} ({trans['count']} transitions)")
else:
    print(f"\n🔒 No significant between-cluster transitions detected (all <5% probability)")
    
print("✅ Transition visualization and analysis completed")