In [10]:
import numpy as np
import tifffile
from scipy import ndimage
from scipy.ndimage.morphology import binary_opening
from skimage import measure
from skimage.morphology import disk
import warnings
from config import *
import warnings
import matplotlib.pyplot as plt


class Gtgrid:
    def __init__(self, img_gt, img_baseline, area=200):
        self.img_gt = img_gt
        self.img_baseline = img_baseline
        self.img_gt = self.remove_nuclei_border(self.img_gt)
        self.img_baseline = self.remove_nuclei_border(self.img_baseline)

    def remove_nuclei_border(self, img, margin=3):
        uniques, counts = np.unique(img, return_counts=True)
        for rep in uniques:
            tmp = np.where(img == rep)
            x_min, x_max = np.min(tmp[0]), np.max(tmp[0])
            y_min, y_max = np.min(tmp[1]), np.max(tmp[1])
            bool_erase = (
                (x_min <= margin)
                or (x_max >= img.shape[0] - margin)
                or (y_min <= margin)
                or (y_max >= img.shape[0] - margin)
            )
            bool_erase = bool_erase
            if bool_erase:
                img[img == rep] = 0
        return img

    def create_dictionnary(self):
        dic = {}
        """
        trouvons les associations en premier
        """
        uniques_gt = np.unique(self.img_gt)[1:]
        uniques_baseline = np.unique(self.img_baseline)[1:]
        for unique_gt in uniques_gt:
            for unique_baseline in uniques_baseline:
                nuclei_baseline = self.img_baseline == unique_baseline
                nuclei_gt = self.img_gt == unique_gt
                iou = np.sum(nuclei_baseline * nuclei_gt) / np.sum(
                    np.maximum(nuclei_baseline, nuclei_gt)
                )

                if iou > 0.5:
                    dic[unique_gt] = unique_baseline

        used_gt = list(dic.keys())
        used_baseline = list(dic.values())
        all_gts = list(uniques_gt)
        all_baselines = list(uniques_baseline)
        not_used_gt = np.array(list(set(all_gts) - set(used_gt)))
        not_used_baseline = np.array(list(set(all_baselines) - set(used_baseline)))

        """cherchons si stardist n'a pas mergé des noyaux"""
        not_used_baseline1 = not_used_baseline.copy()
        not_used_gt1 = not_used_gt.copy()

        for nb_nuclei_baseline1 in not_used_baseline1:
            nuclei_baseline = self.img_baseline == nb_nuclei_baseline1
            i = 0
            l = []
            for nb_nuclei_gt1 in not_used_gt1:
                nucleit_gt = self.img_gt == nb_nuclei_gt1
                iou = np.sum(nuclei_baseline * nuclei_gt) / np.sum(nuclei_gt)

                if iou > 0.5:
                    l.append(nb_nuclei_gt1)

            if len(l) >= 2:
                for element in l:
                    dic[element] = nb_nuclei_baseline1

                not_used_baseline1 = np.delete(
                    not_used_baseline1,
                    np.where(not_used_baseline1 == nb_nuclei_baseline1),
                )
                not_used_gt1 = np.delete(
                    not_used_gt1, np.where(np.isin(not_used_gt1, l))[0].flatten()
                )

        not_used_baseline2 = not_used_baseline1.copy()
        not_used_gt2 = not_used_gt1.copy()

        """cherchons si stardist n'a pas splité des noyaux"""
        # print(not_used_baseline2)
        for nb_nuclei_gt2 in not_used_gt2:
            nuclei_gt = self.img_gt == nb_nuclei_gt2
            l = []
            for nb_nuclei_baseline2 in not_used_baseline2:
                nuclei_baseline = self.img_baseline == nb_nuclei_baseline2
                iou = np.sum(nuclei_gt * nuclei_baseline) / np.sum(nuclei_baseline)
                if iou >= 0.5:
                    # print('hey')
                    # plt.imshow(nuclei_baseline)
                    # plt.show()
                    l.append(nb_nuclei_baseline2)

            if len(l) >= 2:
                dic[nb_nuclei_gt2] = l
                not_used_gt2 = np.delete(
                    not_used_gt2, np.where(not_used_gt2 == nb_nuclei_gt2)
                )
                not_used_baseline2 = np.delete(
                    not_used_baseline2,
                    np.where(np.isin(not_used_baseline2, l))[0].flatten(),
                )

        """
        if a nuclei appears several times in the values of the dictionnary,
        the baseline has merged some gt nucleis
        """
        values_dic = list(dic.values())
        unique, counts = np.unique(self.flatten(values_dic), return_counts=True)
        baseline_merges = []
        for merge in np.where(counts > 1)[0].flatten():
            baseline_merges.append(unique[merge])

        """
        if a list is in the values of the dictionnary, the baseline splitted a nuclei_gt
        """

        gt_splits = np.array([u for u in list(dic.keys()) if type(dic[u]) == list])
        return (
            np.array(baseline_merges).astype(int),
            gt_splits.astype(int),
            not_used_gt2.astype(int),
            not_used_baseline2.astype(int),
        )

    def get_size_of_error(self, binary, margin=15):
        indexes = np.argwhere(binary)
        bottom_left = np.min(indexes, 0)
        upper_right = np.max(indexes, 0)
        return np.max(np.abs(bottom_left - upper_right)) + margin

    def flatten(self, list_):
        new_list = []
        for v in list_:
            if type(v) != list:
                new_list.append(v)
            else:
                new_list = new_list + v
        return new_list

    def create_dic_errors(self):
        all_errors = self.create_dictionnary()
        centers = []
        sizes = []
        receptive_field_sizes = np.array([5, 13, 29, 61, 125])

        keys = ["merge", "split", "fn", "fp"]
        dic = {}
        for key, errors, img in zip(
            keys,
            all_errors,
            (self.img_baseline, self.img_gt, self.img_gt, self.img_baseline),
        ):
            scale1 = np.zeros((256, 256))
            scale2 = np.zeros((128, 128))
            scale3 = np.zeros((64, 64))
            scale4 = np.zeros((32, 32))
            scale5 = np.zeros((16, 16))

            scales = [scale1, scale2, scale3, scale4, scale5]

            for nb in errors:
                error = img == nb
                rows, columns = np.where(error)[0:2]
                min_rows, max_rows = np.min(rows), np.max(rows)
                min_columns, max_columns = np.min(columns), np.max(columns)
                mean_rows = (max_rows + min_rows) // 2
                mean_columns = (max_columns + min_columns) // 2
                size = self.get_size_of_error(error)
                print(size)
                scale = np.searchsorted(receptive_field_sizes, size).astype(int)
                factor_resize = 1 / (2 ** (scale))
                row, col = np.round(factor_resize * mean_rows).astype(int), np.round(
                    factor_resize * mean_columns
                ).astype(int)
                row, col = np.clip(row,0,(256//(2**scale))-1).astype(int), np.clip(col,0,(256/(2**scale))-1).astype(int)
                print(scale,row,col)
                scales[scale][row, col] = 1
            scales = [u for u in scales if np.sum(u) > 0]
            dic[key] = scales

        all_scales_used = []
        for key, value in dic.items():
            for l in value:
                all_scales_used.append(l.shape[0])
        return dic, np.unique(all_scales_used)

    def create_grid(self):
        dic = self.create_dic_errors()[0]
        u = np.zeros((256, 256, 4))
        for i, (error, array) in enumerate(dic.items()):
            u_error = np.zeros((256, 256))
            if len(array) != 0:
                for arr in array:
                    factor_resize = 256 / arr.shape[0]
                    coord_image = (
                        (np.stack(np.where(arr)).T + 1 / 2) * factor_resize
                    ).astype(int)
                    for coord in coord_image:
                        u_error[coord[0], coord[1]] = 1
        return u


  from scipy.ndimage.morphology import binary_opening


In [11]:
from config import *
# from create_gt_grids import Gtgrid

filename = 'nuclei_3_1351.tif'
img_gt = tifffile.imread(os.path.join(path_gt, filename))
img_baseline = tifffile.imread(os.path.join(path_stardist_modified, filename))
grid_creator = Gtgrid(img_gt, img_baseline, area=0)
grid = grid_creator.create_grid()

30
3 14 25
28
2 26 59
43
3 12 27
44
3 18 20
31
3 27 28
43
3 21 23
45
3 7 10
43
3 4 8
25
2 60 52
39
3 19 26
43
3 8 26
37
3 28 29
55
3 16 4
33
3 30 20
36
3 3 4
33
3 17 15
179
5 2 4


IndexError: list index out of range