# CDACS Model Experiment - Evaluation

## Import necessary libraries

In [None]:
import os
from IPython.display import clear_output
import matplotlib.pyplot as plt

In [None]:
import cv2
import numpy as np
from skimage import io, color, measure
from skimage.segmentation import watershed
import scipy
import centrosome.outline
import centrosome.cpmorphology
from PIL import Image
import glob
from tqdm import tqdm

from module.evaluate.nuclei_util_v2 import identifyprimaryobjects

Image.MAX_IMAGE_PIXELS = None

## Define evaluation functions

In [None]:
def filter_on_size(img, size=1):
    labeled_image, object_count = scipy.ndimage.label(img, np.ones((3, 3), bool))
    areas = np.array(scipy.ndimage.measurements.sum(np.ones(labeled_image.shape), labeled_image,
                                                    np.array(range(0, object_count + 1), dtype=np.int32)),
                     dtype=int)

    area_image = areas[labeled_image]
    labeled_image[area_image < size] = 0
    labeled_image, object_count = centrosome.cpmorphology.relabel(labeled_image)
    labeled_mask = labeled_image > 0
    labeled_mask = labeled_mask * 255
    return np.array(labeled_mask, dtype=np.uint8)

def labeled_image_crop_region_counting(mask, labeled_image):    
    mask = mask > 0
    
    valid_mask = mask & (labeled_image != 0)
    
    valid_list = labeled_image[valid_mask == True]

    histogram = scipy.sparse.coo_matrix((np.ones(valid_list.shape), (valid_list, np.zeros(valid_list.shape))),
                                        shape=(np.max(labeled_image) + 1, 1)).todense()

    histogram = np.array(histogram).flatten()
    return histogram

def evaluate_method(image_path_list, target_path_list, pred_path_list):
    total_tp = 0
    total_tn = 0
    total_fp = 0
    total_fn = 0

    mask_targets, mask_preds, r_labeled_images = [], [], []
    for i in tqdm(range(len(image_path_list))):
        print(image_path_list[i].split('/')[-1].split("_input")[0])
        img = np.array(Image.open(image_path_list[i]))
        mask_target = np.array(Image.open(target_path_list[i]))[:, :, 0]
        mask_pred = np.array(Image.open(pred_path_list[i]))[:, :, 0]
        mask_pred = filter_on_size(mask_pred, 11)

        img_r = img[:, :, 0].copy()
        img_r = (img_r - img_r.min()) / (img_r.max() - img_r.min()) * 255

        if (img_r.sum() == 0):
            r_count = 0
            r_outline = np.zeros(img_r.shape)
        else:
            r_count, r_outline, r_labeled_image = identifyprimaryobjects(img_r, exclude_border_objects=False)
            hist_target = labeled_image_crop_region_counting(mask_target, r_labeled_image)
            hist_pred = labeled_image_crop_region_counting(mask_pred, r_labeled_image)
        
        hist_target[hist_target > 0] = 1
        hist_pred[hist_pred > 0] = 1

        _tp = hist_target + hist_pred
        _tp = _tp == 2
        _tp = _tp * 1
        total_tp = total_tp + _tp.sum()

        _tn = hist_target + hist_pred
        _tn = _tn == 0
        _tn = _tn * 1
        total_tn = total_tn + _tn.sum()

        _fp = hist_target - hist_pred
        _fp = _fp == -1
        _fp = _fp * 1
        total_fp = total_fp + _fp.sum()

        _fn = hist_target - hist_pred
        _fn = _fn == 1
        _fn = _fn * 1
        total_fn = total_fn + _fn.sum()
    
#     out = [
#         ["total_tp.", total_tp],
#         ["total_tn.", total_tn],
#         ["total_fp.", total_fp],
#         ["total_fn.", total_fn],
#         ["Sensitivity.", total_tp / (total_tp + total_fn)],
#         ["Specificity.", total_tn / (total_tn + total_fp)],
#         ["Acc.", (total_tp + total_tn) / (total_tp + total_tn + total_fp + total_fn)],
#     ]
#     out_str = '\n'.join([' '.join([str(w) for w in l]) for l in out])
    
#     print(out_str)
    
    print("total_tp.", total_tp)
    print("total_tn.", total_tn)
    print("total_fp.", total_fp)
    print("total_fn.", total_fn)
    print("Sensitivity.", total_tp / (total_tp + total_fn))
    print("Specificity.", total_tn / (total_tn + total_fp))
    print("Acc.", (total_tp + total_tn) / (total_tp + total_tn + total_fp + total_fn))
    
#     return out_str

## Evaluation

In [None]:
root = os.path.join('inferences', 'CDACS_HECR')

image_path_list = glob.glob(os.path.join(root, "*_input.png"))
target_path_list = glob.glob(os.path.join(root, "*_gt.png"))
pred_path_list = glob.glob(os.path.join(root, "*_pred.png"))

evaluate_method(image_path_list, target_path_list, pred_path_list)

In [None]:
root = os.path.join('inferences', 'CDACS_IFCR')

image_path_list = glob.glob(os.path.join(root, "*_input.png"))
target_path_list = glob.glob(os.path.join(root, "*_gt.png"))
pred_path_list = glob.glob(os.path.join(root, "*_pred.png"))

evaluate_method(image_path_list, target_path_list, pred_path_list)