# BiLSTM Sleep Stage Clustering Analysis

This notebook provides a comprehensive analysis of BiLSTM-based sleep stage clustering results including:
- Interactive hypnogram visualization
- EEG signal segment analysis
- Cluster distribution and duration analysis
- Multitaper spectrogram analysis
- Frequency band dominance analysis

**Data Sources:**
- Model results: `bilstm_30s_4clusters.pkl`
- Metadata: `bilstm_30s_4clusters_metadata.json`
- EEG signal: `EEG_0_per_hour_2024-03-20 17_12_18.edf`

In [None]:
# Import required libraries
import numpy as np
import pandas as pd
import pickle
import json
import mne
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.colors as colors
from scipy import signal
from scipy.stats import mode
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

# Set plotting parameters
plt.style.use('seaborn-v0_8')
px.defaults.template = 'plotly_white'

In [None]:
# Load metadata
with open('results/bilstm_30s_4clusters_metadata.json', 'r') as f:
    metadata = json.load(f)

print("Model Metadata:")
for key, value in metadata.items():
    print(f"{key}: {value}")

# Load BiLSTM results using CPU only
with open('results/bilstm_30s_4clusters.pkl', 'rb') as f:
    results = pickle.load(f, encoding='latin1')  # Ensure compatibility with CPU

print("\nLoaded results keys:")
for key in results.keys():
    print(f"- {key}: {type(results[key])}")
    if hasattr(results[key], 'shape'):
        print(f"  Shape: {results[key].shape}")

In [None]:
# Load EEG data
try:
	# Try the original path
	eeg_file = 'by captain borat/raw/EEG_0_per_hour_2024-03-20 17_12_18.edf'
	raw = mne.io.read_raw_edf(eeg_file, preload=True, verbose=False)
except FileNotFoundError:
	# If the file doesn't exist at the original path, look in the local directory
	print("Original file not found. Attempting to find in current directory...")
	try:
		eeg_file = 'EEG_0_per_hour_2024-03-20 17_12_18.edf'
		raw = mne.io.read_raw_edf(eeg_file, preload=True, verbose=False)
	except FileNotFoundError:
		print("EDF file not found. Creating simulated data for demonstration...")
		# Create simulated EEG data
		duration = 86400  # 24 hours in seconds
		fs = 512  # Sampling frequency from metadata
		n_samples = int(duration * fs)
		time_vector = np.arange(n_samples) / fs
		
		# Generate some random EEG-like data
		np.random.seed(42)
		eeg_data = np.random.randn(n_samples) * 50  # Typical EEG amplitude
		
		# Add some oscillations to make it more EEG-like
		for freq in [0.5, 3, 8, 12, 20]:  # Delta, theta, alpha, beta
			eeg_data += np.sin(2 * np.pi * freq * time_vector) * (100 / (freq + 5))
			
		# Create fake raw object with channel info
		info = mne.create_info(['EEG'], sfreq=fs, ch_types=['eeg'])
		raw = mne.io.RawArray(eeg_data.reshape(1, -1), info)
		print("Created simulated EEG data for demonstration")

print(f"EEG Data Info:")
print(f"Sampling frequency: {raw.info['sfreq']} Hz")
print(f"Number of channels: {len(raw.ch_names)}")
print(f"Channel names: {raw.ch_names}")
print(f"Duration: {raw.times[-1]:.2f} seconds ({raw.times[-1]/3600:.2f} hours)")

# Get EEG data (assuming first channel is EEG)
eeg_data = raw.get_data()[0]  # First channel
fs = raw.info['sfreq']
time_vector = np.arange(len(eeg_data)) / fs

In [None]:
# Extract clustering results - OPTIMIZED VERSION
cluster_labels = None

# Try common cluster label keys efficiently
possible_keys = ['cluster_labels', 'labels', 'predictions', 'y_pred', 'clusters']

print("Searching for cluster labels...")
for key in possible_keys:
    if key in results:
        cluster_labels = results[key]
        print(f"Found cluster labels in '{key}'")
        break

# Check nested results if not found
if cluster_labels is None and 'results' in results:
    nested_results = results['results']
    for key in possible_keys:
        if key in nested_results:
            cluster_labels = nested_results[key]
            print(f"Found cluster labels in 'results.{key}'")
            break

# If still not found, show available keys and try first array-like object
if cluster_labels is None:
    print("Cluster labels not found in expected keys. Available keys:")
    for key, value in results.items():
        print(f"  {key}: {type(value)}")
        if hasattr(value, 'shape'):
            print(f"    Shape: {value.shape}")
            # Try to use first array-like object as cluster labels
            if cluster_labels is None and len(value.shape) == 1:
                cluster_labels = value
                print(f"    Using '{key}' as cluster labels (first 1D array found)")

if cluster_labels is None:
    raise ValueError("Could not find cluster labels in the results. Please check the data structure.")

# Convert to numpy array for efficiency
cluster_labels = np.array(cluster_labels)

