In [1]:
import numpy as np
import networkx as nx
from skimage import io, morphology, filters, exposure, measure
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import seaborn as sns
import h5py
import napari
import pickle

from cell_paint_seg.utils import (
    get_id_to_path,
    get_id_from_name_96,
    combine_soma_cell_labels,
    combine_soma_nucleus_labels,
    check_valid_labels,
)
from cell_paint_seg.image_io import read_ims

In [40]:
def AP2(seg1, seg2, iou_threshold=0.5):
    """Computes the AP2 score between two segmentations.

    Args:
        seg1 (np.array): first instance segmentation. 0 is background.
        seg2 (np.array): second instance segmentation. 0 is background.
        iou_threshold (float, optional): iou threshold. Defaults to 0.5.
    """

    def iou_score(seg1, seg2):
        intersection = np.logical_and(seg1, seg2).sum()
        union = np.logical_or(seg1, seg2).sum()
        return intersection / union

    # construct a networkx graph that has a vertex for each connected component in seg1 and seg2
    nodes = [(0, i) for i in np.unique(seg1) if i != 0] + [
        (1, i) for i in np.unique(seg2) if i != 0
    ]
    G = nx.Graph()
    G.add_nodes_from(nodes)

    # add edges between nodes with an iou greater than the threshold
    for i in tqdm(np.unique(seg1), desc="Constructing graph", leave=False):
        if i == 0:
            continue
        for j in np.unique(seg2[seg1 == i]):
            if j == 0:
                continue
            iou = iou_score(seg1 == i, seg2 == j)
            if iou > iou_threshold:
                G.add_edge((0, i), (1, j), weight=1)
                if iou_threshold > 0.5:
                    break

    # find the maximum weight matching
    matching = nx.max_weight_matching(G)

    # calculate the AP score
    tp = len(matching)
    fp = len(np.unique(seg1)) - 1 - tp
    fn = len(np.unique(seg2)) - 1 - tp
    ap2 = tp / (tp + fp + fn)

    return ap2

# Segmentation

In [3]:
def seg_init_nuc(images):
    dna_im = images[1]
    threshold = filters.threshold_otsu(dna_im)
    seg = dna_im > threshold
    seg = morphology.remove_small_objects(seg, min_size=64)
    return measure.label(seg)


def seg_init_soma(images):
    rna_im = images[4]
    threshold = filters.threshold_otsu(rna_im)
    seg = rna_im > threshold
    seg = morphology.remove_small_objects(seg, min_size=64)
    return measure.label(seg)


def seg_init_cell(images):
    rna_im = images[4]
    threshold = filters.threshold_otsu(rna_im)
    seg = rna_im > threshold
    seg = morphology.remove_small_objects(seg, min_size=64)
    return measure.label(seg)

In [4]:
def get_id_pre_hyph(name):
    return name.split("-")[0]


def get_id_pre_us(name):
    return name.split("_")[0]

In [None]:
id_to_path_ilastik

In [None]:
id_to_path_cellpose

In [None]:
dir_test = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/tifs"
dir_seg = (
    "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/seg_gt_v3/"
)
dir_ilastik = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/segmentations/"
dir_cellpose = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/tifs_cellpose/cellpose_commandline/"

id_to_path = get_id_to_path(
    path_dir=dir_test, tag=".tif", id_from_name=get_id_from_name_96
)
id_to_path_seg = get_id_to_path(dir_seg, tag=".tif", id_from_name=get_id_from_name_96)
id_to_path_ilastik = get_id_to_path(
    dir_ilastik, tag=".tif", id_from_name=get_id_pre_hyph
)
id_to_path_cellpose = get_id_to_path(
    dir_cellpose, tag=".png", id_from_name=get_id_pre_us
)

