## 1. Setup and Imports

In [None]:
# Install required packages (run once)
# !pip install monai nibabel torch torchvision scikit-learn matplotlib tqdm

In [None]:
import os
import json
import numpy as np
import nibabel as nib
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from monai.networks.nets import SEResNet50
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration

In [None]:
# ============== MODIFY THESE PATHS ==============
# Path to straightened CT volumes
CT_FOLDER = "/content/verse19/straighten/CT"  # Colab path
# CT_FOLDER = "d:/Graduation Project/HeathiVert/verse19/straighten/CT"  # Windows path

# Path to vertebra_data.json with ground truth labels
JSON_PATH = "/content/verse19/vertebra_data_test.json"  # Colab path  
# JSON_PATH = "d:/Graduation Project/HeathiVert/verse19/vertebra_data_test.json"  # Windows path

# Output folder for checkpoints
CHECKPOINT_FOLDER = "/content/checkpoints/classifier"  # Colab path
# CHECKPOINT_FOLDER = "d:/Graduation Project/HeathiVert/checkpoints/classifier"  # Windows path

# ============== TRAINING HYPERPARAMETERS ==============
BATCH_SIZE = 32
NUM_EPOCHS = 50
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
NUM_SLICES = 30  # Number of slices to extract from each volume (center ± 15)
VAL_SPLIT = 0.2  # Validation split ratio

# Create checkpoint folder
os.makedirs(CHECKPOINT_FOLDER, exist_ok=True)
print(f"Checkpoints will be saved to: {CHECKPOINT_FOLDER}")

## 3. Load Ground Truth Labels

The JSON structure is:
```json
{
  "train": {"sub-verse004_ct_23": 0, "sub-verse020_12": 1, ...},
  "test": {...},
  "val": {...}
}
```

We convert Genant grades (0,1,2,3) to binary (0=healthy, 1=fractured)

In [None]:
def load_labels(json_path):
    """Load vertebra labels from JSON and convert to binary classification."""
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    # Combine all splits
    all_labels = {}
    for split in ['train', 'test', 'val']:
        if split in data:
            all_labels.update(data[split])
    
    # Convert to binary: 0 = healthy, 1+ = fractured
    binary_labels = {k: (1 if v > 0 else 0) for k, v in all_labels.items()}
    
    # Statistics
    n_healthy = sum(1 for v in binary_labels.values() if v == 0)
    n_fractured = sum(1 for v in binary_labels.values() if v == 1)
    
    print(f"Total vertebrae: {len(binary_labels)}")
    print(f"  Healthy (0): {n_healthy} ({100*n_healthy/len(binary_labels):.1f}%)")
    print(f"  Fractured (1): {n_fractured} ({100*n_fractured/len(binary_labels):.1f}%)")
    
    return binary_labels

labels = load_labels(JSON_PATH)

## 4. Dataset Class

**Key Points:**
- Each `.nii.gz` file is a single straightened vertebra (not whole spine)
- File naming: `sub-verse004_ct_23.nii.gz` → lookup key: `sub-verse004_ct_23`
- Extract **middle 30 slices** (z_center ± 15) as 2D images
- Each slice gets the same label as its parent vertebra

