# RSNA 2024 Lumbar Spine ‚Äî Version 7
## 2.5D CNN Approach (Fundamentally Different)

### Philosophy: Simplicity Over Complexity

Previous versions (v4-v6) used: `CNN ‚Üí RNN/GRU ‚Üí Attention ‚Üí Classify`  
**v7 eliminates the sequence model entirely** and uses a **2.5D CNN** approach.

### How 2.5D Works:
- Take 7 adjacent slices centered on the diagnostic level
- **Stack them as input channels** (7 channels instead of 3 RGB)
- Feed through a single CNN (adapted for 7-channel input)
- The CNN learns cross-slice spatial features naturally

### Why This Should Work Better:
1. **Fewer parameters** ‚Äî no GRU, no attention pool = less overfitting
2. **Joint spatial-temporal learning** ‚Äî CNN kernels see across slices natively
3. **Proven approach** ‚Äî used by top RSNA competition solutions
4. **Simple + strong baseline** ‚Äî easier to debug and tune

### Other Key Features:
- Multi-head classification (main + ordinal)
- WeightedRandomSampler for class balance
- DICOM windowing preserved
- Gradient clipping + SWA
- No horizontal flip in TTA


In [None]:
import os
import copy
import cv2
import glob
import pydicom
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
from tqdm import tqdm
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import confusion_matrix, classification_report, balanced_accuracy_score
from collections import Counter


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.amp import autocast, GradScaler
from torch.optim.swa_utils import AveragedModel, SWALR
import albumentations as A
from albumentations.pytorch import ToTensorV2


In [None]:
CONFIG = {
    'seed': 42,
    'img_size': 224,               # Slightly smaller for 2.5D (saves memory)
    'num_slices': 7,               # Adjacent slices to stack
    'batch_size': 16,              # Larger batch (no RNN = less memory)
    'epochs': 30,
    
    'learning_rate': 2e-4,
    'backbone_lr': 2e-5,
    'weight_decay': 0.03,
    'patience': 12,
    'num_folds': 5,
    'train_folds': [0],
    
    # Loss
    'focal_gamma': 2.0,
    'ordinal_weight': 0.5,        # Weight for ordinal auxiliary loss
    
    # Training
    'clip_grad_norm': 1.0,
    'use_swa': True,
    'swa_start_epoch': 20,
    'swa_lr': 5e-6,
    'warmup_epochs': 3,
    'freeze_backbone_epochs': 2,
    
    # Architecture
    'dropout': 0.3,               # Less dropout (simpler model needs less regularization)
    
    # Mixup
    'use_mixup': True,
    'mixup_alpha': 0.2,
    
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'target_condition': 'spinal_canal_stenosis',
    'target_series': 'Sagittal T2/STIR'
}


In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(CONFIG['seed'])
print(f"‚úÖ Device: {CONFIG['device']}")
print(f"   Version: 7 (2.5D CNN ‚Äî No RNN)")
print(f"   Image: {CONFIG['img_size']}px, {CONFIG['num_slices']} slices stacked as channels")
print(f"   Batch: {CONFIG['batch_size']} (larger ‚Äî no RNN memory overhead)")


## 1. Data Loading

In [None]:
DATA_ROOT = "/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/"
TRAIN_IMAGES = os.path.join(DATA_ROOT, "train_images")

df_train = pd.read_csv(f"{DATA_ROOT}/train.csv")
df_coords = pd.read_csv(f"{DATA_ROOT}/train_label_coordinates.csv")
df_desc = pd.read_csv(f"{DATA_ROOT}/train_series_descriptions.csv")

# Clean & Merge
df_train.columns = [col.lower().replace('/', '_') for col in df_train.columns]
condition_cols = [c for c in df_train.columns if c != 'study_id']
df_labels = pd.melt(df_train, id_vars=['study_id'], value_vars=condition_cols,
                    var_name='condition_level', value_name='severity')
df_labels = df_labels.dropna(subset=['severity'])
df_labels['severity'] = df_labels['severity'].astype(str).str.lower().str.replace('/', '_')

def extract_meta(val):
    parts = val.split('_')
    level = parts[-2] + '_' + parts[-1]
    condition = '_'.join(parts[:-2])
    return condition, level

