# Time Series Transformer (TST) Sleep Analysis - Results & Clustering

This notebook provides comprehensive analysis of sleep stage detection using Time Series Transformer (TST) embeddings combined with K-means clustering. This work represents advanced sequential learning approaches for automated sleep analysis as part of the research project conducted at Institut de Neurosciences des Systèmes (INS).

## Research Methodology:
- **Time Series Transformers**: Advanced deep learning architecture for capturing long-range temporal dependencies in physiological signals
- **Multi-Scale Analysis**: Comparison between 3-second and 30-second time windows for different temporal resolution insights
- **Unsupervised Clustering**: K-means clustering applied to transformer embeddings for sleep stage discovery
- **Micro-arousal Detection**: High-resolution temporal analysis for detecting brief arousal events

## Analysis Components:
- **Cluster Distribution Analysis**: Statistical characterization of detected sleep patterns
- **Temporal Dynamics**: Evolution of sleep stages across recording sessions
- **Frequency Domain Analysis**: Spectral characteristics of different sleep clusters
- **Comparative Analysis**: Performance evaluation across different time scales
- **Clinical Interpretation**: Physiological relevance of detected patterns

## Data Sources:
- **3-second Analysis**: `predicted_labels_3s_tst_20_files_*.csv/npy` - High temporal resolution
- **30-second Analysis**: `predicted_labels_30s_tst_3_files_*.csv/npy` - Standard sleep scoring resolution
- **EEG Source**: `SC4001E0-PSG.edf` - Raw polysomnography data

In [None]:
# ==============================================================================
# LIBRARY IMPORTS AND CONFIGURATION
# ==============================================================================

# Core scientific computing
import numpy as np
import pandas as pd

# Visualization libraries
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Signal processing and analysis
import scipy.signal as signal
from scipy import stats
import mne

# Machine learning and data analysis
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

# Configuration
import warnings
warnings.filterwarnings('ignore')

# Set visualization styles for consistent appearance
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')
px.defaults.template = 'plotly_white'

print('✅ All required libraries imported successfully!')
print('✅ Visualization styles configured for Time Series Transformer analysis')

# 1. Data Loading and Preprocessing

Load Time Series Transformer clustering results and original EEG data for comprehensive analysis.

In [None]:
# ==============================================================================
# TIME SERIES TRANSFORMER RESULTS LOADING
# ==============================================================================

def load_tst_clustering_results():
    """
    Load Time Series Transformer clustering results for multi-scale analysis.
    
    Returns:
        tuple: (results_3s, results_30s) dictionaries containing CSV, NPY, and metadata
    """
    print("📊 LOADING TIME SERIES TRANSFORMER RESULTS")
    print("=" * 55)
    
    # Load 3-second window results (high temporal resolution)
    print("🔍 Loading 3-second window analysis results...")
    results_3s = {}
    
    try:
        results_3s['csv'] = pd.read_csv('results/predicted_labels_3s_tst_20_files_20250530_131707.csv')
        results_3s['npy'] = np.load('results/predicted_labels_3s_tst_20_files_20250530_131707.npy')
        
        with open('results/predicted_labels_3s_tst_20_files_20250530_131707_metadata.txt', 'r') as f:
            results_3s['metadata'] = f.read()
        
        print(f"  ✅ 3s CSV shape: {results_3s['csv'].shape}")
        print(f"  ✅ 3s NPY shape: {results_3s['npy'].shape}")
        
    except FileNotFoundError as e:
        print(f"  ❌ Error loading 3s results: {e}")
        results_3s = None
    
    # Load 30-second window results (standard sleep scoring resolution)
    print("\n🔍 Loading 30-second window analysis results...")
    results_30s = {}
    
    try:
        results_30s['csv'] = pd.read_csv('results/predicted_labels_30s_tst_3_files_20250523_181027.csv')
        results_30s['npy'] = np.load('results/predicted_labels_30s_tst_3_files_20250523_181027.npy')
        
        with open('results/predicted_labels_30s_tst_3_files_20250523_181027_metadata.txt', 'r') as f:
            results_30s['metadata'] = f.read()
            
        print(f"  ✅ 30s CSV shape: {results_30s['csv'].shape}")
        print(f"  ✅ 30s NPY shape: {results_30s['npy'].shape}")
        
    except FileNotFoundError as e:
        print(f"  ❌ Error loading 30s results: {e}")
        results_30s = None
    
    return results_3s, results_30s

def load_source_eeg_data():
    """
    Load original polysomnography (PSG) EEG data for signal analysis.
    
    Returns:
        tuple: (eeg_signal, sampling_frequency, raw_mne_object)
    """
    print("\n📡 LOADING SOURCE EEG DATA")
    print("=" * 35)
    
    # Priority order for EEG file locations
    eeg_file_paths = [
        'by captain borat/raw/SC4001E0-PSG.edf',
        'raw data/SC4001E0-PSG.edf',
        'SC4001E0-PSG.edf'
    ]
    
    for eeg_path in eeg_file_paths:
        try:
            print(f"🔍 Attempting to load: {eeg_path}")
            
            # Load EDF file with MNE
            raw = mne.io.read_raw_edf(eeg_path, preload=True, verbose=False)
            
            # Extract EEG signal (assuming first channel contains EEG)
            eeg_signal = raw.get_data()[0]
            sampling_freq = raw.info['sfreq']
            
            print(f"  ✅ EEG data loaded successfully")
            print(f"  📊 Signal shape: {eeg_signal.shape}")
            print(f"  📊 Sampling frequency: {sampling_freq} Hz")
            print(f"  📊 Duration: {len(eeg_signal)/sampling_freq/3600:.2f} hours")
            print(f"  📊 Channels available: {len(raw.ch_names)} ({', '.join(raw.ch_names[:5])}{'...' if len(raw.ch_names) > 5 else ''})")
            
            return eeg_signal, sampling_freq, raw
            
        except FileNotFoundError:
            print(f"  ⚠️  File not found: {eeg_path}")
            continue
        except Exception as e:
            print(f"  ❌ Error loading {eeg_path}: {e}")
            continue
    
    print("  ❌ No EEG data files found - analysis will be limited to clustering results only")
    return None, None, None

# Load all data sources
results_3s, results_30s = load_tst_clustering_results()
eeg_data, fs, raw_eeg = load_source_eeg_data()

print(f"\n📋 DATA LOADING SUMMARY:")
print(f"  3-second analysis: {'✅ Available' if results_3s else '❌ Unavailable'}")
print(f"  30-second analysis: {'✅ Available' if results_30s else '❌ Unavailable'}")
print(f"  Source EEG data: {'✅ Available' if eeg_data is not None else '❌ Unavailable'}")

if results_3s and results_30s:
    print(f"\n🔬 MULTI-SCALE ANALYSIS READY:")
    print(f"  High-resolution (3s): {len(results_3s['csv']):,} windows")
    print(f"  Standard-resolution (30s): {len(results_30s['csv']):,} windows")
    
print("✅ Data loading completed")

# 2. Window Duration Validation and Cluster Distribution Analysis

Validate the temporal windowing strategy and analyze the distribution of detected sleep clusters.

In [None]:
# ==============================================================================
# WINDOW DURATION VALIDATION
# ==============================================================================

def validate_window_durations(results_3s, results_30s):
    """
    Validate temporal windowing parameters and analyze window characteristics.
    
    Args:
        results_3s, results_30s: Loaded clustering results dictionaries
    """
    print("⏱️  TEMPORAL WINDOWING VALIDATION")
    print("=" * 45)
    
    # Analyze 3-second windows
    if results_3s:
        print("🔍 3-Second Window Analysis:")
        df_3s = results_3s['csv']
        
        # Calculate actual window durations
        durations_3s = df_3s['end_time_sec'] - df_3s['start_time_sec']
        
        print(f"  Total windows: {len(df_3s):,}")
        print(f"  Expected duration: 3.0s")
        print(f"  Actual duration range: {durations_3s.min():.1f}s - {durations_3s.max():.1f}s")
        print(f"  Mean duration: {durations_3s.mean():.2f}s (±{durations_3s.std():.3f}s)")
        
        # Display first few windows for verification
        print(f"\n  Sample windows:")
        for i in range(min(5, len(df_3s))):
            start = df_3s['start_time_sec'].iloc[i]
            end = df_3s['end_time_sec'].iloc[i] 
            duration = end - start
            cluster = df_3s['predicted_cluster'].iloc[i]
            print(f"    Window {i:2d}: {start:6.1f}s - {end:6.1f}s (Δ{duration:.1f}s) → Cluster {cluster}")
    
    # Analyze 30-second windows
    if results_30s:
        print(f"\n🔍 30-Second Window Analysis:")
        df_30s = results_30s['csv']
        
        # Calculate actual window durations
        durations_30s = df_30s['end_time_sec'] - df_30s['start_time_sec']
        
        print(f"  Total windows: {len(df_30s):,}")
        print(f"  Expected duration: 30.0s")
        print(f"  Actual duration range: {durations_30s.min():.1f}s - {durations_30s.max():.1f}s")
        print(f"  Mean duration: {durations_30s.mean():.2f}s (±{durations_30s.std():.3f}s)")
        
        # Display first few windows for verification
        print(f"\n  Sample windows:")
        for i in range(min(5, len(df_30s))):
            start = df_30s['start_time_sec'].iloc[i]
            end = df_30s['end_time_sec'].iloc[i]
            duration = end - start
            cluster = df_30s['predicted_cluster'].iloc[i]
            print(f"    Window {i:2d}: {start:6.1f}s - {end:6.1f}s (Δ{duration:.1f}s) → Cluster {cluster}")
    
    # Temporal coverage analysis
    if results_3s and results_30s:
        print(f"\n📊 Multi-Scale Coverage Comparison:")
        
        total_time_3s = df_3s['end_time_sec'].max()
        total_time_30s = df_30s['end_time_sec'].max()
        
        print(f"  3s analysis coverage: {total_time_3s/3600:.2f} hours")
        print(f"  30s analysis coverage: {total_time_30s/3600:.2f} hours")
        print(f"  Temporal resolution ratio: {len(df_3s)/len(df_30s):.1f}:1")
        
        # Check for temporal alignment
        overlap_start = max(df_3s['start_time_sec'].min(), df_30s['start_time_sec'].min())
        overlap_end = min(df_3s['end_time_sec'].max(), df_30s['end_time_sec'].max())
        overlap_duration = max(0, overlap_end - overlap_start)
        
        print(f"  Temporal overlap: {overlap_duration/3600:.2f} hours")
        print(f"  Overlap percentage: {overlap_duration/max(total_time_3s, total_time_30s)*100:.1f}%")

