## 1. Setup and Installation

In [1]:
# Install required packages
!pip install scipy scikit-learn matplotlib seaborn tqdm



In [None]:
import os
import sys
import random
import numpy as np
import pickle
from pathlib import Path
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional, Union
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.cuda.amp import GradScaler, autocast

from scipy import signal
from scipy.io import loadmat
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

import math
import json
from datetime import datetime

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2. Configuration

In [None]:
@dataclass
class STFTConfig:
    """STFT preprocessing configuration"""
    freq_range: Tuple[float, float] = (0.5, 50.0)
    target_freq_bins: int = 129
    target_time_bins: int = 126

@dataclass
class ModelConfig:
    """Model architecture configuration"""
    model_type: str = 'dynamic'
    gru_hidden_size: int = 128
    gru_num_layers: int = 2
    num_attention_heads: int = 4
    dropout: float = 0.5
    n_classes: int = 2

@dataclass
class TrainingConfig:
    """Training configuration"""
    learning_rate: float = 0.001
    weight_decay: float = 1e-4
    batch_size: int = 32
    num_epochs: int = 30
    use_scheduler: bool = True
    scheduler_params: Dict = field(default_factory=lambda: {'T_max': 30, 'eta_min': 1e-6})
    early_stopping: bool = True
    patience: int = 10
    min_delta: float = 0.001
    grad_clip: Optional[float] = 1.0
    gradient_accumulation_steps: int = 4
    use_amp: bool = True
    device: str = 'cuda'

@dataclass
class DataConfig:
    """Data loading configuration"""
    # Update these paths for your Kaggle dataset
    base_dir: str = '/kaggle/input'  # Kaggle input directory
    augmented_dir: str = 'eeg-augmented-datasets'  # Your dataset folder name
    datasets: List[str] = field(default_factory=lambda: ['DEAP', 'GAMEEMO', 'SEEDIV'])
    train_ratio: float = 0.7
    val_ratio: float = 0.15
    test_ratio: float = 0.15
    balance_classes: bool = True
    cache_stft: bool = False
    num_workers: int = 2