df_labels[['base_condition', 'level_str']] = df_labels['condition_level'].apply(lambda x: pd.Series(extract_meta(x)))
severity_map = {'normal_mild': 0, 'moderate': 1, 'severe': 2}
df_labels['label'] = df_labels['severity'].map(severity_map)
df_labels = df_labels.dropna(subset=['label'])
df_labels['label'] = df_labels['label'].astype(int)

df_coords = df_coords.merge(df_desc, on=['study_id', 'series_id'], how='left')
df_coords['condition'] = df_coords['condition'].str.lower().str.replace(' ', '_')
df_coords['level'] = df_coords['level'].str.lower().str.replace('/', '_')
df_coords['condition_level'] = df_coords['condition'] + '_' + df_coords['level']

df_model = df_labels[df_labels['base_condition'] == CONFIG['target_condition']].copy()
df_coords_filt = df_coords[(df_coords['condition'] == CONFIG['target_condition']) & 
                           (df_coords['series_description'] == CONFIG['target_series'])]

df_final = df_model.merge(df_coords_filt[['study_id', 'condition_level', 'series_id', 'instance_number', 'x', 'y']],
                          on=['study_id', 'condition_level'], how='inner')


In [None]:
# Filter valid files
valid_rows = []
for index, row in tqdm(df_final.iterrows(), total=len(df_final), desc="Checking Files"):
    path = f"{TRAIN_IMAGES}/{row['study_id']}/{row['series_id']}/{int(row['instance_number'])}.dcm"
    if os.path.exists(path):
        valid_rows.append(row)

df_final = pd.DataFrame(valid_rows).reset_index(drop=True)
level_map = {'l1_l2': 0, 'l2_l3': 1, 'l3_l4': 2, 'l4_l5': 3, 'l5_s1': 4}
df_final['level_idx'] = df_final['level_str'].map(level_map)

print(f"\n‚úÖ Data Ready: {len(df_final)} samples")
class_counts = df_final['label'].value_counts().sort_index()
for i, count in enumerate(class_counts):
    pct = count / len(df_final) * 100
    print(f"   Class {i}: {count} samples ({pct:.1f}%)")


## 2. Weighted Sampler

In [None]:
def create_weighted_sampler(df):
    class_counts = np.bincount(df['label'].values, minlength=3).astype(float)
    class_weights = 1.0 / class_counts
    sample_weights = class_weights[df['label'].values]
    sampler = WeightedRandomSampler(
        weights=sample_weights, num_samples=len(df), replacement=True
    )
    print(f"üìä WeightedRandomSampler: counts={class_counts.astype(int).tolist()}")
    return sampler


## 3. 2.5D Dataset

**Key difference from v4-v6**: Instead of returning `(7, 3, H, W)` sequence tensors,
this dataset returns `(7, H, W)` ‚Äî 7 grayscale slices stacked as channels.

The CNN's first conv layer is adapted to accept 7 input channels.


