# üî¨ ResNet-101 Comprehensive GridSearch

**Parameters to Explore:**
1. ‚úÖ **Optimizer**: Adam, AdamW, SGD
2. ‚úÖ **Activation Function**: ReLU, LeakyReLU, ELU
3. ‚úÖ **L1 Regularization**: [0, 1e-5, 1e-4]
4. ‚úÖ **L2 Regularization (weight_decay)**: [0, 1e-4, 1e-3]
5. ‚úÖ **Early Stopping**: patience=10
6. ‚úÖ **LR Scheduler**: CosineAnnealingLR
7. ‚úÖ **Auto-save**: Results saved after each config

**Fixed (Same as KAN):**
- Loss: SoftFocalLoss (gamma=3.0)
- Data: Hybrid loading (oversample + weighted)
- LR: 3e-4
- Batch size: 16
- Input: 224x224
- Pretrained: ImageNet weights

## üì¶ CELL 1: Setup & Imports

In [23]:
import os
from pathlib import Path
import random
import time
import gc
import json
import pickle
import warnings
from datetime import datetime
from itertools import product
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    f1_score, precision_score, recall_score, 
    confusion_matrix, accuracy_score
)
import torchvision.models as models

print("="*80)
print(" ResNet-101 Comprehensive GridSearch ".center(80, "="))
print(" Optimizer + Activation + Regularization ".center(80))
print("="*80)
print(f"\nStarted: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

# ===== REPRODUCIBILITY =====
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

seed_everything(42)
print("‚úÖ Random seed: 42")

# ===== DEVICE =====
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"‚úÖ Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    torch.cuda.empty_cache()

# ===== PATHS =====
DATA_PKG = Path("data_package")
SPEC_DIR = Path("spec_hr_out")
RESULTS_DIR = Path("gridsearch_results")
RESULTS_DIR.mkdir(exist_ok=True)

assert DATA_PKG.exists(), "‚ùå data_package/ not found"
assert SPEC_DIR.exists(), "‚ùå spec_hr_out/ not found"
print(f"\n‚úÖ Paths OK")
print(f"   Data: {DATA_PKG}")
print(f"   Spectrograms: {SPEC_DIR}")
print(f"   Results will be saved to: {RESULTS_DIR}")

                    Optimizer + Activation + Regularization                     

Started: 2026-01-07 00:25:52
‚úÖ Random seed: 42
‚úÖ Device: cuda:0
   GPU: NVIDIA GeForce RTX 5060 Ti
   Memory: 17.10 GB

‚úÖ Paths OK
   Data: data_package
   Spectrograms: spec_hr_out
   Results will be saved to: gridsearch_results


## üìä CELL 2: Load Data

In [24]:
print("\n" + "="*80)
print(" LOADING DATA ".center(80, "="))
print("="*80)

meta_use = pd.read_csv(DATA_PKG / "meta_use.csv")
lbl = np.load(DATA_PKG / "labels.npz", allow_pickle=True)
y_soft = lbl["y_soft"]
w_conf = lbl["w_conf"]
classes = [str(c) for c in lbl["classes"]]
y_hard = y_soft.argmax(axis=1)

print(f"\n‚úÖ Data loaded:")
print(f"   Metadata: {meta_use.shape}")
print(f"   Labels: {y_soft.shape}")
print(f"   Classes: {classes}")

# Class distribution
print("\nüìä CLASS DISTRIBUTION:")
print("-"*80)
for i, cls in enumerate(classes):
    count = (y_hard == i).sum()
    pct = 100 * count / len(y_hard)
    print(f"  {cls:10s}: {count:5d} ({pct:5.1f}%)")
print(f"  {'TOTAL':10s}: {len(y_hard):5d}")

# Create 3-fold CV
print("\nüìÇ Creating stratified folds...")
N_FOLDS = 3
skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=42)
folds = list(skf.split(meta_use, y_hard))
print(f"‚úÖ Created {N_FOLDS}-fold CV")
for i, (tr_idx, va_idx) in enumerate(folds):
    print(f"   Fold {i}: train={len(tr_idx)}, val={len(va_idx)}")



‚úÖ Data loaded:
   Metadata: (17089, 3)
   Labels: (17089, 6)
   Classes: ['seizure', 'lpd', 'gpd', 'lrda', 'grda', 'other']

üìä CLASS DISTRIBUTION:
--------------------------------------------------------------------------------
  seizure   :  2716 ( 15.9%)
  lpd       :  2583 ( 15.1%)
  gpd       :  1814 ( 10.6%)
  lrda      :   936 (  5.5%)
  grda      :  1835 ( 10.7%)
  other     :  7205 ( 42.2%)
  TOTAL     : 17089