@dataclass
class ExperimentConfig:
    """Complete experiment configuration"""
    experiment_name: str = "gru_xnet_kaggle"
    output_dir: str = '/kaggle/working/outputs'  # Kaggle working directory
    stft: STFTConfig = field(default_factory=STFTConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    data: DataConfig = field(default_factory=DataConfig)
    log_interval: int = 10
    seed: int = 42
    deterministic: bool = True

# Create configuration
config = ExperimentConfig()

# Create output directories
os.makedirs(config.output_dir, exist_ok=True)
os.makedirs(os.path.join(config.output_dir, 'checkpoints'), exist_ok=True)
os.makedirs(os.path.join(config.output_dir, 'figures'), exist_ok=True)

print("Configuration created successfully!")
print(f"Output directory: {config.output_dir}")
print(f"Datasets: {config.data.datasets}")

## 3. Utility Functions

In [None]:
def set_seed(seed: int, deterministic: bool = True):
    """Set random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_device(preferred_device: str = 'cuda') -> torch.device:
    """Get device for training"""
    if preferred_device == 'cuda' and torch.cuda.is_available():
        return torch.device('cuda')
    return torch.device('cpu')

class EarlyStopping:
    """Early stopping callback"""
    def __init__(self, patience: int = 10, min_delta: float = 0.001, mode: str = 'min'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.should_stop = False
    
    def __call__(self, score: float):
        if self.best_score is None:
            self.best_score = score
        else:
            improved = score < (self.best_score - self.min_delta) if self.mode == 'min' else score > (self.best_score + self.min_delta)
            if improved:
                self.best_score = score
                self.counter = 0
            else:
                self.counter += 1
                if self.counter >= self.patience:
                    self.should_stop = True

def plot_training_history(history: Dict, save_path: str = None):
    """Plot training curves"""
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss
    if 'loss' in history['train']:
        axes[0].plot(history['train']['loss'], label='Train')
    if 'loss' in history['val']:
        axes[0].plot(history['val']['loss'], label='Val')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # Accuracy
    if 'accuracy' in history['train']:
        axes[1].plot(history['train']['accuracy'], label='Train')
    if 'accuracy' in history['val']:
        axes[1].plot(history['val']['accuracy'], label='Val')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Accuracy')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

print("Utility functions defined!")

## 4. STFT Preprocessing

In [None]:
class STFTPreprocessor:
    """STFT transformation for EEG signals"""
    
    def __init__(self, sampling_rate: int, nperseg: int = 256, noverlap: int = None,
                 nfft: int = None, window: str = 'hann', freq_range: Tuple[float, float] = None):
        self.sampling_rate = sampling_rate
        self.nperseg = nperseg
        self.noverlap = noverlap if noverlap is not None else nperseg // 2
        self.nfft = nfft if nfft is not None else nperseg
        self.window = window
        self.freq_range = freq_range
        
        # Calculate frequency indices
        self.n_freq_bins = self.nfft // 2 + 1
        if self.freq_range is not None:
            freq_resolution = self.sampling_rate / self.nfft
            self.freq_start_idx = int(self.freq_range[0] / freq_resolution)
            self.freq_end_idx = int(self.freq_range[1] / freq_resolution) + 1
            self.n_freq_bins = self.freq_end_idx - self.freq_start_idx
        else:
            self.freq_start_idx = 0
            self.freq_end_idx = self.n_freq_bins
    
    def compute_stft(self, signal_data: np.ndarray) -> np.ndarray:
        """Compute STFT for single channel"""
        f, t, Zxx = signal.stft(signal_data, fs=self.sampling_rate, window=self.window,
                                nperseg=self.nperseg, noverlap=self.noverlap, nfft=self.nfft)
        stft_magnitude = np.abs(Zxx)
        return stft_magnitude[self.freq_start_idx:self.freq_end_idx, :]
    
    def transform(self, eeg_data: np.ndarray) -> np.ndarray:
        """Transform EEG data (n_channels, n_timepoints) to STFT"""
        n_channels = eeg_data.shape[0]
        stft_temp = self.compute_stft(eeg_data[0])
        n_freq_bins, n_time_bins = stft_temp.shape
        
        stft_features = np.zeros((n_channels, n_freq_bins, n_time_bins), dtype=np.float32)
        stft_features[0] = stft_temp
        
        for ch in range(1, n_channels):
            stft_features[ch] = self.compute_stft(eeg_data[ch])
        
        return stft_features

class MultiDatasetSTFTPreprocessor:
    """Handle multiple datasets with different sampling rates"""
    
    def __init__(self, dataset_configs: Dict, target_freq_bins: int = 129, target_time_bins: int = 126):
        self.target_freq_bins = target_freq_bins
        self.target_time_bins = target_time_bins
        self.preprocessors = {name: STFTPreprocessor(**cfg) for name, cfg in dataset_configs.items()}
    
    def transform(self, eeg_data: np.ndarray, dataset_name: str, standardize: bool = True) -> np.ndarray:
        """Transform and standardize STFT output"""
        stft_features = self.preprocessors[dataset_name].transform(eeg_data)
        
        if standardize:
            from scipy.ndimage import zoom
            n_channels, curr_freq, curr_time = stft_features.shape
            
            if curr_freq != self.target_freq_bins or curr_time != self.target_time_bins:
                zoom_factors = (1, self.target_freq_bins / curr_freq, self.target_time_bins / curr_time)
                stft_features = zoom(stft_features, zoom_factors, order=1)
        
        return stft_features

def create_dataset_stft_configs(target_freq_range: Tuple[float, float] = (0.5, 50.0)) -> Dict:
    """Create STFT configs for each dataset"""
    return {
        'DEAP': {'sampling_rate': 128, 'nperseg': 256, 'freq_range': target_freq_range},
        'GAMEEMO': {'sampling_rate': 128, 'nperseg': 256, 'freq_range': target_freq_range},
        'SEEDIV': {'sampling_rate': 200, 'nperseg': 400, 'freq_range': target_freq_range}
    }

print("STFT preprocessing classes defined!")

## 5. Model Architecture

In [None]:
class ChannelIndependentCNN(nn.Module):
    """Channel-independent CNN for spatial feature extraction"""
    
    def __init__(self, n_freq_bins: int):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool3(x)
        x = self.dropout(x)
        return x

class MultiHeadSelfAttention(nn.Module):
    """Multi-head self-attention mechanism"""
    
    def __init__(self, d_model: int, num_heads: int = 4, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        attn_output = torch.matmul(attn_weights, V)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        output = self.W_o(attn_output)
        output = self.dropout(output)
        output = self.layer_norm(x + output)
        
        return output

class gru_xnetDynamic(nn.Module):
    """gru_xnet: CNN-BiGRU-Self Attention Network"""
    
    def __init__(self, n_channels: int, n_freq_bins: int, n_time_bins: int, n_classes: int,
                 gru_hidden_size: int = 128, num_attention_heads: int = 4, dropout: float = 0.5):
        super().__init__()
        
        self.n_channels = n_channels
        self.gru_hidden_size = gru_hidden_size
        
        # Channel-independent CNNs
        self.channel_cnns = nn.ModuleList([ChannelIndependentCNN(n_freq_bins) for _ in range(n_channels)])
        
        # Calculate dimensions after CNN
        self.freq_reduced = n_freq_bins // 8
        self.time_reduced = n_time_bins // 8
        self.feature_dim_per_channel = 128 * self.freq_reduced
        
        # BiGRU
        self.bigru = nn.GRU(n_channels * self.feature_dim_per_channel, gru_hidden_size,
                           num_layers=2, batch_first=True, bidirectional=True,
                           dropout=dropout if dropout > 0 else 0)
        
        # Multi-head self-attention
        self.attention = MultiHeadSelfAttention(gru_hidden_size * 2, num_attention_heads, dropout)
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(gru_hidden_size * 2, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, n_classes)
        )
    
    def forward(self, x_stft):
        batch_size = x_stft.shape[0]
        actual_channels = x_stft.shape[1]
        
        # Apply CNNs
        channel_features = []
        for i in range(min(self.n_channels, actual_channels)):
            channel_feat = self.channel_cnns[i](x_stft[:, i:i+1, :, :])
            channel_feat = channel_feat.permute(0, 3, 1, 2).reshape(batch_size, self.time_reduced, -1)
            channel_features.append(channel_feat)
        
        # Handle variable channel counts
        if actual_channels < self.n_channels:
            padding = torch.zeros(batch_size, self.time_reduced,
                                (self.n_channels - actual_channels) * self.feature_dim_per_channel,
                                device=x_stft.device, dtype=x_stft.dtype)
            channel_features.append(padding)
        
        temporal_features = torch.cat(channel_features, dim=2)
        
        # BiGRU
        gru_output, _ = self.bigru(temporal_features)
        
        # Attention
        attn_output = self.attention(gru_output)
        
        # Global pooling and classification
        pooled = torch.mean(attn_output, dim=1)
        output = self.classifier(pooled)
        
        return output

def create_gru_xnet_model(n_channels: int, n_freq_bins: int, n_time_bins: int,
                       n_classes: int, **kwargs) -> nn.Module:
    """Factory function to create model"""
    return gru_xnetDynamic(n_channels, n_freq_bins, n_time_bins, n_classes, **kwargs)

print("Model architecture defined!")

## 6. Data Loading

In [None]:
def custom_collate_fn(batch):
    """Collate function to handle variable channel counts"""
    stft_features_list = [item[0] for item in batch]
    labels = torch.stack([item[1] for item in batch])
    
    max_channels = max([x.shape[0] for x in stft_features_list])
    n_freq_bins = stft_features_list[0].shape[1]
    n_time_bins = stft_features_list[0].shape[2]
    
    padded_features = []
    for features in stft_features_list:
        n_channels = features.shape[0]
        if n_channels < max_channels:
            padding = torch.zeros(max_channels - n_channels, n_freq_bins, n_time_bins, dtype=features.dtype)
            features = torch.cat([features, padding], dim=0)
        padded_features.append(features)
    
    stft_features = torch.stack(padded_features, dim=0)
    return stft_features, labels

class gru_xnetDataset(Dataset):
    """PyTorch Dataset for gru_xnet"""
    
    def __init__(self, data: List[np.ndarray], labels: np.ndarray, dataset_names: np.ndarray,
                 stft_preprocessor: MultiDatasetSTFTPreprocessor):
        self.data = data
        self.labels = labels
        self.dataset_names = dataset_names
        self.stft_preprocessor = stft_preprocessor
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        eeg_sample = self.data[idx]
        dataset_name = self.dataset_names[idx]
        stft_features = self.stft_preprocessor.transform(eeg_sample, dataset_name, standardize=True)
        stft_features = torch.from_numpy(stft_features).float()
        label = torch.tensor(self.labels[idx]).long()
        return stft_features, label

def load_augmented_dataset(base_dir: str, dataset_name: str) -> Tuple:
    """Load augmented dataset from pickle files"""
    # Adjust this function based on your Kaggle dataset structure
    dataset_path = os.path.join(base_dir, f'{dataset_name}_augmented')
    
    # Load your augmented data
    # Adjust file paths based on actual structure
    with open(os.path.join(dataset_path, 'data.pkl'), 'rb') as f:
        data = pickle.load(f)
    with open(os.path.join(dataset_path, 'labels.pkl'), 'rb') as f:
        labels = pickle.load(f)
    with open(os.path.join(dataset_path, 'subjects.pkl'), 'rb') as f:
        subjects = pickle.load(f)
    
    return data, labels, subjects

def create_data_loaders(config: ExperimentConfig) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Create train/val/test data loaders"""
    print("Loading datasets...")
    
    all_data = []
    all_labels = []
    all_dataset_names = []
    
    augmented_dir = os.path.join(config.data.base_dir, config.data.augmented_dir)
    
    for dataset_name in config.data.datasets:
        print(f"Loading {dataset_name}...")
        data, labels, subjects = load_augmented_dataset(augmented_dir, dataset_name)
        
        all_data.append(data)
        all_labels.append(labels)
        all_dataset_names.extend([dataset_name] * len(labels))
        print(f"  Loaded {len(labels)} samples")
    
    # Combine
    labels = np.concatenate(all_labels)
    dataset_names = np.array(all_dataset_names)
    
    # Split data
    n_samples = len(labels)
    indices = np.random.permutation(n_samples)
    
    n_train = int(n_samples * config.data.train_ratio)
    n_val = int(n_samples * config.data.val_ratio)
    
    train_idx = indices[:n_train]
    val_idx = indices[n_train:n_train + n_val]
    test_idx = indices[n_train + n_val:]
    
    # Extract data for splits
    def extract_split(split_indices):
        split_data = []
        for i in split_indices:
            # Find dataset and local index
            dataset_idx = 0
            local_idx = i
            cumsum = 0
            for j, data_array in enumerate(all_data):
                if i < cumsum + len(data_array):
                    dataset_idx = j
                    local_idx = i - cumsum
                    break
                cumsum += len(data_array)
            split_data.append(all_data[dataset_idx][local_idx])
        return split_data, labels[split_indices], dataset_names[split_indices]
    
    train_data, train_labels, train_datasets = extract_split(train_idx)
    val_data, val_labels, val_datasets = extract_split(val_idx)
    test_data, test_labels, test_datasets = extract_split(test_idx)
    
    print(f"\nSplit: Train={len(train_labels)}, Val={len(val_labels)}, Test={len(test_labels)}")
    
    # Create STFT preprocessor
    stft_configs = create_dataset_stft_configs(config.stft.freq_range)
    stft_preprocessor = MultiDatasetSTFTPreprocessor(stft_configs, config.stft.target_freq_bins,
                                                     config.stft.target_time_bins)
    
    # Create datasets
    train_dataset = gru_xnetDataset(train_data, train_labels, train_datasets, stft_preprocessor)
    val_dataset = gru_xnetDataset(val_data, val_labels, val_datasets, stft_preprocessor)
    test_dataset = gru_xnetDataset(test_data, test_labels, test_datasets, stft_preprocessor)
    
    # Create loaders
    train_loader = DataLoader(train_dataset, batch_size=config.training.batch_size, shuffle=True,
                             num_workers=config.data.num_workers, collate_fn=custom_collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=config.training.batch_size, shuffle=False,
                           num_workers=config.data.num_workers, collate_fn=custom_collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=config.training.batch_size, shuffle=False,
                            num_workers=config.data.num_workers, collate_fn=custom_collate_fn)
    
    return train_loader, val_loader, test_loader

print("Data loading functions defined!")

## 7. Training Loop

In [None]:
class Trainer:
    """Training manager"""
    
    def __init__(self, config: ExperimentConfig):
        self.config = config
        set_seed(config.seed, config.deterministic)
        self.device = get_device(config.training.device)
        print(f"Using device: {self.device}")
        
        # Create data loaders
        self.train_loader, self.val_loader, self.test_loader = create_data_loaders(config)
        
        # Get dimensions
        sample_batch = next(iter(self.train_loader))
        n_channels, n_freq_bins, n_time_bins = sample_batch[0].shape[1:]
        print(f"\nData dimensions: channels={n_channels}, freq={n_freq_bins}, time={n_time_bins}")
        
        # Create model
        self.model = create_gru_xnet_model(
            n_channels=n_channels, n_freq_bins=n_freq_bins, n_time_bins=n_time_bins,
            n_classes=config.model.n_classes, gru_hidden_size=config.model.gru_hidden_size,
            num_attention_heads=config.model.num_attention_heads, dropout=config.model.dropout
        ).to(self.device)
        
        n_params = sum(p.numel() for p in self.model.parameters())
        print(f"Model parameters: {n_params:,}")
        
        # Training components
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=config.training.learning_rate,
                                   weight_decay=config.training.weight_decay)
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer,
                                                              **config.training.scheduler_params)
        self.scaler = GradScaler() if config.training.use_amp else None
        self.early_stopping = EarlyStopping(config.training.patience, config.training.min_delta)
        
        self.history = {'train': {'loss': [], 'accuracy': []}, 'val': {'loss': [], 'accuracy': []}}
        self.best_val_loss = float('inf')
    
    def train_epoch(self, epoch: int):
        """Train one epoch"""
        self.model.train()
        total_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.config.training.num_epochs}")
        
        for batch_idx, (stft_features, labels) in enumerate(pbar):
            stft_features, labels = stft_features.to(self.device), labels.to(self.device)
            
            if self.scaler:
                with autocast():
                    outputs = self.model(stft_features)
                    loss = self.criterion(outputs, labels)
                    loss = loss / self.config.training.gradient_accumulation_steps
                
                self.scaler.scale(loss).backward()
                
                if (batch_idx + 1) % self.config.training.gradient_accumulation_steps == 0:
                    if self.config.training.grad_clip:
                        self.scaler.unscale_(self.optimizer)
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                      self.config.training.grad_clip)
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.optimizer.zero_grad()
            else:
                outputs = self.model(stft_features)
                loss = self.criterion(outputs, labels)
                loss = loss / self.config.training.gradient_accumulation_steps
                loss.backward()
                
                if (batch_idx + 1) % self.config.training.gradient_accumulation_steps == 0:
                    if self.config.training.grad_clip:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                      self.config.training.grad_clip)
                    self.optimizer.step()
                    self.optimizer.zero_grad()
            
            total_loss += loss.item() * self.config.training.gradient_accumulation_steps
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({'loss': total_loss / (batch_idx + 1), 'acc': 100. * correct / total})
        
        return total_loss / len(self.train_loader), 100. * correct / total
    
    @torch.no_grad()
    def validate(self):
        """Validate"""
        self.model.eval()
        total_loss = 0.0
        correct = 0
        total = 0
        
        for stft_features, labels in tqdm(self.val_loader, desc="Validation"):
            stft_features, labels = stft_features.to(self.device), labels.to(self.device)
            outputs = self.model(stft_features)
            loss = self.criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        
        return total_loss / len(self.val_loader), 100. * correct / total
    
    def train(self):
        """Main training loop"""
        print("\n" + "="*60)
        print("Starting Training")
        print("="*60 + "\n")
        
        for epoch in range(self.config.training.num_epochs):
            train_loss, train_acc = self.train_epoch(epoch)
            val_loss, val_acc = self.validate()
            
            self.history['train']['loss'].append(train_loss)
            self.history['train']['accuracy'].append(train_acc)
            self.history['val']['loss'].append(val_loss)
            self.history['val']['accuracy'].append(val_acc)
            
            print(f"\nEpoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%")
            print(f"           Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%")
            
            # Save best model
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                torch.save(self.model.state_dict(),
                          os.path.join(config.output_dir, 'checkpoints', 'best_model.pth'))
                print("  Saved best model")
            
            self.scheduler.step()
            
            # Early stopping
            self.early_stopping(val_loss)
            if self.early_stopping.should_stop:
                print(f"\nEarly stopping at epoch {epoch+1}")
                break
        
        print("\n" + "="*60)
        print("Training Complete!")
        print("="*60)
    
    @torch.no_grad()
    def test(self):
        """Test on test set"""
        # Load best model
        self.model.load_state_dict(torch.load(
            os.path.join(config.output_dir, 'checkpoints', 'best_model.pth')))
        self.model.eval()
        
        total_loss = 0.0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []
        
        for stft_features, labels in tqdm(self.test_loader, desc="Testing"):
            stft_features, labels = stft_features.to(self.device), labels.to(self.device)
            outputs = self.model(stft_features)
            loss = self.criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
        
        test_loss = total_loss / len(self.test_loader)
        test_acc = 100. * correct / total
        
        print(f"\nTest Results:")
        print(f"  Loss: {test_loss:.4f}")
        print(f"  Accuracy: {test_acc:.2f}%")
        
        return test_loss, test_acc, np.array(all_preds), np.array(all_labels)