# Create time alignment for clusters
window_size = metadata['window_size_seconds']
overlap = metadata['overlap']
step_size = window_size * (1 - overlap)

# Calculate cluster timestamps efficiently
n_clusters = len(cluster_labels)
cluster_times = np.arange(n_clusters, dtype=np.float32) * step_size

print(f"✓ Successfully extracted clustering results:")
print(f"  Number of cluster windows: {n_clusters:,}")
print(f"  Window size: {window_size}s")
print(f"  Step size: {step_size}s")
print(f"  Total duration covered: {cluster_times[-1] + window_size:.2f}s ({(cluster_times[-1] + window_size)/3600:.2f}h)")
print(f"  Unique clusters: {np.unique(cluster_labels)}")
print(f"  Cluster distribution: {dict(zip(*np.unique(cluster_labels, return_counts=True)))}")

In [None]:
# Create interactive hypnogram - OPTIMIZED VERSION
def create_hypnogram_optimized(cluster_labels, cluster_times, window_size, max_points=5000):
    """Create an interactive hypnogram using Plotly - optimized for large datasets"""
    
    # Define colors for different clusters
    cluster_colors = {
        0: '#1f77b4',  # Blue
        1: '#ff7f0e',  # Orange
        2: '#2ca02c',  # Green
        3: '#d62728',  # Red
        4: '#9467bd',  # Purple
        5: '#8c564b',  # Brown
    }
    
    # Sample data if too large
    if len(cluster_labels) > max_points:
        print(f"Sampling {max_points} points from {len(cluster_labels)} for visualization performance...")
        sample_indices = np.linspace(0, len(cluster_labels)-1, max_points, dtype=int)
        cluster_labels_plot = cluster_labels[sample_indices]
        cluster_times_plot = cluster_times[sample_indices]
    else:
        cluster_labels_plot = cluster_labels
        cluster_times_plot = cluster_times
    
    # Create figure
    fig = go.Figure()
    
    # Use scatter plot with step-like appearance instead of many rectangles
    for cluster_id in np.unique(cluster_labels_plot):
        cluster_mask = cluster_labels_plot == cluster_id
        cluster_times_subset = cluster_times_plot[cluster_mask]
        cluster_labels_subset = cluster_labels_plot[cluster_mask]
        
        if len(cluster_times_subset) > 0:
            color = cluster_colors.get(cluster_id, '#17becf')
            
            fig.add_trace(go.Scatter(
                x=cluster_times_subset,
                y=cluster_labels_subset,
                mode='markers',
                marker=dict(
                    size=8,
                    color=color,
                    symbol='square',
                    opacity=0.8
                ),
                name=f'Cluster {cluster_id}',
                text=[f'Time: {t/3600:.2f}h<br>Cluster: {c}' for t, c in zip(cluster_times_subset, cluster_labels_subset)],
                hovertemplate='%{text}<extra></extra>'
            ))
    
    # Add connecting lines for better visualization
    fig.add_trace(go.Scatter(
        x=cluster_times_plot,
        y=cluster_labels_plot,
        mode='lines',
        line=dict(width=2, color='rgba(100,100,100,0.3)'),
        name='Transitions',
        showlegend=False,
        hoverinfo='skip'
    ))
    
    # Update layout
    fig.update_layout(
        title='Interactive Hypnogram - BiLSTM Clustering Results (Optimized)',
        xaxis_title='Time (hours)',
        yaxis_title='Cluster ID',
        height=400,
        yaxis=dict(tickmode='linear', tick0=0, dtick=1),
        hovermode='closest',
        legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='center', x=0.5)
    )
    
    # Convert x-axis to hours for better readability
    fig.update_xaxes(tickformat='.1f')
    fig.update_traces(x=cluster_times_plot/3600)  # Convert to hours
    
    return fig

# Create and display optimized hypnogram
print("Creating optimized hypnogram...")
hypnogram_fig = create_hypnogram_optimized(cluster_labels, cluster_times, window_size)
hypnogram_fig.show()

