# 2D SE-ResNet50 Classifier Training for Vertebral Fracture Detection

This notebook trains a binary classifier (Normal vs. Fracture) on 2D sagittal slices of straightened vertebrae. 
The trained model is required for generating Grad-CAM heatmaps used in the HealthiVert-GAN pipeline.

In [None]:
# Install MONAI if not already installed
!pip install monai

In [None]:
import os
import json
import glob
import random
import numpy as np
import nibabel as nib
from pathlib import Path
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from monai.networks.nets import SEresnet50
from monai.transforms import (
    Compose,
    EnsureChannelFirst,
    Resize,
    ScaleIntensity,
    ToTensor,
    RandRotate,
    RandFlip,
    RandZoom
)
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report

In [None]:
# ==========================================
# Configuration
# ==========================================

# Root directory for the project
PROJECT_ROOT = Path("..").resolve()

# --- UPDATE THESE PATHS FOR KAGGLE ---
# If you generated 30 samples to /kaggle/working/straighten_30s, point there:
DATA_DIR = Path("/kaggle/working/straighten_30s/CT")

# Path to the vertebra_data.json
JSON_PATH = Path("/kaggle/input/verse-19-genant-fracture-grades/vertebra_data.json")

OUTPUT_DIR = Path("./checkpoints")
OUTPUT_DIR.mkdir(exist_ok=True)

# Hyperparameters
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
NUM_EPOCHS = 20
IMAGE_SIZE = (256, 256)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")
print(f"Data Directory: {DATA_DIR}")

In [None]:
# ==========================================
# Dataset Definition
# ==========================================

