#  Comparison of Segmentation Methods on X-ray Microtomography Data

This notebook evaluates and compares three segmentation approaches applied to X-ray Microtomography (Micro-CT) images:

1. **Pre-trained U-Net** – A convolutional encoder–decoder widely used in biomedical segmentation.  
2. **Vision Transformer (ViT)** from *timm* – A transformer-based model adapted for image segmentation.  
3. **Cross-Teaching Ensemble** – A hybrid method where U-Net and ViT exchange pseudo-labels during training and produce a combined prediction.

The goal is to benchmark these models on the same dataset, assess their segmentation quality using metrics such as Dice score and IoU, and visualize qualitative differences through sample predictions. This provides insight into how classical CNN-based, transformer-based, and ensemble methods perform on Micro-CT segmentation tasks.


Imports

In [10]:
from glob import glob
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import Dataset
import timm
from skimage.filters import threshold_otsu
from sklearn.metrics import confusion_matrix
from sklearn.metrics import jaccard_score

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

Load and Visualize Data

In [11]:

# ============================================================
#   Helper: Read image as grayscale float32 numpy array
# ============================================================
def load_image(path):
    img = Image.open(path).convert("L")
    return np.array(img).astype(np.float32)


# ============================================================
#   Evaluation-only Dataset for X-ray Microtomography Segmentation
# ============================================================
class XrayMicroCTDataset(Dataset):
    """
    (This is reformatted from the Unet_TransferLearn.py script)
    Evaluation-only dataset:
        ✓ No augmentations
        ✓ No random operations
        ✓ Strictly deterministic
        ✓ Otsu thresholding for masks (optional)
        ✓ Fixed normalization (min-max or z-score)
        ✓ Returns image + mask as torch tensors

    Suitable for inference, visualization, and model comparison.
    """

    def __init__(
        self,
        image_dir,
        mask_dir=None,
        resize_to=512,
        normalize="minmax",   # "minmax", "zscore", or None
        use_otsu=True,
    ):
        self.image_paths = sorted(glob(os.path.join(image_dir, "*")))
        self.mask_paths = sorted(glob(os.path.join(mask_dir, "*"))) if mask_dir else None

        self.resize_to = resize_to
        self.normalize = normalize
        self.use_otsu = use_otsu

        assert len(self.image_paths) > 0, "No images found!"
        if self.mask_paths:
            assert len(self.image_paths) == len(self.mask_paths), \
                "Image and mask count mismatch."

    def __len__(self):
        return len(self.image_paths)

    def _resize(self, arr):
        """Resize using PIL with correct interpolation."""
        img = Image.fromarray(arr.astype(np.float32))
        img = img.resize((self.resize_to, self.resize_to), Image.BILINEAR)
        return np.array(img).astype(np.float32)

    def _normalize(self, arr):
        if self.normalize is None:
            return arr

        if self.normalize == "minmax":
            mn, mx = arr.min(), arr.max()
            if mx > mn:
                return (arr - mn) / (mx - mn)
            else:
                return arr * 0  # degenerate case

        if self.normalize == "zscore":
            mean = arr.mean()
            std = arr.std() + 1e-6
            return (arr - mean) / std

        return arr

    def __getitem__(self, idx):
        # -----------------------------
        # Load image
        # -----------------------------
        img = load_image(self.image_paths[idx])
        img = self._resize(img)
        img = self._normalize(img)

        # Convert to tensor: (1, H, W)
        img_t = torch.from_numpy(img).unsqueeze(0)

        # -----------------------------
        # Load mask (if exists)
        # -----------------------------
        if self.mask_paths is not None:
            mask = load_image(self.mask_paths[idx])

            mask = mask.astype(np.float32)
            mask = self._resize(mask)

            # Optional: Otsu threshold for clean segmentation targets
            if self.use_otsu:
                t = threshold_otsu(mask)
                mask = (mask > t).astype(np.float32)
            else:
                mask = (mask > 0.5).astype(np.float32)

            mask_t = torch.from_numpy(mask).unsqueeze(0)
        else:
            mask_t = None

        return img_t, mask_t, os.path.basename(self.image_paths[idx])

In [12]:
# ============================================================
#   Visualization Helpers
# ============================================================

def show_image(img_tensor, title="Image"):
    img = img_tensor.squeeze().cpu().numpy()
    plt.figure(figsize=(4,4))
    plt.imshow(img, cmap="gray")
    plt.title(title)
    plt.axis("off")
    plt.show()


