# RSNA 2024 Lumbar Spine Degenerative Classification
## Version 4

### Key Fixes in v4:
1. **Class weights calculated from ORIGINAL data** (not oversampled!)
2. **Model selection based on Balanced Accuracy** (not val_loss)
3. **Increased focal_gamma to 3.5** (from 2.0)
4. **Added minimum recall threshold** for model saving

In [None]:
import os
import cv2
import glob
import pydicom
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
from tqdm import tqdm
from sklearn.model_selection import StratifiedGroupKFold, train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score
from collections import Counter

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
# --- CONFIGURATION ---
CONFIG = {
    'seed': 42,
    'img_size': 256,
    'seq_length': 7,
    'batch_size': 8,
    'epochs': 25,
    'learning_rate': 3e-4,
    'backbone_lr': 3e-5,
    'weight_decay': 0.05,
    'patience': 10,  # Increased patience since we use balanced_acc
    'num_folds': 5,
    'train_folds': [0],
    'focal_gamma': 3.5,  # INCREASED from 2.0 to focus more on hard examples
    'label_smoothing': 0.1,
    'dropout': 0.4,
    'num_attention_heads': 4,
    'warmup_epochs': 2,
    'oversample_strategy': 'progressive',  # 'progressive', 'smote_like', or 'balanced'
    'min_minority_recall': 0.20,  # Minimum recall for minority classes to save model
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'target_condition': 'spinal_canal_stenosis',
    'target_series': 'Sagittal T2/STIR'
}

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
seed_everything(CONFIG['seed'])
print(f"‚úÖ Device: {CONFIG['device']}")
print(f"   Oversample Strategy: {CONFIG['oversample_strategy']}")
print(f"   Focal Gamma: {CONFIG['focal_gamma']} (Higher = more focus on hard examples)")

## 1. Data Loading

In [None]:
# --- PATHS ---
DATA_ROOT = "/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/"
TRAIN_IMAGES = os.path.join(DATA_ROOT, "train_images")

In [None]:
df_train = pd.read_csv(f"{DATA_ROOT}/train.csv")
df_coords = pd.read_csv(f"{DATA_ROOT}/train_label_coordinates.csv")
df_desc = pd.read_csv(f"{DATA_ROOT}/train_series_descriptions.csv")

In [None]:
# Clean & Merge
df_train.columns = [col.lower().replace('/', '_') for col in df_train.columns]
condition_cols = [c for c in df_train.columns if c != 'study_id']
df_labels = pd.melt(df_train, id_vars=['study_id'], value_vars=condition_cols,
                    var_name='condition_level', value_name='severity')
df_labels = df_labels.dropna(subset=['severity'])
df_labels['severity'] = df_labels['severity'].astype(str).str.lower().str.replace('/', '_')

In [None]:
def extract_meta(val):
    parts = val.split('_')
    level = parts[-2] + '_' + parts[-1]
    condition = '_'.join(parts[:-2])
    return condition, level

In [None]:
df_labels[['base_condition', 'level_str']] = df_labels['condition_level'].apply(lambda x: pd.Series(extract_meta(x)))
severity_map = {'normal_mild': 0, 'moderate': 1, 'severe': 2}
df_labels['label'] = df_labels['severity'].map(severity_map)
df_labels = df_labels.dropna(subset=['label'])
df_labels['label'] = df_labels['label'].astype(int)

In [None]:
df_coords = df_coords.merge(df_desc, on=['study_id', 'series_id'], how='left')
df_coords['condition'] = df_coords['condition'].str.lower().str.replace(' ', '_')
df_coords['level'] = df_coords['level'].str.lower().str.replace('/', '_')
df_coords['condition_level'] = df_coords['condition'] + '_' + df_coords['level']

In [None]:
df_model = df_labels[df_labels['base_condition'] == CONFIG['target_condition']].copy()
df_coords_filt = df_coords[(df_coords['condition'] == CONFIG['target_condition']) & 
                           (df_coords['series_description'] == CONFIG['target_series'])]

In [None]:
df_final = df_model.merge(df_coords_filt[['study_id', 'condition_level', 'series_id', 'instance_number', 'x', 'y']],
                          on=['study_id', 'condition_level'], how='inner')

In [None]:
# Filter valid files
valid_rows = []
for index, row in tqdm(df_final.iterrows(), total=len(df_final), desc="Checking Files"):
    path = f"{TRAIN_IMAGES}/{row['study_id']}/{row['series_id']}/{int(row['instance_number'])}.dcm"
    if os.path.exists(path):
        valid_rows.append(row)

