# Advanced Time Series Transformer + Clustering Analysis

This notebook provides comprehensive analysis of detected sleep stage clusters using time series transformer embeddings followed by K-means clustering. We analyze cluster distributions, temporal dynamics, and frequency domain characteristics.

In [None]:
# Import required libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.signal as signal
from scipy import stats
import mne
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

print('Libraries imported successfully!')

## 1. Data Loading and Preprocessing

In [None]:
# Load clustering results
def load_clustering_results():
    """Load both 3s and 30s clustering results"""
    
    # Load 3-second window results
    results_3s = {}
    results_3s['csv'] = pd.read_csv('results/predicted_labels_3s_tst_20_files_20250530_131707.csv')
    results_3s['npy'] = np.load('results/predicted_labels_3s_tst_20_files_20250530_131707.npy')
    
    with open('results/predicted_labels_3s_tst_20_files_20250530_131707_metadata.txt', 'r') as f:
        results_3s['metadata'] = f.read()
    
    # Load 30-second window results
    results_30s = {}
    results_30s['csv'] = pd.read_csv('results/predicted_labels_30s_tst_3_files_20250523_181027.csv')
    results_30s['npy'] = np.load('results/predicted_labels_30s_tst_3_files_20250523_181027.npy')
    
    with open('results/predicted_labels_30s_tst_3_files_20250523_181027_metadata.txt', 'r') as f:
        results_30s['metadata'] = f.read()
    
    return results_3s, results_30s

# Load original EEG data
def load_eeg_data():
    """Load the original EEG signal for comparison"""
    try:
        # Load the EDF file
        raw = mne.io.read_raw_edf('by captain borat/raw/SC4001E0-PSG.edf', preload=True)
        
        # Get EEG channels (assuming first channel is EEG)
        eeg_data = raw.get_data()[0]  # First channel
        fs = raw.info['sfreq']
        
        return eeg_data, fs, raw
    except Exception as e:
        print(f'Error loading EEG data: {e}')
        return None, None, None

# Load all data
results_3s, results_30s = load_clustering_results()
eeg_data, fs, raw_eeg = load_eeg_data()

print(f'3s results shape: {results_3s["csv"].shape}')
print(f'30s results shape: {results_30s["csv"].shape}')
if eeg_data is not None:
    print(f'EEG data shape: {eeg_data.shape}, Sampling rate: {fs} Hz')
else:
    print('EEG data not available')

## 2. Cluster Distribution Analysis

In [None]:
# Let's examine the actual window durations to verify the issue
print("=== Window Duration Analysis ===")

# Check 3s results
df_3s = results_3s['csv']
window_duration_3s = df_3s['end_time_sec'].iloc[0] - df_3s['start_time_sec'].iloc[0]
print(f"\n3-second dataset:")
print(f"  First few windows:")
for i in range(5):
    start = df_3s['start_time_sec'].iloc[i]
    end = df_3s['end_time_sec'].iloc[i]
    duration = end - start
    print(f"    Window {i}: {start:.1f}s - {end:.1f}s (duration: {duration:.1f}s)")
print(f"  Expected window duration: 3s, Actual: {window_duration_3s:.1f}s")

# Check 30s results  
df_30s = results_30s['csv']
window_duration_30s = df_30s['end_time_sec'].iloc[0] - df_30s['start_time_sec'].iloc[0]
print(f"\n30-second dataset:")
print(f"  First few windows:")
for i in range(5):
    start = df_30s['start_time_sec'].iloc[i]
    end = df_30s['end_time_sec'].iloc[i]
    duration = end - start
    print(f"    Window {i}: {start:.1f}s - {end:.1f}s (duration: {duration:.1f}s)")
print(f"  Expected window duration: 30s, Actual: {window_duration_30s:.1f}s")

# Check if there's a labeling issue
print(f"\nTotal windows comparison:")
print(f"  3s dataset: {len(df_3s)} windows")
print(f"  30s dataset: {len(df_30s)} windows")
print(f"  Expected ratio (3s:30s): ~10:1, Actual ratio: {len(df_3s)/len(df_30s):.1f}:1")