class SagittalSliceDataset(Dataset):
    def __init__(self, data_dir, json_path, split='train', transform=None, slice_range=15):
        """
        Args:
            data_dir (Path): Path to straightened CT nifti files.
            json_path (Path): Path to vertebra_data.json.
            split (str): 'train' or 'test'.
            transform (callable): MONAI transforms.
            slice_range (int): Number of slices +/- from center to extract (Total = 2*range).
        """
        self.data_dir = data_dir
        self.transform = transform
        self.slice_range = slice_range
        self.samples = []

        # Load Labels
        with open(json_path, 'r') as f:
            labels_data = json.load(f)
        
        # Use training data for training, test for validation
        if split not in labels_data:
            raise ValueError(f"Split {split} not found in json")
            
        split_data = labels_data[split]
        
        # Scan directory
        nii_files = list(data_dir.glob("*.nii.gz"))
        
        print(f"Scanning {len(nii_files)} files for split '{split}'...")

        for file_path in tqdm(nii_files):
            # Filename format: subject_vertID.nii.gz (e.g., sub-verse004_16.nii.gz)
            # Logic to match JSON keys which are also subject_vertID
            filename = file_path.name.replace('.nii.gz', '')
            
            # Check if this filename exists in the current split's labels
            # This automatically handles the case where we only have a subset of data
            if filename in split_data:
                grade = split_data[filename]
                # Binary Label: 0=Normal, 1=Fracture (Grade > 0)
                label = 1 if grade > 0 else 0
                
                self.samples.append({
                    'path': str(file_path),
                    'label': label,
                    'id': filename
                })
        
        # Balance check
        neg = sum(1 for s in self.samples if s['label'] == 0)
        pos = sum(1 for s in self.samples if s['label'] == 1)
        print(f"Found {len(self.samples)} volumes. Normal: {neg}, Fracture: {pos}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        path = sample['path']
        label = sample['label']
        
        try:
            img_nii = nib.load(path)
            img_arr = img_nii.get_fdata()
            
            # Extract Sagittal Slice
            z_center = img_arr.shape[2] // 2
            low = max(0, z_center - self.slice_range)
            high = min(img_arr.shape[2], z_center + self.slice_range)
            
            # Random slice for training robustness
            slice_idx = random.randint(low, high - 1) if high > low else z_center
            
            # Extract 2D slice from the center of the 3D volume
            slice_img = img_arr[:, :, slice_idx]
            
            # Add Channel Dimension (1, H, W)
            slice_img = slice_img[np.newaxis, ...]
            
            if self.transform:
                slice_img = self.transform(slice_img)
            
            return slice_img.float(), torch.tensor(label).long()
            
        except Exception as e:
            print(f"Error loading {path}: {e}")
            return torch.zeros((1, 256, 256)), torch.tensor(0).long()


In [None]:
# ==========================================
# Transforms
# ==========================================

train_transforms = Compose([
    ScaleIntensity(),
    Resize(spatial_size=IMAGE_SIZE),
    # Augmentation
    RandRotate(range_x=np.pi/12, prob=0.5, keep_size=True),
    RandFlip(spatial_axis=1, prob=0.5),
    RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
    ToTensor()
])

val_transforms = Compose([
    ScaleIntensity(),
    Resize(spatial_size=IMAGE_SIZE),
    ToTensor()
])

In [None]:
# ==========================================
# Data Loaders
# ==========================================

train_ds = SagittalSliceDataset(
    data_dir=DATA_DIR, 
    json_path=JSON_PATH, 
    split='train', 
    transform=train_transforms
)

val_ds = SagittalSliceDataset(
    data_dir=DATA_DIR, 
    json_path=JSON_PATH, 
    split='test', 
    transform=val_transforms
)

# Weighted Sampler for Class Imbalance
labels = [s['label'] for s in train_ds.samples]
if len(labels) > 0:
    class_counts = np.bincount(labels)
    weights = 1. / (class_counts + 1e-6) # add small epsilon
    samples_weights = weights[labels]
    sampler = torch.utils.data.WeightedRandomSampler(samples_weights, len(samples_weights))
    shuffle = False
else:
    sampler = None
    shuffle = True
    print("Warning: No samples found! Check paths.")

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, shuffle=shuffle if sampler is None else False, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

In [None]:
# ==========================================
# Model Setup
# ==========================================

model = SEresnet50(
    spatial_dims=2,
    in_channels=1,
    num_classes=2  # 0: Normal, 1: Fracture
).to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
# ==========================================
# Training Loop
# ==========================================

best_acc = 0.0

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    
    # --- TRAINING ---
    model.train()
    train_loss = 0.0
    all_preds, all_labels = [], []
    
    for images, targets in tqdm(train_loader):
        images, targets = images.to(DEVICE), targets.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(targets.cpu().numpy())
        
    if len(train_loader) > 0:
        epoch_loss = train_loss / len(train_loader)
        train_acc = accuracy_score(all_labels, all_preds)
        print(f"Train Loss: {epoch_loss:.4f}, Acc: {train_acc:.4f}")
    
    # --- VALIDATION ---
    model.eval()
    val_loss = 0.0
    val_preds, val_labels = [], []
    
    with torch.no_grad():
        for images, targets in val_loader:
            images, targets = images.to(DEVICE), targets.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            
            _, preds = torch.max(outputs, 1)
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(targets.cpu().numpy())
            
    if len(val_loader) > 0:
        val_acc = accuracy_score(val_labels, val_preds)
        val_f1 = f1_score(val_labels, val_preds, average='weighted')
        print(f"Val Loss: {val_loss/len(val_loader):.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
        
        # Save Best Model
        if val_acc > best_acc:
            best_acc = val_acc
            save_path = OUTPUT_DIR / "best_ckpt.tar"
            torch.save({
                'state_dict': model.state_dict(),
                'epoch': epoch,
                'accuracy': best_acc
            }, save_path)
            print(f"New Best Model Saved to {save_path}!")

print("Training Complete.")

In [None]:
# ==========================================
# Validation Report
# ==========================================

if len(val_labels) > 0:
    print("Classification Report ON TEST SET:")
    print(classification_report(val_labels, val_preds, target_names=['Normal', 'Fracture']))
else:
    print("No validation data found to report.")