# Unsupervised Sleep Stage Clustering with CNN+LSTM

This notebook implements a deep learning pipeline to cluster sleep stages from PSG signals and identify micro arousals at millisecond scale in an unsupervised manner. We'll use:
- PyTorch for building the neural network
- CuPy for GPU acceleration
- Dynamic Mode Decomposition (DMD) for dimension reduction

## 1. Import Libraries and Setup GPU

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import cupy as cp
from sklearn.cluster import KMeans, DBSCAN
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
import pyedflib
import warnings
warnings.filterwarnings('ignore')

# Check if GPU is available
if not torch.cuda.is_available():
    raise RuntimeError("GPU is required for this pipeline. CUDA not available.")

device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name()}")

Using GPU: NVIDIA GeForce GTX 1080 Ti


## 2. Data Loading and Preprocessing

We'll start by defining functions to load and preprocess EDF files containing PSG data.

In [2]:
class PSGDataset(Dataset):
    def __init__(self, file_paths, window_size=30, overlap=0.5, transform=None):
        """
        Dataset for PSG signals
        
        Args:
            file_paths: List of paths to EDF files
            window_size: Size of the sliding window in seconds
            overlap: Overlap between windows (0-1)
            transform: Optional transform to apply to the data
        """
        self.file_paths = file_paths
        self.window_size = window_size
        self.overlap = overlap
        self.transform = transform
        
        self.signal_names = ['EEG Fpz-Cz', 'EEG Pz-Oz', 'EOG horizontal', 
                            'Resp oro-nasal', 'EMG submental', 'Temp rectal', 'Event marker']
        self.sampling_rates = [100, 100, 100, 1, 1, 1, 1]  # Hz
        
        self.segments = self._prepare_segments()
        
    def _prepare_segments(self):
        segments = []
        
        for file_path in self.file_paths:
            signals, signal_headers = self._load_edf(file_path)
            
            # Create segments
            for i, (signal, header, rate) in enumerate(zip(signals, signal_headers, self.sampling_rates)):
                step = int(self.window_size * rate * (1 - self.overlap))
                window_samples = int(self.window_size * rate)
                
                for start in range(0, len(signal) - window_samples, step):
                    segment = signal[start:start+window_samples]
                    segments.append({
                        "signal": segment,
                        "signal_name": header['label'],
                        "start_time": start / rate,
                        "sampling_rate": rate,
                        "file_path": file_path
                    })
        
        return segments
    
    def _load_edf(self, file_path):
        f = pyedflib.EdfReader(file_path)
        n = f.signals_in_file
        signal_headers = f.getSignalHeaders()
        signals = [f.readSignal(i) for i in range(n)]
        f.close()
        return signals, signal_headers
    
    def __len__(self):
        return len(self.segments)
    
    def __getitem__(self, idx):
        segment = self.segments[idx]
        data = segment["signal"]
        
        if self.transform:
            data = self.transform(data)
        
        # Convert to tensor
        data = torch.FloatTensor(data)
        
        return data, segment["signal_name"], segment["start_time"]

## 3. Dynamic Mode Decomposition (DMD) for Feature Extraction

In [3]:
class DMD:
    def __init__(self, rank=None, svd_rank=None):
        """
        Dynamic Mode Decomposition for feature extraction
        
        Args:
            rank: Number of DMD modes to compute
            svd_rank: SVD rank to use for truncated SVD
        """
        self.rank = rank
        self.svd_rank = svd_rank
        self.modes = None
        self.dynamics = None
        self.eigenvalues = None
    
    def fit_transform(self, X):
        """
        Compute DMD modes and return the reconstructed data
        
        Args:
            X: Data matrix (time steps x features)
            
        Returns:
            Transformed data
        """
        # Use CuPy for GPU acceleration
        X_gpu = cp.asarray(X)
        
        # Split data into snapshots
        X1 = X_gpu[:, :-1]
        X2 = X_gpu[:, 1:]
        
        # SVD on X1
        U, sigma, Vh = cp.linalg.svd(X1, full_matrices=False)
        
        # Truncate SVD if svd_rank is specified
        if self.svd_rank is not None:
            r = min(self.svd_rank, len(sigma))
        elif self.rank is not None:
            r = min(self.rank, len(sigma))
        else:
            # Use cumulative energy to determine rank
            cumulative_energy = cp.cumsum(sigma**2) / cp.sum(sigma**2)
            r = cp.argmax(cumulative_energy >= 0.99) + 1
        
        U = U[:, :r]
        sigma = sigma[:r]
        Vh = Vh[:r, :]
        
        # Compute A tilde (reduced order operator)
        A_tilde = cp.dot(cp.dot(U.T, X2), cp.diag(1.0/sigma).dot(Vh))
        
        # Eigendecomposition of A tilde
        eigenvalues, eigenvectors = cp.linalg.eig(A_tilde)
        
        # DMD modes
        modes = cp.dot(cp.dot(X2, cp.diag(1.0/sigma).dot(Vh).T), eigenvectors)
        
        # Save results
        self.modes = cp.asnumpy(modes)
        self.eigenvalues = cp.asnumpy(eigenvalues)
        
        # Return DMD features
        return cp.asnumpy(cp.abs(modes).T)
    
    def get_feature_importance(self):
        """
        Get DMD mode importance based on eigenvalues
        
        Returns:
            Mode importance scores
        """
        if self.eigenvalues is None:
            raise ValueError("DMD not fit yet. Call fit_transform first.")
        
        # Importance is based on the modulus of eigenvalues
        importance = np.abs(self.eigenvalues)
        return importance / np.sum(importance)

