# RSNA 2024 Lumbar Spine ‚Äî Version 6
## Ordinal-Aware Pipeline with Anti-Overfitting

### v5 Post-Mortem:
- v5 achieved 74.5% BA (same as v4 at 74.9%)
- Model peaked at **epoch 3**, then overfitted
- 27.7% of Severe confused with Moderate (ordinal neighbor)
- TTA hurt (-0.9%) ‚Äî horizontal flip wrong for sagittal spine

### v6 Key Changes:
1. **CORAL Ordinal Loss** ‚Äî encodes class ordering (Normal < Moderate < Severe)
2. **Lower LR (1e-4)** + **4 epoch warmup** ‚Äî v5 peaked too early
3. **Frame Dropout** ‚Äî randomly masks frames to prevent sequence overfitting
4. **Stronger Augmentation** ‚Äî CoarseDropout, wider rotation, more scale variation
5. **Progressive Backbone Unfreezing** ‚Äî freeze backbone for first 3 epochs
6. **Fixed TTA** ‚Äî no horizontal flip


In [None]:
import os
import copy
import cv2
import glob
import pydicom
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
from tqdm import tqdm
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import confusion_matrix, classification_report, balanced_accuracy_score
from collections import Counter


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.amp import autocast, GradScaler
from torch.optim.swa_utils import AveragedModel, SWALR
import albumentations as A
from albumentations.pytorch import ToTensorV2


In [None]:
CONFIG = {
    'seed': 42,
    'img_size': 256,
    'seq_length': 7,
    'batch_size': 8,
    'epochs': 30,
    
    # Lower LR ‚Äî v5 peaked at epoch 3 meaning LR was too high
    'learning_rate': 1e-4,         # Reduced from 3e-4
    'backbone_lr': 1e-5,           # Reduced from 3e-5
    'weight_decay': 0.03,          # Slightly reduced
    'patience': 15,
    'num_folds': 5,
    'train_folds': [0],
    
    # Loss ‚Äî ordinal, no label smoothing
    'coral_lambda': 1.0,           # CORAL ordinal loss weight
    'ce_weight': 0.6,              # Small CE component for stability
    'label_smoothing': 0.0,        # REMOVED: was fighting ordinal signal
    
    # Training stability
    'clip_grad_norm': 1.0,
    'use_swa': True,
    'swa_start_epoch': 20,
    'swa_lr': 5e-6,
    
    # Architecture
    'hidden_dim': 256,
    'dropout': 0.4,
    'frame_dropout': 0.15,         # NEW: randomly mask frames
    'stochastic_depth_rate': 0.1,
    
    # Scheduler
    'warmup_epochs': 3,            # Increased from 2
    'freeze_backbone_epochs': 0,   # NEW: progressive unfreezing
    
    # Mixup
    'use_mixup': True,
    'mixup_alpha': 0.2,            # Reduced from 0.3
    
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'target_condition': 'spinal_canal_stenosis',
    'target_series': 'Sagittal T2/STIR'
}


In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(CONFIG['seed'])
print(f"‚úÖ Device: {CONFIG['device']}")
print(f"   Version: 6 (Ordinal-Aware + Anti-Overfitting)")
print(f"   LR: {CONFIG['learning_rate']} (head) / {CONFIG['backbone_lr']} (backbone)")
print(f"   Warmup: {CONFIG['warmup_epochs']} epochs, Backbone freeze: {CONFIG['freeze_backbone_epochs']} epochs")
print(f"   Frame dropout: {CONFIG['frame_dropout']}")
print(f"   CORAL ordinal loss + {CONFIG['ce_weight']} CE")


## 1. Data Loading

In [None]:
DATA_ROOT = "/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/"
TRAIN_IMAGES = os.path.join(DATA_ROOT, "train_images")


In [None]:
df_train = pd.read_csv(f"{DATA_ROOT}/train.csv")
df_coords = pd.read_csv(f"{DATA_ROOT}/train_label_coordinates.csv")
df_desc = pd.read_csv(f"{DATA_ROOT}/train_series_descriptions.csv")


In [None]:
df_train.columns = [col.lower().replace('/', '_') for col in df_train.columns]
condition_cols = [c for c in df_train.columns if c != 'study_id']
df_labels = pd.melt(df_train, id_vars=['study_id'], value_vars=condition_cols,
                    var_name='condition_level', value_name='severity')
df_labels = df_labels.dropna(subset=['severity'])
df_labels['severity'] = df_labels['severity'].astype(str).str.lower().str.replace('/', '_')

def extract_meta(val):
    parts = val.split('_')
    level = parts[-2] + '_' + parts[-1]
    condition = '_'.join(parts[:-2])
    return condition, level