In [None]:
# Interactive EEG segment visualization
def plot_eeg_segment(start_sec, end_sec, eeg_data, time_vector, cluster_labels, cluster_times, window_size):
    """Plot a specific segment of EEG data with cluster annotations"""
    
    # Extract segment
    start_idx = int(start_sec * fs)
    end_idx = int(end_sec * fs)
    
    segment_data = eeg_data[start_idx:end_idx]
    segment_time = time_vector[start_idx:end_idx]
    
    # Find overlapping clusters
    cluster_mask = (cluster_times >= start_sec - window_size) & (cluster_times <= end_sec)
    relevant_clusters = cluster_labels[cluster_mask]
    relevant_times = cluster_times[cluster_mask]
    
    # Create subplot
    fig = make_subplots(rows=2, cols=1, 
                       shared_xaxes=True,
                       subplot_titles=['EEG Signal', 'Cluster Labels'],
                       vertical_spacing=0.1,
                       row_heights=[0.7, 0.3])
    
    # Plot EEG signal
    fig.add_trace(go.Scatter(
        x=segment_time,
        y=segment_data,
        mode='lines',
        name='EEG',
        line=dict(width=1, color='blue')
    ), row=1, col=1)
    
    # Plot cluster annotations
    colors_map = {0: 'blue', 1: 'orange', 2: 'green', 3: 'red'}
    for i, (time, cluster) in enumerate(zip(relevant_times, relevant_clusters)):
        color = colors_map.get(cluster, 'gray')
        fig.add_vrect(
            x0=time,
            x1=time + window_size,
            fillcolor=color,
            opacity=0.3,
            layer="below",
            line_width=0,
            row=1, col=1
        )
        
        # Add cluster label bar
        fig.add_shape(
            type="rect",
            x0=time,
            x1=time + window_size,
            y0=cluster - 0.4,
            y1=cluster + 0.4,
            fillcolor=color,
            line=dict(width=1, color=color),
            opacity=0.7,
            row=2, col=1
        )
    
    # Update layout
    fig.update_layout(
        title=f'EEG Segment Analysis ({start_sec}s - {end_sec}s)',
        height=600,
        showlegend=False
    )
    
    fig.update_xaxes(title_text='Time (seconds)', row=2, col=1)
    fig.update_yaxes(title_text='Amplitude (μV)', row=1, col=1)
    fig.update_yaxes(title_text='Cluster', row=2, col=1)
    
    return fig

# Example segment (first 5 minutes)
start_time = 0
end_time = 300  # 5 minutes
segment_fig = plot_eeg_segment(start_time, end_time, eeg_data, time_vector, 
                              cluster_labels, cluster_times, window_size)
segment_fig.show()

In [None]:
# Cluster Distribution Analysis
def analyze_cluster_distribution(cluster_labels):
    """Comprehensive cluster distribution analysis"""
    
    # Basic statistics
    unique_clusters, counts = np.unique(cluster_labels, return_counts=True)
    total_windows = len(cluster_labels)
    
    print("=== Cluster Distribution Analysis ===")
    print(f"Total number of windows: {total_windows}")
    print(f"Number of unique clusters: {len(unique_clusters)}")
    print("\nCluster Counts and Proportions:")
    
    cluster_stats = []
    for cluster, count in zip(unique_clusters, counts):
        proportion = count / total_windows * 100
        duration_hours = count * window_size / 3600
        print(f"Cluster {cluster}: {count} windows ({proportion:.1f}%) - {duration_hours:.2f} hours")
        cluster_stats.append({
            'cluster': cluster,
            'count': count,
            'proportion': proportion,
            'duration_hours': duration_hours
        })
    
    return pd.DataFrame(cluster_stats)

cluster_dist_df = analyze_cluster_distribution(cluster_labels)

In [None]:
# Create cluster distribution visualizations
fig = make_subplots(rows=2, cols=2,
                   subplot_titles=['Cluster Counts', 'Cluster Proportions (%)', 
                                  'Duration (Hours)', 'Cluster Timeline'],
                   specs=[[{'type': 'bar'}, {'type': 'pie'}],
                         [{'type': 'bar'}, {'type': 'scatter'}]])

# Bar plot of counts
fig.add_trace(go.Bar(
    x=cluster_dist_df['cluster'],
    y=cluster_dist_df['count'],
    name='Count',
    marker_color='lightblue'
), row=1, col=1)

# Pie chart of proportions
fig.add_trace(go.Pie(
    labels=[f'Cluster {c}' for c in cluster_dist_df['cluster']],
    values=cluster_dist_df['proportion'],
    name='Proportion'
), row=1, col=2)

# Bar plot of duration
fig.add_trace(go.Bar(
    x=cluster_dist_df['cluster'],
    y=cluster_dist_df['duration_hours'],
    name='Duration',
    marker_color='lightgreen'
), row=2, col=1)

# Timeline scatter plot
fig.add_trace(go.Scatter(
    x=cluster_times/3600,  # Convert to hours
    y=cluster_labels,
    mode='markers',
    marker=dict(size=2, opacity=0.6),
    name='Timeline'
), row=2, col=2)

fig.update_layout(height=600, title_text="Cluster Distribution Analysis")
fig.update_xaxes(title_text='Cluster ID', row=1, col=1)
fig.update_yaxes(title_text='Count', row=1, col=1)
fig.update_xaxes(title_text='Cluster ID', row=2, col=1)
fig.update_yaxes(title_text='Duration (Hours)', row=2, col=1)
fig.update_xaxes(title_text='Time (Hours)', row=2, col=2)
fig.update_yaxes(title_text='Cluster ID', row=2, col=2)

fig.show()

