# Swin-UNETR Transfer Learning for Liver/Tumor Segmentation

Using pre-trained Swin-UNETR from MONAI with weights trained on 5,050 CT scans.

In [None]:
# Install MONAI if needed
# !pip install monai[all]

In [None]:
import os
import time
import json
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast

from monai.networks.nets import SwinUNETR
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference

from sklearn.metrics import confusion_matrix

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

In [None]:
# =============================================================================
# Configuration
# =============================================================================

CONFIG = {
    # Data
    'data_dir': 'preprocessed_patches_v2',
    'num_classes': 3,
    'input_size': (128, 128, 128),
    
    # Model
    'feature_size': 48,  # 48 for base, 24 for small (less VRAM)
    'use_pretrained': True,
    
    # Training
    'batch_size': 1,  # Keep small for 12GB VRAM
    'accumulation_steps': 4,  # Effective batch size = 4
    'epochs': 100,
    'lr': 1e-4,
    'weight_decay': 1e-5,
    
    # Class weights for imbalanced data [bg, liver, tumor]
    'class_weights': [0.1, 1.0, 5.0],
    
    # Misc
    'seed': 42,
    'num_workers': 2,
    'checkpoint_dir': 'checkpoints_swin',
}

# Create checkpoint directory
os.makedirs(CONFIG['checkpoint_dir'], exist_ok=True)

# Set seed for reproducibility
torch.manual_seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# =============================================================================
# Dataset
# =============================================================================

class LiverDataset(Dataset):
    """
    Dataset for liver/tumor segmentation using preprocessed .npz files.
    Each file contains 20 patches of shape (128, 128, 128).
    """
    def __init__(self, file_list, augment=False):
        self.file_list = file_list
        self.augment = augment
        
        # Build flat index: (file_idx, patch_idx)
        self.indices = []
        for file_idx, filepath in enumerate(file_list):
            # Each file has 20 patches
            for patch_idx in range(20):
                self.indices.append((file_idx, patch_idx))
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        file_idx, patch_idx = self.indices[idx]
        data = np.load(self.file_list[file_idx])
        
        # Load patch and segmentation
        image = data['patches'][patch_idx].astype(np.float32) / 255.0
        label = data['segmentations'][patch_idx].astype(np.int64)
        
        # Add channel dimension: (128,128,128) -> (1,128,128,128)
        image = image[np.newaxis, ...]
        
        # Simple augmentation
        if self.augment:
            # Random flip along each axis (50% chance each)
            if np.random.random() > 0.5:
                image = np.flip(image, axis=1).copy()
                label = np.flip(label, axis=0).copy()
            if np.random.random() > 0.5:
                image = np.flip(image, axis=2).copy()
                label = np.flip(label, axis=1).copy()
            if np.random.random() > 0.5:
                image = np.flip(image, axis=3).copy()
                label = np.flip(label, axis=2).copy()
            
            # Random intensity shift
            if np.random.random() > 0.5:
                shift = np.random.uniform(-0.1, 0.1)
                image = np.clip(image + shift, 0, 1)
        
        return torch.from_numpy(image), torch.from_numpy(label)


# Get file list and split
all_files = sorted([os.path.join(CONFIG['data_dir'], f) 
                    for f in os.listdir(CONFIG['data_dir']) if f.endswith('.npz')])

np.random.seed(CONFIG['seed'])
indices = np.random.permutation(len(all_files))
train_end = int(len(all_files) * 0.70)
val_end = train_end + int(len(all_files) * 0.15)

train_files = [all_files[i] for i in indices[:train_end]]
val_files = [all_files[i] for i in indices[train_end:val_end]]
test_files = [all_files[i] for i in indices[val_end:]]

print(f"Total files: {len(all_files)}")
print(f"Train: {len(train_files)} files ({len(train_files)*20} patches)")
print(f"Val: {len(val_files)} files ({len(val_files)*20} patches)")
print(f"Test: {len(test_files)} files ({len(test_files)*20} patches)")

# Create datasets and dataloaders
train_dataset = LiverDataset(train_files, augment=True)
val_dataset = LiverDataset(val_files, augment=False)
test_dataset = LiverDataset(test_files, augment=False)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], 
                          shuffle=True, num_workers=CONFIG['num_workers'], pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], 
                        shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=CONFIG['batch_size'], 
                         shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