for id, paths in id_to_path_seg.items():
    ims = read_ims(id_to_path[id])
    ims = [im.astype("float64") for im in ims]
    ims = [np.linalg.norm(im, axis=-1) for im in ims]
    ims = [im / np.amax(im) for im in ims]
    ims = [exposure.equalize_adapthist(im, clip_limit=0.03) for im in ims]

    ilastik_segs = read_ims(id_to_path_ilastik[id])
    cellpose_segs = read_ims(id_to_path_cellpose[id])

    data_acc = []
    data_ious = []
    data_objects = []
    data_method = []

    data_acc_05 = []
    data_objects_05 = []
    data_method_05 = []

    for method in tqdm(
        ["otsu, size filter", "ilastik", "cellpose"],
        desc="Method",
        leave=False,
        disable=True,
    ):
        if method == "ilastik":
            seg_preds = [ilastik_segs[2], ilastik_segs[1], ilastik_segs[0]]
            # assert check_valid_labels(*seg_preds)
        elif method == "otsu, size filter":
            seg_soma = seg_init_soma(ims)
            seg_preds = [
                combine_soma_nucleus_labels(seg_soma, seg_init_nuc(ims)),
                seg_soma,
                combine_soma_cell_labels(seg_soma, seg_init_cell(ims)),
            ]
            # assert check_valid_labels(*seg_preds)
        elif method == "cellpose":
            seg_preds = [cellpose_segs[0], cellpose_segs[1]]

        for object_type, seg_pred in tqdm(
            zip(["nuclei", "somas", "cells"], seg_preds),
            desc="Compartment type",
            leave=False,
            disable=True,
        ):
            for path in paths:
                if object_type in str(path):
                    seg = io.imread(path)

            for iou_threshold in tqdm(
                np.arange(0.4, 0.8, 0.05), desc="Varying iou_threshold", leave=False
            ):
                ap2 = AP2(seg, seg_pred, iou_threshold)
                data_acc.append(ap2)
                data_ious.append(iou_threshold)
                data_objects.append(object_type)
                data_method.append(method)

                if iou_threshold == 0.5:
                    data_acc_05.append(ap2)
                    data_objects_05.append(object_type)
                    data_method_05.append(method)

    data = {
        "iou_threshold": data_ious,
        "Accuracy": data_acc,
        "Object type": data_objects,
        "Method": data_method,
    }
    df = pd.DataFrame(data)

    data_05 = {
        "Threat score/AP (IoU=0.5)": data_acc_05,
        "Object type": data_objects_05,
        "Method": data_method_05,
    }
    df_05 = pd.DataFrame(data_05)

    sns.barplot(
        data=df_05, x="Object type", y="Threat score/AP (IoU=0.5)", hue="Method"
    )
    plt.title(id)
    plt.show()
    # df_map = df.pivot_table(index="Object type", columns="Method", values="Accuracy", aggfunc="mean")
    # sns.barplot(data=df, x="Object type", y="Accuracy", hue="Method", estimator="mean", errorbar=None)
    # plt.title(id)
    # plt.show()

    # Make a line plot of precision, recall, and fscore as a function of iou_threshold
    fig, ax = plt.subplots()
    sns.lineplot(
        x="iou_threshold",
        y="Accuracy",
        hue="Method",
        style="Object type",
        data=df,
        ax=ax,
    )
    plt.title(id)
    plt.show()

In [None]:
im = io.imread(
    "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/tifs_cellpose/s039.tif"
)

im.dtype

## Image

In [48]:
dir_im = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/train_set/tifs"
# dir_seg = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/train_set/seg_gt_v3/"
dir_ilastik = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/train_set/segmentations_v1/"
dir_cellpose = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/train_set/segmentations/"

id_to_path = get_id_to_path(
    path_dir=dir_im, tag=".tif", id_from_name=get_id_from_name_96
)
# id_to_path_seg = get_id_to_path(dir_seg, tag=".tif", id_from_name=get_id_from_name_96)
id_to_path_ilastik = get_id_to_path(
    dir_ilastik, tag=".tif", id_from_name=get_id_pre_hyph
)
id_to_path_cellpose = get_id_to_path(
    dir_cellpose, tag=".tif", id_from_name=get_id_pre_hyph
)

