Segment Confocal Images
====
Test to see if we can directly segment images from the confocal scanner, which will be faster.

In [None]:
"""
Read in the images
"""

import pathlib

# the directory containing microscopy images
parent_dir = pathlib.Path(
    "~/zebrafish_rdsf/Carran/Postgrad/Slidescanner images circularity tests- for Rich"
).expanduser()
assert parent_dir.exists()

male_dir = parent_dir / "Single channel 3m male onto"
female_dir = parent_dir / "Single channel 3y 6m female onto"

assert male_dir.exists()
assert female_dir.exists()

In [None]:
"""
Display the scales - there's only 3 in each directory
"""

import tqdm
import tifffile

male_paths = sorted(list(male_dir.glob("*.tif")))
female_paths = sorted(list(female_dir.glob("*.tif")))

male_imgs = [tifffile.imread(f) for f in male_paths]
female_imgs = [tifffile.imread(f) for f in female_paths]

assert len(male_imgs) == 3
assert len(female_imgs) == 3

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 3, figsize=(9, 6))

for axis, img in zip(axes.flat, male_imgs + female_imgs):
    axis.imshow(img, cmap="binary")
    axis.set_xticks([])
    axis.set_yticks([])

In [None]:
"""Try to enhance contrast"""

from tqdm.notebook import tqdm
from skimage.filters import gaussian
from skimage.exposure import equalize_adapthist


def preprocess(grey_img):
    blurred = gaussian(grey_img, sigma=3)
    # Giant kernel seems to work best, since the scale is also very large
    # Unfortunately this does make things slow
    return equalize_adapthist(blurred, kernel_size=2001)


preprocessed = [preprocess(i) for i in tqdm(male_imgs + female_imgs)]

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(9, 6))

for axis, img in zip(axes.flat, preprocessed):
    axis.imshow(img, cmap="binary")
    axis.set_xticks([])
    axis.set_yticks([])

In [None]:
"""
Download SAM model weights
"""

import requests

url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
sam_dir = pathlib.Path("checkpoints/")
sam_dir.mkdir(exist_ok=True)
sam_path = sam_dir / "sam_vit_h_4b8939.pth"

if not sam_path.is_file():
    with open(sam_path, "wb") as f:
        f.write(requests.get(url).content)

In [None]:
from segment_anything import sam_model_registry, SamPredictor

device = "cuda"
sam = sam_model_registry["vit_h"](checkpoint=sam_path)
sam.to(device).eval()

model = SamPredictor(sam)

In [None]:
"""
Use SAM with no prior to see if it works
"""

import cv2
import numpy as np


def pos_points(img_shape):
    centre = np.array([img_shape[1] // 2, img_shape[0] // 2])

    # How far to move our points
    offset = 500
    return np.array(
        [
            centre,
            centre + [0, offset],
            centre + [0, -offset],
            centre + [offset, 0],
            centre + [-offset, 0],
        ]
    )


def segment(img, sam_model):
    """
    Segment with SAM model
    """
    # Turn the greyscale image back to RGB, since this is what SAM expects
    grey = cv2.cvtColor(img.astype(np.float32), cv2.COLOR_GRAY2RGB)

    sam_model.set_image(grey)

    # Enforce that the scale overlaps with the centre of the image
    points = pos_points(img.shape)

    masks, scores, _ = model.predict(
        point_labels=[1 for _ in points],
        point_coords=points,
        multimask_output=False,
    )
    return masks, scores


masks = []
scores = []
for img in tqdm(preprocessed):
    a, b = segment(img, model)

    masks.append(a)
    scores.append(b)

In [None]:
from scale_morphology.scales import plotting

fig, axes = plt.subplots(2, 3, figsize=(9, 6))


for axis, img, mask, score in zip(axes.flat, preprocessed, masks, scores):
    pts = pos_points(img.shape)
    axis.imshow(img, cmap="binary")

    axis.imshow(mask.squeeze(), cmap=plotting.clear2colour_cmap("red"))

    # axis.scatter(*pts.T, color="b", s=5)

    axis.set_xticks([])
    axis.set_yticks([])