In [None]:
%pip install kagglehub
import kagglehub
import os 
path = os.path.join(os.getcwd(), "data")
# Download latest version
path = kagglehub.dataset_download("nikhilroxtomar/brain-tumor-segmentation",output_dir=str(path))
print("Path to dataset files:", path)

## Dataloader 


In [None]:
import os
import random
from pathlib import Path

import numpy as np
from PIL import Image
from skimage.filters import threshold_otsu, threshold_sauvola

# Resolve dataset root (expects data/images and data/masks)
DATA_DIR = Path(os.getcwd()) / "data"
IMAGES_DIR = DATA_DIR / "images"
MASKS_DIR = DATA_DIR / "masks"

if not IMAGES_DIR.exists() or not MASKS_DIR.exists():
    raise FileNotFoundError(
        f"Expected images/masks under {DATA_DIR}, but found images={IMAGES_DIR.exists()}, masks={MASKS_DIR.exists()}"
    )

# Build image-mask pairs by filename stem (.png only)
image_files = [p for p in IMAGES_DIR.rglob("*.png")]

mask_index = {p.stem: p for p in MASKS_DIR.rglob("*.png")}

pairs = [(img, mask_index[img.stem]) for img in image_files if img.stem in mask_index]

if not pairs:
    raise RuntimeError("No matching image/mask pairs found. Check filenames in images/ and masks/.")

print(f"Found {len(pairs)} image/mask pairs.")

In [None]:
def load_grayscale(path: Path) -> np.ndarray:
    """Load image as grayscale float32 in [0, 1]."""
    img = Image.open(path).convert("L")
    arr = np.asarray(img, dtype=np.float32)
    if arr.max() > 0:
        arr = arr / 255.0
    return arr


def load_mask(path: Path) -> np.ndarray:
    """Load mask as binary uint8 (0/1)."""
    mask = Image.open(path).convert("L")
    arr = np.asarray(mask, dtype=np.uint8)
    # Binarize using >0 to handle any non-zero mask values
    return (arr > 0).astype(np.uint8)


def dice_score(pred: np.ndarray, gt: np.ndarray) -> float:
    pred = pred.astype(bool)
    gt = gt.astype(bool)
    intersection = np.logical_and(pred, gt).sum()
    total = pred.sum() + gt.sum()
    if total == 0:
        return 1.0
    return 2.0 * intersection / total


def jaccard_score(pred: np.ndarray, gt: np.ndarray) -> float:
    pred = pred.astype(bool)
    gt = gt.astype(bool)
    union = np.logical_or(pred, gt).sum()
    if union == 0:
        return 1.0
    intersection = np.logical_and(pred, gt).sum()
    return intersection / union


def otsu_threshold(img: np.ndarray) -> np.ndarray:
    thresh = threshold_otsu(img)
    return (img >= thresh).astype(np.uint8)


def sauvola_threshold(img: np.ndarray, window_size: int = 55, k: float = 0.30) -> np.ndarray:
    thresh = threshold_sauvola(img)
    return (img >= thresh).astype(np.uint8)

In [None]:
# sample plot for a single image-mask pair and corresponding thresholds
import matplotlib.pyplot as plt

sample_img, sample_mask = random.choice(pairs)
img = load_grayscale(sample_img)
mask = load_mask(sample_mask)
otsu_pred = otsu_threshold(img)
sauv_pred = sauvola_threshold(img)

fig, axes = plt.subplots(2, 2, figsize=(10, 8))
axes = axes.ravel()

axes[0].imshow(img, cmap="gray")
axes[0].set_title("Sample Image")
axes[0].axis("off")

axes[1].imshow(mask, cmap="gray")
axes[1].set_title("Sample Mask")
axes[1].axis("off")

axes[2].imshow(otsu_pred, cmap="gray")
axes[2].set_title("Otsu Threshold")
axes[2].axis("off")

axes[3].imshow(sauv_pred, cmap="gray")
axes[3].set_title("Sauvola Threshold")
axes[3].axis("off")

plt.tight_layout()
plt.show()

In [None]:
# Evaluate on a random subset (up to 100 images)
random.seed(42)
np.random.seed(42)

sample_size = min(100, len(pairs))
subset = random.sample(pairs, sample_size)

metrics = {
    "otsu": {"dice": [], "jaccard": []},
    "sauvola": {"dice": [], "jaccard": []},
}

for img_path, mask_path in subset:
    img = load_grayscale(img_path)
    gt = load_mask(mask_path)

    pred_otsu = otsu_threshold(img)
    pred_sauv = sauvola_threshold(img)

    metrics["otsu"]["dice"].append(dice_score(pred_otsu, gt))
    metrics["otsu"]["jaccard"].append(jaccard_score(pred_otsu, gt))

    metrics["sauvola"]["dice"].append(dice_score(pred_sauv, gt))
    metrics["sauvola"]["jaccard"].append(jaccard_score(pred_sauv, gt))

In [None]:
def summarize(scores: list[float]) -> float:
    return float(np.mean(scores)) if scores else float("nan")

print(f"Evaluated {sample_size} images")
print("\nGlobal Otsu:")
print(f"  Dice:    {summarize(metrics['otsu']['dice']):.4f}")
print(f"  Jaccard: {summarize(metrics['otsu']['jaccard']):.4f}")

print("\nSauvola:")
print(f"  Dice:    {summarize(metrics['sauvola']['dice']):.4f}")
print(f"  Jaccard: {summarize(metrics['sauvola']['jaccard']):.4f}")