In [3]:
from cell_paint_seg.utils import (
    get_id_to_path,
    get_id_from_name_96,
    check_valid_labels,
    label_celltype,
)
from cell_paint_seg.image_io import read_ims, convert_to_hdf5

from tqdm import tqdm
from skimage import io, exposure, filters, measure, segmentation
import matplotlib.pyplot as plt
import numpy as np
import napari
import random
import os
from pathlib import Path
import time

In [4]:
channels = ["ER", "DNA", "Mito", "Actin", "RNA", "Golgi/membrane"]

dir_all = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/Experiment2_DB_Zprojection_unmixed"
dir_test = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/tifs"

dir_test = (
    "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/train_set/tifs"
)

In [5]:
id_to_path = get_id_to_path(
    path_dir=dir_all, tag=".tif", id_from_name=get_id_from_name_96
)

# Sample

In [None]:
# print(random.sample(list(id_to_path.keys()), 3))

test: ['s158', 's160', 's039']

train: ['s001', 's167', 's143']

# Convert

In [None]:
order = [-1, 0, 3, 2, 1, 4]

parent_dir = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/train_set/"


parent_dir = Path(parent_dir)
tif_path = parent_dir / "tifs"
hdf5_path = parent_dir / "hdf5s"

id_to_path = get_id_to_path(tif_path, tag=".tif", id_from_name=get_id_from_name_96)
image_ids = list(id_to_path.keys())
n_files = len(image_ids)
n_channels = len(id_to_path[image_ids[0]])

im_shape = convert_to_hdf5(id_to_path, hdf5_path, order=order, preprocess=True)

# view

In [10]:
sample = "s167"

id_to_path = get_id_to_path(
    path_dir=dir_test, tag=".tif", id_from_name=get_id_from_name_96
)

ims = read_ims(id_to_path[sample])

ims = [im.astype("float64") for im in ims]
ims = [np.linalg.norm(im, axis=-1) for im in ims]
ims = [im / np.amax(im) for im in ims]
ims = [exposure.equalize_adapthist(im, clip_limit=0.03) for im in ims]

im_rgb = np.stack([ims[2], ims[4], ims[1]], axis=2)

In [None]:
plt.imshow(im_rgb)
plt.gcf().set_dpi(300)
plt.xticks([])
plt.yticks([])
# plt.colorbar()
plt.show()

# napari segmentation

In [7]:
seg_path_nuc = f"/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/seg_gt_v3/{sample}-nuclei.tif"
seg_path_somas = f"/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/seg_gt_v3/{sample}-somas.tif"
seg_path_cells = f"/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/seg_gt_v3/{sample}-cells.tif"

In [None]:
if os.path.exists(seg_path_nuc):
    print(f"nucleus file found: {seg_path_nuc}")
    seg_init_nuc = io.imread(seg_path_nuc)
else:
    threshold = filters.threshold_otsu(ims[1])
    seg_init = ims[1] > threshold
    seg_init = measure.label(seg_init)
    regprops = measure.regionprops(seg_init)

    areas = []
    seg_init_nuc = np.zeros_like(seg_init)
    counter = 1
    for props in regprops:
        if props["area"] >= 64:
            seg_init_nuc[seg_init == props["label"]] = counter
            counter += 1

# soma
if os.path.exists(seg_path_somas):
    print(f"soma file found: {seg_path_somas}")
    seg_init_soma = io.imread(seg_path_somas)
else:
    threshold = filters.threshold_otsu(ims[4])
    seg_init = ims[4] > threshold  # RNA channel
    seg_init = np.logical_or(seg_init, seg_init_nuc > 0)  # add nucleus
    seg_init = measure.label(seg_init)
    regprops = measure.regionprops(seg_init)

    areas = []
    mask_init = np.zeros_like(seg_init)
    for props in regprops:
        if props["area"] >= 81:
            mask_init[seg_init == props["label"]] = 1

    seg_init_soma = segmentation.watershed(
        mask_init, markers=seg_init_nuc, mask=mask_init
    )


