Segmentation Illustration
====
I've settled on a hybrid pipeline - we first use classical techniques to get a rough mask.
Then we use the bounding box of this rough mask + some points from within as a prior, and the Meta Segment Anything Model (SAM) to do the actual segmentation.

Finally, manually tidied up these masks to get the final dataset that I worked with.

Initial Segmentation
----
First we'll make some rough masks by doing a classical CV pipeline.

In [None]:
%load_ext autoreload
%autoreload 2

import pathlib

parent_dir = pathlib.Path("~/zebrafish_rdsf/Rabia/SOST scales").expanduser()
assert parent_dir.exists()

scale_dirs = tuple(d for d in parent_dir.glob("*") if not d.stem in {".DS_Store", "TIFs"})

In [None]:
from tqdm.notebook import tqdm

from scale_morphology.scales import read

# Just look at the ALP scales since they seem easiest to segment
(alp_dir,) = (d for d in scale_dirs if "ALP" in d.stem)

alp_files = list(alp_dir.glob("*.lif"))
print(len(alp_files), "files")

names, images = [], []
for path in alp_files:
    name, img = zip(*read.read_lif(path))
    names += name
    images += img

In [None]:
"""
Convert to greyscale
"""

import math

import matplotlib.pyplot as plt


def factor_int(n):
    val = math.ceil(math.sqrt(n))
    val2 = int(n / val)
    while val2 * val != float(n):
        val -= 1
        val2 = int(n / val)
    return val, val2


def plot_imgs(images, **plot_kw):
    n_figs = factor_int(len(images))

    fig, axes = plt.subplots(*n_figs, figsize=[2 * x for x in n_figs])
    for axis, img in zip(tqdm(axes.flat), images):
        axis.imshow(img, **plot_kw, cmap="grey")
        axis.set_axis_off()
    fig.tight_layout()


plot_imgs(images[:16])

In [None]:
from skimage.color import rgb2gray

greyscale = [rgb2gray(i) for i in tqdm(images)]

In [None]:
"""
Blur to remove noise
Contrast enhance
Threshold
Remove small objects
Binary opening
Fill holes
Remove small objects
"""

import sys
import numpy as np
from scipy.ndimage import binary_fill_holes, label
from skimage.segmentation import clear_border
from skimage.exposure import equalize_adapthist
from skimage.morphology import binary_opening, disk
from skimage.filters import gaussian, threshold_minimum, threshold_mean


def threshold(i):
    return (i < threshold_minimum(i)) | (i < threshold_mean(i))


def clear_border_keep_large(img):
    """
    Clear border but dont do anything if it would remove too much
    """
    sum_before = np.sum(img)
    cleared = clear_border(img)

    if np.sum(cleared) < 0.1 * sum_before:
        print("large obj touching border", file=sys.stderr)
        return img
    return cleared


def _largest_connected_component(binary_array):
    """
    Return the largest connected component of a binary array, as a binary array

    :param binary_array: Binary array.
    :returns: Largest connected component.

    """
    labelled, _ = label(binary_array, np.ones((3, 3)))

    # Find the size of each component
    sizes = np.bincount(labelled.ravel())
    sizes[0] = 0

    retval = labelled == np.argmax(sizes)
    return retval


# Structuring element for binary opening
elem = disk(10)

blurred = [gaussian(i, sigma=3) for i in tqdm(greyscale, desc="Blurring")]
enhanced = [
    equalize_adapthist(i, kernel_size=2001)
    for i in tqdm(blurred, desc="Enhance contrast")
]
thresholded = [threshold(i) for i in tqdm(enhanced, desc="thresholding")]
cleared = [
    clear_border_keep_large(i) for i in tqdm(thresholded, desc="clearing borders")
]
opened = [binary_opening(i, elem) for i in tqdm(cleared, desc="opening")]
filled = [binary_fill_holes(i) for i in tqdm(opened, desc="filling")]

final = [
    _largest_connected_component(i) for i in tqdm(filled, desc="Removing small objs")
]

In [None]:
plot_imgs(final[:25])

In [None]:
"""
Save them to disk, in case we need them later
"""

import tifffile

out_dir = pathlib.Path("segmentation/")

mask_prior_dir = out_dir / "mask_priors"
mask_prior_dir.mkdir(exist_ok=True, parents=True)

for mask, name in zip(final, tqdm(names)):
    tifffile.imwrite(mask_prior_dir / (name + ".tif"), mask)

In [None]:
"""Also save the originals"""

img_dir = out_dir / "images"
img_dir.mkdir(exist_ok=True)

for img, name in zip(images, tqdm(names)):
    tifffile.imwrite(img_dir / (name + ".tif"), img)

Transformer-based segmentation
----
Now we've got some masks that are almost there, we want to tidy things up a little.

One could do this by hand but that would be extremely slow, and quite subjective. Instead, we can tidy the masks up, at least initally, by using SAM to segment out the scales.
This doesn't "just work" out of the box, though - if we try to segment the scales without any prior, it will not give a better result than the above classical pipeline - it may also label bubbles, leeched stain, the different parts of the scale etc. as different objects, which we don't want.