## 4. CNN+LSTM Model Architecture

In [4]:
class CNN_LSTM(nn.Module):
    def __init__(self, input_channels=7, hidden_dims=128, latent_dim=64):
        """
        CNN+LSTM model for unsupervised learning of PSG signals
        
        Args:
            input_channels: Number of input channels (signal types)
            hidden_dims: Dimensions of hidden layers
            latent_dim: Dimension of latent space
        """
        super(CNN_LSTM, self).__init__()
        
        # CNN feature extractor
        self.feature_extractor = nn.Sequential(
            nn.Conv1d(input_channels, 64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(2),
            
            nn.Conv1d(64, 128, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(2),
            
            nn.Conv1d(128, 256, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.MaxPool1d(2),
        )
        
        # LSTM for temporal modeling
        self.lstm = nn.LSTM(256, hidden_dims, num_layers=2, batch_first=True, dropout=0.2, bidirectional=True)
        
        # MLP for latent representation
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dims*2, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, latent_dim)
        )
        
        # Decoder for reconstruction (will be used in training)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, input_channels * 100)  # Assuming 1 second of data at 100 Hz
        )
    
    def forward(self, x):
        # x shape: (batch_size, channels, time_steps)
        
        # Extract CNN features
        cnn_features = self.feature_extractor(x)
        
        # Reshape for LSTM: (batch_size, time_steps, features)
        cnn_features = cnn_features.permute(0, 2, 1)
        
        # Apply LSTM
        lstm_out, _ = self.lstm(cnn_features)
        
        # Get final time step output
        lstm_out = lstm_out[:, -1, :]
        
        # Get latent representation
        latent = self.mlp(lstm_out)
        
        # Decode for reconstruction
        reconstruction = self.decoder(latent)
        reconstruction = reconstruction.view(x.size(0), x.size(1), -1)
        
        return latent, reconstruction
    
    def encode(self, x):
        # Only return the latent representation
        cnn_features = self.feature_extractor(x)
        cnn_features = cnn_features.permute(0, 2, 1)
        lstm_out, _ = self.lstm(cnn_features)
        lstm_out = lstm_out[:, -1, :]
        latent = self.mlp(lstm_out)
        return latent

## 5. Unsupervised Training Pipeline

