## 1. Setup and Imports

In [None]:
# Install required packages (run once)
# !pip install monai nibabel torch torchvision scikit-learn matplotlib tqdm

In [None]:
import os
import json
import numpy as np
import nibabel as nib
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms

from monai.networks.nets import SEResNet50
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import KFold

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration

In [None]:
# ============== MODIFY THESE PATHS ==============
# Path to straightened CT volumes
CT_FOLDER = "/content/verse19/straighten/CT"  # Colab path
# CT_FOLDER = "d:/Graduation Project/HeathiVert/verse19/straighten/CT"  # Windows path

# Path to vertebra_data.json with ground truth labels
JSON_PATH = "/content/verse19/vertebra_data_test.json"  # Colab path  
# JSON_PATH = "d:/Graduation Project/HeathiVert/verse19/vertebra_data_test.json"  # Windows path

# Output folder for checkpoints
CHECKPOINT_FOLDER = "/content/checkpoints/classifier_kfold"  # Colab path
# CHECKPOINT_FOLDER = "d:/Graduation Project/HeathiVert/checkpoints/classifier_kfold"  # Windows path

# ============== TRAINING HYPERPARAMETERS ==============
BATCH_SIZE = 32
NUM_EPOCHS = 50
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
NUM_SLICES = 30  # Number of slices to extract from each volume (center ± 15)
N_FOLDS = 5  # Number of folds for cross-validation

# Create checkpoint folder
os.makedirs(CHECKPOINT_FOLDER, exist_ok=True)
print(f"Checkpoints will be saved to: {CHECKPOINT_FOLDER}")
print(f"Using {N_FOLDS}-fold cross-validation")

## 3. Load Ground Truth Labels

The JSON structure is:
```json
{
  "train": {"sub-verse004_ct_23": 0, "sub-verse020_12": 1, ...},
  "test": {...},
  "val": {...}
}
```

We convert Genant grades (0,1,2,3) to binary (0=healthy, 1=fractured)

In [None]:
def load_labels(json_path):
    """Load vertebra labels from JSON and convert to binary classification."""
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    # Combine all splits
    all_labels = {}
    for split in ['train', 'test', 'val']:
        if split in data:
            all_labels.update(data[split])
    
    # Convert to binary: 0 = healthy, 1+ = fractured
    binary_labels = {k: (1 if v > 0 else 0) for k, v in all_labels.items()}
    
    # Statistics
    n_healthy = sum(1 for v in binary_labels.values() if v == 0)
    n_fractured = sum(1 for v in binary_labels.values() if v == 1)
    
    print(f"Total vertebrae: {len(binary_labels)}")
    print(f"  Healthy (0): {n_healthy} ({100*n_healthy/len(binary_labels):.1f}%)")
    print(f"  Fractured (1): {n_fractured} ({100*n_fractured/len(binary_labels):.1f}%)")
    
    return binary_labels

labels = load_labels(JSON_PATH)

## 4. Dataset Class

**Key Points:**
- Each `.nii.gz` file is a single straightened vertebra (not whole spine)
- File naming: `sub-verse004_ct_23.nii.gz` → lookup key: `sub-verse004_ct_23`
- Extract **middle 30 slices** (z_center ± 15) as 2D images
- Each slice gets the same label as its parent vertebra

