# Annotator Variability and Model Performance Evaluation
This notebook is part of the paper: Automated Segmentation of the Dorsal Root Ganglia (DRG) in MRI by Nauroth-Kreß et al., 2024
The following cells contain the code used for the cacluctation of the inter-annotator metric scores and the model segmentation performance evaluation. After the cells for calculations folow cells with the matching visualization code.

The code is ready to use with any dataset matching the general structure.

In [8]:
import nibabel as nib
import numpy as np
import pandas as pd
from itertools import combinations
from pathlib import Path
from surface_distance import compute_dice_coefficient, compute_surface_distances, compute_average_surface_distance
from typing import List, Dict, Tuple


def load_files(data_dir : Path | str) -> List[nib.Nifti1Image]:
    """Load all nifti files in directory"""
    file_paths = sorted(Path(data_dir).glob('*.nii*'))
    imgs = {Path(path.stem).stem.split('_')[0]: nib.load(path) for path in file_paths}
    return imgs


def calc_metrics(gt : nib.Nifti1Image, pred : nib.Nifti1Image, labels : List[int] = None, spacing : List[int] = None) -> dict:
    """Calculate the DSC and ASD for the a prediction ground truth pair.
    :param labels: List of valid label values, default: None -> all valid.
    :param spacing: Voxel spacing (x,y,z), default: None -> take from gt header
    :return: Dictionary with metric scores per label ready for conversion to a pd.DataFrame
    """
    if not labels:
        labels = np.unique(gt.get_fdata())
    if not spacing:
        spacing = np.round(gt.header.get_zooms(), 2)
        if gt.header.get_zooms() != pred.header.get_zooms():
            raise ValueError('Image spacings do not match!')
    res = dict(
        Label=[],
        DSC=[],
        ASD=[],
    )
    for label in labels:
        if label != 0:
            gt_mask = gt.get_fdata() == label
            pred_mask = pred.get_fdata() == label
            dice = compute_dice_coefficient(gt_mask, pred_mask)
            sdist = compute_surface_distances(gt_mask, pred_mask, spacing_mm=spacing)
            avsdist = compute_average_surface_distance(sdist)[1]
            res['Label'] += [label]
            res['DSC'] += [dice]
            res['ASD'] += [avsdist]
    return res


def eval_model(gt : Dict[str, nib.Nifti1Image], pred : Dict[str, nib.Nifti1Image], valid_labels : List[int] = None, vspacing : List[int] = None) -> pd.DataFrame:
    """Calculate evaluation metrics for each prediction ground truth pair.
    Checks the pairing before calculation.
    :param gt: Dictinary containg subject IDs as keys and the corresponding GT segmentations as values.
    :param pred: Dictinary containg IDs as keys and the corresponding predicted segmentations as values.
    :param valid_labels: List of valid label values, default: None -> all valid.
    :param vspacing: Voxel spacing (x,y,z), default: None -> take from gt header
    :return: Dataframe with all results. Identification via subject IDs and label columns.
    """
    if len(gt) != len(pred):
        raise ValueError('Number of GT and Pred do not match!')
    
    results = []
    for (gt_key, gt_img), (pred_key, pred_img) in zip(gt.items(), pred.items()):
        # check if the IDs match
        if gt_key != pred_key:
            raise ValueError(f'Img IDs do not match: {gt_key} <> {pred_key}!')
        metrics = calc_metrics(gt_img, pred_img, labels=valid_labels, spacing=vspacing)
        tmp_dict = dict(
            Sub_ID = [gt_key for i in range(len(metrics['Label']))],
            **metrics
        )
        results.append(pd.DataFrame.from_dict(tmp_dict))
    
    return pd.concat(results)