df_labels[['base_condition', 'level_str']] = df_labels['condition_level'].apply(lambda x: pd.Series(extract_meta(x)))
severity_map = {'normal_mild': 0, 'moderate': 1, 'severe': 2}
df_labels['label'] = df_labels['severity'].map(severity_map)
df_labels = df_labels.dropna(subset=['label'])
df_labels['label'] = df_labels['label'].astype(int)


In [None]:
df_coords = df_coords.merge(df_desc, on=['study_id', 'series_id'], how='left')
df_coords['condition'] = df_coords['condition'].str.lower().str.replace(' ', '_')
df_coords['level'] = df_coords['level'].str.lower().str.replace('/', '_')
df_coords['condition_level'] = df_coords['condition'] + '_' + df_coords['level']

df_model = df_labels[df_labels['base_condition'] == CONFIG['target_condition']].copy()
df_coords_filt = df_coords[(df_coords['condition'] == CONFIG['target_condition']) & 
                           (df_coords['series_description'] == CONFIG['target_series'])]

df_final = df_model.merge(df_coords_filt[['study_id', 'condition_level', 'series_id', 'instance_number', 'x', 'y']],
                          on=['study_id', 'condition_level'], how='inner')


In [None]:
# Filter valid files
valid_rows = []
for index, row in tqdm(df_final.iterrows(), total=len(df_final), desc="Checking Files"):
    path = f"{TRAIN_IMAGES}/{row['study_id']}/{row['series_id']}/{int(row['instance_number'])}.dcm"
    if os.path.exists(path):
        valid_rows.append(row)

df_final = pd.DataFrame(valid_rows).reset_index(drop=True)
level_map = {'l1_l2': 0, 'l2_l3': 1, 'l3_l4': 2, 'l4_l5': 3, 'l5_s1': 4}
df_final['level_idx'] = df_final['level_str'].map(level_map)

print(f"\n‚úÖ Data Ready: {len(df_final)} samples")
class_counts = df_final['label'].value_counts().sort_index()
for i, count in enumerate(class_counts):
    pct = count / len(df_final) * 100
    print(f"   Class {i}: {count} samples ({pct:.1f}%)")


## 2. Weighted Sampler

In [None]:
def create_weighted_sampler(df):
    class_counts = np.bincount(df['label'].values, minlength=3).astype(float)
    class_weights = 1.0 / class_counts
    sample_weights = class_weights[df['label'].values]
    sampler = WeightedRandomSampler(
        weights=sample_weights, num_samples=len(df), replacement=True
    )
    print(f"üìä WeightedRandomSampler: counts={class_counts.astype(int).tolist()}")
    return sampler


## 3. Dataset with Frame Dropout