for id, paths in id_to_path.items():
    if id != "s167":
        continue
    ims = read_ims(id_to_path[id])
    ims = [im.astype("float64") for im in ims]
    ims = [np.linalg.norm(im, axis=-1) for im in ims]
    ims = [im / np.amax(im) for im in ims]
    ims = [exposure.equalize_adapthist(im, clip_limit=0.03) for im in ims]
    im_rgb = np.stack([ims[3], ims[4], ims[1]], axis=2)

    ilastik_segs = read_ims(id_to_path_ilastik[id])
    cellpose_segs = read_ims(id_to_path_cellpose[id])

    viewer = napari.Viewer()
    viewer.add_image(ims[1], scale=(0.6, 0.6), name="DNA")
    viewer.add_labels(cellpose_segs[8], scale=(0.6, 0.6), name="nuclei")
    viewer.add_labels(cellpose_segs[2], scale=(0.6, 0.6), name="alive nuclei")
    viewer.add_labels(cellpose_segs[5], scale=(0.6, 0.6), name="dead nuclei")

    viewer.add_image(ims[4], scale=(0.6, 0.6), name="RNA")
    viewer.add_labels(cellpose_segs[7], scale=(0.6, 0.6), name="soma")
    viewer.add_labels(cellpose_segs[1], scale=(0.6, 0.6), name="alive soma")
    viewer.add_labels(cellpose_segs[4], scale=(0.6, 0.6), name="dead soma")

    viewer.add_image(ims[3], scale=(0.6, 0.6), name="Actin")
    viewer.add_labels(cellpose_segs[6], scale=(0.6, 0.6), name="cell")
    viewer.add_labels(cellpose_segs[0], scale=(0.6, 0.6), name="alive cell")
    viewer.add_labels(cellpose_segs[3], scale=(0.6, 0.6), name="dead cell")
    # viewer.add_labels(ctype, scale=(0.6, 0.6), name="celltype")
    # viewer.add_image(im_rgb, scale=(0.6, 0.6), rgb=True)

In [None]:
gt = io.imread(
    "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/seg_gt_v3/s039-cells.tif"
)
ilastik_segs = read_ims(id_to_path_ilastik["s039"])
ilastik = ilastik_segs[0]

ims = read_ims(id_to_path["s039"])
ims = [im.astype("float64") for im in ims]
ims = [np.linalg.norm(im, axis=-1) for im in ims]
ims = [im / np.amax(im) for im in ims]
ims = [exposure.equalize_adapthist(im, clip_limit=0.03) for im in ims]
seg_soma = seg_init_soma(ims)
seg_preds = [
    combine_soma_nucleus_labels(seg_soma, seg_init_nuc(ims)),
    seg_soma,
    combine_soma_cell_labels(seg_soma, seg_init_cell(ims)),
]
baseline = seg_preds[2]

# show gt, ilastik, and baseline in a row of subplots
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(gt)
ax[0].set_title("Ground truth")
ax[1].imshow(ilastik)
ax[1].set_title("Ilastik")
ax[2].imshow(baseline)
ax[2].set_title("Baseline")
plt.show()

# Cell type

In [88]:
path_im = (
    "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/hdf5s/"
)

path_seg_gt = (
    "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/seg_gt_v3/"
)
path_seg_pred = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/segmentations/"

path_ctype_gt = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/seg_gt_v3/celltypes/"
path_ctype_pred = (
    "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/hdf5s/"
)

id_to_path_seg_gt = get_id_to_path(
    path_seg_gt, tag=".tif", id_from_name=get_id_from_name_96
)
id_to_path_seg_pred = get_id_to_path(
    path_seg_pred, tag=".tif", id_from_name=get_id_pre_hyph
)

id_to_path_ctype_gt = get_id_to_path(
    path_ctype_gt, tag=".pickle", id_from_name=get_id_from_name_96
)
id_to_path_ctype_pred = get_id_to_path(
    path_ctype_pred, tag="Object", id_from_name=get_id_pre_us
)

In [None]:
data_acc = []
data_ious = []
data_celltype = []
data_sample = []

data_acc_05 = []
data_celltype_05 = []
data_sample_05 = []

