# Segmentation Tutorial: Classical Methods + UNet

This tutorial combines basic classical image segmentation techniques with a minimal UNet training/inference workflow on the NuInsSeg dataset contained in this repository. It is designed to run training for only 1 epoch on a small subset.

What you'll do:
- Explore classical segmentation: grayscale, Otsu thresholding, simple morphology, Sobel edges.
- Train a small UNet for 1 epoch on a limited subset.
- Load a pretrained checkpoint (if available under `runs/`) and run inference.

Notes:
- Data: this uses the `NuInsSeg/` tree (e.g., `human spleen/tissue images` + `
mask binary`).
- Keep the subset size small (`LIMIT`) and the image size moderate (`IMG_SIZE`).
- GPU is recommended but not required.


In [1]:
# Setup 1: Ensure correct working directory

import os
import subprocess
import sys

# Check if we need to clone (Colab case)
if not os.path.exists('train_unet'):
    print("Cloning repository...")
    subprocess.run(['git', 'clone', 'https://github.com/KikuchiJun1/MicroTas-2025-Workshop-11.git', 'segmentation_repo'], check=True)
    os.chdir('segmentation_repo')
    print("Repository cloned successfully!")
else:
    # Already in the repo directory (local case)
    print("✓ Repository directory detected")

print(f"Working directory: {os.getcwd()}")
print()

Cloning repository...
Repository cloned successfully!
Working directory: /content/segmentation_repo



In [2]:
# Setup 2: Import Required Libraries

import random, time
from PIL import Image, ImageFilter, ImageOps
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader

from segmentation.train_unet.dataset_nuinsseg import discover_pairs, NuInsSegDataset
from segmentation.train_unet.transforms import build_transforms
from segmentation.train_unet.model_unet import UNet
from segmentation.train_unet.utils import set_seed, bce_dice_loss, dice_coeff, iou_score, AverageMeter, ensure_dir, count_parameters

print("All libraries imported successfully!")

All libraries imported successfully!


In [3]:
# Setup 3: Configuration & Download Pretrained Weights

set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

# Configuration
DATA_ROOT = 'segmentation/NuInsSeg'
INCLUDE = ['human_spleen']
IMG_SIZE = 256
LIMIT = 40      # total pairs used (keep small for speed)
VAL_FRAC = 0.2
TEST_FRAC = 0.1
BATCH_SIZE = 4
EPOCHS = 10
OUTDIR = 'segmentation/runs/tutorial_classical_unet'
ensure_dir(OUTDIR)


Device: cpu


In [None]:
# Setup 4: Download Pretrained Weights (Optional)

WEIGHTS_URL = 'https://drive.google.com/uc?id=1tjfTRYWWf1OlsHml1-DltZuIA6O71_ro'
WEIGHTS_PATH = os.path.join('segmentation','runs', 'expD2', 'best.pt')

if not os.path.exists(WEIGHTS_PATH):
    os.makedirs(os.path.dirname(WEIGHTS_PATH), exist_ok=True)
    print("Attempting to download pretrained weights from Google Drive...")
    try:
        # Install gdown
        subprocess.run(['pip', 'install', '-q', 'gdown'], check=True)
        # Download weights
        subprocess.run(['gdown', WEIGHTS_URL, '-O', WEIGHTS_PATH], check=True)
        print(f"✓ Weights downloaded successfully to {WEIGHTS_PATH}")
    except subprocess.CalledProcessError as e:
        print(f"⚠️ Could not download weights: {e}")
        print(f"   Weights will be optional - training model will be used for inference")
else:
    print(f"✓ Weights already exist at {WEIGHTS_PATH}")

print("\nSetup complete! Ready to start segmentation tutorial.")

Attempting to download pretrained weights from Google Drive...


In [None]:
# Helper functions: visualization, Otsu threshold, Sobel edges, simple morphology

def to_gray_uint8(pil_img):
    return np.asarray(pil_img.convert('L'), dtype=np.uint8)