In [None]:
class RSNADatasetV6(Dataset):
    def __init__(self, df, seq_length=7, img_size=256, transform=None, 
                 is_training=False, frame_dropout=0.0):
        self.df = df.reset_index(drop=True)
        self.seq_length = seq_length
        self.img_size = img_size
        self.transform = transform
        self.is_training = is_training
        self.frame_dropout = frame_dropout
        self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        
    def __len__(self):
        return len(self.df)
    
    def load_dicom(self, path):
        try:
            dcm = pydicom.dcmread(path)
            img = dcm.pixel_array.astype(np.float32)
            
            # DICOM windowing
            if hasattr(dcm, 'WindowCenter') and hasattr(dcm, 'WindowWidth'):
                wc = dcm.WindowCenter
                ww = dcm.WindowWidth
                if isinstance(wc, pydicom.multival.MultiValue):
                    wc = float(wc[0])
                else:
                    wc = float(wc)
                if isinstance(ww, pydicom.multival.MultiValue):
                    ww = float(ww[0])
                else:
                    ww = float(ww)
                img = np.clip((img - (wc - ww/2)) / max(ww, 1) * 255, 0, 255)
            else:
                if img.max() > img.min():
                    img = (img - img.min()) / (img.max() - img.min()) * 255.0
                else:
                    img = np.zeros_like(img)
            
            img = img.astype(np.uint8)
            img = self.clahe.apply(img)
            return img
        except:
            return np.zeros((self.img_size, self.img_size), dtype=np.uint8)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        center_inst = int(row['instance_number'])
        study_path = f"{TRAIN_IMAGES}/{row['study_id']}/{row['series_id']}"
        cx, cy = int(row['x']), int(row['y'])
        
        # Crop jittering during training
        if self.is_training:
            jitter = self.img_size // 16  # ~6% jitter
            cx += random.randint(-jitter, jitter)
            cy += random.randint(-jitter, jitter)
        
        start = center_inst - (self.seq_length // 2)
        indices = [start + i for i in range(self.seq_length)]
        
        # Frame dropout mask (keep center frame always)
        if self.is_training and self.frame_dropout > 0:
            frame_mask = [random.random() > self.frame_dropout for _ in range(self.seq_length)]
            frame_mask[self.seq_length // 2] = True  # Always keep center
        else:
            frame_mask = [True] * self.seq_length
        
        images_list = []
        for i, inst in enumerate(indices):
            if not frame_mask[i]:
                # Dropped frame ‚Äî zero tensor
                if self.transform:
                    dummy = np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8)
                    res = self.transform(image=dummy)
                    images_list.append(res['image'] * 0)  # Zero after normalization
                else:
                    images_list.append(torch.zeros(3, self.img_size, self.img_size))
                continue
            
            path = os.path.join(study_path, f"{inst}.dcm")
            if os.path.exists(path):
                img = self.load_dicom(path)
            else:
                img = np.zeros((self.img_size, self.img_size), dtype=np.uint8)
            
            h, w = img.shape
            crop_size = self.img_size // 2
            x1 = max(0, cx - crop_size)
            y1 = max(0, cy - crop_size)
            x2 = min(w, cx + crop_size)
            y2 = min(h, cy + crop_size)
            crop = img[y1:y2, x1:x2]
            
            if crop.size == 0:
                crop = np.zeros((self.img_size, self.img_size), dtype=np.uint8)
            else:
                crop = cv2.resize(crop, (self.img_size, self.img_size))
            
            crop = cv2.cvtColor(crop, cv2.COLOR_GRAY2RGB)
            
            if self.transform:
                res = self.transform(image=crop)
                images_list.append(res['image'])
            else:
                images_list.append(torch.tensor(crop).permute(2, 0, 1).float() / 255.0)
            
        sequence = torch.stack(images_list, dim=0)
        label = torch.tensor(row['label'], dtype=torch.long)
        level_idx = torch.tensor(row['level_idx'], dtype=torch.long)
        
        return sequence, label, level_idx

print("‚úÖ RSNADatasetV6 with frame dropout ready")


## 4. Stronger Augmentation Pipeline

In [None]:
train_aug = A.Compose([
    # Spatial ‚Äî wider range than v5
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.15, rotate_limit=15,
                       border_mode=cv2.BORDER_CONSTANT, value=0, p=0.7),
    
    # Intensity ‚Äî medical imaging appropriate
    A.OneOf([
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1.0),
        A.RandomGamma(gamma_limit=(70, 130), p=1.0),
        A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=1.0),
    ], p=0.8),
    
    # Noise
    A.OneOf([
        A.GaussNoise(var_limit=(5.0, 40.0), p=1.0),
        A.MultiplicativeNoise(multiplier=(0.85, 1.15), p=1.0),
    ], p=0.4),
    
    # Geometric distortion
    A.OneOf([
        A.ElasticTransform(alpha=1, sigma=50, alpha_affine=25, p=1.0),
        A.GridDistortion(num_steps=5, distort_limit=0.1, p=1.0),
        A.OpticalDistortion(distort_limit=0.1, shift_limit=0.05, p=1.0),
    ], p=0.3),
    
    # Dropout ‚Äî forces model to use all spatial info
    A.CoarseDropout(max_holes=6, max_height=32, max_width=32,
                    min_holes=2, min_height=16, min_width=16,
                    fill_value=0, p=0.4),
    
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

val_aug = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

# TTA ‚Äî NO horizontal flip (wrong for sagittal spine)
tta_augs = [
    val_aug,  # Original
    A.Compose([  # Slight brightness variation
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1.0),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ]),
    A.Compose([  # Slight scale
        A.ShiftScaleRotate(shift_limit=0, scale_limit=0.05, rotate_limit=0, p=1.0),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ]),
]

print("‚úÖ Stronger augmentation pipeline")
print(f"   - CoarseDropout added (forces spatial diversity)")
print(f"   - Wider rotation: 15¬∞ (was 8¬∞)")
print(f"   - TTA: {len(tta_augs)} augmentations (no horizontal flip)")


## 5. Model Architecture

Same core as v5 (AttentionPool, FiLM, BiGRU) but output layer adapted for CORAL ordinal loss.
- CORAL outputs **K-1 = 2 logits** instead of K = 3
- Each logit represents P(Y > k): P(Y > 0) and P(Y > 1)