In [None]:
def analyze_cluster_distribution(results, title_suffix):
    """Analyze and visualize cluster distributions"""
    df = results['csv']
    
    # Basic statistics
    cluster_counts = df['cluster_label'].value_counts().sort_index()
    cluster_percentages = (cluster_counts / len(df) * 100).round(2)
    
    print(f'\n=== Cluster Distribution Analysis ({title_suffix}) ===')
    print(f'Total windows: {len(df)}')
    print(f'Total duration: {df["end_time_sec"].max()/3600:.2f} hours')
    print('\nCluster distribution:')
    for cluster in sorted(cluster_counts.index):
        count = cluster_counts[cluster]
        pct = cluster_percentages[cluster]
        duration_min = count * (df['end_time_sec'][0] - df['start_time_sec'][0]) / 60
        print(f'  Cluster {cluster}: {count:4d} windows ({pct:5.1f}%) - {duration_min:.1f} min')
    
    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(f'Cluster Distribution Analysis - {title_suffix}', fontsize=16, fontweight='bold')
    
    # Bar plot of counts
    axes[0,0].bar(cluster_counts.index, cluster_counts.values, alpha=0.7)
    axes[0,0].set_title('Cluster Counts')
    axes[0,0].set_xlabel('Cluster Label')
    axes[0,0].set_ylabel('Number of Windows')
    for i, v in enumerate(cluster_counts.values):
        axes[0,0].text(cluster_counts.index[i], v, str(v), ha='center', va='bottom')
    
    # Pie chart
    axes[0,1].pie(cluster_counts.values, labels=[f'Cluster {i}' for i in cluster_counts.index], 
                  autopct='%1.1f%%', startangle=90)
    axes[0,1].set_title('Cluster Proportions')
    
    # Timeline visualization (first 2 hours)
    df_subset = df[df['end_time_sec'] <= 7200].copy()  # First 2 hours
    time_points = df_subset['start_time_sec'] / 60  # Convert to minutes
    colors = plt.cm.Set1(df_subset['cluster_label'])
    
    axes[1,0].scatter(time_points, df_subset['cluster_label'], c=colors, alpha=0.6, s=20)
    axes[1,0].set_title('Cluster Timeline (First 2 Hours)')
    axes[1,0].set_xlabel('Time (minutes)')
    axes[1,0].set_ylabel('Cluster Label')
    axes[1,0].set_yticks(sorted(df['cluster_label'].unique()))
    
    # Cluster transitions
    transitions = []
    for i in range(1, len(df)):
        if df.iloc[i]['cluster_label'] != df.iloc[i-1]['cluster_label']:
            transitions.append((df.iloc[i-1]['cluster_label'], df.iloc[i]['cluster_label']))
    
    if transitions:
        transition_df = pd.DataFrame(transitions, columns=['from_cluster', 'to_cluster'])
        transition_matrix = pd.crosstab(transition_df['from_cluster'], transition_df['to_cluster'])
        
        sns.heatmap(transition_matrix, annot=True, fmt='d', cmap='Blues', ax=axes[1,1])
        axes[1,1].set_title('Cluster Transition Matrix')
        axes[1,1].set_xlabel('To Cluster')
        axes[1,1].set_ylabel('From Cluster')
    else:
        axes[1,1].text(0.5, 0.5, 'No transitions found', ha='center', va='center', transform=axes[1,1].transAxes)
        axes[1,1].set_title('Cluster Transition Matrix')
    
    plt.tight_layout()
    plt.show()
    
    return cluster_counts, cluster_percentages

# Analyze both datasets
cluster_counts_3s, cluster_pct_3s = analyze_cluster_distribution(results_3s, '3-second windows')
cluster_counts_30s, cluster_pct_30s = analyze_cluster_distribution(results_30s, '30-second windows')

## 3. Continuous Duration Analysis

In [None]:
def analyze_continuous_durations(results, title_suffix):
    """Analyze continuous durations of each cluster"""
    df = results['csv'].copy()
    window_duration = df['end_time_sec'].iloc[0] - df['start_time_sec'].iloc[0]
    
    # Find continuous segments
    segments = []
    current_cluster = df['cluster_label'].iloc[0]
    start_idx = 0
    
    for i in range(1, len(df)):
        if df['cluster_label'].iloc[i] != current_cluster:
            # End of current segment
            duration = (i - start_idx) * window_duration
            segments.append({
                'cluster': current_cluster,
                'start_time': df['start_time_sec'].iloc[start_idx],
                'end_time': df['end_time_sec'].iloc[i-1],
                'duration_sec': duration,
                'duration_min': duration / 60,
                'num_windows': i - start_idx
            })
            
            # Start new segment
            current_cluster = df['cluster_label'].iloc[i]
            start_idx = i
    
    # Add final segment
    duration = (len(df) - start_idx) * window_duration
    segments.append({
        'cluster': current_cluster,
        'start_time': df['start_time_sec'].iloc[start_idx],
        'end_time': df['end_time_sec'].iloc[-1],
        'duration_sec': duration,
        'duration_min': duration / 60,
        'num_windows': len(df) - start_idx
    })
    
    segments_df = pd.DataFrame(segments)
    
    print(f'\n=== Continuous Duration Analysis ({title_suffix}) ===')
    print(f'Total continuous segments: {len(segments_df)}')
    
    # Statistics by cluster
    duration_stats = segments_df.groupby('cluster')['duration_min'].agg([
        'count', 'mean', 'median', 'std', 'min', 'max'
    ]).round(2)
    
    print('\nDuration statistics by cluster (minutes):')
    print(duration_stats)
    
    # Visualization
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(f'Continuous Duration Analysis - {title_suffix}', fontsize=16, fontweight='bold')
    
    # Box plot of durations by cluster
    cluster_labels = sorted(segments_df['cluster'].unique())
    duration_data = [segments_df[segments_df['cluster'] == c]['duration_min'].values for c in cluster_labels]
    
    axes[0,0].boxplot(duration_data, labels=[f'Cluster {c}' for c in cluster_labels])
    axes[0,0].set_title('Duration Distribution by Cluster')
    axes[0,0].set_ylabel('Duration (minutes)')
    axes[0,0].tick_params(axis='x', rotation=45)
    
    # Histogram of all durations
    axes[0,1].hist(segments_df['duration_min'], bins=30, alpha=0.7, edgecolor='black')
    axes[0,1].set_title('Overall Duration Distribution')
    axes[0,1].set_xlabel('Duration (minutes)')
    axes[0,1].set_ylabel('Frequency')
    
    # Mean duration by cluster
    mean_durations = segments_df.groupby('cluster')['duration_min'].mean()
    axes[1,0].bar(mean_durations.index, mean_durations.values, alpha=0.7)
    axes[1,0].set_title('Mean Duration by Cluster')
    axes[1,0].set_xlabel('Cluster Label')
    axes[1,0].set_ylabel('Mean Duration (minutes)')
    for i, v in enumerate(mean_durations.values):
        axes[1,0].text(mean_durations.index[i], v, f'{v:.1f}', ha='center', va='bottom')
    
    # Timeline of segments (first 4 hours)
    segments_subset = segments_df[segments_df['start_time'] <= 14400].copy()  # First 4 hours
    
    for i, row in segments_subset.iterrows():
        start_min = row['start_time'] / 60
        end_min = row['end_time'] / 60
        cluster = row['cluster']
        axes[1,1].barh(cluster, end_min - start_min, left=start_min, 
                      alpha=0.7, label=f'Cluster {cluster}' if i == 0 else '')
    
    axes[1,1].set_title('Segment Timeline (First 4 Hours)')
    axes[1,1].set_xlabel('Time (minutes)')
    axes[1,1].set_ylabel('Cluster Label')
    axes[1,1].set_yticks(sorted(segments_df['cluster'].unique()))
    
    plt.tight_layout()
    plt.show()
    
    return segments_df, duration_stats