In [None]:
df_final = pd.DataFrame(valid_rows).reset_index(drop=True)
level_map = {'l1_l2': 0, 'l2_l3': 1, 'l3_l4': 2, 'l4_l5': 3, 'l5_s1': 4}
df_final['level_idx'] = df_final['level_str'].map(level_map)

In [None]:
print(f"\n‚úÖ Data Ready: {len(df_final)} samples")
print(f"   Class Distribution: {df_final['label'].value_counts().sort_index().to_dict()}")
class_counts = df_final['label'].value_counts().sort_index()
for i, count in enumerate(class_counts):
    pct = count / len(df_final) * 100
    print(f"   Class {i}: {count} samples ({pct:.1f}%)")

## 2. Balanced Oversampling Function

In [None]:
def create_stratified_balanced_df(df, strategy='progressive', random_state=42):
    """
    Progressive balancing that considers both label and level together.
    Adds augmentation variant tracking for diversity.
    """
    np.random.seed(random_state)
    grouped = df.groupby(['level_idx', 'label'])
    balanced_dfs = []
    
    print("\nüìä Stratified Sampling Details:")
    
    for (level, label), group_df in grouped:
        group_df = group_df.copy()
        group_df['is_oversampled'] = False
        group_df['aug_variant'] = 0
        current_count = len(group_df)
        level_counts = df[df['level_idx'] == level]['label'].value_counts()
        
        if strategy == 'progressive':
            target_count = int(level_counts.median() * (1 + 0.3 * label))
        elif strategy == 'smote_like':
            if current_count < level_counts.median():
                target_count = int(level_counts.median() * 0.8)
            else:
                target_count = current_count
        elif strategy == 'balanced':
            target_count = level_counts.max()
        else:
            target_count = current_count
        
        samples_needed = target_count - current_count
        
        if samples_needed > 0:
            oversample_indices = np.random.choice(group_df.index, size=samples_needed, replace=True)
            oversampled_df = df.loc[oversample_indices].copy()
            oversampled_df['is_oversampled'] = True
            oversampled_df['aug_variant'] = np.random.randint(0, 4, size=len(oversampled_df))
            print(f"   Level {level}, Label {label}: {current_count} ‚Üí {target_count} (+{samples_needed})")
            balanced_dfs.append(group_df)
            balanced_dfs.append(oversampled_df)
        else:
            print(f"   Level {level}, Label {label}: {current_count} (no oversampling)")
            balanced_dfs.append(group_df)
    
    balanced_df = pd.concat(balanced_dfs, ignore_index=True)
    return balanced_df.sample(frac=1, random_state=random_state).reset_index(drop=True)

In [None]:
# Test the function
print("\nüìä Before Oversampling:")
print(df_final['label'].value_counts().sort_index())

In [None]:
balanced_test = create_stratified_balanced_df(df_final, strategy='balanced')
print("\nüìä After Balanced Oversampling:")
print(balanced_test['label'].value_counts().sort_index())
print(f"   Total samples: {len(df_final)} ‚Üí {len(balanced_test)}")
print(f"   Oversampled: {balanced_test['is_oversampled'].sum()} samples")

## 3. Dataset with Adaptive Augmentation

