# BiLSTM + Time Series Transformer for Self-Supervised Signal Reconstruction and Clustering

This notebook implements a sophisticated deep learning pipeline that combines **BiLSTM** and **Time Series Transformer** architectures for self-supervised signal reconstruction and clustering of EEG sleep stage data.

## Key Features:
- **3-second time windows** for precise temporal analysis
- **BiLSTM + Transformer hybrid model** for capturing both sequential and attention-based patterns
- **Self-supervised signal reconstruction** without requiring labeled data
- **DMD feature extraction** for capturing dynamic mode information
- **HDBSCAN clustering** for discovering hidden states in the embedding space
- **Multi-channel support** for EEG, EOG, and other physiological signals

## Pipeline Overview:
1. Load and preprocess EEG/EDF files
2. Extract DMD features from 3-second windows
3. Create sequences for model input
4. Train BiLSTM + Transformer hybrid model
5. Extract embeddings from trained model
6. Apply HDBSCAN clustering
7. Visualize results and save outputs

## Section 1: Import Required Libraries

Import all necessary libraries for the pipeline including deep learning frameworks, signal processing, and clustering tools.

In [None]:
# Core libraries
import numpy as np
import pandas as pd
import os
import json
import pickle
import math
import warnings
from datetime import datetime
from tqdm import tqdm
warnings.filterwarnings('ignore')

# Deep learning libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch.nn.utils.rnn import pad_sequence

# Signal processing and EEG
import mne
from scipy.signal import butter, filtfilt
from scipy.stats import mode
from mne.time_frequency import psd_array_multitaper

# DMD for feature extraction
from pydmd import DMD, EDMD

# Machine learning and clustering
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score
import hdbscan

# Visualization
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
else:
    print("Warning: CUDA not available. Using CPU.")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)

print("✓ All libraries imported successfully!")

## Section 2: Data Loading and Preprocessing

Load EEG/EDF files, normalize signals, apply bandpass filtering, and remove outliers using a sliding window approach.

In [None]:
def bandpass_filter(signal, sfreq, low=0.5, high=40, order=4):
    """Apply bandpass filter to EEG signal."""
    nyq = 0.5 * sfreq
    lowcut = low / nyq
    highcut = high / nyq
    b, a = butter(order, [lowcut, highcut], btype='band')
    filtered = filtfilt(b, a, signal)
    return filtered

def load_and_preprocess_edf(filepath, num_channels=4, window_sec=30, std_factor=3):
    """Load and preprocess multiple channels from an EDF file.
    
    Parameters:
    -----------
    filepath : str
        Path to the EDF file
    num_channels : int
        Number of channels to load (default: 4)
    window_sec : int
        Window size in seconds for outlier removal
    std_factor : int
        Standard deviation multiplier for outlier detection
        
    Returns:
    --------
    signals : list of numpy.ndarray
        List of preprocessed signals
    sfreqs : list of int
        List of sampling frequencies for each signal
    channel_names : list of str
        List of channel names
    """
    print(f"Loading EDF file: {filepath}")
    raw = mne.io.read_raw_edf(filepath, preload=True, verbose=False)
    
    # Get channel names and make sure we don't request more than available
    all_channels = raw.ch_names
    num_channels = min(num_channels, len(all_channels))
    channels_to_use = all_channels[:num_channels]
    
    signals = []
    sfreqs = []
    
    print(f"Processing {num_channels} channels: {channels_to_use}")
    
    # Process each channel
    for i, channel in enumerate(channels_to_use):
        print(f"  Processing channel {i+1}/{num_channels}: {channel}")
        
        # Extract signal and sampling frequency for this channel
        signal = raw.get_data(picks=channel)[0]
        sfreq = int(raw.info['sfreq'])
        
        # Bandpass filter 0.5-40Hz (for sleep stage analysis)
        signal = bandpass_filter(signal, sfreq, low=0.5, high=40)
        
        # Normalize to [-1, 1]
        scaler = MinMaxScaler(feature_range=(-1, 1))
        signal_norm = scaler.fit_transform(signal.reshape(-1, 1)).flatten()
        
        # Remove outliers in sliding windows
        win_samples = window_sec * sfreq
        cleaned = signal_norm.copy()
        
        for start in range(0, len(signal_norm), win_samples):
            end = min(start + win_samples, len(signal_norm))
            window = signal_norm[start:end]
            mean = np.mean(window)
            std = np.std(window)
            mask = np.abs(window - mean) < std_factor * std
            window_cleaned = np.where(mask, window, mean)
            cleaned[start:end] = window_cleaned
            
        signals.append(cleaned)
        sfreqs.append(sfreq)
        
        print(f"    Signal length: {len(cleaned)} samples ({len(cleaned)/sfreq:.1f} seconds)")
    
    print(f"✓ Successfully loaded {len(signals)} channels")
    return signals, sfreqs, channels_to_use

# EDF file list for processing
edf_files = [
    "raw data/SC4001E0-PSG.edf",
    "raw data/SC4002E0-PSG.edf", 
    "raw data/SC4011E0-PSG.edf",
    "raw data/SC4012E0-PSG.edf",
    "raw data/SC4021E0-PSG.edf"
]

print(f"Will process {len(edf_files)} EDF files")

## Section 3: Feature Extraction with DMD on 3-second Windows

Extract DMD features (magnitude and phase) from each 3-second window of the preprocessed signals.

