In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from scipy.signal import butter, filtfilt
import math
import os
from typing import Tuple, List, Dict, Optional, Union
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class SpatialSelfAttention(nn.Module):
    """
    Spatial self-attention module for processing temporal EEG data.
    Applies attention across the time dimension for each channel.
    """
    def __init__(self, in_channels: int, dk: int = 32):
        super().__init__()
        self.dk = dk
        
        # Projections for Query, Key, Value
        self.query_conv = nn.Conv1d(in_channels, dk, 1)
        self.key_conv = nn.Conv1d(in_channels, dk, 1)
        self.value_conv = nn.Conv1d(in_channels, in_channels, 1)
        
        self.out_proj = nn.Conv1d(in_channels, in_channels, 1)
        self.layer_norm = nn.LayerNorm([in_channels])
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, C, L = x.size()
        
        # Create Q, K, V projections
        Q = self.query_conv(x)
        K = self.key_conv(x)
        V = self.value_conv(x)
        
        # Reshape for attention computation
        Q = Q.permute(0, 2, 1)
        K = K.permute(0, 2, 1)
        V = V.permute(0, 2, 1)
        
        # Compute attention scores
        attn = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.dk)
        attn = F.softmax(attn, dim=-1)
        
        # Apply attention to values
        out = torch.matmul(attn, V).permute(0, 2, 1)
        out = self.out_proj(out)
        
        # Residual connection and normalization
        out = x + out
        out = out.permute(0, 2, 1)
        out = self.layer_norm(out)
        out = out.permute(0, 2, 1)
        
        return out, attn

class CrossChannelAttention(nn.Module):
    """
    Cross-channel attention module for modeling relationships between EEG channels.
    """
    def __init__(self, in_channels: int, dk: int = 32):
        super().__init__()
        self.dk = dk
        
        self.q_proj = nn.Linear(in_channels, dk)
        self.k_proj = nn.Linear(in_channels, dk)
        self.v_proj = nn.Linear(in_channels, in_channels)
        
        self.out_proj = nn.Linear(in_channels, in_channels)
        self.layer_norm = nn.LayerNorm(in_channels)
        
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        x_reshaped = x.permute(0, 2, 1)
        
        Q = self.q_proj(x_reshaped)
        K = self.k_proj(x_reshaped)
        V = self.v_proj(x_reshaped)
        
        attn = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.dk)
        attn = F.softmax(attn, dim=-1)
        
        out = torch.matmul(attn, V)
        out = self.out_proj(out)
        
        out = x_reshaped + out
        out = self.layer_norm(out)
        out = out.permute(0, 2, 1)
        
        return out, attn

class LocalSpikeDetectionBranch(nn.Module):
    """
    Branch for detecting local spike patterns in EEG signals.
    """
    def __init__(self, kernel_size: int, in_channels: int, out_channels: int):
        super().__init__()
        self.conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            padding=kernel_size // 2,  # Changed to same padding
            bias=True
        )
        self.bn = nn.BatchNorm1d(out_channels)
        self.pool = nn.MaxPool1d(kernel_size=4, stride=4)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.bn(self.conv(x)))
        x = self.pool(x)
        return x