In [None]:
class VertebraSliceDataset(Dataset):
    """Dataset that extracts 2D slices from 3D straightened vertebra volumes."""
    
    def __init__(self, ct_folder, labels_dict, num_slices=30, transform=None):
        """
        Args:
            ct_folder: Path to folder containing .nii.gz files
            labels_dict: Dictionary mapping vertebra_id -> binary label
            num_slices: Number of slices to extract from each volume (center ± num_slices/2)
            transform: Optional transforms to apply to each slice
        """
        self.ct_folder = Path(ct_folder)
        self.labels_dict = labels_dict
        self.num_slices = num_slices
        self.transform = transform
        
        # Find all .nii.gz files and their labels
        self.samples = []  # List of (nii_path, slice_idx, label)
        self._prepare_samples()
        
    def _prepare_samples(self):
        """Build list of (file_path, slice_index, label) tuples."""
        nii_files = list(self.ct_folder.glob('*.nii.gz'))
        print(f"Found {len(nii_files)} .nii.gz files")
        
        matched = 0
        unmatched = []
        
        for nii_path in nii_files:
            # Extract vertebra ID from filename
            # e.g., "sub-verse004_ct_23.nii.gz" -> "sub-verse004_ct_23"
            vertebra_id = nii_path.stem.replace('.nii', '')
            
            # Lookup label
            if vertebra_id in self.labels_dict:
                label = self.labels_dict[vertebra_id]
                matched += 1
                
                # Load volume to get dimensions
                try:
                    vol = nib.load(str(nii_path))
                    z_dim = vol.shape[2]
                    z_center = z_dim // 2
                    half_slices = self.num_slices // 2
                    
                    # Extract middle slices
                    start = max(0, z_center - half_slices)
                    end = min(z_dim, z_center + half_slices)
                    
                    for slice_idx in range(start, end):
                        self.samples.append((nii_path, slice_idx, label))
                except Exception as e:
                    print(f"Error loading {nii_path}: {e}")
            else:
                unmatched.append(vertebra_id)
        
        print(f"Matched: {matched} vertebrae")
        print(f"Unmatched: {len(unmatched)} vertebrae")
        print(f"Total slices: {len(self.samples)}")
        
        # Class distribution
        n_healthy = sum(1 for _, _, l in self.samples if l == 0)
        n_fractured = sum(1 for _, _, l in self.samples if l == 1)
        print(f"Slice distribution: Healthy={n_healthy}, Fractured={n_fractured}")
        
        if unmatched[:5]:
            print(f"Sample unmatched IDs: {unmatched[:5]}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        nii_path, slice_idx, label = self.samples[idx]
        
        # Load volume and extract slice
        vol = nib.load(str(nii_path)).get_fdata()
        slice_2d = vol[:, :, slice_idx].astype(np.float32)
        
        # Normalize to [0, 1]
        slice_min, slice_max = slice_2d.min(), slice_2d.max()
        if slice_max > slice_min:
            slice_2d = (slice_2d - slice_min) / (slice_max - slice_min)
        
        # Add channel dimension: (H, W) -> (1, H, W)
        slice_2d = np.expand_dims(slice_2d, axis=0)
        
        # Convert to tensor
        slice_tensor = torch.from_numpy(slice_2d)
        
        if self.transform:
            slice_tensor = self.transform(slice_tensor)
        
        return slice_tensor, torch.tensor(label, dtype=torch.long)

## 5. Setup K-Fold Cross-Validation

We'll use 5-fold cross-validation to better evaluate model performance and reduce overfitting to a single train/val split.

## 6. Visualize Sample Data

In [None]:
# Create full dataset (no split yet)
full_dataset = VertebraSliceDataset(
    ct_folder=CT_FOLDER,
    labels_dict=labels,
    num_slices=NUM_SLICES,
    transform=None
)

print(f"Total dataset samples: {len(full_dataset)}")

# Setup K-Fold cross-validation
kfold = KFold(n_splits=N_FOLDS, shuffle=True, random_state=42)

# Store results for each fold
fold_results = {
    'train_loss': [],
    'val_loss': [],
    'accuracy': [],
    'precision': [],
    'recall': [],
    'f1': [],
    'confusion_matrices': []
}

print(f"\nK-Fold Setup: {N_FOLDS} folds, {len(full_dataset)} total samples")
print(f"Approx. samples per fold: Train={int(len(full_dataset)*(N_FOLDS-1)/N_FOLDS)}, Val={int(len(full_dataset)/N_FOLDS)}")

In [None]:
# Create a temporary loader for visualization
temp_loader = DataLoader(full_dataset, batch_size=8, shuffle=True, num_workers=0)

# Get a batch
images, labels_batch = next(iter(temp_loader))
print(f"Batch shape: {images.shape}")
print(f"Labels: {labels_batch}")

# Visualize
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
    if i < len(images):
        ax.imshow(images[i, 0].numpy(), cmap='gray')
        label_text = 'Healthy' if labels_batch[i] == 0 else 'Fractured'
        ax.set_title(f'{label_text} ({labels_batch[i].item()})')
        ax.axis('off')
plt.suptitle('Sample Training Images')
plt.tight_layout()
plt.show()

del temp_loader  # Clean up

In [None]:
# Visualize: Understanding the data structure
# Each file is CENTERED on the target vertebra from preprocessing

nii_files = list(Path(CT_FOLDER).glob('*.nii.gz')) + list(Path(CT_FOLDER).glob('*.nii'))
if len(nii_files) > 0:
    sample_ct_path = nii_files[0]
    ct_vol = nib.load(str(sample_ct_path)).get_fdata()

    # Get center slice (where target vertebra should be)
    z_center = ct_vol.shape[2] // 2

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Middle slice (center of z-axis = center of target vertebra)
    axes[0].imshow(ct_vol[:, :, z_center].T, cmap='gray', origin='lower')
    axes[0].set_title(f'Center Slice (z={z_center})\nTarget vertebra centered here')
    axes[0].axhline(y=ct_vol.shape[1]//2, color='r', linestyle='--', alpha=0.5)
    axes[0].axvline(x=ct_vol.shape[0]//2, color='r', linestyle='--', alpha=0.5)
    axes[0].axis('off')

    # Earlier slice (may show more of target vertebra body)
    z_early = max(0, z_center - 10)
    axes[1].imshow(ct_vol[:, :, z_early].T, cmap='gray', origin='lower')
    axes[1].set_title(f'Slice z={z_early}')
    axes[1].axis('off')

    # Later slice
    z_late = min(ct_vol.shape[2] - 1, z_center + 10)
    axes[2].imshow(ct_vol[:, :, z_late].T, cmap='gray', origin='lower')
    axes[2].set_title(f'Slice z={z_late}')
    axes[2].axis('off')

    plt.suptitle(f'Volume: {sample_ct_path.name}\nShape: {ct_vol.shape} (target vertebra = center)', fontsize=12)
    plt.tight_layout()
    plt.show()

    print(f"\nFile: {sample_ct_path.name}")
    print(f"Volume shape: {ct_vol.shape}")
    print(f"Center of volume = center of target vertebra")
    print(f"No external localization needed - preprocessing already centered the data")
else:
    print("No .nii or .nii.gz files found in CT_FOLDER")

## 7. Model Definition

**IMPORTANT**: Must match the architecture in `grad_CAM_3d_sagittal.py`:
```python
model = SEresnet50(spatial_dims=2, in_channels=1, num_classes=2)
model = torch.nn.DataParallel(model).cuda()
```

In [None]:
def create_model():
    """Create SEResNet50 model matching grad_CAM requirements."""
    model = SEResNet50(
        spatial_dims=2,      # 2D images
        in_channels=1,       # Grayscale CT
        num_classes=2        # Binary: healthy/fractured
    )
    return model

model = create_model()
model = model.to(device)

# Print model summary
print(f"Model: SEResNet50")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## 8. Loss Function and Optimizer

In [None]:
# Calculate class weights for imbalanced data
all_labels = [l for _, _, l in full_dataset.samples]
n_healthy = sum(1 for l in all_labels if l == 0)
n_fractured = sum(1 for l in all_labels if l == 1)

# Inverse frequency weighting
if n_fractured > 0 and n_healthy > 0:
    weight_healthy = len(all_labels) / (2 * n_healthy)
    weight_fractured = len(all_labels) / (2 * n_fractured)
    class_weights = torch.tensor([weight_healthy, weight_fractured], dtype=torch.float32).to(device)
else:
    class_weights = None

print(f"Class weights: {class_weights}")

# Loss and optimizer
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

## 9. Training Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    pbar = tqdm(loader, desc='Training')
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = accuracy_score(all_labels, all_preds)
    return epoch_loss, epoch_acc


def validate(model, loader, criterion, device):
    """Validate the model."""
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validating'):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(loader)
    
    # Metrics
    metrics = {
        'accuracy': accuracy_score(all_labels, all_preds),
        'precision': precision_score(all_labels, all_preds, zero_division=0),
        'recall': recall_score(all_labels, all_preds, zero_division=0),
        'f1': f1_score(all_labels, all_preds, zero_division=0)
    }
    
    return epoch_loss, metrics, all_preds, all_labels

## 10. K-Fold Training Loop

Train model on each fold and track performance across all folds.

In [None]:
print(f"Starting {N_FOLDS}-Fold Cross-Validation")
print("="*80)

# Iterate through each fold
for fold, (train_ids, val_ids) in enumerate(kfold.split(full_dataset)):
    print(f"\n{'='*80}")
    print(f"FOLD {fold + 1}/{N_FOLDS}")
    print(f"{'='*80}")
    print(f"Train samples: {len(train_ids)}, Val samples: {len(val_ids)}")
    
    # Create data loaders for this fold
    train_subset = Subset(full_dataset, train_ids)
    val_subset = Subset(full_dataset, val_ids)
    
    train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    # Initialize fresh model for this fold
    model = SEResNet50(
        spatial_dims=2,
        in_channels=1,
        num_classes=2,
        pretrained=False
    ).to(device)
    
    # Loss, optimizer, scheduler
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.5)
    
    # Training history for this fold
    fold_history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [],
        'val_f1': [], 'val_precision': [], 'val_recall': []
    }
    
    best_f1 = 0.0
    best_epoch = 0
    
    # Train for NUM_EPOCHS
    for epoch in range(NUM_EPOCHS):
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validate
        val_loss, val_metrics, val_preds, val_labels = validate(model, val_loader, criterion, device)
        
        # Update scheduler
        scheduler.step()
        
        # Log
        fold_history['train_loss'].append(train_loss)
        fold_history['train_acc'].append(train_acc)
        fold_history['val_loss'].append(val_loss)
        fold_history['val_acc'].append(val_metrics['accuracy'])
        fold_history['val_f1'].append(val_metrics['f1'])
        fold_history['val_precision'].append(val_metrics['precision'])
        fold_history['val_recall'].append(val_metrics['recall'])
        
        # Print every 5 epochs
        if (epoch + 1) % 5 == 0 or epoch == 0:
            print(f"Epoch {epoch+1}/{NUM_EPOCHS} | "
                  f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
                  f"Val Acc: {val_metrics['accuracy']:.4f} | F1: {val_metrics['f1']:.4f}")
        
        # Save best model for this fold
        if val_metrics['f1'] > best_f1:
            best_f1 = val_metrics['f1']
            best_epoch = epoch + 1
            
            checkpoint = {
                'fold': fold + 1,
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_f1': best_f1,
                'history': fold_history
            }
            torch.save(checkpoint, os.path.join(CHECKPOINT_FOLDER, f'fold_{fold+1}_best.tar'))
    
    # Load best model and evaluate on validation set
    checkpoint = torch.load(os.path.join(CHECKPOINT_FOLDER, f'fold_{fold+1}_best.tar'), weights_only=False)
    model.load_state_dict(checkpoint['state_dict'])
    
    val_loss, val_metrics, val_preds, val_labels = validate(model, val_loader, criterion, device)
    cm = confusion_matrix(val_labels, val_preds)
    
    # Store fold results
    fold_results['train_loss'].append(fold_history['train_loss'][-1])
    fold_results['val_loss'].append(val_loss)
    fold_results['accuracy'].append(val_metrics['accuracy'])
    fold_results['precision'].append(val_metrics['precision'])
    fold_results['recall'].append(val_metrics['recall'])
    fold_results['f1'].append(val_metrics['f1'])
    fold_results['confusion_matrices'].append(cm)
    
    print(f"\nFold {fold + 1} Results:")
    print(f"  Best Epoch: {best_epoch}")
    print(f"  Accuracy:  {val_metrics['accuracy']:.4f}")
    print(f"  Precision: {val_metrics['precision']:.4f}")
    print(f"  Recall:    {val_metrics['recall']:.4f}")
    print(f"  F1 Score:  {val_metrics['f1']:.4f}")
    print(f"  Confusion Matrix:")
    print(f"    [[{cm[0,0]:4d} {cm[0,1]:4d}]")
    print(f"     [{cm[1,0]:4d} {cm[1,1]:4d}]]")

print(f"\n{'='*80}")
print("K-FOLD CROSS-VALIDATION COMPLETE")
print(f"{'='*80}")

## 11. Cross-Validation Results Summary

Calculate mean and standard deviation across all folds.

In [None]:
# Calculate mean and std for each metric
results_summary = {
    'Accuracy': (np.mean(fold_results['accuracy']), np.std(fold_results['accuracy'])),
    'Precision': (np.mean(fold_results['precision']), np.std(fold_results['precision'])),
    'Recall': (np.mean(fold_results['recall']), np.std(fold_results['recall'])),
    'F1 Score': (np.mean(fold_results['f1']), np.std(fold_results['f1'])),
    'Val Loss': (np.mean(fold_results['val_loss']), np.std(fold_results['val_loss']))
}

print("\n" + "="*80)
print("CROSS-VALIDATION RESULTS (Mean ± Std)")
print("="*80)
for metric, (mean, std) in results_summary.items():
    print(f"{metric:12s}: {mean:.4f} ± {std:.4f}")

# Per-fold detailed results
print("\n" + "="*80)
print("PER-FOLD DETAILED RESULTS")
print("="*80)
print(f"{'Fold':<6} {'Accuracy':<10} {'Precision':<11} {'Recall':<10} {'F1 Score':<10}")
print("-"*80)
for i in range(N_FOLDS):
    print(f"{i+1:<6} {fold_results['accuracy'][i]:<10.4f} {fold_results['precision'][i]:<11.4f} "
          f"{fold_results['recall'][i]:<10.4f} {fold_results['f1'][i]:<10.4f}")

# Average confusion matrix
avg_cm = np.mean(fold_results['confusion_matrices'], axis=0).astype(int)
print(f"\n" + "="*80)
print("AVERAGE CONFUSION MATRIX")
print("="*80)
print(f"              Predicted 0   Predicted 1")
print(f"Actual 0:     {avg_cm[0,0]:6d}        {avg_cm[0,1]:6d}")
print(f"Actual 1:     {avg_cm[1,0]:6d}        {avg_cm[1,1]:6d}")

# Save results to file
with open(os.path.join(CHECKPOINT_FOLDER, 'kfold_results.txt'), 'w') as f:
    f.write(f"{N_FOLDS}-Fold Cross-Validation Results\n")
    f.write("="*80 + "\n\n")
    f.write("Mean ± Std:\n")
    for metric, (mean, std) in results_summary.items():
        f.write(f"{metric:12s}: {mean:.4f} ± {std:.4f}\n")
    f.write("\nPer-Fold Results:\n")
    f.write(f"{'Fold':<6} {'Accuracy':<10} {'Precision':<11} {'Recall':<10} {'F1 Score':<10}\n")
    for i in range(N_FOLDS):
        f.write(f"{i+1:<6} {fold_results['accuracy'][i]:<10.4f} {fold_results['precision'][i]:<11.4f} "
                f"{fold_results['recall'][i]:<10.4f} {fold_results['f1'][i]:<10.4f}\n")

print(f"\nResults saved to: {os.path.join(CHECKPOINT_FOLDER, 'kfold_results.txt')}")

## 12. Visualize Cross-Validation Results

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Metrics across folds
metrics_to_plot = ['accuracy', 'precision', 'recall', 'f1']
colors = ['#3498db', '#e74c3c', '#2ecc71', '#f39c12']

for i, (metric, color) in enumerate(zip(metrics_to_plot, colors)):
    ax = axes[0, 0]
    values = fold_results[metric]
    mean = np.mean(values)
    x = np.arange(1, N_FOLDS + 1)
    ax.bar(x + i*0.2 - 0.3, values, width=0.2, label=metric.capitalize(), color=color, alpha=0.8)

axes[0, 0].axhline(y=np.mean(fold_results['accuracy']), color='gray', linestyle='--', alpha=0.3)
axes[0, 0].set_xlabel('Fold')
axes[0, 0].set_ylabel('Score')
axes[0, 0].set_title('Metrics Across Folds')
axes[0, 0].legend()
axes[0, 0].set_xticks(np.arange(1, N_FOLDS + 1))
axes[0, 0].grid(True, alpha=0.3)

# 2. Box plot of metrics
axes[0, 1].boxplot([fold_results[m] for m in metrics_to_plot], 
                    labels=[m.capitalize() for m in metrics_to_plot],
                    patch_artist=True)
axes[0, 1].set_ylabel('Score')
axes[0, 1].set_title('Metric Distribution Across Folds')
axes[0, 1].grid(True, alpha=0.3)

# 3. Average Confusion Matrix
import seaborn as sns
sns.heatmap(avg_cm, annot=True, fmt='d', cmap='Blues', ax=axes[1, 0],
            xticklabels=['Healthy', 'Fractured'],
            yticklabels=['Healthy', 'Fractured'],
            cbar_kws={'label': 'Count'})
axes[1, 0].set_xlabel('Predicted')
axes[1, 0].set_ylabel('Actual')
axes[1, 0].set_title('Average Confusion Matrix')

# 4. F1 Score per fold with error bars
x = np.arange(1, N_FOLDS + 1)
f1_mean = np.mean(fold_results['f1'])
f1_std = np.std(fold_results['f1'])
axes[1, 1].bar(x, fold_results['f1'], color='#9b59b6', alpha=0.7, label='Per-fold F1')
axes[1, 1].axhline(y=f1_mean, color='red', linestyle='--', linewidth=2, label=f'Mean: {f1_mean:.4f}')
axes[1, 1].axhline(y=f1_mean + f1_std, color='red', linestyle=':', alpha=0.5)
axes[1, 1].axhline(y=f1_mean - f1_std, color='red', linestyle=':', alpha=0.5)
axes[1, 1].fill_between([0.5, N_FOLDS + 0.5], f1_mean - f1_std, f1_mean + f1_std, 
                         color='red', alpha=0.1, label=f'±1 Std: {f1_std:.4f}')
axes[1, 1].set_xlabel('Fold')
axes[1, 1].set_ylabel('F1 Score')
axes[1, 1].set_title('F1 Score Across Folds')
axes[1, 1].set_xticks(x)
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(CHECKPOINT_FOLDER, 'kfold_results.png'), dpi=150)
plt.show()