In [None]:
# =============================================================================
# Model: Swin-UNETR with Pre-trained Weights
# =============================================================================

model = SwinUNETR(
    img_size=CONFIG['input_size'],
    in_channels=1,
    out_channels=CONFIG['num_classes'],
    feature_size=CONFIG['feature_size'],
    use_checkpoint=True,  # Gradient checkpointing to save VRAM
    spatial_dims=3,
)

if CONFIG['use_pretrained']:
    print("Loading pre-trained weights...")
    # Download pre-trained weights (trained on 5050 CT volumes)
    weight_path = os.path.join(CONFIG['checkpoint_dir'], 'swin_unetr_pretrained.pt')
    
    if not os.path.exists(weight_path):
        print("Downloading pre-trained weights (~400MB)...")
        import urllib.request
        url = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/swin_unetr.base_5000ep_f48_lr2e-4_pretrained.pt"
        urllib.request.urlretrieve(url, weight_path)
        print("Download complete!")
    
    # Load weights (only encoder, not the segmentation head)
    pretrained_dict = torch.load(weight_path, map_location='cpu')
    model_dict = model.state_dict()
    
    # Filter out segmentation head weights (we have different num_classes)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() 
                       if k in model_dict and v.shape == model_dict[k].shape}
    
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    print(f"Loaded {len(pretrained_dict)}/{len(model_dict)} pre-trained layers")

model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

In [None]:
# =============================================================================
# Loss Function and Optimizer
# =============================================================================

# Combined Dice + Cross-Entropy loss with class weights
class_weights = torch.tensor(CONFIG['class_weights'], dtype=torch.float32, device=device)

loss_fn = DiceCELoss(
    to_onehot_y=True,
    softmax=True,
    ce_weight=class_weights,
)

# Optimizer with weight decay
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['lr'],
    weight_decay=CONFIG['weight_decay']
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=CONFIG['epochs'],
    eta_min=1e-6
)

# Mixed precision scaler
scaler = GradScaler()

# Dice metric for evaluation
dice_metric = DiceMetric(include_background=False, reduction='mean_batch')

print("Loss: DiceCE with class weights", CONFIG['class_weights'])
print(f"Optimizer: AdamW (lr={CONFIG['lr']}, wd={CONFIG['weight_decay']})")
print(f"Scheduler: CosineAnnealing")
print(f"Mixed precision: Enabled")

In [None]:
# =============================================================================
# Training Functions
# =============================================================================