# Perform window validation
if results_3s or results_30s:
    validate_window_durations(results_3s, results_30s)
    print("✅ Window duration validation completed")
else:
    print("❌ No clustering results available for validation")

In [None]:
# ==============================================================================
# CLUSTER DISTRIBUTION ANALYSIS
# ==============================================================================

def analyze_tst_cluster_distribution(results, analysis_type):
    """
    Comprehensive analysis of Time Series Transformer clustering results.
    
    Args:
        results (dict): TST clustering results containing CSV data
        analysis_type (str): Description of analysis (e.g., "3-second", "30-second")
        
    Returns:
        tuple: (cluster_counts, cluster_percentages, analysis_summary)
    """
    
    if not results or 'csv' not in results:
        print(f"❌ No valid results available for {analysis_type} analysis")
        return None, None, None
    
    df = results['csv']
    
    print(f"📊 TST CLUSTER DISTRIBUTION ANALYSIS ({analysis_type.upper()})")
    print("=" * 60)
    
    # Basic temporal parameters
    window_duration = df['end_time_sec'].iloc[0] - df['start_time_sec'].iloc[0]
    total_duration_hours = df['end_time_sec'].max() / 3600
    
    # Cluster distribution statistics
    cluster_counts = df['predicted_cluster'].value_counts().sort_index()
    cluster_percentages = (cluster_counts / len(df) * 100).round(2)
    
    print(f"Recording Overview:")
    print(f"  Total analysis windows: {len(df):,}")
    print(f"  Window duration: {window_duration:.1f}s")
    print(f"  Total recording time: {total_duration_hours:.2f}h ({total_duration_hours*60:.0f}min)")
    print(f"  Temporal resolution: {3600/window_duration:.0f} windows/hour")
    
    print(f"\nCluster Distribution:")
    print("-" * 40)
    
    for cluster in sorted(cluster_counts.index):
        count = cluster_counts[cluster]
        percentage = cluster_percentages[cluster]
        duration_minutes = count * window_duration / 60
        duration_hours = duration_minutes / 60
        
        print(f"  Cluster {cluster}: {count:5,d} windows ({percentage:5.1f}%) "
              f"→ {duration_minutes:6.1f}min ({duration_hours:4.2f}h)")
    
    # Identify dominant and rare clusters
    dominant_cluster = cluster_counts.idxmax()
    rare_cluster = cluster_counts.idxmin()
    
    print(f"\nCluster Characteristics:")
    print(f"  Most frequent: Cluster {dominant_cluster} ({cluster_percentages[dominant_cluster]:.1f}%)")
    print(f"  Least frequent: Cluster {rare_cluster} ({cluster_percentages[rare_cluster]:.1f}%)")
    print(f"  Cluster diversity: {len(cluster_counts)} distinct sleep patterns detected")
    
    return cluster_counts, cluster_percentages, {
        'total_windows': len(df),
        'total_duration_hours': total_duration_hours,
        'window_duration': window_duration,
        'dominant_cluster': dominant_cluster,
        'dominant_percentage': cluster_percentages[dominant_cluster]
    }

def create_tst_distribution_plots(results_3s, results_30s):
    """Create comprehensive visualization comparing 3s and 30s clustering results."""
    
    fig = make_subplots(
        rows=2, cols=3,
        subplot_titles=[
            '3s Windows: Cluster Counts', '30s Windows: Cluster Counts',
            'Cluster Proportions Comparison',
            '3s Timeline (First 2h)', '30s Timeline (First 2h)', 
            'Multi-Scale Summary'
        ],
        specs=[
            [{'type': 'bar'}, {'type': 'bar'}, {'type': 'bar'}],
            [{'type': 'scatter'}, {'type': 'scatter'}, {'type': 'table'}]
        ],
        horizontal_spacing=0.08,
        vertical_spacing=0.12
    )
    
    colors_3s = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']
    colors_30s = ['#17becf', '#bcbd22', '#e377c2', '#8c564b', '#7f7f7f']
    
    # 3s cluster counts
    if results_3s:
        df_3s = results_3s['csv']
        counts_3s = df_3s['predicted_cluster'].value_counts().sort_index()
        
        fig.add_trace(go.Bar(
            x=counts_3s.index,
            y=counts_3s.values,
            name='3s Windows',
            marker_color=colors_3s[0],
            text=counts_3s.values,
            textposition='outside'
        ), row=1, col=1)
    
    # 30s cluster counts
    if results_30s:
        df_30s = results_30s['csv']
        counts_30s = df_30s['predicted_cluster'].value_counts().sort_index()
        
        fig.add_trace(go.Bar(
            x=counts_30s.index,
            y=counts_30s.values,
            name='30s Windows',
            marker_color=colors_30s[0],
            text=counts_30s.values,
            textposition='outside'
        ), row=1, col=2)
    
    # Comparative proportions
    if results_3s and results_30s:
        all_clusters = sorted(set(counts_3s.index) | set(counts_30s.index))
        
        props_3s = [(counts_3s.get(c, 0) / len(df_3s) * 100) for c in all_clusters]
        props_30s = [(counts_30s.get(c, 0) / len(df_30s) * 100) for c in all_clusters]
        
        fig.add_trace(go.Bar(
            x=all_clusters,
            y=props_3s,
            name='3s (%)',
            marker_color=colors_3s[0],
            opacity=0.7
        ), row=1, col=3)
        
        fig.add_trace(go.Bar(
            x=all_clusters,
            y=props_30s,
            name='30s (%)',
            marker_color=colors_30s[0],
            opacity=0.7
        ), row=1, col=3)
    
    # Timeline visualizations (first 2 hours)
    if results_3s:
        df_3s_subset = df_3s[df_3s['end_time_sec'] <= 7200].copy()
        time_min_3s = df_3s_subset['start_time_sec'] / 60
        
        fig.add_trace(go.Scatter(
            x=time_min_3s,
            y=df_3s_subset['predicted_cluster'],
            mode='markers',
            name='3s Timeline',
            marker=dict(size=4, opacity=0.6),
            showlegend=False
        ), row=2, col=1)
    
    if results_30s:
        df_30s_subset = df_30s[df_30s['end_time_sec'] <= 7200].copy()
        time_min_30s = df_30s_subset['start_time_sec'] / 60
        
        fig.add_trace(go.Scatter(
            x=time_min_30s,
            y=df_30s_subset['predicted_cluster'],
            mode='markers',
            name='30s Timeline',
            marker=dict(size=8, opacity=0.8),
            showlegend=False
        ), row=2, col=2)
    
    # Summary table
    summary_data = []
    if results_3s:
        summary_data.append(['3s Analysis', f'{len(df_3s):,}', f'{len(counts_3s)}', f'{df_3s["end_time_sec"].max()/3600:.1f}h'])
    if results_30s:
        summary_data.append(['30s Analysis', f'{len(df_30s):,}', f'{len(counts_30s)}', f'{df_30s["end_time_sec"].max()/3600:.1f}h'])
    
    if summary_data:
        fig.add_trace(go.Table(
            header=dict(values=['Analysis Type', 'Windows', 'Clusters', 'Duration']),
            cells=dict(values=list(zip(*summary_data)))
        ), row=2, col=3)
    
    # Update layout
    fig.update_layout(
        height=800,
        title={
            'text': 'Time Series Transformer Multi-Scale Clustering Analysis',
            'x': 0.5,
            'font': {'size': 16}
        },
        showlegend=True
    )
    
    return fig

# Perform cluster distribution analysis
print("🔬 Starting Time Series Transformer cluster analysis...")

if results_3s:
    counts_3s, percentages_3s, summary_3s = analyze_tst_cluster_distribution(results_3s, "3-second")

if results_30s:
    counts_30s, percentages_30s, summary_30s = analyze_tst_cluster_distribution(results_30s, "30-second")

# Create comprehensive visualization
if results_3s or results_30s:
    print("\n🎨 Creating multi-scale distribution plots...")
    distribution_fig = create_tst_distribution_plots(results_3s, results_30s)
    distribution_fig.show()
    print("✅ Cluster distribution analysis completed")
else:
    print("❌ No results available for visualization")

# 3. Continuous Duration Analysis

Analyze the temporal continuity and stability of detected sleep clusters to understand sleep architecture patterns.

In [None]:
# ==============================================================================
# CONTINUOUS DURATION AND SLEEP ARCHITECTURE ANALYSIS
# ==============================================================================

