In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors as plt_colors
import pandas as pd

from numena.image.color import rgb2bgr
from numena.image.drawing import draw_overlay
from numena.image.contour import contours_fill, contours_find
from numena.geometry import MicroEntity2D
from numena.figure import Figure

from kartezio.model.registry import registry
from kartezio.dataset import read_dataset

from generated_classes import ModelWGA, ModelDiO
from WGA.train_model import preprocessing as preprocessing_WGA
from DiO.train_model import preprocessing as preprocessing_DiO

import seaborn as sns

In [None]:
DATASET = "./Experiment"


COLOR_WGA_HEX = "#0099DB"
COLOR_WGA = plt_colors.hex2color(COLOR_WGA_HEX)
COLOR_WGA = [int(round((COLOR_WGA[0]*255))), int(round((COLOR_WGA[1]*255))), int(round((COLOR_WGA[2]*255)))]


COLOR_DIO_HEX = "#B55088"
COLOR_DIO = plt_colors.hex2color(COLOR_DIO_HEX)
COLOR_DIO = [int(round((COLOR_DIO[0]*255))), int(round((COLOR_DIO[1]*255))), int(round((COLOR_DIO[2]*255)))]

In [None]:
def __create_entity(mask):
    mask_bin = np.zeros_like(mask, dtype=np.uint8)
    cnts = contours_find(mask)
    if len(cnts) != 1:
        raise ValueError(f"Can't create entity with ambiguous mask!")
    cnt = cnts[0]
    if len(cnt) < 4:
        raise ValueError(f"Can't create entity with less than 4 points!")
    M = cv2.moments(cnt)
    if M["m00"] == 0:
        raise ValueError(f"Can't create entity with M['m00'] == 0")
    cx = int(M["m10"] / M["m00"])
    cy = int(M["m01"] / M["m00"])
    mask_bin = contours_fill(mask_bin, [cnt], color=255)
    return mask_bin, cx, cy


class Particle2D(MicroEntity2D):
    pass


def create_particle(name, mask, custom_data=None):
    mask_bin, x, y = __create_entity(mask)
    return Particle2D(name, mask_bin, x, y, custom_data=custom_data)

In [None]:
dataset = read_dataset(DATASET)

In [None]:
model_WGA = ModelWGA()
model_DiO = ModelDiO()
matching = registry.metrics.instantiate("CAP")

In [None]:
x, y, v = dataset.train_xyv

In [None]:
p_WGA = model_WGA.predict(x, reformat_x=preprocessing_WGA)

p_DIO = model_DiO.predict(x, reformat_x=preprocessing_DiO)

In [None]:
def plot_particles_2(WGA_labels, DiO_labels, both_labels):
    fig = Figure(title=f"Particles predictions and fusion", size=(12, 4))
    fig.create_panels(rows=1, cols=3)

    A = fig.get_panel(0)
    A.set_title("WGA particles")
    A.axis("off")
    A.imshow(WGA_labels)
    cv2.imwrite("./results/WGA_labels.png", rgb2bgr(WGA_labels))
    
    B = fig.get_panel(1)
    B.set_title("DiO particles")
    B.axis("off")
    B.imshow(DiO_labels)
    cv2.imwrite("./results/DiO_labels.png", rgb2bgr(DiO_labels))
    
    C = fig.get_panel(2)
    C.set_title("./results/WGA+DiO particles")
    C.axis("off")
    C.imshow(both_labels)
    cv2.imwrite("./results/WGA-DiO_labels.png", rgb2bgr(both_labels))


def plot_particles(WGA, WGA_mask, WGA_labels, DiO, DiO_mask, DiO_labels):
    fig = Figure(title=f"Particles predictions made by Kartezio", size=(12, 8))
    fig.create_panels(rows=2, cols=3)

    A = fig.get_panel(0)
    A.set_title("WGA channel")
    A.axis("off")
    A.imshow(WGA, cmap="Greens")
    
    B = fig.get_panel(1)
    B.set_title("Kartezio Mask")
    B.axis("off")
    B.imshow(WGA_mask, cmap="viridis")
    
    C = fig.get_panel(2)
    C.set_title("Kartezio Labels")
    C.axis("off")
    C.imshow(WGA_labels, cmap="viridis")

    
    D = fig.get_panel(3)
    D.set_title("DiO channel")
    D.axis("off")
    D.imshow(DiO, cmap="Reds")
    
    E = fig.get_panel(4)
    E.set_title("Kartezio Mask")
    E.axis("off")
    E.imshow(DiO_mask, cmap="viridis")
    
    F = fig.get_panel(5)
    F.set_title("Kartezio Labels")
    F.axis("off")
    F.imshow(DiO_labels, cmap="viridis")

    
