# U-Net 3D Paper CT Segmentation

This notebook implements a 3D U-Net model for segmenting pelvis from CT scans.

## Setup and Imports

First, let's check our environment and dependencies.

In [1]:
!pip install SimpleITK
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from tqdm import tqdm
from google.colab import drive
from google.colab import files


# Print versions for reproducibility
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
print(f"SimpleITK version: {sitk.__version__}")
print(f"NumPy version: {np.__version__}")



ModuleNotFoundError: No module named 'google'

## 2. Data Exploration
Let's examine our dataset structure and visualize some samples.

In [None]:
# Define paths
drive.mount('/content/drive')
data_dir = '/content/drive/MyDrive/data'
BASE_DIR = '/content/drive/MyDrive/ct_segmentation'  # Create a dedicated project folder

# Create organized subdirectories
PATHS = {
    'data': f'{BASE_DIR}/data',
    'models': f'{BASE_DIR}/models',
    'checkpoints': f'{BASE_DIR}/checkpoints',
    'results': f'{BASE_DIR}/results',
    'logs': f'{BASE_DIR}/logs'
}

# Create all directories
for path in PATHS.values():
    os.makedirs(path, exist_ok=True)

# Update data paths
images_path = os.path.join(PATHS['data'], 'PENGWIN_CT_train_images')
labels_path = os.path.join(PATHS['data'], 'PENGWIN_CT_train_labels')

# Update model save paths
MODEL_SAVE_PATH = os.path.join(PATHS['models'], 'best_unet_model.pth')
CHECKPOINT_DIR = PATHS['checkpoints']

# List and count files
image_files = sorted([f for f in os.listdir(images_path) if f.endswith('.mha')])
label_files = sorted([f for f in os.listdir(labels_path) if f.endswith('.mha')])

Found 100 image files in ../data\PENGWIN_CT_train_images
Found 100 label files in ../data\PENGWIN_CT_train_labels

First 5 image files:
- 001.mha
- 002.mha
- 003.mha
- 004.mha
- 005.mha

Original image dimensions: (401, 512, 512)
Image spacing: (0.78125, 0.78125, 0.800000011920929)
Value range: [-1023.00, 2775.00] HU


## 3. Dataset Implementation

