In [1]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from torch.cuda.amp import GradScaler, autocast
import timm
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import torchvision.transforms.v2 as v2
import torch.nn.functional as F

class EEGSpectrogramDataset(Dataset):
    def __init__(self, eeg_ids, spectrogram_dict, label_dict):
        self.eeg_ids = eeg_ids
        self.spectrogram_dict = spectrogram_dict
        self.label_dict = label_dict

        self.transform = v2.Compose([
            v2.Resize((224, 224)),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
    def __len__(self):
        return len(self.eeg_ids)
    
    def __getitem__(self, idx):
        eeg_id = self.eeg_ids[idx]
        spectrogram = self.spectrogram_dict[eeg_id]
        
        label_str = self.label_dict[eeg_id]
        label = self._label_to_index(label_str)

        stacked = np.vstack([spectrogram[..., i] for i in range(4)])[..., np.newaxis]
        stacked = torch.tensor(stacked, dtype=torch.float32).permute(2, 0, 1)
    
        rgb = torch.cat([stacked, stacked, stacked], dim=0)
        
        # Resize to 224x224 using bilinear interpolation
        rgb = F.interpolate(rgb.unsqueeze(0), 
                          size=(224, 224), 
                          mode='bilinear',
                          align_corners=False).squeeze(0)
        
        # Apply transforms
        rgb = self.transform(rgb)
        
        label = torch.tensor(label, dtype=torch.long)
        
        return rgb, label
    
    def _label_to_index(self, label_str):
        label_mapping = {
            'Seizure': 0,
            'LPD': 1,
            'GPD': 2,
            'LRDA': 3,
            'GRDA': 4,
            'Other': 5
        }
        return label_mapping.get(label_str, 5)

# Load data
spectrogram_dict = np.load('/kaggle/input/brain-eeg-spectrograms/eeg_specs.npy', allow_pickle=True).item()
train_df = pd.read_csv('/kaggle/input/hms-harmful-brain-activity-classification/train.csv')
label_dict = dict(zip(train_df['eeg_id'], train_df['expert_consensus']))

common_eeg_ids = [eeg_id for eeg_id in spectrogram_dict.keys() if eeg_id in label_dict]

print(f"Total EEG IDs in spectrogram dict: {len(spectrogram_dict)}")
print(f"Total EEG IDs in labels: {len(label_dict)}")
print(f"Common EEG IDs with both: {len(common_eeg_ids)}")
print(f"Sample spectrogram shape: {spectrogram_dict[common_eeg_ids[0]].shape}")

Total EEG IDs in spectrogram dict: 17089
Total EEG IDs in labels: 17089
Common EEG IDs with both: 17089
Sample spectrogram shape: (128, 256, 4)


In [3]:
class EarlyStopping:
    def __init__(self, patience=3, verbose=False):
        self.patience = patience
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False
        self.verbose = verbose

    def __call__(self, val_loss):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping: {self.counter}/{self.patience} without improvement.")
            if self.counter >= self.patience:
                self.early_stop = True

In [5]:
from torch.cuda.amp import autocast, GradScaler

In [8]:
os.makedirs("/kaggle/working/saved_models", exist_ok=True)

In [None]:
# Prepare for K-Fold
num_folds = 5
batch_size = 16
num_epochs = 20
num_classes = 6
fp16 = True

# Get labels for stratification
labels = [label_dict[eeg_id] for eeg_id in common_eeg_ids]
skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)

# Store fold results
fold_results = []

for fold, (train_idx, val_idx) in enumerate(skf.split(common_eeg_ids, labels)):
    print(f"\n{'='*40}")
    print(f"Fold {fold + 1}/{num_folds}")
    print(f"{'='*40}")
    
    # Create datasets
    train_eeg_ids = [common_eeg_ids[i] for i in train_idx]
    val_eeg_ids = [common_eeg_ids[i] for i in val_idx]
    
    train_dataset = EEGSpectrogramDataset(train_eeg_ids, spectrogram_dict, label_dict)
    val_dataset = EEGSpectrogramDataset(val_eeg_ids, spectrogram_dict, label_dict)
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    # Initialize model
    model = timm.create_model("swin_tiny_patch4_window7_224", pretrained=True, num_classes=num_classes)
    
    # Multi-GPU support
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
        model = nn.DataParallel(model)
    model = model.to(device)
    
    # Training setup
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.01)
    scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-7)
    scaler = GradScaler(enabled=fp16)
    early_stopper = EarlyStopping(patience=3, verbose=True)
    
    best_val_acc = 0.0
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            with autocast(enabled=fp16):
                outputs = model(images)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        scheduler.step()
        
        train_acc = correct / total
        avg_train_loss = running_loss / len(train_loader)
        print(f"\nEpoch [{epoch+1}/{num_epochs}]")
        print(f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_acc:.4f}")

        # Validation
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                with autocast(enabled=fp16):
                    outputs = model(images)
                    loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

        val_acc = val_correct / val_total
        avg_val_loss = val_loss / len(val_loader)
        print(f"Val Loss: {avg_val_loss:.4f}, Val Accuracy: {val_acc:.4f}")

        # Save best model for this fold
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), f"/kaggle/working/saved_models/swin_fold{fold+1}_best.pth")
            print(f"Best model for fold {fold+1} saved with val accuracy: {val_acc:.4f}")

        early_stopper(avg_val_loss)
        if early_stopper.early_stop:
            print("Early stopping triggered. Ending training for this fold.")
            break
    
    # Store fold results
    fold_results.append({
        'fold': fold + 1,
        'best_val_acc': best_val_acc,
        'best_val_loss': best_val_loss
    })