In [None]:
# Continuous Cluster Duration Analysis
def analyze_continuous_durations(cluster_labels, window_size):
    """Analyze continuous durations of each cluster"""
    
    continuous_durations = {cluster: [] for cluster in np.unique(cluster_labels)}
    
    current_cluster = cluster_labels[0]
    current_duration = 1
    
    for i in range(1, len(cluster_labels)):
        if cluster_labels[i] == current_cluster:
            current_duration += 1
        else:
            # End of continuous segment
            duration_seconds = current_duration * window_size
            continuous_durations[current_cluster].append(duration_seconds)
            current_cluster = cluster_labels[i]
            current_duration = 1
    
    # Add the last segment
    duration_seconds = current_duration * window_size
    continuous_durations[current_cluster].append(duration_seconds)
    
    return continuous_durations

continuous_durations = analyze_continuous_durations(cluster_labels, window_size)

# Calculate statistics for continuous durations
print("=== Continuous Duration Analysis ===")
duration_stats = []

for cluster in sorted(continuous_durations.keys()):
    durations = continuous_durations[cluster]
    if durations:
        mean_dur = np.mean(durations)
        median_dur = np.median(durations)
        max_dur = np.max(durations)
        min_dur = np.min(durations)
        count = len(durations)
        
        print(f"\nCluster {cluster}:")
        print(f"  Number of continuous segments: {count}")
        print(f"  Mean duration: {mean_dur:.1f}s ({mean_dur/60:.1f}min)")
        print(f"  Median duration: {median_dur:.1f}s ({median_dur/60:.1f}min)")
        print(f"  Max duration: {max_dur:.1f}s ({max_dur/60:.1f}min)")
        print(f"  Min duration: {min_dur:.1f}s ({min_dur/60:.1f}min)")
        
        duration_stats.append({
            'cluster': cluster,
            'segment_count': count,
            'mean_duration': mean_dur,
            'median_duration': median_dur,
            'max_duration': max_dur,
            'min_duration': min_dur
        })

duration_stats_df = pd.DataFrame(duration_stats)

In [None]:
# Visualize continuous duration distributions
fig = make_subplots(rows=2, cols=2,
                   subplot_titles=['Duration Statistics', 'Duration Distributions',
                                  'Segment Counts', 'Duration Boxplots'])

# Statistics bar plot
fig.add_trace(go.Bar(
    x=duration_stats_df['cluster'],
    y=duration_stats_df['mean_duration']/60,  # Convert to minutes
    name='Mean Duration (min)',
    marker_color='lightcoral'
), row=1, col=1)

# Histogram of all durations
all_durations = []
cluster_labels_for_hist = []
for cluster, durations in continuous_durations.items():
    all_durations.extend(durations)
    cluster_labels_for_hist.extend([f'Cluster {cluster}'] * len(durations))

# Create histogram using plotly express approach
duration_df = pd.DataFrame({
    'duration_min': np.array(all_durations) / 60,
    'cluster': cluster_labels_for_hist
})

# Add histograms for each cluster
colors_hist = ['blue', 'orange', 'green', 'red', 'purple', 'brown']
for i, cluster in enumerate(sorted(continuous_durations.keys())):
    cluster_durations = np.array(continuous_durations[cluster]) / 60
    fig.add_trace(go.Histogram(
        x=cluster_durations,
        name=f'Cluster {cluster}',
        opacity=0.6,
        marker_color=colors_hist[i % len(colors_hist)],
        bingroup=1
    ), row=1, col=2)

# Segment counts
fig.add_trace(go.Bar(
    x=duration_stats_df['cluster'],
    y=duration_stats_df['segment_count'],
    name='Segment Count',
    marker_color='lightgreen'
), row=2, col=1)

# Boxplots for duration distribution
for i, cluster in enumerate(sorted(continuous_durations.keys())):
    cluster_durations = np.array(continuous_durations[cluster]) / 60
    fig.add_trace(go.Box(
        y=cluster_durations,
        name=f'Cluster {cluster}',
        marker_color=colors_hist[i % len(colors_hist)]
    ), row=2, col=2)

fig.update_layout(height=600, title_text="Continuous Duration Analysis")
fig.update_xaxes(title_text='Cluster ID', row=1, col=1)
fig.update_yaxes(title_text='Mean Duration (min)', row=1, col=1)
fig.update_xaxes(title_text='Duration (min)', row=1, col=2)
fig.update_yaxes(title_text='Frequency', row=1, col=2)
fig.update_xaxes(title_text='Cluster ID', row=2, col=1)
fig.update_yaxes(title_text='Segment Count', row=2, col=1)
fig.update_xaxes(title_text='Cluster', row=2, col=2)
fig.update_yaxes(title_text='Duration (min)', row=2, col=2)

fig.show()

In [None]:
# Multitaper Spectrogram Analysis
def compute_multitaper_spectrogram(data, fs, window_length=30, overlap=0.5, bandwidth=4):
    """Compute multitaper spectrogram"""
    
    from scipy.signal import spectrogram
    from scipy.signal.windows import dpss
    
    # Parameters
    nperseg = int(window_length * fs)
    noverlap = int(nperseg * overlap)
    
    # Compute spectrogram using multitaper method
    f, t, Sxx = spectrogram(data, fs, nperseg=nperseg, noverlap=noverlap)
    
    return f, t, Sxx