In [None]:
class RSNA25DDataset(Dataset):
    """
    2.5D Dataset: stacks adjacent slices as channels.
    Output shape: (num_slices, H, W) instead of (seq, 3, H, W)
    """
    def __init__(self, df, num_slices=7, img_size=224, transform=None, is_training=False):
        self.df = df.reset_index(drop=True)
        self.num_slices = num_slices
        self.img_size = img_size
        self.transform = transform
        self.is_training = is_training
        self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        
    def __len__(self):
        return len(self.df)
    
    def load_dicom(self, path):
        try:
            dcm = pydicom.dcmread(path)
            img = dcm.pixel_array.astype(np.float32)
            if hasattr(dcm, 'WindowCenter') and hasattr(dcm, 'WindowWidth'):
                wc = dcm.WindowCenter
                ww = dcm.WindowWidth
                if isinstance(wc, pydicom.multival.MultiValue): wc = float(wc[0])
                else: wc = float(wc)
                if isinstance(ww, pydicom.multival.MultiValue): ww = float(ww[0])
                else: ww = float(ww)
                img = np.clip((img - (wc - ww/2)) / max(ww, 1) * 255, 0, 255)
            else:
                if img.max() > img.min():
                    img = (img - img.min()) / (img.max() - img.min()) * 255.0
            img = img.astype(np.uint8)
            img = self.clahe.apply(img)
            return img
        except:
            return np.zeros((self.img_size, self.img_size), dtype=np.uint8)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        center_inst = int(row['instance_number'])
        study_path = f"{TRAIN_IMAGES}/{row['study_id']}/{row['series_id']}"
        cx, cy = int(row['x']), int(row['y'])
        
        if self.is_training:
            jitter = self.img_size // 16
            cx += random.randint(-jitter, jitter)
            cy += random.randint(-jitter, jitter)
        
        half = self.num_slices // 2
        indices = [center_inst + i - half for i in range(self.num_slices)]
        
        # Load all slices as single-channel images
        slices = []
        for inst in indices:
            path = os.path.join(study_path, f"{inst}.dcm")
            if os.path.exists(path):
                img = self.load_dicom(path)
            else:
                img = np.zeros((self.img_size, self.img_size), dtype=np.uint8)
            
            h, w = img.shape
            crop_size = self.img_size // 2
            x1 = max(0, cx - crop_size)
            y1 = max(0, cy - crop_size)
            x2 = min(w, cx + crop_size)
            y2 = min(h, cy + crop_size)
            crop = img[y1:y2, x1:x2]
            
            if crop.size == 0:
                crop = np.zeros((self.img_size, self.img_size), dtype=np.uint8)
            else:
                crop = cv2.resize(crop, (self.img_size, self.img_size))
            
            slices.append(crop)
        
        # Stack as multi-channel image: (H, W, num_slices)
        multi_channel = np.stack(slices, axis=-1)  # (H, W, 7)
        
        # Apply augmentation (treats each channel consistently)
        if self.transform:
            res = self.transform(image=multi_channel)
            tensor = res['image']  # (7, H, W) after ToTensorV2
        else:
            tensor = torch.tensor(multi_channel).permute(2, 0, 1).float() / 255.0
        
        label = torch.tensor(row['label'], dtype=torch.long)
        level_idx = torch.tensor(row['level_idx'], dtype=torch.long)
        
        return tensor, label, level_idx

print("‚úÖ RSNA25DDataset ready")
print(f"   Output shape: ({CONFIG['num_slices']}, {CONFIG['img_size']}, {CONFIG['img_size']})")
print(f"   No sequence dimension ‚Äî slices are channels")


## 4. Augmentation (7-channel compatible)

In [None]:
# Albumentations handles arbitrary channel counts natively
train_aug = A.Compose([
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.15, rotate_limit=15,
                       border_mode=cv2.BORDER_CONSTANT, value=0, p=0.7),
    A.OneOf([
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1.0),
        A.RandomGamma(gamma_limit=(70, 130), p=1.0),
    ], p=0.7),
    A.OneOf([
        A.GaussNoise(var_limit=(5.0, 40.0), p=1.0),
        A.MultiplicativeNoise(multiplier=(0.85, 1.15), p=1.0),
    ], p=0.3),
    A.OneOf([
        A.ElasticTransform(alpha=1, sigma=50, alpha_affine=25, p=1.0),
        A.GridDistortion(num_steps=5, distort_limit=0.1, p=1.0),
    ], p=0.25),
    A.CoarseDropout(max_holes=5, max_height=28, max_width=28,
                    min_holes=1, min_height=12, min_width=12,
                    fill_value=0, p=0.3),
    # No Normalize ‚Äî we handle this in the model's stem
    ToTensorV2()
])

val_aug = A.Compose([
    ToTensorV2()
])

# TTA (no horizontal flip)
tta_augs = [
    val_aug,
    A.Compose([
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1.0),
        ToTensorV2()
    ]),
    A.Compose([
        A.ShiftScaleRotate(shift_limit=0, scale_limit=0.05, rotate_limit=0, p=1.0),
        ToTensorV2()
    ]),
]

print(f"‚úÖ Augmentation (7-channel compatible): {len(tta_augs)} TTA variants")


## 5. 2.5D CNN Model