def overlay_mask(rgb_img, mask, color=(255, 0, 0), alpha=0.4):
    if isinstance(rgb_img, Image.Image):
        base = np.asarray(rgb_img.convert('RGB')).copy()
    else:
        base = np.asarray(rgb_img).copy()
    m = (mask > 0)
    overlay = np.zeros_like(base)
    overlay[m] = np.array(color, dtype=np.uint8)
    out = (base * (1 - alpha) + overlay * alpha).clip(0, 255).astype(np.uint8)
    return out

def show_triplet(img, pred_mask, gt_mask, title=''):
    plt.figure(figsize=(12, 4))
    plt.subplot(1,3,1); plt.imshow(img); plt.axis('off'); plt.title('Image')
    plt.subplot(1,3,2); plt.imshow(pred_mask, cmap='gray'); plt.axis('off'); plt.title('Prediction')
    plt.subplot(1,3,3); plt.imshow(gt_mask, cmap='gray'); plt.axis('off'); plt.title('Ground Truth')
    if title: plt.suptitle(title)
    plt.show()

def otsu_threshold(gray_uint8):
    # gray_uint8: HxW in [0..255]
    hist = np.bincount(gray_uint8.ravel(), minlength=256).astype(np.float64)
    p = hist / (hist.sum() + 1e-12)
    omega = np.cumsum(p)
    mu = np.cumsum(p * np.arange(256))
    mu_t = mu[-1]
    denom = (omega * (1.0 - omega))
    denom[denom == 0] = np.nan
    sigma_b2 = (mu_t * omega - mu)**2 / denom
    t = int(np.nanargmax(sigma_b2))
    return t

def _to_tensor01(x):
    return torch.from_numpy(x.astype(np.float32))[None, None]  # [1,1,H,W]