def show_image_and_mask(img, mask, filename=""):
    img = img.squeeze().cpu().numpy()
    mask = mask.squeeze().cpu().numpy()

    plt.figure(figsize=(10,4))
    plt.suptitle(filename)

    plt.subplot(1,2,1)
    plt.imshow(img, cmap="gray")
    plt.title("Image")
    plt.axis("off")

    plt.subplot(1,2,2)
    plt.imshow(mask, cmap="gray")
    plt.title("Mask (binary)")
    plt.axis("off")
    plt.show()


def compare_predictions(img, mask, unet_pred, vit_pred, ensemble_pred, filename=""):
    """Side-by-side comparison for evaluation notebook."""
    img  = img.squeeze().cpu().numpy()
    mask = mask.squeeze().cpu().numpy()
    u    = unet_pred.squeeze().cpu().numpy()
    v    = vit_pred.squeeze().cpu().numpy()
    e    = ensemble_pred.squeeze().cpu().numpy()

    plt.figure(figsize=(16,6))
    plt.suptitle(f"Segmentation Comparison – {filename}", fontsize=14)

    titles = [
        "Input Image",
        "Ground Truth Mask",
        "U-Net Prediction",
        "ViT Prediction",
        "Ensemble Output",
    ]
    arrays = [img, mask, u, v, e]

    for i, (arr, t) in enumerate(zip(arrays, titles)):
        plt.subplot(1, 5, i+1)
        plt.imshow(arr, cmap="gray")
        plt.title(t)
        plt.axis("off")

    plt.show()


In [13]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

IMAGE_DIR = "/path/to/labeled/images"           # ⚠️ UPDATE ME
MASK_DIR  = "/path/to/labeled/masks"            # ⚠️ UPDATE ME

dataset = XrayMicroCTDataset(
    image_dir=IMAGE_DIR,
    mask_dir=MASK_DIR,
    resize_to=512,
    normalize="minmax",
    use_otsu=True
)

# Create dataset
dataset = SegmentationDataset(image_files, mask_files)
sample_img, sample_mask = dataset[0]
sample_img = sample_img.unsqueeze(0).to(DEVICE)

AssertionError: No images found!

In [None]:
img, mask, fname = dataset[0]
show_image_and_mask(img, mask, fname)

Load in Models and Model Architectures

In [None]:
# -----------------------------
# Vision Transformer Segmentation Head
# -----------------------------
class ViTSegmentationHead(nn.Module):
    def __init__(self, embed_dim=1024, num_classes=1):
        super().__init__()
        self.norm = nn.LayerNorm(embed_dim)
        self.conv = nn.Conv2d(embed_dim, num_classes, kernel_size=1)

    def forward(self, features):
        normed = self.norm(features)
        seg_map = self.conv(normed)
        seg_map = torch.sigmoid(seg_map)
        return seg_map

# -----------------------------
# Vision Transformer Segmentation Model
# -----------------------------
class ViTSegmentation(nn.Module):
    def __init__(self, num_classes=1, img_size=224):
        super().__init__()

        self.vit = create_model(
            'vit_large_patch16_224',
            pretrained=True,
            img_size=img_size,
            in_chans=1,
            num_classes=0
        )
        for param in self.vit.parameters():
            param.requires_grad = False

        embed_dim = self.vit.embed_dim
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, 512, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, embed_dim, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
        )
        self.seg_head = ViTSegmentationHead(embed_dim=embed_dim, num_classes=num_classes)

    def forward(self, x):
        vit_out = self.vit.forward_features(x)
        b, hw, embed = vit_out.shape
        side = int(hw ** 0.5)
        feat_2d = vit_out.transpose(1, 2).reshape(b, embed, side, side)
        upsampled = self.decoder(feat_2d)
        seg_map = self.seg_head(upsampled)
        return seg_map

    def load_vit_weights(self, path):
        state_dict = np.load(path)['params']
        self.load_state_dict(state_dict, strict=False)
        print(f"Loaded custom ViT weights from {path}")