**Architecture:**
```
Input: (B, 7, 224, 224)  ‚Äî 7 slices as channels
  ‚Üí Stem conv (7 ‚Üí 3 channels, initialized from pretrained)
  ‚Üí EfficientNet-V2-S backbone (pretrained)
  ‚Üí Global Average Pool ‚Üí 1280 features
  ‚Üí Level conditioning (FiLM)
  ‚Üí Dual head: main classifier + ordinal head
```

The stem conv **adapts 7-channel input to 3 channels** while preserving pretrained weights.
The center 3 channels use the original ImageNet weights; edge channels start near zero.


In [None]:
class Spine25DModel(nn.Module):
    """
    2.5D CNN: no RNN, no attention ‚Äî pure CNN.
    Adjacent slices stacked as channels ‚Üí CNN ‚Üí classify.
    """
    def __init__(self, num_classes=3, num_slices=7, dropout=0.3, num_levels=5):
        super().__init__()
        self.num_classes = num_classes
        
        # Load pretrained backbone
        effnet = models.efficientnet_v2_s(weights='IMAGENET1K_V1')
        
        # Get original first conv weights (3-channel)
        orig_conv = effnet.features[0][0]  # First Conv2dNormActivation
        orig_weight = orig_conv.weight.data  # (out_channels, 3, kH, kW)
        
        # Create new stem: 7 channels ‚Üí same output as original first conv
        out_channels = orig_weight.shape[0]
        kH, kW = orig_weight.shape[2], orig_weight.shape[3]
        
        self.stem = nn.Conv2d(num_slices, out_channels, kernel_size=(kH, kW),
                             stride=orig_conv.stride, padding=orig_conv.padding,
                             bias=False)
        
        # Initialize stem weights from pretrained
        with torch.no_grad():
            new_weight = torch.zeros(out_channels, num_slices, kH, kW)
            # Center 3 channels get pretrained weights
            center = num_slices // 2
            new_weight[:, center-1:center+2, :, :] = orig_weight
            # Edge channels get small random init (near zero)
            for i in range(num_slices):
                if i < center-1 or i > center+1:
                    new_weight[:, i, :, :] = orig_weight.mean(dim=1) * 0.1
            self.stem.weight.data = new_weight
        
        # Backbone (skip original first conv, use rest)
        self.backbone_norm = effnet.features[0][1]  # BatchNorm after first conv
        self.backbone_act = effnet.features[0][2]   # Activation after first conv
        self.backbone_rest = nn.Sequential(*list(effnet.features.children())[1:])
        self.avgpool = effnet.avgpool
        self.feature_dim = 1280
        
        # Level conditioning
        self.level_gamma = nn.Embedding(num_levels, self.feature_dim)
        self.level_beta = nn.Embedding(num_levels, self.feature_dim)
        nn.init.ones_(self.level_gamma.weight)
        nn.init.zeros_(self.level_beta.weight)
        
        # Main classifier (3-class)
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.feature_dim),
            nn.Dropout(dropout),
            nn.Linear(self.feature_dim, 256),
            nn.GELU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(256, num_classes)
        )
        
        # Ordinal head (K-1 logits for ordinal consistency)
        self.ordinal_head = nn.Sequential(
            nn.Linear(self.feature_dim, 128),
            nn.GELU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(128, num_classes - 1)
        )
        
    def forward(self, x, level_idx=None):
        # x: (B, 7, H, W)
        # Normalize to [0, 1] range
        x = x.float() / 255.0
        
        # Stem: 7 channels ‚Üí backbone channel count
        x = self.stem(x)
        x = self.backbone_norm(x)
        x = self.backbone_act(x)
        
        # EfficientNet backbone (rest of the layers)
        x = self.backbone_rest(x)
        
        # Pool
        x = self.avgpool(x)
        features = x.view(x.size(0), -1)
        
        # Level conditioning (FiLM)
        if level_idx is not None:
            gamma = self.level_gamma(level_idx)
            beta = self.level_beta(level_idx)
            features = gamma * features + beta
        
        # Dual output
        ce_logits = self.classifier(features)
        ordinal_logits = self.ordinal_head(features)
        
        return {
            'ce': ce_logits,
            'ordinal': ordinal_logits
        }
    
    def predict_proba(self, outputs):
        """Combine CE and ordinal predictions for robust inference."""
        # CE probabilities
        ce_probs = F.softmax(outputs['ce'], dim=1)
        
        # Ordinal probabilities
        cum = torch.sigmoid(outputs['ordinal'])  # P(Y>0), P(Y>1)
        ord_probs = torch.zeros_like(ce_probs)
        ord_probs[:, 0] = 1 - cum[:, 0]
        ord_probs[:, 1] = cum[:, 0] - cum[:, 1]
        ord_probs[:, 2] = cum[:, 1]
        ord_probs = ord_probs.clamp(min=0)
        ord_probs = ord_probs / ord_probs.sum(dim=1, keepdim=True).clamp(min=1e-8)
        
        # Average CE and ordinal probabilities
        return 0.6 * ce_probs + 0.4 * ord_probs