def eval_anno_var(ds_dir: Path | str) -> pd.DataFrame:
    """Calcluate metrics for all possible annotator combinations.
    The annotator segmentation images have to be in annotator specific subdirectories.
    :param ds_dir: Path to directoy containing annotator subdirectories.
    :return: Dataframe with all metrics for each possile anotator pair.
    """
    anno_segs = {anno_dir.name: load_files(anno_dir) for anno_dir in Path(ds_dir, 'annotator_labels').iterdir() if anno_dir.is_dir()}
    all_combs = combinations(anno_segs, 2)
    all_comb_dfs = {f'{a}-{b}': eval_model(anno_segs[a], anno_segs[b]) for a, b in all_combs}
    out_df = pd.concat(all_comb_dfs.values(), keys=all_comb_dfs.keys(), names=['Combination']).reset_index(level='Combination')
    return out_df


def eval_on_dataset(ds_dir : Path | str) -> pd.DataFrame:
    """Caculate metrics for al models on a data set.
    Dataset directory must contain a model subdirectory containing a ground_truth and a predictions subdirectory.
    :param ds_dir: Path to directoy containing model subdirectories.
    :return: Dataframe with all metrics for each model.
    """
    model_preds = {model_dir.name: load_files(model_dir) for model_dir in Path(ds_dir, 'model_predictions').iterdir() if model_dir.is_dir()}
    gt = load_files(Path(ds_dir, 'staple_gt'))
    model_dfs = {model: eval_model(gt, pred) for model, pred in model_preds.items()}
    out_df =  pd.concat(model_dfs.values(), keys=model_dfs.keys(), names=['Model']).reset_index(level='Model')
    return out_df


def calc_mean_scores(df, col1, col2, metrics=['DSC', 'ASD'], round=2, total_mean_col=None) -> pd.DataFrame | Tuple[pd.DataFrame]:
    """Calculate the mean for a subset defined by two identification columns.
    :param df: Dataframe with two identification columns and an arbitrary number of value columns
    :param col1: Name of first level identification column
    :param col2: Name of second level identification column
    :param metrics: Name of value / metric columns to calculate mean on
    :param round: Number of decimal places
    :param total_mean_col: Name of first level identification column to calculate mean over all second level subsets
    :return: Dataframe with the mean values per metric and subsets | Tuple of Dataframe with the mean values per metric and second level subsets and Dataframe with the mean values per metric and first level subsets
    """
    res_list = []
    for c1 in df[col1].unique():
        for c2 in df[col2].unique():
            tmp_df = df[
                (df[col1]==c1) &
                (df[col2]==c2)
            ][metrics]
            tmp_dict = dict(
                **{col1: c1, col2: c2},
                **{metric: tmp_df[metric].mean().round(round) for metric in metrics}
            )
            res_list.append(tmp_dict)
    res = pd.DataFrame.from_records(res_list)
    if total_mean_col:
        total_res_list = []
        for c3 in res[total_mean_col].unique():
            tmp_dict = dict(
                **{total_mean_col: c3},
                **{metric: res[res[total_mean_col]==c3][metric].mean().round(round) for metric in metrics}
            )
            total_res_list.append(tmp_dict)
        return res, pd.DataFrame.from_records(total_res_list)
    return res

## Annotator Variability
Execute the following cell to compare the variability between annotators by calculating the metric scores for each possible annotator pair on all labels.
Input the path to the directory containing the testset directories.


In [12]:
data_dir = Path(input('Input paths to the directory containing test set subdirectories:\n'))

# evaluate on test sets
ts = {path.name: eval_anno_var(path) for path in data_dir.iterdir() if path.is_dir()}
av_full = pd.concat(
    ts.values(),
    keys=ts.keys(), 
    names=['DataSet']
).reset_index(level='DataSet')

# calculate the mean scores
res, total_res = calc_mean_scores(av_full, 'DataSet', 'Combination', total_mean_col='DataSet')
print('Metric scores all annotator combinations')
display(res)
print('Mean metric scores over all annotators')
display(total_res)

Metric scores all annotator combinations