üìÇ Creating stratified folds...
‚úÖ Created 3-fold CV
   Fold 0: train=11392, val=5697
   Fold 1: train=11393, val=5696
   Fold 2: train=11393, val=5696


## ü§ñ CELL 3: Dataset Class

In [25]:
print("\n" + "="*80)
print(" DATASET CLASS ".center(80, "="))
print("="*80)

class SpecDataset(Dataset):
    """EEG Spectrogram Dataset"""
    
    def __init__(self, df, root_dir, y_soft, w_conf, F_target=81, T_target=600):
        self.df = df.reset_index(drop=True)
        self.root = Path(root_dir)
        self.y_soft = y_soft
        self.w_conf = w_conf
        self.F_target = F_target
        self.T_target = T_target

    def __len__(self):
        return len(self.df)

    def _center_crop_pad(self, x):
        C, F, T = x.shape
        # Frequency
        if F >= self.F_target:
            f0 = (F - self.F_target) // 2
            x = x[:, f0:f0+self.F_target, :]
        else:
            pad = self.F_target - F
            x = np.pad(x, ((0,0),(pad//2, pad-pad//2),(0,0)), mode="constant")
        # Time
        if T >= self.T_target:
            t0 = (T - self.T_target) // 2
            x = x[:, :, t0:t0+self.T_target]
        else:
            pad = self.T_target - T
            x = np.pad(x, ((0,0),(0,0),(pad//2, pad-pad//2)), mode="constant")
        return x.copy()

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        eid = int(row.eeg_id)
        
        npz = np.load(self.root / f"{eid}_hr.npz")
        x = npz["x"]
        
        x = self._center_crop_pad(x)
        x = torch.from_numpy(x).float()
        
        # Resize to 224x224
        x = F.interpolate(x.unsqueeze(0), size=(224, 224),
                          mode="bilinear", align_corners=False).squeeze(0)
        
        y = torch.from_numpy(self.y_soft[self.df.index[idx]]).float()
        w = torch.tensor(self.w_conf[self.df.index[idx]], dtype=torch.float32)
        
        return x, y, w

print("‚úÖ Dataset class ready")


‚úÖ Dataset class ready


## üèóÔ∏è CELL 4: ResNet-101 with Configurable Activation

In [26]:
class BaseModel4Ch(nn.Module):
    """Base class for adapting models to 4-channel input"""
    
    def adapt_first_conv(self, old_conv):
        new_conv = nn.Conv2d(
            in_channels=4,
            out_channels=old_conv.out_channels,
            kernel_size=old_conv.kernel_size,
            stride=old_conv.stride,
            padding=old_conv.padding,
            bias=False if old_conv.bias is None else True
        )
        with torch.no_grad():
            new_conv.weight[:, :3] = old_conv.weight
            new_conv.weight[:, 3:] = old_conv.weight.mean(dim=1, keepdim=True) * 0.33
            if old_conv.bias is not None:
                new_conv.bias = old_conv.bias
        return new_conv


class ResNet101_4Ch_Configurable(BaseModel4Ch):
    """ResNet-101 with configurable activation function"""
    
    def __init__(self, n_classes=6, activation='relu', pretrained=True):
        super().__init__()
        
        # Load pretrained model
        if pretrained:
            self.model = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1)
        else:
            self.model = models.resnet101(weights=None)
        
        # Adapt first conv
        self.model.conv1 = self.adapt_first_conv(self.model.conv1)
        
        # Replace classifier
        self.model.fc = nn.Linear(self.model.fc.in_features, n_classes)
        
        # Replace activation functions if not ReLU
        if activation == 'leakyrelu':
            self._replace_activations()
        
        self.activation_type = activation
    
    def _replace_activations(self):
        """Replace all ReLU with LeakyReLU"""
        for name, module in self.model.named_modules():
            if isinstance(module, nn.ReLU):
                parent_name = '.'.join(name.split('.')[:-1])
                child_name = name.split('.')[-1]
                
                if parent_name:
                    parent = dict(self.model.named_modules())[parent_name]
                    setattr(parent, child_name, nn.LeakyReLU(0.01, True))
    
    def forward(self, x):
        return self.model(x)

## üéØ CELL 5: SoftFocalLoss

In [27]:
print("\n" + "="*80)
print(" SOFT FOCAL LOSS ".center(80, "="))
print("="*80)

class SoftFocalLoss(nn.Module):
    """Soft Focal Loss - same as KAN"""
    
    def __init__(self, alpha=None, gamma=3.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, logits, soft_targets, sample_weights=None):
        hard_targets = soft_targets.argmax(dim=1)
        probs = F.softmax(logits, dim=1)
        p_t = probs.gather(1, hard_targets.unsqueeze(1)).squeeze(1)
        ce_loss = -(soft_targets * F.log_softmax(logits, dim=1)).sum(dim=1)
        focal_weight = ((1 - p_t) ** self.gamma)
        loss = focal_weight * ce_loss
        
        if self.alpha is not None:
            alpha_t = self.alpha[hard_targets]
            loss = alpha_t * loss
        
        if sample_weights is not None:
            loss = loss * sample_weights
        
        return loss.mean()

print("‚úÖ SoftFocalLoss defined (gamma=3.0)")


‚úÖ SoftFocalLoss defined (gamma=3.0)


## üì¶ CELL 6: Hybrid Data Loader

In [28]:
print("\n" + "="*80)
print(" HYBRID DATA LOADING ".center(80, "="))
print("="*80)

def create_hybrid_loader(fold=0, target_ratio=0.4, weight_power=3.0, batch_size=16, verbose=False):
    tr_idx, va_idx = folds[fold]
    df_tr = meta_use.iloc[tr_idx]
    y_soft_tr, w_conf_tr = y_soft[tr_idx], w_conf[tr_idx]
    
    y_hard = y_soft_tr.argmax(axis=1)
    counts = np.bincount(y_hard, minlength=6)
    
    target = int(counts.max() * target_ratio)
    
    # Oversample
    indices_add = []
    for i in range(6):
        mask = y_hard == i
        if mask.sum() < target:
            idx = np.where(mask)[0]
            n_add = target - mask.sum()
            indices_add.extend(np.random.choice(idx, n_add, replace=True))
    
    all_idx = np.concatenate([np.arange(len(y_hard)), indices_add])
    np.random.shuffle(all_idx)
    
    df_tr_over = df_tr.iloc[all_idx].reset_index(drop=True)
    y_soft_over, w_conf_over = y_soft_tr[all_idx], w_conf_tr[all_idx]
    
    y_hard_over = y_soft_over.argmax(axis=1)
    counts_over = np.bincount(y_hard_over, minlength=6)
    
    weights = (len(y_hard_over) / (counts_over + 1)) ** weight_power
    weights = torch.FloatTensor(weights / weights.sum() * 6)
    
    sample_weights = weights[y_hard_over].numpy()
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    
    ds_tr = SpecDataset(df_tr_over, SPEC_DIR, y_soft_over, w_conf_over)
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, sampler=sampler, num_workers=0)
    
    ds_va = SpecDataset(meta_use.iloc[va_idx], SPEC_DIR, y_soft[va_idx], w_conf[va_idx])
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, num_workers=0)
    
    return dl_tr, dl_va, weights

print("‚úÖ Hybrid loader function ready")


‚úÖ Hybrid loader function ready


## üìà CELL 7: Evaluation Function

In [29]:
print("\n" + "="*80)
print(" EVALUATION FUNCTION ".center(80, "="))
print("="*80)

@torch.no_grad()
def evaluate_full(model, loader):
    model.eval()
    preds, targets = [], []
    
    for x, y, w in loader:
        x = x.to(device)
        logits = model(x)
        preds.append(logits.argmax(1).cpu().numpy())
        targets.append(y.argmax(1).cpu().numpy())
    
    y_pred = np.concatenate(preds)
    y_true = np.concatenate(targets)
    
    return {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred, average='macro', zero_division=0),
        'recall': recall_score(y_true, y_pred, average='macro', zero_division=0),
        'f1': f1_score(y_true, y_pred, average='macro', zero_division=0),
    }

print("‚úÖ Evaluation function ready")


‚úÖ Evaluation function ready


## üèãÔ∏è CELL 8: Training Function with L1/L2 Regularization

In [30]:
def train_one_config(fold, optimizer_name, activation, l1_lambda, l2_lambda, 
                     lr=3e-4, batch_size=16, epochs=30, patience=10):
    """
    Train with specific configuration
    """
    import sys
    
    # [1/6] Data
    print(f"      [1/6] Data loaders...", end=" ", flush=True)
    sys.stdout.flush()
    t0 = time.time()
    
    dl_tr, dl_va, class_weights = create_hybrid_loader(
        fold=fold, target_ratio=0.4, weight_power=3.0, batch_size=batch_size
    )
    
    print(f"‚úì ({time.time()-t0:.1f}s)", flush=True)
    
    # [2/6] Model
    print(f"      [2/6] Model ({activation})...", end=" ", flush=True)
    sys.stdout.flush()
    t0 = time.time()
    
    model = ResNet101_4Ch_Configurable(
        n_classes=6, 
        activation=activation, 
        pretrained=True
    ).to(device)
    
    print(f"‚úì ({time.time()-t0:.1f}s)", flush=True)
    
    # [3/6] Optimizer
    print(f"      [3/6] Optimizer ({optimizer_name}, L2={l2_lambda:.0e})...", end=" ", flush=True)
    sys.stdout.flush()
    
    if optimizer_name == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_lambda)
    elif optimizer_name == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=l2_lambda)
    elif optimizer_name == 'adagrad':
        optimizer = torch.optim.Adagrad(model.parameters(), lr=lr, weight_decay=l2_lambda)
    else:
        raise ValueError(f"Unknown optimizer: {optimizer_name}")
    
    print(f"‚úì", flush=True)
    
    # [4/6] Loss & Scheduler
    print(f"      [4/6] Loss & Scheduler...", end=" ", flush=True)
    sys.stdout.flush()
    
    criterion = SoftFocalLoss(alpha=class_weights.to(device), gamma=3.0)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    print(f"‚úì", flush=True)
    
    # [5/6] CUDA warmup
    print(f"      [5/6] CUDA warmup...", end=" ", flush=True)
    sys.stdout.flush()
    t0 = time.time()
    
    xb, yb, wb = next(iter(dl_tr))
    xb = xb.to(device)
    with torch.no_grad():
        _ = model(xb)
    del xb, yb, wb
    
    print(f"‚úì ({time.time()-t0:.1f}s)", flush=True)
    
    # [6/6] Training
    print(f"      [6/6] Training (patience={patience}, L1={l1_lambda:.0e})...", flush=True)
    sys.stdout.flush()
    
    best_f1, best_state, no_improve = 0.0, None, 0
    
    for epoch in range(1, epochs + 1):
        # Training
        model.train()
        train_loss, n = 0.0, 0
        
        for x, y, w in dl_tr:
            x, y, w = x.to(device), y.to(device), w.to(device)
            
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y, w)
            
            # L1 Regularization
            if l1_lambda > 0:
                l1_norm = sum(p.abs().sum() for p in model.parameters())
                loss = loss + l1_lambda * l1_norm
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            train_loss += loss.item() * x.size(0)
            n += x.size(0)
        
        train_loss /= n
        
        # Validation
        val_results = evaluate_full(model, dl_va)
        scheduler.step()
        
        # Early stopping
        if val_results['f1'] > best_f1:
            best_f1 = val_results['f1']
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"        Early stop at epoch {epoch}", flush=True)
                break
        
        # Progress
        if epoch % 5 == 0 or epoch == 1:
            print(f"        Epoch {epoch:2d}: F1={val_results['f1']:.4f}, Loss={train_loss:.4f}", flush=True)
        
        # Cleanup
        if epoch % 5 == 0:
            gc.collect()
            torch.cuda.empty_cache()
    
    # Load best
    if best_state:
        model.load_state_dict(best_state)
    
    final_results = evaluate_full(model, dl_va)
    
    # Cleanup
    del model, optimizer, scheduler, dl_tr, dl_va
    gc.collect()
    torch.cuda.empty_cache()
    
    return final_results

