# X-Ray Segmentation using U-Net

This notebook implements pelvis X-ray segmentation using a U-Net model.

In [4]:
import os
import sys
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# Add project root to path
project_root = Path.cwd().parent
sys.path.append(str(project_root))

from src.utils.pengwin_utils import (
    load_image, 
    load_masks, 
    build_augmentation,
    visualize_sample,
    CATEGORIES
)

# Print versions for reproducibility
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

  from .autonotebook import tqdm as notebook_tqdm


PyTorch version: 2.5.1
CUDA available: True
CUDA device: NVIDIA GeForce RTX 4070


## 1. Dataset Implementation
First, let's create our custom dataset class for handling X-ray images and their masks.

In [5]:
class XrayDataset(Dataset):
    def __init__(self, root_dir, split='train', img_size=448):
        self.root = Path(root_dir)
        self.split = split
        self.img_size = img_size
        
        # Setup directories
        self.input_dir = self.root / split / "input" / "images" / "x-ray"
        self.output_dir = self.root / split / "output" / "images" / "x-ray"
        
        # Get file paths
        self.image_paths = sorted(self.input_dir.glob("*.tif"))
        self.mask_paths = sorted(self.output_dir.glob("*.tif"))
        
        assert len(self.image_paths) == len(self.mask_paths)
        
        # Setup augmentation
        self.aug = build_augmentation(train=(split=='train'), img_size=img_size)
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image and mask
        image = load_image(self.image_paths[idx])
        masks, category_ids, _ = load_masks(self.mask_paths[idx])
        
        # Apply augmentation
        augmented = self.aug(image=image, masks=masks)
        
        # Convert to torch tensors
        image = torch.from_numpy(augmented['image']).float()
        masks = torch.from_numpy(np.array(augmented['masks'])).float()
        
        return image, masks

In [7]:
# Initialize dataset
dataset = XrayDataset('xray_seg', 'train')
print(f"Dataset size: {len(dataset)}")

# Test loading one sample
image, masks = dataset[0]
print(f"Image shape: {image.shape}")
print(f"Masks shape: {masks.shape}")

# Visualize sample
vis_img = visualize_sample(
    image.numpy(), 
    masks.numpy(), 
    category_ids=list(CATEGORIES.values()),
    fragment_ids=[1]*len(CATEGORIES)
)
plt.figure(figsize=(10, 10))
plt.imshow(vis_img)
plt.axis('off')
plt.show()

TypeError: Lambda.__init__() got an unexpected keyword argument 'keypoint'

## 2. U-Net Model Implementation
Now let's implement our U-Net architecture for segmentation.

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        
        # Encoder (Contracting Path)
        self.enc1 = DoubleConv(in_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)
        self.enc5 = DoubleConv(512, 1024)
        
        # Decoder (Expansive Path)
        self.dec4 = DoubleConv(1024 + 512, 512)
        self.dec3 = DoubleConv(512 + 256, 256)
        self.dec2 = DoubleConv(256 + 128, 128)
        self.dec1 = DoubleConv(128 + 64, 64)
        
        # Final convolution
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
        
        # Pooling and Upsampling
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        enc5 = self.enc5(self.pool(enc4))
        
        # Decoder with skip connections
        dec4 = self.dec4(torch.cat([self.up(enc5), enc4], dim=1))
        dec3 = self.dec3(torch.cat([self.up(dec4), enc3], dim=1))
        dec2 = self.dec2(torch.cat([self.up(dec3), enc2], dim=1))
        dec1 = self.dec1(torch.cat([self.up(dec2), enc1], dim=1))
        
        return self.final_conv(dec1)

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=1, out_channels=len(CATEGORIES)).to(device)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

print(f"Model initialized on: {device}")
print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    epoch_loss = 0
    
    with tqdm(dataloader, desc='Training') as pbar:
        for images, masks in pbar:
            # Move to device
            images = images.to(device)
            masks = masks.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Update progress bar
            epoch_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return epoch_loss / len(dataloader)

def validate(model, dataloader, criterion, device):
    model.eval()
    val_loss = 0
    
    with torch.no_grad():
        with tqdm(dataloader, desc='Validation') as pbar:
            for images, masks in pbar:
                images = images.to(device)
                masks = masks.to(device)
                
                outputs = model(images)
                loss = criterion(outputs, masks)
                
                val_loss += loss.item()
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return val_loss / len(dataloader)

# Training configuration
num_epochs = 100
batch_size = 8

# Create dataloaders
train_dataset = XrayDataset('path/to/your/data', split='train')
val_dataset = XrayDataset('path/to/your/data', split='val')

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Training loop
best_val_loss = float('inf')
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    print(f'\nEpoch {epoch+1}/{num_epochs}')
    
    # Train
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    
    # Validate
    val_loss = validate(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    
    print(f'Training Loss: {train_loss:.4f}')
    print(f'Validation Loss: {val_loss:.4f}')
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
        }, 'best_model.pth')

# Plot training history
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()