print("‚úÖ Spine25DModel ready")
print("   - Stem: 7ch ‚Üí pretrained weights (center channels = ImageNet)")
print("   - Backbone: EfficientNet-V2-S")
print("   - Output: CE logits + ordinal logits (averaged at inference)")


## 6. Combined Loss (CE + Ordinal)

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, inputs, targets):
        ce = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce)
        focal = ((1 - pt) ** self.gamma) * ce
        return focal.mean() if self.reduction == 'mean' else focal

class OrdinalLoss(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        self.num_classes = num_classes
    def forward(self, logits, labels):
        levels = torch.arange(self.num_classes - 1, device=labels.device)
        targets = (labels.unsqueeze(1) > levels.unsqueeze(0)).float()
        return F.binary_cross_entropy_with_logits(logits, targets, reduction='mean')

class CombinedLoss(nn.Module):
    def __init__(self, gamma=2.0, ordinal_weight=0.5):
        super().__init__()
        self.focal = FocalLoss(gamma)
        self.ordinal = OrdinalLoss()
        self.ordinal_weight = ordinal_weight
    def forward(self, outputs, labels):
        ce_loss = self.focal(outputs['ce'], labels)
        ord_loss = self.ordinal(outputs['ordinal'], labels)
        total = ce_loss + self.ordinal_weight * ord_loss
        return total, {'total': total.item(), 'ce': ce_loss.item(), 'ordinal': ord_loss.item()}

def mixup_data(x, y, alpha=0.2):
    if alpha > 0: lam = np.random.beta(alpha, alpha)
    else: lam = 1.0
    idx = torch.randperm(x.size(0), device=x.device)
    return lam * x + (1 - lam) * x[idx], y, y[idx], lam

print("‚úÖ FocalLoss + OrdinalLoss combined")


In [None]:
def compute_per_class_metrics(preds, labels, num_classes=3):
    metrics = {}
    for c in range(num_classes):
        mask = (labels == c)
        if mask.sum() > 0:
            metrics[f'class_{c}_recall'] = ((preds == c) & mask).sum() / mask.sum()
        else:
            metrics[f'class_{c}_recall'] = 0.0
    return metrics

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001, mode='max'):
        self.patience, self.min_delta, self.mode = patience, min_delta, mode
        self.counter, self.best_score = 0, None
    def __call__(self, score):
        if self.best_score is None:
            self.best_score = score; return False
        improved = (score > self.best_score + self.min_delta) if self.mode == 'max' \
                   else (score < self.best_score - self.min_delta)
        if improved:
            self.best_score, self.counter = score, 0; return False
        self.counter += 1
        return self.counter >= self.patience


## 7. Training Function