for sample in id_to_path_seg_gt.keys():

    f_seg_gt = id_to_path_seg_gt[sample][2]
    f_seg_pred = id_to_path_seg_pred[sample][1]

    f_ctype_gt = id_to_path_ctype_gt[sample]
    f_ctype_pred = id_to_path_ctype_pred[sample]

    seg_pred = io.imread(f_seg_pred)
    seg_gt = io.imread(f_seg_gt)

    # gt
    with open(f_ctype_gt, "rb") as f:
        types = pickle.load(f)
    obj_to_type_gt = {i[1]: int(i[2]) for i in types}
    obj_to_type_gt[0] = 3

    assert set(obj_to_type_gt.keys()) == set(np.unique(seg_gt))

    seg_gt_alive = np.zeros_like(seg_gt)
    seg_gt_dead = np.zeros_like(seg_gt)
    for obj, ctype in obj_to_type_gt.items():
        if ctype == 1:
            seg_gt_alive[seg_gt == obj] = seg_gt[seg_gt == obj]
        if ctype == 2:
            seg_gt_dead[seg_gt == obj] = seg_gt[seg_gt == obj]

    # predicted
    with h5py.File(f_ctype_pred, "r") as f:
        obj_type = np.squeeze(f["exported_data"][:])

    obj_to_type_pred = {}
    for obj in np.unique(seg_pred):
        loc = np.argwhere(seg_pred == obj)[0, :]
        obj_to_type_pred[obj] = obj_type[loc[0], loc[1]]

    assert set(obj_to_type_pred.keys()) == set(np.unique(seg_pred))

    seg_pred_alive = np.zeros_like(seg_pred)
    seg_pred_dead = np.zeros_like(seg_pred)
    for obj, ctype in obj_to_type_pred.items():
        if ctype == 1:
            seg_pred_alive[seg_pred == obj] = seg_pred[seg_pred == obj]
        if ctype == 2:
            seg_pred_dead[seg_pred == obj] = seg_pred[seg_pred == obj]

    # plot

    for ctype, segs in zip(
        ["alive", "dead"],
        [(seg_gt_alive, seg_pred_alive), (seg_gt_dead, seg_pred_dead)],
    ):
        seg_gt, seg_pred = segs
        for iou_threshold in tqdm(
            np.arange(0.5, 0.99, 0.05), desc="Varying iou_threshold", leave=False
        ):
            ap2 = AP2(seg_gt, seg_pred, iou_threshold)
            data_acc.append(ap2)
            data_ious.append(iou_threshold)
            data_celltype.append(ctype)
            data_sample.append(sample)

            if iou_threshold == 0.5:
                data_acc_05.append(ap2)
                data_celltype_05.append(ctype)
                data_sample_05.append(sample)


data = {"iou_threshold": data_ious, "Accuracy": data_acc, "Cell Type": data_celltype}
df = pd.DataFrame(data)

data_05 = {
    "Threat score/AP (IoU=0.5)": data_acc_05,
    "Cell Type": data_celltype_05,
    "Sample": data_sample_05,
}
df_05 = pd.DataFrame(data_05)

sns.barplot(data=df_05, x="Cell Type", y="Threat score/AP (IoU=0.5)", hue="Sample")
plt.show()

fig, ax = plt.subplots()
sns.lineplot(x="iou_threshold", y="Accuracy", hue="Cell Type", data=df, ax=ax)
plt.show()

In [None]:
f_im = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/hdf5s/s039.h5"

f_seg_gt = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/seg_gt_v3/s039-somas.tif"
f_seg_pred = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/segmentations/s039-ch8sk1fk1fl1.tif"

f_ctype_gt = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/seg_gt_v3/celltypes/s039.pickle"
f_ctype_pred = "/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/hdf5s/s039_Object Predictions.h5"

seg_pred = io.imread(f_seg_pred)
seg_gt = io.imread(f_seg_gt)

with h5py.File(f_im, "r") as f:
    ims = f["image"][:]

# gt
with open(f_ctype_gt, "rb") as f:
    types = pickle.load(f)
obj_to_type_gt = {i[1]: int(i[2]) for i in types}
obj_to_type_gt[0] = 3

assert set(obj_to_type_gt.keys()) == set(np.unique(seg_gt))

seg_gt_alive = np.zeros_like(seg_gt)
seg_gt_dead = np.zeros_like(seg_gt)
for obj, ctype in obj_to_type_gt.items():
    if ctype == 1:
        seg_gt_alive[seg_gt == obj] = seg_gt[seg_gt == obj]
    if ctype == 2:
        seg_gt_dead[seg_gt == obj] = seg_gt[seg_gt == obj]

# predicted
with h5py.File(f_ctype_pred, "r") as f:
    obj_type = np.squeeze(f["exported_data"][:])

obj_to_type_pred = {}
for obj in np.unique(seg_pred):
    loc = np.argwhere(seg_pred == obj)[0, :]
    obj_to_type_pred[obj] = obj_type[loc[0], loc[1]]

assert set(obj_to_type_pred.keys()) == set(np.unique(seg_pred))


seg_pred_alive = np.zeros_like(seg_pred)
seg_pred_dead = np.zeros_like(seg_pred)
for obj, ctype in obj_to_type_pred.items():
    if ctype == 1:
        seg_pred_alive[seg_pred == obj] = seg_pred[seg_pred == obj]
    if ctype == 2:
        seg_pred_dead[seg_pred == obj] = seg_pred[seg_pred == obj]