In [None]:
    def extract_patches(self, image, label):
        """Extract patches from image and label"""
        patches_img = []
        patches_label = []
        
        D, H, W = image.shape
        
        # Calculate steps for each dimension
        d_steps = range(0, D - self.patch_size[0] + 1, self.stride[0])
        h_steps = range(0, H - self.patch_size[1] + 1, self.stride[1])
        w_steps = range(0, W - self.patch_size[2] + 1, self.stride[2])
        
        # If image is smaller than patch size, pad it
        if D < self.patch_size[0]:
            d_steps = [0]
        if H < self.patch_size[1]:
            h_steps = [0]
        if W < self.patch_size[2]:
            w_steps = [0]
            
        for d in d_steps:
            for h in h_steps:
                for w in w_steps:
                    # Extract patches
                    d_end = min(d + self.patch_size[0], D)
                    h_end = min(h + self.patch_size[1], H)
                    w_end = min(w + self.patch_size[2], W)
                    
                    patch_img = image[d:d_end, h:h_end, w:w_end]
                    patch_label = label[d:d_end, h:h_end, w:w_end]
                    
                    # Pad if necessary
                    if patch_img.shape != self.patch_size:
                        pad_d = self.patch_size[0] - patch_img.shape[0]
                        pad_h = self.patch_size[1] - patch_img.shape[1]
                        pad_w = self.patch_size[2] - patch_img.shape[2]
                        
                        patch_img = np.pad(patch_img, 
                                         ((0, pad_d), (0, pad_h), (0, pad_w)), 
                                         mode='constant')
                        patch_label = np.pad(patch_label, 
                                           ((0, pad_d), (0, pad_h), (0, pad_w)), 
                                           mode='constant')
                    
                    patches_img.append(patch_img)
                    patches_label.append(patch_label)
        
        return np.array(patches_img), np.array(patches_label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image and label
        image = sitk.GetArrayFromImage(sitk.ReadImage(self.image_paths[idx])).astype(np.float32)
        label = sitk.GetArrayFromImage(sitk.ReadImage(self.label_paths[idx])).astype(np.float32)
        
        # Preprocessing
        image = np.clip(image, -1000, 1000)
        image = (image + 1000) / 2000
        label = (label > 0).astype(np.float32)
        
        # Extract patches
        patches_img, patches_label = self.extract_patches(image, label)
        
        # Randomly select one patch during training
        patch_idx = np.random.randint(len(patches_img))
        
        # Add channel dimension
        image_patch = np.expand_dims(patches_img[patch_idx], axis=0)
        label_patch = np.expand_dims(patches_label[patch_idx], axis=0)
        
        return torch.tensor(image_patch), torch.tensor(label_patch)

In [None]:
def validate(model, val_image_path, val_label_path, patch_size=(128, 128, 128)):
    """
    Validate model performance on a single validation image
    Returns Dice score
    """
    model.eval()
    device = next(model.parameters()).device
    
    # Load validation image and label
    val_image = sitk.GetArrayFromImage(sitk.ReadImage(val_image_path)).astype(np.float32)
    val_label = sitk.GetArrayFromImage(sitk.ReadImage(val_label_path)).astype(np.float32)
    
    # Preprocess
    val_image = np.clip(val_image, -1000, 1000)
    val_image = (val_image + 1000) / 2000
    val_label = (val_label > 0).astype(np.float32)
    
    # Predict using sliding window
    prediction = predict_volume(model, val_image, patch_size=patch_size)
    
    # Apply sigmoid and threshold
    prediction = (prediction > 0.5).astype(np.float32)
    
    # Calculate Dice score
    intersection = np.sum(prediction * val_label)
    dice_score = (2. * intersection) / (np.sum(prediction) + np.sum(val_label) + 1e-7)
    
    return dice_score

def predict_volume(model, image, patch_size=(128, 128, 128), stride=(64, 64, 64)):
    """Predict segmentation using sliding window approach"""
    model.eval()
    device = next(model.parameters()).device
    
    D, H, W = image.shape
    output = np.zeros_like(image, dtype=np.float32)
    weight = np.zeros_like(image, dtype=np.float32)
    
    # Calculate steps
    d_steps = range(0, D - patch_size[0] + 1, stride[0])
    h_steps = range(0, H - patch_size[1] + 1, stride[1])
    w_steps = range(0, W - patch_size[2] + 1, stride[2])
    
    # Handle edge cases
    if D < patch_size[0]:
        d_steps = [0]
    if H < patch_size[1]:
        h_steps = [0]
    if W < patch_size[2]:
        w_steps = [0]
    
    with torch.no_grad():
        for d in d_steps:
            for h in h_steps:
                for w in w_steps:
                    # Extract patch
                    d_end = min(d + patch_size[0], D)
                    h_end = min(h + patch_size[1], H)
                    w_end = min(w + patch_size[2], W)
                    
                    patch = image[d:d_end, h:h_end, w:w_end]
                    
                    # Pad if necessary
                    if patch.shape != patch_size:
                        pad_d = patch_size[0] - patch.shape[0]
                        pad_h = patch_size[1] - patch.shape[1]
                        pad_w = patch_size[2] - patch.shape[2]
                        
                        patch = np.pad(patch, 
                                     ((0, pad_d), (0, pad_h), (0, pad_w)), 
                                     mode='constant')
                    
                    # Predict
                    patch = torch.tensor(patch).float().unsqueeze(0).unsqueeze(0).to(device)
                    pred = torch.sigmoid(model(patch)).cpu().numpy()[0, 0]
                    
                    # Unpad if necessary
                    if pad_d > 0 or pad_h > 0 or pad_w > 0:
                        pred = pred[:d_end-d, :h_end-h, :w_end-w]
                    
                    # Add to output with weight
                    output[d:d_end, h:h_end, w:w_end] += pred
                    weight[d:d_end, h:h_end, w:w_end] += 1
    
    # Average overlapping predictions
    output = np.divide(output, weight, where=weight!=0)
    return output

In [None]:
class ShapeModelEstimator:
    def __init__(self, n_components=10):
        self.n_components = n_components
        self.pca = None
        self.mean_shape = None
        
    def fit(self, training_masks):
        """训练形状模型"""
        # 将掩码转换为距离图
        distance_maps = []
        for mask in training_masks:
            dist_map = self._compute_distance_map(mask)
            distance_maps.append(dist_map.flatten())
            
        # 执行PCA
        self.pca = PCA(n_components=self.n_components)
        self.pca.fit(distance_maps)
        self.mean_shape = self.pca.mean_.reshape(mask.shape)
        
    def estimate_shape(self, segmentation, alpha=1.0, beta=1.0, gamma=0.1):
        """使用level set方法估计形状"""
        phi = self._compute_distance_map(segmentation)
        
        for _ in range(50):  # 迭代优化
            # 计算图像力
            F = self._compute_image_force(segmentation)
            
            # 计算形状模型力
            shape_force = self._compute_shape_force(phi)
            
            # 计算曲率
            curvature = self._compute_curvature(phi)
            
            # 更新level set
            dphi = (alpha * F + beta * shape_force + gamma * curvature)
            phi += dphi
            
        return phi
    
    def _compute_distance_map(self, binary_mask):
        """计算二值掩码的距离图"""
        return ndimage.distance_transform_edt(binary_mask) - \
               ndimage.distance_transform_edt(1 - binary_mask)
    
    def _compute_image_force(self, seg):
        """计算图像力"""
        return ndimage.gaussian_gradient_magnitude(seg, sigma=1)
    
    def _compute_shape_force(self, phi):
        """计算形状模型力"""
        # 投影到PCA空间并重建
        flat_phi = phi.flatten()
        projection = self.pca.transform([flat_phi])[0]
        reconstruction = self.pca.inverse_transform([projection])[0]
        return reconstruction.reshape(phi.shape) - phi
    
    def _compute_curvature(self, phi):
        """计算曲率"""
        return ndimage.laplace(phi)

In [None]:
class CTScanDataset(Dataset):
    def __init__(self, images_path, labels_path, patch_size=(128, 128, 72), 
                 shape_context=None, transform=True):
        super().__init__()
        self.patch_size = patch_size
        self.transform = transform
        self.shape_context = shape_context
        
        # 加载数据
        self.images = self._load_images(images_path)
        self.labels = self._load_labels(labels_path)
        
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        # 获取patch
        image_patch, label_patch = self._get_patch(image, label)
        
        # 数据增强
        if self.transform:
            image_patch, label_patch = self._augment(image_patch, label_patch)
        
        # 添加形状上下文通道
        if self.shape_context is not None:
            shape_context_patch = self._get_shape_context_patch(idx)
            image_patch = np.concatenate([image_patch, shape_context_patch], axis=0)
        else:
            # 添加空白形状上下文
            blank_context = np.zeros((2,) + image_patch.shape[1:], dtype=np.float32)
            image_patch = np.concatenate([image_patch, blank_context], axis=0)
        
        return torch.from_numpy(image_patch), torch.from_numpy(label_patch)
    
    def _augment(self, image, label):
        """实现数据增强"""
        # 随机平移
        shift = np.random.randint(-20, 20, size=3)
        image = ndimage.shift(image, shift, mode='constant')
        label = ndimage.shift(label, shift, mode='constant')
        
        # 随机旋转
        angle = np.random.randint(-15, 15)
        image = ndimage.rotate(image, angle, axes=(1,2), mode='constant')
        label = ndimage.rotate(label, angle, axes=(1,2), mode='constant')
        
        # 随机缩放
        scale = np.random.uniform(0.9, 1.1)
        image = ndimage.zoom(image, scale, mode='constant')
        label = ndimage.zoom(label, scale, mode='constant')
        
        return image, label

## 4. Model Architecture Visualization

In [None]:
class UNet3D(nn.Module):
    def __init__(self, in_channels=3, out_channels=6):  # 3 input channels (1 CT + 2 shape), 6 classes
        super(UNet3D, self).__init__()
        
        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1),
                nn.BatchNorm3d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv3d(out_ch, out_ch, kernel_size=3, padding=1),
                nn.BatchNorm3d(out_ch),
                nn.ReLU(inplace=True)
            )
        
        # 3 max pooling layers as mentioned in paper
        self.encoder1 = conv_block(in_channels, 64)
        self.encoder2 = conv_block(64, 128)
        self.encoder3 = conv_block(128, 256)
        
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
        
        # Bottleneck
        self.bottleneck = conv_block(256, 512)
        
        # 3 up-sampling layers as mentioned in paper
        self.upconv3 = nn.ConvTranspose3d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = conv_block(512, 256)  # 512 due to skip connection
        self.upconv2 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = conv_block(256, 128)  # 256 due to skip connection
        self.upconv1 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = conv_block(128, 64)   # 128 due to skip connection
        
        self.final = nn.Conv3d(64, out_channels, kernel_size=1)
        
    def forward(self, x):
        # Encoding
        e1 = self.encoder1(x)
        e2 = self.encoder2(self.pool(e1))
        e3 = self.encoder3(self.pool(e2))
        
        # Bottleneck
        b = self.bottleneck(self.pool(e3))
        
        # Decoding with skip connections
        d3 = self.upconv3(b)
        d3 = self.decoder3(torch.cat([d3, e3], dim=1))
        d2 = self.upconv2(d3)
        d2 = self.decoder2(torch.cat([d2, e2], dim=1))
        d1 = self.upconv1(d2)
        d1 = self.decoder1(torch.cat([d1, e1], dim=1))
        
        return self.final(d1)