# Analyze continuous durations
segments_3s, duration_stats_3s = analyze_continuous_durations(results_3s, '3-second windows')
segments_30s, duration_stats_30s = analyze_continuous_durations(results_30s, '30-second windows')

## 4. Frequency Domain Analysis with Multitaper Spectrogram

In [None]:
def create_multitaper_spectrogram(eeg_data, fs, window_length=30, overlap=0.5):
    """Create multitaper spectrogram of EEG data"""
    if eeg_data is None:
        print('No EEG data available for spectrogram analysis')
        return None, None, None
    
    # Parameters for spectrogram
    nperseg = int(window_length * fs)
    noverlap = int(nperseg * overlap)
    
    # Compute spectrogram (removed invalid 'method' parameter)
    frequencies, times, Sxx = signal.spectrogram(
        eeg_data, fs,
        window='hann',
        nperseg=nperseg,
        noverlap=noverlap
    )
    
    return frequencies, times, Sxx

# Alternative function using scipy's built-in multitaper method
def create_multitaper_psd(eeg_data, fs, window_length=30):
    """Create multitaper power spectral density"""
    if eeg_data is None:
        return None, None
    
    try:
        # Use multitaper method for PSD estimation
        from scipy import signal
        nperseg = int(window_length * fs)
        
        # Calculate PSD using multitaper method
        frequencies, psd = signal.welch(
            eeg_data, fs,
            window='hann',
            nperseg=nperseg,
            method='multitaper'
        )
        
        return frequencies, psd
    except Exception as e:
        print(f'Multitaper PSD calculation failed: {e}')
        # Fallback to regular welch method
        frequencies, psd = signal.welch(
            eeg_data, fs,
            window='hann',
            nperseg=nperseg
        )
        return frequencies, psd

def analyze_frequency_bands(frequencies, Sxx, cluster_labels, times):
    """Analyze power in different frequency bands for each cluster"""
    
    # Define frequency bands
    bands = {
        'Delta (0.5-4 Hz)': (0.5, 4),
        'Theta (4-8 Hz)': (4, 8),
        'Alpha (8-13 Hz)': (8, 13),
        'Beta (13-30 Hz)': (13, 30),
        'Gamma (30-50 Hz)': (30, 50)
    }
    
    # Calculate power in each band
    band_powers = {}
    
    for band_name, (low_freq, high_freq) in bands.items():
        # Find frequency indices
        freq_mask = (frequencies >= low_freq) & (frequencies <= high_freq)
        
        if np.any(freq_mask):
            # Calculate mean power in this band across time
            band_power = np.mean(Sxx[freq_mask, :], axis=0)
            band_powers[band_name] = band_power
        else:
            band_powers[band_name] = np.zeros(len(times))
    
    return band_powers