In [None]:
class AttentionPool(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.attn = nn.Sequential(
            nn.Linear(dim, dim // 4),
            nn.Tanh(),
            nn.Linear(dim // 4, 1)
        )
    def forward(self, x):
        weights = F.softmax(self.attn(x), dim=1)
        pooled = (x * weights).sum(dim=1)
        return pooled, weights.squeeze(-1)


class FiLMLayer(nn.Module):
    def __init__(self, num_levels, feature_dim):
        super().__init__()
        self.gamma = nn.Embedding(num_levels, feature_dim)
        self.beta = nn.Embedding(num_levels, feature_dim)
        nn.init.ones_(self.gamma.weight)
        nn.init.zeros_(self.beta.weight)
    def forward(self, x, level_idx):
        return self.gamma(level_idx) * x + self.beta(level_idx)


class SpineModelV6(nn.Module):
    def __init__(self, num_classes=3, hidden_dim=256, gru_layers=2,
                 dropout=0.4, num_levels=5, stochastic_depth=0.1):
        super().__init__()
        self.num_classes = num_classes
        
        effnet = models.efficientnet_v2_s(weights='IMAGENET1K_V1')
        if stochastic_depth > 0:
            blocks = list(effnet.features.children())
            num_blocks = len(blocks)
            for i, block in enumerate(blocks):
                if hasattr(block, 'stochastic_depth'):
                    block.stochastic_depth.p = stochastic_depth * (i / num_blocks)
        
        self.backbone = nn.Sequential(*list(effnet.children())[:-1])
        self.feature_dim = 1280
        
        self.feature_proj = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(self.feature_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU()
        )
        
        self.gru = nn.GRU(
            input_size=hidden_dim, hidden_size=hidden_dim // 2,
            num_layers=gru_layers, batch_first=True, bidirectional=True,
            dropout=dropout if gru_layers > 1 else 0
        )
        
        self.attn_pool = AttentionPool(hidden_dim)
        self.film = FiLMLayer(num_levels, hidden_dim)
        
        # Dual output: CORAL logits + standard CE logits
        # CORAL head: K-1 ordinal logits with shared features + rank-specific biases
        self.ordinal_features = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 128),
            nn.GELU(),
            nn.Dropout(dropout * 0.5),
        )
        self.ordinal_fc = nn.Linear(128, 1, bias=False)  # Shared weight
        self.ordinal_bias = nn.Parameter(torch.tensor([-2.0, -2.95]))  # Prior: P(Y>0)‚âà12%, P(Y>1)‚âà5%  # Per-rank bias
        
        # CE head for stability
        self.ce_head = nn.Linear(128, num_classes)
        
    def forward(self, x, level_idx=None):
        b, s, c, h, w = x.size()
        x = x.view(b * s, c, h, w)
        
        features = self.backbone(x)
        features = features.view(b, s, -1)
        features = self.feature_proj(features)
        
        gru_out, _ = self.gru(features)
        context, attn_weights = self.attn_pool(gru_out)
        
        if level_idx is not None:
            context = self.film(context, level_idx)
        
        # Shared feature extraction
        shared = self.ordinal_features(context)
        
        # CORAL ordinal logits: shared_weight * features + per_rank_bias
        ordinal_logits = self.ordinal_fc(shared) + self.ordinal_bias.unsqueeze(0)
        # ordinal_logits shape: (B, K-1)
        
        # CE logits for stability
        ce_logits = self.ce_head(shared)
        
        return {
            'ordinal': ordinal_logits,
            'ce': ce_logits,
            'attention': attn_weights
        }
    
    def predict_proba(self, ordinal_logits):
        """Convert CORAL ordinal logits to class probabilities."""
        cumprobs = torch.sigmoid(ordinal_logits)  # P(Y > k) for k=0,1
        # P(Y=0) = 1 - P(Y>0)
        # P(Y=1) = P(Y>0) - P(Y>1)
        # P(Y=2) = P(Y>1)
        probs = torch.zeros(cumprobs.size(0), self.num_classes, device=cumprobs.device)
        probs[:, 0] = 1 - cumprobs[:, 0]
        probs[:, 1] = cumprobs[:, 0] - cumprobs[:, 1]
        probs[:, 2] = cumprobs[:, 1]
        # Clamp to avoid negative probabilities from floating point
        probs = probs.clamp(min=0)
        # Renormalize
        probs = probs / probs.sum(dim=1, keepdim=True).clamp(min=1e-8)
        return probs

print("‚úÖ SpineModelV6: CORAL ordinal output + CE head")


## 6. CORAL Ordinal Loss + CE

In [None]:
class CoralLoss(nn.Module):
    """
    CORAL (Consistent Rank Logits) ordinal loss.
    
    For K classes, uses K-1 binary classifiers: P(Y > k)
    This naturally encodes the ordinal structure:
    - Label 0 (Normal): both P(Y>0) and P(Y>1) should be low
    - Label 1 (Moderate): P(Y>0) high, P(Y>1) low
    - Label 2 (Severe): both P(Y>0) and P(Y>1) high
    
    Key advantage: predicting Normal‚ÜíSevere requires TWO binary mistakes,
    making distant ordinal errors much less likely.
    """
    def __init__(self, num_classes=3):
        super().__init__()
        self.num_classes = num_classes
    
    def forward(self, ordinal_logits, labels):
        # ordinal_logits: (B, K-1) ‚Äî raw logits for P(Y > k)
        # labels: (B,) ‚Äî class indices 0, 1, 2
        
        # Create ordinal targets: for label=k, target is 1 for all j < k
        # label=0: [0, 0]  label=1: [1, 0]  label=2: [1, 1]
        batch_size = labels.size(0)
        levels = torch.arange(self.num_classes - 1, device=labels.device)
        targets = (labels.unsqueeze(1) > levels.unsqueeze(0)).float()
        
        # Binary cross-entropy for each ordinal threshold
        loss = F.binary_cross_entropy_with_logits(ordinal_logits, targets, reduction='mean')
        return loss


class CombinedOrdinalLoss(nn.Module):
    """CORAL + small CE for training stability."""
    def __init__(self, num_classes=3, ce_weight=0.3):
        super().__init__()
        self.coral = CoralLoss(num_classes)
        self.ce_weight = ce_weight
    
    def forward(self, outputs, labels):
        coral_loss = self.coral(outputs['ordinal'], labels)
        ce_loss = F.cross_entropy(outputs['ce'], labels)
        total = coral_loss + self.ce_weight * ce_loss
        return total, {
            'total': total.item(),
            'coral': coral_loss.item(),
            'ce': ce_loss.item()
        }


def mixup_data(x, y, alpha=0.2):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0
    batch_size = x.size(0)
    index = torch.randperm(batch_size, device=x.device)
    mixed_x = lam * x + (1 - lam) * x[index]
    return mixed_x, y, y[index], lam

print("‚úÖ CORAL ordinal loss + CE ready")
print("   - CORAL encodes Normal < Moderate < Severe ordering")
print("   - CE provides gradient stability")


In [None]:
def compute_per_class_metrics(preds, labels, num_classes=3):
    metrics = {}
    for c in range(num_classes):
        mask = (labels == c)
        if mask.sum() > 0:
            correct = ((preds == c) & mask).sum()
            metrics[f'class_{c}_recall'] = correct / mask.sum()
        else:
            metrics[f'class_{c}_recall'] = 0.0
    return metrics

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001, mode='max'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        
    def __call__(self, val_score):
        if self.best_score is None:
            self.best_score = val_score
            return False
        improved = (val_score > self.best_score + self.min_delta) if self.mode == 'max' \
                   else (val_score < self.best_score - self.min_delta)
        if improved:
            self.best_score = val_score
            self.counter = 0
            return False
        self.counter += 1
        return self.counter >= self.patience


## 7. Training Function v6

In [None]:
def train_one_fold_v6(model, train_loader, val_loader, fold, config):
    criterion = CombinedOrdinalLoss(num_classes=3, ce_weight=config['ce_weight'])
    
    optimizer = optim.AdamW([
        {'params': model.backbone.parameters(), 'lr': config['backbone_lr']},
        {'params': model.feature_proj.parameters(), 'lr': config['learning_rate']},
        {'params': model.gru.parameters(), 'lr': config['learning_rate']},
        {'params': model.attn_pool.parameters(), 'lr': config['learning_rate']},
        {'params': model.film.parameters(), 'lr': config['learning_rate']},
        {'params': model.ordinal_features.parameters(), 'lr': config['learning_rate']},
        {'params': model.ordinal_fc.parameters(), 'lr': config['learning_rate']},
        {'params': [model.ordinal_bias], 'lr': config['learning_rate']},
        {'params': model.ce_head.parameters(), 'lr': config['learning_rate']},
    ], weight_decay=config['weight_decay'])
    
    warmup_steps = config['warmup_epochs'] * len(train_loader)
    total_steps = config['epochs'] * len(train_loader)
    
    def lr_lambda(step):
        if step < warmup_steps:
            return step / max(warmup_steps, 1)
        progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
        return max(0.5 * (1 + np.cos(np.pi * progress)), 1e-6 / config['learning_rate'])
    
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    scaler = GradScaler('cuda')
    
    swa_model = AveragedModel(model) if config['use_swa'] else None
    swa_scheduler = SWALR(optimizer, swa_lr=config['swa_lr']) if config['use_swa'] else None
    
    early_stopping = EarlyStopping(patience=config['patience'], min_delta=0.003, mode='max')
    
    best_balanced_acc = 0.0
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [], 'balanced_acc': [],
        'class_0_recall': [], 'class_1_recall': [], 'class_2_recall': [],
        'coral_loss': [], 'ce_loss': []
    }
    
    print(f"\nüöÄ Training Fold {fold+1}/{config['num_folds']} (v6 ‚Äî Ordinal)")
    print(f"   Train: {len(train_loader.dataset)}, Val: {len(val_loader.dataset)}")
    print(f"   CORAL + {config['ce_weight']}*CE, LR: {config['learning_rate']}")
    print(f"   Backbone frozen for first {config['freeze_backbone_epochs']} epochs")
    
    # Progressive backbone unfreezing
    for param in model.backbone.parameters():
        param.requires_grad = False
    backbone_frozen = True
    
    for epoch in range(config['epochs']):
        # Unfreeze backbone after N epochs
        if backbone_frozen and epoch >= config['freeze_backbone_epochs']:
            for param in model.backbone.parameters():
                param.requires_grad = True
            backbone_frozen = False
            print(f"   üîì Backbone unfrozen at epoch {epoch+1}")
        
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        epoch_coral = 0
        epoch_ce = 0
        is_swa_phase = config['use_swa'] and epoch >= config['swa_start_epoch']
        
        status = "[FROZEN]" if backbone_frozen else ("[SWA]" if is_swa_phase else "")
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']} {status}")
        
        for images, labels, level_idx in loop:
            images = images.to(config['device'])
            labels = labels.to(config['device'])
            level_idx = level_idx.to(config['device'])
            
            # Mixup (skip during SWA phase)
            use_mixup = config['use_mixup'] and not is_swa_phase and random.random() < 0.5
            if use_mixup:
                images, labels_a, labels_b, lam = mixup_data(images, labels, config['mixup_alpha'])
            
            optimizer.zero_grad()
            
            with autocast('cuda'):
                outputs = model(images, level_idx)
                if use_mixup:
                    loss_a, _ = criterion(outputs, labels_a)
                    loss_b, _ = criterion(outputs, labels_b)
                    loss = lam * loss_a + (1 - lam) * loss_b
                    loss_dict = {'coral': 0, 'ce': 0}  # Skip tracking for mixup
                else:
                    loss, loss_dict = criterion(outputs, labels)
                    epoch_coral += loss_dict['coral']
                    epoch_ce += loss_dict['ce']
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config['clip_grad_norm'])
            scaler.step(optimizer)
            scaler.update()
            
            if is_swa_phase:
                swa_scheduler.step()
            else:
                scheduler.step()
            
            train_loss += loss.item()
            
            # Predictions from ordinal logits
            with torch.no_grad():
                probs = model.predict_proba(outputs['ordinal'])
                predicted = probs.argmax(dim=1)
            
            if use_mixup:
                train_correct += (lam * (predicted == labels_a).float() + 
                                  (1 - lam) * (predicted == labels_b).float()).sum().item()
            else:
                train_correct += (predicted == labels).sum().item()
            train_total += labels.size(0)
            
            loop.set_postfix(
                loss=f"{train_loss/(loop.n+1):.4f}",
                acc=f"{100*train_correct/train_total:.1f}%",
                lr=f"{optimizer.param_groups[0]['lr']:.2e}"
            )
        
        n_batches = len(train_loader)
        train_epoch_loss = train_loss / n_batches
        train_acc = train_correct / train_total
        
        if swa_model is not None and is_swa_phase:
            swa_model.update_parameters(model)
        
        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for images, labels, level_idx in val_loader:
                images = images.to(config['device'])
                labels = labels.to(config['device'])
                level_idx = level_idx.to(config['device'])
                
                with autocast('cuda'):
                    outputs = model(images, level_idx)
                    loss, _ = criterion(outputs, labels)
                
                val_loss += loss.item()
                probs = model.predict_proba(outputs['ordinal'])
                predicted = probs.argmax(dim=1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)
                
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        val_epoch_loss = val_loss / len(val_loader)
        val_acc = val_correct / val_total
        
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        per_class = compute_per_class_metrics(all_preds, all_labels)
        
        balanced_acc = (per_class['class_0_recall'] + 
                       per_class['class_1_recall'] + 
                       per_class['class_2_recall']) / 3
        
        history['train_loss'].append(train_epoch_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_epoch_loss)
        history['val_acc'].append(val_acc)
        history['balanced_acc'].append(balanced_acc)
        history['class_0_recall'].append(per_class['class_0_recall'])
        history['class_1_recall'].append(per_class['class_1_recall'])
        history['class_2_recall'].append(per_class['class_2_recall'])
        history['coral_loss'].append(epoch_coral / n_batches)
        history['ce_loss'].append(epoch_ce / n_batches)
        
        print(f"üìä Train Loss: {train_epoch_loss:.4f} | Train Acc: {100*train_acc:.1f}% | "
              f"Val Loss: {val_epoch_loss:.4f} | Val Acc: {100*val_acc:.1f}%")
        print(f"   Per-class: Normal={100*per_class['class_0_recall']:.1f}%, "
              f"Moderate={100*per_class['class_1_recall']:.1f}%, "
              f"Severe={100*per_class['class_2_recall']:.1f}%")
        print(f"   üéØ Balanced Accuracy: {100*balanced_acc:.1f}%"
              f"{' [SWA]' if is_swa_phase else ''}"
              f"{' [FROZEN]' if backbone_frozen else ''}")
        
        min_minority_recall = min(per_class['class_1_recall'], per_class['class_2_recall'])
        
        if balanced_acc > best_balanced_acc and min_minority_recall >= config.get('min_minority_recall', 0.1):
            best_balanced_acc = balanced_acc
            torch.save(model.state_dict(), f"best_model_v6_fold{fold}.pth")
            print(f"‚úÖ Best Model Saved! (BA: {100*balanced_acc:.1f}%, "
                  f"Min Minority: {100*min_minority_recall:.1f}%)")
        
        if early_stopping(balanced_acc):
            print(f"‚èπÔ∏è Early stopping at epoch {epoch+1}")
            break
    
    model.load_state_dict(torch.load(f"best_model_v6_fold{fold}.pth"))
    return model, history, best_balanced_acc