# soma
if os.path.exists(seg_path_cells):
    print(f"cell file found: {seg_path_cells}")
    seg_init_cells = io.imread(seg_path_cells)
else:
    seg_init_cells = seg_init_soma

In [None]:
viewer = napari.Viewer()
for im, c in zip(ims, channels):
    viewer.add_image(im, name=c, scale=(0.6, 0.6))


viewer.add_image(im_rgb, name="rgb", rgb=True, scale=(0.6, 0.6))

print(
    f"max nuc: {np.amax(seg_init_nuc)}, max soma: {np.amax(seg_init_soma)}, max cell: {np.amax(seg_init_cells)}"
)
viewer.add_labels(seg_init_nuc, name="nuclei", scale=(0.6, 0.6))
viewer.add_labels(seg_init_soma, name="somas", scale=(0.6, 0.6))
viewer.add_labels(seg_init_cells, name="cells", scale=(0.6, 0.6))

viewer.scale_bar.visible = True
viewer.scale_bar.unit = "um"

# Clean up labels

In [None]:
dir_seg = (
    "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/seg_gt_v3/"
)

id_to_path_seg = get_id_to_path(dir_seg, tag=".tif", id_from_name=get_id_from_name_96)

for id, paths in id_to_path_seg.items():
    for path in paths:
        if "nuclei" in str(path):
            seg_nuc = io.imread(path)
        elif "somas" in str(path):
            seg_soma = io.imread(path)
        elif "cells" in str(path):
            seg_cell = io.imread(path)

    print("Separating components...")
    for comp, seg in zip(["Nuc", "Soma", "Cell"], [seg_nuc, seg_soma, seg_cell]):
        new_label = np.amax(seg) + 1
        for label in np.unique(seg):
            # background
            if label == 0:
                continue

            lbl = measure.label(seg == label)
            if not np.amax(lbl) == 1:
                regprops = measure.regionprops(lbl)
                print(
                    f"{comp} {label} has {np.amax(lbl)} disconnected components - renaming {[np.multiply(props.centroid, 0.6) for props in regprops]}..."
                )
                for props in regprops[1:]:
                    mask = lbl == props["label"]
                    seg[mask] = new_label
                    new_label += 1

    print("Matching somas to nuclei...")
    soma_to_nuc = {}
    for nuc_label in tqdm(np.unique(seg_nuc)):
        if nuc_label == 0:
            continue
        found = False
        for soma_label in np.unique(seg_soma):
            if soma_label == 0:
                continue
            recall = np.sum((seg_nuc == nuc_label) & (seg_soma == soma_label)) / np.sum(
                (seg_nuc == nuc_label)
            )
            if recall > 0.9:
                assert soma_label not in soma_to_nuc.keys()
                soma_to_nuc[soma_label] = nuc_label
                found = True
        if not found:
            wher = np.where(seg_nuc == nuc_label)
            print(
                f"nuc {nuc_label} not found in soma {(wher[0][0]*0.6, wher[1][0]*0.6)}"
            )

    seg_soma_relabel = np.zeros_like(seg_soma)
    for soma_label, nuc_label in soma_to_nuc.items():
        seg_soma_relabel[seg_soma == soma_label] = nuc_label
    seg_soma = seg_soma_relabel

    print("Matching somas to cells...")
    cell_to_soma = {}
    for soma_label in tqdm(np.unique(seg_soma)):
        if soma_label == 0:
            continue
        found = False
        for cell_label in np.unique(seg_cell):
            if cell_label == 0:
                continue
            recall = np.sum(
                (seg_soma == soma_label) & (seg_cell == cell_label)
            ) / np.sum((seg_soma == soma_label))
            if recall > 0.9:
                if cell_label in cell_to_soma.keys():
                    print(f"cell {cell_label} already matched to soma")
                # assert cell_label not in cell_to_soma.keys()
                cell_to_soma[cell_label] = soma_label
                found = True
        if not found:
            wher = np.where(seg_soma == soma_label)
            print(
                f"soma {soma_label} not found in cells {(wher[0][0]*0.6, wher[1][0]*0.6)}"
            )

    seg_cell_relabel = np.zeros_like(seg_cell)
    for cell_label, soma_label in cell_to_soma.items():
        seg_cell_relabel[seg_cell == cell_label] = soma_label
    seg_cell = seg_cell_relabel

    print("Subsetting...")
    for label in np.unique(seg_nuc):
        if label == 0:
            continue
        seg_soma[seg_nuc == label] = label
        seg_cell[seg_soma == label] = label

    print("Relabelling consecutively...")
    seg_nuc_relabeled = np.zeros_like(seg_nuc)
    seg_soma_relabeled = np.zeros_like(seg_soma)
    seg_cell_relabeled = np.zeros_like(seg_cell)
    counter = 1
    for label in np.unique(seg_nuc):
        if label == 0:
            continue
        seg_nuc_relabeled[seg_nuc == label] = counter
        seg_soma_relabeled[seg_soma == label] = counter
        seg_cell_relabeled[seg_cell == label] = counter
        counter += 1

    seg_nuc = seg_nuc_relabeled
    seg_soma = seg_soma_relabeled
    seg_cell = seg_cell_relabeled