def analyze_continuous_durations(results, analysis_type):
    """
    Analyze continuous sleep cluster segments to understand sleep architecture.
    
    Args:
        results (dict): TST clustering results
        analysis_type (str): Type of analysis (e.g., "3-second", "30-second")
        
    Returns:
        tuple: (segments_dataframe, duration_statistics, architecture_summary)
    """
    
    if not results or 'csv' not in results:
        print(f"❌ No valid results for {analysis_type} duration analysis")
        return None, None, None
    
    df = results['csv'].copy()
    window_duration = df['end_time_sec'].iloc[0] - df['start_time_sec'].iloc[0]
    
    print(f"🏗️  SLEEP ARCHITECTURE ANALYSIS ({analysis_type.upper()})")
    print("=" * 55)
    
    # Identify continuous segments
    segments = []
    current_cluster = df['predicted_cluster'].iloc[0]
    segment_start_idx = 0
    
    for i in range(1, len(df)):
        if df['predicted_cluster'].iloc[i] != current_cluster:
            # End of current continuous segment
            segment_duration = (i - segment_start_idx) * window_duration
            
            segments.append({
                'cluster': current_cluster,
                'start_time_sec': df['start_time_sec'].iloc[segment_start_idx],
                'end_time_sec': df['end_time_sec'].iloc[i-1],
                'duration_sec': segment_duration,
                'duration_min': segment_duration / 60,
                'duration_hours': segment_duration / 3600,
                'num_windows': i - segment_start_idx,
                'start_window_idx': segment_start_idx,
                'end_window_idx': i - 1
            })
            
            # Start tracking new segment
            current_cluster = df['predicted_cluster'].iloc[i]
            segment_start_idx = i
    
    # Add the final segment
    final_duration = (len(df) - segment_start_idx) * window_duration
    segments.append({
        'cluster': current_cluster,
        'start_time_sec': df['start_time_sec'].iloc[segment_start_idx],  
        'end_time_sec': df['end_time_sec'].iloc[-1],
        'duration_sec': final_duration,
        'duration_min': final_duration / 60,
        'duration_hours': final_duration / 3600,
        'num_windows': len(df) - segment_start_idx,
        'start_window_idx': segment_start_idx,
        'end_window_idx': len(df) - 1
    })
    
    segments_df = pd.DataFrame(segments)
    
    print(f"Sleep Architecture Overview:")
    print(f"  Total continuous segments: {len(segments_df):,}")
    print(f"  Average segment length: {segments_df['duration_min'].mean():.2f} minutes")
    print(f"  Sleep transitions: {len(segments_df) - 1:,}")
    print(f"  Recording fragmentation: {len(segments_df) / (len(df) * window_duration / 3600):.1f} segments/hour")
    
    # Calculate statistics by cluster
    duration_stats = segments_df.groupby('cluster')['duration_min'].agg([
        'count', 'mean', 'median', 'std', 'min', 'max', 'sum'
    ]).round(2)
    
    duration_stats.columns = ['segments', 'mean_duration', 'median_duration', 
                             'std_duration', 'min_duration', 'max_duration', 'total_duration']
    
    print(f"\nCluster-Specific Architecture (Duration in minutes):")
    print("-" * 60)
    print(f"{'Cluster':<8} {'Segments':<8} {'Mean':<7} {'Median':<7} {'Std':<7} {'Min':<6} {'Max':<7} {'Total':<8}")
    print("-" * 60)
    
    for cluster in sorted(duration_stats.index):
        stats = duration_stats.loc[cluster]
        print(f"{cluster:<8} {stats['segments']:<8.0f} {stats['mean_duration']:<7.1f} "
              f"{stats['median_duration']:<7.1f} {stats['std_duration']:<7.1f} "
              f"{stats['min_duration']:<6.1f} {stats['max_duration']:<7.1f} {stats['total_duration']:<8.1f}")
    
    # Architecture quality metrics
    total_recording_time = segments_df['duration_min'].sum()
    mean_segment_duration = segments_df['duration_min'].mean()
    segment_stability = segments_df['duration_min'].median()  # Median as stability measure
    
    # Find most and least stable clusters
    most_stable_cluster = duration_stats['mean_duration'].idxmax()
    least_stable_cluster = duration_stats['mean_duration'].idxmin()
    longest_segment = segments_df.loc[segments_df['duration_min'].idxmax()]
    
    print(f"\nArchitecture Quality Metrics:")
    print(f"  Most stable cluster: {most_stable_cluster} (avg {duration_stats.loc[most_stable_cluster, 'mean_duration']:.1f}min)")
    print(f"  Most dynamic cluster: {least_stable_cluster} (avg {duration_stats.loc[least_stable_cluster, 'mean_duration']:.1f}min)")
    print(f"  Longest single segment: Cluster {longest_segment['cluster']} ({longest_segment['duration_min']:.1f}min)")
    print(f"  Overall sleep continuity: {segment_stability:.1f}min median segment")
    
    architecture_summary = {
        'total_segments': len(segments_df),
        'mean_segment_duration': mean_segment_duration,
        'most_stable_cluster': most_stable_cluster,
        'least_stable_cluster': least_stable_cluster,
        'longest_segment_duration': longest_segment['duration_min'],
        'segment_stability': segment_stability
    }
    
    return segments_df, duration_stats, architecture_summary

def create_duration_analysis_plots(segments_3s, segments_30s, stats_3s, stats_30s):
    """Create comprehensive duration analysis visualizations."""
    
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=[
            'Segment Duration Distributions',
            'Architecture Stability by Cluster',
            'Segment Count vs Duration (3s)',
            'Segment Count vs Duration (30s)'
        ],
        specs=[
            [{'type': 'histogram'}, {'type': 'bar'}],
            [{'type': 'scatter'}, {'type': 'scatter'}]
        ]
    )
    
    # Duration distributions
    if segments_3s is not None:
        fig.add_trace(go.Histogram(
            x=segments_3s['duration_min'],
            name='3s Segments',
            opacity=0.7,
            nbinsx=30
        ), row=1, col=1)
    
    if segments_30s is not None:
        fig.add_trace(go.Histogram(
            x=segments_30s['duration_min'],
            name='30s Segments',
            opacity=0.7,
            nbinsx=30
        ), row=1, col=1)
    
    # Stability comparison
    if stats_3s is not None and stats_30s is not None:
        clusters = sorted(set(stats_3s.index) | set(stats_30s.index))
        
        mean_3s = [stats_3s.loc[c, 'mean_duration'] if c in stats_3s.index else 0 for c in clusters]
        mean_30s = [stats_30s.loc[c, 'mean_duration'] if c in stats_30s.index else 0 for c in clusters]
        
        fig.add_trace(go.Bar(
            x=clusters,
            y=mean_3s,
            name='3s Mean Duration',
            opacity=0.7
        ), row=1, col=2)
        
        fig.add_trace(go.Bar(
            x=clusters,
            y=mean_30s,
            name='30s Mean Duration',
            opacity=0.7
        ), row=1, col=2)
    
    # Scatter plots for segment analysis
    if segments_3s is not None:
        segment_counts_3s = segments_3s['cluster'].value_counts()
        fig.add_trace(go.Scatter(
            x=segment_counts_3s.values,
            y=[segments_3s[segments_3s['cluster']==c]['duration_min'].mean() for c in segment_counts_3s.index],
            mode='markers+text',
            text=[f'C{c}' for c in segment_counts_3s.index],
            textposition='top center',
            name='3s Analysis',
            marker=dict(size=10, opacity=0.7)
        ), row=2, col=1)
    
    if segments_30s is not None:
        segment_counts_30s = segments_30s['cluster'].value_counts()
        fig.add_trace(go.Scatter(
            x=segment_counts_30s.values,
            y=[segments_30s[segments_30s['cluster']==c]['duration_min'].mean() for c in segment_counts_30s.index],
            mode='markers+text',
            text=[f'C{c}' for c in segment_counts_30s.index],
            textposition='top center',
            name='30s Analysis',
            marker=dict(size=10, opacity=0.7)
        ), row=2, col=2)
    
    # Update layout
    fig.update_layout(
        height=700,
        title={
            'text': 'Sleep Architecture - Continuous Duration Analysis',
            'x': 0.5,
            'font': {'size': 16}
        }
    )
    
    # Update axes
    fig.update_xaxes(title_text='Duration (minutes)', row=1, col=1)
    fig.update_yaxes(title_text='Frequency', row=1, col=1)
    fig.update_xaxes(title_text='Cluster', row=1, col=2)
    fig.update_yaxes(title_text='Mean Duration (minutes)', row=1, col=2)
    fig.update_xaxes(title_text='Number of Segments', row=2, col=1)
    fig.update_yaxes(title_text='Mean Duration (minutes)', row=2, col=1)
    fig.update_xaxes(title_text='Number of Segments', row=2, col=2)
    fig.update_yaxes(title_text='Mean Duration (minutes)', row=2, col=2)
    
    return fig

# Perform continuous duration analysis
print("🏗️  Starting sleep architecture analysis...")

segments_3s, stats_3s, summary_3s = None, None, None
segments_30s, stats_30s, summary_30s = None, None, None

if results_3s:
    segments_3s, stats_3s, summary_3s = analyze_continuous_durations(results_3s, "3-second")

if results_30s:
    segments_30s, stats_30s, summary_30s = analyze_continuous_durations(results_30s, "30-second")

# Create comprehensive visualization
if segments_3s is not None or segments_30s is not None:
    print("\n🎨 Creating duration analysis visualizations...")
    duration_fig = create_duration_analysis_plots(segments_3s, segments_30s, stats_3s, stats_30s)
    duration_fig.show()
    print("✅ Sleep architecture analysis completed")
else:
    print("❌ No segment data available for duration analysis")

In [None]:
# ==============================================================================
# FREQUENCY DOMAIN ANALYSIS AND SPECTRAL CHARACTERISTICS
# ==============================================================================

def analyze_cluster_frequency_characteristics(segments_df, results, analysis_type):
    """
    Analyze frequency domain characteristics of different sleep clusters.
    
    Args:
        segments_df (pd.DataFrame): Continuous segments data
        results (dict): TST clustering results
        analysis_type (str): Type of analysis
    
    Returns:
        dict: Frequency analysis results
    """
    
    if segments_df is None or results is None:
        print(f"❌ No data available for {analysis_type} frequency analysis")
        return None
    
    print(f"📊 FREQUENCY CHARACTERISTICS ANALYSIS ({analysis_type.upper()})")
    print("=" * 58)
    
    # Extract key frequency metrics
    frequency_metrics = {}
    
    # Cluster-specific frequency patterns
    for cluster in sorted(segments_df['cluster'].unique()):
        cluster_segments = segments_df[segments_df['cluster'] == cluster]
        
        total_time = cluster_segments['duration_min'].sum()
        avg_segment_duration = cluster_segments['duration_min'].mean()
        dominant_frequency = 1 / (avg_segment_duration * 60)  # Hz equivalent
        
        frequency_metrics[cluster] = {
            'total_time_min': total_time,
            'avg_segment_duration_min': avg_segment_duration,
            'segment_count': len(cluster_segments),
            'dominant_frequency_hz': dominant_frequency,
            'time_percentage': (total_time / segments_df['duration_min'].sum()) * 100
        }
    
    print(f"Cluster Frequency Characteristics:")
    print("-" * 70)
    print(f"{'Cluster':<8} {'Time%':<6} {'Segments':<9} {'Avg Dur(min)':<12} {'Dom Freq(Hz)':<12}")
    print("-" * 70)
    
    for cluster in sorted(frequency_metrics.keys()):
        metrics = frequency_metrics[cluster]
        print(f"{cluster:<8} {metrics['time_percentage']:<6.1f} {metrics['segment_count']:<9} "
              f"{metrics['avg_segment_duration_min']:<12.2f} {metrics['dominant_frequency_hz']:<12.4f}")
    
    # Temporal frequency analysis (transitions per hour)
    total_recording_hours = segments_df['duration_min'].sum() / 60
    transition_frequency = len(segments_df) / total_recording_hours
    
    print(f"\nTemporal Dynamics:")
    print(f"  Total recording time: {total_recording_hours:.2f} hours")
    print(f"  Segment transitions: {len(segments_df):,}")
    print(f"  Transition frequency: {transition_frequency:.2f} transitions/hour")
    print(f"  Sleep fragmentation index: {transition_frequency / 10:.2f} (normalized)")
    
    # Identify frequency bands based on segment durations
    short_segments = segments_df[segments_df['duration_min'] < 5]
    medium_segments = segments_df[(segments_df['duration_min'] >= 5) & (segments_df['duration_min'] < 20)]
    long_segments = segments_df[segments_df['duration_min'] >= 20]
    
    print(f"\nSegment Duration Distribution:")
    print(f"  Short segments (<5min): {len(short_segments):,} ({len(short_segments)/len(segments_df)*100:.1f}%)")
    print(f"  Medium segments (5-20min): {len(medium_segments):,} ({len(medium_segments)/len(segments_df)*100:.1f}%)")
    print(f"  Long segments (>20min): {len(long_segments):,} ({len(long_segments)/len(segments_df)*100:.1f}%)")
    
    return {
        'frequency_metrics': frequency_metrics,
        'transition_frequency': transition_frequency,
        'segment_distribution': {
            'short': len(short_segments),
            'medium': len(medium_segments),
            'long': len(long_segments)
        }
    }

