In [None]:
from pathlib import Path

import cv2
import h5py
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm import tqdm

## Check folders/files

In [None]:
neuron_mask_dir = Path("../output/")
groups = sorted([x for x in neuron_mask_dir.iterdir() if x.is_dir()])

second_mask_dir = Path("../masks_SECOND/")

print("We found the following groups:")
for g in groups:
    masks_ls = sorted(g.rglob("*_masks.tif"))
    print(f"{g.name}: {len(masks_ls)} images")

    second_group_dir = second_mask_dir / g.name
    assert second_group_dir.exists(), (
        f"Second group dir {second_group_dir} does not exist"
    )

    for mask_path in masks_ls:
        img_id = mask_path.stem.replace("_masks", "")

        # find the second folder
        second_img_dir = second_group_dir / img_id
        assert second_img_dir.exists(), f"Second mask {second_img_dir} does not exist"

        # find the second mask
        second_mask_path = next(second_img_dir.glob("*.h5"))
        assert second_mask_path.exists(), (
            f"Second mask {second_mask_path} does not exist"
        )


## Distance analysis

In [None]:
# upper bound for distance calculation
upper_limit_um = 1000
# bin width
bin_width_um = 50
# conversion factor
conv_fct = 0.344
# compression factor
comp_fct = 0.5

In [None]:
# read the mask in h5 format
def read_h5_mask(mask_path):
    with h5py.File(mask_path, "r") as f:
        keys = list(f.keys())
        exclusion_mask = f["exclusions"][:] if "exclusions" in keys else None
        hole_mask = f["hole"][:]
    return hole_mask, exclusion_mask

In [None]:
def extract_centroids(cp_pred, exclusion_mask, comp_fct=0.5):
    res = cv2.connectedComponentsWithStats(cp_pred, connectivity=8)
    centroids = res[3] / comp_fct

    if exclusion_mask is not None:
        y_centroids = centroids[:, 1].astype(int)
        x_centroids = centroids[:, 0].astype(int)

        masked_pts = exclusion_mask[y_centroids, x_centroids]
        centroids = centroids[~masked_pts]

    return centroids


def read_cp_mask(mask_path):
    mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
    return mask


def binned_analysis(binned, hole_mask, centroids, bins, comp_fct):
    counts = np.zeros(len(bins) - 1)
    area = np.zeros(len(bins) - 1)
    hole_area = np.sum(hole_mask)
    for i in range(1, len(bins)):
        counts[i - 1] = np.sum(
            binned[centroids[:, 1].astype(int), centroids[:, 0].astype(int)] == i
        )
        if i == 1:
            area[i - 1] = (np.sum(binned == i) - hole_area) * comp_fct**2
        else:
            area[i - 1] = np.sum(binned == i) * comp_fct**2

    return counts / area, counts, area

In [None]:
for g in groups:
    masks_ls = sorted(g.rglob("*_masks.tif"))
    print(f"{g.name}: {len(masks_ls)} images")

    second_group_dir = second_mask_dir / g.name
    preview_group_dir = Path(f"../results/bins_preview/{g.name}")
    preview_group_dir.mkdir(exist_ok=True, parents=True)

    density_res_ls = []
    count_res_ls = []
    area_res_ls = []

    for mask_path in tqdm(masks_ls):
        img_id = mask_path.stem.replace("_masks", "")

        second_img_dir = second_group_dir / img_id
        second_mask_path = next(second_img_dir.glob("*.h5"))

        # read the masks
        cp_mask = read_cp_mask(mask_path)
        hole_mask, exclusion_mask = read_h5_mask(second_mask_path)

        # extract centroids
        # combine the hole mask and exclusion mask
        if exclusion_mask is not None:
            exclusion_mask = hole_mask | exclusion_mask
        else:
            exclusion_mask = hole_mask
        cp_centroids = extract_centroids(cp_mask, exclusion_mask, comp_fct=comp_fct)

        # calculate the distance
        dist_map = cv2.distanceTransform(
            (~hole_mask).astype(np.uint8), cv2.DIST_L2, cv2.DIST_MASK_PRECISE
        )
        dist_map = dist_map * conv_fct

        bins = np.arange(0, upper_limit_um + bin_width_um, bin_width_um)
        binned_dist_map = np.digitize(dist_map, bins)

        # plot the distance bin preview and centroids
        plt.imshow(binned_dist_map, cmap="viridis_r")
        plt.imshow(np.ma.masked_where(~hole_mask, hole_mask), cmap="plasma", alpha=1)
        # draw a contour around the hole
        plt.scatter(
            cp_centroids[:, 0], cp_centroids[:, 1], c="r", s=1, marker=".", alpha=0.5
        )
        plt.axis("off")
        plt.savefig(
            preview_group_dir / f"{img_id}.png",
            bbox_inches="tight",
            pad_inches=0,
            dpi=150,
        )
        plt.close()

        density, count, area = binned_analysis(
            binned_dist_map, hole_mask, cp_centroids, bins, comp_fct
        )
        density = density * 1e6  # convert to mm^2
        area = area / 1e6  # convert to mm^2
        bin_str = [f"{b}-{b + bin_width_um}" for b in bins[:-1]]

        density_dict = {
            "image_id": img_id,
        }
        for i, d in enumerate(density):
            density_dict[bin_str[i]] = d

        count_dict = {
            "image_id": img_id,
        }
        for i, c in enumerate(count):
            count_dict[bin_str[i]] = c

        area_dict = {
            "image_id": img_id,
        }
        for i, a in enumerate(area):
            area_dict[bin_str[i]] = a

        area_dict = {
            "image_id": img_id,
        }
        for i, a in enumerate(area):
            area_dict[bin_str[i]] = a

        density_res_ls.append(density_dict)
        count_res_ls.append(count_dict)
        area_res_ls.append(area_dict)

        del cp_mask, hole_mask, exclusion_mask, cp_centroids, dist_map, binned_dist_map

    density_df = pd.DataFrame(density_res_ls)
    count_df = pd.DataFrame(count_res_ls)
    area_df = pd.DataFrame(area_res_ls)

    density_df.to_csv(f"../results/{g.name}_density.csv", index=False)
    count_df.to_csv(f"../results/{g.name}_count.csv", index=False)
    area_df.to_csv(f"../results/{g.name}_area.csv", index=False)