In [None]:
class RSNASequenceDataset(Dataset):
    def __init__(self, df, seq_length=7, img_size=256, transform=None, 
                 strong_transform=None, is_training=False):
        self.df = df.reset_index(drop=True)
        self.seq_length = seq_length
        self.img_size = img_size
        self.transform = transform
        self.strong_transform = strong_transform  # For oversampled data
        self.is_training = is_training
        self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        
    def __len__(self):
        return len(self.df)
    
    def load_dicom(self, path):
        try:
            dcm = pydicom.dcmread(path)
            img = dcm.pixel_array.astype(np.float32)
            if img.max() > img.min():
                img = (img - img.min()) / (img.max() - img.min()) * 255.0
            else:
                img = np.zeros_like(img)
            img = img.astype(np.uint8)
            img = self.clahe.apply(img)
            return img
        except:
            return np.zeros((self.img_size, self.img_size), dtype=np.uint8)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        center_inst = int(row['instance_number'])
        study_path = f"{TRAIN_IMAGES}/{row['study_id']}/{row['series_id']}"
        cx, cy = int(row['x']), int(row['y'])
        
        # Check if this sample is oversampled (needs stronger augmentation)
        is_oversampled = row.get('is_oversampled', False)
        aug_variant = row.get('aug_variant', 0)
        
        # Set different random seed based on variant for diversity
        if is_oversampled and self.is_training:
            np.random.seed(idx * 1000 + int(aug_variant))
            random.seed(idx * 1000 + int(aug_variant))
        
        start = center_inst - (self.seq_length // 2)
        indices = [start + i for i in range(self.seq_length)]
        
        images_list = []
        for inst in indices:
            path = os.path.join(study_path, f"{inst}.dcm")
            if os.path.exists(path):
                img = self.load_dicom(path)
            else:
                img = np.zeros((self.img_size, self.img_size), dtype=np.uint8)
            
            h, w = img.shape
            crop_size = self.img_size // 2 
            x1 = max(0, cx - crop_size)
            y1 = max(0, cy - crop_size)
            x2 = min(w, cx + crop_size)
            y2 = min(h, cy + crop_size)
            crop = img[y1:y2, x1:x2]
            
            if crop.size == 0:
                crop = np.zeros((self.img_size, self.img_size), dtype=np.uint8)
            else:
                crop = cv2.resize(crop, (self.img_size, self.img_size))
            
            crop = cv2.cvtColor(crop, cv2.COLOR_GRAY2RGB)
            
            # Use stronger augmentation for oversampled data
            if self.is_training and is_oversampled and self.strong_transform:
                res = self.strong_transform(image=crop)
            elif self.transform:
                res = self.transform(image=crop)
            else:
                res = {'image': torch.tensor(crop).permute(2, 0, 1).float() / 255.0}
            
            crop_tensor = res['image']
            images_list.append(crop_tensor)
            
        sequence = torch.stack(images_list, dim=0)
        label = torch.tensor(row['label'], dtype=torch.long)
        level_idx = torch.tensor(row['level_idx'], dtype=torch.long)
        
        return sequence, label, level_idx

## 4. Augmentation Pipelines (Normal + Strong)

In [None]:
# Medical imaging-appropriate augmentation (normal)
train_aug = A.Compose([
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=5, 
                       border_mode=cv2.BORDER_CONSTANT, value=0, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.GaussNoise(var_limit=(5.0, 20.0), p=0.2),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

In [None]:
# Medical imaging-appropriate augmentation (strong for oversampled)
strong_aug = A.Compose([
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.15, rotate_limit=8,
                       border_mode=cv2.BORDER_CONSTANT, value=0, p=0.7),
    A.ElasticTransform(alpha=1, sigma=50, alpha_affine=20, p=0.3),
    A.OneOf([
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1.0),
        A.RandomGamma(gamma_limit=(80, 120), p=1.0),
        A.CLAHE(clip_limit=4.0, p=1.0),
    ], p=0.8),
    A.OneOf([
        A.GaussNoise(var_limit=(10.0, 50.0), p=1.0),
        A.MultiplicativeNoise(multiplier=(0.9, 1.1), p=1.0),
    ], p=0.3),
    A.GridDistortion(num_steps=5, distort_limit=0.1, p=0.3),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

In [None]:
val_aug = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

In [None]:
print("‚úÖ Dual augmentation pipeline (ENHANCED):")
print("   - Normal: for original samples")
print("   - Strong: for oversampled minority class samples (more aggressive)")

## 5. Model Architecture

In [None]:
class SpineSeqAttention(nn.Module):
    def __init__(self, num_classes=3, hidden_dim=256, lstm_layers=2, 
                 num_heads=4, dropout=0.4, num_levels=5):
        super(SpineSeqAttention, self).__init__()
        
        effnet = models.efficientnet_v2_s(weights='IMAGENET1K_V1')
        self.backbone = nn.Sequential(*list(effnet.children())[:-1]) 
        self.feature_dim = 1280 
        
        self.feature_proj = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(self.feature_dim, hidden_dim * 2),
            nn.LayerNorm(hidden_dim * 2),
            nn.GELU()
        )
        
        self.lstm = nn.LSTM(
            input_size=hidden_dim * 2, 
            hidden_size=hidden_dim, 
            num_layers=lstm_layers, 
            batch_first=True, 
            bidirectional=True, 
            dropout=dropout if lstm_layers > 1 else 0
        )
        
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim * 2,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        self.level_embedding = nn.Embedding(num_levels, 64)
        
        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_dim * 2 + 64),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2 + 64, 128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes)
        )
        
    def forward(self, x, level_idx=None):
        b, s, c, h, w = x.size()
        x = x.view(b * s, c, h, w)
        
        features = self.backbone(x)
        features = features.view(b, s, -1)
        features = self.feature_proj(features)
        
        lstm_out, _ = self.lstm(features)
        attn_out, attn_weights = self.attention(lstm_out, lstm_out, lstm_out)
        context = attn_out.mean(dim=1)
        
        if level_idx is not None:
            level_feat = self.level_embedding(level_idx)
            context = torch.cat([context, level_feat], dim=-1)
        else:
            context = torch.cat([context, torch.zeros(b, 64, device=x.device)], dim=-1)
        
        out = self.classifier(context)
        avg_attn = attn_weights.mean(dim=1)
        
        return out, avg_attn