In [None]:
def train_one_fold_v7(model, train_loader, val_loader, fold, config):
    criterion = CombinedLoss(gamma=config['focal_gamma'], ordinal_weight=config['ordinal_weight'])
    
    # Separate stem learning rate (new weights need faster learning)
    optimizer = optim.AdamW([
        {'params': model.stem.parameters(), 'lr': config['learning_rate']},      # Stem adapts fast
        {'params': model.backbone_norm.parameters(), 'lr': config['backbone_lr']},
        {'params': model.backbone_act.parameters(), 'lr': config['backbone_lr']},
        {'params': model.backbone_rest.parameters(), 'lr': config['backbone_lr']},
        {'params': model.level_gamma.parameters(), 'lr': config['learning_rate']},
        {'params': model.level_beta.parameters(), 'lr': config['learning_rate']},
        {'params': model.classifier.parameters(), 'lr': config['learning_rate']},
        {'params': model.ordinal_head.parameters(), 'lr': config['learning_rate']},
    ], weight_decay=config['weight_decay'])
    
    warmup_steps = config['warmup_epochs'] * len(train_loader)
    total_steps = config['epochs'] * len(train_loader)
    def lr_lambda(step):
        if step < warmup_steps: return step / max(warmup_steps, 1)
        progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
        return max(0.5 * (1 + np.cos(np.pi * progress)), 1e-6 / config['learning_rate'])
    
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    scaler = GradScaler('cuda')
    
    swa_model = AveragedModel(model) if config['use_swa'] else None
    swa_scheduler = SWALR(optimizer, swa_lr=config['swa_lr']) if config['use_swa'] else None
    
    early_stopping = EarlyStopping(patience=config['patience'], min_delta=0.003, mode='max')
    best_balanced_acc = 0.0
    
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [], 'balanced_acc': [],
        'class_0_recall': [], 'class_1_recall': [], 'class_2_recall': []
    }
    
    print(f"\nüöÄ Training Fold {fold+1} (v7 ‚Äî 2.5D CNN)")
    print(f"   Train: {len(train_loader.dataset)}, Val: {len(val_loader.dataset)}")
    print(f"   Batch: {config['batch_size']}, FocalGamma: {config['focal_gamma']}")
    
    # Freeze backbone initially
    for p in model.backbone_rest.parameters(): p.requires_grad = False
    for p in model.backbone_norm.parameters(): p.requires_grad = False
    backbone_frozen = True
    
    for epoch in range(config['epochs']):
        if backbone_frozen and epoch >= config['freeze_backbone_epochs']:
            for p in model.backbone_rest.parameters(): p.requires_grad = True
            for p in model.backbone_norm.parameters(): p.requires_grad = True
            backbone_frozen = False
            print(f"   üîì Backbone unfrozen at epoch {epoch+1}")
        
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        is_swa = config['use_swa'] and epoch >= config['swa_start_epoch']
        
        tag = "[FROZEN]" if backbone_frozen else ("[SWA]" if is_swa else "")
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']} {tag}")
        
        for images, labels, level_idx in loop:
            images = images.to(config['device'])
            labels = labels.to(config['device'])
            level_idx = level_idx.to(config['device'])
            
            use_mix = config['use_mixup'] and not is_swa and random.random() < 0.5
            if use_mix:
                images, labels_a, labels_b, lam = mixup_data(images, labels, config['mixup_alpha'])
            
            optimizer.zero_grad()
            with autocast('cuda'):
                outputs = model(images, level_idx)
                if use_mix:
                    loss_a, _ = criterion(outputs, labels_a)
                    loss_b, _ = criterion(outputs, labels_b)
                    loss = lam * loss_a + (1 - lam) * loss_b
                else:
                    loss, _ = criterion(outputs, labels)
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config['clip_grad_norm'])
            scaler.step(optimizer)
            scaler.update()
            
            if is_swa: swa_scheduler.step()
            else: scheduler.step()
            
            train_loss += loss.item()
            with torch.no_grad():
                preds = model.predict_proba(outputs).argmax(dim=1)
            if use_mix:
                train_correct += (lam * (preds == labels_a).float() + 
                                  (1-lam) * (preds == labels_b).float()).sum().item()
            else:
                train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)
            
            loop.set_postfix(loss=f"{train_loss/(loop.n+1):.4f}", 
                           acc=f"{100*train_correct/train_total:.1f}%")
        
        train_acc = train_correct / train_total
        if swa_model and is_swa: swa_model.update_parameters(model)
        
        # Validation
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        all_preds, all_labels = [], []
        
        with torch.no_grad():
            for images, labels, level_idx in val_loader:
                images = images.to(config['device'])
                labels = labels.to(config['device'])
                level_idx = level_idx.to(config['device'])
                with autocast('cuda'):
                    outputs = model(images, level_idx)
                    loss, _ = criterion(outputs, labels)
                val_loss += loss.item()
                preds = model.predict_proba(outputs).argmax(dim=1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        val_acc = val_correct / val_total
        all_preds, all_labels = np.array(all_preds), np.array(all_labels)
        pc = compute_per_class_metrics(all_preds, all_labels)
        ba = (pc['class_0_recall'] + pc['class_1_recall'] + pc['class_2_recall']) / 3
        
        history['train_loss'].append(train_loss / len(train_loader))
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss / len(val_loader))
        history['val_acc'].append(val_acc)
        history['balanced_acc'].append(ba)
        for c in range(3): history[f'class_{c}_recall'].append(pc[f'class_{c}_recall'])
        
        print(f"üìä Train Acc: {100*train_acc:.1f}% | Val Acc: {100*val_acc:.1f}% | "
              f"N={100*pc['class_0_recall']:.1f}% M={100*pc['class_1_recall']:.1f}% "
              f"S={100*pc['class_2_recall']:.1f}% | BA={100*ba:.1f}%")
        
        min_min = min(pc['class_1_recall'], pc['class_2_recall'])
        if ba > best_balanced_acc and min_min >= 0.15:
            best_balanced_acc = ba
            torch.save(model.state_dict(), f"best_v7_fold{fold}.pth")
            print(f"   ‚úÖ Saved! BA={100*ba:.1f}%")
        
        if early_stopping(ba):
            print(f"   ‚èπÔ∏è Early stop at epoch {epoch+1}")
            break
    
    model.load_state_dict(torch.load(f"best_v7_fold{fold}.pth"))
    return model, history, best_balanced_acc