In [None]:
class VertebraSliceDataset(Dataset):
    """Dataset that extracts 2D slices from 3D straightened vertebra volumes."""
    
    def __init__(self, ct_folder, labels_dict, num_slices=30, transform=None):
        """
        Args:
            ct_folder: Path to folder containing .nii.gz files
            labels_dict: Dictionary mapping vertebra_id -> binary label
            num_slices: Number of slices to extract from each volume (center ± num_slices/2)
            transform: Optional transforms to apply to each slice
        """
        self.ct_folder = Path(ct_folder)
        self.labels_dict = labels_dict
        self.num_slices = num_slices
        self.transform = transform
        
        # Find all .nii.gz files and their labels
        self.samples = []  # List of (nii_path, slice_idx, label)
        self._prepare_samples()
        
    def _prepare_samples(self):
        """Build list of (file_path, slice_index, label) tuples."""
        nii_files = list(self.ct_folder.glob('*.nii.gz'))
        print(f"Found {len(nii_files)} .nii.gz files")
        
        matched = 0
        unmatched = []
        
        for nii_path in nii_files:
            # Extract vertebra ID from filename
            # e.g., "sub-verse004_ct_23.nii.gz" -> "sub-verse004_ct_23"
            vertebra_id = nii_path.stem.replace('.nii', '')
            
            # Lookup label
            if vertebra_id in self.labels_dict:
                label = self.labels_dict[vertebra_id]
                matched += 1
                
                # Load volume to get dimensions
                try:
                    vol = nib.load(str(nii_path))
                    z_dim = vol.shape[2]
                    z_center = z_dim // 2
                    half_slices = self.num_slices // 2
                    
                    # Extract middle slices
                    start = max(0, z_center - half_slices)
                    end = min(z_dim, z_center + half_slices)
                    
                    for slice_idx in range(start, end):
                        self.samples.append((nii_path, slice_idx, label))
                except Exception as e:
                    print(f"Error loading {nii_path}: {e}")
            else:
                unmatched.append(vertebra_id)
        
        print(f"Matched: {matched} vertebrae")
        print(f"Unmatched: {len(unmatched)} vertebrae")
        print(f"Total slices: {len(self.samples)}")
        
        # Class distribution
        n_healthy = sum(1 for _, _, l in self.samples if l == 0)
        n_fractured = sum(1 for _, _, l in self.samples if l == 1)
        print(f"Slice distribution: Healthy={n_healthy}, Fractured={n_fractured}")
        
        if unmatched[:5]:
            print(f"Sample unmatched IDs: {unmatched[:5]}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        nii_path, slice_idx, label = self.samples[idx]
        
        # Load volume and extract slice
        vol = nib.load(str(nii_path)).get_fdata()
        slice_2d = vol[:, :, slice_idx].astype(np.float32)
        
        # Normalize to [0, 1]
        slice_min, slice_max = slice_2d.min(), slice_2d.max()
        if slice_max > slice_min:
            slice_2d = (slice_2d - slice_min) / (slice_max - slice_min)
        
        # Add channel dimension: (H, W) -> (1, H, W)
        slice_2d = np.expand_dims(slice_2d, axis=0)
        
        # Convert to tensor
        slice_tensor = torch.from_numpy(slice_2d)
        
        if self.transform:
            slice_tensor = self.transform(slice_tensor)
        
        return slice_tensor, torch.tensor(label, dtype=torch.long)

## 5. Create DataLoaders

In [None]:
# Create full dataset
full_dataset = VertebraSliceDataset(
    ct_folder=CT_FOLDER,
    labels_dict=labels,
    num_slices=NUM_SLICES,
    transform=None
)

# Split into train/val
train_size = int((1 - VAL_SPLIT) * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"\nTrain samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

## 6. Visualize Sample Data

In [None]:
# Get a batch
images, labels_batch = next(iter(train_loader))
print(f"Batch shape: {images.shape}")
print(f"Labels: {labels_batch[:8]}")

# Visualize
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
    if i < len(images):
        ax.imshow(images[i, 0].numpy(), cmap='gray')
        label_text = 'Healthy' if labels_batch[i] == 0 else 'Fractured'
        ax.set_title(f'{label_text} ({labels_batch[i].item()})')
        ax.axis('off')
plt.suptitle('Sample Training Images')
plt.tight_layout()
plt.show()

In [None]:
# Visualize: Understanding the data structure
# Each file is CENTERED on the target vertebra from preprocessing

sample_ct_path = list(Path(CT_FOLDER).glob('*.nii.gz'))[0]
ct_vol = nib.load(str(sample_ct_path)).get_fdata()

# Get center slice (where target vertebra should be)
z_center = ct_vol.shape[2] // 2

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Middle slice (center of z-axis = center of target vertebra)
axes[0].imshow(ct_vol[:, :, z_center].T, cmap='gray', origin='lower')
axes[0].set_title(f'Center Slice (z={z_center})\nTarget vertebra centered here')
axes[0].axhline(y=ct_vol.shape[1]//2, color='r', linestyle='--', alpha=0.5)
axes[0].axvline(x=ct_vol.shape[0]//2, color='r', linestyle='--', alpha=0.5)
axes[0].axis('off')

# Earlier slice (may show more of target vertebra body)
z_early = z_center - 10
axes[1].imshow(ct_vol[:, :, z_early].T, cmap='gray', origin='lower')
axes[1].set_title(f'Slice z={z_early}')
axes[1].axis('off')

# Later slice
z_late = z_center + 10
axes[2].imshow(ct_vol[:, :, z_late].T, cmap='gray', origin='lower')
axes[2].set_title(f'Slice z={z_late}')
axes[2].axis('off')

plt.suptitle(f'Volume: {sample_ct_path.name}\nShape: {ct_vol.shape} (target vertebra = center)', fontsize=12)
plt.tight_layout()
plt.show()

print(f"\nFile: {sample_ct_path.name}")
print(f"Volume shape: {ct_vol.shape}")
print(f"Center of volume = center of target vertebra")
print(f"No external localization needed - preprocessing already centered the data")

## 7. Model Definition

**IMPORTANT**: Must match the architecture in `grad_CAM_3d_sagittal.py`:
```python
model = SEresnet50(spatial_dims=2, in_channels=1, num_classes=2)
model = torch.nn.DataParallel(model).cuda()
```

In [None]:
def create_model():
    """Create SEResNet50 model matching grad_CAM requirements."""
    model = SEResNet50(
        spatial_dims=2,      # 2D images
        in_channels=1,       # Grayscale CT
        num_classes=2        # Binary: healthy/fractured
    )
    return model

model = create_model()
model = model.to(device)

# Print model summary
print(f"Model: SEResNet50")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## 8. Loss Function and Optimizer

In [None]:
# Calculate class weights for imbalanced data
all_labels = [l for _, _, l in full_dataset.samples]
n_healthy = sum(1 for l in all_labels if l == 0)
n_fractured = sum(1 for l in all_labels if l == 1)

# Inverse frequency weighting
if n_fractured > 0 and n_healthy > 0:
    weight_healthy = len(all_labels) / (2 * n_healthy)
    weight_fractured = len(all_labels) / (2 * n_fractured)
    class_weights = torch.tensor([weight_healthy, weight_fractured], dtype=torch.float32).to(device)
else:
    class_weights = None

print(f"Class weights: {class_weights}")

# Loss and optimizer
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

## 9. Training Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    pbar = tqdm(loader, desc='Training')
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = accuracy_score(all_labels, all_preds)
    return epoch_loss, epoch_acc


def validate(model, loader, criterion, device):
    """Validate the model."""
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validating'):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(loader)
    
    # Metrics
    metrics = {
        'accuracy': accuracy_score(all_labels, all_preds),
        'precision': precision_score(all_labels, all_preds, zero_division=0),
        'recall': recall_score(all_labels, all_preds, zero_division=0),
        'f1': f1_score(all_labels, all_preds, zero_division=0)
    }
    
    return epoch_loss, metrics, all_preds, all_labels

## 10. Training Loop

In [None]:
# Training history
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [],
    'val_f1': [], 'val_precision': [], 'val_recall': []
}

best_f1 = 0.0
best_epoch = 0

print(f"Starting training for {NUM_EPOCHS} epochs...")
print("="*60)

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_metrics, val_preds, val_labels = validate(model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step()
    
    # Log
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_metrics['accuracy'])
    history['val_f1'].append(val_metrics['f1'])
    history['val_precision'].append(val_metrics['precision'])
    history['val_recall'].append(val_metrics['recall'])
    
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_metrics['accuracy']:.4f}")
    print(f"  Val F1: {val_metrics['f1']:.4f} | Precision: {val_metrics['precision']:.4f} | Recall: {val_metrics['recall']:.4f}")
    
    # Save best model
    if val_metrics['f1'] > best_f1:
        best_f1 = val_metrics['f1']
        best_epoch = epoch + 1
        
        checkpoint = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_f1': best_f1,
            'history': history
        }
        torch.save(checkpoint, os.path.join(CHECKPOINT_FOLDER, 'best_model.tar'))
        print(f"  ✓ Saved best model (F1: {best_f1:.4f})")
    
    # Save periodic checkpoint
    if (epoch + 1) % 10 == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'history': history
        }
        torch.save(checkpoint, os.path.join(CHECKPOINT_FOLDER, f'checkpoint_epoch_{epoch+1}.tar'))