print("‚úÖ Training function ready")
print("   Optimizers: Adam, AdamW, Adagrad")
print("   L1/L2 regularization supported")
print("   Early stopping: patience=10")

‚úÖ Training function ready
   Optimizers: Adam, AdamW, Adagrad
   L1/L2 regularization supported
   Early stopping: patience=10


## üîç CELL 9: GridSearch Configuration

In [31]:
print("\n" + "="*80)
print(" GRIDSEARCH CONFIGURATION ".center(80, "="))
print("="*80)

# ============================================================================
# HYPERPARAMETER GRID - WITH LEAKYRELU
# ============================================================================

param_grid = {
    'optimizer': ['adam', 'adamw', 'adagrad'],  # 4 optimizers
    'activation': ['relu', 'leakyrelu'],               # 2 activations ‚úÖ
    'l1_lambda': [0],                                   # No L1
    'l2_lambda': [0, 1e-4, 1e-3],                      # 3 L2 values
}

# Fixed parameters
fixed_params = {
    'lr': 3e-4,
    'batch_size': 16,
    'epochs': 30,
    'patience': 10,
}

# Generate combinations
keys = list(param_grid.keys())
values = list(param_grid.values())
combinations = list(product(*values))

print("\nüìã HYPERPARAMETER GRID:")
print("-"*80)
print(f"  Optimizer:   {param_grid['optimizer']}")
print(f"  Activation:  {param_grid['activation']}")
print(f"  L1 lambda:   {param_grid['l1_lambda']}")
print(f"  L2 lambda:   {param_grid['l2_lambda']}")

