## Volume averaged Jaccard index estimation for a segmented image and a ground truth image

This notebook may be used for Volume averaged Jaccard index estimation between two segmented images.

Original repository: https://gitlab.inria.fr/mosaic/publications/seg_compare

In [None]:
import os
import numpy as np
import pandas as pd
import scipy.ndimage as nd
import matplotlib.pyplot as plt

from timagetk.io import imread
from timagetk.components import LabelledImage
from timagetk.visu.util import glasbey

from ctrl.image_overlap import fast_image_overlap3d

%matplotlib inline

### Data and parameters

Make sure to modify paths and filenames appropriately to match your file architecture.

In [None]:
main_directory = os.path.dirname(os.getcwd()) # for example

# - Set the predicted segmentation path
target_path = main_directory + 'path to segmented image'

# - Set the ground-truth segmentation path
reference_path = main_directory + 'path to ground truth image'

In [None]:
slice_index = 50

### Image loading

We assume that the images have the same shape and voxelsize. We also assume that the background label is equal to 1 in all images.

In [None]:
BACKGROUND_LABEL = 1

target_img = LabelledImage(imread(target_path), no_label_id=0)
reference_img = LabelledImage(imread(reference_path), no_label_id=0)

### Image slice visualization

In [None]:
def plot_segmentation(img_seg, title, ax, slice_index, extent):
    slice_img = img_seg.get_array()[:,:,slice_index].T

    barycenter = nd.center_of_mass(np.ones_like(slice_img), slice_img, np.unique(slice_img))
    barycenter = dict(zip(np.unique(slice_img), barycenter * np.array(reference_img.voxelsize[:2])))

    ax.imshow(slice_img%256, cmap='glasbey', vmin=0, vmax=255, extent=extent, interpolation='none')
    ax.set_title(title)

    # - Add contour and label annotation
    for label in np.unique(slice_img):
        if label != 1:
            ax.contour(slice_img == label, linewidths=0.1, extent=extent,
                                 origin='upper', colors='k')
            ax.text(barycenter[label][1],
                              barycenter[label][0],
                              str(label),ha='center',va='center',size=12)
    ax.axis('off')

def plot_contour(img_seg, ax, slice_index, extent, color='r'):
    slice_img = img_seg.get_array()[:,:,slice_index].T

    barycenter = nd.center_of_mass(np.ones_like(slice_img), slice_img, np.unique(slice_img))
    barycenter = dict(zip(np.unique(slice_img), barycenter * np.array(reference_img.voxelsize[:2])))

    ax.imshow(np.zeros_like(slice_img), cmap='Reds', vmin=0, vmax=1, extent=extent, interpolation='none', alpha=0)
    ax.set_title('Superimpose the edges and barycenter')
    
    # - Add contour and label annotation
    for label in np.unique(slice_img):
        if label != 1:
            ax.contour(slice_img == label, extent=extent,
                                 origin='upper', colors=color)
            ax.scatter(barycenter[label][1],
                       barycenter[label][0],
                       s=12, color=color)
    ax.axis('off')
    
figure = plt.figure(0)
figure.clf()

extent = (0, reference_img.extent[0], reference_img.extent[1],0)

# - reference img
figure.add_subplot(1, 3, 1)
plot_segmentation(img_seg=reference_img, title='Ground-truth segmentation',
                  ax=figure.gca(), slice_index=slice_index, extent=extent)

# - superimpose images
figure.add_subplot(1, 3, 2)
plot_contour(img_seg=reference_img, ax=figure.gca(), slice_index=slice_index, extent=extent, color='tab:green')
plot_contour(img_seg=target_img, ax=figure.gca(), slice_index=slice_index, extent=extent, color='tab:red')

# - target img
figure.add_subplot(1, 3, 3)
plot_segmentation(img_seg=target_img, title='Predicted segmentation',
                  ax=figure.gca(), slice_index=slice_index, extent=extent)

figure.set_size_inches(10*2,15)
figure.subplots_adjust(wspace=0,hspace=0)
figure.tight_layout()

### VJI computation

Compute the jaccard index between all possible pairs of overlapping cells in the image.

In [None]:
target_cells = [lab for lab in target_img.labels() if lab != BACKGROUND_LABEL] # avoid background
reference_cells = [lab for lab in reference_img.labels() if lab != BACKGROUND_LABEL] # avoid background

# - Get the jaccard index of pair of reference and target cells
# - Reference and target pair of cells that have no intersection are not returned.
df_jaccard = fast_image_overlap3d(mother_seg=target_img,
                               daughter_seg=reference_img,
                               mother_label = target_cells,
                               daughter_label= reference_cells,
                               method='jaccard', ds=1, verbose=False)
df_jaccard.columns = ['target', 'reference', 'jaccard'] # relabel columns

print(df_jaccard[:15])

Identify for each reference cell, the target cell that maximize their jaccard index.

In [None]:
# - For each reference labels find the target labels that maximize the jaccard index
df_jaccard = df_jaccard.loc[df_jaccard.groupby('reference')['jaccard'].idxmax()]
print(df_jaccard[:5])

Add the missing reference cells (no intersection with any target cells) and calculate the volume of each reference cells.

In [None]:
# - Assert that all the reference labels are in the dataframe (if a reference cell is totally included
#   in the target background it wont appear)
missing_cells = set(reference_cells) - set(df_jaccard.reference.values)

if len(missing_cells) > 0:
    # - add them with a jaccard index of 0
    for lab in missing_cells:
        df_jaccard = df_jaccard.append({'target': 0, 'reference': lab, 'jaccard': 0}, ignore_index=True)

# - Add the corresponding volumes (in voxel unit)
cell_ref_volume = nd.sum(np.ones_like(reference_img), reference_img, reference_img.labels())
cell_ref_volume = {lab: vol for lab, vol in zip(reference_img.labels(), cell_ref_volume)}

df_jaccard['volume'] = df_jaccard.apply(lambda x: cell_ref_volume[x.reference], axis=1)
print(df_jaccard[:5])

Calculate the weighted jaccard index by multiplying the jaccard index by the cell volume.

In [None]:
df_jaccard['weighted_jaccard'] = df_jaccard.jaccard * df_jaccard.volume
print(df_jaccard[:5])

The volume averaged jaccard index is obtained by summing all the weighted jaccard index and divide them by the total volume of the ground-truth tissue.

In [None]:
total_cell_volume = sum(df_jaccard.volume.values)
sum_weighted_ji = sum(df_jaccard.weighted_jaccard.values)
vji = sum_weighted_ji / total_cell_volume
print(f'Reference img: {os.path.basename(reference_path)} \nTarget img: {os.path.basename(target_path)}')
print(f"Volume average jaccard index : {np.around(vji, 5)}")

## Save the output Jaccard index for each cell in a .csv file. 
Note that you will need this CSV file for 3D visualization of segmentation quality using Morphonet

In [None]:
save_directory = main_directory # set the save folder

df_jaccard.to_csv(main_directory + '/ji_results.csv', index=False)