def create_frequency_analysis_plots(freq_results_3s, freq_results_30s):
    """Create frequency domain analysis visualizations."""
    
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=[
            'Time Distribution by Cluster',
            'Transition Frequency Comparison',
            'Segment Duration Categories (3s)',
            'Segment Duration Categories (30s)'
        ],
        specs=[
            [{'type': 'bar'}, {'type': 'bar'}],
            [{'type': 'pie'}, {'type': 'pie'}]
        ]
    )
    
    # Time distribution by cluster
    if freq_results_3s and freq_results_30s:
        clusters_3s = sorted(freq_results_3s['frequency_metrics'].keys())
        clusters_30s = sorted(freq_results_30s['frequency_metrics'].keys())
        
        time_pct_3s = [freq_results_3s['frequency_metrics'][c]['time_percentage'] for c in clusters_3s]
        time_pct_30s = [freq_results_30s['frequency_metrics'][c]['time_percentage'] for c in clusters_30s]
        
        fig.add_trace(go.Bar(
            x=clusters_3s,
            y=time_pct_3s,
            name='3s Analysis',
            opacity=0.7
        ), row=1, col=1)
        
        fig.add_trace(go.Bar(
            x=clusters_30s,
            y=time_pct_30s,
            name='30s Analysis',
            opacity=0.7
        ), row=1, col=1)
    
    # Transition frequency comparison
    if freq_results_3s and freq_results_30s:
        fig.add_trace(go.Bar(
            x=['3-second Analysis', '30-second Analysis'],
            y=[freq_results_3s['transition_frequency'], freq_results_30s['transition_frequency']],
            name='Transition Frequency',
            marker_color=['lightblue', 'lightcoral']
        ), row=1, col=2)
    
    # Segment duration category pie charts
    if freq_results_3s:
        seg_dist_3s = freq_results_3s['segment_distribution']
        fig.add_trace(go.Pie(
            labels=['Short (<5min)', 'Medium (5-20min)', 'Long (>20min)'],
            values=[seg_dist_3s['short'], seg_dist_3s['medium'], seg_dist_3s['long']],
            name='3s Duration Categories'
        ), row=2, col=1)
    
    if freq_results_30s:
        seg_dist_30s = freq_results_30s['segment_distribution']
        fig.add_trace(go.Pie(
            labels=['Short (<5min)', 'Medium (5-20min)', 'Long (>20min)'],
            values=[seg_dist_30s['short'], seg_dist_30s['medium'], seg_dist_30s['long']],
            name='30s Duration Categories'
        ), row=2, col=2)
    
    # Update layout
    fig.update_layout(
        height=700,
        title={
            'text': 'Frequency Domain and Temporal Dynamics Analysis',
            'x': 0.5,
            'font': {'size': 16}
        }
    )
    
    # Update axes
    fig.update_xaxes(title_text='Cluster', row=1, col=1)
    fig.update_yaxes(title_text='Time Percentage (%)', row=1, col=1)
    fig.update_xaxes(title_text='Analysis Type', row=1, col=2)
    fig.update_yaxes(title_text='Transitions/Hour', row=1, col=2)
    
    return fig

# Perform frequency analysis
print("📊 Starting frequency characteristics analysis...")

freq_results_3s = None
freq_results_30s = None

if segments_3s is not None:
    freq_results_3s = analyze_cluster_frequency_characteristics(segments_3s, results_3s, "3-second")

if segments_30s is not None:
    freq_results_30s = analyze_cluster_frequency_characteristics(segments_30s, results_30s, "30-second")

# Create frequency analysis visualizations
if freq_results_3s or freq_results_30s:
    print("\n🎨 Creating frequency analysis visualizations...")
    freq_fig = create_frequency_analysis_plots(freq_results_3s, freq_results_30s)
    freq_fig.show()
    print("✅ Frequency characteristics analysis completed")
else:
    print("❌ No frequency data available for analysis")

In [None]:
# ==============================================================================
# TEMPORAL DYNAMICS AND TRANSITION ANALYSIS
# ==============================================================================

def analyze_temporal_transitions(segments_df, results, analysis_type):
    """
    Analyze temporal transitions between sleep clusters.
    
    Args:
        segments_df (pd.DataFrame): Continuous segments data
        results (dict): TST clustering results
        analysis_type (str): Type of analysis
    
    Returns:
        dict: Transition analysis results
    """
    
    if segments_df is None or results is None:
        print(f"❌ No data available for {analysis_type} transition analysis")
        return None
        
    print(f"🔄 TEMPORAL TRANSITION ANALYSIS ({analysis_type.upper()})")
    print("=" * 52)
    
    # Create transition matrix
    clusters = sorted(segments_df['cluster'].unique())
    transition_matrix = np.zeros((len(clusters), len(clusters)))
    cluster_to_idx = {cluster: i for i, cluster in enumerate(clusters)}
    
    # Count transitions
    transitions = []
    for i in range(len(segments_df) - 1):
        from_cluster = segments_df.iloc[i]['cluster']
        to_cluster = segments_df.iloc[i + 1]['cluster']
        
        transitions.append({
            'from_cluster': from_cluster,
            'to_cluster': to_cluster,
            'from_duration': segments_df.iloc[i]['duration_min'],
            'to_duration': segments_df.iloc[i + 1]['duration_min'],
            'transition_time': segments_df.iloc[i]['end_time_sec'] / 3600  # hours
        })
        
        from_idx = cluster_to_idx[from_cluster]
        to_idx = cluster_to_idx[to_cluster]
        transition_matrix[from_idx, to_idx] += 1
    
    transitions_df = pd.DataFrame(transitions)
    
    # Calculate transition probabilities
    transition_probs = transition_matrix.copy()
    for i in range(len(clusters)):
        row_sum = transition_matrix[i, :].sum()  
        if row_sum > 0:
            transition_probs[i, :] = transition_matrix[i, :] / row_sum
    
    print(f"Transition Matrix (Raw Counts):")
    print("-" * 40)
    print(f"{'From\\To':<8}", end="")
    for cluster in clusters:
        print(f"{cluster:>8}", end="")
    print()
    print("-" * 40)
    
    for i, from_cluster in enumerate(clusters):
        print(f"{from_cluster:<8}", end="")
        for j in range(len(clusters)):
            print(f"{transition_matrix[i, j]:>8.0f}", end="")
        print()
    
    print(f"\nTransition Probabilities:")
    print("-" * 40)
    print(f"{'From\\To':<8}", end="")
    for cluster in clusters:
        print(f"{cluster:>8}", end="")
    print()
    print("-" * 40)
    
    for i, from_cluster in enumerate(clusters):
        print(f"{from_cluster:<8}", end="")
        for j in range(len(clusters)):
            print(f"{transition_probs[i, j]:>8.2f}", end="")
        print()
    
    # Analyze transition patterns
    most_common_transitions = transitions_df.groupby(['from_cluster', 'to_cluster']).size().sort_values(ascending=False)
    
    print(f"\nMost Common Transitions:")
    print("-" * 30)
    for (from_c, to_c), count in most_common_transitions.head(5).items():
        prob = count / len(transitions_df)
        print(f"  {from_c} → {to_c}: {count:,} transitions ({prob:.1%})")
    
    # Stability analysis
    diagonal_sum = np.diag(transition_matrix).sum()
    total_transitions = transition_matrix.sum()
    stability_index = diagonal_sum / total_transitions if total_transitions > 0 else 0
    
    print(f"\nStability Metrics:")
    print(f"  Total transitions: {total_transitions:.0f}")
    print(f"  Self-transitions (stability): {diagonal_sum:.0f}")
    print(f"  Stability index: {stability_index:.3f}")
    print(f"  Average transitions per segment: {total_transitions / len(segments_df):.2f}")
    
    return {
        'transition_matrix': transition_matrix,
        'transition_probs': transition_probs,
        'transitions_df': transitions_df,
        'clusters': clusters,
        'stability_index': stability_index,
        'most_common_transitions': most_common_transitions
    }