In [None]:
def extract_dmd_features_3s(signal, sfreq, window_sec=3, step_sec=1, n_modes=8):
    """Extract DMD features from 3-second windows with 1-second steps.
    
    Parameters:
    -----------
    signal : numpy.ndarray
        Input signal array
    sfreq : float
        Sampling frequency in Hz
    window_sec : int
        Window size in seconds (fixed at 3 for this pipeline)
    step_sec : int
        Step size in seconds (1 for high temporal resolution)
    n_modes : int
        Number of DMD modes to extract
        
    Returns:
    --------
    features : numpy.ndarray
        Extracted DMD features (n_windows x n_features)
    time_stamps : numpy.ndarray
        Time stamps for each window
    """
    win_samples = int(window_sec * sfreq)
    step_samples = int(step_sec * sfreq)
    
    # Calculate number of windows
    n_windows = (len(signal) - win_samples) // step_samples + 1
    
    features = []
    time_stamps = []
    
    print(f"Extracting DMD features from {n_windows} windows (3s each, 1s step)...")
    
    # Use tqdm for progress tracking
    for start in tqdm(range(0, len(signal) - win_samples + 1, step_samples), 
                      desc="DMD feature extraction"):
        
        window = signal[start:start + win_samples]
        
        # Create DMD instance
        dmd = DMD(svd_rank=n_modes)
        
        # Fit DMD model to the window data
        # DMD expects data in format (n_features, n_samples)
        dmd.fit(window.reshape(1, -1))
        
        # Extract magnitude and phase features from the amplitudes
        if len(dmd.amplitudes) >= n_modes:
            amplitudes = dmd.amplitudes[:n_modes]
        else:
            # Pad with zeros if fewer modes than expected
            amplitudes = np.pad(dmd.amplitudes, (0, n_modes - len(dmd.amplitudes)), 
                              mode='constant', constant_values=0)
        
        feat_mag = np.abs(amplitudes)  # Magnitude features
        feat_phase = np.angle(amplitudes)  # Phase features
        
        # Also extract eigenvalues (frequency and growth rate information)
        if len(dmd.eigs) >= n_modes:
            eigenvalues = dmd.eigs[:n_modes]
        else:
            eigenvalues = np.pad(dmd.eigs, (0, n_modes - len(dmd.eigs)), 
                               mode='constant', constant_values=0)
        
        feat_freq = np.abs(eigenvalues)  # Frequency-like features
        feat_growth = np.real(eigenvalues)  # Growth rate features
        
        # Combine all features
        combined_features = np.concatenate([feat_mag, feat_phase, feat_freq, feat_growth])
        
        features.append(combined_features)
        time_stamps.append(start / sfreq)
    
    features = np.array(features)
    time_stamps = np.array(time_stamps)
    
    print(f"✓ Extracted {features.shape[0]} feature vectors with {features.shape[1]} features each")
    print(f"  Feature components: {n_modes} magnitude + {n_modes} phase + {n_modes} frequency + {n_modes} growth")
    
    return features, time_stamps

def process_all_files_dmd(edf_files, save_dir='features_3s'):
    """Process all EDF files and extract DMD features.
    
    Parameters:
    -----------
    edf_files : list
        List of EDF file paths
    save_dir : str
        Directory to save features
        
    Returns:
    --------
    feature_dict : dict
        Dictionary containing features for each file and channel
    """
    os.makedirs(save_dir, exist_ok=True)
    
    feature_dict = {}
    
    for file_path in edf_files:
        print(f"\n{'='*60}")
        print(f"Processing: {file_path}")
        print(f"{'='*60}")
        
        # Load and preprocess
        signals, sfreqs, channel_names = load_and_preprocess_edf(file_path, num_channels=4)
        
        file_basename = os.path.splitext(os.path.basename(file_path))[0]
        feature_dict[file_basename] = {}
        
        # Process each channel
        for i, (signal, sfreq, channel) in enumerate(zip(signals, sfreqs, channel_names)):
            print(f"\nChannel {i+1}/{len(signals)}: {channel}")
            
            # Extract DMD features
            features, time_stamps = extract_dmd_features_3s(signal, sfreq)
            
            # Store in dictionary
            feature_dict[file_basename][channel] = {
                'features': features,
                'time_stamps': time_stamps,
                'sfreq': sfreq
            }
            
            # Save features
            save_path = os.path.join(save_dir, f"{file_basename}_{channel}_features_3s.npy")
            np.save(save_path, features)
            
            # Save time stamps
            time_path = os.path.join(save_dir, f"{file_basename}_{channel}_timestamps_3s.npy")
            np.save(time_path, time_stamps)
            
            print(f"  ✓ Saved features: {save_path}")
    
    print(f"\n{'='*60}")
    print(f"✓ All files processed successfully!")
    print(f"Total files: {len(feature_dict)}")
    print(f"Total channels per file: {len(next(iter(feature_dict.values())))}")
    
    return feature_dict

## Section 4: Sequence Creation for Model Input

Create overlapping sequences of DMD features for each channel to be used as input for the hybrid model.

In [None]:
def create_sequences_from_features(feature_dict, seq_length=20, overlap=0.5):
    """Create overlapping sequences from feature dictionary.
    
    Parameters:
    -----------
    feature_dict : dict
        Dictionary of features for each file and channel
    seq_length : int
        Length of sequences to create (number of 3-second windows)
    overlap : float
        Overlap between consecutive sequences (0.0-1.0)
        
    Returns:
    --------
    sequences : dict
        Dictionary of sequences for each file and channel
    """
    step = int(seq_length * (1 - overlap))
    step = max(1, step)  # Ensure step is at least 1
    
    sequences = {}
    
    for file_name, channels in feature_dict.items():
        sequences[file_name] = {}
        
        for channel, data in channels.items():
            features = data['features']
            time_stamps = data['time_stamps']
            
            channel_sequences = []
            seq_timestamps = []
            
            # Create sequences
            for i in range(0, len(features) - seq_length + 1, step):
                seq = features[i:i + seq_length]
                seq_time = time_stamps[i:i + seq_length]
                
                channel_sequences.append(seq)
                seq_timestamps.append(seq_time)
            
            if channel_sequences:
                sequences[file_name][channel] = {
                    'sequences': np.array(channel_sequences),
                    'timestamps': np.array(seq_timestamps)
                }
                
                print(f"{file_name} - {channel}: {len(channel_sequences)} sequences "
                      f"(shape: {np.array(channel_sequences).shape})")
    
    return sequences