Unnamed: 0,DataSet,Combination,DSC,ASD
0,HE,annotator2-annotator1,0.83,0.36
1,HE,annotator2-annotator3,0.85,0.28
2,HE,annotator1-annotator3,0.92,0.13
3,FD,annotator2-annotator1,0.85,0.34
4,FD,annotator2-annotator3,0.85,0.25
5,FD,annotator1-annotator3,0.9,0.13


Mean metric scores over all annotators


Unnamed: 0,DataSet,DSC,ASD
0,HE,0.87,0.26
1,FD,0.87,0.24


Save the results as .csv file.

In [13]:
av_full.to_csv(Path(input('Input save path:\n'), 'annotator_variability_scores.csv'), index=False)

## Model Performance Evaluation
### Calculation
Execute the following cell to calculate DSC and ASD for the test set labels predicted by each model.
Input the path to the directory containing the test set subdirectories.


In [19]:
data_dir = Path(input('Input paths to the directory containing test set subdirectories:\n'))

# evaluate on test sets
ts = {path.name: eval_on_dataset(path) for path in data_dir.iterdir() if path.is_dir()}
eval_full = pd.concat(
    ts.values(), 
    keys=ts.keys(), 
    names=['DataSet']
).reset_index(level='DataSet')
print('Detail metric scores')
display(eval_full.head())

# calculate the mean scores
print('Mean metric scores of all models')
calc_mean_scores(eval_full, 'DataSet', 'Model')

Detail metric scores


Unnamed: 0,DataSet,Model,Sub_ID,Label,DSC,ASD
0,HE,DC-CE-LCD,Sub-Ctrl007,1.0,0.923077,0.087373
1,HE,DC-CE-LCD,Sub-Ctrl007,2.0,0.929843,0.090706
2,HE,DC-CE-LCD,Sub-Ctrl007,3.0,0.897041,0.121676
3,HE,DC-CE-LCD,Sub-Ctrl007,4.0,0.945571,0.06396
0,HE,DC-CE-LCD,Sub-Ctrl015,1.0,0.78676,0.674039


Mean metric scores of all models


Unnamed: 0,DataSet,Model,DSC,ASD
0,HE,DC-CE-LCD,0.88,4.04
1,HE,DC-TopK,0.88,0.19
2,HE,DC-CE,0.85,8.6
3,FD,DC-CE-LCD,0.88,0.16
4,FD,DC-TopK,0.89,0.16
5,FD,DC-CE,0.89,0.17


Save the results as .csv file.

In [20]:
eval_full.to_csv(Path(input('Input save path: ', 'model_performance_scores.csv')), index=False)

### Visualization

In [21]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from scipy import stats
from matplotlib.patches import PathPatch


# taken from the answer of Thomas Kühn (Jan 31, 2018 at 11:41) on the stackoverflow thread: 
# Set space between boxplots in Python Graphs generated nested box plots with Seaborn? 
def adjust_box_widths(g, fac):
    """
    Adjust the withs of a seaborn-generated boxplot.
    """

    # iterating through Axes instances
    for ax in g.axes:

        # iterating through axes artists:
        for c in ax.get_children():

            # searching for PathPatches
            if isinstance(c, PathPatch):
                # getting current width of box:
                p = c.get_path()
                verts = p.vertices
                verts_sub = verts[:-1]
                xmin = np.min(verts_sub[:, 0])
                xmax = np.max(verts_sub[:, 0])
                xmid = 0.5*(xmin+xmax)
                xhalf = 0.5*(xmax - xmin)

                # setting new width of box
                xmin_new = xmid-fac*xhalf
                xmax_new = xmid+fac*xhalf
                verts_sub[verts_sub[:, 0] == xmin, 0] = xmin_new
                verts_sub[verts_sub[:, 0] == xmax, 0] = xmax_new

                # setting new width of median line
                for l in ax.lines:
                    if np.all(l.get_xdata() == [xmin, xmax]):
                        l.set_xdata([xmin_new, xmax_new])