class EpilepsyDetectionModel(nn.Module):
    """
    Complete model for epilepsy detection from EEG signals.
    Combines multiple detection branches with attention mechanisms.
    """
    def __init__(self, num_channels: int = 32, sampling_rate: int = 300, num_classes: int = 2):
        super().__init__()
        
        self.kernel_sizes = [
            int(0.07 * sampling_rate),  # 70ms
            int(0.15 * sampling_rate),  # 150ms
            int(0.20 * sampling_rate)   # 200ms
        ]
        
        self.branches = nn.ModuleList([
            LocalSpikeDetectionBranch(k, num_channels, 128)
            for k in self.kernel_sizes
        ])
        
        self.channel_conv = nn.Conv1d(384, 256, kernel_size=1)
        self.spatial_attention = SpatialSelfAttention(256)
        self.cross_channel_attention = CrossChannelAttention(256)
        
        # Adaptive pooling to handle variable input lengths
        self.adaptive_pool = nn.AdaptiveAvgPool1d(64)
        
        self.classifier = nn.Sequential(
            nn.Linear(256 * 64, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )
        
        self.apply(self._init_weights)
        
    def _init_weights(self, m: nn.Module):
        if isinstance(m, (nn.Conv1d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm1d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self.attention_weights = {}
        
        # Process each branch
        branch_outputs = []
        for branch in self.branches:
            branch_outputs.append(branch(x))
        
        x = torch.cat(branch_outputs, dim=1)
        x = self.channel_conv(x)
        
        x, spatial_attn = self.spatial_attention(x)
        x, cross_attn = self.cross_channel_attention(x)
        
        self.attention_weights['spatial'] = spatial_attn
        self.attention_weights['cross_channel'] = cross_attn
        
        # Adaptive pooling and classification
        x = self.adaptive_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        
        return F.log_softmax(x, dim=1)

class EEGDataset(Dataset):
    """Dataset class for EEG data handling."""
    def __init__(self, eeg_data: np.ndarray, labels: np.ndarray, 
                 spike_annotations: Optional[np.ndarray] = None,
                 transform: Optional[callable] = None):
        self.eeg_data = torch.FloatTensor(eeg_data)
        self.labels = torch.LongTensor(labels)
        self.transform = transform
        self.spike_annotations = (torch.FloatTensor(spike_annotations) 
                                if spike_annotations is not None 
                                else torch.zeros_like(self.eeg_data))

    def __len__(self) -> int:
        return len(self.eeg_data)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x = self.eeg_data[idx]
        y = self.labels[idx]
        spikes = self.spike_annotations[idx]

        if self.transform:
            x = self.transform(x)

        return x, y, spikes

class EEGAugmentation:
    """Data augmentation for EEG signals."""
    def __init__(self, noise_level: float = 0.01, 
                 shift_range: int = 15,
                 scale_range: Tuple[float, float] = (0.9, 1.1)):
        self.noise_level = noise_level
        self.shift_range = shift_range
        self.scale_range = scale_range

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        # Add Gaussian noise
        x = x + torch.randn_like(x) * self.noise_level

        # Random time shift
        if self.shift_range > 0:
            shift = torch.randint(-self.shift_range, self.shift_range + 1, (1,))
            x = torch.roll(x, shifts=shift.item(), dims=-1)

        # Random scaling
        if self.scale_range[0] < self.scale_range[1]:
            scale = torch.empty(1).uniform_(*self.scale_range)
            x = x * scale

        return x

def preprocess_eeg(raw_eeg: np.ndarray, sampling_rate: int = 300) -> np.ndarray:
    """Preprocess EEG data with filtering and normalization."""
    nyq = sampling_rate / 2
    b, a = butter(4, [1/nyq, 45/nyq], btype='band')
    
    # Apply filter along the last axis
    filtered_eeg = filtfilt(b, a, raw_eeg, axis=-1)
    
    # Z-score normalization
    mean = filtered_eeg.mean(axis=-1, keepdims=True)
    std = filtered_eeg.std(axis=-1, keepdims=True)
    normalized_eeg = (filtered_eeg - mean) / (std + 1e-8)
    
    return normalized_eeg

class Trainer:
    """Training manager for the epilepsy detection model."""
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, 
                 scheduler, device, config):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.config = config
        
        self.config.update({
            'spike_lambda': 0.1,
            'consistency_lambda': 0.05,
            'spike_margin': 0.5,
            'validation_interval': 100
        })
        
        self.branch_activations = {}
        self._register_hooks()

    def _register_hooks(self):
        def get_activation(name):
            def hook(module, input, output):
                self.branch_activations[name] = output
            return hook

        for i, branch in enumerate(self.model.branches):
            branch.conv.register_forward_hook(get_activation(f'branch_{i}'))

    def _calculate_metrics(self, trues, preds):
        precision, recall, f1, _ = precision_recall_fscore_support(
            trues, preds, average='binary'
        )
        conf_matrix = confusion_matrix(trues, preds)
        
        return {
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'confusion_matrix': conf_matrix
        }

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        predictions = []
        true_labels = []
        
        for batch_idx, (data, target, spike_annot) in enumerate(tqdm(self.train_loader)):
            data, target = data.to(self.device), target.to(self.device)
            
            self.optimizer.zero_grad()
            output = self.model(data)
            
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            
            pred = output.max(1)[1]
            predictions.extend(pred.cpu().numpy())
            true_labels.extend(target.cpu().numpy())
            
            if batch_idx % self.config['validation_interval'] == 0:
                logger.info(f'Training batch {batch_idx}/{len(self.train_loader)}, '
                          f'Loss: {loss.item():.6f}')
        
        metrics = self._calculate_metrics(true_labels, predictions)
        metrics['train_loss'] = total_loss / len(self.train_loader)
        
        logger.info(f"Training metrics: {metrics}")
        return metrics

    def validate(self):
        self.model.eval()
        val_loss = 0
        predictions = []
        true_labels = []
        
        with torch.no_grad():
            for data, target, spike_annot in tqdm(self.val_loader):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                
                val_loss += self.criterion(output, target).item()
                
                pred = output.max(1)[1]
                predictions.extend(pred.cpu().numpy())
                true_labels.extend(target.cpu().numpy())
        
        metrics = self._calculate_metrics(true_labels, predictions)
        metrics['val_loss'] = val_loss / len(self.val_loader)
        
        logger.info(f"Validation metrics: {metrics}")
        return metrics

    def train(self, epochs):
        best_f1 = 0
        patience_counter = 0
        
        for epoch in range(epochs):
            logger.info(f"Epoch {epoch+1}/{epochs}")
            
            # Training phase
            train_metrics = self.train_epoch()
            
            # Validation phase
            val_metrics = self.validate()
            
            # Learning rate scheduling
            self.scheduler.step(val_metrics['val_loss'])
            
            # Early stopping
            if val_metrics['f1'] > best_f1:
                best_f1 = val_metrics['f1']
                patience_counter = 0
                # Save best model
                torch.save(self.model.state_dict(), 'best_model.pt')
                logger.info(f"New best model saved with F1: {best_f1:.4f}")
            else:
                patience_counter += 1
                
            if patience_counter >= self.config['early_stopping_patience']:
                logger.info("Early stopping triggered")
                break

def generate_dummy_data(num_samples: int = 1000,
                       num_channels: int = 32,
                       sequence_length: int = 1000,
                       sampling_rate: int = 300) -> Tuple[np.ndarray, ...]:
    """
    Generate synthetic EEG data with epileptic patterns.
    
    Args:
        num_samples: Number of EEG recordings to generate
        num_channels: Number of EEG channels
        sequence_length: Length of each recording
        sampling_rate: Sampling rate in Hz
    
    Returns:
        Tuple containing:
        - raw_train_data: Training data
        - train_labels: Training labels
        - train_spike_annotations: Spike annotations for training
        - raw_val_data: Validation data
        - val_labels: Validation labels
        - val_spike_annotations: Spike annotations for validation
    """
    
    def generate_normal_eeg(num_samples):
        # Generate background EEG as pink noise
        eeg = np.random.randn(num_samples, num_channels, sequence_length)
        
        # Add alpha rhythm (8-12 Hz)
        t = np.linspace(0, sequence_length/sampling_rate, sequence_length)
        alpha = np.sin(2 * np.pi * 10 * t)  # 10 Hz oscillation
        eeg += 0.5 * alpha.reshape(1, 1, -1)
        
        # Add beta rhythm (13-30 Hz)
        beta = np.sin(2 * np.pi * 20 * t)  # 20 Hz oscillation
        eeg += 0.3 * beta.reshape(1, 1, -1)
        
        return eeg
    
    def generate_epileptic_spikes(eeg, num_spikes=5):
        # Create spike template
        spike_length = int(0.1 * sampling_rate)  # 100ms spike
        spike = np.zeros(spike_length)
        spike[:spike_length//2] = np.linspace(0, 1, spike_length//2)
        spike[spike_length//2:] = np.linspace(1, 0, spike_length - spike_length//2)
        
        # Add spikes at random locations
        spike_locations = np.zeros((eeg.shape[0], num_channels, sequence_length))
        
        for i in range(eeg.shape[0]):
            if np.random.rand() > 0.5:  # 50% chance of epileptic sample
                for _ in range(num_spikes):
                    # Random start position
                    start = np.random.randint(0, sequence_length - spike_length)
                    
                    # Select random subset of channels for the spike
                    channels = np.random.choice(num_channels, size=num_channels//4, replace=False)
                    
                    # Add spike with random amplitude variation
                    amplitude = np.random.uniform(1.5, 2.5)
                    eeg[i, channels, start:start+spike_length] += amplitude * spike
                    spike_locations[i, channels, start:start+spike_length] = 1
                    
                    # Add propagation effect to neighboring channels
                    for ch in channels:
                        if ch > 0:
                            eeg[i, ch-1, start:start+spike_length] += 0.3 * amplitude * spike
                        if ch < num_channels-1:
                            eeg[i, ch+1, start:start+spike_length] += 0.3 * amplitude * spike
        
        return eeg, spike_locations

    # Generate training data
    total_samples = num_samples
    train_samples = int(0.8 * total_samples)
    
    # Generate normal EEG
    raw_train_data = generate_normal_eeg(train_samples)
    raw_val_data = generate_normal_eeg(total_samples - train_samples)
    
    # Add epileptic spikes and get annotations
    raw_train_data, train_spike_annotations = generate_epileptic_spikes(raw_train_data)
    raw_val_data, val_spike_annotations = generate_epileptic_spikes(raw_val_data)
    
    # Generate labels (1 if contains spikes, 0 otherwise)
    train_labels = (train_spike_annotations.sum(axis=(1, 2)) > 0).astype(int)
    val_labels = (val_spike_annotations.sum(axis=(1, 2)) > 0).astype(int)
    
    return (raw_train_data, train_labels, train_spike_annotations,
            raw_val_data, val_labels, val_spike_annotations)

def main():
    """Main execution function."""
    # Configuration
    config = {
        'batch_size': 32,
        'learning_rate': 1e-4,
        'weight_decay': 1e-4,
        'epochs': 10,  # Reduced for demonstration
        'early_stopping_patience': 3,
        'sampling_rate': 300,
        'spike_lambda': 0.1,
        'consistency_lambda': 0.05,
        'spike_margin': 0.5,
        'num_channels': 32,
        'sequence_length': 1000,
        'num_samples': 1000
    }

    try:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        logger.info(f"Using device: {device}")

        # Generate dummy data
        logger.info("Generating dummy data...")
        (raw_train_data, train_labels, train_spike_annotations,
         raw_val_data, val_labels, val_spike_annotations) = generate_dummy_data(
            num_samples=config['num_samples'],
            num_channels=config['num_channels'],
            sequence_length=config['sequence_length'],
            sampling_rate=config['sampling_rate']
        )

        # Preprocess data
        logger.info("Preprocessing data...")
        train_data = preprocess_eeg(raw_train_data, config['sampling_rate'])
        val_data = preprocess_eeg(raw_val_data, config['sampling_rate'])

        # Create datasets
        train_transform = EEGAugmentation()
        train_dataset = EEGDataset(
            train_data, 
            train_labels, 
            spike_annotations=train_spike_annotations,
            transform=train_transform
        )
        val_dataset = EEGDataset(
            val_data, 
            val_labels,
            spike_annotations=val_spike_annotations
        )

        # Create dataloaders
        train_loader = DataLoader(
            train_dataset, 
            batch_size=config['batch_size'],
            shuffle=True, 
            num_workers=4,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset, 
            batch_size=config['batch_size'],
            shuffle=False, 
            num_workers=4,
            pin_memory=True
        )

        # Initialize model and training components
        model = EpilepsyDetectionModel(
            num_channels=config['num_channels'],
            sampling_rate=config['sampling_rate']
        ).to(device)

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(
            model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='min', 
            factor=0.5, 
            patience=5, 
            verbose=True
        )

        # Initialize trainer
        trainer = Trainer(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            config=config
        )

        # Train model
        logger.info("Starting training...")
        trainer.train(config['epochs'])

        logger.info("Training completed successfully!")

    except Exception as e:
        logger.error(f"An error occurred during execution: {str(e)}")
        raise

if __name__ == "__main__":
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)
    
    main()

INFO:__main__:Using device: cuda
INFO:__main__:Generating dummy data...
INFO:__main__:Preprocessing data...
INFO:__main__:Starting training...
INFO:__main__:Epoch 1/10
  0%|          | 0/25 [00:05<?, ?it/s]
ERROR:__main__:An error occurred during execution: DataLoader worker (pid(s) 17660, 23240, 4932, 7336) exited unexpectedly


RuntimeError: DataLoader worker (pid(s) 17660, 23240, 4932, 7336) exited unexpectedly