# plot
data_acc = []
data_ious = []
data_celltype = []

data_acc_05 = []
data_celltype_05 = []


for ctype, segs in zip(
    ["alive", "dead"], [(seg_gt_alive, seg_pred_alive), (seg_gt_dead, seg_pred_dead)]
):
    seg_gt, seg_pred = segs
    for iou_threshold in tqdm(
        np.arange(0.5, 0.99, 0.05), desc="Varying iou_threshold", leave=False
    ):
        ap2 = AP2(seg_gt, seg_pred, iou_threshold)
        data_acc.append(ap2)
        data_ious.append(iou_threshold)
        data_celltype.append(ctype)

        if iou_threshold == 0.5:
            data_acc_05.append(ap2)
            data_celltype_05.append(ctype)


data = {"iou_threshold": data_ious, "Accuracy": data_acc, "Cell Type": data_celltype}
df = pd.DataFrame(data)

data_05 = {"Threat score/AP (IoU=0.5)": data_acc_05, "Cell Type": data_celltype_05}
df_05 = pd.DataFrame(data_05)

sns.barplot(data=df_05, x="Cell Type", y="Threat score/AP (IoU=0.5)")
plt.show()

fig, ax = plt.subplots()
sns.lineplot(x="iou_threshold", y="Accuracy", hue="Cell Type", data=df, ax=ax)
plt.show()

# Celltype with true segmentation

In [None]:

data_id = []
data_gt = []
data_pred = []

for id in ["s039", "s158", "s160"]:
    path_seg_gt = f"/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/seg_gt_v3/seg_npuint32/{id}-somas.tif"
    path_ctype_true_seg = f"/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/hdf5s/{id}_Object Predictions_gt.h5"
    path_ctype_gt = f"/Users/thomasathey/Documents/shavit-lab/fraenkel/96_well/exp2/test_set/seg_gt_v3/celltypes/{id}.pickle"

    # read the segmentation
    seg = io.imread(path_seg_gt)

    # read in the ground truth cell type labels
    with open(path_ctype_gt, "rb") as f:
        types = pickle.load(f)

    # read in the ground truth cell type segmentation
    with h5py.File(path_ctype_true_seg, "r") as f:
        ctype_true_seg = f["exported_data"][:]
        ctype_true_seg = np.squeeze(ctype_true_seg)

    correct = 0
    for gt_entry in types:
        obj = gt_entry[1]
        ctype = int(gt_entry[2])
        loc = np.argwhere(seg == obj)[0, :]

        if ctype == 1:
            data_gt.append("alive")
        else:
            data_gt.append("dead")
        
        if ctype_true_seg[loc[0], loc[1]] == 1:
            data_pred.append("alive")
        else:  
            data_pred.append("dead")

        data_id.append(id)



data = {"Sample": data_id, "True Class": data_gt, "Predicted Class": data_pred}
df = pd.DataFrame(data)

for id in ["s039", "s158", "s160"]:
    df_sub = df[df["Sample"] == id]

    df_cmat = df_sub.groupby(["True Class", "Predicted Class"]).size().unstack(fill_value=0)

    for i in range(2):
        df_cmat.iloc[i] /= df_cmat.iloc[i].sum()
    sns.heatmap(df_cmat, annot=True)
    plt.title(f"{id} - Normalized along true class")
    #plt.show()

    plt.savefig(f"/Users/thomasathey/Documents/shavit-lab/fraenkel/presentation/96-well/figures/ctype-acc/{id}.svg")
    plt.show()

## Image

In [None]:
ims = [ims[:, :, i] for i in range(ims.shape[-1])]
ims = [im.astype("float64") for im in ims]
ims = [im / np.amax(im) for im in ims]
ims = [exposure.equalize_adapthist(im, clip_limit=0.03) for im in ims]
im_rgb = np.stack([ims[2], ims[5], ims[4]], axis=2)

In [None]:
viewer = napari.Viewer()
viewer.add_image(ims[5], scale=(0.6, 0.6))  # , rgb=True)
viewer.add_labels(seg_pred_alive, scale=(0.6, 0.6), name="pred. alive")
viewer.add_labels(seg_pred_dead, scale=(0.6, 0.6), name="pred. dead")