class MultiChannelSequenceDataset(Dataset):
    """Dataset for multi-channel sequence data."""
    
    def __init__(self, sequences_dict):
        """
        Parameters:
        -----------
        sequences_dict : dict
            Dictionary of sequences for each file and channel
        """
        self.data = []
        self.file_names = []
        self.timestamps = []
        
        # Flatten all sequences from all files and channels
        for file_name, channels in sequences_dict.items():
            # Get the first channel to determine sequence count
            first_channel = next(iter(channels.keys()))
            n_sequences = len(channels[first_channel]['sequences'])
            
            for seq_idx in range(n_sequences):
                # Create a sample with all channels
                sample = {}
                timestamp = None
                
                for channel, data in channels.items():
                    sample[channel] = torch.tensor(data['sequences'][seq_idx], dtype=torch.float32)
                    if timestamp is None:
                        timestamp = data['timestamps'][seq_idx]
                
                self.data.append(sample)
                self.file_names.append(file_name)
                self.timestamps.append(timestamp)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.file_names[idx], self.timestamps[idx]

def collate_multichannel_sequences(batch):
    """Custom collate function for multi-channel sequences."""
    samples, file_names, timestamps = zip(*batch)
    
    # Get channel names from first sample
    channels = list(samples[0].keys())
    
    # Stack sequences for each channel
    batched_data = {}
    for channel in channels:
        channel_data = [sample[channel] for sample in samples]
        batched_data[channel] = torch.stack(channel_data)
    
    return batched_data, file_names, timestamps

print("✓ Sequence creation functions defined")

## Section 5: Define BiLSTM + Time Series Transformer Hybrid Model

Implement a PyTorch model that combines BiLSTM layers with a transformer encoder, supporting multi-channel input and positional encoding.

In [None]:
class PositionalEncoding(nn.Module):
    """Positional encoding for transformer."""
    
    def __init__(self, d_model, dropout=0.1, max_len=100):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [None]:
class BiLSTMTransformerHybrid(nn.Module):
    """
    Hybrid model combining BiLSTM and Transformer for self-supervised signal reconstruction.
    
    Architecture:
    1. Input projection for each channel
    2. BiLSTM layers for sequential modeling
    3. Transformer encoder for attention-based modeling
    4. Reconstruction heads for each channel
    5. Embedding extraction for clustering
    """
    
    def __init__(self, input_dims, d_model=256, lstm_hidden=128, lstm_layers=2, 
                 nhead=8, num_transformer_layers=4, embedding_dim=64, 
                 dropout=0.1, seq_length=20):
        super(BiLSTMTransformerHybrid, self).__init__()
        
        self.input_dims = input_dims
        self.d_model = d_model
        self.embedding_dim = embedding_dim
        self.seq_length = seq_length
        
        # Input projections for each channel
        self.input_projs = nn.ModuleDict({
            channel: nn.Sequential(
                nn.Linear(dim, d_model // 2),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(d_model // 2, d_model)
            ) for channel, dim in input_dims.items()
        })
        
        # BiLSTM layers for each channel
        self.bilstm_layers = nn.ModuleDict({
            channel: nn.LSTM(
                input_size=d_model,
                hidden_size=lstm_hidden,
                num_layers=lstm_layers,
                batch_first=True,
                dropout=dropout if lstm_layers > 1 else 0,
                bidirectional=True
            ) for channel in input_dims.keys()
        })
        
        # Project BiLSTM output to d_model
        self.lstm_proj = nn.ModuleDict({
            channel: nn.Linear(lstm_hidden * 2, d_model)  # *2 for bidirectional
            for channel in input_dims.keys()
        })
        
        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout, seq_length)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            activation='relu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_transformer_layers)
        
        # Reconstruction heads for each channel
        self.reconstruction_heads = nn.ModuleDict({
            channel: nn.Sequential(
                nn.Linear(d_model, d_model // 2),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(d_model // 2, dim)
            ) for channel, dim in input_dims.items()
        })
        
        # Embedding projection for clustering
        self.embedding_proj = nn.Sequential(
            nn.Linear(d_model, embedding_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embedding_dim * 2, embedding_dim)
        )
        
        # Layer normalization
        self.layer_norm = nn.LayerNorm(d_model)
        
    def forward(self, inputs, return_embeddings=False):
        """
        Forward pass through the hybrid model.
        
        Parameters:
        -----------
        inputs : dict
            Dictionary of channel inputs, each with shape [batch_size, seq_length, feature_dim]
        return_embeddings : bool
            If True, return embeddings for clustering
            
        Returns:
        --------
        reconstructed : dict
            Reconstructed signals for each channel
        embeddings : dict (optional)
            Embeddings for clustering if return_embeddings=True
        """
        batch_size = next(iter(inputs.values())).shape[0]
        
        # Process each channel
        lstm_outputs = {}
        transformer_outputs = {}
        
        for channel, x in inputs.items():
            # Input projection
            x = self.input_projs[channel](x)  # [batch, seq_len, d_model]
            
            # BiLSTM processing
            lstm_out, _ = self.bilstm_layers[channel](x)  # [batch, seq_len, lstm_hidden*2]
            lstm_out = self.lstm_proj[channel](lstm_out)  # [batch, seq_len, d_model]
            
            # Add residual connection
            x = x + lstm_out
            x = self.layer_norm(x)
            
            lstm_outputs[channel] = x
        
        # Combine channels for transformer processing
        # Simple approach: concatenate along feature dimension then project back
        combined = torch.cat(list(lstm_outputs.values()), dim=-1)  # [batch, seq_len, d_model*n_channels]
        
        # Project back to d_model
        combined_proj = nn.Linear(combined.shape[-1], self.d_model).to(combined.device)
        combined = combined_proj(combined)  # [batch, seq_len, d_model]
        
        # Transformer processing
        # Permute for transformer: [seq_len, batch, d_model]
        combined = combined.permute(1, 0, 2)
        
        # Add positional encoding
        combined = self.pos_encoder(combined)
        
        # Apply transformer
        transformer_out = self.transformer(combined)  # [seq_len, batch, d_model]
        
        # Permute back: [batch, seq_len, d_model]
        transformer_out = transformer_out.permute(1, 0, 2)
        
        # Reconstruction for each channel
        reconstructed = {}
        for channel in inputs.keys():
            reconstructed[channel] = self.reconstruction_heads[channel](transformer_out)
        
        # Extract embeddings if requested
        if return_embeddings:
            # Use mean pooling over sequence length
            pooled = transformer_out.mean(dim=1)  # [batch, d_model]
            embeddings = self.embedding_proj(pooled)  # [batch, embedding_dim]
            
            # Return embeddings for all channels (same embedding represents the combined signal)
            embedding_dict = {channel: embeddings for channel in inputs.keys()}
            return reconstructed, embedding_dict
        
        return reconstructed

print("✓ BiLSTM + Transformer hybrid model defined")

## Section 6: Model Training for Self-Supervised Signal Reconstruction

Train the hybrid model to reconstruct the input signals in a self-supervised manner using MSE loss.

In [None]:
def train_hybrid_model(sequences_dict, input_dims, epochs=50, batch_size=32, 
                      learning_rate=1e-4, weight_decay=1e-5, focus_channel=None):
    """
    Train the BiLSTM + Transformer hybrid model for self-supervised reconstruction.
    
    Parameters:
    -----------
    sequences_dict : dict
        Dictionary of sequences for each file and channel
    input_dims : dict
        Dictionary of input dimensions for each channel
    epochs : int
        Number of training epochs
    batch_size : int
        Batch size for training
    learning_rate : float
        Learning rate for optimizer
    weight_decay : float
        Weight decay for regularization
    focus_channel : str or None
        Channel to focus on for reconstruction loss (if None, use all channels)
        
    Returns:
    --------
    model : BiLSTMTransformerHybrid
        Trained model
    training_history : dict
        Training loss history
    """
    
    # Create dataset and dataloader
    dataset = MultiChannelSequenceDataset(sequences_dict)
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        collate_fn=collate_multichannel_sequences,
        num_workers=0  # Set to 0 for Windows compatibility
    )\n    
    print(f"Training dataset size: {len(dataset)} sequences")
    print(f"Batch size: {batch_size}, Batches per epoch: {len(dataloader)}")
    
    # Initialize model
    model = BiLSTMTransformerHybrid(input_dims).to(device)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model parameters: {total_params:,} total, {trainable_params:,} trainable")
    
    # Initialize optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    # Loss function
    criterion = nn.MSELoss()
    
    # Training history
    training_history = {
        'epoch_losses': [],
        'channel_losses': {channel: [] for channel in input_dims.keys()},
        'learning_rates': []
    }
    
    # Training loop
    print(f"\\nStarting training for {epochs} epochs...")
    print(f"Focus channel: {focus_channel if focus_channel else 'All channels'}")
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        channel_losses = {channel: 0.0 for channel in input_dims.keys()}
        
        pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}', leave=False)
        
        for batch_idx, (batch_data, file_names, timestamps) in enumerate(pbar):
            # Move data to device
            batch_data = {channel: data.to(device) for channel, data in batch_data.items()}
            
            # Forward pass
            reconstructed = model(batch_data)
            
            # Calculate loss
            total_loss = 0.0
            
            if focus_channel and focus_channel in batch_data:
                # Focus on specific channel
                loss = criterion(reconstructed[focus_channel], batch_data[focus_channel])
                total_loss = loss
                channel_losses[focus_channel] += loss.item()
            else:
                # Use all channels
                for channel in batch_data.keys():
                    loss = criterion(reconstructed[channel], batch_data[channel])
                    total_loss += loss
                    channel_losses[channel] += loss.item()
            
            # Backward pass
            optimizer.zero_grad()
            total_loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            epoch_loss += total_loss.item()
            
            # Update progress bar
            pbar.set_postfix({'loss': total_loss.item():.6f})
        
        # Update learning rate
        scheduler.step()
        
        # Calculate average losses
        avg_epoch_loss = epoch_loss / len(dataloader)
        avg_channel_losses = {channel: loss / len(dataloader) for channel, loss in channel_losses.items()}
        
        # Store history
        training_history['epoch_losses'].append(avg_epoch_loss)
        training_history['learning_rates'].append(scheduler.get_last_lr()[0])
        for channel, loss in avg_channel_losses.items():
            training_history['channel_losses'][channel].append(loss)
        
        # Print epoch summary
        print(f'Epoch {epoch+1}/{epochs}: Loss = {avg_epoch_loss:.6f}, LR = {scheduler.get_last_lr()[0]:.2e}')
        
        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'loss': avg_epoch_loss,
                'input_dims': input_dims,
                'training_history': training_history
            }
            torch.save(checkpoint, f'bilstm_transformer_checkpoint_epoch_{epoch+1}.pt')
            print(f"  ✓ Checkpoint saved")
    
    print(f"\\n✓ Training completed!")
    return model, training_history