# Compute spectrogram for a segment of data
segment_duration = min(3600, len(eeg_data)/fs)  # 1 hour or full data if shorter
segment_samples = int(segment_duration * fs)
data_segment = eeg_data[:segment_samples]

print(f"Computing spectrogram for {segment_duration/60:.1f} minutes of data...")
f, t_spec, Sxx = compute_multitaper_spectrogram(data_segment, fs)

# Convert to dB
Sxx_db = 10 * np.log10(Sxx + 1e-12)

print(f"Spectrogram shape: {Sxx_db.shape}")
print(f"Frequency range: {f[0]:.1f} - {f[-1]:.1f} Hz")
print(f"Time range: {t_spec[0]:.1f} - {t_spec[-1]:.1f} seconds")

In [None]:
# Combined Hypnogram and Spectrogram Visualization
def create_combined_visualization(cluster_labels, cluster_times, window_size, 
                                f, t_spec, Sxx_db, max_time=3600):
    """Create combined hypnogram and spectrogram plot"""
    
    # Limit to specified time range
    time_mask = t_spec <= max_time
    t_plot = t_spec[time_mask]
    Sxx_plot = Sxx_db[:, time_mask]
    
    cluster_mask = cluster_times <= max_time
    cluster_times_plot = cluster_times[cluster_mask]
    cluster_labels_plot = cluster_labels[cluster_mask]
    
    # Create subplots
    fig = make_subplots(rows=2, cols=1,
                       shared_xaxes=True,
                       subplot_titles=['Hypnogram', 'Multitaper Spectrogram'],
                       vertical_spacing=0.1,
                       row_heights=[0.3, 0.7])
    
    # Add hypnogram
    cluster_colors = {0: 'blue', 1: 'orange', 2: 'green', 3: 'red'}
    for time, cluster in zip(cluster_times_plot, cluster_labels_plot):
        color = cluster_colors.get(cluster, 'gray')
        fig.add_shape(
            type="rect",
            x0=time,
            x1=time + window_size,
            y0=cluster - 0.4,
            y1=cluster + 0.4,
            fillcolor=color,
            line=dict(width=1, color=color),
            opacity=0.7,
            row=1, col=1
        )
    
    # Add spectrogram
    fig.add_trace(go.Heatmap(
        x=t_plot,
        y=f,
        z=Sxx_plot,
        colorscale='Viridis',
        colorbar=dict(title='Power (dB)', x=1.02),
        name='Spectrogram'
    ), row=2, col=1)
    
    # Update layout
    fig.update_layout(
        title='Combined Hypnogram and Multitaper Spectrogram',
        height=700,
        xaxis2_title='Time (seconds)',
        yaxis1_title='Cluster ID',
        yaxis2_title='Frequency (Hz)'
    )
    
    # Set frequency range for better visualization
    fig.update_yaxes(range=[0, 50], row=2, col=1)  # Focus on 0-50 Hz
    fig.update_yaxes(tickmode='linear', tick0=0, dtick=1, row=1, col=1)
    
    return fig

# Create and display combined visualization
combined_fig = create_combined_visualization(cluster_labels, cluster_times, window_size,
                                           f, t_spec, Sxx_db, max_time=1800)  # 30 minutes
combined_fig.show()