def scatterbox(data, x, y, hue: str = None, order : List[str] = None, palette : list = None, broken : bool = None, xlabel : str = None, ylabel : str = None, legend : bool = True, ltitle : str = None, lloc : str = 'best', labels : List[str] = None, ylim : Tuple[float, float] = None, font_scale : int = 1, despine : bool = None, figsize : Tuple[int, int] = None):
    """Combined box and strip plot with or without broken y axis.

    :param data: pd.DataFrame containing the data
    :param x: Column name to plot as x
    :param y: Column name to plot as y
    :param hue: Column name hue encoding, default: None -> no hue encoding
    :param order: X order, default: None -> derive from dataframe
    :param palette: Color palette, default: None -> seaborn default palette
    :parm broken: Toggle broken axis, default: False
    :param xlabel: X axis label, default: None -> column name as label
    :param ylabel: Y axis label, default: None -> column name as label
    :param legend: Toggle legend, default: True
    :param ltitle: Legend title, default: None
    :param lloc: Legend location, default: 'best'
    :param labels: Legend labels, default: None -> derive from dataframe
    :param ylim: Y axis limits, default: None -> seaborn auto axis limits
    :param font_scale: Font scaling factor, default: 1
    :param despine: Toggle despine, default: False
    :param figsize: Figure size (width, height)
    :return: Figure object, axe object
    """
    count = len(data[hue].unique())
    # scale all plot fonts
    sns.set(font_scale=font_scale)
    sns.set_style('ticks')
    
    # we need two axe objects for a broken y axis
    if broken:
        fig,(ax,ax2) = plt.subplots(2, 1, figsize=figsize, sharex=True, height_ratios=broken['height_ratio'] if 'height_ratio' in broken else (abs(broken['bot_lim'][0])/abs(broken['bot_lim'][1]), abs(broken['top_lim'][0])/abs(broken['top_lim'][1])))
        ax.spines['bottom'].set_visible(False)
        ax2.spines['top'].set_visible(False)
    else:
        fig, ax = plt.subplots(1,1, figsize=figsize)

    # generate a box plot from the data
    plot = sns.boxplot(
        x=x, y=y, hue=hue, data=data, order=order,
        palette=palette, 
        ax=ax,
    )
    # overlay with a strip plot
    plot = sns.stripplot(
        x=x, y=y, hue=hue, data=data, order=order,
        jitter=True, dodge=True, marker='o', edgecolor='black', linewidth=1, alpha=0.7, palette=['white', 'white'],
        ax=ax,
    )
    
    # reformat the axes of the two axe objects (axo)
    if broken:
        # remove the x axis and set y range - upper axo
        plot.set(
            xlabel=None,
            xticklabels=[],
            ylabel=None,
            ylim=broken['top_lim'],
        )
        # specify custom y ticks - upper axo
        if 'top_ticks' in broken:
            plot.set(
                yticks=broken['top_ticks']
            )
        plot.tick_params(bottom=False)
        # fill lower axo with the same data
        plot2 = sns.boxplot(
            x=x, y=y, hue=hue, data=data, order=order,
            palette=palette, 
            ax=ax2,
        )
        plot2 = sns.stripplot(
            x=x, y=y, hue=hue, data=data, order=order,
            jitter=True, dodge=True, marker='o', edgecolor='black', linewidth=1, alpha=0.7, palette=['white', 'white'],
            ax=ax2,
        )
        # set y range - lower axo
        plot2.set(
            ylabel=None,
            ylim=broken['bot_lim'],
        )
        # specify custom y ticks - lower axo
        if 'bot_ticks' in broken:
            plot2.set(
                yticks=broken['bot_ticks']
            )
        # set custom x label
        if xlabel:
            if xlabel=='off':
                plot2.set(xlabel=None)
            else:
                plot2.set(xlabel=xlabel)
        # set custom y label
        if ylabel:
            if ylabel != 'off':
                plt.annotate(ylabel, (0.025, 0.5), xytext=(0.015, 0.5), rotation=90, xycoords='figure fraction', va='center')
        else:
            plt.annotate(y, (0.025, 0.5), xytext=(0.045, 0.5), rotation=90, xycoords='figure fraction', va='center')

        d = .5  # proportion of vertical to horizontal extent of the slanted line

        # add boken axe markers
        kwargs = dict(
            marker=[(-1, -d), (1, d)], markersize=12,
            linestyle="none", color='k', mec='k', mew=1, clip_on=False
        )
        if despine:
            ax.plot([0, ], [0, ], transform=ax.transAxes, **kwargs)
            ax2.plot([0, ], [1, ], transform=ax2.transAxes, **kwargs)
        else:
            ax.plot([0, 1], [0, 0], transform=ax.transAxes, **kwargs)
            ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)
    else:
        plot.set(
            ylim=ylim
        )
        # set custom x label
        if xlabel:
            if xlabel == 'off':
                plot.set(xlabel=None)
            else:
                plot.set(xlabel=xlabel)
        # set custom y label
        if ylabel:
            if ylabel == 'off':
                plot.set(ylabel=None)
            else:
                plot.set(ylabel=ylabel)
    # dspine toggle
    if despine:
        if broken:
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax2.spines['right'].set_visible(False)
        else:
            sns.despine(right=True)
    
    # add plot legend
    ax.get_legend().remove()
    if broken:
        ax2.get_legend().remove()
    if legend:
        handles, _labels = ax.get_legend_handles_labels()
        if not labels:
            labels=_labels
        if lloc=='best':
            ax.legend(handles, labels, loc=lloc, title=ltitle, frameon=True)
            fig.tight_layout()
        else:
            fig.legend(handles, labels, loc=lloc, title=ltitle, frameon=False)
            fig.tight_layout()
            if lloc==7:
                fig.subplots_adjust(right=0.75)
    else:
        fig.tight_layout()
    if broken:
        fig.subplots_adjust(hspace=0.05)
    return fig, plot