plot_wga = []
plot_dio = []
plot_wga_sum = []
plot_dio_sum = []
classes = []
sizes=[]
markers_plot = []

PIXEL_TO_MICRON_SQUARE = 0.1066667 * 0.1066667


for xi, pi_dio, pi_wga, visual in zip(x, p_DIO, p_WGA, v):
    WGA_particles = {}
    DiO_particles = {}
    WGA_DiO_particles = []
    
    dio_to_rm  = []
    wga_to_rm  = []

    WGA = xi[0]
    WGA_mask = pi_wga["mask"]
    heatmap_color = cv2.applyColorMap((WGA_mask*50).astype(np.uint8), cv2.COLORMAP_VIRIDIS)
    cv2.imwrite("./results/WGA_dt.png", heatmap_color)
    WGA_labels = pi_wga["labels"]
    
    DiO = xi[1]
    DiO_mask = pi_dio["mask"]
    heatmap_color = cv2.applyColorMap((DiO_mask*50).astype(np.uint8), cv2.COLORMAP_VIRIDIS)
    cv2.imwrite("./results/DiO_dt.png", heatmap_color)
    DiO_labels = pi_dio["labels"]
    
    WGA_color_labels = np.zeros(shape=(WGA_mask.shape[0], WGA_mask.shape[0], 3), dtype=np.uint8)
    colors = np.random.randint(256, size=(len(np.unique(WGA_labels)), 3), dtype=np.uint8)
    for i, label in enumerate(np.unique(WGA_labels)):
        if label == 0:
            continue
        color = colors[i].tolist()
        label_mask = (WGA_labels == label).astype(np.uint8)
        try:
            particle = create_particle(f"particle_WGA_{label}", label_mask)
            WGA_particles[label] = particle
            WGA_color_labels = draw_overlay(WGA_color_labels, label_mask, color=color, alpha=0.8)
        except Exception as e:
            print(e)
            print(f"missed one particle {label}")
    
    DiO_color_labels = np.zeros(shape=(DiO_mask.shape[0], DiO_mask.shape[0], 3), dtype=np.uint8)
    colors = np.random.randint(256, size=(len(np.unique(DiO_labels)), 3), dtype=np.uint8)
    for i, label in enumerate(np.unique(DiO_labels)):
        if label == 0:
            continue
        color = colors[i].tolist()
        label_mask = (DiO_labels == label).astype(np.uint8)
        try:
            particle = create_particle(f"particle_DiO_{label}", label_mask)
            DiO_particles[label] = particle
            DiO_color_labels = draw_overlay(DiO_color_labels, label_mask, color=color, alpha=0.8)
        except Exception as e:
            print(e)
            print(f"missed one particle {label}")
    
    scores = matching.mask_ious(WGA_labels, DiO_labels)
    
    both_final_particles = np.zeros(shape=(DiO_mask.shape[0], DiO_mask.shape[0], 3), dtype=np.uint8)     
    
    
    
    for i, (iou, idx) in enumerate(zip(scores[0], scores[1])):
        if iou > 0.05:
            mask = (DiO_particles[idx].mask | WGA_particles[i+1].mask).astype(np.uint8)
            both_final_particles = draw_overlay(both_final_particles, DiO_particles[idx].mask, color=COLOR_DIO, alpha=1.0)
            try:
                particle = create_particle(f"particle_{label}", mask)
                plot_wga.append(particle.get_mean(WGA))
                plot_dio.append(particle.get_mean(DiO))
                plot_wga_sum.append(particle.get_sum(WGA))
                plot_dio_sum.append(particle.get_sum(DiO))
                classes.append("WGA+/DiO+")
                markers_plot.append("o")
                sizes.append(particle.area * PIXEL_TO_MICRON_SQUARE)
                WGA_DiO_particles.append(particle)
                both_final_particles = draw_overlay(both_final_particles, particle.mask, color=[254, 174, 52], alpha=1.0)
                
                dio_to_rm.append(idx)
                wga_to_rm.append(i+1)
            except Exception as e:
                print(e)
                print(f"missed one particle {idx}-{i}")
    
    WGA_final_particles = np.zeros(shape=(WGA_mask.shape[0], WGA_mask.shape[0], 3), dtype=np.uint8)
    for lbl, WGA_part in WGA_particles.items():
        WGA_final_particles = draw_overlay(WGA_final_particles, WGA_part.mask, color=COLOR_WGA, alpha=1.0)
        if not lbl in wga_to_rm:
            both_final_particles = draw_overlay(both_final_particles, WGA_part.mask, color=COLOR_WGA, alpha=1.0)
            plot_wga.append(WGA_part.get_mean(WGA))
            plot_dio.append(WGA_part.get_mean(DiO))
            plot_wga_sum.append(WGA_part.get_sum(WGA))
            plot_dio_sum.append(WGA_part.get_sum(DiO))
            classes.append("WGA+/DiO-")
            markers_plot.append("^")
            sizes.append(WGA_part.area * PIXEL_TO_MICRON_SQUARE)
        
    DiO_final_particles = np.zeros(shape=(DiO_mask.shape[0], DiO_mask.shape[0], 3), dtype=np.uint8)
    for lbl, DiO_part in DiO_particles.items():
        DiO_final_particles = draw_overlay(DiO_final_particles, DiO_part.mask, color=COLOR_DIO, alpha=1.0)
        if not lbl in dio_to_rm:
            both_final_particles = draw_overlay(both_final_particles, DiO_part.mask, color=COLOR_DIO, alpha=1.0)
            plot_wga.append(DiO_part.get_mean(WGA))
            plot_dio.append(DiO_part.get_mean(DiO))
            plot_wga_sum.append(DiO_part.get_sum(WGA))
            plot_dio_sum.append(DiO_part.get_sum(DiO))
            classes.append("WGA-/DiO+")
            markers_plot.append("X")
            sizes.append(DiO_part.area * PIXEL_TO_MICRON_SQUARE)
        
    
    
    plot_particles(WGA, WGA_mask, WGA_color_labels, DiO, DiO_mask, DiO_color_labels)
    
    
    plot_particles_2(WGA_final_particles, DiO_final_particles, both_final_particles)