print("Trainer class defined!")

## 8. Run Training

In [None]:
# Set seed
set_seed(config.seed, config.deterministic)

# Create trainer
trainer = Trainer(config)

# Train
trainer.train()

# Save training history
with open(os.path.join(config.output_dir, 'history.json'), 'w') as f:
    json.dump(trainer.history, f, indent=4)

# Plot training curves
plot_training_history(trainer.history, 
                     save_path=os.path.join(config.output_dir, 'figures', 'training_curves.png'))

## 9. Test and Evaluation

In [None]:
# Test
test_loss, test_acc, predictions, labels = trainer.test()

# Confusion matrix
cm = confusion_matrix(labels, predictions)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title(f'Confusion Matrix (Acc: {test_acc:.2f}%)')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.savefig(os.path.join(config.output_dir, 'figures', 'confusion_matrix.png'), dpi=300)
plt.show()

# Classification report
print("\nClassification Report:")
print(classification_report(labels, predictions, target_names=['Negative', 'Positive']))

## 10. Save Results

In [None]:
# Save final results
results = {
    'test_loss': float(test_loss),
    'test_accuracy': float(test_acc),
    'config': {
        'batch_size': config.training.batch_size,
        'learning_rate': config.training.learning_rate,
        'num_epochs': config.training.num_epochs,
        'datasets': config.data.datasets
    }
}

with open(os.path.join(config.output_dir, 'results.json'), 'w') as f:
    json.dump(results, f, indent=4)

print("\nResults saved!")
print(f"Output directory: {config.output_dir}")
print(f"Best model: {os.path.join(config.output_dir, 'checkpoints', 'best_model.pth')}")