# Print final results
print("\nFinal Results:")
for result in fold_results:
    print(f"Fold {result['fold']}: Val Acc: {result['best_val_acc']:.4f}, Val Loss: {result['best_val_loss']:.4f}")

mean_acc = np.mean([r['best_val_acc'] for r in fold_results])
mean_loss = np.mean([r['best_val_loss'] for r in fold_results])
print(f"\nMean Val Accuracy: {mean_acc:.4f}")
print(f"Mean Val Loss: {mean_loss:.4f}")


Fold 1/5


  scaler = GradScaler(enabled=fp16)
  with autocast(enabled=fp16):
Epoch 1/20: 100%|██████████| 854/854 [02:02<00:00,  6.98it/s]


Epoch [1/20]
Train Loss: 1.4986, Train Accuracy: 0.4483



  with autocast(enabled=fp16):


Val Loss: 1.4353, Val Accuracy: 0.4912
Best model for fold 1 saved with val accuracy: 0.4912


Epoch 2/20: 100%|██████████| 854/854 [02:02<00:00,  6.98it/s]


Epoch [2/20]
Train Loss: 1.3696, Train Accuracy: 0.4930





Val Loss: 1.2356, Val Accuracy: 0.5477
Best model for fold 1 saved with val accuracy: 0.5477


Epoch 3/20: 100%|██████████| 854/854 [02:02<00:00,  6.96it/s]


Epoch [3/20]
Train Loss: 1.2379, Train Accuracy: 0.5492





Val Loss: 1.1905, Val Accuracy: 0.5690
Best model for fold 1 saved with val accuracy: 0.5690


Epoch 4/20: 100%|██████████| 854/854 [02:02<00:00,  6.96it/s]


Epoch [4/20]
Train Loss: 1.0758, Train Accuracy: 0.6114





Val Loss: 1.1849, Val Accuracy: 0.5591


Epoch 5/20: 100%|██████████| 854/854 [02:02<00:00,  6.97it/s]


Epoch [5/20]
Train Loss: 0.9677, Train Accuracy: 0.6511





Val Loss: 1.0124, Val Accuracy: 0.6202
Best model for fold 1 saved with val accuracy: 0.6202


Epoch 6/20: 100%|██████████| 854/854 [02:02<00:00,  6.97it/s]


Epoch [6/20]
Train Loss: 0.8888, Train Accuracy: 0.6789





Val Loss: 0.8761, Val Accuracy: 0.6823
Best model for fold 1 saved with val accuracy: 0.6823


Epoch 7/20: 100%|██████████| 854/854 [02:02<00:00,  6.97it/s]


Epoch [7/20]
Train Loss: 0.8044, Train Accuracy: 0.7089





Val Loss: 0.8342, Val Accuracy: 0.6984
Best model for fold 1 saved with val accuracy: 0.6984


Epoch 8/20: 100%|██████████| 854/854 [02:02<00:00,  6.97it/s]