show the parameter of the model

In [None]:

    def print_model_summary(model):
        """Print model architecture summary"""
        print("Model Architecture:")
        print("------------------")
        total_params = 0
        for name, param in model.named_parameters():
            param_count = param.numel()
            total_params += param_count
            print(f"{name}: {list(param.shape)} ({param_count:,} parameters)")
        print(f"\nTotal parameters: {total_params:,}")

    # Initialize model and print summary
    model = UNet3D(in_channels=1, out_channels=1)
    print_model_summary(model)

Model Architecture:
------------------
encoder1.0.weight: [64, 1, 3, 3, 3] (1,728 parameters)
encoder1.0.bias: [64] (64 parameters)
encoder1.2.weight: [64, 64, 3, 3, 3] (110,592 parameters)
encoder1.2.bias: [64] (64 parameters)
encoder2.0.weight: [128, 64, 3, 3, 3] (221,184 parameters)
encoder2.0.bias: [128] (128 parameters)
encoder2.2.weight: [128, 128, 3, 3, 3] (442,368 parameters)
encoder2.2.bias: [128] (128 parameters)
encoder3.0.weight: [256, 128, 3, 3, 3] (884,736 parameters)
encoder3.0.bias: [256] (256 parameters)
encoder3.2.weight: [256, 256, 3, 3, 3] (1,769,472 parameters)
encoder3.2.bias: [256] (256 parameters)
encoder4.0.weight: [512, 256, 3, 3, 3] (3,538,944 parameters)
encoder4.0.bias: [512] (512 parameters)
encoder4.2.weight: [512, 512, 3, 3, 3] (7,077,888 parameters)
encoder4.2.bias: [512] (512 parameters)
bottleneck.0.weight: [1024, 512, 3, 3, 3] (14,155,776 parameters)
bottleneck.0.bias: [1024] (1,024 parameters)
bottleneck.2.weight: [1024, 1024, 3, 3, 3] (28,311,552 p

## 5. Training Configuration and Progress Tracking

In [None]:

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Use paths from PATHS dictionary
images_path = os.path.join(PATHS['data'], 'PENGWIN_CT_train_images')
labels_path = os.path.join(PATHS['data'], 'PENGWIN_CT_train_labels')

# Hyperparameters
LEARNING_RATE = 0.001
NUM_EPOCHS = 3
BATCH_SIZE = 4

# Dataset and DataLoader setup
train_transform = transforms.Compose([
    transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.float32)),
    transforms.Normalize((0.5,), (0.5,))
])