In [None]:
print("‚úÖ Model architecture loaded")

## 6. Focal Loss + Early Stopping

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, label_smoothing=0.1, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.label_smoothing = label_smoothing
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(
            inputs, targets, 
            weight=self.alpha, 
            reduction='none', 
            label_smoothing=self.label_smoothing
        )
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        return focal_loss

In [None]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001, mode='max'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode  # 'max' for balanced accuracy, 'min' for loss
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        
    def __call__(self, val_score):
        if self.best_score is None:
            self.best_score = val_score
            return False
        
        if self.mode == 'max':
            # Higher is better (e.g., balanced accuracy)
            if val_score > self.best_score + self.min_delta:
                self.best_score = val_score
                self.counter = 0
                return False
        else:
            # Lower is better (e.g., loss)
            if val_score < self.best_score - self.min_delta:
                self.best_score = val_score
                self.counter = 0
                return False
        
        self.counter += 1
        if self.counter >= self.patience:
            self.early_stop = True
            return True
        return False

In [None]:
print("‚úÖ Focal Loss + Early Stopping (uses balanced accuracy)")

## 7. Training Function with Per-Class Metrics (FIXED!)

In [None]:
def compute_per_class_metrics(preds, labels, num_classes=3):
    """Compute per-class accuracy (recall)"""
    metrics = {}
    for c in range(num_classes):
        mask = (labels == c)
        if mask.sum() > 0:
            correct = ((preds == c) & mask).sum()
            metrics[f'class_{c}_recall'] = correct / mask.sum()
        else:
            metrics[f'class_{c}_recall'] = 0.0
    return metrics