def plot_training_history(training_history):
    """Plot training history."""
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot overall loss
    axes[0].plot(training_history['epoch_losses'], 'b-', linewidth=2, label='Total Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training Loss Over Time')
    axes[0].grid(True, alpha=0.3)
    axes[0].legend()
    
    # Plot learning rate
    axes[1].plot(training_history['learning_rates'], 'r-', linewidth=2, label='Learning Rate')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Learning Rate')
    axes[1].set_title('Learning Rate Schedule')
    axes[1].set_yscale('log')
    axes[1].grid(True, alpha=0.3)
    axes[1].legend()
    
    plt.tight_layout()
    plt.show()
    
    # Plot channel-specific losses
    if len(training_history['channel_losses']) > 1:
        plt.figure(figsize=(12, 6))
        for channel, losses in training_history['channel_losses'].items():
            plt.plot(losses, linewidth=2, label=f'{channel} Loss')
        
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Channel-Specific Training Losses')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.show()

print("✓ Training functions defined")

## Section 7: Extract Embeddings from Trained Model

Obtain embeddings from the model's hidden states for each 3-second window for downstream clustering.

In [None]:
def extract_embeddings_from_model(model, sequences_dict, batch_size=64):
    """
    Extract embeddings from the trained hybrid model.
    
    Parameters:
    -----------
    model : BiLSTMTransformerHybrid
        Trained model
    sequences_dict : dict
        Dictionary of sequences for each file and channel
    batch_size : int
        Batch size for inference
        
    Returns:
    --------
    embeddings : numpy.ndarray
        Extracted embeddings for all sequences
    metadata : list
        Metadata for each embedding (file_name, timestamp)
    """
    model.eval()
    
    # Create dataset and dataloader
    dataset = MultiChannelSequenceDataset(sequences_dict)
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=False,  # Important: don't shuffle for consistent ordering
        collate_fn=collate_multichannel_sequences,
        num_workers=0
    )
    
    embeddings = []
    metadata = []
    
    print(f"Extracting embeddings from {len(dataset)} sequences...")
    
    with torch.no_grad():
        for batch_data, file_names, timestamps in tqdm(dataloader, desc="Extracting embeddings"):
            # Move data to device
            batch_data = {channel: data.to(device) for channel, data in batch_data.items()}
            
            # Get embeddings
            _, batch_embeddings = model(batch_data, return_embeddings=True)
            
            # We use the first channel's embeddings (they're all the same for combined signals)
            first_channel = next(iter(batch_embeddings.keys()))
            batch_emb = batch_embeddings[first_channel].cpu().numpy()
            
            embeddings.append(batch_emb)
            
            # Store metadata
            for i in range(len(file_names)):
                metadata.append({
                    'file_name': file_names[i],
                    'timestamp_start': timestamps[i][0],
                    'timestamp_end': timestamps[i][-1],
                    'sequence_length': len(timestamps[i])
                })
    
    # Concatenate all embeddings
    embeddings = np.vstack(embeddings)
    
    print(f"✓ Extracted {embeddings.shape[0]} embeddings with {embeddings.shape[1]} dimensions")
    
    return embeddings, metadata