In [5]:
class UnsupervisedTrainer:
    def __init__(self, model, optimizer, scheduler=None, device='cuda'):
        """
        Trainer for unsupervised learning
        
        Args:
            model: CNN_LSTM model
            optimizer: Optimizer
            scheduler: Learning rate scheduler
            device: Device to use
        """
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.reconstruction_loss = nn.MSELoss()
        
        # For contrastive learning
        self.temperature = 0.5
    
    def train_epoch(self, dataloader, epoch):
        self.model.train()
        total_loss = 0
        total_recon_loss = 0
        total_contrastive_loss = 0
        
        for batch_idx, (data, signal_names, _) in enumerate(dataloader):
            data = data.to(self.device)
            
            # Create augmented versions
            data_aug1 = self._augment(data)
            data_aug2 = self._augment(data)
            
            # Forward pass for original and augmented data
            latent, reconstruction = self.model(data)
            latent_aug1, _ = self.model(data_aug1)
            latent_aug2, _ = self.model(data_aug2)
            
            # Reconstruction loss
            recon_loss = self.reconstruction_loss(reconstruction, data)
            
            # Contrastive loss between augmented versions
            contrastive_loss = self._contrastive_loss(latent_aug1, latent_aug2)
            
            # Total loss (weighted sum)
            loss = recon_loss + 0.5 * contrastive_loss
            
            # Backpropagation
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            # Track losses
            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_contrastive_loss += contrastive_loss.item()
            
            if batch_idx % 10 == 0:
                print(f"Epoch {epoch} [{batch_idx * len(data)}/{len(dataloader.dataset)} "
                      f"({100. * batch_idx / len(dataloader):.0f}%)]", end=' ')
                print(f"Loss: {loss.item():.6f} "
                      f"Recon: {recon_loss.item():.6f} "
                      f"Contrastive: {contrastive_loss.item():.6f}")
        
        # Update scheduler if provided
        if self.scheduler is not None:
            self.scheduler.step()
        
        avg_loss = total_loss / len(dataloader)
        avg_recon_loss = total_recon_loss / len(dataloader)
        avg_contrastive_loss = total_contrastive_loss / len(dataloader)
        
        return avg_loss, avg_recon_loss, avg_contrastive_loss
    
    def _contrastive_loss(self, z1, z2):
        """
        Contrastive loss function (SimCLR)
        """
        batch_size = z1.size(0)
        z1 = nn.functional.normalize(z1, dim=1)
        z2 = nn.functional.normalize(z2, dim=1)
        
        # Full similarity matrix
        z = torch.cat([z1, z2], dim=0)
        sim_matrix = torch.mm(z, z.t()) / self.temperature
        
        # Mask for positive pairs
        sim_matrix.fill_diagonal_(float('-inf'))
        mask = torch.zeros_like(sim_matrix)
        mask[:batch_size, batch_size:] = torch.eye(batch_size)
        mask[batch_size:, :batch_size] = torch.eye(batch_size)
        
        # Compute loss
        positives = torch.sum(torch.exp(sim_matrix) * mask, dim=1)
        negatives = torch.sum(torch.exp(sim_matrix), dim=1) - positives
        loss = -torch.log(positives / negatives).mean()
        
        return loss
    
    def _augment(self, data):
        """
        Apply random augmentations to data
        """
        batch_size, channels, time_steps = data.shape
        augmented = data.clone()
        
        # Random noise
        if torch.rand(1).item() < 0.5:
            noise = torch.randn_like(augmented) * 0.05
            augmented += noise
        
        # Random scaling
        if torch.rand(1).item() < 0.5:
            scale = torch.randn(batch_size, channels, 1).to(self.device) * 0.1 + 1.0
            augmented *= scale
        
        # Random time masking
        if torch.rand(1).item() < 0.5:
            for i in range(batch_size):
                mask_len = int(time_steps * (torch.rand(1).item() * 0.2))
                start = int(torch.rand(1).item() * (time_steps - mask_len))
                augmented[i, :, start:start+mask_len] = 0
        
        return augmented
    
    def evaluate(self, dataloader):
        """
        Evaluate model on validation data
        """
        self.model.eval()
        total_loss = 0
        
        with torch.no_grad():
            for data, _, _ in dataloader:
                data = data.to(self.device)
                latent, reconstruction = self.model(data)
                loss = self.reconstruction_loss(reconstruction, data)
                total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        return avg_loss
    
    def get_embeddings(self, dataloader):
        """
        Get latent embeddings for all data points
        """
        self.model.eval()
        embeddings = []
        timestamps = []
        signal_types = []
        
        with torch.no_grad():
            for data, signal_name, start_time in dataloader:
                data = data.to(self.device)
                latent = self.model.encode(data)
                embeddings.append(latent.cpu().numpy())
                timestamps.append(start_time.numpy())
                signal_types.extend(signal_name)
        
        embeddings = np.vstack(embeddings)
        timestamps = np.concatenate(timestamps)
        
        return embeddings, timestamps, signal_types

## 6. Micro-arousal Detection and Clustering

