Exploring stuff for segmentation

In [None]:
%load_ext autoreload
%autoreload 2

import pathlib

parent_dir = pathlib.Path("~/zebrafish_rdsf/Rabia/SOST scales").expanduser()
assert parent_dir.exists()

scale_dirs = tuple(d for d in parent_dir.glob("*") if not d.stem in {".DS_Store", "TIFs"})

In [None]:
from tqdm.notebook import tqdm

from scale_morphology.scales import read

# Just look at the ALP scales since they seem easiest to segment
(alp_dir,) = (d for d in scale_dirs if "ALP" in d.stem)

alp_files = list(alp_dir.glob("*.lif"))
print(len(alp_files), "files")

names, images = [], []
for path in alp_files:
    name, img = zip(*read.read_lif(path))
    names += name
    images += img

In [None]:
import pathlib

out_dir = pathlib.Path("segmentation_stuff")

overlaid_dir = out_dir / "overlaid"
mask_dir = out_dir / "masks"
img_dir = out_dir / "imgs"

for d in overlaid_dir, mask_dir, img_dir:
    d.mkdir(exist_ok=True, parents=True)

In [None]:
"""
Build our training set by first using classical CV to perform segmentation
"""

import tifffile
from matplotlib import colors
import matplotlib.pyplot as plt
from scale_morphology.scales import segmentation


def clear2black_cmap() -> colors.Colormap:
    """
    Colormap that varies from clear to black
    """
    c_white = colors.colorConverter.to_rgba("white", alpha=0)
    c_black = colors.colorConverter.to_rgba("black", alpha=1)
    return colors.ListedColormap([c_white, c_black], "clear2black")


cmap = clear2black_cmap()
for name, img in zip(tqdm(names), images):
    mask_path = mask_dir / (name + ".tif")
    if mask_path.exists():
        continue
    mask = segmentation.classical_segmentation(img)

    tifffile.imwrite(mask_path, mask)
    tifffile.imwrite(img_dir / (name + ".tif"), img)

    fig, axis = plt.subplots(figsize=(5, 5))
    axis.imshow(img)
    axis.imshow(mask, cmap=cmap, alpha=0.5)
    axis.set_axis_off()
    fig.savefig(overlaid_dir / (name + ".png"))
    plt.close(fig)

In [None]:
import napari
import numpy as np
import tifffile
from pathlib import Path

mask_paths = sorted(mask_dir.glob("*.tif"))
img_paths = {p.name: img_dir / p.name for p in mask_paths}

state = {"i": 0, "viewer": None, "labels": None}


def load_index(i):
    name = mask_paths[i].name
    im = tifffile.imread(img_paths[name])
    mask = tifffile.imread(mask_paths[i]).astype(np.uint8)

    if state["labels"] is None:
        state["image"] = viewer.add_image(im, name="image")
        state["labels"] = viewer.add_labels(mask, name="mask", opacity=0.5)
    else:
        state["image"].data = im
        state["labels"].data = mask
    viewer.title = f"{i+1}/{len(mask_paths)} : {name}"


def save_current():
    name = mask_paths[state["i"]].name
    out_path = out_dir / "cleaned_masks" / name
    tifffile.imwrite(out_path, (state["labels"].data > 0).astype(np.uint8) * 255)
    print(f"Saved {out_path}")


viewer = napari.Viewer()
state["viewer"] = viewer


@viewer.bind_key("s")
def _save(v):
    save_current()


@viewer.bind_key("n")
def _next(v):
    save_current()
    if state["i"] < len(mask_paths) - 1:
        state["i"] += 1
        load_index(state["i"])


@viewer.bind_key("p")
def _prev(v):
    save_current()
    if state["i"] > 0:
        state["i"] -= 1
        load_index(state["i"])


load_index(0)