## 8. Training

In [None]:
kfold = StratifiedGroupKFold(n_splits=CONFIG['num_folds'], shuffle=True, random_state=CONFIG['seed'])
fold_results = []


In [None]:
for fold, (train_idx, val_idx) in enumerate(kfold.split(df_final, df_final['label'], df_final['study_id'])):
    if fold not in CONFIG['train_folds']:
        continue
    
    print(f"\n{'='*60}")
    print(f"FOLD {fold+1} (v7 ‚Äî 2.5D CNN)")
    print(f"{'='*60}")
    
    train_df = df_final.iloc[train_idx].reset_index(drop=True)
    val_df = df_final.iloc[val_idx].reset_index(drop=True)
    
    for i in range(3):
        c = (train_df['label']==i).sum()
        print(f"   Class {i}: {c} ({100*c/len(train_df):.1f}%)")
    
    sampler = create_weighted_sampler(train_df)
    
    train_ds = RSNA25DDataset(train_df, num_slices=CONFIG['num_slices'], 
                              img_size=CONFIG['img_size'], transform=train_aug, is_training=True)
    val_ds = RSNA25DDataset(val_df, num_slices=CONFIG['num_slices'],
                            img_size=CONFIG['img_size'], transform=val_aug, is_training=False)
    
    train_loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], sampler=sampler,
                             num_workers=2, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_ds, batch_size=CONFIG['batch_size'], shuffle=False,
                           num_workers=2, pin_memory=True)
    
    model = Spine25DModel(
        num_classes=3, num_slices=CONFIG['num_slices'], dropout=CONFIG['dropout']
    ).to(CONFIG['device'])
    
    params = sum(p.numel() for p in model.parameters())
    print(f"   üèóÔ∏è Spine25DModel: {params:,} params (no RNN!)")
    
    model, history, best_ba = train_one_fold_v7(model, train_loader, val_loader, fold, CONFIG)
    fold_results.append({'fold': fold, 'best_balanced_acc': best_ba, 'history': history})
    print(f"\n‚úÖ Fold {fold+1}: Best BA = {100*best_ba:.1f}%")


In [None]:
print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
for r in fold_results:
    print(f"Fold {r['fold']+1}: Best BA = {100*r['best_balanced_acc']:.1f}%")


## 9. TTA Evaluation