In [6]:
class MicroArousalDetector:
    def __init__(self, embeddings, timestamps, signal_types, min_samples=5, eps=0.5):
        """
        Detect micro-arousals using clustering of embeddings
        
        Args:
            embeddings: Latent embeddings from model
            timestamps: Timestamps for each embedding
            signal_types: Signal type for each embedding
            min_samples: Minimum samples for DBSCAN clustering
            eps: Maximum distance between samples for DBSCAN
        """
        self.embeddings = embeddings
        self.timestamps = timestamps
        self.signal_types = signal_types
        self.min_samples = min_samples
        self.eps = eps
        self.clusters = None
    
    def detect_clusters(self, method="dbscan", n_clusters=5):
        """
        Detect clusters in embedding space
        
        Args:
            method: Clustering method ("dbscan" or "kmeans")
            n_clusters: Number of clusters for K-means
        """
        # Scale embeddings
        scaler = StandardScaler()
        scaled_embeddings = scaler.fit_transform(self.embeddings)
        
        if method == "dbscan":
            # DBSCAN for density-based clustering
            clustering = DBSCAN(eps=self.eps, min_samples=self.min_samples, n_jobs=-1)
            clusters = clustering.fit_predict(scaled_embeddings)
        else:
            # K-means for fixed number of clusters
            clustering = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
            clusters = clustering.fit_predict(scaled_embeddings)
        
        self.clusters = clusters
        return clusters
    
    def detect_micro_arousals(self, window_size=5, threshold=0.8):
        """
        Detect micro-arousals based on cluster transitions
        
        Args:
            window_size: Size of sliding window in seconds
            threshold: Threshold for cluster transition ratio
        """
        if self.clusters is None:
            raise ValueError("Run detect_clusters first")
        
        # Sort by timestamps
        sorted_indices = np.argsort(self.timestamps)
        sorted_timestamps = self.timestamps[sorted_indices]
        sorted_clusters = self.clusters[sorted_indices]
        
        micro_arousals = []
        
        # Slide window and detect transitions
        for i in range(0, len(sorted_timestamps) - window_size, 1):
            window_clusters = sorted_clusters[i:i+window_size]
            
            # Count transitions
            transitions = sum(window_clusters[j] != window_clusters[j+1] for j in range(len(window_clusters)-1))
            transition_ratio = transitions / (len(window_clusters) - 1)
            
            if transition_ratio >= threshold:
                micro_arousals.append({
                    "start_time": sorted_timestamps[i],
                    "end_time": sorted_timestamps[i+window_size-1],
                    "transition_ratio": transition_ratio,
                    "clusters": window_clusters.copy()
                })
        
        return micro_arousals
    
    def visualize_clusters(self):
        """
        Visualize clusters using t-SNE
        """
        if self.clusters is None:
            raise ValueError("Run detect_clusters first")
        
        # Compute t-SNE embedding
        tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
        embedded = tsne.fit_transform(self.embeddings)
        
        # Plot clusters
        plt.figure(figsize=(12, 10))
        scatter = plt.scatter(embedded[:, 0], embedded[:, 1], c=self.clusters, cmap='viridis', alpha=0.7, s=10)
        plt.colorbar(scatter)
        plt.title("t-SNE Visualization of Sleep Stage Clusters")
        plt.xlabel("t-SNE Component 1")
        plt.ylabel("t-SNE Component 2")
        plt.tight_layout()
        plt.show()
        
        # Plot time series of clusters
        sorted_indices = np.argsort(self.timestamps)
        plt.figure(figsize=(14, 6))
        plt.scatter(self.timestamps[sorted_indices], self.clusters[sorted_indices], 
                   c=self.clusters[sorted_indices], cmap='viridis', alpha=0.7, s=10)
        plt.title("Cluster Assignment Over Time")
        plt.xlabel("Time (seconds)")
        plt.ylabel("Cluster")
        plt.colorbar(label="Cluster")
        plt.tight_layout()
        plt.show()

## 7. Main Pipeline Implementation

