In [1]:
import matplotlib.pyplot as plt
import matplotlib.collections as clt
import matplotlib
import numpy as np
import pandas as pd
import seaborn as sns
from tifffile import imread

from skimage.measure import label, regionprops
import pickle
from scipy.spatial.distance import cdist
from tqdm.notebook import tqdm

In [2]:
data_path = '/data/duanb/results/segmentation/NeuN/TheFirstnnUNet/'

In [3]:
datasets = ['20210929_HDCF_R56-8_fibers_6-11_hiRes_NeuN_histMatched_medFilt']

In [4]:
samples = {
    'raw': imread(f'{data_path}{datasets[0]}.tif'),
    '10': imread(f'{data_path}10/{datasets[0]}.tif'),
    '30': imread(f'{data_path}30/{datasets[0]}.tif'),
    '50': imread(f'{data_path}50/{datasets[0]}.tif'),
    '100': imread(f'{data_path}100/{datasets[0]}.tif'),
    '400': imread(f'{data_path}400/{datasets[0]}.tif'),
}

In [5]:
def stat_results(seg, ref):
    seg = label(seg)
    ref = label(ref)
    
    props = regionprops(seg)
    ref_props = regionprops(ref)
    
    centriods = np.array([np.array(p.centroid) for p in props])
    ref_centriods = np.array([np.array(p.centroid) for p in ref_props])

    dists = cdist(centriods, ref_centriods)
    
    ids = np.argmin(dists, axis=-1)

    area = np.array([props[i].area - ref_props[idx].area for i, idx in zip(range(len(ids)), ids)])
    centroid = np.array([np.array(props[i].centroid) - np.array(ref_props[idx].centroid) for i, idx in zip(range(len(ids)), ids)])
    diameter = np.array([props[i].equivalent_diameter_area - ref_props[idx].equivalent_diameter_area for i, idx in zip(range(len(ids)), ids)])

    return np.array([dists[i, idx] for i, idx in zip(range(len(ids)), ids)]), area, centroid, diameter
    # return area, centroid, diameter

In [6]:
def find_most_overlapped(ref, seg, ref_props, props):
    matched = []
    no_match = []
    for i in tqdm(range(1, len(ref_props)+1)):
        vals, counts = np.unique(seg[ref==i].flatten(), return_counts=True)
        m = counts.argmax()
        if vals[m] == 0:
            no_match.append(i - 1)
        else:
            # remove outlier
            dist = np.array(np.array(ref_props[i - 1].centroid) - np.array(props[vals[m] - 1].centroid))
            dist = np.power(np.sum(dist ** 2), 0.5)
            if dist <= 10:
                matched.append((i - 1, vals[m] - 1))
            else:
                no_match.append(i - 1)

    return matched, no_match

In [7]:
def stat_resultsV2(seg, ref):
    # TP, FP, FN
    seg = label(seg)
    ref = label(ref)
    
    props = regionprops(seg)
    ref_props = regionprops(ref)

    # find pairs of segmentation id
    ref_matched, ref_no_match = find_most_overlapped(ref, seg, ref_props, props)
    seg_matched, seg_no_match = find_most_overlapped(seg, ref, props, ref_props)

    print(len(ref_matched), len(ref_no_match), len(seg_matched), len(seg_no_match))
    
    centriods = np.array([np.array(p.centroid) for p in props])
    ref_centriods = np.array([np.array(p.centroid) for p in ref_props])

    dists = cdist(ref_centriods, centriods)

    return {'TP': {'dists': [dists[r, s] for r, s in ref_matched], 
                   'centroid': np.array([np.array(ref_props[r].centroid) - np.array(props[s].centroid) for r, s in ref_matched])}, 
            'FN': {'area': np.array([ref_props[r].area for r in ref_no_match]),
                   'rate': len(ref_no_match)/(len(ref_matched) + len(ref_no_match))},
            'FP': {'area': np.array([props[s].area for s in seg_no_match]),
                   'rate': len(seg_no_match)/(len(seg_matched) + len(seg_no_match))}
           }


In [None]:
def stat_resultsV3(seg, ref):
    # TP, FP, FN
    seg = label(seg)
    ref = label(ref)
    
    props = regionprops(seg)
    ref_props = regionprops(ref)

    # find pairs of segmentation id
    ref_matched, ref_no_match = find_most_overlapped(ref, seg, ref_props, props)
    seg_matched, seg_no_match = find_most_overlapped(seg, ref, props, ref_props)

    print(len(ref_matched), len(ref_no_match), len(seg_matched), len(seg_no_match))
    
    centriods = np.array([np.array(p.centroid) for p in props])
    ref_centriods = np.array([np.array(p.centroid) for p in ref_props])

    dists = cdist(ref_centriods, centriods)

    return {
            'FN': {'num_pixels': np.array([ref_props[r].num_pixels for r in ref_no_match]),
                   'rate': len(ref_no_match)/(len(ref_matched) + len(ref_no_match))},
            'FP': {'num_pixels': np.array([props[s].num_pixels for s in seg_no_match]),
                   'rate': len(seg_no_match)/(len(seg_matched) + len(seg_no_match))},
            'TP': {'dists': [dists[r, s] for r, s in ref_matched], 
                   'centroid': np.array([np.array(ref_props[r].centroid) - np.array(props[s].centroid) for r, s in ref_matched]),
                   'num_pixels': np.array([ref_props[r].num_pixels for r, _ in ref_matched])}
           }

In [None]:
for k in samples.keys():
    if k == 'raw':
        continue
    else:
        res = stat_resultsV3(samples[k], samples['raw'])
        with open(f'{datasets[0]}_cs{k}.pkl', 'wb') as fp:
            pickle.dump(res, fp)