Epoch [8/20]
Train Loss: 0.7323, Train Accuracy: 0.7364





Val Loss: 0.7937, Val Accuracy: 0.7203
Best model for fold 1 saved with val accuracy: 0.7203


Epoch 9/20: 100%|██████████| 854/854 [02:02<00:00,  6.97it/s]


Epoch [9/20]
Train Loss: 0.6686, Train Accuracy: 0.7600





Val Loss: 0.7925, Val Accuracy: 0.7133


Epoch 10/20: 100%|██████████| 854/854 [02:02<00:00,  6.96it/s]


Epoch [10/20]
Train Loss: 0.6238, Train Accuracy: 0.7783





Val Loss: 0.7971, Val Accuracy: 0.7139
EarlyStopping: 1/3 without improvement.


Epoch 11/20: 100%|██████████| 854/854 [02:02<00:00,  6.97it/s]


Epoch [11/20]
Train Loss: 0.6019, Train Accuracy: 0.7878





Val Loss: 0.7964, Val Accuracy: 0.7156
EarlyStopping: 2/3 without improvement.


Epoch 12/20: 100%|██████████| 854/854 [02:02<00:00,  6.97it/s]


Epoch [12/20]
Train Loss: 0.6078, Train Accuracy: 0.7829





Val Loss: 0.8022, Val Accuracy: 0.7142
EarlyStopping: 3/3 without improvement.
Early stopping triggered. Ending training for this fold.

Fold 2/5


Epoch 1/20: 100%|██████████| 854/854 [02:02<00:00,  6.97it/s]


Epoch [1/20]
Train Loss: 1.3540, Train Accuracy: 0.5112





Val Loss: 1.1184, Val Accuracy: 0.5848
Best model for fold 2 saved with val accuracy: 0.5848


Epoch 2/20: 100%|██████████| 854/854 [02:02<00:00,  6.96it/s]


Epoch [2/20]
Train Loss: 1.0813, Train Accuracy: 0.6104





Val Loss: 1.0201, Val Accuracy: 0.6325
Best model for fold 2 saved with val accuracy: 0.6325


Epoch 3/20: 100%|██████████| 854/854 [02:02<00:00,  6.96it/s]


Epoch [3/20]
Train Loss: 0.9357, Train Accuracy: 0.6647





Val Loss: 0.9080, Val Accuracy: 0.6697
Best model for fold 2 saved with val accuracy: 0.6697


Epoch 4/20: 100%|██████████| 854/854 [02:02<00:00,  6.96it/s]


Epoch [4/20]
Train Loss: 0.8510, Train Accuracy: 0.6937





Val Loss: 0.8178, Val Accuracy: 0.7159
Best model for fold 2 saved with val accuracy: 0.7159


Epoch 5/20: 100%|██████████| 854/854 [02:02<00:00,  6.97it/s]


Epoch [5/20]
Train Loss: 0.7744, Train Accuracy: 0.7242





Val Loss: 0.8108, Val Accuracy: 0.7209
Best model for fold 2 saved with val accuracy: 0.7209


Epoch 6/20: 100%|██████████| 854/854 [02:02<00:00,  6.97it/s]


Epoch [6/20]
Train Loss: 0.6782, Train Accuracy: 0.7591





Val Loss: 0.8723, Val Accuracy: 0.7057
EarlyStopping: 1/3 without improvement.


Epoch 7/20: 100%|██████████| 854/854 [02:02<00:00,  6.96it/s]


Epoch [7/20]
Train Loss: 0.5648, Train Accuracy: 0.8023





Val Loss: 0.8352, Val Accuracy: 0.7305
Best model for fold 2 saved with val accuracy: 0.7305
EarlyStopping: 2/3 without improvement.


Epoch 8/20: 100%|██████████| 854/854 [02:02<00:00,  6.97it/s]


Epoch [8/20]
Train Loss: 0.4377, Train Accuracy: 0.8473





Val Loss: 0.8756, Val Accuracy: 0.7326
Best model for fold 2 saved with val accuracy: 0.7326
EarlyStopping: 3/3 without improvement.
Early stopping triggered. Ending training for this fold.

Fold 3/5


Epoch 1/20:   5%|▍         | 42/854 [00:06<01:55,  7.01it/s]