## 8. Training

In [None]:
kfold = StratifiedGroupKFold(n_splits=CONFIG['num_folds'], shuffle=True, random_state=CONFIG['seed'])
fold_results = []


In [None]:
for fold, (train_idx, val_idx) in enumerate(kfold.split(df_final, df_final['label'], df_final['study_id'])):
    if fold not in CONFIG['train_folds']:
        continue
    
    print(f"\n{'='*60}")
    print(f"FOLD {fold + 1}/{CONFIG['num_folds']} (v6 ‚Äî Ordinal)")
    print(f"{'='*60}")
    
    train_df = df_final.iloc[train_idx].reset_index(drop=True)
    val_df = df_final.iloc[val_idx].reset_index(drop=True)
    
    print(f"\nüìä Class Distribution:")
    for i in range(3):
        count = (train_df['label'] == i).sum()
        print(f"   Class {i}: {count} ({100*count/len(train_df):.1f}%)")
    
    sampler = create_weighted_sampler(train_df)
    
    train_dataset = RSNADatasetV6(
        train_df, seq_length=CONFIG['seq_length'], img_size=CONFIG['img_size'],
        transform=train_aug, is_training=True, 
        frame_dropout=CONFIG['frame_dropout']
    )
    val_dataset = RSNADatasetV6(
        val_df, seq_length=CONFIG['seq_length'], img_size=CONFIG['img_size'],
        transform=val_aug, is_training=False
    )
    
    train_loader = DataLoader(
        train_dataset, batch_size=CONFIG['batch_size'], sampler=sampler,
        num_workers=2, pin_memory=True, drop_last=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=CONFIG['batch_size'], shuffle=False,
        num_workers=2, pin_memory=True
    )
    
    model = SpineModelV6(
        num_classes=3, hidden_dim=CONFIG['hidden_dim'],
        dropout=CONFIG['dropout'], stochastic_depth=CONFIG['stochastic_depth_rate']
    ).to(CONFIG['device'])
    
    param_count = sum(p.numel() for p in model.parameters())
    print(f"\nüèóÔ∏è  Model: SpineModelV6 ({param_count:,} params)")
    
    model, history, best_balanced_acc = train_one_fold_v6(
        model, train_loader, val_loader, fold, CONFIG
    )
    
    fold_results.append({
        'fold': fold,
        'best_balanced_acc': best_balanced_acc,
        'history': history
    })
    
    print(f"\n‚úÖ Fold {fold+1} Complete | Best BA: {100*best_balanced_acc:.1f}%")


