# 02 - Cell Segmentation Pipeline

This notebook walks through a basic segmentation pipeline for XRF elemental
maps. The steps are:

1. Select a reference element and apply Otsu thresholding
2. Morphological cleanup (opening, closing, dilation)
3. Connected-component labeling
4. Filter regions by size and shape

In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
from skimage.filters import threshold_otsu
from skimage.morphology import binary_opening, binary_closing, disk
from skimage.measure import label, regionprops

%matplotlib inline

## 1. Load Reference Element Map

In [None]:
DATA_PATH = "../data/sample_xrf.h5"
REFERENCE_ELEMENT = "Zn"  # zinc often highlights cell bodies

with h5py.File(DATA_PATH, "r") as f:
    names = [n.decode() for n in f["/MAPS/channel_names"][:]]
    idx = names.index(REFERENCE_ELEMENT)
    ref_map = f["/MAPS/XRF_fits"][idx]

print(f"Reference element: {REFERENCE_ELEMENT}, shape: {ref_map.shape}")

## 2. Otsu Thresholding

Otsu's method automatically selects an intensity threshold that minimizes
intra-class variance.

In [None]:
thresh = threshold_otsu(ref_map)
binary = ref_map > thresh

fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].imshow(ref_map, cmap="inferno")
axes[0].set_title(f"{REFERENCE_ELEMENT} intensity")
axes[1].hist(ref_map.ravel(), bins=128)
axes[1].axvline(thresh, color="r", ls="--", label=f"Otsu={thresh:.3f}")
axes[1].legend()
axes[1].set_title("Histogram + threshold")
axes[2].imshow(binary, cmap="gray")
axes[2].set_title("Binary mask")
for ax in axes:
    ax.axis("off") if ax != axes[1] else None
plt.tight_layout()
plt.show()

## 3. Morphological Operations

Clean up noise and fill small holes with opening then closing.

In [None]:
selem = disk(3)
cleaned = binary_opening(binary, selem)
cleaned = binary_closing(cleaned, selem)

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(binary, cmap="gray")
axes[0].set_title("Before morphology")
axes[1].imshow(cleaned, cmap="gray")
axes[1].set_title("After opening + closing")
for ax in axes:
    ax.axis("off")
plt.tight_layout()
plt.show()

## 4. Connected Components and Filtering

In [None]:
labels = label(cleaned)
regions = regionprops(labels, intensity_image=ref_map)

MIN_AREA = 50
MAX_AREA = 5000

filtered_mask = np.zeros_like(labels, dtype=bool)
kept_regions = []
for r in regions:
    if MIN_AREA <= r.area <= MAX_AREA:
        filtered_mask[labels == r.label] = True
        kept_regions.append(r)

print(f"Total components: {len(regions)}")
print(f"After area filter [{MIN_AREA}, {MAX_AREA}]: {len(kept_regions)}")

fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(ref_map, cmap="inferno", alpha=0.6)
ax.contour(filtered_mask, colors="cyan", linewidths=0.8)
ax.set_title(f"Segmented regions (n={len(kept_regions)})")
ax.axis("off")
plt.tight_layout()
plt.show()