print("\nüìä GRIDSEARCH STATISTICS:")
print("-"*80)
print(f"  Total combinations: {len(combinations)}")
print(f"  Folds per config:   {N_FOLDS}")
print(f"  Total trainings:    {len(combinations) * N_FOLDS}")
print(f"  Est. time per run:  ~20 min")
print(f"  Est. total time:    ~{len(combinations) * N_FOLDS * 20 / 60:.1f} hours")

print("\nüîß FIXED PARAMETERS:")
print("-"*80)
print(f"  Learning rate: {fixed_params['lr']}")
print(f"  Batch size:    {fixed_params['batch_size']}")
print(f"  Max epochs:    {fixed_params['epochs']}")
print(f"  Patience:      {fixed_params['patience']}")

print("\nüìù ALL COMBINATIONS TO TEST:")
print("-"*80)
for i, combo in enumerate(combinations, 1):
    params = dict(zip(keys, combo))
    print(f"  {i:2d}. {params['optimizer']:7s} + {params['activation']:10s} + "
          f"L1={params['l1_lambda']:.0e} + L2={params['l2_lambda']:.0e}")

print("\nüéØ COMPARISON FOCUS:")
print("-"*80)
print("  ‚úÖ 4 Optimizers: Adam vs AdamW vs Adagrad")
print("  ‚úÖ 2 Activations: ReLU vs LeakyReLU")
print("  ‚úÖ 3 L2 values: None vs Light (1e-4) vs Strong (1e-3)")