print(f"Visualization saved to: {os.path.join(CHECKPOINT_FOLDER, 'kfold_results.png')}")

## 13. Save Best Fold Model for Grad-CAM

Choose the best performing fold and save in format compatible with `grad_CAM_3d_sagittal.py`.

In [None]:
# Find best fold based on F1 score
best_fold_idx = np.argmax(fold_results['f1'])
best_fold_num = best_fold_idx + 1
best_fold_f1 = fold_results['f1'][best_fold_idx]

print(f"Best performing fold: Fold {best_fold_num}")
print(f"Best F1 Score: {best_fold_f1:.4f}")
print(f"Accuracy: {fold_results['accuracy'][best_fold_idx]:.4f}")
print(f"Precision: {fold_results['precision'][best_fold_idx]:.4f}")
print(f"Recall: {fold_results['recall'][best_fold_idx]:.4f}")

# Load best fold model
best_model_path = os.path.join(CHECKPOINT_FOLDER, f'fold_{best_fold_num}_best.tar')
checkpoint = torch.load(best_model_path, weights_only=False)

# Initialize fresh model and load state
best_model = SEResNet50(
    spatial_dims=2,
    in_channels=1,
    num_classes=2,
    pretrained=False
).to(device)
best_model.load_state_dict(checkpoint['state_dict'])

