Exploring stuff for segmentation

In [None]:
%load_ext autoreload
%autoreload 2

import pathlib

parent_dir = pathlib.Path("~/zebrafish_rdsf/Rabia/SOST scales").expanduser()
assert parent_dir.exists()

scale_dirs = tuple(d for d in parent_dir.glob("*") if not d.stem in {".DS_Store", "TIFs"})

In [None]:
from tqdm.notebook import tqdm

from scale_morphology.scales import read

# Just look at the ALP scales since they seem easiest to segment
(alp_dir,) = (d for d in scale_dirs if "ALP" in d.stem)

alp_files = list(alp_dir.glob("*.lif"))
print(len(alp_files), "files")

names, images = [], []
for path in alp_files:
    name, img = zip(*read.read_lif(path))
    names += name
    images += img
    break

In [None]:
"""
Convert to greyscale
"""

import math

import matplotlib.pyplot as plt
from skimage.color import rgb2gray


def factor_int(n):
    val = math.ceil(math.sqrt(n))
    val2 = int(n / val)
    while val2 * val != float(n):
        val -= 1
        val2 = int(n / val)
    return val, val2


def plot_imgs(images, **plot_kw):
    n_figs = factor_int(len(images))

    fig, axes = plt.subplots(*n_figs, figsize=[3 * x for x in n_figs])
    for axis, img in zip(tqdm(axes.flat), images):
        axis.imshow(img, **plot_kw, cmap="grey")
        axis.set_axis_off()
    fig.tight_layout()


greyscale = [rgb2gray(i) for i in tqdm(images[:16])]

In [None]:
plot_imgs(greyscale)

In [None]:
"""
Bkg subtraction - cba with this rn
"""

In [None]:
"""
Blur to remove noise
"""
from skimage.filters import gaussian

blurred = [gaussian(i, sigma=3) for i in tqdm(greyscale)]
plot_imgs(blurred)

In [None]:
"""
Contrast enhance
"""

from skimage.exposure import equalize_adapthist

enhanced = [equalize_adapthist(i, kernel_size=2001) for i in tqdm(blurred)]

In [None]:
plot_imgs(enhanced)

In [None]:
"""
Threshold
"""

from skimage.filters import threshold_minimum, threshold_mean


def threshold(i):
    return (i < threshold_minimum(i)) | (i < threshold_mean(i))


thresholded = [threshold(i) for i in tqdm(enhanced)]

plot_imgs(thresholded)

In [None]:
"""
Remove small objects
"""
from skimage.segmentation import clear_border
a = [clear_border(i) for i in tqdm(thresholded)]

plot_imgs(a)


In [None]:
"""
Fill holes
"""

In [None]:
"""
Binary opening
"""

In [None]:
masks = [255 * tifffile.imread(f).astype(np.uint8) for f in mask_dir.glob("*.tif")]
old_masks = [
    255 * tifffile.imread(f).astype(np.uint8)
    for f in (pathlib.Path("segmentation_stuff") / "masks").glob("*.tif")
]

In [None]:
import napari
import numpy as np
import tifffile
from pathlib import Path

mask_paths = sorted(mask_dir.glob("*.tif"))
img_paths = {p.name: img_dir / p.name for p in mask_paths}

state = {"i": 0, "viewer": None, "labels": None}


def load_index(i):
    name = mask_paths[i].name
    im = tifffile.imread(img_paths[name])
    mask = tifffile.imread(mask_paths[i]).astype(np.uint8)

    if state["labels"] is None:
        state["image"] = viewer.add_image(im, name="image")
        state["labels"] = viewer.add_labels(mask, name="mask", opacity=0.5)
    else:
        state["image"].data = im
        state["labels"].data = mask
    viewer.title = f"{i+1}/{len(mask_paths)} : {name}"


def save_current():
    name = mask_paths[state["i"]].name
    out_path = out_dir / "cleaned_masks" / name
    tifffile.imwrite(out_path, (state["labels"].data > 0).astype(np.uint8) * 255)
    print(f"Saved {out_path}")


viewer = napari.Viewer()
state["viewer"] = viewer


@viewer.bind_key("s")
def _save(v):
    save_current()


@viewer.bind_key("n")
def _next(v):
    save_current()
    if state["i"] < len(mask_paths) - 1:
        state["i"] += 1
        load_index(state["i"])


@viewer.bind_key("p")
def _prev(v):
    save_current()
    if state["i"] > 0:
        state["i"] -= 1
        load_index(state["i"])


load_index(0)

In [None]:
"""
Read in a couple of the masks and look at them
"""

import numpy as np
import matplotlib.pyplot as plt


def plot_masks(masks):
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    for mask, axis in zip(masks, axes.flat):
        im = axis.imshow(mask)
        axis.set_axis_off()

    fig.colorbar(im)
    fig.tight_layout()

In [None]:
import pathlib
import tifffile
from itertools import islice

mask_dir = pathlib.Path("segmentation_stuff") / "masks"

masks = [tifffile.imread(f) for f in islice(mask_dir.glob("*"), 4)]

plot_masks(masks)