def create_transition_analysis_plots(trans_results_3s, trans_results_30s):
    """Create temporal transition analysis visualizations."""
    
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=[
            'Transition Matrix Heatmap (3s)',
            'Transition Matrix Heatmap (30s)', 
            'Transition Probability Comparison',
            'Stability Index Comparison'
        ],
        specs=[
            [{'type': 'heatmap'}, {'type': 'heatmap'}],
            [{'type': 'scatter'}, {'type': 'bar'}]
        ]
    )
    
    # Transition matrix heatmaps
    if trans_results_3s:
        clusters_3s = trans_results_3s['clusters']
        matrix_3s = trans_results_3s['transition_probs']
        
        fig.add_trace(go.Heatmap(
            z=matrix_3s,
            x=clusters_3s,
            y=clusters_3s,
            colorscale='Blues',
            name='3s Transitions'
        ), row=1, col=1)
    
    if trans_results_30s:
        clusters_30s = trans_results_30s['clusters']
        matrix_30s = trans_results_30s['transition_probs']
        
        fig.add_trace(go.Heatmap(
            z=matrix_30s, 
            x=clusters_30s,
            y=clusters_30s,
            colorscale='Reds',
            name='30s Transitions'
        ), row=1, col=2)
    
    # Transition probability comparison
    if trans_results_3s and trans_results_30s:
        # Get common clusters and their transition probabilities
        common_clusters = set(trans_results_3s['clusters']) & set(trans_results_30s['clusters'])
        
        for from_cluster in sorted(common_clusters):
            for to_cluster in sorted(common_clusters):
                if from_cluster != to_cluster:  # Skip self-transitions for clarity
                    from_idx_3s = trans_results_3s['clusters'].index(from_cluster)
                    to_idx_3s = trans_results_3s['clusters'].index(to_cluster)
                    prob_3s = trans_results_3s['transition_probs'][from_idx_3s, to_idx_3s]
                    
                    from_idx_30s = trans_results_30s['clusters'].index(from_cluster)
                    to_idx_30s = trans_results_30s['clusters'].index(to_cluster)
                    prob_30s = trans_results_30s['transition_probs'][from_idx_30s, to_idx_30s]
                    
                    fig.add_trace(go.Scatter(
                        x=[prob_3s],
                        y=[prob_30s],
                        mode='markers+text',
                        text=[f'{from_cluster}→{to_cluster}'],
                        textposition='top center',
                        name=f'{from_cluster}→{to_cluster}',
                        showlegend=False
                    ), row=2, col=1)
    
    # Stability comparison
    if trans_results_3s and trans_results_30s:
        fig.add_trace(go.Bar(
            x=['3-second Analysis', '30-second Analysis'],
            y=[trans_results_3s['stability_index'], trans_results_30s['stability_index']],
            name='Stability Index',
            marker_color=['lightblue', 'lightcoral']
        ), row=2, col=2)
    
    # Update layout
    fig.update_layout(
        height=800,
        title={
            'text': 'Temporal Dynamics - Transition Analysis',
            'x': 0.5,
            'font': {'size': 16}
        }
    )
    
    # Update axes
    fig.update_xaxes(title_text='To Cluster', row=1, col=1)
    fig.update_yaxes(title_text='From Cluster', row=1, col=1)
    fig.update_xaxes(title_text='To Cluster', row=1, col=2)
    fig.update_yaxes(title_text='From Cluster', row=1, col=2)
    fig.update_xaxes(title_text='3s Transition Probability', row=2, col=1)
    fig.update_yaxes(title_text='30s Transition Probability', row=2, col=1)
    fig.update_xaxes(title_text='Analysis Type', row=2, col=2)
    fig.update_yaxes(title_text='Stability Index', row=2, col=2)
    
    return fig

# Perform temporal transition analysis
print("🔄 Starting temporal transition analysis...")

trans_results_3s = None
trans_results_30s = None

if segments_3s is not None:
    trans_results_3s = analyze_temporal_transitions(segments_3s, results_3s, "3-second")

if segments_30s is not None:
    trans_results_30s = analyze_temporal_transitions(segments_30s, results_30s, "30-second")

# Create transition analysis visualizations
if trans_results_3s or trans_results_30s:
    print("\n🎨 Creating transition analysis visualizations...")
    trans_fig = create_transition_analysis_plots(trans_results_3s, trans_results_30s)
    trans_fig.show()
    print("✅ Temporal transition analysis completed")
else:
    print("❌ No transition data available for analysis")

In [None]:
# ==============================================================================
# COMPREHENSIVE ANALYSIS SUMMARY AND CLINICAL INSIGHTS
# ==============================================================================

def generate_comprehensive_summary(results_3s, results_30s, segments_3s, segments_30s, 
                                 freq_results_3s, freq_results_30s, 
                                 trans_results_3s, trans_results_30s):
    """
    Generate a comprehensive summary of TST clustering analysis results.
    
    Args:
        Various analysis results from different temporal scales
        
    Returns:
        dict: Comprehensive analysis summary
    """
    
    print("📋 COMPREHENSIVE TST CLUSTERING ANALYSIS SUMMARY")
    print("=" * 58)
    
    summary = {
        'temporal_scale_comparison': {},
        'cluster_characteristics': {},
        'clinical_insights': {},
        'methodological_findings': {}
    }
    
    # Temporal Scale Comparison
    print("\n🔍 TEMPORAL SCALE COMPARISON")
    print("-" * 35)
    
    if results_3s and results_30s:
        # Basic comparison
        clusters_3s = len(set(results_3s['csv']['predicted_cluster'])) if results_3s.get('csv') is not None else 0
        clusters_30s = len(set(results_30s['csv']['predicted_cluster'])) if results_30s.get('csv') is not None else 0
        
        windows_3s = len(results_3s['csv']) if results_3s.get('csv') is not None else 0
        windows_30s = len(results_30s['csv']) if results_30s.get('csv') is not None else 0
        
        print(f"3-second Analysis:")
        print(f"  • Total windows: {windows_3s:,}")
        print(f"  • Unique clusters: {clusters_3s}")
        print(f"  • Temporal resolution: High (fine-grained micro-patterns)")
        
        print(f"\n30-second Analysis:")
        print(f"  • Total windows: {windows_30s:,}")
        print(f"  • Unique clusters: {clusters_30s}")
        print(f"  • Temporal resolution: Standard (macro sleep architecture)")
        
        # Resolution ratio
        resolution_ratio = windows_3s / windows_30s if windows_30s > 0 else 0
        print(f"\nTemporal Resolution Ratio: {resolution_ratio:.1f}:1 (3s:30s)")
        
        summary['temporal_scale_comparison'] = {
            'windows_3s': windows_3s,
            'windows_30s': windows_30s,
            'clusters_3s': clusters_3s,
            'clusters_30s': clusters_30s,
            'resolution_ratio': resolution_ratio
        }
    
    # Cluster Characteristics Analysis
    print(f"\n🧠 CLUSTER CHARACTERISTICS SUMMARY")
    print("-" * 38)
    
    if freq_results_3s and freq_results_30s:
        print("Time Distribution Patterns:")
        
        for scale, freq_results in [("3s", freq_results_3s), ("30s", freq_results_30s)]:
            print(f"\n{scale} Analysis:")
            metrics = freq_results['frequency_metrics']
            
            # Find dominant and minor clusters
            time_percentages = {c: metrics[c]['time_percentage'] for c in metrics}
            dominant_cluster = max(time_percentages, key=time_percentages.get)
            minor_cluster = min(time_percentages, key=time_percentages.get)
            
            print(f"  • Dominant cluster: {dominant_cluster} ({time_percentages[dominant_cluster]:.1f}% of recording)")
            print(f"  • Minor cluster: {minor_cluster} ({time_percentages[minor_cluster]:.1f}% of recording)")
            print(f"  • Sleep fragmentation: {freq_results['transition_frequency']:.1f} transitions/hour")
            
            summary['cluster_characteristics'][f'{scale}_dominant'] = dominant_cluster
            summary['cluster_characteristics'][f'{scale}_fragmentation'] = freq_results['transition_frequency']
    
    # Architecture Stability Analysis
    if trans_results_3s and trans_results_30s:
        print(f"\nSleep Architecture Stability:")
        stability_3s = trans_results_3s['stability_index']
        stability_30s = trans_results_30s['stability_index']
        
        print(f"  • 3s stability index: {stability_3s:.3f}")
        print(f"  • 30s stability index: {stability_30s:.3f}")
        
        stability_difference = abs(stability_3s - stability_30s)
        if stability_difference > 0.1:
            print(f"  • Scale-dependent stability detected (Δ={stability_difference:.3f})")
        else:
            print(f"  • Consistent stability across scales (Δ={stability_difference:.3f})")
        
        summary['cluster_characteristics']['stability_comparison'] = {
            '3s': stability_3s,
            '30s': stability_30s,
            'difference': stability_difference
        }
    
    # Clinical Insights
    print(f"\n🏥 CLINICAL INSIGHTS AND INTERPRETATION")
    print("-" * 42)
    
    if freq_results_3s and freq_results_30s:
        # Sleep quality assessment
        frag_3s = freq_results_3s['transition_frequency']
        frag_30s = freq_results_30s['transition_frequency']
        
        print("Sleep Quality Assessment:")
        if frag_3s > 60:  # More than 1 transition per minute at 3s scale
            print("  • High micro-fragmentation detected (potential micro-arousals)")
        else:
            print("  • Normal micro-architecture stability")
            
        if frag_30s > 10:  # More than 10 transitions per hour at 30s scale
            print("  • Fragmented macro-architecture (potential sleep disorder)")
        else:
            print("  • Stable macro sleep architecture")
        
        # Temporal scale insights
        scale_ratio = frag_3s / frag_30s if frag_30s > 0 else 0
        print(f"\nTemporal Scale Insights:")
        print(f"  • Micro/Macro fragmentation ratio: {scale_ratio:.1f}")
        
        if scale_ratio > 20:
            print("  • Significant micro-instability with stable macro-architecture")
            print("  • Suggests: Possible micro-arousal events, breathing disorders")
        elif scale_ratio < 5:
            print("  • Proportional fragmentation across scales")
            print("  • Suggests: Primary sleep architecture disruption")
        else:
            print("  • Balanced micro/macro sleep dynamics")
        
        summary['clinical_insights'] = {
            'micro_fragmentation': frag_3s > 60,
            'macro_fragmentation': frag_30s > 10,
            'scale_ratio': scale_ratio,
            'sleep_quality': 'fragmented' if (frag_3s > 60 or frag_30s > 10) else 'stable'
        }
    
    # Methodological Findings
    print(f"\n🔬 METHODOLOGICAL FINDINGS")
    print("-" * 30)
    
    print("Time Series Transformer Performance:")
    if results_3s and results_30s:
        print("  • Multi-scale clustering successfully implemented")
        print("  • Temporal resolution affects cluster granularity")
        print("  • Both scales provide complementary information")
        
        if clusters_3s > clusters_30s:
            print(f"  • Higher resolution reveals {clusters_3s - clusters_30s} additional micro-patterns")
        
        print(f"\nRecommendations:")
        print("  • Use 3s windows for micro-arousal detection")
        print("  • Use 30s windows for sleep stage classification")
        print("  • Combine both scales for comprehensive sleep analysis")
        print("  • Consider clinical context when interpreting fragmentation")
        
        summary['methodological_findings'] = {
            'multi_scale_effective': True,
            'complementary_information': True,
            'recommended_use': {
                '3s': 'micro-arousal detection',
                '30s': 'sleep stage classification'
            }
        }
    
    print(f"\n✅ Analysis Summary Generated Successfully")
    return summary