In [None]:
data = {
    "Mean WGA Intensity (Log)": plot_wga,
    "Mean DiO Intensity (Log)": plot_dio,
    "Classified": classes,
    "\nSize (µ²)": sizes,
    "markers": markers_plot,
}
df = pd.DataFrame(data)
mean_area = np.mean(df["\nSize (µ²)"])
r = np.sqrt(mean_area / 3.14)

print(mean_area)
print(r)
print(2*r)
print(len(df[df["Classified"] == "WGA+/DiO-"]))
print(len(df))

In [None]:
palette = ["#FEAE34", COLOR_WGA_HEX, COLOR_DIO_HEX]
markers= {
    "WGA+/DiO+": "o",
    "WGA+/DiO-": "^",
    "WGA-/DiO+": "X",
}


with plt.style.context(["science", "nature"]):
    fig, ax = plt.subplots()
    ax.tick_params(
        axis='both',
        which='both',
        bottom=True,
        top=False,
        right=False,
        left=True)
    ax.set_yscale('log')
    ax.set_xscale('log')
    sns.scatterplot(data=data, x="Mean WGA Intensity (Log)", y="Mean DiO Intensity (Log)", hue="Classified", s=100, ax=ax, palette=palette, size="\nSize (µ²)", markers=markers)
    ax.set(xlabel=None, ylabel=None)
    ax.legend(loc = 2, bbox_to_anchor = (1,1), facecolor="white")
    fig.savefig("./results/Fig3_d.png", dpi=300)