def train_epoch(model, loader, optimizer, loss_fn, scaler, accumulation_steps=1):
    model.train()
    total_loss = 0
    optimizer.zero_grad()
    
    pbar = tqdm(loader, desc='Training')
    for batch_idx, (images, labels) in enumerate(pbar):
        images = images.to(device)
        labels = labels.to(device)
        
        # Mixed precision forward pass
        with autocast():
            outputs = model(images)
            loss = loss_fn(outputs, labels.unsqueeze(1)) / accumulation_steps
        
        # Backward pass with gradient scaling
        scaler.scale(loss).backward()
        
        # Gradient accumulation
        if (batch_idx + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        total_loss += loss.item() * accumulation_steps
        pbar.set_postfix({'loss': f'{loss.item() * accumulation_steps:.4f}'})
    
    return total_loss / len(loader)


def validate(model, loader, loss_fn, dice_metric):
    model.eval()
    total_loss = 0
    dice_metric.reset()
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validation'):
            images = images.to(device)
            labels = labels.to(device)
            
            with autocast():
                outputs = model(images)
                loss = loss_fn(outputs, labels.unsqueeze(1))
            
            total_loss += loss.item()
            
            # Compute dice
            preds = torch.argmax(outputs, dim=1)
            # One-hot encode for dice metric
            preds_onehot = torch.nn.functional.one_hot(preds, CONFIG['num_classes']).permute(0, 4, 1, 2, 3)
            labels_onehot = torch.nn.functional.one_hot(labels, CONFIG['num_classes']).permute(0, 4, 1, 2, 3)
            dice_metric(preds_onehot, labels_onehot)
    
    dice_scores = dice_metric.aggregate()
    dice_metric.reset()
    
    return total_loss / len(loader), dice_scores


print("Training functions defined.")

In [None]:
# =============================================================================
# Training Loop
# =============================================================================

print("="*60)
print("Starting Swin-UNETR Training")
print("="*60)
print(f"Epochs: {CONFIG['epochs']}")
print(f"Batch size: {CONFIG['batch_size']} x {CONFIG['accumulation_steps']} = {CONFIG['batch_size'] * CONFIG['accumulation_steps']} effective")
print(f"Learning rate: {CONFIG['lr']}")
print("="*60 + "\n")

history = {
    'train_loss': [],
    'val_loss': [],
    'val_dice_liver': [],
    'val_dice_tumor': [],
    'lr': []
}

best_dice_tumor = 0
patience = 15
patience_counter = 0

try:
    for epoch in range(CONFIG['epochs']):
        epoch_start = time.time()
        
        # Train
        train_loss = train_epoch(
            model, train_loader, optimizer, loss_fn, scaler, 
            accumulation_steps=CONFIG['accumulation_steps']
        )
        
        # Validate
        val_loss, dice_scores = validate(model, val_loader, loss_fn, dice_metric)
        dice_liver = dice_scores[0].item()  # Class 1 (liver)
        dice_tumor = dice_scores[1].item()  # Class 2 (tumor)
        
        # Update scheduler
        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
        
        # Save history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_dice_liver'].append(dice_liver)
        history['val_dice_tumor'].append(dice_tumor)
        history['lr'].append(current_lr)
        
        epoch_time = time.time() - epoch_start
        
        # Print progress
        print(f"Epoch {epoch+1:3d}/{CONFIG['epochs']} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | "
              f"Dice Liver: {dice_liver:.4f} | "
              f"Dice Tumor: {dice_tumor:.4f} | "
              f"LR: {current_lr:.2e} | "
              f"Time: {epoch_time:.0f}s")
        
        # Save best model
        if dice_tumor > best_dice_tumor:
            best_dice_tumor = dice_tumor
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_dice_tumor': best_dice_tumor,
                'config': CONFIG,
            }, os.path.join(CONFIG['checkpoint_dir'], 'best_model.pt'))
            print(f"  >> New best tumor Dice: {best_dice_tumor:.4f} - Model saved!")
        else:
            patience_counter += 1
        
        # Save latest checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'history': history,
        }, os.path.join(CONFIG['checkpoint_dir'], 'latest_checkpoint.pt'))
        
        # Early stopping
        if patience_counter >= patience:
            print(f"\nEarly stopping at epoch {epoch+1} (no improvement for {patience} epochs)")
            break
        
        print()

except KeyboardInterrupt:
    print("\n" + "="*60)
    print("Training interrupted by user")
    print("="*60)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'history': history,
    }, os.path.join(CONFIG['checkpoint_dir'], 'interrupted_model.pt'))
    print(f"Model saved to {CONFIG['checkpoint_dir']}/interrupted_model.pt")

print("\n" + "="*60)
print(f"Training Complete! Best Tumor Dice: {best_dice_tumor:.4f}")
print("="*60)

# Save history
with open(os.path.join(CONFIG['checkpoint_dir'], 'training_history.json'), 'w') as f:
    json.dump(history, f, indent=2)

In [None]:
# =============================================================================
# Plot Training History
# =============================================================================

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training & Validation Loss')
axes[0].legend()
axes[0].grid(True)

# Dice Scores
axes[1].plot(history['val_dice_liver'], label='Liver')
axes[1].plot(history['val_dice_tumor'], label='Tumor')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Dice Score')
axes[1].set_title('Validation Dice Scores')
axes[1].legend()
axes[1].grid(True)

# Learning Rate
axes[2].plot(history['lr'])
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Learning Rate')
axes[2].set_title('Learning Rate Schedule')
axes[2].set_yscale('log')
axes[2].grid(True)

plt.tight_layout()
plt.savefig(os.path.join(CONFIG['checkpoint_dir'], 'training_history.png'), dpi=150)
plt.show()

In [None]:
# =============================================================================
# Load Best Model and Evaluate on Test Set
# =============================================================================

print("Loading best model for evaluation...")
checkpoint = torch.load(os.path.join(CONFIG['checkpoint_dir'], 'best_model.pt'))
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded model from epoch {checkpoint['epoch']+1} with tumor Dice: {checkpoint['best_dice_tumor']:.4f}")