In [None]:
def train_one_fold(model, train_loader, val_loader, fold, config, original_class_counts):
    """
    FIXED Training function:
    1. Uses ORIGINAL class counts for loss weights (not oversampled!)
    2. Uses BALANCED ACCURACY for model selection (not val_loss!)
    3. Tracks minimum minority recall threshold
    """
    
    # ========================================
    # FIX #1: Use ORIGINAL class counts for loss weights!
    # ========================================
    # After oversampling, all classes have equal counts, so weights would be ~1.0 each
    # This negates the benefit of weighted loss!
    # We use the ORIGINAL (pre-oversampling) counts instead
    class_weights = 1. / (original_class_counts + 1e-6)
    class_weights = class_weights / class_weights.sum() * 3
    loss_weights = torch.FloatTensor(class_weights).to(config['device'])
    
    print(f"\n   üìä Class weights from ORIGINAL data: {class_weights}")
    print(f"      (Class 0: {class_weights[0]:.3f}, Class 1: {class_weights[1]:.3f}, Class 2: {class_weights[2]:.3f})")
    
    criterion = FocalLoss(
        alpha=loss_weights, 
        gamma=config['focal_gamma'], 
        label_smoothing=config['label_smoothing']
    )
    
    optimizer = optim.AdamW([
        {'params': model.backbone.parameters(), 'lr': config['backbone_lr']},
        {'params': model.feature_proj.parameters(), 'lr': config['learning_rate']},
        {'params': model.lstm.parameters(), 'lr': config['learning_rate']},
        {'params': model.attention.parameters(), 'lr': config['learning_rate']},
        {'params': model.level_embedding.parameters(), 'lr': config['learning_rate']},
        {'params': model.classifier.parameters(), 'lr': config['learning_rate']}
    ], weight_decay=config['weight_decay'])
    
    warmup_steps = config['warmup_epochs'] * len(train_loader)
    total_steps = config['epochs'] * len(train_loader)
    
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        else:
            progress = (step - warmup_steps) / (total_steps - warmup_steps)
            return 0.5 * (1 + np.cos(np.pi * progress))
    
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    scaler = GradScaler('cuda')
    
    # ========================================
    # FIX #2: Use BALANCED ACCURACY for early stopping!
    # ========================================
    early_stopping = EarlyStopping(patience=config['patience'], min_delta=0.005, mode='max')
    
    best_balanced_acc = 0.0
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [], 'balanced_acc': [],
        'class_0_recall': [], 'class_1_recall': [], 'class_2_recall': []
    }
    
    print(f"\nüöÄ Training Fold {fold+1}/{config['num_folds']}")
    print(f"   Train: {len(train_loader.dataset)}, Val: {len(val_loader.dataset)}")
    print(f"   Focal Gamma: {config['focal_gamma']}")
    print(f"   ‚ö†Ô∏è  Model saved based on BALANCED ACCURACY (not val_loss!)")
    
    for epoch in range(config['epochs']):
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
        
        for images, labels, level_idx in loop:
            images = images.to(config['device'])
            labels = labels.to(config['device'])
            level_idx = level_idx.to(config['device'])
            
            optimizer.zero_grad()
            
            with autocast('cuda'):
                outputs, _ = model(images, level_idx)
                loss = criterion(outputs, labels)
                
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_correct += (predicted == labels).sum().item()
            train_total += labels.size(0)
            
            loop.set_postfix(
                loss=f"{train_loss/(loop.n+1):.4f}", 
                acc=f"{100*train_correct/train_total:.1f}%",
                lr=f"{optimizer.param_groups[0]['lr']:.2e}"
            )
        
        train_epoch_loss = train_loss / len(train_loader)
        train_acc = train_correct / train_total
        
        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for images, labels, level_idx in val_loader:
                images = images.to(config['device'])
                labels = labels.to(config['device'])
                level_idx = level_idx.to(config['device'])
                
                with autocast('cuda'):
                    outputs, _ = model(images, level_idx)
                    loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)
                
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        val_epoch_loss = val_loss / len(val_loader)
        val_acc = val_correct / val_total
        
        # Compute per-class recall
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        per_class = compute_per_class_metrics(all_preds, all_labels)
        
        # ========================================
        # FIX #3: Compute and track BALANCED ACCURACY
        # ========================================
        balanced_acc = (per_class['class_0_recall'] + 
                       per_class['class_1_recall'] + 
                       per_class['class_2_recall']) / 3
        
        history['train_loss'].append(train_epoch_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_epoch_loss)
        history['val_acc'].append(val_acc)
        history['balanced_acc'].append(balanced_acc)
        history['class_0_recall'].append(per_class['class_0_recall'])
        history['class_1_recall'].append(per_class['class_1_recall'])
        history['class_2_recall'].append(per_class['class_2_recall'])
        
        print(f"üìä Train Loss: {train_epoch_loss:.4f} | Train Acc: {100*train_acc:.1f}% | "
              f"Val Loss: {val_epoch_loss:.4f} | Val Acc: {100*val_acc:.1f}%")
        print(f"   Per-class Recall: Normal={100*per_class['class_0_recall']:.1f}%, "
              f"Moderate={100*per_class['class_1_recall']:.1f}%, "
              f"Severe={100*per_class['class_2_recall']:.1f}%")
        print(f"   üéØ Balanced Accuracy: {100*balanced_acc:.1f}%")
        
        # ========================================
        # FIX #4: Save based on BALANCED ACCURACY!
        # ========================================
        # Also check minimum recall threshold for minority classes
        min_minority_recall = min(per_class['class_1_recall'], per_class['class_2_recall'])
        
        if balanced_acc > best_balanced_acc and min_minority_recall >= config.get('min_minority_recall', 0.1):
            best_balanced_acc = balanced_acc
            torch.save(model.state_dict(), f"best_model_fold{fold}.pth")
            print(f"‚úÖ Best Model Saved! (Balanced Acc: {100*balanced_acc:.1f}%, "
                  f"Min Minority Recall: {100*min_minority_recall:.1f}%)")
        
        if early_stopping(balanced_acc):
            print(f"‚èπÔ∏è Early stopping at epoch {epoch+1} (balanced acc not improving)")
            break
    
    model.load_state_dict(torch.load(f"best_model_fold{fold}.pth"))
    
    return model, history, best_balanced_acc

## 8. Training with Balanced Oversampling (FIXED!)

In [None]:
kfold = StratifiedGroupKFold(n_splits=CONFIG['num_folds'], shuffle=True, random_state=CONFIG['seed'])

In [None]:
fold_results = []