def dilate_bin(mask_uint8, k=3, iters=1):
    x = _to_tensor01(mask_uint8)
    for _ in range(max(1, iters)):
        x = torch.nn.functional.max_pool2d(x, kernel_size=k, stride=1, padding=k//2)
    return (x.squeeze().numpy() > 0.5).astype(np.uint8)

def erode_bin(mask_uint8, k=3, iters=1):
    x = _to_tensor01(mask_uint8)
    for _ in range(max(1, iters)):
        x = -torch.nn.functional.max_pool2d(-x, kernel_size=k, stride=1, padding=k//2)
    return (x.squeeze().numpy() > 0.5).astype(np.uint8)

def open_bin(mask_uint8, k=3, iters=1):
    return dilate_bin(erode_bin(mask_uint8, k, iters), k, iters)

def close_bin(mask_uint8, k=3, iters=1):
    return erode_bin(dilate_bin(mask_uint8, k, iters), k, iters)

def compute_metrics(pred_mask_uint8, gt_mask_uint8):
    p = torch.from_numpy(pred_mask_uint8.astype(np.float32))[None, None]
    t = torch.from_numpy(gt_mask_uint8.astype(np.float32))[None, None]
    with torch.no_grad():
        logits = p  # {0,1} is fine for these metrics
        d = dice_coeff(logits, t).item()
        i = iou_score(logits, t).item()
    return d, i

def split_pairs_threeway(pairs, val_frac, test_frac, seed=42):
    if val_frac + test_frac >= 1.0:
        raise ValueError('val_frac + test_frac must be < 1.0')
    n = len(pairs)
    rng = random.Random(seed)
    idxs = list(range(n)); rng.shuffle(idxs)
    n_test = int(n * test_frac)
    n_val = int(n * val_frac)
    if test_frac > 0 and n_test == 0 and n >= 1: n_test = 1
    if val_frac > 0 and n_val == 0 and n - n_test >= 1: n_val = 1
    test_idx = set(idxs[:n_test])
    val_idx = set(idxs[n_test:n_test+n_val])
    train = [pairs[i] for i in idxs if i not in test_idx and i not in val_idx]
    val = [pairs[i] for i in val_idx]
    test = [pairs[i] for i in test_idx]
    return train, val, test

def collate_batch(batch):
    imgs = torch.stack([b.image for b in batch], dim=0)
    msks = torch.stack([b.mask for b in batch], dim=0)
    paths = [(b.img_path, b.msk_path) for b in batch]
    return imgs, msks, paths


In [None]:
# Discover pairs and preview

pairs = discover_pairs(DATA_ROOT, include=INCLUDE)
if LIMIT and LIMIT > 0 and len(pairs) > LIMIT:
    pairs = pairs[:LIMIT]
print(f'Found {len(pairs)} image/mask pairs (limited).')

train_pairs, val_pairs, test_pairs = split_pairs_threeway(pairs, VAL_FRAC, TEST_FRAC, seed=42)
print(f'Train: {len(train_pairs)} | Val: {len(val_pairs)} | Test: {len(test_pairs)}')

# Preview a few samples
num_preview = min(3, len(pairs))
plt.figure(figsize=(12, 4*num_preview))
for i in range(num_preview):
    img_p, msk_p = pairs[i]
    img = Image.open(img_p).convert('RGB')
    msk = (np.asarray(Image.open(msk_p).convert('L')) > 0).astype(np.uint8)
    over = overlay_mask(img, msk*255, color=(0,255,0), alpha=0.35)
    plt.subplot(num_preview, 2, 2*i+1); plt.imshow(img); plt.axis('off'); plt.title(os.path.basename(img_p))
    plt.subplot(num_preview, 2, 2*i+2); plt.imshow(over); plt.axis('off'); plt.title('Overlay GT mask')
plt.show()


In [None]:
# Classical method: Otsu threshold + morphology

val_eval = []
n_demo = min(4, len(val_pairs))
for i in range(n_demo):
    img_p, msk_p = val_pairs[i]
    img = Image.open(img_p).convert('RGB')
    gt = (np.asarray(Image.open(msk_p).convert('L')) > 0).astype(np.uint8)
    # slight blur to denoise before thresholding
    img_blur = img.filter(ImageFilter.GaussianBlur(radius=1))
    gray = to_gray_uint8(img_blur)
    t = otsu_threshold(gray)
    pred = (gray < t).astype(np.uint8)
    # small opening to remove speckles
    pred = open_bin(pred, k=3, iters=1)
    d, iou = compute_metrics(pred, gt)
    val_eval.append((d, iou))
    show_triplet(img, pred*255, gt*255, title=f'Otsu+Open | Dice={d:.3f}, IoU={iou:.3f}')

if val_eval:
    md = float(np.mean([x[0] for x in val_eval]))
    mi = float(np.mean([x[1] for x in val_eval]))
    print(f'Classical (Otsu+Open) on {len(val_eval)} samples -> Dice={md:.3f}, IoU={mi:.3f}')
else:
    print('No validation pairs to demo classical method.')


In [None]:
# Prepare Datasets and DataLoaders

# Build transforms for training and validation
# Training: includes augmentation (rotation, flipping, etc.) for better generalization
# Validation: no augmentation, just normalization
train_t, val_t = build_transforms(IMG_SIZE, augment=True)
print(f"✓ Transforms created for IMG_SIZE={IMG_SIZE}")

# Create datasets
# NuInsSegDataset loads image/mask pairs and applies transforms
train_ds = NuInsSegDataset(train_pairs, transform=train_t, img_size=IMG_SIZE)
val_ds = NuInsSegDataset(val_pairs, transform=val_t, img_size=IMG_SIZE)
print(f"✓ Training dataset: {len(train_ds)} samples")
print(f"✓ Validation dataset: {len(val_ds)} samples")

# Create DataLoaders for batching and parallel loading
# num_workers=2: parallel loading (faster on GPU)
# pin_memory=True: faster GPU data transfer
train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,           # shuffle training data for better learning
    num_workers=2,
    pin_memory=True,
    collate_fn=collate_batch
)
val_loader = DataLoader(
    val_ds,
    batch_size=max(1, BATCH_SIZE//2),  # smaller batch for validation
    shuffle=False,          # don't shuffle validation data
    num_workers=2,
    pin_memory=True,
    collate_fn=collate_batch
)
print(f"✓ Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")
print()

In [None]:
# Initialize UNet Model

# Create small UNet with base channels=16 (lightweight for fast training)
# Larger base_ch values (e.g., 64) = more parameters but slower training
model_small = UNet(n_channels=3, n_classes=1, base_ch=16).to(device)
total_params = count_parameters(model_small)
print(f"✓ UNet Model created")
print(f"  - Input channels: 3 (RGB)")
print(f"  - Output channels: 1 (binary segmentation)")
print(f"  - Base channels: 16")
print(f"  - Total parameters: {total_params:,}")

# Initialize optimizer
# Adam: adaptive learning rate optimizer (good for most tasks)
# lr=1e-3: learning rate (0.001)
opt = torch.optim.Adam(model_small.parameters(), lr=1e-3)
print(f"✓ Optimizer: Adam (lr=1e-3)")

# Track best validation performance
best_dice = -1.0
print()

In [None]:
# Train UNet Model

import warnings
warnings.filterwarnings('ignore')  # Suppress warnings for cleaner output

for epoch in range(EPOCHS):
    # ===== TRAINING PHASE =====
    model_small.train()  # enable training mode (dropout, batch norm active)
    train_loss = AverageMeter()  # track average loss per epoch

    for step, (imgs, msks, _) in enumerate(train_loader, 1):
        # Move data to GPU/CPU
        imgs = imgs.to(device, non_blocking=True)
        msks = msks.to(device, non_blocking=True)

        # Forward pass
        opt.zero_grad(set_to_none=True)  # clear previous gradients
        logits = model_small(imgs)  # predict segmentation masks

        # Compute loss (binary cross-entropy + Dice loss)
        loss = bce_dice_loss(logits, msks)

        # Backward pass
        loss.backward()  # compute gradients
        opt.step()  # update weights

        # Update metrics
        train_loss.update(loss.item(), imgs.size(0))

        # Print progress every 10 steps
        if step % 10 == 0:
            print(f'  Epoch {epoch+1}/{EPOCHS} | Step {step}/{len(train_loader)} | Loss {train_loss.avg:.4f}')

    # ===== VALIDATION PHASE =====
    model_small.eval()  # disable training mode (dropout off, batch norm frozen)
    val_loss = AverageMeter()
    val_dice = AverageMeter()
    val_iou = AverageMeter()

    with torch.no_grad():  # disable gradient computation for speed
        for imgs, msks, _ in val_loader:
            imgs = imgs.to(device, non_blocking=True)
            msks = msks.to(device, non_blocking=True)

            logits = model_small(imgs)
            loss = bce_dice_loss(logits, msks)

            # Update validation metrics
            val_loss.update(loss.item(), imgs.size(0))
            val_dice.update(dice_coeff(logits, msks).item(), imgs.size(0))
            val_iou.update(iou_score(logits, msks).item(), imgs.size(0))

    # Print epoch summary
    print(f'Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss.avg:.4f} | Val Loss: {val_loss.avg:.4f}')
    print(f'           | Dice: {val_dice.avg:.4f} | IoU: {val_iou.avg:.4f}')

    # Save best model checkpoint
    if val_dice.avg > best_dice:
        best_dice = val_dice.avg
        ckpt_path = os.path.join(OUTDIR, 'best.pt')
        torch.save({'epoch': epoch+1, 'model': model_small.state_dict()}, ckpt_path)
        print(f'  ✓ Saved best checkpoint to {ckpt_path}')

print("=" * 60)
print(f"Training complete! Best Dice Score: {best_dice:.4f}")
print("=" * 60)
print()

In [None]:
# Visualize Validation Predictions

model_small.eval()
n_show = min(2, len(val_ds))
print(f"Showing {n_show} validation predictions:\n")

with torch.no_grad():
    for i in range(n_show):
        s = val_ds[i]

        # Prepare input
        img_t = s.image.unsqueeze(0).to(device)

        # Predict
        logits = model_small(img_t)
        prob = torch.sigmoid(logits)[0, 0].cpu().numpy()
        pred = (prob >= 0.5).astype(np.uint8)  # threshold at 0.5

        # Convert tensors back to numpy for visualization
        rgb = (s.image.clamp(0, 1).cpu().permute(1, 2, 0).numpy() * 255).astype(np.uint8)
        gt = s.mask[0].cpu().numpy().astype(np.uint8)

        # Visualize
        show_triplet(rgb, pred*255, gt*255, title=f'UNet Prediction (Sample {i+1}/{n_show})')

print("✓ Validation visualization complete!")
print()

In [None]:
# Pretrained Inference

# Define checkpoint paths to try (in order of preference)
ckpt_candidates = [
    os.path.join('segmentation','runs', 'expD2', 'best.pt'),  # Pretrained weights
    os.path.join(OUTDIR, 'best.pt'),           # Newly trained weights
]

# Find first available checkpoint
ckpt_path = next((p for p in ckpt_candidates if os.path.isfile(p)), None)
print(f'Looking for checkpoint...')
print(f'Candidate: {ckpt_candidates[0]} → {"✓ Found" if os.path.isfile(ckpt_candidates[0]) else "✗ Not found"}')
print(f'Candidate: {ckpt_candidates[1]} → {"✓ Found" if os.path.isfile(ckpt_candidates[1]) else "✗ Not found"}')
print()

if ckpt_path is None:
    print('⚠️ No checkpoint found; using the trained model for inference.')
    model_pre = model_small
else:
    print(f'✓ Loading checkpoint from: {ckpt_path}')

    # Load checkpoint with weights_only=False (needed for PyTorch 2.6+)
    try:
        sd = torch.load(ckpt_path, map_location='cpu', weights_only=False)
        sd = sd.get('model', sd)  # extract model weights
        print(f'✓ Checkpoint loaded ({len(sd)} parameters)')
    except Exception as e:
        print(f'⚠️ Failed to load checkpoint: {e}')
        model_pre = model_small
        ckpt_path = None

    # Try to load with different UNet configurations
    if ckpt_path is not None:
        loaded = False
        for base_ch in [64, 32, 16]:
            try:
                model_pre = UNet(n_channels=3, n_classes=1, base_ch=base_ch).to(device)
                model_pre.load_state_dict(sd, strict=True)
                print(f'✓ Loaded pretrained UNet with base_ch={base_ch}')
                loaded = True
                break
            except Exception as e:
                continue

        if not loaded:
            print(f'⚠️ Could not load with any base_ch configuration')
            print(f'  Falling back to the trained model')
            model_pre = model_small

# Run inference on test/validation samples
print()
print("-" * 60)
print("Running Inference")
print("-" * 60)

model_pre.eval()

# Use test set if available; otherwise use validation set
sample_pairs = test_pairs if len(test_pairs) > 0 else val_pairs
n_infer = min(6, len(sample_pairs))
print(f'Inferring on {n_infer} {("test" if len(test_pairs) > 0 else "validation")} samples:\n')

with torch.no_grad():
    for i in range(n_infer):
        img_p, msk_p = sample_pairs[i]

        # Load image and ground truth
        img = Image.open(img_p).convert('RGB')
        gt = (np.asarray(Image.open(msk_p).convert('L')) > 0).astype(np.uint8)

        # Prepare input (normalize to [0, 1])
        img_np = np.asarray(img, dtype=np.uint8)
        img_t = torch.from_numpy(img_np).float().permute(2, 0, 1) / 255.0
        img_t = img_t.unsqueeze(0).to(device)  # add batch dimension

        # Predict
        logits = model_pre(img_t)
        prob = torch.sigmoid(logits)[0, 0].cpu().numpy()
        pred = (prob >= 0.5).astype(np.uint8)  # threshold at 0.5

        # Visualize
        show_triplet(img, pred*255, gt*255,
                    title=f'Inference {i+1}/{n_infer}: {os.path.basename(img_p)}')

print("=" * 60)
print("✓ Inference complete!")
print("=" * 60)

In [None]:
# Interactive Image Segmentation

from pathlib import Path
import shutil

repo_dir = Path.cwd()
content_dir = Path("/content")

preset_root = repo_dir / "segmentation" / "NuInsSeg" / "other_images"
if not preset_root.is_dir():
    alt_root = Path("/content/segmentation_repo/segmentation/NuInsSeg/other_images")
    if alt_root.is_dir():
        preset_root = alt_root

PRESET_NAMES = [
    "human_bladder_03",
    "human_brain_11",
    "human_kidney_02",
    "mouse_muscle_tibia_02",
    "mouse_thymus_01",
]
PRESET_IMAGES = {name: preset_root / f"{name}.png" for name in PRESET_NAMES}

print("Preset image options:")
for name in PRESET_NAMES:
    print(f"  • {name} → {PRESET_IMAGES[name]}")

# ========================== USER SETTINGS ==========================
IMAGE_CHOICE = "mouse_thymus_01"  # Pick one of PRESET_NAMES or set to None
CUSTOM_FILENAME = None             # e.g., "my_upload.jpg" if you uploaded a file
# ===================================================================

uploaded_path = None

if IMAGE_CHOICE:
    key = IMAGE_CHOICE.strip()
    preset_path = PRESET_IMAGES.get(key)
    if preset_path and preset_path.is_file():
        uploaded_path = preset_path
        print(f"Using preset image: {key}")
        print(f"   Path: {preset_path}")
    else:
        print(f"Preset image '{key}' not found at {preset_path}")
        print("Falling back to CUSTOM_FILENAME workflow…")
        IMAGE_CHOICE = None

if uploaded_path is None and CUSTOM_FILENAME:
    IMAGE_FILENAME = CUSTOM_FILENAME.strip()
    source_path = content_dir / IMAGE_FILENAME
    dest_path = repo_dir / IMAGE_FILENAME

    if source_path.is_file():
        if dest_path.resolve() != source_path.resolve():
            shutil.copy2(source_path, dest_path)
            print(f"Copied {IMAGE_FILENAME} from /content to {repo_dir}")
        uploaded_path = dest_path
    elif dest_path.is_file():
        uploaded_path = dest_path
    else:
        print()
        print("Troubleshooting:")
        print("  • Upload the image to /content/ in the Colab file browser")
        print("  • Keep the filename (with extension) exactly the same")
        print()
        if content_dir.exists():
            print("Files in /content:")
            for f in sorted(content_dir.iterdir()):
                if f.suffix.lower() in {".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp"}:
                    print(f"  • {f.name}")
        uploaded_path = None

print()
print("=" * 60)
print("STEP 1: Find Your Image")
print("=" * 60)
print()

if uploaded_path:
    print(f"Found your image: {uploaded_path.name}")
else:
    print("Set IMAGE_CHOICE to one of the preset names above or provide CUSTOM_FILENAME.")

if uploaded_path:
    print()
    print("=" * 60)
    print("PROCESSING YOUR IMAGE...")
    print("=" * 60)
    print()

    try:
        img = Image.open(uploaded_path).convert('RGB')
        orig_size = img.size
        print(f"Original size: {orig_size[0]} × {orig_size[1]} pixels")

        img_resized = img.resize((IMG_SIZE, IMG_SIZE), Image.BILINEAR)
        print(f"Resized to: {IMG_SIZE} × {IMG_SIZE} pixels (for AI model)")
        print()

        img_np = np.asarray(img_resized, dtype=np.uint8)
        img_t = torch.from_numpy(img_np).float().permute(2, 0, 1) / 255.0
        img_t = img_t.unsqueeze(0).to(device)

        print("Running AI segmentation...")
        model_pre.eval()
        with torch.no_grad():
            logits = model_pre(img_t)
            prob = torch.sigmoid(logits)[0, 0].cpu().numpy()
            pred_mask = (prob >= 0.5).astype(np.uint8)

        print("Segmentation complete!")
        print()

        print("-" * 60)
        print("RESULTS - 4 Different Views:")
        print("-" * 60)
        print()

        fig = plt.figure(figsize=(16, 4))

        plt.subplot(1, 4, 1)
        plt.imshow(img_resized)
        plt.axis('off')
        plt.title('Your Image\n(resized)', fontsize=12, fontweight='bold')

        plt.subplot(1, 4, 2)
        plt.imshow(pred_mask, cmap='gray')
        plt.axis('off')
        plt.title('Detected Regions\n(binary mask)', fontsize=12, fontweight='bold')

        plt.subplot(1, 4, 3)
        im = plt.imshow(prob, cmap='hot')
        plt.colorbar(im, fraction=0.046, pad=0.04)
        plt.axis('off')
        plt.title('AI Confidence\n(hotter = more confident)', fontsize=12, fontweight='bold')

        plt.subplot(1, 4, 4)
        overlay = overlay_mask(img_resized, pred_mask * 255, color=(0, 255, 255), alpha=0.5)
        plt.imshow(overlay)
        plt.axis('off')
        plt.title('Overlay\n(cyan = detected)', fontsize=12, fontweight='bold')

        plt.suptitle(f'Segmentation Results for: {uploaded_path.name}',
                     fontsize=14, fontweight='bold', y=1.02)
        plt.tight_layout()
        plt.show()

        print()
        print("-" * 60)
        print("FULL RESOLUTION RESULT:")
        print("-" * 60)
        print()

        mask_fullsize = Image.fromarray((pred_mask * 255).astype(np.uint8))
        mask_fullsize = mask_fullsize.resize(orig_size, Image.NEAREST)
        print(f"Mask resized back to original: {orig_size[0]} × {orig_size[1]} pixels")
        print()

        fig = plt.figure(figsize=(14, 7))

        plt.subplot(1, 2, 1)
        plt.imshow(img)
        plt.axis('off')
        plt.title('Your Original Image', fontsize=14, fontweight='bold')

        plt.subplot(1, 2, 2)
        overlay_full = overlay_mask(img, np.array(mask_fullsize), color=(0, 255, 255), alpha=0.5)
        plt.imshow(overlay_full)
        plt.axis('off')
        plt.title('Segmentation Result\n(Full Resolution)', fontsize=14, fontweight='bold')

        plt.tight_layout()
        plt.show()

        print()
        print("=" * 60)
        print("STATISTICS:")
        print("=" * 60)
        mask_array = np.array(mask_fullsize)
        total_pixels = mask_array.size
        detected_pixels = np.sum(mask_array > 0)
        percentage = (detected_pixels / total_pixels) * 100

        print(f"  • Total pixels: {total_pixels:,}")
        print(f"  • Detected pixels: {detected_pixels:,}")
        print(f"  • Coverage: {percentage:.2f}% of image")
        print(f"  • Average confidence: {prob.mean():.3f}")
        print()

        print("=" * 60)
        print("ALL DONE! Hope you enjoyed it!")
        print("=" * 60)

    except Exception as e:
        print(f"Error processing image: {e}")
        print("   Make sure your file is a valid image (JPG, PNG, etc.)")
else:
    print()
    print("=" * 60)
    print("NO IMAGE TO PROCESS")
    print("=" * 60)

<small>These images are in the test set: human_bladder_03, human_bladder_09, human_brain_11, human_brain_1, human_brain_9, human_cardia_7, human_epiglottis_2, human_jejunum_04, human_kidney_02, human_kidney_04, human_kidney_06, human_liver_13, human_liver_21, human_liver_22, human_liver_27, human_liver_37, human_liver_39, human_melanoma_08, human_muscle_7, human_oesophagus_11, human_oesophagus_22, human_oesophagus_30, human_oesophagus_34, human_oesophagus_37, human_oesophagus_43, human_pancreas_14, human_pancreas_21, human_pancreas_39, human_peritoneum_6, human_peritoneum_9, human_placenta_26, human_placenta_28, human_placenta_31, human_pylorus_5, human_rectum_10, human_rectum_11, human_rectum_7, human_salivory_11, human_salivory_17, human_salivory_28, human_salivory_32, human_spleen_29, human_testis_10, human_tongue_08, human_tongue_16, human_tongue_17, human_tongue_33, human_tongue_36, human_tonsile_4, human_tonsile_6, human_umbilical_cord_06, mouse_heart_01, mouse_heart_15, mouse_heart_23, mouse_kidney_10, mouse_kidney_19, mouse_kidney_24, mouse_liver_01, mouse_liver_21, mouse_liver_32, mouse_muscle_tibia_02, mouse_muscle_tibia_04, mouse_muscle_tibia_24, mouse_subscapula_16, mouse_subscapula_30, mouse_thymus_01.</small>