In [None]:
print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
for r in fold_results:
    print(f"Fold {r['fold']+1}: Best BA = {100*r['best_balanced_acc']:.1f}%")


## 9. Evaluation with TTA

In [None]:
def predict_with_tta_v6(model, dataset_df, config, tta_augs):
    model.eval()
    all_probs = []
    all_labels = None
    
    for aug_idx, aug in enumerate(tta_augs):
        ds = RSNADatasetV6(
            dataset_df, seq_length=config['seq_length'], img_size=config['img_size'],
            transform=aug, is_training=False
        )
        loader = DataLoader(ds, batch_size=config['batch_size'], shuffle=False, 
                          num_workers=2, pin_memory=True)
        
        aug_probs = []
        if aug_idx == 0:
            labels_list = []
        
        with torch.no_grad():
            for images, labels, level_idx in loader:
                images = images.to(config['device'])
                level_idx = level_idx.to(config['device'])
                
                with autocast('cuda'):
                    outputs = model(images, level_idx)
                    probs = model.predict_proba(outputs['ordinal'])
                
                aug_probs.append(probs.cpu().numpy())
                if aug_idx == 0:
                    labels_list.extend(labels.numpy())
        
        all_probs.append(np.concatenate(aug_probs, axis=0))
        if aug_idx == 0:
            all_labels = np.array(labels_list)
    
    avg_probs = np.mean(all_probs, axis=0)
    avg_preds = np.argmax(avg_probs, axis=1)
    return avg_preds, all_labels, avg_probs