# -----------------------------
# U-Net Loader Function
# -----------------------------
def load_unet_model(unet_path, device="cuda"):
    sys.path.append("..")  # ⚠️ MAY NEED UPDATE
    from unet_pytorch import create_unet_for_porosity  # ⚠️ CHECK PATH

    model = create_unet_for_porosity().to(device)
    checkpoint = torch.load(unet_path, map_location=device)

    if 'state_dict' in checkpoint:
        model.load_state_dict(checkpoint['state_dict'])
    else:
        model.load_state_dict(checkpoint)

    model.eval()
    print(f"Loaded U-Net model from {unet_path}")
    return model


# -----------------------------
# Cross Teaching Trainer (for ensemble inference)
# -----------------------------
class CrossTeachingTrainer:
    def __init__(self, unet, vit, device="cuda"):
        self.unet = unet.eval()
        self.vit = vit.eval()
        self.device = device

    @torch.no_grad()
    def ensemble_predict(self, images):
        images = images.to(self.device)

        unet_in = F.interpolate(images, size=(512, 512), mode="bilinear", align_corners=False)
        vit_in = F.interpolate(images, size=(224, 224), mode="bilinear", align_corners=False)

        unet_pred = self.unet(unet_in)
        vit_pred = self.vit(vit_in)

        unet_pred_up = F.interpolate(unet_pred, size=images.shape[-2:], mode="bilinear", align_corners=False)
        vit_pred_up = F.interpolate(vit_pred, size=images.shape[-2:], mode="bilinear", align_corners=False)

        return (unet_pred_up + vit_pred_up) / 2

#Metric for segmentation performance
def dice_score(pred, target, eps=1e-6):
    pred_bin = (pred > 0.5).float()
    target_bin = (target > 0.5).float()
    intersection = (pred_bin * target_bin).sum()
    return (2 * intersection) / (pred_bin.sum() + target_bin.sum() + eps)

In [None]:
UNET_PATH = "/path/to/unet_checkpoint.pth"     # ⚠️ UPDATE ME
VIT_PRETRAIN = "/path/to/vit_weights.npz"       # ⚠️ UPDATE ME
XUNET_PATH = "/path/to/Xunet_checkpoint.pth"     # ⚠️ UPDATE ME
XVIT_PRETRAIN = "/path/to/vit_weights.npz"       #⚠️ UPDATE ME
# Load Models
unet = load_unet_model(UNET_PATH, device=DEVICE)
vit = ViTSegmentation().to(DEVICE)

Xunet = load_unet_model(XUNET_PATH, device=DEVICE)
Xvit = ViTSegmentation().to(DEVICE)

# Optional: load custom ViT weights                   <---- do this
vit.load_vit_weights(VIT_PRETRAIN)
Xvit.load_vit_weights(XVIT_PRETRAIN)


# Ensemble
ensemble = CrossTeachingTrainer(unet, vit, device=DEVICE)

Inference Helpers

In [None]:
@torch.no_grad()
def predict_unet(unet, image_tensor, device="cuda"):
    image_resized = F.interpolate(image_tensor, size=(512, 512),
                                  mode="bilinear", align_corners=False)
    pred = unet(image_resized)
    pred = F.interpolate(pred, size=image_tensor.shape[-2:],
                         mode="bilinear", align_corners=False)
    return pred


@torch.no_grad()
def predict_vit(vit, image_tensor, device="cuda"):
    image_resized = F.interpolate(image_tensor, size=(224, 224),
                                  mode="bilinear", align_corners=False)
    pred = vit(image_resized)
    pred = F.interpolate(pred, size=image_tensor.shape[-2:],
                         mode="bilinear", align_corners=False)
    return pred

Visualize

In [None]:
def visualize_predictions(image, mask, unet_pred, vit_pred, ensemble_pred):
    image = image.squeeze().cpu().numpy()
    mask = mask.squeeze().cpu().numpy()
    up = unet_pred.squeeze().cpu().numpy()
    vp = vit_pred.squeeze().cpu().numpy()
    ep = ensemble_pred.squeeze().cpu().numpy()

    plt.figure(figsize=(15, 8))

    titles = ["Image", "Ground Truth", "U-Net", "ViT", "Ensemble"]
    imgs = [image, mask, up, vp, ep]

    for i, (t, im) in enumerate(zip(titles, imgs), 1):
        plt.subplot(2, 3, i)
        plt.imshow(im, cmap="gray")
        plt.title(t)
        plt.axis("off")

    plt.tight_layout()
    plt.show()
idx = 0
image, mask = dataset[idx]