print(f"\nLoaded best fold model from: {best_model_path}")
print(f"  Epoch: {checkpoint['epoch']}")
print(f"  F1 Score: {checkpoint['best_f1']:.4f}")

## 14. Save Model for Grad-CAM

Save in format compatible with `grad_CAM_3d_sagittal.py`:
```python
model = SEresnet50(spatial_dims=2, in_channels=1, num_classes=2)
model = torch.nn.DataParallel(model).cuda()
checkpoint = torch.load(ckpt_path)
model.load_state_dict(checkpoint['state_dict'])
```

In [None]:
# Wrap model in DataParallel to match grad_CAM expectations
model_dp = nn.DataParallel(best_model)

# Save checkpoint in expected format
gradcam_checkpoint = {
    'fold': best_fold_num,
    'epoch': checkpoint['epoch'],
    'state_dict': model_dp.state_dict(),  # DataParallel state dict
    'best_f1': best_fold_f1,
    'mean_f1': np.mean(fold_results['f1']),
    'std_f1': np.std(fold_results['f1'])
}

save_path = os.path.join(CHECKPOINT_FOLDER, 'seresnet50_classifier_best.tar')
torch.save(gradcam_checkpoint, save_path)
print(f"\nGrad-CAM compatible checkpoint saved to:")
print(f"  {save_path}")