def plot_frequency_analysis(eeg_data, fs, results, title_suffix):
    """Comprehensive frequency domain analysis"""
    
    if eeg_data is None:
        print(f'Cannot perform frequency analysis for {title_suffix} - no EEG data')
        return
    
    # Truncate EEG data to match clustering results duration
    df = results['csv']
    max_time = df['end_time_sec'].max()
    eeg_truncated = eeg_data[:int(max_time * fs)]
    
    print(f'\n=== Frequency Domain Analysis ({title_suffix}) ===')
    print(f'EEG duration: {len(eeg_truncated)/fs/3600:.2f} hours')
    print(f'Clustering duration: {max_time/3600:.2f} hours')
    
    # Create spectrogram (fixed function call)
    frequencies, times, Sxx = create_multitaper_spectrogram(eeg_truncated, fs)
    
    if frequencies is None:
        return
    
    # Convert power to dB
    Sxx_db = 10 * np.log10(Sxx + 1e-12)
    
    # Create cluster labels aligned with spectrogram times
    cluster_times = []
    cluster_labels = []
    
    for _, row in df.iterrows():
        start_time = row['start_time_sec']
        end_time = row['end_time_sec']
        cluster = row['cluster_label']
        
        # Find corresponding time indices in spectrogram
        time_mask = (times >= start_time) & (times < end_time)
        cluster_times.extend(times[time_mask])
        cluster_labels.extend([cluster] * np.sum(time_mask))
    
    # Analyze frequency bands
    band_powers = analyze_frequency_bands(frequencies, Sxx, cluster_labels, times)
    
    # Create comprehensive visualization
    fig = plt.figure(figsize=(20, 15))
    gs = fig.add_gridspec(4, 3, hspace=0.3, wspace=0.3)
    
    # 1. Full spectrogram
    ax1 = fig.add_subplot(gs[0, :])
    im1 = ax1.pcolormesh(times/60, frequencies, Sxx_db, shading='gouraud', cmap='viridis')
    ax1.set_ylabel('Frequency (Hz)')
    ax1.set_title(f'Spectrogram - {title_suffix}')
    ax1.set_ylim([0, 50])
    plt.colorbar(im1, ax=ax1, label='Power (dB)')
    
    # 2. Cluster overlay on spectrogram (first 2 hours)
    ax2 = fig.add_subplot(gs[1, :])
    time_subset = times <= 7200  # First 2 hours
    im2 = ax2.pcolormesh(times[time_subset]/60, frequencies, Sxx_db[:, time_subset], 
                        shading='gouraud', cmap='viridis', alpha=0.7)
    
    # Overlay cluster information
    df_subset = df[df['end_time_sec'] <= 7200]
    for _, row in df_subset.iterrows():
        start_min = row['start_time_sec'] / 60
        end_min = row['end_time_sec'] / 60
        cluster = int(row['cluster_label'])  # Ensure integer
        color_idx = cluster % 10  # Use modulo to stay within colormap range
        ax2.axvspan(start_min, end_min, ymin=0.95, ymax=1.0, 
                   color=plt.cm.Set1(color_idx), alpha=0.8)
    
    ax2.set_ylabel('Frequency (Hz)')
    ax2.set_xlabel('Time (minutes)')
    ax2.set_title('Spectrogram with Cluster Overlay (First 2 Hours)')
    ax2.set_ylim([0, 50])
    plt.colorbar(im2, ax=ax2, label='Power (dB)')
    
    # 3-6. Frequency band analysis by cluster
    band_names = list(band_powers.keys())[:4]  # Show first 4 bands
    
    for i, band_name in enumerate(band_names):
        ax = fig.add_subplot(gs[2 + i//2, i%2])
        
        # Calculate mean power per cluster
        cluster_band_power = {}
        for cluster in sorted(df['cluster_label'].unique()):
            cluster_mask = np.array(cluster_labels) == cluster
            if np.any(cluster_mask):
                cluster_times_subset = np.array(cluster_times)[cluster_mask]
                time_indices = [np.argmin(np.abs(times - t)) for t in cluster_times_subset]
                # Ensure indices are integers and within bounds
                time_indices = [int(idx) for idx in time_indices if 0 <= int(idx) < len(band_powers[band_name])]
                if time_indices:
                    power_values = band_powers[band_name][time_indices]
                    cluster_band_power[cluster] = np.mean(power_values)
                else:
                    cluster_band_power[cluster] = 0
            else:
                cluster_band_power[cluster] = 0
        
        clusters = list(cluster_band_power.keys())
        powers = list(cluster_band_power.values())
        
        # Use safe color indexing
        colors_safe = [plt.cm.Set1(int(c) % 10) for c in clusters]
        bars = ax.bar(clusters, powers, alpha=0.7, color=colors_safe)
        ax.set_title(f'{band_name} Power by Cluster')
        ax.set_xlabel('Cluster Label')
        ax.set_ylabel('Mean Power')
        
        # Add value labels
        for bar, power in zip(bars, powers):
            if power > 0:
                ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(), 
                       f'{power:.2e}', ha='center', va='bottom', fontsize=8)
    
    # 7. Power spectral density by cluster
    ax7 = fig.add_subplot(gs[3, 2])
    
    for cluster in sorted(df['cluster_label'].unique()):
        cluster_mask = np.array(cluster_labels) == cluster
        if np.any(cluster_mask):
            cluster_times_subset = np.array(cluster_times)[cluster_mask]
            # Fix: Convert time indices to integers and add bounds checking
            time_indices = []
            for t in cluster_times_subset:
                idx = np.argmin(np.abs(times - t))
                # Ensure index is integer and within bounds
                idx = int(idx)
                if 0 <= idx < Sxx.shape[1]:
                    time_indices.append(idx)
            
            if time_indices:  # Only proceed if we have valid indices
                mean_psd = np.mean(Sxx[:, time_indices], axis=1)
                ax7.semilogy(frequencies, mean_psd, label=f'Cluster {cluster}', alpha=0.8)
    
    ax7.set_xlabel('Frequency (Hz)')
    ax7.set_ylabel('Power Spectral Density')
    ax7.set_title('Mean PSD by Cluster')
    ax7.set_xlim([0, 50])
    ax7.legend()
    ax7.grid(True, alpha=0.3)
    
    plt.suptitle(f'Comprehensive Frequency Analysis - {title_suffix}', fontsize=16, fontweight='bold')
    plt.show()
    
    return frequencies, times, Sxx_db, band_powers

# Perform frequency analysis
if eeg_data is not None:
    freq_results_3s = plot_frequency_analysis(eeg_data, fs, results_3s, '3-second windows')
    freq_results_30s = plot_frequency_analysis(eeg_data, fs, results_30s, '30-second windows')
else:
    print('EEG data not available - skipping frequency analysis')

## 5. Interactive Hypnogram Comparison

In [None]:
def create_interactive_hypnogram(results_3s, results_30s):
    """Create interactive hypnogram comparison"""
    
    # Prepare data for 3s results
    df_3s = results_3s['csv'].copy()
    df_3s['time_hours'] = df_3s['start_time_sec'] / 3600
    df_3s['window_type'] = '3-second'
    
    # Prepare data for 30s results
    df_30s = results_30s['csv'].copy()
    df_30s['time_hours'] = df_30s['start_time_sec'] / 3600
    df_30s['window_type'] = '30-second'
    
    # Create subplots
    fig = make_subplots(
        rows=2, cols=1,
        subplot_titles=('3-Second Window Clustering', '30-Second Window Clustering'),
        vertical_spacing=0.1
    )
    
    # Color mapping for clusters
    colors = px.colors.qualitative.Set1
    
    # Plot 3s results
    for cluster in sorted(df_3s['cluster_label'].unique()):
        cluster_data = df_3s[df_3s['cluster_label'] == cluster]
        # Fix: Ensure cluster index is integer for color selection
        color_idx = int(cluster) % len(colors)
        fig.add_trace(
            go.Scatter(
                x=cluster_data['time_hours'],
                y=cluster_data['cluster_label'],
                mode='markers',
                marker=dict(color=colors[color_idx], size=4),
                name=f'3s Cluster {cluster}',
                legendgroup=f'3s_{cluster}',
                showlegend=True
            ),
            row=1, col=1
        )
    
    # Plot 30s results (limit to same time range as 3s)
    max_time_3s = df_3s['time_hours'].max()
    df_30s_subset = df_30s[df_30s['time_hours'] <= max_time_3s]
    
    for cluster in sorted(df_30s_subset['cluster_label'].unique()):
        cluster_data = df_30s_subset[df_30s_subset['cluster_label'] == cluster]
        # Fix: Ensure cluster index is integer for color selection
        color_idx = int(cluster) % len(colors)
        fig.add_trace(
            go.Scatter(
                x=cluster_data['time_hours'],
                y=cluster_data['cluster_label'],
                mode='markers',
                marker=dict(color=colors[color_idx], size=8, symbol='square'),
                name=f'30s Cluster {cluster}',
                legendgroup=f'30s_{cluster}',
                showlegend=True
            ),
            row=2, col=1
        )

    # Update layout
    fig.update_layout(
        title='Interactive Hypnogram Comparison: Time Series Transformer + Clustering Results',
        height=800,
        hovermode='x unified'
    )
    
    # Update axes
    fig.update_xaxes(title_text='Time (hours)', row=2, col=1)
    fig.update_yaxes(title_text='Cluster Label', dtick=1)
    
    fig.show()
    
    return fig

# Create interactive hypnogram
interactive_fig = create_interactive_hypnogram(results_3s, results_30s)

In [None]:
def create_interactive_synchronized_plot(eeg_data, fs, results, title_suffix, start_time_hours, duration_hours):
    """Create interactive synchronized plot with EEG, spectrogram, and hypnogram with synchronized selection/zooming"""
    
    if eeg_data is None:
        print('No EEG data available for synchronized plot')
        return None
    
    # Convert time parameters to seconds
    start_time_sec = start_time_hours * 3600
    duration_sec = duration_hours * 3600
    end_time_sec = start_time_sec + duration_sec
    
    # Extract EEG segment
    start_idx = int(start_time_sec * fs)
    end_idx = int(end_time_sec * fs)
    eeg_segment = eeg_data[start_idx:end_idx]
    
    # Create time array for EEG
    eeg_time = np.linspace(start_time_hours, start_time_hours + duration_hours, len(eeg_segment))
    
    # Extract clustering results for this time segment
    df = results['csv'].copy()
    df_segment = df[(df['start_time_sec'] >= start_time_sec) & (df['end_time_sec'] <= end_time_sec)].copy()
    df_segment['time_hours'] = df_segment['start_time_sec'] / 3600
    
    # Create spectrogram for the segment
    window_length = 10  # seconds for spectrogram
    nperseg = int(window_length * fs)
    noverlap = int(nperseg * 0.75)
    
    frequencies, times_spec, Sxx = signal.spectrogram(
        eeg_segment, fs,
        window='hann',
        nperseg=nperseg,
        noverlap=noverlap
    )
    
    # Convert spectrogram time to hours (relative to start)
    times_spec_hours = start_time_hours + times_spec / 3600
    
    # Convert power to dB
    Sxx_db = 10 * np.log10(Sxx + 1e-12)
    
    # Create subplot figure with synchronized axes
    fig = make_subplots(
        rows=3, cols=1,
        row_heights=[0.3, 0.4, 0.3],
        subplot_titles=(
            f'EEG Signal - {title_suffix}',
            f'Multitaper Spectrogram - {title_suffix}',
            f'Hypnogram - {title_suffix}'
        ),
        shared_xaxes=True,
        vertical_spacing=0.08
    )
    
    # 1. EEG Signal Plot
    fig.add_trace(
        go.Scatter(
            x=eeg_time,
            y=eeg_segment,
            mode='lines',
            name='EEG Signal',
            line=dict(color='blue', width=0.5),
            showlegend=False
        ),
        row=1, col=1
    )
    
    # 2. Spectrogram Plot
    fig.add_trace(
        go.Heatmap(
            x=times_spec_hours,
            y=frequencies,
            z=Sxx_db,
            colorscale='Viridis',
            colorbar=dict(
                title='Power (dB)',
                x=1.02,
                y=0.5,
                len=0.4
            ),
            hovertemplate='Time: %{x:.2f} hrs<br>Frequency: %{y:.1f} Hz<br>Power: %{z:.1f} dB<extra></extra>',
            showscale=True
        ),
        row=2, col=1
    )
    
    # 3. Hypnogram with cluster transitions
    colors = px.colors.qualitative.Set1
    
    # Create step plot for hypnogram
    for cluster in sorted(df_segment['cluster_label'].unique()):
        cluster_data = df_segment[df_segment['cluster_label'] == cluster]
        
        # Create step-like visualization
        x_vals = []
        y_vals = []
        
        for _, row in cluster_data.iterrows():
            start_hour = row['start_time_sec'] / 3600
            end_hour = row['end_time_sec'] / 3600
            cluster_val = row['cluster_label']
            
            # Add points for step function
            x_vals.extend([start_hour, end_hour, end_hour])
            y_vals.extend([cluster_val, cluster_val, None])  # None creates a break
        
        color_idx = int(cluster) % len(colors)
        fig.add_trace(
            go.Scatter(
                x=x_vals,
                y=y_vals,
                mode='lines',
                line=dict(color=colors[color_idx], width=3),
                name=f'Cluster {cluster}',
                hovertemplate='Time: %{x:.2f} hrs<br>Cluster: %{y}<extra></extra>',
                connectgaps=False
            ),
            row=3, col=1
        )
    
    # Add cluster blocks for better visualization
    for _, row in df_segment.iterrows():
        start_hour = row['start_time_sec'] / 3600
        end_hour = row['end_time_sec'] / 3600
        cluster = int(row['cluster_label'])
        color_idx = cluster % len(colors)
        
        fig.add_trace(
            go.Scatter(
                x=[start_hour, end_hour, end_hour, start_hour, start_hour],
                y=[cluster-0.4, cluster-0.4, cluster+0.4, cluster+0.4, cluster-0.4],
                fill='toself',
                fillcolor=colors[color_idx],
                opacity=0.3,
                line=dict(width=0),
                showlegend=False,
                hoverinfo='skip'
            ),
            row=3, col=1
        )
    
    # Update layout for synchronized zooming and panning
    fig.update_layout(
        title=f'Synchronized EEG Analysis - {title_suffix}<br>Time Range: {start_time_hours:.1f} - {start_time_hours + duration_hours:.1f} hours',
        height=1000,
        hovermode='x unified',
        showlegend=True,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )
    
    # Update x-axes
    fig.update_xaxes(
        title_text='Time (hours)',
        row=3, col=1,
        rangeslider=dict(visible=True, thickness=0.05),
        type='linear'
    )
    
    # Update y-axes
    fig.update_yaxes(title_text='Amplitude (µV)', row=1, col=1)
    fig.update_yaxes(
        title_text='Frequency (Hz)',
        row=2, col=1,
        range=[0, 50]  # Focus on relevant frequency range
    )
    fig.update_yaxes(
        title_text='Cluster Label',
        row=3, col=1,
        dtick=1,
        range=[-0.5, max(df_segment['cluster_label']) + 0.5]
    )
    
    # Add annotation with statistics
    total_windows = len(df_segment)
    cluster_counts = df_segment['cluster_label'].value_counts().sort_index()
    stats_text = f"Total windows: {total_windows}<br>"
    for cluster, count in cluster_counts.items():
        pct = (count / total_windows) * 100
        stats_text += f"Cluster {cluster}: {count} ({pct:.1f}%)<br>"
    
    fig.add_annotation(
        text=stats_text,
        xref="paper", yref="paper",
        x=0.02, y=0.98,
        showarrow=False,
        align="left",
        bgcolor="rgba(255,255,255,0.8)",
        bordercolor="black",
        borderwidth=1
    )
    
    return fig

def interactive_segment_analyzer(eeg_data, fs, results_3s, results_30s, start_hours, duration_hours):
    """Interactive function to analyze different time segments"""
    
    print("=== Interactive Synchronized EEG/Spectrogram/Hypnogram Analyzer ===")
    print("\nAvailable functions:")
    print("1. plot_3s_segment(start_hours, duration_hours) - Analyze 3-second results")
    print("2. plot_30s_segment(start_hours, duration_hours) - Analyze 30-second results")
    print("3. compare_segments(start_hours, duration_hours) - Compare both window sizes")
    print("\nExample usage:")
    print("  fig_3s = plot_3s_segment(0, 2)     # First 2 hours with 3s windows")
    print("  fig_30s = plot_30s_segment(4, 1)   # Hour 4-5 with 30s windows")
    print("  compare_segments(8, 4)             # Hours 8-12 comparison")
    
    def plot_3s_segment(start_hours, duration_hours):
        """Plot synchronized analysis for 3-second windows"""
        return create_interactive_synchronized_plot(
            eeg_data, fs, results_3s, '3-second windows', start_hours, duration_hours
        )
    
    def plot_30s_segment(start_hours, duration_hours):
        """Plot synchronized analysis for 30-second windows"""
        return create_interactive_synchronized_plot(
            eeg_data, fs, results_30s, '30-second windows', start_hours, duration_hours
        )
    
    def compare_segments(start_hours, duration_hours):
        """Compare both window sizes side by side"""
        fig_3s = plot_3s_segment(start_hours, duration_hours)
        fig_30s = plot_30s_segment(start_hours, duration_hours)
        
        if fig_3s and fig_30s:
            print(f"\nComparison for time range: {start_hours:.1f} - {start_hours + duration_hours:.1f} hours")
            fig_3s.show()
            fig_30s.show()
            
            # Print comparison statistics
            df_3s = results_3s['csv']
            df_30s = results_30s['csv']
            
            start_sec = start_hours * 3600
            end_sec = (start_hours + duration_hours) * 3600
            
            df_3s_seg = df_3s[(df_3s['start_time_sec'] >= start_sec) & (df_3s['end_time_sec'] <= end_sec)]
            df_30s_seg = df_30s[(df_30s['start_time_sec'] >= start_sec) & (df_30s['end_time_sec'] <= end_sec)]
            
            print(f"\n3-second windows: {len(df_3s_seg)} windows")
            print(f"30-second windows: {len(df_30s_seg)} windows")
            print(f"Resolution ratio: {len(df_3s_seg) / max(len(df_30s_seg), 1):.1f}:1")
            
            return fig_3s, fig_30s
        
        return None, None
    
    # Return the functions for interactive use
    return plot_3s_segment, plot_30s_segment, compare_segments


In [None]:
# Initialize the interactive analyzer
start_hours=12
duration_hours=1

eeg_data, fs, raw_eeg = load_eeg_data()
if eeg_data is not None:
    plot_3s_segment, plot_30s_segment, compare_segments = interactive_segment_analyzer(
        eeg_data, fs, results_3s, results_30s, start_hours, duration_hours
    )
    
    # Create default plots for the first 2 hours
    print("\nCreating default synchronized plots for the first 2 hours...")
    default_fig_3s = plot_3s_segment(start_hours, duration_hours)
    default_fig_30s = plot_30s_segment(start_hours, duration_hours)
    
    if default_fig_3s:
        default_fig_3s.show()
    if default_fig_30s:
        default_fig_30s.show()
        
else:
    print("EEG data not available - cannot create synchronized plots")

## 6. Advanced Statistical Analysis

In [None]:
def advanced_cluster_statistics(results, title_suffix):
    """Perform advanced statistical analysis of clustering results"""
    df = results['csv'].copy()
    
    print(f'\n=== Advanced Statistical Analysis ({title_suffix}) ===')
    
    # 1. Temporal distribution analysis
    df['hour'] = (df['start_time_sec'] / 3600).astype(int)
    hourly_distribution = df.groupby(['hour', 'cluster_label']).size().unstack(fill_value=0)
    
    print('\n1. Hourly Distribution of Clusters:')
    print(hourly_distribution)
    
    # 2. Cluster stability analysis (consecutive same-cluster windows)
    stability_scores = []
    current_cluster = df['cluster_label'].iloc[0]
    current_length = 1
    
    for i in range(1, len(df)):
        if df['cluster_label'].iloc[i] == current_cluster:
            current_length += 1
        else:
            stability_scores.append(current_length)
            current_cluster = df['cluster_label'].iloc[i]
            current_length = 1
    stability_scores.append(current_length)
    
    print(f'\n2. Cluster Stability Analysis:')
    print(f'   Mean consecutive windows: {np.mean(stability_scores):.2f}')
    print(f'   Median consecutive windows: {np.median(stability_scores):.2f}')
    print(f'   Max consecutive windows: {np.max(stability_scores)}')
    print(f'   Number of segments: {len(stability_scores)}')
    
    # 3. Circadian rhythm analysis (if data spans multiple hours)
    if df['hour'].max() >= 4:  # At least 4 hours of data
        print(f'\n3. Circadian Pattern Analysis:')
        for cluster in sorted(df['cluster_label'].unique()):
            cluster_hours = df[df['cluster_label'] == cluster]['hour'].values
            if len(cluster_hours) > 0:
                # Fix: Handle the mode calculation properly
                try:
                    mode_result = stats.mode(cluster_hours, keepdims=False)
                    if hasattr(mode_result, 'mode'):
                        peak_hour = mode_result.mode
                    else:
                        peak_hour = mode_result[0]
                    print(f'   Cluster {cluster}: Peak activity at hour {peak_hour}')
                except Exception as e:
                    # Fallback to manual mode calculation
                    from collections import Counter
                    hour_counts = Counter(cluster_hours)
                    peak_hour = hour_counts.most_common(1)[0][0]
                    print(f'   Cluster {cluster}: Peak activity at hour {peak_hour}')
    
    # 4. Transition analysis
    transitions = []
    for i in range(1, len(df)):
        prev_cluster = df['cluster_label'].iloc[i-1]
        curr_cluster = df['cluster_label'].iloc[i]
        if prev_cluster != curr_cluster:
            transitions.append((prev_cluster, curr_cluster))
    
    if transitions:
        transition_counts = pd.Series(transitions).value_counts()
        print(f'\n4. Most Common Transitions:')
        print(transition_counts.head(10))
    
    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(f'Advanced Statistical Analysis - {title_suffix}', fontsize=16, fontweight='bold')
    
    # Hourly distribution heatmap
    if not hourly_distribution.empty:
        sns.heatmap(hourly_distribution.T, annot=True, fmt='d', cmap='YlOrRd', ax=axes[0,0])
        axes[0,0].set_title('Hourly Cluster Distribution')
        axes[0,0].set_xlabel('Hour of Recording')
        axes[0,0].set_ylabel('Cluster Label')
    
    # Stability distribution
    axes[0,1].hist(stability_scores, bins=min(30, len(set(stability_scores))), alpha=0.7, edgecolor='black')
    axes[0,1].set_title('Distribution of Consecutive Window Lengths')
    axes[0,1].set_xlabel('Consecutive Windows')
    axes[0,1].set_ylabel('Frequency')
    axes[0,1].axvline(np.mean(stability_scores), color='red', linestyle='--', label=f'Mean: {np.mean(stability_scores):.1f}')
    axes[0,1].legend()
    
    # Cluster proportion over time (sliding window)
    window_size = max(100, len(df) // 20)  # Adaptive window size
    time_points = []
    cluster_props = {c: [] for c in sorted(df['cluster_label'].unique())}
    
    for i in range(window_size, len(df), window_size//2):
        window_data = df.iloc[i-window_size:i]
        time_points.append(window_data['start_time_sec'].mean() / 3600)
        
        for cluster in cluster_props.keys():
            prop = (window_data['cluster_label'] == cluster).mean()
            cluster_props[cluster].append(prop)
    
    for cluster, props in cluster_props.items():
        if len(props) > 0:
            axes[1,0].plot(time_points, props, marker='o', label=f'Cluster {cluster}', alpha=0.7)
    
    axes[1,0].set_title('Cluster Proportions Over Time')
    axes[1,0].set_xlabel('Time (hours)')
    axes[1,0].set_ylabel('Proportion')
    axes[1,0].legend()
    axes[1,0].grid(True, alpha=0.3)
    
    # Autocorrelation of cluster sequence
    cluster_sequence = df['cluster_label'].values
    max_lag = min(100, len(cluster_sequence) // 4)
    
    autocorr = []
    for lag in range(max_lag):
        if lag == 0:
            autocorr.append(1.0)
        else:
            # Fix: Ensure proper array indexing and handle edge cases
            if lag < len(cluster_sequence):
                corr_val = np.corrcoef(cluster_sequence[:-lag], cluster_sequence[lag:])[0,1]
                # Handle NaN values
                if np.isnan(corr_val):
                    corr_val = 0.0
                autocorr.append(corr_val)
            else:
                autocorr.append(0.0)
    
    axes[1,1].plot(range(len(autocorr)), autocorr, marker='o', markersize=3)
    axes[1,1].set_title('Autocorrelation of Cluster Sequence')
    axes[1,1].set_xlabel('Lag (windows)')
    axes[1,1].set_ylabel('Autocorrelation')
    axes[1,1].grid(True, alpha=0.3)
    axes[1,1].axhline(y=0, color='red', linestyle='--', alpha=0.5)

    return {
        'hourly_distribution': hourly_distribution,
        'stability_scores': stability_scores,
        'transitions': transitions,
        'autocorr': autocorr
    }

# Perform advanced statistical analysis
stats_3s = advanced_cluster_statistics(results_3s, '3-second windows')
stats_30s = advanced_cluster_statistics(results_30s, '30-second windows')