In [None]:
from numena.io.drive import Directory
from numena.io.image import imread_tiff
from kartezio.fitness import FitnessAP
from numena.image.threshold import threshold_tozero
from kartezio.model.components import KartezioStacker
from kartezio.dataset import read_dataset
from kartezio.plot import plot_watershed
from kartezio.inference import KartezioModel
from numena.image.morphology import get_kernel
from numena.geometry import Cell2D
from numena.features.profiling import CellStainingProfile, ProfilingInfo
import numpy as np
import cv2
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import umap.umap_ as umap
from train_model import preprocessing

In [None]:
MODEL_NAME = "16425-b443195a-85d8-439a-992d-2f9112f1319c/elite.json"

In [None]:
def labels_to_cells(labels_image):
    kernel = get_kernel("circle", 2)
    label_numbers = np.unique(labels_image)
    cells = []
    for n in label_numbers:
        if n == 0:
            continue
        mask = (labels_image == n).astype('uint8')
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
        mask= cv2.resize(mask, (1024, 1024))
        new_cell = Cell2D.from_mask(mask, f'cell_{n}')
        if new_cell is not None:
            cells.append(new_cell)
    return cells

In [None]:
class MeanStackerSaver(KartezioStacker):
    def _to_json_kwargs(self):
        return {}

    def __init__(self, name="mean_stacker_saver", abbv="MEAN-saver", arity=1, threshold=4):
        super().__init__(name, abbv, arity)
        self.threshold = threshold

    def stack(self, Y):
        for i in range(len(Y)):
            heatmap_color = cv2.applyColorMap(Y[i], cv2.COLORMAP_VIRIDIS)
            cv2.imwrite(f"./results/output_z{i}.png", heatmap_color)
        return np.mean(np.array(Y), axis=0).astype(np.uint8)

    def post_stack(self, x, index):
        yi = x.copy()
        heatmap_color = cv2.applyColorMap(yi, cv2.COLORMAP_VIRIDIS)
        cv2.imwrite("./results/output_mean.png", heatmap_color)
        output = threshold_tozero(yi, self.threshold)
        heatmap_color = cv2.applyColorMap(output, cv2.COLORMAP_VIRIDIS)
        cv2.imwrite("./results/labels.png", heatmap_color)
        return output

In [None]:
experiment_dataset = Directory("./dataset/experiment")
infos = ProfilingInfo(experiment_dataset / "INFOS.json")
profiling = CellStainingProfile(infos)
mapping = {}
for filepath in experiment_dataset.ls("raw/*.tif"):
    filename = filepath.name
    img = imread_tiff(filepath)
    channels = [
        img[:, 0],
        img[:, 1],
        img[:, 2],
        img[:, 3],
    ]
    mapping[int(filename.split(".")[0])] = img

In [None]:
dataset = read_dataset(str(experiment_dataset._path))
model = KartezioModel(f"./models/{MODEL_NAME}", FitnessAP(thresholds=0.7))


# Load the same model and change the Stacker by a custom one
model_introspection = KartezioModel(f"./models/{MODEL_NAME}", FitnessAP(thresholds=0.7))
model_introspection._model.parser.stacker = MeanStackerSaver()

In [None]:
cell_features = []
cells_object = []
p, f, t = model.eval(dataset, subset="test", reformat_x=preprocessing)
ordered_keys = [1, 4, 9, 14, 16, 2, 3, 5, 6, 7, 8, 10, 11, 12, 13, 15, 17]
for image_idx in range(len(dataset.test_x)):
    plot_watershed(dataset.test_v[image_idx], p[image_idx]["mask"], p[image_idx]["markers"], p[image_idx]["labels"], gt=dataset.test_y[image_idx][0])
    image_labels = p[image_idx]["labels"]
    cells = labels_to_cells(image_labels)
    channels = mapping[ordered_keys[image_idx]]

    if image_idx == 2:
        print(dataset.test_x[image_idx][0].shape, len(dataset.test_x[image_idx]))
        y_hat, t = model_introspection.predict([dataset.test_x[image_idx]], reformat_x=preprocessing)
        dt = y_hat[0]["mask"]
        dt = (( dt / dt.max()) * 255).astype(np.uint8)
        heatmap_color = cv2.applyColorMap(dt, cv2.COLORMAP_VIRIDIS)
        cv2.imwrite("./results/distance_transform.png", heatmap_color)

    for cell in cells:
        cells_object.append(cell)
        profile = profiling.get_profile(cell, channels)
        profile[0].append(cell.area)
        profile[1].append("area")
        cell_features.append(profile[0])

feature_name = profile[1]
df = pd.DataFrame(cell_features, columns=feature_name)
df = df[(df.area > 2000) & (df.area < 14000)]

In [None]:
X = df.values[:, [0, 1, 2, 3, 4, 5, 6, 10, 13, 16]]
scaler = StandardScaler().fit(X)
X_scaled = scaler.transform(X)
X_emb = umap.UMAP(n_neighbors=15, n_components=2).fit_transform(X_scaled)
df_plot = df.copy()
df_plot["UMAP-1"] = X_emb[:, 0]
df_plot["UMAP-2"] = X_emb[:, 1]

In [None]:
pparam = dict(xlabel="UMAP-1") #, ylabel="UMAP-2")

for staining in ["Perf+", "GzmB+", "Lamp+", "GzmB-Perf-Lamp"]:
    with plt.style.context(["science", "nature"]):
        fig, ax = plt.subplots()
        ax.tick_params(
        axis='both',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        right=False,
        left=False,         # ticks along the top edge are off
        labelbottom=False,
        labelleft=False) # labels along the bottom edge are off
        sns.scatterplot(data=df_plot, x="UMAP-1", y="UMAP-2", hue=staining, palette="viridis", edgecolor="none", ax=ax)
        ax.legend(title=staining)
        ax.set(**pparam)
        fig.savefig(f"./results/Fig4_defg_{staining}.png", dpi=300)