def save_embeddings(embeddings, metadata, save_path='embeddings_3s'):
    """Save embeddings and metadata to disk."""
    os.makedirs(save_path, exist_ok=True)
    
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    # Save embeddings
    emb_file = os.path.join(save_path, f"embeddings_{timestamp}.npy")
    np.save(emb_file, embeddings)
    
    # Save metadata
    meta_file = os.path.join(save_path, f"metadata_{timestamp}.json")
    with open(meta_file, 'w') as f:
        json.dump(metadata, f, indent=2)
    
    # Save as CSV for compatibility
    csv_file = os.path.join(save_path, f"embeddings_{timestamp}.csv")
    emb_df = pd.DataFrame(embeddings, columns=[f'emb_{i}' for i in range(embeddings.shape[1])])
    
    # Add metadata columns
    for key in metadata[0].keys():
        emb_df[key] = [m[key] for m in metadata]
    
    emb_df.to_csv(csv_file, index=False)
    
    print(f"✓ Embeddings saved:")
    print(f"  - {emb_file}")
    print(f"  - {meta_file}")
    print(f"  - {csv_file}")
    
    return emb_file, meta_file, csv_file

print("✓ Embedding extraction functions defined")

## Section 8: Cluster Embeddings Using HDBSCAN

Apply HDBSCAN to the extracted embeddings to discover clusters representing hidden states.

In [None]:
def cluster_embeddings_hdbscan(embeddings, min_cluster_size=50, min_samples=10, 
                              metric='euclidean', cluster_selection_epsilon=0.0):
    """
    Perform HDBSCAN clustering on embeddings.
    
    Parameters:
    -----------
    embeddings : numpy.ndarray
        Embeddings to cluster
    min_cluster_size : int
        Minimum size of clusters
    min_samples : int
        Minimum number of samples in a neighborhood for a point to be considered a core point
    metric : str
        Distance metric for clustering
    cluster_selection_epsilon : float
        Distance threshold for cluster selection
        
    Returns:
    --------
    cluster_labels : numpy.ndarray
        Cluster labels for each embedding
    clusterer : hdbscan.HDBSCAN
        Fitted HDBSCAN object
    cluster_stats : dict
        Statistics about the clustering
    """
    print(f"Clustering {embeddings.shape[0]} embeddings with {embeddings.shape[1]} dimensions...")
    print(f"HDBSCAN parameters: min_cluster_size={min_cluster_size}, min_samples={min_samples}")
    
    # Standardize embeddings
    scaler = StandardScaler()
    scaled_embeddings = scaler.fit_transform(embeddings)
    
    # Apply HDBSCAN
    clusterer = hdbscan.HDBSCAN(
        min_cluster_size=min_cluster_size,
        min_samples=min_samples,
        metric=metric,
        cluster_selection_epsilon=cluster_selection_epsilon
    )
    
    cluster_labels = clusterer.fit_predict(scaled_embeddings)
    
    # Calculate cluster statistics
    unique_labels, counts = np.unique(cluster_labels, return_counts=True)
    n_clusters = len(unique_labels) - (1 if -1 in unique_labels else 0)  # Exclude noise
    n_noise = np.sum(cluster_labels == -1)
    
    cluster_stats = {
        'n_clusters': n_clusters,
        'n_noise': n_noise,
        'n_total': len(cluster_labels),
        'noise_ratio': n_noise / len(cluster_labels),
        'cluster_sizes': dict(zip(unique_labels, counts)),
        'silhouette_score': silhouette_score(scaled_embeddings, cluster_labels) if n_clusters > 1 else -1
    }
    
    print(f"✓ Clustering completed:")
    print(f"  - Number of clusters: {n_clusters}")
    print(f"  - Noise points: {n_noise} ({cluster_stats['noise_ratio']:.2%})")
    print(f"  - Silhouette score: {cluster_stats['silhouette_score']:.3f}")
    
    # Print cluster sizes
    print(f"  - Cluster sizes:")
    for label, size in sorted(cluster_stats['cluster_sizes'].items()):
        if label == -1:
            print(f"    Noise: {size}")
        else:
            print(f"    Cluster {label}: {size}")
    
    return cluster_labels, clusterer, cluster_stats