In [7]:
def run_sleep_clustering_pipeline(edf_files, output_dir="./results", epochs=50, batch_size=64):
    """
    Run the complete sleep clustering pipeline
    
    Args:
        edf_files: List of paths to EDF files
        output_dir: Directory to save results
        epochs: Number of training epochs
        batch_size: Batch size for training
    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # 1. Load and preprocess data
    print("Loading and preprocessing data...")
    dataset = PSGDataset(edf_files, window_size=30, overlap=0.5)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False)
    
    # 2. Initialize model and trainer
    print("Initializing model...")
    model = CNN_LSTM(input_channels=7, hidden_dims=128, latent_dim=64).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)
    trainer = UnsupervisedTrainer(model, optimizer, scheduler, device)
    
    # 3. Train model
    print("Starting training...")
    train_losses = []
    recon_losses = []
    contrastive_losses = []
    
    for epoch in range(1, epochs+1):
        loss, recon_loss, contrastive_loss = trainer.train_epoch(dataloader, epoch)
        train_losses.append(loss)
        recon_losses.append(recon_loss)
        contrastive_losses.append(contrastive_loss)
        
        print(f"Epoch {epoch}/{epochs} - Loss: {loss:.6f}, "
              f"Recon Loss: {recon_loss:.6f}, Contrastive Loss: {contrastive_loss:.6f}")
    
    # Save model
    torch.save(model.state_dict(), f"{output_dir}/sleep_clustering_model.pt")
    
    # Plot training curves
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Total Loss')
    plt.plot(recon_losses, label='Reconstruction Loss')
    plt.plot(contrastive_losses, label='Contrastive Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training Losses')
    plt.savefig(f"{output_dir}/training_curves.png")
    plt.show()
    
    # 4. Get embeddings for all data
    print("Extracting embeddings...")
    eval_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    embeddings, timestamps, signal_types = trainer.get_embeddings(eval_dataloader)
    
    # 5. Apply DMD for additional feature extraction
    print("Applying DMD for feature extraction...")
    dmd = DMD(rank=20)
    dmd_features = dmd.fit_transform(embeddings.T)  # Transpose for DMD
    
    # Combine original embeddings with DMD features
    combined_features = np.hstack([embeddings, dmd_features])
    
    # 6. Detect micro-arousals
    print("Detecting micro-arousals...")
    detector = MicroArousalDetector(combined_features, np.array(timestamps), signal_types)
    
    # Try different clustering methods
    print("Detecting clusters with DBSCAN...")
    detector.detect_clusters(method="dbscan")
    detector.visualize_clusters()
    
    print("Detecting clusters with K-means...")
    detector.detect_clusters(method="kmeans", n_clusters=5)
    detector.visualize_clusters()
    
    # Find micro-arousals
    micro_arousals = detector.detect_micro_arousals(window_size=5, threshold=0.7)
    
    print(f"Detected {len(micro_arousals)} potential micro-arousals")
    
    # Save results
    with open(f"{output_dir}/micro_arousals.txt", "w") as f:
        for i, arousal in enumerate(micro_arousals):
            f.write(f"Micro-arousal {i+1}: {arousal['start_time']:.2f}s to {arousal['end_time']:.2f}s "
                   f"(Transition ratio: {arousal['transition_ratio']:.2f})\n")
    
    # Return results for further analysis
    return {
        "model": model,
        "embeddings": embeddings,
        "dmd_features": dmd_features,
        "combined_features": combined_features,
        "micro_arousals": micro_arousals,
        "detector": detector
    }

In [10]:
def custom_collate(batch):
    """
    Custom collate function to handle variable-sized tensors
    """
    # Extract data, signal_names, and start_times
    data = [item[0] for item in batch]
    signal_names = [item[1] for item in batch]
    start_times = [item[2] for item in batch]
    
    # Handle data tensors of different sizes
    if isinstance(data[0], torch.Tensor):
        # Option 1: Pad all tensors to the maximum size
        max_length = max([d.size(0) for d in data])
        data_padded = []
        
        for d in data:
            if d.size(0) < max_length:
                # Create padding
                padding = torch.zeros(max_length - d.size(0), *d.size()[1:], device=d.device)
                # Concatenate with original data
                d_padded = torch.cat([d, padding], dim=0)
                data_padded.append(d_padded)
            else:
                data_padded.append(d)
        
        # Stack the padded tensors
        data_tensor = torch.stack(data_padded)
    else:
        # If not tensors, just keep as a list
        data_tensor = data
    
    # Convert start_times to tensor if they're numeric
    try:
        start_times_tensor = torch.tensor(start_times)
    except:
        start_times_tensor = start_times
    
    return data_tensor, signal_names, start_times_tensor

## 8. Example Usage

Assuming you have EDF files containing PSG data, you can use the pipeline as follows:

In [11]:
# Example usage (replace with your actual EDF file paths)
edf_files = ["raw data/SC4001E0-PSG.edf"]
results = run_sleep_clustering_pipeline(edf_files, output_dir="./sleep_results", epochs=50)

print("Pipeline ready to use. Please provide paths to your EDF files containing PSG data to begin.")
print("Example: results = run_sleep_clustering_pipeline(['/path/to/psg1.edf', '/path/to/psg2.edf'])")

Loading and preprocessing data...
Initializing model...
Starting training...


RuntimeError: stack expects each tensor to be equal size, but got [3000] at entry 0 and [30] at entry 1

## 9. Conclusion

This pipeline implements an unsupervised deep learning approach for analyzing PSG signals and identifying micro-arousals at a millisecond scale. The key components are:

1. Data preprocessing of EDF files containing PSG signals
2. Feature extraction using Dynamic Mode Decomposition (DMD)
3. A CNN+LSTM architecture for temporal sequence modeling
4. Unsupervised learning with reconstruction and contrastive losses
5. Clustering and visualization of sleep stages
6. Detection of micro-arousals based on cluster transitions

This approach allows for the discovery of new patterns in sleep data without relying on labeled data, which can help researchers better understand sleep disorders and patterns.