In [None]:
for fold, (train_idx, val_idx) in enumerate(kfold.split(df_final, df_final['label'], df_final['study_id'])):
    if fold not in CONFIG['train_folds']:
        continue
    
    print(f"\n{'='*60}")
    print(f"FOLD {fold + 1}/{CONFIG['num_folds']}")
    print(f"{'='*60}")
    
    # Get original train/val splits
    train_df_original = df_final.iloc[train_idx].reset_index(drop=True)
    val_df = df_final.iloc[val_idx].reset_index(drop=True)
    
    # ========================================
    # Get ORIGINAL class counts BEFORE oversampling!
    # ========================================
    original_class_counts = np.bincount(train_df_original['label'].values, minlength=3)
    print(f"\nüìä ORIGINAL Class Distribution (for loss weights):")
    for i, count in enumerate(original_class_counts):
        print(f"   Class {i}: {count} samples")
    
    # Apply oversampling to training data ONLY
    train_df = create_stratified_balanced_df(
        train_df_original, 
        strategy=CONFIG['oversample_strategy'],
        random_state=CONFIG['seed'] + fold
    )
    
    print(f"\nüìä Training Data (after oversampling):")
    print(f"   Original: {len(train_df_original)} samples")
    print(f"   After Oversampling: {len(train_df)} samples")
    print(f"   Class distribution: {train_df['label'].value_counts().sort_index().to_dict()}")
    
    train_dataset = RSNASequenceDataset(
        train_df, 
        seq_length=CONFIG['seq_length'], 
        img_size=CONFIG['img_size'], 
        transform=train_aug,
        strong_transform=strong_aug,  # For oversampled data
        is_training=True
    )
    
    # Validation uses original (non-oversampled) data
    val_df['is_oversampled'] = False  # Add column for compatibility
    val_dataset = RSNASequenceDataset(
        val_df, 
        seq_length=CONFIG['seq_length'], 
        img_size=CONFIG['img_size'], 
        transform=val_aug,
        is_training=False
    )
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=CONFIG['batch_size'], 
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=CONFIG['batch_size'], 
        shuffle=False, 
        num_workers=2,
        pin_memory=True
    )
    
    model = SpineSeqAttention(
        num_classes=3,
        num_heads=CONFIG['num_attention_heads'],
        dropout=CONFIG['dropout']
    ).to(CONFIG['device'])
    
    # Pass original class counts to training function!
    model, history, best_balanced_acc = train_one_fold(
        model, train_loader, val_loader, fold, CONFIG, original_class_counts
    )
    
    fold_results.append({
        'fold': fold,
        'best_balanced_acc': best_balanced_acc,
        'history': history
    })
    
    print(f"\n‚úÖ Fold {fold+1} Complete | Best Balanced Acc: {100*best_balanced_acc:.1f}%")

In [None]:
print("\n" + "="*60)
print("SUMMARY")
print("="*60)
for r in fold_results:
    print(f"Fold {r['fold']+1}: Best Balanced Acc = {100*r['best_balanced_acc']:.1f}%")

## 9. Final Evaluation

In [None]:
# Load best model and evaluate
model.eval()
all_preds = []
all_labels = []
all_probs = []

In [None]:
with torch.no_grad():
    for images, labels, level_idx in val_loader:
        images = images.to(CONFIG['device'])
        level_idx = level_idx.to(CONFIG['device'])
        
        with autocast('cuda'):
            outputs, _ = model(images, level_idx)
            probs = F.softmax(outputs, dim=1)
        
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.numpy())
        all_probs.extend(probs.cpu().numpy())

In [None]:
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

In [None]:
print("\n" + "="*50)
print("CLASSIFICATION REPORT")
print("="*50)
print(classification_report(all_labels, all_preds, 
                           target_names=['Normal/Mild', 'Moderate', 'Severe']))

In [None]:
print("\n" + "="*60)
print("TRAINING COMPLETE - Version 4 (FIXED)")
print("="*60)
print(f"\nKey Fixes in v4:")
print(f"  ‚úì Class weights from ORIGINAL data (not oversampled!)")
print(f"  ‚úì Model selection based on BALANCED ACCURACY (not val_loss!)")
print(f"  ‚úì Increased focal_gamma to {CONFIG['focal_gamma']} (from 2.0)")
print(f"  ‚úì Stronger augmentations for oversampled samples")
print(f"  ‚úì Minimum minority recall threshold for model saving")
print(f"\nüéØ These changes should significantly improve Moderate and Severe recall!")