def create_summary_dashboard(summary, results_3s, results_30s):
    """Create a comprehensive dashboard summarizing all results."""
    
    fig = make_subplots(
        rows=2, cols=3,
        subplot_titles=[
            'Temporal Resolution Comparison',
            'Sleep Fragmentation Assessment', 
            'Cluster Time Distribution (3s)',
            'Cluster Time Distribution (30s)',
            'Architecture Stability',
            'Clinical Quality Indicators'
        ],
        specs=[
            [{'type': 'bar'}, {'type': 'scatter'}, {'type': 'pie'}],
            [{'type': 'pie'}, {'type': 'bar'}, {'type': 'indicator'}]
        ]
    )
    
    # Add plots based on available data
    if summary.get('temporal_scale_comparison'):
        comp = summary['temporal_scale_comparison']
        
        # Resolution comparison
        fig.add_trace(go.Bar(
            x=['3-second', '30-second'],
            y=[comp['windows_3s'], comp['windows_30s']],
            name='Total Windows',
            marker_color=['lightblue', 'lightcoral']
        ), row=1, col=1)
        
        # Fragmentation assessment
        if summary.get('cluster_characteristics'):
            char = summary['cluster_characteristics']
            fig.add_trace(go.Scatter(
                x=[char['3s_fragmentation'], char['30s_fragmentation']],
                y=[3, 30],
                mode='markers+text',
                text=['3s Scale', '30s Scale'],
                textposition='middle right',
                marker=dict(size=[15, 15], color=['blue', 'red']),
                name='Fragmentation Rate'
            ), row=1, col=2)
    
    # Add clinical quality indicator
    if summary.get('clinical_insights'):
        clinical = summary['clinical_insights']
        quality_score = 100 if clinical['sleep_quality'] == 'stable' else 50
        
        fig.add_trace(go.Indicator(
            mode="gauge+number",
            value=quality_score,
            title={'text': "Sleep Quality Score"},
            gauge={
                'axis': {'range': [0, 100]},
                'bar': {'color': "darkgreen" if quality_score > 75 else "orange"},
                'steps': [
                    {'range': [0, 50], 'color': "lightgray"},
                    {'range': [50, 80], 'color': "gray"},
                    {'range': [80, 100], 'color': "lightgreen"}
                ]
            }
        ), row=2, col=3)
    
    fig.update_layout(
        height=800,
        title={
            'text': 'TST Clustering Analysis - Comprehensive Dashboard',
            'x': 0.5,
            'font': {'size': 18}
        }
    )
    
    return fig

# Generate comprehensive analysis summary
print("📋 Generating comprehensive analysis summary...")

if any([results_3s, results_30s, segments_3s, segments_30s]):
    analysis_summary = generate_comprehensive_summary(
        results_3s, results_30s, segments_3s, segments_30s,
        freq_results_3s, freq_results_30s, 
        trans_results_3s, trans_results_30s
    )
    
    # Create summary dashboard
    print("\n🎨 Creating comprehensive dashboard...")
    summary_dashboard = create_summary_dashboard(analysis_summary, results_3s, results_30s)
    summary_dashboard.show()
    
    print("\n🎯 KEY FINDINGS SUMMARY:")
    print("=" * 30)
    if analysis_summary.get('clinical_insights'):
        clinical = analysis_summary['clinical_insights']
        print(f"Sleep Quality: {clinical['sleep_quality'].upper()}")
        print(f"Micro-fragmentation: {'DETECTED' if clinical['micro_fragmentation'] else 'NORMAL'}")
        print(f"Macro-fragmentation: {'DETECTED' if clinical['macro_fragmentation'] else 'NORMAL'}")
        print(f"Scale Ratio: {clinical['scale_ratio']:.1f}")
    
    print("\n✅ Comprehensive TST clustering analysis completed successfully!")
    
else:
    print("❌ Insufficient data for comprehensive analysis")

## 4. Frequency Domain Analysis with Multitaper Spectrogram

In [None]:
def create_multitaper_spectrogram(eeg_data, fs, window_length=30, overlap=0.5):
    """Create multitaper spectrogram of EEG data"""
    if eeg_data is None:
        print('No EEG data available for spectrogram analysis')
        return None, None, None
    
    # Parameters for spectrogram
    nperseg = int(window_length * fs)
    noverlap = int(nperseg * overlap)
    
    # Compute spectrogram (removed invalid 'method' parameter)
    frequencies, times, Sxx = signal.spectrogram(
        eeg_data, fs,
        window='hann',
        nperseg=nperseg,
        noverlap=noverlap
    )
    
    return frequencies, times, Sxx

# Alternative function using scipy's built-in multitaper method
def create_multitaper_psd(eeg_data, fs, window_length=30):
    """Create multitaper power spectral density"""
    if eeg_data is None:
        return None, None
    
    try:
        # Use multitaper method for PSD estimation
        from scipy import signal
        nperseg = int(window_length * fs)
        
        # Calculate PSD using multitaper method
        frequencies, psd = signal.welch(
            eeg_data, fs,
            window='hann',
            nperseg=nperseg,
            method='multitaper'
        )
        
        return frequencies, psd
    except Exception as e:
        print(f'Multitaper PSD calculation failed: {e}')
        # Fallback to regular welch method
        frequencies, psd = signal.welch(
            eeg_data, fs,
            window='hann',
            nperseg=nperseg
        )
        return frequencies, psd

def analyze_frequency_bands(frequencies, Sxx, cluster_labels, times):
    """Analyze power in different frequency bands for each cluster"""
    
    # Define frequency bands
    bands = {
        'Delta (0.5-4 Hz)': (0.5, 4),
        'Theta (4-8 Hz)': (4, 8),
        'Alpha (8-13 Hz)': (8, 13),
        'Beta (13-30 Hz)': (13, 30),
        'Gamma (30-50 Hz)': (30, 50)
    }
    
    # Calculate power in each band
    band_powers = {}
    
    for band_name, (low_freq, high_freq) in bands.items():
        # Find frequency indices
        freq_mask = (frequencies >= low_freq) & (frequencies <= high_freq)
        
        if np.any(freq_mask):
            # Calculate mean power in this band across time
            band_power = np.mean(Sxx[freq_mask, :], axis=0)
            band_powers[band_name] = band_power
        else:
            band_powers[band_name] = np.zeros(len(times))
    
    return band_powers