print("\n" + "="*60)
print("USAGE WITH GRAD-CAM")
print("="*60)
print(f"python Attention/grad_CAM_3d_sagittal.py \\")
print(f"  --ckpt-path {save_path} \\")
print(f"  --dataroot <straightened_CT_folder> \\")
print(f"  --output-folder <heatmap_output_folder>")

## Summary

### Training Approach
- **Method**: 5-Fold Cross-Validation
- **Purpose**: Robust evaluation and reduced overfitting
- **Each fold**: Separate model trained independently

### Model Architecture
- **Model**: MONAI SEResNet50
- **Input**: 2D grayscale images (1 channel, 64x64)
- **Output**: 2 classes (healthy/fractured)
- **Spatial dims**: 2D

### Data Pipeline
1. Load 3D straightened vertebra volume (`.nii` or `.nii.gz`)
2. Extract middle 30 slices (center ± 15)
3. Each slice inherits parent vertebra's label
4. Normalize to [0, 1]
5. Binary classification: 0=healthy, 1+=fractured

### Ground Truth
- File: `vertebra_data.json`
- Format: `{"patient_ct_vertebraID": genant_grade}`
- Mapping: Grade 0 → Class 0, Grade 1/2/3 → Class 1

### Output Files
- **Per-fold checkpoints**: `fold_1_best.tar` through `fold_5_best.tar`
- **Best model**: `seresnet50_classifier_best.tar` (best fold, Grad-CAM compatible)
- **Results**: `kfold_results.txt` (detailed metrics)
- **Visualization**: `kfold_results.png` (plots)

### Cross-Validation Benefits
- More robust performance estimate
- Reduces variance from single train/val split
- Uses all data for both training and validation
- Identifies model stability across different data subsets