In [None]:
viewer = napari.Viewer()
for im, c in zip(ims, channels):
    viewer.add_image(im, name=c, scale=(0.6, 0.6))

viewer.add_image(im_rgb, name="rgb", rgb=True, scale=(0.6, 0.6))

print(
    f"max nuc: {np.amax(seg_nuc)}, max soma: {np.amax(seg_soma)}, max cell: {np.amax(seg_cell)}"
)
viewer.add_labels(seg_nuc, name="nuclei", scale=(0.6, 0.6))
viewer.add_labels(seg_soma, name="somas", scale=(0.6, 0.6))
viewer.add_labels(seg_cell, name="cells", scale=(0.6, 0.6))

viewer.scale_bar.visible = True
viewer.scale_bar.unit = "um"

# Check labels

In [None]:
dir_seg = (
    "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/seg_gt_v3/"
)

id_to_path = get_id_to_path(
    path_dir=dir_test, tag=".tif", id_from_name=get_id_from_name_96
)
id_to_path_seg = get_id_to_path(dir_seg, tag=".tif", id_from_name=get_id_from_name_96)

for sample in id_to_path.keys():
    ims = read_ims(id_to_path[sample])
    segs = read_ims(id_to_path_seg[sample])

    ims = [im.astype("float64") for im in ims]
    ims = [np.linalg.norm(im, axis=-1) for im in ims]
    ims = [im / np.amax(im) for im in ims]
    ims = [exposure.equalize_adapthist(im, clip_limit=0.03) for im in ims]
    # im_rgb = np.stack([ims[3], ims[4], ims[1]], axis=2)

    viewer = napari.Viewer()
    viewer.add_image(ims[1], scale=(0.6, 0.6), name="DNA")
    viewer.add_labels(segs[1], scale=(0.6, 0.6), name="nuclei")

    viewer.add_image(ims[4], scale=(0.6, 0.6), name="RNA")
    viewer.add_labels(segs[2], scale=(0.6, 0.6), name="soma")

    viewer.add_image(ims[3], scale=(0.6, 0.6), name="Actin")
    viewer.add_labels(segs[0], scale=(0.6, 0.6), name="cell")

    check_valid_labels(segs[1], segs[2], segs[0])
    print(f"{len(np.unique(segs[0]))-1} cells")

# Classify cell type

In [None]:
import napari
import numpy as np

viewer = napari.Viewer()
viewer.add_image(np.ones((10, 10, 10)))


@viewer.bind_key("a")
def get_input(viewer):
    from magicgui.widgets import request_values

    viewer.add_points(np.array([[1, 1]]), size=np.array([1]))


napari.run()

In [None]:
dir_im = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/tifs"
dir_seg = (
    "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/seg_gt_v3/"
)
outdir = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/seg_gt_v3/celltypes"

label_celltype(
    path_dir_im=dir_im, channels=[3, 4, 1], path_dir_gt=dir_seg, out_dir=outdir
)