In [None]:
# Run evaluation
model.eval()
tta_preds, tta_labels, _ = predict_with_tta_v6(model, val_df, CONFIG, tta_augs)
no_tta_preds, _, _ = predict_with_tta_v6(model, val_df, CONFIG, [val_aug])

pc_no = compute_per_class_metrics(no_tta_preds, tta_labels)
ba_no = np.mean([pc_no[f'class_{c}_recall'] for c in range(3)])

pc_tta = compute_per_class_metrics(tta_preds, tta_labels)
ba_tta = np.mean([pc_tta[f'class_{c}_recall'] for c in range(3)])

print(f"\n{'='*60}")
print(f"RESULTS COMPARISON")
print(f"{'='*60}")
print(f"\nWithout TTA:  BA={100*ba_no:.1f}%  N={100*pc_no['class_0_recall']:.1f}%  "
      f"M={100*pc_no['class_1_recall']:.1f}%  S={100*pc_no['class_2_recall']:.1f}%")
print(f"With TTA:     BA={100*ba_tta:.1f}%  N={100*pc_tta['class_0_recall']:.1f}%  "
      f"M={100*pc_tta['class_1_recall']:.1f}%  S={100*pc_tta['class_2_recall']:.1f}%")
print(f"TTA delta:    {100*(ba_tta-ba_no):+.1f}%")