# Evaluate on test set
model.eval()
cm = np.zeros((CONFIG['num_classes'], CONFIG['num_classes']), dtype=np.int64)

print("\nEvaluating on test set...")
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc='Testing'):
        images = images.to(device)
        
        with autocast():
            outputs = model(images)
        
        preds = torch.argmax(outputs, dim=1).cpu().numpy()
        labels = labels.numpy()
        
        cm += confusion_matrix(labels.flatten(), preds.flatten(), labels=[0, 1, 2])

print(f"\nTotal voxels evaluated: {cm.sum():,}")

In [None]:
# =============================================================================
# Plot Confusion Matrix and Metrics
# =============================================================================

class_names = ['Background', 'Liver', 'Tumor']
cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True) * 100

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Raw counts
im1 = axes[0].imshow(cm, cmap='Blues')
axes[0].set_title('Confusion Matrix (Counts)', fontsize=14)
axes[0].set_xlabel('Predicted', fontsize=12)
axes[0].set_ylabel('True', fontsize=12)
axes[0].set_xticks(range(3))
axes[0].set_yticks(range(3))
axes[0].set_xticklabels(class_names)
axes[0].set_yticklabels(class_names)
for i in range(3):
    for j in range(3):
        axes[0].text(j, i, f'{cm[i,j]:,}', ha='center', va='center', fontsize=10)
plt.colorbar(im1, ax=axes[0])

# Normalized
im2 = axes[1].imshow(cm_norm, cmap='Blues', vmin=0, vmax=100)
axes[1].set_title('Confusion Matrix (% by True Class)', fontsize=14)
axes[1].set_xlabel('Predicted', fontsize=12)
axes[1].set_ylabel('True', fontsize=12)
axes[1].set_xticks(range(3))
axes[1].set_yticks(range(3))
axes[1].set_xticklabels(class_names)
axes[1].set_yticklabels(class_names)
for i in range(3):
    for j in range(3):
        axes[1].text(j, i, f'{cm_norm[i,j]:.1f}%', ha='center', va='center', fontsize=10)
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.savefig(os.path.join(CONFIG['checkpoint_dir'], 'confusion_matrix.png'), dpi=150)
plt.show()

# Print metrics
print("\n" + "="*60)
print("TEST SET RESULTS - Swin-UNETR")
print("="*60)
print(f"{'Class':<12} {'Precision':>10} {'Recall':>10} {'F1':>10} {'Dice':>10}")
print("-"*60)

for i, name in enumerate(class_names):
    tp = cm[i, i]
    fp = cm[:, i].sum() - tp
    fn = cm[i, :].sum() - tp
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    dice = (2 * tp) / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0
    
    print(f"{name:<12} {precision:>10.4f} {recall:>10.4f} {f1:>10.4f} {dice:>10.4f}")

print("="*60)
accuracy = np.trace(cm) / cm.sum()
print(f"\nOverall Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")

In [None]:
# =============================================================================
# Compare with U-Net Baseline
# =============================================================================

print("\n" + "="*60)
print("COMPARISON: Swin-UNETR vs 3D U-Net Baseline")
print("="*60)

# Your U-Net results (from previous evaluation)
unet_results = {
    'Background': {'dice': 0.9734},
    'Liver': {'dice': 0.9036},
    'Tumor': {'dice': 0.6539},
}

# Calculate Swin-UNETR results
swin_results = {}
for i, name in enumerate(class_names):
    tp = cm[i, i]
    fp = cm[:, i].sum() - tp
    fn = cm[i, :].sum() - tp
    dice = (2 * tp) / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0
    swin_results[name] = {'dice': dice}

print(f"{'Class':<12} {'U-Net':>12} {'Swin-UNETR':>12} {'Improvement':>12}")
print("-"*50)
for name in class_names:
    unet_dice = unet_results[name]['dice']
    swin_dice = swin_results[name]['dice']
    improvement = (swin_dice - unet_dice) * 100
    arrow = '↑' if improvement > 0 else '↓' if improvement < 0 else '→'
    print(f"{name:<12} {unet_dice:>11.4f} {swin_dice:>12.4f} {arrow:>6}{abs(improvement):>5.1f}%")
print("="*50)