print("\n‚è±Ô∏è  TIMELINE:")
print("-"*80)
current_time = datetime.now()
finish_time = current_time + pd.Timedelta(hours=24)
print(f"  Start:  {current_time.strftime('%Y-%m-%d %H:%M')}")
print(f"  Finish: {finish_time.strftime('%Y-%m-%d %H:%M')} (approx)")

print("\nüíæ AUTO-SAVE ENABLED:")
print("-"*80)
print(f"  Results directory: {RESULTS_DIR}/")
print(f"  - gridsearch_progress.json (updated after each config)")
print(f"  - gridsearch_summary.txt (human-readable)")
print(f"  - gridsearch_final.json (complete results at end)")



üìã HYPERPARAMETER GRID:
--------------------------------------------------------------------------------
  Optimizer:   ['adam', 'adamw', 'adagrad']
  Activation:  ['relu', 'leakyrelu']
  L1 lambda:   [0]
  L2 lambda:   [0, 0.0001, 0.001]

üìä GRIDSEARCH STATISTICS:
--------------------------------------------------------------------------------
  Total combinations: 18
  Folds per config:   3
  Total trainings:    54
  Est. time per run:  ~20 min
  Est. total time:    ~18.0 hours

üîß FIXED PARAMETERS:
--------------------------------------------------------------------------------
  Learning rate: 0.0003
  Batch size:    16
  Max epochs:    30
  Patience:      10

üìù ALL COMBINATIONS TO TEST:
--------------------------------------------------------------------------------
   1. adam    + relu       + L1=0e+00 + L2=0e+00
   2. adam    + relu       + L1=0e+00 + L2=1e-04
   3. adam    + relu       + L1=0e+00 + L2=1e-03
   4. adam    + leakyrelu  + L1=0e+00 + L2=0e+00
   5. adam 

## üöÄ CELL 10: Run GridSearch with Auto-Save

In [32]:
print("\n" + "="*80)
print(" RUNNING GRIDSEARCH (WITH AUTO-SAVE) ".center(80, "="))
print("="*80)

# Results storage
all_results = []
start_time = time.time()

print(f"\nStarted: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Total runs: {len(combinations) * N_FOLDS}\n")