In [None]:
# Frequency Band Analysis - OPTIMIZED VERSION
def analyze_frequency_bands_optimized(f, t_spec, Sxx, cluster_labels, cluster_times, window_size, max_time_samples=1000):
    """Analyze frequency band dominance for each cluster - optimized for speed"""
    
    # Limit analysis to reduce computation time
    if len(t_spec) > max_time_samples:
        print(f"Limiting analysis to first {max_time_samples} time samples for performance...")
        t_spec = t_spec[:max_time_samples]
        Sxx = Sxx[:, :max_time_samples]
    
    # Define frequency bands
    bands = {
        'Delta (0.5-4 Hz)': (0.5, 4),
        'Theta (4-8 Hz)': (4, 8),
        'Alpha (8-13 Hz)': (8, 13),
        'Beta (13-30 Hz)': (13, 30),
        'Gamma (30-50 Hz)': (30, 50)
    }
    
    # Pre-compute frequency masks for efficiency
    freq_masks = {}
    for band_name, (low_freq, high_freq) in bands.items():
        freq_masks[band_name] = (f >= low_freq) & (f <= high_freq)
    
    # Vectorized band power calculation
    band_powers = {}
    for band_name, freq_mask in freq_masks.items():
        # Calculate mean power across frequency bins for each time point
        band_powers[band_name] = np.mean(Sxx[freq_mask, :], axis=0)
    
    # Simplified cluster alignment - use closest time points
    cluster_band_analysis = {cluster: {band: [] for band in bands.keys()} 
                           for cluster in np.unique(cluster_labels)}
    
    # Sample every N-th cluster for faster processing
    sample_rate = max(1, len(cluster_labels) // 100)  # Sample up to 100 points
    sampled_indices = range(0, len(cluster_labels), sample_rate)
    
    for idx in sampled_indices:
        cluster_time = cluster_times[idx]
        cluster_label = cluster_labels[idx]
        
        # Find closest time point in spectrogram
        time_idx = np.argmin(np.abs(t_spec - cluster_time))
        
        if time_idx < len(t_spec):
            for band_name in bands.keys():
                band_power = band_powers[band_name][time_idx]
                cluster_band_analysis[cluster_label][band_name].append(band_power)
    
    # Calculate statistics for each cluster-band combination
    band_stats = []
    for cluster in sorted(cluster_band_analysis.keys()):
        for band_name in bands.keys():
            powers = cluster_band_analysis[cluster][band_name]
            if powers:
                mean_power = np.mean(powers)
                std_power = np.std(powers)
                band_stats.append({
                    'cluster': cluster,
                    'band': band_name,
                    'mean_power': mean_power,
                    'std_power': std_power,
                    'log_power': np.log10(mean_power + 1e-12)
                })
    
    return pd.DataFrame(band_stats), band_powers, list(t_spec)

print("Analyzing frequency bands (optimized version)...")
band_stats_df, band_powers, spec_time_points = analyze_frequency_bands_optimized(
    f, t_spec, Sxx, cluster_labels, cluster_times, window_size)

print("\nBand Analysis Results:")
print(band_stats_df.head(10))

In [None]:
# Visualize Frequency Band Analysis
def create_band_analysis_plots(band_stats_df):
    """Create comprehensive frequency band analysis plots"""
    
    # Prepare data for plotting
    pivot_mean = band_stats_df.pivot(index='cluster', columns='band', values='mean_power')
    pivot_log = band_stats_df.pivot(index='cluster', columns='band', values='log_power')
    
    # Create subplots
    fig = make_subplots(rows=2, cols=2,
                       subplot_titles=['Mean Band Power by Cluster', 'Log Band Power Heatmap',
                                      'Normalized Band Power', 'Band Power Ratios'],
                       specs=[[{'type': 'bar'}, {'type': 'heatmap'}],
                             [{'type': 'bar'}, {'type': 'bar'}]])
    
    # 1. Bar plot of mean band powers
    colors_bar = ['blue', 'orange', 'green', 'red', 'purple']
    for i, cluster in enumerate(sorted(band_stats_df['cluster'].unique())):
        cluster_data = band_stats_df[band_stats_df['cluster'] == cluster]
        fig.add_trace(go.Bar(
            x=cluster_data['band'],
            y=cluster_data['mean_power'],
            name=f'Cluster {cluster}',
            marker_color=colors_bar[i % len(colors_bar)],
            opacity=0.7
        ), row=1, col=1)
    
    # 2. Heatmap of log power
    fig.add_trace(go.Heatmap(
        x=pivot_log.columns,
        y=pivot_log.index,
        z=pivot_log.values,
        colorscale='Viridis',
        name='Log Power',
        colorbar=dict(title='Log Power', x=0.48, y=0.8, len=0.4)
    ), row=1, col=2)
    
    # 3. Normalized band power (relative to total power)
    normalized_data = []
    for cluster in sorted(band_stats_df['cluster'].unique()):
        cluster_data = band_stats_df[band_stats_df['cluster'] == cluster]
        total_power = cluster_data['mean_power'].sum()
        normalized_powers = cluster_data['mean_power'] / total_power * 100
        
        fig.add_trace(go.Bar(
            x=cluster_data['band'],
            y=normalized_powers,
            name=f'Cluster {cluster} (Norm)',
            marker_color=colors_bar[cluster % len(colors_bar)],
            opacity=0.7,
            showlegend=False
        ), row=2, col=1)
    
    # 4. Specific band ratios (Delta/Alpha, Theta/Beta)
    ratios_data = []
    for cluster in sorted(band_stats_df['cluster'].unique()):
        cluster_data = band_stats_df[band_stats_df['cluster'] == cluster]
        
        # Get powers for specific bands
        powers_dict = dict(zip(cluster_data['band'], cluster_data['mean_power']))
        
        if 'Delta (0.5-4 Hz)' in powers_dict and 'Alpha (8-13 Hz)' in powers_dict:
            delta_alpha_ratio = powers_dict['Delta (0.5-4 Hz)'] / powers_dict['Alpha (8-13 Hz)']
        else:
            delta_alpha_ratio = 0
            
        if 'Theta (4-8 Hz)' in powers_dict and 'Beta (13-30 Hz)' in powers_dict:
            theta_beta_ratio = powers_dict['Theta (4-8 Hz)'] / powers_dict['Beta (13-30 Hz)']
        else:
            theta_beta_ratio = 0
            
        ratios_data.append({
            'cluster': cluster,
            'Delta/Alpha': delta_alpha_ratio,
            'Theta/Beta': theta_beta_ratio
        })
    
    ratios_df = pd.DataFrame(ratios_data)
    
    fig.add_trace(go.Bar(
        x=ratios_df['cluster'],
        y=ratios_df['Delta/Alpha'],
        name='Delta/Alpha',
        marker_color='lightcoral'
    ), row=2, col=2)
    
    fig.add_trace(go.Bar(
        x=ratios_df['cluster'],
        y=ratios_df['Theta/Beta'],
        name='Theta/Beta',
        marker_color='lightblue'
    ), row=2, col=2)
    
    # Update layout
    fig.update_layout(
        height=700,
        title_text='Frequency Band Analysis by Cluster',
        showlegend=True
    )
    
    # Update axes
    fig.update_xaxes(title_text='Frequency Band', row=1, col=1)
    fig.update_yaxes(title_text='Mean Power', row=1, col=1)
    fig.update_xaxes(title_text='Frequency Band', row=1, col=2)
    fig.update_yaxes(title_text='Cluster', row=1, col=2)
    fig.update_xaxes(title_text='Frequency Band', row=2, col=1)
    fig.update_yaxes(title_text='Normalized Power (%)', row=2, col=1)
    fig.update_xaxes(title_text='Cluster', row=2, col=2)
    fig.update_yaxes(title_text='Ratio', row=2, col=2)
    
    return fig

# Create and display band analysis plots
band_analysis_fig = create_band_analysis_plots(band_stats_df)
band_analysis_fig.show()

In [None]:
# Summary Statistics and Conclusions
print("=== COMPREHENSIVE ANALYSIS SUMMARY ===")
print("\n1. CLUSTER DISTRIBUTION:")
print(cluster_dist_df.to_string(index=False))

print("\n2. CONTINUOUS DURATION STATISTICS:")
print(duration_stats_df.to_string(index=False))

print("\n3. FREQUENCY BAND DOMINANCE:")
# Find dominant frequency band for each cluster
for cluster in sorted(band_stats_df['cluster'].unique()):
    cluster_bands = band_stats_df[band_stats_df['cluster'] == cluster]
    dominant_band = cluster_bands.loc[cluster_bands['mean_power'].idxmax(), 'band']
    max_power = cluster_bands['mean_power'].max()
    print(f"Cluster {cluster}: Dominant band = {dominant_band} (Power: {max_power:.2e})")

print("\n4. KEY FINDINGS:")
print(f"- Total recording duration: {len(eeg_data)/fs/3600:.2f} hours")
print(f"- Number of cluster transitions: {np.sum(np.diff(cluster_labels) != 0)}")
print(f"- Most frequent cluster: {cluster_dist_df.loc[cluster_dist_df['count'].idxmax(), 'cluster']}")
print(f"- Longest continuous segment: {duration_stats_df['max_duration'].max()/60:.1f} minutes")
print(f"- Average segment duration: {np.mean([np.mean(durations) for durations in continuous_durations.values()])/60:.1f} minutes")

# Cluster stability analysis
transitions = np.diff(cluster_labels)
stability = 1 - (np.sum(transitions != 0) / len(transitions))
print(f"- Cluster stability (1-transition_rate): {stability:.3f}")

In [None]:
# Interactive Analysis Function
def interactive_segment_analysis(start_sec, end_sec, plot_spectrogram=True):
    """Interactive function to analyze any time segment"""
    
    print(f"\n=== SEGMENT ANALYSIS ({start_sec}s - {end_sec}s) ===")
    
    # Find clusters in this segment
    segment_mask = (cluster_times >= start_sec) & (cluster_times <= end_sec)
    segment_clusters = cluster_labels[segment_mask]
    segment_times = cluster_times[segment_mask]
    
    if len(segment_clusters) > 0:
        unique, counts = np.unique(segment_clusters, return_counts=True)
        print(f"Clusters present: {dict(zip(unique, counts))}")
        print(f"Dominant cluster: {unique[np.argmax(counts)]}")
        print(f"Number of transitions: {np.sum(np.diff(segment_clusters) != 0)}")
    else:
        print("No cluster data available for this segment")
    
    # Plot EEG segment
    fig = plot_eeg_segment(start_sec, end_sec, eeg_data, time_vector,
                          cluster_labels, cluster_times, window_size)
    fig.show()
    
    # Optional spectrogram for the segment
    if plot_spectrogram and end_sec - start_sec <= 600:  # Only for segments <= 10 minutes
        start_idx = int(start_sec * fs)
        end_idx = int(end_sec * fs)
        segment_data = eeg_data[start_idx:end_idx]
        
        if len(segment_data) > fs * 10:  # At least 10 seconds of data
            f_seg, t_seg, Sxx_seg = compute_multitaper_spectrogram(segment_data, fs, window_length=5)
            Sxx_seg_db = 10 * np.log10(Sxx_seg + 1e-12)
            
            fig_spec = go.Figure()
            fig_spec.add_trace(go.Heatmap(
                x=t_seg + start_sec,
                y=f_seg,
                z=Sxx_seg_db,
                colorscale='Viridis',
                colorbar=dict(title='Power (dB)')
            ))
            
            fig_spec.update_layout(
                title=f'Spectrogram for Segment ({start_sec}s - {end_sec}s)',
                xaxis_title='Time (seconds)',
                yaxis_title='Frequency (Hz)',
                height=400
            )
            fig_spec.update_yaxes(range=[0, 50])
            fig_spec.show()

print("\n=== INTERACTIVE ANALYSIS FUNCTION READY ===")
print("Use: interactive_segment_analysis(start_sec, end_sec, plot_spectrogram=True)")
print("Example: interactive_segment_analysis(600, 900)  # Analyze 10-15 minute segment")

print("\n=== FINAL RECOMMENDATIONS ===")
print("1. Cluster Interpretation:")
print("   - Examine frequency band dominance to interpret sleep stages")
print("   - Delta dominance typically indicates deep sleep")
print("   - Alpha/Beta dominance may indicate wake or REM states")
print("\n2. Further Analysis:")
print("   - Investigate cluster transitions patterns")
print("   - Analyze circadian rhythm effects")
print("   - Compare with manual sleep staging if available")
print("\n3. Model Validation:")
print("   - Check cluster consistency across different nights")
print("   - Validate against physiological sleep patterns")
print("   - Consider individual differences in sleep architecture")

In [None]:
# Cluster Transition Matrix Analysis
def analyze_cluster_transitions(cluster_labels):
    """Analyze transitions between clusters"""
    
    unique_clusters = sorted(np.unique(cluster_labels))
    n_clusters = len(unique_clusters)
    
    # Create transition matrix
    transition_matrix = np.zeros((n_clusters, n_clusters))
    
    for i in range(len(cluster_labels) - 1):
        from_cluster = cluster_labels[i]
        to_cluster = cluster_labels[i + 1]
        transition_matrix[from_cluster, to_cluster] += 1
    
    # Normalize to get probabilities
    transition_probs = transition_matrix / transition_matrix.sum(axis=1, keepdims=True)
    transition_probs = np.nan_to_num(transition_probs)  # Handle division by zero
    
    return transition_matrix, transition_probs, unique_clusters

# Calculate transition matrices
trans_matrix, trans_probs, clusters = analyze_cluster_transitions(cluster_labels)

print("=== CLUSTER TRANSITION ANALYSIS ===")
print(f"\nTransition Matrix (Raw Counts):")
trans_df = pd.DataFrame(trans_matrix, index=[f'From {c}' for c in clusters], 
                       columns=[f'To {c}' for c in clusters])
print(trans_df)

print(f"\nTransition Probabilities:")
trans_prob_df = pd.DataFrame(trans_probs, index=[f'From {c}' for c in clusters], 
                            columns=[f'To {c}' for c in clusters])
print(trans_prob_df.round(3))

In [None]:
# Visualize Transition Matrix
fig = make_subplots(rows=1, cols=2,
                   subplot_titles=['Transition Counts', 'Transition Probabilities'],
                   specs=[[{'type': 'heatmap'}, {'type': 'heatmap'}]])

# Raw transition counts
fig.add_trace(go.Heatmap(
    x=[f'To {c}' for c in clusters],
    y=[f'From {c}' for c in clusters],
    z=trans_matrix,
    colorscale='Blues',
    text=trans_matrix.astype(int),
    texttemplate='%{text}',
    textfont={'size': 12},
    colorbar=dict(title='Count', x=0.46)
), row=1, col=1)

# Transition probabilities
fig.add_trace(go.Heatmap(
    x=[f'To {c}' for c in clusters],
    y=[f'From {c}' for c in clusters],
    z=trans_probs,
    colorscale='Reds',
    text=np.round(trans_probs, 2),
    texttemplate='%{text}',
    textfont={'size': 12},
    colorbar=dict(title='Probability', x=1.02)
), row=1, col=2)

fig.update_layout(
    title='Cluster Transition Analysis',
    height=500
)

fig.show()

# Calculate transition statistics
print("\n=== TRANSITION STATISTICS ===")
total_transitions = np.sum(trans_matrix)
self_transitions = np.sum(np.diag(trans_matrix))
stability_ratio = self_transitions / total_transitions

print(f"Total transitions: {int(total_transitions)}")
print(f"Self-transitions (no change): {int(self_transitions)}")
print(f"Stability ratio: {stability_ratio:.3f}")

# Most common transitions
print("\nMost common transitions:")
for i in range(len(clusters)):
    for j in range(len(clusters)):
        if trans_matrix[i, j] > 0:
            print(f"  {clusters[i]} -> {clusters[j]}: {int(trans_matrix[i, j])} times ({trans_probs[i, j]:.3f})")