We can instead use the above masks as a prior for the SAM.
We will build a bounding-box around the above segmentation masks, choose some points from within them, and feed these in to SAM to be used as a prior for the segmentation.


In [None]:
from scipy.ndimage import binary_erosion


def bbox_from_mask(m, pad=16):
    ys, xs = np.nonzero(m)
    if ys.size == 0:
        return None
    y0, y1 = ys.min(), ys.max()
    x0, x1 = xs.min(), xs.max()
    return np.array(
        [
            max(0, x0 - pad),
            max(0, y0 - pad),
            min(m.shape[1] - 1, x1 + pad),
            min(m.shape[0] - 1, y1 + pad),
        ],
        dtype=np.int32,
    )


def sample_pos_points(m, n=4):
    mm = binary_erosion(m, iterations=100)
    ys, xs = np.nonzero(mm if mm.any() else m)

    idx = np.linspace(0, ys.size - 1, num=min(n, ys.size)).astype(int)
    pts = np.stack([xs[idx], ys[idx]], axis=1)
    return pts


boxes = [bbox_from_mask(i, pad=32) for i in tqdm(final)]
points = [sample_pos_points(i, n=10) for i in tqdm(final)]

In [None]:
fig, axes = plt.subplots(5, 5)
for axis, img, box, point in zip(axes.flat, final, boxes, points):
    axis.imshow(img, cmap="binary")

    axis.axvline(box[0])
    axis.axvline(box[2])

    axis.axhline(box[1])
    axis.axhline(box[3])

    for pt in point:
        axis.scatter(*pt)

In [None]:
import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

device = "cuda"
model_type = "vit_h"
sam_checkpoint = pathlib.Path("checkpoints") / "sam_vit_h_4b8939.pth"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device).eval()

predictor = SamPredictor(sam)

In [None]:
import cv2

sam_masks = []
with torch.inference_mode():
    for img, pts, box in zip(tqdm(enhanced), points, boxes):
        grey = cv2.cvtColor(img.astype(np.float32), cv2.COLOR_GRAY2RGB)
        predictor.set_image(grey)

        masks, scores, logits = predictor.predict(
            point_coords=pts,
            point_labels=np.ones((pts.shape[0],), dtype=np.int32),
            box=box,
            multimask_output=True,
        )

        sam_masks.append(masks[np.argmax(scores)])

In [None]:
plot_imgs(sam_masks[:25])

In [None]:
fig, axes = plt.subplots(5, 5, figsize=(10, 10))

for axis, img, mask in zip(axes.flat, tqdm(enhanced), sam_masks):
    axis.imshow(img, cmap="binary")
    axis.imshow(mask, alpha=0.5)
    axis.set_axis_off()

In [None]:
sam_mask_dir = out_dir / "sam_masks"
sam_mask_dir.mkdir()

for mask, name in zip(sam_masks, tqdm(names)):
    tifffile.imwrite(sam_mask_dir / (name + ".tif"), mask)

Edit the masks manually
----

The SAM masks are still not perfect. Some (but not all) of them will need editing - the below cells will open a GUI that can be used to tidy up any of the masks that aren't sensible.

In [None]:
masks = [255 * tifffile.imread(f).astype(np.uint8) for f in mask_dir.glob("*.tif")]
old_masks = [
    255 * tifffile.imread(f).astype(np.uint8)
    for f in (pathlib.Path("segmentation_stuff") / "masks").glob("*.tif")
]

In [None]:
import napari
import numpy as np
import tifffile
from pathlib import Path

mask_paths = sorted(mask_dir.glob("*.tif"))
img_paths = {p.name: img_dir / p.name for p in mask_paths}

state = {"i": 0, "viewer": None, "labels": None}


def load_index(i):
    name = mask_paths[i].name
    im = tifffile.imread(img_paths[name])
    mask = tifffile.imread(mask_paths[i]).astype(np.uint8)

    if state["labels"] is None:
        state["image"] = viewer.add_image(im, name="image")
        state["labels"] = viewer.add_labels(mask, name="mask", opacity=0.5)
    else:
        state["image"].data = im
        state["labels"].data = mask
    viewer.title = f"{i+1}/{len(mask_paths)} : {name}"


def save_current():
    name = mask_paths[state["i"]].name
    out_path = out_dir / "cleaned_masks" / name
    tifffile.imwrite(out_path, (state["labels"].data > 0).astype(np.uint8) * 255)
    print(f"Saved {out_path}")


viewer = napari.Viewer()
state["viewer"] = viewer


@viewer.bind_key("s")
def _save(v):
    save_current()


@viewer.bind_key("n")
def _next(v):
    save_current()
    if state["i"] < len(mask_paths) - 1:
        state["i"] += 1
        load_index(state["i"])


@viewer.bind_key("p")
def _prev(v):
    save_current()
    if state["i"] > 0:
        state["i"] -= 1
        load_index(state["i"])


load_index(0)