In [None]:
def predict_tta_v7(model, df, config, augs):
    model.eval()
    all_probs, all_labels = [], None
    for ai, aug in enumerate(augs):
        ds = RSNA25DDataset(df, num_slices=config['num_slices'],
                            img_size=config['img_size'], transform=aug, is_training=False)
        loader = DataLoader(ds, batch_size=config['batch_size'], shuffle=False, 
                          num_workers=2, pin_memory=True)
        probs_list = []
        if ai == 0: lbl = []
        with torch.no_grad():
            for imgs, labels, lidx in loader:
                imgs = imgs.to(config['device'])
                lidx = lidx.to(config['device'])
                with autocast('cuda'):
                    out = model(imgs, lidx)
                    p = model.predict_proba(out)
                probs_list.append(p.cpu().numpy())
                if ai == 0: lbl.extend(labels.numpy())
        all_probs.append(np.concatenate(probs_list, 0))
        if ai == 0: all_labels = np.array(lbl)
    avg = np.mean(all_probs, 0)
    return np.argmax(avg, 1), all_labels, avg


In [None]:
model.eval()
tta_preds, tta_labels, _ = predict_tta_v7(model, val_df, CONFIG, tta_augs)
no_tta_preds, _, _ = predict_tta_v7(model, val_df, CONFIG, [val_aug])

pc1 = compute_per_class_metrics(no_tta_preds, tta_labels)
ba1 = np.mean([pc1[f'class_{c}_recall'] for c in range(3)])
pc2 = compute_per_class_metrics(tta_preds, tta_labels)
ba2 = np.mean([pc2[f'class_{c}_recall'] for c in range(3)])

print(f"\n{'='*60}")
print(f"Without TTA: BA={100*ba1:.1f}%  N={100*pc1['class_0_recall']:.1f}%  "
      f"M={100*pc1['class_1_recall']:.1f}%  S={100*pc1['class_2_recall']:.1f}%")
print(f"With TTA:    BA={100*ba2:.1f}%  N={100*pc2['class_0_recall']:.1f}%  "
      f"M={100*pc2['class_1_recall']:.1f}%  S={100*pc2['class_2_recall']:.1f}%")
print(f"TTA delta:   {100*(ba2-ba1):+.1f}%")


In [None]:
print("\n" + "="*50)
print("CLASSIFICATION REPORT")
print("="*50)
print(classification_report(tta_labels, tta_preds,
                           target_names=['Normal/Mild', 'Moderate', 'Severe']))

cm = confusion_matrix(tta_labels, tta_preds)
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
plt.figure(figsize=(8, 6))
sns.heatmap(cm_norm, annot=True, fmt='.2%', cmap='Blues',
            xticklabels=['Normal/Mild', 'Moderate', 'Severe'],
            yticklabels=['Normal/Mild', 'Moderate', 'Severe'])
plt.ylabel('True'); plt.xlabel('Predicted')
plt.title(f'v7 2.5D CNN (BA: {100*ba2:.1f}%)')
plt.tight_layout(); plt.show()


In [None]:
if fold_results:
    h = fold_results[0]['history']
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    ep = range(1, len(h['train_loss'])+1)
    axes[0].plot(ep, h['train_loss'], 'b-', label='Train')
    axes[0].plot(ep, h['val_loss'], 'r-', label='Val')
    axes[0].set_title('Loss'); axes[0].legend(); axes[0].grid(True, alpha=0.3)
    axes[1].plot(ep, h['class_0_recall'], 'g-o', label='Normal', ms=3)
    axes[1].plot(ep, h['class_1_recall'], color='orange', marker='s', label='Moderate', ms=3)
    axes[1].plot(ep, h['class_2_recall'], 'r-^', label='Severe', ms=3)
    axes[1].set_title('Per-Class Recall'); axes[1].legend(); axes[1].grid(True, alpha=0.3)
    axes[2].plot(ep, h['balanced_acc'], 'purple', marker='d', lw=2, ms=3)
    axes[2].set_title(f'BA (Best: {100*max(h["balanced_acc"]):.1f}%)')
    axes[2].grid(True, alpha=0.3)
    plt.tight_layout(); plt.show()


In [None]:
print("\n" + "="*60)
print("v7 COMPLETE ‚Äî 2.5D CNN (No RNN)")
print("="*60)
print(f"  Architecture: 7-slice ‚Üí stem conv ‚Üí EfficientNet-V2-S ‚Üí classify")
print(f"  No GRU, no LSTM, no attention pooling")
print(f"  Simpler model = less overfitting with limited minority data")