In [None]:
print("\n" + "="*50)
print("CLASSIFICATION REPORT")
print("="*50)
print(classification_report(tta_labels, tta_preds,
                           target_names=['Normal/Mild', 'Moderate', 'Severe']))

# Confusion matrix
cm = confusion_matrix(tta_labels, tta_preds)
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
plt.figure(figsize=(8, 6))
sns.heatmap(cm_norm, annot=True, fmt='.2%', cmap='Blues',
            xticklabels=['Normal/Mild', 'Moderate', 'Severe'],
            yticklabels=['Normal/Mild', 'Moderate', 'Severe'])
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title(f'v6 Confusion Matrix (BA: {100*ba_tta:.1f}%)')
plt.tight_layout()
plt.show()


In [None]:
# Training history plots
if fold_results:
    h = fold_results[0]['history']
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    epochs = range(1, len(h['train_loss']) + 1)
    
    axes[0].plot(epochs, h['train_loss'], 'b-', label='Train')
    axes[0].plot(epochs, h['val_loss'], 'r-', label='Val')
    axes[0].set_title('Loss'); axes[0].legend(); axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(epochs, h['class_0_recall'], 'g-o', label='Normal', ms=3)
    axes[1].plot(epochs, h['class_1_recall'], color='orange', marker='s', label='Moderate', ms=3)
    axes[1].plot(epochs, h['class_2_recall'], 'r-^', label='Severe', ms=3)
    axes[1].set_title('Per-Class Recall'); axes[1].legend(); axes[1].grid(True, alpha=0.3)
    
    axes[2].plot(epochs, h['balanced_acc'], 'purple', marker='d', lw=2, ms=3)
    axes[2].set_title(f'Balanced Acc (Best: {100*max(h["balanced_acc"]):.1f}%)')
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()


In [None]:
print("\n" + "="*60)
print("TRAINING COMPLETE ‚Äî Version 6 (Ordinal-Aware)")
print("="*60)
print(f"  ‚úì CORAL ordinal loss (encodes Normal < Moderate < Severe)")
print(f"  ‚úì Lower LR: {CONFIG['learning_rate']} head / {CONFIG['backbone_lr']} backbone")
print(f"  ‚úì 4-epoch warmup + 3-epoch backbone freeze")
print(f"  ‚úì Frame dropout: {CONFIG['frame_dropout']}")
print(f"  ‚úì Stronger augmentation (CoarseDropout, wider rotation)")
print(f"  ‚úì Fixed TTA (no horizontal flip)")
print(f"  ‚úì Gradient clipping + SWA")