# Initialize dataset with patch-based approach
train_dataset = CTScanDataset(
    images_path=images_path,
    labels_path=labels_path,
    patch_size=(128, 128, 128),
    stride=(64, 64, 64)
)

# Rest of your training code remains the same
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Model initialization
model = UNet3D(in_channels=1, out_channels=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

class TrainingMonitor:
    def __init__(self):
        self.train_losses = []
        self.current_epoch = 0
        
    def update(self, epoch_loss):
        self.train_losses.append(epoch_loss)
        self.current_epoch += 1
        
    def plot_progress(self):
        plt.figure(figsize=(10, 5))
        plt.plot(self.train_losses, label='Training Loss')
        plt.title('Training Progress')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.grid(True)
        plt.legend()
        plt.show()
        
    def print_stats(self):
        print(f"Current epoch: {self.current_epoch}")
        print(f"Best loss: {min(self.train_losses):.4f}")
        print(f"Current loss: {self.train_losses[-1]:.4f}")

# Initialize training monitor
monitor = TrainingMonitor()

Using device: cuda


## 5. Training Loop

In [None]:
# Update save paths at the start
best_model_dice_path = os.path.join(PATHS['models'], 'best_model_dice.pth')
best_model_loss_path = os.path.join(PATHS['models'], 'best_model_loss.pth')

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    """两步训练过程"""
    # Step 1: 使用空白形状上下文训练
    print("Step 1: Training with blank shape context...")
    for epoch in range(100):  # 论文中提到100个epoch
        train_epoch(model, train_loader, criterion, optimizer, device)
        validate_model(model, val_loader, device)
        
    # 创建形状模型
    shape_estimator = ShapeModelEstimator()
    initial_predictions = get_predictions(model, train_loader, device)
    shape_estimator.fit(initial_predictions)
    
    # Step 2: 使用形状上下文再训练
    print("Step 2: Training with shape context...")
    for epoch in range(40):  # 论文中提到40个epoch
        # 更新形状上下文
        predictions = get_predictions(model, train_loader, device)
        shape_contexts = shape_estimator.estimate_shape(predictions)
        
        # 更新数据集的形状上下文
        train_loader.dataset.shape_context = shape_contexts
        
        # 训练
        train_epoch(model, train_loader, criterion, optimizer, device)
        # 在每个epoch后进行验证
        validate_model(model, val_loader, device)

Epoch 1/3:  92%|█████████▏| 46/50 [44:26<04:32, 68.08s/it, loss=0.318]  

In [None]:
def test_model(model, test_loader, shape_estimator, device, max_iterations=3):
    """迭代测试过程"""
    model.eval()
    results = []
    
    with torch.no_grad():
        for images, _ in test_loader:
            # 初始预测（使用空白形状上下文）
            current_pred = model(images.to(device))
            
            # 迭代改进
            for _ in range(max_iterations):
                # 估计形状
                shape_context = shape_estimator.estimate_shape(current_pred.cpu().numpy())
                
                # 将形状上下文添加到输入
                shape_input = torch.from_numpy(shape_context).to(device)
                combined_input = torch.cat([images, shape_input], dim=1)
                
                # 重新预测
                new_pred = model(combined_input)
                
                # 检查收敛
                if torch.abs(new_pred - current_pred).mean() < 1e-4:
                    break
                    
                current_pred = new_pred
            
            results.append(current_pred)
    
    return results

## 6. Visualize Training Results

In [None]:
class TrainingMonitor:
    def __init__(self):
        self.train_losses = []
        self.val_dices = []
        self.current_epoch = 0
        self.plot_dir = PATHS['results']
        
    def plot_progress(self):
        plt.figure(figsize=(15, 5))
        
        # Plot training loss
        plt.subplot(1, 2, 1)
        plt.plot(self.train_losses, label='Training Loss')
        plt.title('Training Loss Over Time')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.grid(True)
        plt.legend()
        
        # Plot validation dice if available
        if self.val_dices:
            plt.subplot(1, 2, 2)
            plt.plot(range(0, len(self.val_dices)*5, 5), self.val_dices, label='Validation Dice')
            plt.title('Validation Dice Over Time')
            plt.xlabel('Epoch')
            plt.ylabel('Dice Score')
            plt.grid(True)
            plt.legend()
        
        plt.tight_layout()
        
        # Save plot
        plot_path = os.path.join(self.plot_dir, f'training_progress_epoch_{self.current_epoch:03d}.png')
        plt.savefig(plot_path)
        plt.show()
        plt.close()