def optimize_hdbscan_parameters(embeddings, min_cluster_sizes=None, min_samples_list=None):
    """
    Optimize HDBSCAN parameters using silhouette score.
    
    Parameters:
    -----------
    embeddings : numpy.ndarray
        Embeddings to cluster
    min_cluster_sizes : list
        List of min_cluster_size values to try
    min_samples_list : list
        List of min_samples values to try
        
    Returns:
    --------
    best_params : dict
        Best parameters found
    results : list
        List of all results
    """
    if min_cluster_sizes is None:
        min_cluster_sizes = [20, 30, 50, 75, 100]
    
    if min_samples_list is None:
        min_samples_list = [5, 10, 15, 20]
    
    print("Optimizing HDBSCAN parameters...")
    
    scaler = StandardScaler()
    scaled_embeddings = scaler.fit_transform(embeddings)
    
    results = []
    best_score = -1
    best_params = None
    
    for min_cluster_size in min_cluster_sizes:
        for min_samples in min_samples_list:
            print(f"  Testing: min_cluster_size={min_cluster_size}, min_samples={min_samples}")
            
            try:
                clusterer = hdbscan.HDBSCAN(
                    min_cluster_size=min_cluster_size,
                    min_samples=min_samples,
                    metric='euclidean'
                )
                
                labels = clusterer.fit_predict(scaled_embeddings)
                
                n_clusters = len(np.unique(labels)) - (1 if -1 in labels else 0)
                n_noise = np.sum(labels == -1)
                
                if n_clusters > 1:
                    silhouette = silhouette_score(scaled_embeddings, labels)
                else:
                    silhouette = -1
                
                result = {
                    'min_cluster_size': min_cluster_size,
                    'min_samples': min_samples,
                    'n_clusters': n_clusters,
                    'n_noise': n_noise,
                    'noise_ratio': n_noise / len(labels),
                    'silhouette_score': silhouette
                }
                
                results.append(result)
                
                if silhouette > best_score:
                    best_score = silhouette
                    best_params = result
                    
                print(f"    Clusters: {n_clusters}, Noise: {n_noise}, Silhouette: {silhouette:.3f}")
                
            except Exception as e:
                print(f"    Error: {e}")
                continue
    
    print(f"\\n✓ Best parameters found:")
    print(f"  - min_cluster_size: {best_params['min_cluster_size']}")
    print(f"  - min_samples: {best_params['min_samples']}")
    print(f"  - Silhouette score: {best_params['silhouette_score']:.3f}")
    print(f"  - Number of clusters: {best_params['n_clusters']}")
    
    return best_params, results

print("✓ HDBSCAN clustering functions defined")

## Section 9: Visualize Cluster Assignments and Embedding Space

Visualize the clustering results using PCA or t-SNE and plot cluster assignments over time.

In [None]:
def visualize_embedding_space(embeddings, labels, method='PCA', sample_size=5000):
    """
    Visualize embedding space using PCA or t-SNE.
    
    Parameters:
    -----------
    embeddings : numpy.ndarray
        Embeddings to visualize
    labels : numpy.ndarray
        Cluster labels
    method : str
        Dimensionality reduction method ('PCA' or 'TSNE')
    sample_size : int
        Number of samples to visualize (for performance)
        
    Returns:
    --------
    fig : plotly.graph_objects.Figure
        Plotly figure object
    """
    # Sample data if too large
    if len(embeddings) > sample_size:
        indices = np.random.choice(len(embeddings), sample_size, replace=False)
        embeddings_sample = embeddings[indices]
        labels_sample = labels[indices]
    else:
        embeddings_sample = embeddings
        labels_sample = labels
    
    # Apply dimensionality reduction
    if method == 'PCA':
        reducer = PCA(n_components=2)
        reduced_embeddings = reducer.fit_transform(embeddings_sample)
        explained_var = reducer.explained_variance_ratio_
        title = f"PCA Visualization of Embedding Space\\n(Explained variance: {explained_var[0]:.2%} + {explained_var[1]:.2%} = {explained_var.sum():.2%})"
    elif method == 'TSNE':
        reducer = TSNE(n_components=2, random_state=42, perplexity=30)
        reduced_embeddings = reducer.fit_transform(embeddings_sample)
        title = "t-SNE Visualization of Embedding Space"
    
    # Create color map
    unique_labels = np.unique(labels_sample)
    colors = px.colors.qualitative.Set3 + px.colors.qualitative.Set1
    
    # Create figure
    fig = go.Figure()
    
    for i, label in enumerate(unique_labels):
        mask = labels_sample == label
        color = colors[i % len(colors)]
        
        if label == -1:
            name = "Noise"
            symbol = "x"
        else:
            name = f"Cluster {label}"
            symbol = "circle"
        
        fig.add_trace(go.Scatter(
            x=reduced_embeddings[mask, 0],
            y=reduced_embeddings[mask, 1],
            mode='markers',
            marker=dict(
                size=4,
                color=color,
                opacity=0.6,
                symbol=symbol
            ),
            name=name,
            text=[f"Label: {label}<br>Point: {j}" for j in np.where(mask)[0]],
            hovertemplate='%{text}<extra></extra>'
        ))
    
    fig.update_layout(
        title=title,
        xaxis_title=f"{method} Component 1",
        yaxis_title=f"{method} Component 2",
        width=800,
        height=600,
        showlegend=True
    )
    
    return fig

def plot_cluster_timeline(metadata, labels, file_name=None):
    """
    Plot cluster assignments over time.
    
    Parameters:
    -----------
    metadata : list
        List of metadata dictionaries
    labels : numpy.ndarray
        Cluster labels
    file_name : str or None
        Specific file to plot (if None, plot all files)
        
    Returns:
    --------
    fig : plotly.graph_objects.Figure
        Plotly figure object
    """
    # Convert to DataFrame for easier manipulation
    df = pd.DataFrame(metadata)
    df['cluster'] = labels
    
    # Filter by file if specified
    if file_name:
        df = df[df['file_name'] == file_name]
        title = f"Cluster Timeline for {file_name}"
    else:
        title = "Cluster Timeline (All Files)"
    
    # Create figure
    fig = go.Figure()
    
    # Get unique clusters and colors
    unique_clusters = sorted(df['cluster'].unique())
    colors = px.colors.qualitative.Set3 + px.colors.qualitative.Set1
    
    # Plot each cluster
    for i, cluster in enumerate(unique_clusters):
        cluster_data = df[df['cluster'] == cluster]
        color = colors[i % len(colors)]
        
        if cluster == -1:
            name = "Noise"
            opacity = 0.5
        else:
            name = f"Cluster {cluster}"
            opacity = 0.8
        
        fig.add_trace(go.Scatter(
            x=cluster_data['timestamp_start'],
            y=cluster_data['cluster'],
            mode='markers',
            marker=dict(
                size=8,
                color=color,
                opacity=opacity,
                symbol='square'
            ),
            name=name,
            text=[f"File: {row['file_name']}<br>Time: {row['timestamp_start']:.1f}s<br>Cluster: {row['cluster']}" 
                  for _, row in cluster_data.iterrows()],
            hovertemplate='%{text}<extra></extra>'
        ))
    
    fig.update_layout(
        title=title,
        xaxis_title="Time (seconds)",
        yaxis_title="Cluster",
        width=1000,
        height=600,
        showlegend=True
    )
    
    # Set y-axis to show discrete cluster values
    fig.update_yaxis(
        tickmode='array',
        tickvals=unique_clusters,
        ticktext=[f"Cluster {c}" if c != -1 else "Noise" for c in unique_clusters]
    )
    
    return fig