print("\n" + "="*60)
print(f"Training complete!")
print(f"Best F1: {best_f1:.4f} at epoch {best_epoch}")
print(f"Checkpoints saved to: {CHECKPOINT_FOLDER}")

## 11. Training Curves

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss')
axes[0].legend()
axes[0].grid(True)

# Accuracy
axes[1].plot(history['train_acc'], label='Train')
axes[1].plot(history['val_acc'], label='Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Accuracy')
axes[1].legend()
axes[1].grid(True)

# F1, Precision, Recall
axes[2].plot(history['val_f1'], label='F1')
axes[2].plot(history['val_precision'], label='Precision')
axes[2].plot(history['val_recall'], label='Recall')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Score')
axes[2].set_title('Validation Metrics')
axes[2].legend()
axes[2].grid(True)

plt.tight_layout()
plt.savefig(os.path.join(CHECKPOINT_FOLDER, 'training_curves.png'), dpi=150)
plt.show()

## 12. Final Evaluation

In [None]:
# Load best model
checkpoint = torch.load(os.path.join(CHECKPOINT_FOLDER, 'best_model.tar'))
model.load_state_dict(checkpoint['state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']}")

# Final validation
val_loss, val_metrics, val_preds, val_labels = validate(model, val_loader, criterion, device)