def plot_frequency_analysis(eeg_data, fs, results, title_suffix):
    """Comprehensive frequency domain analysis"""
    
    if eeg_data is None:
        print(f'Cannot perform frequency analysis for {title_suffix} - no EEG data')
        return
    
    # Truncate EEG data to match clustering results duration
    df = results['csv']
    max_time = df['end_time_sec'].max()
    eeg_truncated = eeg_data[:int(max_time * fs)]
    
    print(f'\n=== Frequency Domain Analysis ({title_suffix}) ===')
    print(f'EEG duration: {len(eeg_truncated)/fs/3600:.2f} hours')
    print(f'Clustering duration: {max_time/3600:.2f} hours')
    
    # Create spectrogram (fixed function call)
    frequencies, times, Sxx = create_multitaper_spectrogram(eeg_truncated, fs)
    
    if frequencies is None:
        return
    
    # Convert power to dB
    Sxx_db = 10 * np.log10(Sxx + 1e-12)
    
    # Create cluster labels aligned with spectrogram times
    cluster_times = []
    cluster_labels = []
    
    for _, row in df.iterrows():
        start_time = row['start_time_sec']
        end_time = row['end_time_sec']
        cluster = row['cluster_label']
        
        # Find corresponding time indices in spectrogram
        time_mask = (times >= start_time) & (times < end_time)
        cluster_times.extend(times[time_mask])
        cluster_labels.extend([cluster] * np.sum(time_mask))
    
    # Analyze frequency bands
    band_powers = analyze_frequency_bands(frequencies, Sxx, cluster_labels, times)
    
    # Create comprehensive visualization
    fig = plt.figure(figsize=(20, 15))
    gs = fig.add_gridspec(4, 3, hspace=0.3, wspace=0.3)
    
    # 1. Full spectrogram
    ax1 = fig.add_subplot(gs[0, :])
    im1 = ax1.pcolormesh(times/60, frequencies, Sxx_db, shading='gouraud', cmap='viridis')
    ax1.set_ylabel('Frequency (Hz)')
    ax1.set_title(f'Spectrogram - {title_suffix}')
    ax1.set_ylim([0, 50])
    plt.colorbar(im1, ax=ax1, label='Power (dB)')
    
    # 2. Cluster overlay on spectrogram (first 2 hours)
    ax2 = fig.add_subplot(gs[1, :])
    time_subset = times <= 7200  # First 2 hours
    im2 = ax2.pcolormesh(times[time_subset]/60, frequencies, Sxx_db[:, time_subset], 
                        shading='gouraud', cmap='viridis', alpha=0.7)
    
    # Overlay cluster information
    df_subset = df[df['end_time_sec'] <= 7200]
    for _, row in df_subset.iterrows():
        start_min = row['start_time_sec'] / 60
        end_min = row['end_time_sec'] / 60
        cluster = int(row['cluster_label'])  # Ensure integer
        color_idx = cluster % 10  # Use modulo to stay within colormap range
        ax2.axvspan(start_min, end_min, ymin=0.95, ymax=1.0, 
                   color=plt.cm.Set1(color_idx), alpha=0.8)
    
    ax2.set_ylabel('Frequency (Hz)')
    ax2.set_xlabel('Time (minutes)')
    ax2.set_title('Spectrogram with Cluster Overlay (First 2 Hours)')
    ax2.set_ylim([0, 50])
    plt.colorbar(im2, ax=ax2, label='Power (dB)')
    
    # 3-6. Frequency band analysis by cluster
    band_names = list(band_powers.keys())[:4]  # Show first 4 bands
    
    for i, band_name in enumerate(band_names):
        ax = fig.add_subplot(gs[2 + i//2, i%2])
        
        # Calculate mean power per cluster
        cluster_band_power = {}
        for cluster in sorted(df['cluster_label'].unique()):
            cluster_mask = np.array(cluster_labels) == cluster
            if np.any(cluster_mask):
                cluster_times_subset = np.array(cluster_times)[cluster_mask]
                time_indices = [np.argmin(np.abs(times - t)) for t in cluster_times_subset]
                # Ensure indices are integers and within bounds
                time_indices = [int(idx) for idx in time_indices if 0 <= int(idx) < len(band_powers[band_name])]
                if time_indices:
                    power_values = band_powers[band_name][time_indices]
                    cluster_band_power[cluster] = np.mean(power_values)
                else:
                    cluster_band_power[cluster] = 0
            else:
                cluster_band_power[cluster] = 0
        
        clusters = list(cluster_band_power.keys())
        powers = list(cluster_band_power.values())
        
        # Use safe color indexing
        colors_safe = [plt.cm.Set1(int(c) % 10) for c in clusters]
        bars = ax.bar(clusters, powers, alpha=0.7, color=colors_safe)
        ax.set_title(f'{band_name} Power by Cluster')
        ax.set_xlabel('Cluster Label')
        ax.set_ylabel('Mean Power')
        
        # Add value labels
        for bar, power in zip(bars, powers):
            if power > 0:
                ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(), 
                       f'{power:.2e}', ha='center', va='bottom', fontsize=8)
    
    # 7. Power spectral density by cluster
    ax7 = fig.add_subplot(gs[3, 2])
    
    for cluster in sorted(df['cluster_label'].unique()):
        cluster_mask = np.array(cluster_labels) == cluster
        if np.any(cluster_mask):
            cluster_times_subset = np.array(cluster_times)[cluster_mask]
            # Fix: Convert time indices to integers and add bounds checking
            time_indices = []
            for t in cluster_times_subset:
                idx = np.argmin(np.abs(times - t))
                # Ensure index is integer and within bounds
                idx = int(idx)
                if 0 <= idx < Sxx.shape[1]:
                    time_indices.append(idx)
            
            if time_indices:  # Only proceed if we have valid indices
                mean_psd = np.mean(Sxx[:, time_indices], axis=1)
                ax7.semilogy(frequencies, mean_psd, label=f'Cluster {cluster}', alpha=0.8)
    
    ax7.set_xlabel('Frequency (Hz)')
    ax7.set_ylabel('Power Spectral Density')
    ax7.set_title('Mean PSD by Cluster')
    ax7.set_xlim([0, 50])
    ax7.legend()
    ax7.grid(True, alpha=0.3)
    
    plt.suptitle(f'Comprehensive Frequency Analysis - {title_suffix}', fontsize=16, fontweight='bold')
    plt.show()
    
    return frequencies, times, Sxx_db, band_powers

# Perform frequency analysis
if eeg_data is not None:
    freq_results_3s = plot_frequency_analysis(eeg_data, fs, results_3s, '3-second windows')
    freq_results_30s = plot_frequency_analysis(eeg_data, fs, results_30s, '30-second windows')
else:
    print('EEG data not available - skipping frequency analysis')

## 5. Interactive Hypnogram Comparison

In [None]:
def create_interactive_hypnogram(results_3s, results_30s):
    """Create interactive hypnogram comparison"""
    
    # Prepare data for 3s results
    df_3s = results_3s['csv'].copy()
    df_3s['time_hours'] = df_3s['start_time_sec'] / 3600
    df_3s['window_type'] = '3-second'
    
    # Prepare data for 30s results
    df_30s = results_30s['csv'].copy()
    df_30s['time_hours'] = df_30s['start_time_sec'] / 3600
    df_30s['window_type'] = '30-second'
    
    # Create subplots
    fig = make_subplots(
        rows=2, cols=1,
        subplot_titles=('3-Second Window Clustering', '30-Second Window Clustering'),
        vertical_spacing=0.1
    )
    
    # Color mapping for clusters
    colors = px.colors.qualitative.Set1
    
    # Plot 3s results
    for cluster in sorted(df_3s['cluster_label'].unique()):
        cluster_data = df_3s[df_3s['cluster_label'] == cluster]
        # Fix: Ensure cluster index is integer for color selection
        color_idx = int(cluster) % len(colors)
        fig.add_trace(
            go.Scatter(
                x=cluster_data['time_hours'],
                y=cluster_data['cluster_label'],
                mode='markers',
                marker=dict(color=colors[color_idx], size=4),
                name=f'3s Cluster {cluster}',
                legendgroup=f'3s_{cluster}',
                showlegend=True
            ),
            row=1, col=1
        )
    
    # Plot 30s results (limit to same time range as 3s)
    max_time_3s = df_3s['time_hours'].max()
    df_30s_subset = df_30s[df_30s['time_hours'] <= max_time_3s]
    
    for cluster in sorted(df_30s_subset['cluster_label'].unique()):
        cluster_data = df_30s_subset[df_30s_subset['cluster_label'] == cluster]
        # Fix: Ensure cluster index is integer for color selection
        color_idx = int(cluster) % len(colors)
        fig.add_trace(
            go.Scatter(
                x=cluster_data['time_hours'],
                y=cluster_data['cluster_label'],
                mode='markers',
                marker=dict(color=colors[color_idx], size=8, symbol='square'),
                name=f'30s Cluster {cluster}',
                legendgroup=f'30s_{cluster}',
                showlegend=True
            ),
            row=2, col=1
        )

    # Update layout
    fig.update_layout(
        title='Interactive Hypnogram Comparison: Time Series Transformer + Clustering Results',
        height=800,
        hovermode='x unified'
    )
    
    # Update axes
    fig.update_xaxes(title_text='Time (hours)', row=2, col=1)
    fig.update_yaxes(title_text='Cluster Label', dtick=1)
    
    fig.show()
    
    return fig

# Create interactive hypnogram
interactive_fig = create_interactive_hypnogram(results_3s, results_30s)

In [None]:
def create_interactive_synchronized_plot(eeg_data, fs, results, title_suffix, start_time_hours, duration_hours):
    """Create interactive synchronized plot with EEG, spectrogram, and hypnogram with synchronized selection/zooming"""
    
    if eeg_data is None:
        print('No EEG data available for synchronized plot')
        return None
    
    # Convert time parameters to seconds
    start_time_sec = start_time_hours * 3600
    duration_sec = duration_hours * 3600
    end_time_sec = start_time_sec + duration_sec
    
    # Extract EEG segment
    start_idx = int(start_time_sec * fs)
    end_idx = int(end_time_sec * fs)
    eeg_segment = eeg_data[start_idx:end_idx]
    
    # Create time array for EEG
    eeg_time = np.linspace(start_time_hours, start_time_hours + duration_hours, len(eeg_segment))
    
    # Extract clustering results for this time segment
    df = results['csv'].copy()
    df_segment = df[(df['start_time_sec'] >= start_time_sec) & (df['end_time_sec'] <= end_time_sec)].copy()
    df_segment['time_hours'] = df_segment['start_time_sec'] / 3600
    
    # Create spectrogram for the segment
    window_length = 10  # seconds for spectrogram
    nperseg = int(window_length * fs)
    noverlap = int(nperseg * 0.75)
    
    frequencies, times_spec, Sxx = signal.spectrogram(
        eeg_segment, fs,
        window='hann',
        nperseg=nperseg,
        noverlap=noverlap
    )
    
    # Convert spectrogram time to hours (relative to start)
    times_spec_hours = start_time_hours + times_spec / 3600
    
    # Convert power to dB
    Sxx_db = 10 * np.log10(Sxx + 1e-12)
    
    # Create subplot figure with synchronized axes
    fig = make_subplots(
        rows=3, cols=1,
        row_heights=[0.3, 0.4, 0.3],
        subplot_titles=(
            f'EEG Signal - {title_suffix}',
            f'Multitaper Spectrogram - {title_suffix}',
            f'Hypnogram - {title_suffix}'
        ),
        shared_xaxes=True,
        vertical_spacing=0.08
    )
    
    # 1. EEG Signal Plot
    fig.add_trace(
        go.Scatter(
            x=eeg_time,
            y=eeg_segment,
            mode='lines',
            name='EEG Signal',
            line=dict(color='blue', width=0.5),
            showlegend=False
        ),
        row=1, col=1
    )
    
    # 2. Spectrogram Plot
    fig.add_trace(
        go.Heatmap(
            x=times_spec_hours,
            y=frequencies,
            z=Sxx_db,
            colorscale='Viridis',
            colorbar=dict(
                title='Power (dB)',
                x=1.02,
                y=0.5,
                len=0.4
            ),
            hovertemplate='Time: %{x:.2f} hrs<br>Frequency: %{y:.1f} Hz<br>Power: %{z:.1f} dB<extra></extra>',
            showscale=True
        ),
        row=2, col=1
    )
    
    # 3. Hypnogram with cluster transitions
    colors = px.colors.qualitative.Set1
    
    # Create step plot for hypnogram
    for cluster in sorted(df_segment['cluster_label'].unique()):
        cluster_data = df_segment[df_segment['cluster_label'] == cluster]
        
        # Create step-like visualization
        x_vals = []
        y_vals = []
        
        for _, row in cluster_data.iterrows():
            start_hour = row['start_time_sec'] / 3600
            end_hour = row['end_time_sec'] / 3600
            cluster_val = row['cluster_label']
            
            # Add points for step function
            x_vals.extend([start_hour, end_hour, end_hour])
            y_vals.extend([cluster_val, cluster_val, None])  # None creates a break
        
        color_idx = int(cluster) % len(colors)
        fig.add_trace(
            go.Scatter(
                x=x_vals,
                y=y_vals,
                mode='lines',
                line=dict(color=colors[color_idx], width=3),
                name=f'Cluster {cluster}',
                hovertemplate='Time: %{x:.2f} hrs<br>Cluster: %{y}<extra></extra>',
                connectgaps=False
            ),
            row=3, col=1
        )
    
    # Add cluster blocks for better visualization
    for _, row in df_segment.iterrows():
        start_hour = row['start_time_sec'] / 3600
        end_hour = row['end_time_sec'] / 3600
        cluster = int(row['cluster_label'])
        color_idx = cluster % len(colors)
        
        fig.add_trace(
            go.Scatter(
                x=[start_hour, end_hour, end_hour, start_hour, start_hour],
                y=[cluster-0.4, cluster-0.4, cluster+0.4, cluster+0.4, cluster-0.4],
                fill='toself',
                fillcolor=colors[color_idx],
                opacity=0.3,
                line=dict(width=0),
                showlegend=False,
                hoverinfo='skip'
            ),
            row=3, col=1
        )
    
    # Update layout for synchronized zooming and panning
    fig.update_layout(
        title=f'Synchronized EEG Analysis - {title_suffix}<br>Time Range: {start_time_hours:.1f} - {start_time_hours + duration_hours:.1f} hours',
        height=1000,
        hovermode='x unified',
        showlegend=True,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )
    
    # Update x-axes
    fig.update_xaxes(
        title_text='Time (hours)',
        row=3, col=1,
        rangeslider=dict(visible=True, thickness=0.05),
        type='linear'
    )
    
    # Update y-axes
    fig.update_yaxes(title_text='Amplitude (µV)', row=1, col=1)
    fig.update_yaxes(
        title_text='Frequency (Hz)',
        row=2, col=1,
        range=[0, 50]  # Focus on relevant frequency range
    )
    fig.update_yaxes(
        title_text='Cluster Label',
        row=3, col=1,
        dtick=1,
        range=[-0.5, max(df_segment['cluster_label']) + 0.5]
    )
    
    # Add annotation with statistics
    total_windows = len(df_segment)
    cluster_counts = df_segment['cluster_label'].value_counts().sort_index()
    stats_text = f"Total windows: {total_windows}<br>"
    for cluster, count in cluster_counts.items():
        pct = (count / total_windows) * 100
        stats_text += f"Cluster {cluster}: {count} ({pct:.1f}%)<br>"
    
    fig.add_annotation(
        text=stats_text,
        xref="paper", yref="paper",
        x=0.02, y=0.98,
        showarrow=False,
        align="left",
        bgcolor="rgba(255,255,255,0.8)",
        bordercolor="black",
        borderwidth=1
    )
    
    return fig

def interactive_segment_analyzer(eeg_data, fs, results_3s, results_30s, start_hours, duration_hours):
    """Interactive function to analyze different time segments"""
    
    print("=== Interactive Synchronized EEG/Spectrogram/Hypnogram Analyzer ===")
    print("\nAvailable functions:")
    print("1. plot_3s_segment(start_hours, duration_hours) - Analyze 3-second results")
    print("2. plot_30s_segment(start_hours, duration_hours) - Analyze 30-second results")
    print("3. compare_segments(start_hours, duration_hours) - Compare both window sizes")
    print("\nExample usage:")
    print("  fig_3s = plot_3s_segment(0, 2)     # First 2 hours with 3s windows")
    print("  fig_30s = plot_30s_segment(4, 1)   # Hour 4-5 with 30s windows")
    print("  compare_segments(8, 4)             # Hours 8-12 comparison")
    
    def plot_3s_segment(start_hours, duration_hours):
        """Plot synchronized analysis for 3-second windows"""
        return create_interactive_synchronized_plot(
            eeg_data, fs, results_3s, '3-second windows', start_hours, duration_hours
        )
    
    def plot_30s_segment(start_hours, duration_hours):
        """Plot synchronized analysis for 30-second windows"""
        return create_interactive_synchronized_plot(
            eeg_data, fs, results_30s, '30-second windows', start_hours, duration_hours
        )
    
    def compare_segments(start_hours, duration_hours):
        """Compare both window sizes side by side"""
        fig_3s = plot_3s_segment(start_hours, duration_hours)
        fig_30s = plot_30s_segment(start_hours, duration_hours)
        
        if fig_3s and fig_30s:
            print(f"\nComparison for time range: {start_hours:.1f} - {start_hours + duration_hours:.1f} hours")
            fig_3s.show()
            fig_30s.show()
            
            # Print comparison statistics
            df_3s = results_3s['csv']
            df_30s = results_30s['csv']
            
            start_sec = start_hours * 3600
            end_sec = (start_hours + duration_hours) * 3600
            
            df_3s_seg = df_3s[(df_3s['start_time_sec'] >= start_sec) & (df_3s['end_time_sec'] <= end_sec)]
            df_30s_seg = df_30s[(df_30s['start_time_sec'] >= start_sec) & (df_30s['end_time_sec'] <= end_sec)]
            
            print(f"\n3-second windows: {len(df_3s_seg)} windows")
            print(f"30-second windows: {len(df_30s_seg)} windows")
            print(f"Resolution ratio: {len(df_3s_seg) / max(len(df_30s_seg), 1):.1f}:1")
            
            return fig_3s, fig_30s
        
        return None, None
    
    # Return the functions for interactive use
    return plot_3s_segment, plot_30s_segment, compare_segments


In [None]:
# Initialize the interactive analyzer
start_hours=12
duration_hours=1

eeg_data, fs, raw_eeg = load_eeg_data()
if eeg_data is not None:
    plot_3s_segment, plot_30s_segment, compare_segments = interactive_segment_analyzer(
        eeg_data, fs, results_3s, results_30s, start_hours, duration_hours
    )
    
    # Create default plots for the first 2 hours
    print("\nCreating default synchronized plots for the first 2 hours...")
    default_fig_3s = plot_3s_segment(start_hours, duration_hours)
    default_fig_30s = plot_30s_segment(start_hours, duration_hours)
    
    if default_fig_3s:
        default_fig_3s.show()
    if default_fig_30s:
        default_fig_30s.show()
        
else:
    print("EEG data not available - cannot create synchronized plots")

## 6. Advanced Statistical Analysis

In [None]:
def advanced_cluster_statistics(results, title_suffix):
    """Perform advanced statistical analysis of clustering results"""
    df = results['csv'].copy()
    
    print(f'\n=== Advanced Statistical Analysis ({title_suffix}) ===')
    
    # 1. Temporal distribution analysis
    df['hour'] = (df['start_time_sec'] / 3600).astype(int)
    hourly_distribution = df.groupby(['hour', 'cluster_label']).size().unstack(fill_value=0)
    
    print('\n1. Hourly Distribution of Clusters:')
    print(hourly_distribution)
    
    # 2. Cluster stability analysis (consecutive same-cluster windows)
    stability_scores = []
    current_cluster = df['cluster_label'].iloc[0]
    current_length = 1
    
    for i in range(1, len(df)):
        if df['cluster_label'].iloc[i] == current_cluster:
            current_length += 1
        else:
            stability_scores.append(current_length)
            current_cluster = df['cluster_label'].iloc[i]
            current_length = 1
    stability_scores.append(current_length)
    
    print(f'\n2. Cluster Stability Analysis:')
    print(f'   Mean consecutive windows: {np.mean(stability_scores):.2f}')
    print(f'   Median consecutive windows: {np.median(stability_scores):.2f}')
    print(f'   Max consecutive windows: {np.max(stability_scores)}')
    print(f'   Number of segments: {len(stability_scores)}')
    
    # 3. Circadian rhythm analysis (if data spans multiple hours)
    if df['hour'].max() >= 4:  # At least 4 hours of data
        print(f'\n3. Circadian Pattern Analysis:')
        for cluster in sorted(df['cluster_label'].unique()):
            cluster_hours = df[df['cluster_label'] == cluster]['hour'].values
            if len(cluster_hours) > 0:
                # Fix: Handle the mode calculation properly
                try:
                    mode_result = stats.mode(cluster_hours, keepdims=False)
                    if hasattr(mode_result, 'mode'):
                        peak_hour = mode_result.mode
                    else:
                        peak_hour = mode_result[0]
                    print(f'   Cluster {cluster}: Peak activity at hour {peak_hour}')
                except Exception as e:
                    # Fallback to manual mode calculation
                    from collections import Counter
                    hour_counts = Counter(cluster_hours)
                    peak_hour = hour_counts.most_common(1)[0][0]
                    print(f'   Cluster {cluster}: Peak activity at hour {peak_hour}')
    
    # 4. Transition analysis
    transitions = []
    for i in range(1, len(df)):
        prev_cluster = df['cluster_label'].iloc[i-1]
        curr_cluster = df['cluster_label'].iloc[i]
        if prev_cluster != curr_cluster:
            transitions.append((prev_cluster, curr_cluster))
    
    if transitions:
        transition_counts = pd.Series(transitions).value_counts()
        print(f'\n4. Most Common Transitions:')
        print(transition_counts.head(10))
    
    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(f'Advanced Statistical Analysis - {title_suffix}', fontsize=16, fontweight='bold')
    
    # Hourly distribution heatmap
    if not hourly_distribution.empty:
        sns.heatmap(hourly_distribution.T, annot=True, fmt='d', cmap='YlOrRd', ax=axes[0,0])
        axes[0,0].set_title('Hourly Cluster Distribution')
        axes[0,0].set_xlabel('Hour of Recording')
        axes[0,0].set_ylabel('Cluster Label')
    
    # Stability distribution
    axes[0,1].hist(stability_scores, bins=min(30, len(set(stability_scores))), alpha=0.7, edgecolor='black')
    axes[0,1].set_title('Distribution of Consecutive Window Lengths')
    axes[0,1].set_xlabel('Consecutive Windows')
    axes[0,1].set_ylabel('Frequency')
    axes[0,1].axvline(np.mean(stability_scores), color='red', linestyle='--', label=f'Mean: {np.mean(stability_scores):.1f}')
    axes[0,1].legend()
    
    # Cluster proportion over time (sliding window)
    window_size = max(100, len(df) // 20)  # Adaptive window size
    time_points = []
    cluster_props = {c: [] for c in sorted(df['cluster_label'].unique())}
    
    for i in range(window_size, len(df), window_size//2):
        window_data = df.iloc[i-window_size:i]
        time_points.append(window_data['start_time_sec'].mean() / 3600)
        
        for cluster in cluster_props.keys():
            prop = (window_data['cluster_label'] == cluster).mean()
            cluster_props[cluster].append(prop)
    
    for cluster, props in cluster_props.items():
        if len(props) > 0:
            axes[1,0].plot(time_points, props, marker='o', label=f'Cluster {cluster}', alpha=0.7)
    
    axes[1,0].set_title('Cluster Proportions Over Time')
    axes[1,0].set_xlabel('Time (hours)')
    axes[1,0].set_ylabel('Proportion')
    axes[1,0].legend()
    axes[1,0].grid(True, alpha=0.3)
    
    # Autocorrelation of cluster sequence
    cluster_sequence = df['cluster_label'].values
    max_lag = min(100, len(cluster_sequence) // 4)
    
    autocorr = []
    for lag in range(max_lag):
        if lag == 0:
            autocorr.append(1.0)
        else:
            # Fix: Ensure proper array indexing and handle edge cases
            if lag < len(cluster_sequence):
                corr_val = np.corrcoef(cluster_sequence[:-lag], cluster_sequence[lag:])[0,1]
                # Handle NaN values
                if np.isnan(corr_val):
                    corr_val = 0.0
                autocorr.append(corr_val)
            else:
                autocorr.append(0.0)
    
    axes[1,1].plot(range(len(autocorr)), autocorr, marker='o', markersize=3)
    axes[1,1].set_title('Autocorrelation of Cluster Sequence')
    axes[1,1].set_xlabel('Lag (windows)')
    axes[1,1].set_ylabel('Autocorrelation')
    axes[1,1].grid(True, alpha=0.3)
    axes[1,1].axhline(y=0, color='red', linestyle='--', alpha=0.5)

    return {
        'hourly_distribution': hourly_distribution,
        'stability_scores': stability_scores,
        'transitions': transitions,
        'autocorr': autocorr
    }

# Perform advanced statistical analysis
stats_3s = advanced_cluster_statistics(results_3s, '3-second windows')
stats_30s = advanced_cluster_statistics(results_30s, '30-second windows')