Execute the following cell to display the plots in a matplotlib pop-up window.<br>
The window allows for manual modifications of border width etc. and saving.

In [None]:
%matplotlib qt

You have to either execute the calculation cell above or load the results from a .csv file.

In [None]:
eval_full = pd.read_csv(input('Path to model evaluation result file:\n'))
eval_full.head()

#### Dice Similarity Coefficient (DSC)
For the paper figures the following attributes were adjusted in the pop-up window:
- top=0.95
- bottom=0.15
- left=0.17
- right=0.96

In [None]:
df = eval_full

fig, plot = scb_dsc = scatterbox(
    data=df, x="Model", y="DSC", hue='DataSet', order=['DC-CE', 'DC-CE-LCD', 'DC-TopK'],
    palette='Greys', 
    xlabel='Model', ylabel='DSC', ltitle= 'Test Set', labels=['HE', 'FD'],
    broken=dict(bot_lim=(-0.05,0.15), top_lim=(0.45,1), bot_ticks=[0,0.1], height_ratio=(10,3)),
    font_scale=1.4,
    despine=True,
    lloc='best',
    figsize=(5,5)
)
adjust_box_widths(fig, 0.9)

#### Average Surface Distance (ASD)
For the paper figures the following attributes were adjusted in the pop-up window:
- top=0.95
- bottom=0.15
- left=0.17
- right=0.96

In [None]:
df = eval_full

fig, plot = scatterbox(
    data=df, x="Model", y="ASD", hue='DataSet', order=['DC-CE', 'DC-CE-LCD', 'DC-TopK'],
    palette='Greys',
    xlabel='Model', ylabel='ASD', ltitle='Test Set', labels=['HE', 'FD'],
    font_scale=1.4,
    broken=dict(top_lim=(75,225), bot_lim=(0,1.1), top_ticks=[100,225], height_ratio=(1,10)),
    despine=True,
    lloc='best',
    figsize=(5,5)
)