# Main GridSearch loop
for combo_idx, combo in enumerate(combinations, 1):
    params = dict(zip(keys, combo))
    
    print("\n" + "="*80)
    print(f" CONFIG {combo_idx}/{len(combinations)} ".center(80, "="))
    print("="*80)
    print(f"  Optimizer: {params['optimizer']}")
    print(f"  Activation: {params['activation']}")
    print(f"  L1 lambda: {params['l1_lambda']:.0e}")
    print(f"  L2 lambda: {params['l2_lambda']:.0e}")
    print("-"*80)
    
    fold_results = []
    
    # Train on each fold
    for fold in range(N_FOLDS):
        print(f"\n    Fold {fold+1}/{N_FOLDS}...", flush=True)
        
        fold_start = time.time()
        
        try:
            result = train_one_config(
                fold=fold,
                optimizer_name=params['optimizer'],
                activation=params['activation'],
                l1_lambda=params['l1_lambda'],
                l2_lambda=params['l2_lambda'],
                **fixed_params
            )
            
            fold_results.append(result)
            
            fold_time = time.time() - fold_start
            print(f"\n    ‚úì Fold {fold+1}: F1={result['f1']:.4f} ({fold_time/60:.1f} min)", flush=True)
            
        except Exception as e:
            print(f"\n    ‚ùå Fold {fold+1} Error: {e}", flush=True)
            fold_results.append({'f1': 0.0, 'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0})
    
    # Compute mean metrics
    mean_metrics = {
        'f1': np.mean([r['f1'] for r in fold_results]),
        'accuracy': np.mean([r['accuracy'] for r in fold_results]),
        'precision': np.mean([r['precision'] for r in fold_results]),
        'recall': np.mean([r['recall'] for r in fold_results]),
        'f1_std': np.std([r['f1'] for r in fold_results]),
    }
    
    # Store results
    result_entry = {
        'config_id': combo_idx,
        'params': params,
        'mean_metrics': mean_metrics,
        'fold_results': fold_results,
        'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
    }
    all_results.append(result_entry)
    
    # Print summary
    print("\n  " + "-"*76)
    print(f"  Mean F1: {mean_metrics['f1']:.4f} ¬± {mean_metrics['f1_std']:.4f}")
    print(f"  Mean Acc: {mean_metrics['accuracy']:.4f}")
    
    # ============================================================================
    # AUTO-SAVE AFTER EACH CONFIG
    # ============================================================================
    
    # Save progress JSON
    with open(RESULTS_DIR / 'gridsearch_progress.json', 'w') as f:
        json.dump(all_results, f, indent=2, default=str)
    
    # Save readable summary
    with open(RESULTS_DIR / 'gridsearch_summary.txt', 'w') as f:
        f.write("="*80 + "\n")
        f.write(" ResNet-101 GridSearch Progress ".center(80, "=") + "\n")
        f.write("="*80 + "\n\n")
        f.write(f"Last updated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Completed: {combo_idx}/{len(combinations)} configs\n")
        f.write(f"Progress: {100*combo_idx/len(combinations):.1f}%\n\n")
        
        # Top configs so far
        sorted_results = sorted(all_results, key=lambda x: x['mean_metrics']['f1'], reverse=True)
        f.write("TOP 5 CONFIGS SO FAR:\n")
        f.write("-"*80 + "\n")
        for i, res in enumerate(sorted_results[:5], 1):
            p = res['params']
            m = res['mean_metrics']
            f.write(f"{i}. F1={m['f1']:.4f} | Opt={p['optimizer']}, Act={p['activation']}, "
                   f"L1={p['l1_lambda']:.0e}, L2={p['l2_lambda']:.0e}\n")
    
    print(f"\n  üíæ Progress saved to {RESULTS_DIR}/", flush=True)

total_time = time.time() - start_time

# ============================================================================
# FINAL SAVE
# ============================================================================

print("\n" + "="*80)
print(" GRIDSEARCH COMPLETE ".center(80, "="))
print("="*80)
print(f"\nFinished: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Total time: {total_time/3600:.2f} hours")

# Save final results
with open(RESULTS_DIR / 'gridsearch_final.json', 'w') as f:
    json.dump({
        'all_results': all_results,
        'param_grid': param_grid,
        'fixed_params': fixed_params,
        'n_folds': N_FOLDS,
        'total_time_hours': total_time/3600,
        'completed_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
    }, f, indent=2, default=str)

print(f"\nüíæ Final results saved to {RESULTS_DIR}/gridsearch_final.json")



Started: 2026-01-07 00:26:03
Total runs: 54


  Optimizer: adam
  Activation: relu
  L1 lambda: 0e+00
  L2 lambda: 0e+00
--------------------------------------------------------------------------------

    Fold 1/3...
      [1/6] Data loaders... ‚úì (0.0s)
      [2/6] Model (relu)... ‚úì (1.1s)
      [3/6] Optimizer (adam, L2=0e+00)... ‚úì
      [4/6] Loss & Scheduler... ‚úì
      [5/6] CUDA warmup... ‚úì (0.4s)
      [6/6] Training (patience=10, L1=0e+00)...
        Epoch  1: F1=0.3587, Loss=0.7288
        Epoch  5: F1=0.3684, Loss=0.1682
        Epoch 10: F1=0.3800, Loss=0.0999
        Epoch 15: F1=0.4137, Loss=0.0836
        Epoch 20: F1=0.4658, Loss=0.0739
        Epoch 25: F1=0.5104, Loss=0.0737
        Epoch 30: F1=0.5189, Loss=0.0702

    ‚úì Fold 1: F1=0.5189 (93.5 min)

    Fold 2/3...
      [1/6] Data loaders... ‚úì (0.0s)
      [2/6] Model (relu)... ‚úì (0.9s)
      [3/6] Optimizer (adam, L2=0e+00)... ‚úì
      [4/6] Loss & Scheduler... ‚úì
      [5/6] CUDA warmup... ‚úì 

## üìä CELL 11: Analyze Results

In [33]:
print("\n" + "="*80)
print(" RESULTS ANALYSIS ".center(80, "="))
print("="*80)

# Sort by F1
sorted_results = sorted(all_results, key=lambda x: x['mean_metrics']['f1'], reverse=True)

print("\nüèÜ TOP 10 CONFIGURATIONS:")
print("="*80)
print(f"{'Rank':<6} {'Optimizer':>10} {'Activation':>12} {'L1':>8} {'L2':>8} {'F1':>10} {'Acc':>8}")
print("-"*80)

for i, result in enumerate(sorted_results[:10], 1):
    p = result['params']
    m = result['mean_metrics']
    print(f"{i:<6} {p['optimizer']:>10} {p['activation']:>12} {p['l1_lambda']:>8.0e} "
          f"{p['l2_lambda']:>8.0e} {m['f1']:>10.4f} {m['accuracy']:>8.4f}")

# Best config
best_result = sorted_results[0]
best_params = best_result['params']
best_metrics = best_result['mean_metrics']

print("\n" + "="*80)
print(" BEST CONFIGURATION ".center(80, "="))
print("="*80)
print("\nüìã Best Hyperparameters:")
print(f"  Optimizer:    {best_params['optimizer']}")
print(f"  Activation:   {best_params['activation']}")
print(f"  L1 lambda:    {best_params['l1_lambda']:.0e}")
print(f"  L2 lambda:    {best_params['l2_lambda']:.0e}")

print("\nüìä Best Performance:")
print(f"  F1 Score:   {best_metrics['f1']:.4f} ¬± {best_metrics['f1_std']:.4f}")
print(f"  Accuracy:   {best_metrics['accuracy']:.4f}")
print(f"  Precision:  {best_metrics['precision']:.4f}")
print(f"  Recall:     {best_metrics['recall']:.4f}")

# Comparison
print("\nüìà COMPARISON WITH BASELINES:")
print("-"*80)
print(f"  KAN:       F1 = 0.4073")
print(f"  EEGNet:    F1 = 0.3281")
print(f"  ResNet-101 (Best): F1 = {best_metrics['f1']:.4f}")

if best_metrics['f1'] > 0.4073:
    improvement = ((best_metrics['f1'] - 0.4073) / 0.4073) * 100
    print(f"\n  üéâ ResNet-101 BEATS KAN by {improvement:.1f}%!")
else:
    gap = ((0.4073 - best_metrics['f1']) / 0.4073) * 100
    print(f"\n  ‚ö†Ô∏è  ResNet-101 is {gap:.1f}% below KAN")

print("\n" + "="*80)
print(" ALL RESULTS SAVED ".center(80, "="))
print("="*80)
print(f"\nüìÅ Results directory: {RESULTS_DIR}/")
print(f"   - gridsearch_final.json (complete results)")
print(f"   - gridsearch_summary.txt (readable summary)")
print(f"   - gridsearch_progress.json (backup)")



üèÜ TOP 10 CONFIGURATIONS:
Rank    Optimizer   Activation       L1       L2         F1      Acc
--------------------------------------------------------------------------------
1         adagrad         relu    0e+00    1e-03     0.5585   0.5921
2         adagrad    leakyrelu    0e+00    1e-03     0.5580   0.5979
3         adagrad         relu    0e+00    1e-04     0.5499   0.5805
4         adagrad         relu    0e+00    0e+00     0.5491   0.5724
5         adagrad    leakyrelu    0e+00    0e+00     0.5488   0.5743
6         adagrad    leakyrelu    0e+00    1e-04     0.5484   0.5748
7           adamw         relu    0e+00    1e-03     0.5356   0.5864
8            adam         relu    0e+00    0e+00     0.5320   0.5836
9            adam    leakyrelu    0e+00    0e+00     0.5315   0.5821
10          adamw         relu    0e+00    0e+00     0.5295   0.5773


üìã Best Hyperparameters:
  Optimizer:    adagrad
  Activation:   relu
  L1 lambda:    0e+00
  L2 lambda:    1e-03

üìä Best P

## üîÑ CELL 12: Load Previous Results (If Runtime Died)

In [None]:
# ============================================================================
# USE THIS CELL TO LOAD RESULTS IF RUNTIME DIED
# ============================================================================

print("Loading saved results...\n")

# Try to load final results first
final_path = RESULTS_DIR / 'gridsearch_final.json'
progress_path = RESULTS_DIR / 'gridsearch_progress.json'

if final_path.exists():
    with open(final_path, 'r') as f:
        saved_data = json.load(f)
    all_results = saved_data['all_results']
    print(f"‚úÖ Loaded FINAL results: {len(all_results)} configs")
    
elif progress_path.exists():
    with open(progress_path, 'r') as f:
        all_results = json.load(f)
    print(f"‚úÖ Loaded PROGRESS results: {len(all_results)} configs")
    print(f"   (GridSearch was interrupted, this is partial results)")
    
else:
    print("‚ùå No saved results found in gridsearch_results/")
    print("   Make sure you ran Cell 10 (GridSearch) first!")
    all_results = []

if all_results:
    # Show top 5
    sorted_results = sorted(all_results, key=lambda x: x['mean_metrics']['f1'], reverse=True)
    
    print("\nüèÜ TOP 5 CONFIGURATIONS:")
    print("-"*80)
    for i, result in enumerate(sorted_results[:5], 1):
        p = result['params']
        m = result['mean_metrics']
        print(f"{i}. F1={m['f1']:.4f} | Opt={p['optimizer']}, Act={p['activation']}, "
              f"L1={p['l1_lambda']:.0e}, L2={p['l2_lambda']:.0e}")
    
    print(f"\n‚úÖ Results loaded! You can now run Cell 11 for full analysis.")

In [1]:
print("\n" + "="*80)
print(" RESULTS ANALYSIS ".center(80, "="))
print("="*80)

# Sort by F1
sorted_results = sorted(all_results, key=lambda x: x['mean_metrics']['f1'], reverse=True)

print("\nüèÜ TOP 10 CONFIGURATIONS:")
print("="*80)
print(f"{'Rank':<6} {'Optimizer':>10} {'Grid':>6} {'Spline':>8} {'L1':>8} {'L2':>8} {'F1':>10} {'Acc':>8}")
print("-"*80)

for i, result in enumerate(sorted_results[:10], 1):
    p = result['params']
    m = result['mean_metrics']
    print(f"{i:<6} {p['optimizer']:>10} {p['grid_size']:>6} {p['spline_order']:>8} "
          f"{p['l1_lambda']:>8.0e} {p['l2_lambda']:>8.0e} "
          f"{m['f1']:>10.4f} {m['accuracy']:>8.4f}")

# Best config
best_result = sorted_results[0]
best_params = best_result['params']
best_metrics = best_result['mean_metrics']

print("\n" + "="*80)
print(" BEST CONFIGURATION ".center(80, "="))
print("="*80)

print("\nüìã Best Hyperparameters:")
print(f"  Optimizer:    {best_params['optimizer']}")
print(f"  Grid size:    {best_params['grid_size']}")
print(f"  Spline order: {best_params['spline_order']}")
print(f"  L1 lambda:    {best_params['l1_lambda']:.0e}")
print(f"  L2 lambda:    {best_params['l2_lambda']:.0e}")

print("\nüìä Best Performance:")
print(f"  F1 Score:   {best_metrics['f1']:.4f} ¬± {best_metrics['f1_std']:.4f}")
print(f"  Accuracy:   {best_metrics['accuracy']:.4f}")
print(f"  Precision:  {best_metrics['precision']:.4f}")
print(f"  Recall:     {best_metrics['recall']:.4f}")

# Comparison
print("\nüìà COMPARISON WITH BASELINES:")
print("-"*80)
print(f"  EEGNet:    F1 = 0.3281")
print(f"  KAN (Best):       F1 = {best_metrics['f1']:.4f}")
print(f"  ResNet-101: F1 = 0.5585")

if best_metrics['f1'] > 0.5585:
    improvement = ((best_metrics['f1'] - 0.5585) / 0.5585) * 100
    print(f"\n  üéâ KAN BEATS ResNet-101 by {improvement:.1f}%!")
elif best_metrics['f1'] > 0.3281:
    improvement_eegnet = ((best_metrics['f1'] - 0.3281) / 0.3281) * 100
    gap_resnet = ((0.5585 - best_metrics['f1']) / 0.5585) * 100
    print(f"\n  ‚úÖ KAN beats EEGNet by {improvement_eegnet:.1f}%")
    print(f"  üìâ KAN is {gap_resnet:.1f}% below ResNet-101")
else:
    gap = ((0.3281 - best_metrics['f1']) / 0.3281) * 100
    print(f"\n  ‚ö†Ô∏è  KAN is {gap:.1f}% below EEGNet baseline")

print("\n" + "="*80)
print(" ALL RESULTS SAVED ".center(80, "="))
print("="*80)
print(f"\nüìÅ Results directory: {RESULTS_DIR}/")
print(f"   - kan_gridsearch_final.json (complete results)")
print(f"   - kan_gridsearch_summary.txt (readable summary)")
print(f"   - kan_gridsearch_progress.json (backup)")

print("\n" + "="*80)




NameError: name 'all_results' is not defined