print("\n" + "="*60)
print("FINAL EVALUATION RESULTS")
print("="*60)
print(f"Accuracy:  {val_metrics['accuracy']:.4f}")
print(f"Precision: {val_metrics['precision']:.4f}")
print(f"Recall:    {val_metrics['recall']:.4f}")
print(f"F1 Score:  {val_metrics['f1']:.4f}")

# Confusion Matrix
cm = confusion_matrix(val_labels, val_preds)
print(f"\nConfusion Matrix:")
print(f"              Pred 0   Pred 1")
print(f"Actual 0:     {cm[0,0]:5d}    {cm[0,1]:5d}")
print(f"Actual 1:     {cm[1,0]:5d}    {cm[1,1]:5d}")

# Classification Report
print("\nClassification Report:")
print(classification_report(val_labels, val_preds, target_names=['Healthy', 'Fractured']))

## 13. Save Model for Grad-CAM

Save in format compatible with `grad_CAM_3d_sagittal.py`:
```python
model = SEresnet50(spatial_dims=2, in_channels=1, num_classes=2)
model = torch.nn.DataParallel(model).cuda()
checkpoint = torch.load(ckpt_path)
model.load_state_dict(checkpoint['state_dict'])
```

In [None]:
# Wrap model in DataParallel to match grad_CAM expectations
model_dp = nn.DataParallel(model)

# Save checkpoint in expected format
gradcam_checkpoint = {
    'epoch': checkpoint['epoch'],
    'state_dict': model_dp.state_dict(),  # DataParallel state dict
    'best_f1': best_f1,
    'metrics': val_metrics
}

save_path = os.path.join(CHECKPOINT_FOLDER, 'seresnet50_classifier.tar')
torch.save(gradcam_checkpoint, save_path)
print(f"\nGrad-CAM compatible checkpoint saved to:")
print(f"  {save_path}")

print("\n" + "="*60)
print("USAGE WITH GRAD-CAM")
print("="*60)
print(f"python Attention/grad_CAM_3d_sagittal.py \\")
print(f"  --ckpt-path {save_path} \\")
print(f"  --dataroot <straightened_CT_folder> \\")
print(f"  --output-folder <heatmap_output_folder>")

## 14. Confusion Matrix Visualization

In [None]:
import seaborn as sns

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Healthy', 'Fractured'],
            yticklabels=['Healthy', 'Fractured'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title(f'Confusion Matrix\nAccuracy: {val_metrics["accuracy"]:.2%}, F1: {val_metrics["f1"]:.4f}')
plt.tight_layout()
plt.savefig(os.path.join(CHECKPOINT_FOLDER, 'confusion_matrix.png'), dpi=150)
plt.show()

## Summary

### Model Architecture
- **Model**: MONAI SEResNet50
- **Input**: 2D grayscale images (1 channel)
- **Output**: 2 classes (healthy/fractured)
- **Spatial dims**: 2D

### Data Pipeline
1. Load 3D straightened vertebra volume (`.nii.gz`)
2. Extract middle 30 slices (center ± 15)
3. Each slice inherits parent vertebra's label
4. Normalize to [0, 1]
5. Binary classification: 0=healthy, 1+=fractured

### Ground Truth
- File: `vertebra_data.json`
- Format: `{"patient_ct_vertebraID": genant_grade}`
- Mapping: Grade 0 → Class 0, Grade 1/2/3 → Class 1

### Output
- Checkpoint: `seresnet50_classifier.tar`
- Compatible with `grad_CAM_3d_sagittal.py`