def plot_cluster_statistics(cluster_stats, labels):
    """
    Plot cluster statistics.
    
    Parameters:
    -----------
    cluster_stats : dict
        Statistics about the clustering
    labels : numpy.ndarray
        Cluster labels
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Cluster size distribution
    cluster_sizes = [size for label, size in cluster_stats['cluster_sizes'].items() if label != -1]
    cluster_labels_plot = [f"Cluster {label}" for label in cluster_stats['cluster_sizes'].keys() if label != -1]
    
    axes[0, 0].bar(cluster_labels_plot, cluster_sizes)
    axes[0, 0].set_title('Cluster Size Distribution')
    axes[0, 0].set_ylabel('Number of Points')
    axes[0, 0].tick_params(axis='x', rotation=45)
    
    # Cluster proportion pie chart
    sizes = list(cluster_stats['cluster_sizes'].values())
    labels_pie = [f"Cluster {label}" if label != -1 else "Noise" for label in cluster_stats['cluster_sizes'].keys()]
    
    axes[0, 1].pie(sizes, labels=labels_pie, autopct='%1.1f%%', startangle=90)
    axes[0, 1].set_title('Cluster Proportion')
    
    # Histogram of cluster assignments
    axes[1, 0].hist(labels, bins=np.arange(min(labels)-0.5, max(labels)+1.5, 1), alpha=0.7, edgecolor='black')
    axes[1, 0].set_title('Distribution of Cluster Assignments')
    axes[1, 0].set_xlabel('Cluster Label')
    axes[1, 0].set_ylabel('Frequency')
    
    # Summary statistics
    stats_text = f"""
    Total Points: {cluster_stats['n_total']}
    Number of Clusters: {cluster_stats['n_clusters']}
    Noise Points: {cluster_stats['n_noise']} ({cluster_stats['noise_ratio']:.2%})
    Silhouette Score: {cluster_stats['silhouette_score']:.3f}
    
    Largest Cluster: {max(cluster_sizes) if cluster_sizes else 0}
    Smallest Cluster: {min(cluster_sizes) if cluster_sizes else 0}
    Average Cluster Size: {np.mean(cluster_sizes):.1f} ± {np.std(cluster_sizes):.1f}
    """
    
    axes[1, 1].text(0.1, 0.9, stats_text, transform=axes[1, 1].transAxes, 
                    verticalalignment='top', fontsize=12, family='monospace')
    axes[1, 1].set_title('Clustering Summary')
    axes[1, 1].axis('off')
    
    plt.tight_layout()
    plt.show()

print("✓ Visualization functions defined")

## Section 10: Save Cluster Labels and Embeddings

Save the predicted cluster labels and embeddings to disk for further analysis.

In [None]:
def save_clustering_results(embeddings, labels, metadata, cluster_stats, save_dir='results_3s'):
    """
    Save all clustering results to disk.
    
    Parameters:
    -----------
    embeddings : numpy.ndarray
        Embeddings array
    labels : numpy.ndarray
        Cluster labels
    metadata : list
        Metadata for each embedding
    cluster_stats : dict
        Clustering statistics
    save_dir : str
        Directory to save results
        
    Returns:
    --------
    file_paths : dict
        Dictionary of saved file paths
    """
    os.makedirs(save_dir, exist_ok=True)
    
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    base_name = f"bilstm_transformer_hdbscan_3s_{timestamp}"
    
    file_paths = {}
    
    # Save embeddings
    emb_file = os.path.join(save_dir, f"{base_name}_embeddings.npy")
    np.save(emb_file, embeddings)
    file_paths['embeddings'] = emb_file
    
    # Save cluster labels
    labels_file = os.path.join(save_dir, f"{base_name}_labels.npy")
    np.save(labels_file, labels)
    file_paths['labels'] = labels_file
    
    # Save metadata
    meta_file = os.path.join(save_dir, f"{base_name}_metadata.json")
    with open(meta_file, 'w') as f:
        json.dump(metadata, f, indent=2)
    file_paths['metadata'] = meta_file
    
    # Save cluster statistics
    stats_file = os.path.join(save_dir, f"{base_name}_stats.json")
    with open(stats_file, 'w') as f:
        # Convert numpy types to Python types for JSON serialization
        json_stats = {}
        for key, value in cluster_stats.items():
            if isinstance(value, np.ndarray):
                json_stats[key] = value.tolist()
            elif isinstance(value, dict):
                json_stats[key] = {str(k): (v.tolist() if isinstance(v, np.ndarray) else v) 
                                  for k, v in value.items()}
            else:
                json_stats[key] = value
        json.dump(json_stats, f, indent=2)
    file_paths['stats'] = stats_file
    
    # Save combined CSV for easy analysis
    csv_file = os.path.join(save_dir, f"{base_name}_combined.csv")
    
    # Create comprehensive DataFrame
    df = pd.DataFrame(embeddings, columns=[f'emb_{i}' for i in range(embeddings.shape[1])])
    df['cluster_label'] = labels
    
    # Add metadata columns
    for key in metadata[0].keys():
        df[key] = [m[key] for m in metadata]
    
    # Add additional derived columns
    df['is_noise'] = (df['cluster_label'] == -1)
    df['sequence_duration'] = df['timestamp_end'] - df['timestamp_start']
    
    df.to_csv(csv_file, index=False)
    file_paths['csv'] = csv_file
    
    # Save summary report
    report_file = os.path.join(save_dir, f"{base_name}_report.txt")
    with open(report_file, 'w') as f:
        f.write("BiLSTM + Transformer + HDBSCAN Clustering Results\\n")
        f.write("=" * 50 + "\\n\\n")
        f.write(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\\n")
        f.write(f"Total sequences: {len(embeddings)}\\n")
        f.write(f"Embedding dimension: {embeddings.shape[1]}\\n")
        f.write(f"Number of clusters: {cluster_stats['n_clusters']}\\n")
        f.write(f"Noise points: {cluster_stats['n_noise']} ({cluster_stats['noise_ratio']:.2%})\\n")
        f.write(f"Silhouette score: {cluster_stats['silhouette_score']:.3f}\\n\\n")
        
        f.write("Cluster Sizes:\\n")
        for label, size in sorted(cluster_stats['cluster_sizes'].items()):
            if label == -1:
                f.write(f"  Noise: {size} points\\n")
            else:
                f.write(f"  Cluster {label}: {size} points\\n")
        
        f.write("\\nFiles processed:\\n")
        unique_files = list(set(m['file_name'] for m in metadata))
        for file_name in sorted(unique_files):
            count = sum(1 for m in metadata if m['file_name'] == file_name)
            f.write(f"  {file_name}: {count} sequences\\n")
    
    file_paths['report'] = report_file
    
    print(f"✓ Results saved to {save_dir}:")
    for key, path in file_paths.items():
        print(f"  - {key}: {os.path.basename(path)}")
    
    return file_paths

print("✓ Save functions defined")

## Execution Workflow

Now let's execute the complete pipeline step by step:

In [None]:
# Step 1: Process EDF files and extract DMD features
print("Step 1: Processing EDF files and extracting DMD features...")
feature_dict = process_all_files_dmd(edf_files, save_dir='features_3s')

# Get input dimensions for model
first_file = next(iter(feature_dict))
input_dims = {
    channel: data['features'].shape[1] 
    for channel, data in feature_dict[first_file].items()
}

print(f"\\nInput dimensions: {input_dims}")
print(f"Features extracted for {len(feature_dict)} files")

In [None]:
# Step 2: Create sequences for model input
print("Step 2: Creating sequences for model input...")
sequences_dict = create_sequences_from_features(
    feature_dict, 
    seq_length=20,  # 20 sequences of 3-second windows = 60 seconds total
    overlap=0.5     # 50% overlap between sequences
)

print(f"\\nSequences created for {len(sequences_dict)} files")

In [None]:
# Step 3: Train the BiLSTM + Transformer hybrid model
print("Step 3: Training the BiLSTM + Transformer hybrid model...")
trained_model, training_history = train_hybrid_model(
    sequences_dict=sequences_dict,
    input_dims=input_dims,
    epochs=30,
    batch_size=16,
    learning_rate=1e-4,
    focus_channel='EEG Fpz-Cz'  # Focus on EEG channel for reconstruction
)

# Plot training history
plot_training_history(training_history)

In [None]:
# Step 4: Extract embeddings from trained model
print("Step 4: Extracting embeddings from trained model...")
embeddings, metadata = extract_embeddings_from_model(
    model=trained_model,
    sequences_dict=sequences_dict,
    batch_size=64
)

print(f"\\nEmbeddings shape: {embeddings.shape}")
print(f"Metadata entries: {len(metadata)}")

In [None]:
# Step 5: Optimize HDBSCAN parameters and perform clustering
print("Step 5: Optimizing HDBSCAN parameters and performing clustering...")

# Optimize parameters (optional - can skip for faster execution)
best_params, optimization_results = optimize_hdbscan_parameters(
    embeddings,
    min_cluster_sizes=[30, 50, 75, 100],
    min_samples_list=[5, 10, 15]
)

# Perform clustering with optimized parameters
cluster_labels, clusterer, cluster_stats = cluster_embeddings_hdbscan(
    embeddings,
    min_cluster_size=best_params['min_cluster_size'],
    min_samples=best_params['min_samples']
)

print(f"\\nClustering completed with {cluster_stats['n_clusters']} clusters")

In [None]:
# Step 6: Visualize results
print("Step 6: Visualizing clustering results...")

# Plot cluster statistics
plot_cluster_statistics(cluster_stats, cluster_labels)

# Visualize embedding space with PCA
fig_pca = visualize_embedding_space(embeddings, cluster_labels, method='PCA')
fig_pca.show()

# Visualize embedding space with t-SNE (optional - takes longer)
# fig_tsne = visualize_embedding_space(embeddings, cluster_labels, method='TSNE')
# fig_tsne.show()

# Plot cluster timeline
fig_timeline = plot_cluster_timeline(metadata, cluster_labels)
fig_timeline.show()

# Plot timeline for specific file
first_file = metadata[0]['file_name']
fig_timeline_single = plot_cluster_timeline(metadata, cluster_labels, file_name=first_file)
fig_timeline_single.show()

In [None]:
# Step 7: Save all results
print("Step 7: Saving all results...")

# Save clustering results
file_paths = save_clustering_results(
    embeddings=embeddings,
    labels=cluster_labels,
    metadata=metadata,
    cluster_stats=cluster_stats,
    save_dir='results_3s'
)

# Save embeddings separately
save_embeddings(embeddings, metadata, save_path='embeddings_3s')

print("\\n" + "="*60)
print("🎉 PIPELINE EXECUTION COMPLETED SUCCESSFULLY! 🎉")
print("="*60)
print(f"\\nSummary:")
print(f"- Processed {len(edf_files)} EDF files")
print(f"- Extracted {len(embeddings)} embeddings from 3-second windows")
print(f"- Discovered {cluster_stats['n_clusters']} distinct clusters")
print(f"- Silhouette score: {cluster_stats['silhouette_score']:.3f}")
print(f"- Results saved to: results_3s/")
print("\\nNext steps:")
print("- Analyze cluster characteristics")
print("- Correlate with sleep stage annotations (if available)")
print("- Perform statistical analysis of cluster transitions")
print("- Generate detailed reports")