image_tensor = image.unsqueeze(0).to(DEVICE)   # shape: (1, 1, H, W)
mask_tensor  = mask.unsqueeze(0).to(DEVICE)

# Get model predictions
unet_pred = predict_unet(unet, image_tensor, device=DEVICE)
vit_pred = predict_vit(vit, image_tensor, device=DEVICE)
ensemble_pred = ensemble.ensemble_predict(image_tensor)

# Visualize
visualize_predictions(
    image_tensor,
    mask_tensor,
    unet_pred,
    vit_pred,
    ensemble_pred
)

In [None]:
pred_unet = predict_unet(unet, sample_img)
pred_vit = predict_vit(vit, sample_img)
pred_ensemble = ensemble.ensemble_predict(sample_img)

Evaluations

In [None]:
# -----------------------------
# Metric utilities
# -----------------------------

def dice_score(pred, target):
    """Compute Dice score for binary segmentation."""
    pred = pred.flatten()
    target = target.flatten()

    intersection = (pred * target).sum()
    return (2 * intersection) / (pred.sum() + target.sum() + 1e-6)


def iou_score(pred, target):
    """Compute Intersection-over-Union."""
    pred = pred.flatten()
    target = target.flatten()

    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    return intersection / (union + 1e-6)


def precision_recall(pred, target):
    """Compute precision and recall using TP, FP, FN."""
    pred = pred.flatten()
    target = target.flatten()

    tp = (pred * target).sum()
    fp = (pred * (1 - target)).sum()
    fn = ((1 - pred) * target).sum()

    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)

    return precision, recall


# -----------------------------
# Evaluation function
# -----------------------------

def evaluate_models(dataset, unet, vit, ensemble, device="cuda"):
    """
    Evaluate U-Net, ViT, and Ensemble segmentation models using
    consistent metrics (Dice, IoU, Precision, Recall)
    on the provided XrayMicroCTDataset.

    Returns and prints a clean summary table.
    """

    # Metric accumulators
    metrics = {
        "U-Net": {"dice": [], "iou": [], "prec": [], "rec": []},
        "ViT":   {"dice": [], "iou": [], "prec": [], "rec": []},
        "Ensemble": {"dice": [], "iou": [], "prec": [], "rec": []},
    }

    unet.eval()
    vit.eval()
    ensemble.eval()

    # -------------------------
    # Main evaluation loop
    # -------------------------
    with torch.no_grad():
        for img_t, mask_t, fname in dataset:

            # Move to device
            img = img_t.to(device).unsqueeze(0)    # add batch dimension
            mask = mask_t.to(device).unsqueeze(0)  # [1,1,H,W]

            # Ground truth (binary)
            target = (mask > 0.5).float()

            # ---- Predictions ----
            pred_unet = (predict_unet(unet, img) > 0.5).float()
            pred_vit = (predict_vit(vit, img) > 0.5).float()
            pred_ens = (ensemble.ensemble_predict(img) > 0.5).float()

            # ---- Metrics ----
            for name, pred in [
                ("U-Net", pred_unet),
                ("ViT", pred_vit),
                ("Ensemble", pred_ens),
            ]:
                d = dice_score(pred.cpu().numpy(), target.cpu().numpy())
                i = iou_score(pred.cpu().numpy(), target.cpu().numpy())
                p, r = precision_recall(pred.cpu().numpy(), target.cpu().numpy())

                metrics[name]["dice"].append(d)
                metrics[name]["iou"].append(i)
                metrics[name]["prec"].append(p)
                metrics[name]["rec"].append(r)

    # -------------------------
    # Create summary table
    # -------------------------
    summary = pd.DataFrame({
        "Model": ["U-Net", "ViT", "Ensemble"],
        "Dice": [np.mean(metrics[m]["dice"]) for m in metrics],
        "IoU": [np.mean(metrics[m]["iou"]) for m in metrics],
        "Precision": [np.mean(metrics[m]["prec"]) for m in metrics],
        "Recall": [np.mean(metrics[m]["rec"]) for m in metrics],
    })

    print("\n=== Segmentation Performance Summary ===")
    display(summary.style.set_caption("Model Comparison on X-ray Micro-CT Segmentation"))

    return summary



In [None]:
summary = evaluate_models(
    eval_loader,
    unet_model,
    vit_model,
    ensemble_model,
    device=device
)