Conversion + Segmentation of LIF files
====
This notebook runs through segmentation of the ALP stained LIF files on the RDSF.

I'd suggest writing a new notebook (take inspiration from this one, or any others that might be in this directory)
if you want to segment something else (i.e. if you have taken new images and want to segment them out).

- Read in the LIF files
- Save them as TIFs (they're nicer for me to work with)
- Segment them with my pipeline

Part 1 - converting to TIF
----

In [None]:
"""
Get paths to LIF files
"""

import pathlib

# the directory containing microscopy images
parent_dir = pathlib.Path(
    "~/zebrafish_rdsf/Carran/Postgrad/Scale images from WT_spp1_sost"
).expanduser()
assert parent_dir.exists()

lif_paths = list(parent_dir.glob("*"))

In [None]:
"""
Get rid of some weird ones
"""
lif_paths = [l for l in lif_paths if "out of focus" not in str(l) and not l.is_dir() and not l.stem == ".DS_Store"]

In [None]:
"""
For now, select only the ALP stained ones, because they're easier to segment
"""

import pandas as pd
from scale_morphology.scales import metadata


lif_df = pd.DataFrame(columns=["path"])
lif_df["path"] = lif_paths
lif_df["stain"] = lif_df["path"].apply(
    lambda x: metadata.stain(str(x).replace("lif", "tif"))
)

In [None]:
"""
Extract ages
"""

lif_df["age"] = lif_df["path"].apply(lambda x: metadata.age(x.stem))

In [None]:
"""
Select only the ALP ones, for now
"""
alp_df = lif_df[lif_df["stain"] == "ALP"]

In [None]:
"""
Read them into arrays
"""

from tqdm import tqdm
from scale_morphology.scales import read

names, imgs, lifs = [], [], []
for path in tqdm(alp_df["path"]):
    name, img = zip(*read.read_lif(path))

    names += name
    imgs += img
    lifs += [path.name for _ in name]

In [None]:
"""
Munge names and save as TIFs
"""

import tifffile

out_dir = parent_dir / "TIFs"
try:
    out_dir.mkdir(exist_ok=False)
    for name, img, lif in zip(tqdm(names), imgs, lifs, strict=True):
        path = (
            name
            + "__"
            + lif.replace(".lif", "").replace(".", "_").replace(" ", "_")
            + ".tif"
        )
        tifffile.imwrite(out_dir / path, img)
except FileExistsError:
    print("dir", out_dir, "exists")

Part 2: segmenting out the scales
----
Assuming the above has run, we will now segment out the scales stored in the TIF files.

In [None]:
"""
Get the desired output directory from the config file
"""

import pathlib
import yaml

with open("hi.txt", "w") as f:
    f.write("hi\n")

cfg_path = pathlib.Path("config.yaml")
with open(cfg_path, "r") as f:
    out_dir = (
        pathlib.Path(yaml.safe_load(f)["auto_segmentation_dir"]).expanduser().resolve()
    )
out_dir.mkdir(exist_ok=True, parents=True)

In [None]:
"""
Get the input directory - from my RDSF mount point
"""

with open(cfg_path, "r") as f:
    in_dir = pathlib.Path(yaml.safe_load(f)["input_tif_dir"]).expanduser().resolve()

assert in_dir.is_dir(), in_dir
in_paths = sorted(list(in_dir.glob("*.tif")))

out_paths = [out_dir / (name.stem + "_segmentation.tif") for name in in_paths]

In [None]:
"""
Download SAM model weights
"""

import requests

url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
sam_dir = pathlib.Path("checkpoints/")
sam_dir.mkdir(exist_ok=True)
sam_path = sam_dir / "sam_vit_h_4b8939.pth"

if not sam_path.is_file():
    with open(sam_path, "wb") as f:
        f.write(requests.get(url).content)

In [None]:
import tifffile
from tqdm import tqdm

from scale_morphology.scales import segmentation

for in_path, out_path in zip(tqdm(in_paths), out_paths, strict=True):
    if out_path.exists():
        continue
    img = tifffile.imread(in_path)
    mask = segmentation.segment_alp(
        img,
        device="cuda",
        model_type="vit_h",
        model_checkpoint=sam_path,
    )
    tifffile.imwrite(out_path, mask)

    del img
    del mask

In [None]:
"""
Now that the segmentation has run, have a look at them just to check
"""

import numpy as np
import tifffile
import matplotlib.pyplot as plt

fig, axes = plt.subplots(4, 4, figsize=(10, 10))
indices = np.random.randint(0, 600, size=16)

for i, axis in zip(indices, axes.flat):
    axis.imshow(tifffile.imread(in_paths[i]))
    axis.imshow(tifffile.imread(out_paths[i]), alpha=0.5)
    axis.set_axis_off()

fig.tight_layout()
fig.savefig("tmp_random_images.png")