# DRG Feature Analysis
This notebook is part of the paper: Automated Segmentation of the Dorsal Root Ganglia (DRG) in MRI by Nauroth-Kreß et al., 2023
The following cells contain the code used for the calculation of DRG features with the predicted and ground truth labels. After the cells for calculations follow cells with the matching visualization code.

The code is ready to use with any dataset matching the general structure.

In [17]:
import nibabel as nib
import numpy as np
import pandas as pd
from pathlib import Path
from scipy import stats
from typing import List, Dict, Tuple


def load_files(data_dir : Path | str) -> List[nib.Nifti1Image]:
    """Load all nifti files in directory
    
    :param data_dir: Path to directory containing a set of nifti images
    :return: Dictionary with image IDs derived from the file name as keys and nibabel.Nifti1Images as values
    """
    file_paths = sorted(Path(data_dir).glob('*.nii*'))
    imgs = {Path(path.stem).stem.split('_')[0]: nib.load(path) for path in file_paths}
    return imgs


def extract_vol(limg : nib.Nifti1Image) -> Dict[int, int]:
    """Calculate volume per label as voxel count multiplied with the image resolution.
    
    :param limg: Label image
    :return: Label volumes
    """
    # extract image resolution from header
    vol_factor = np.round(limg.header.get_zooms(), 2).prod()

    arr = limg.get_fdata()
    labels = np.unique(arr)
    volumes = {label: np.count_nonzero(arr[arr == label])*vol_factor for label in labels if label !=0}
    return volumes


def extract_int(img : nib.Nifti1Image, limg : nib.Nifti1Image) -> Tuple[Dict[int, np.ndarray], Dict[int, float]]:
    """Exract voxel intensities and mean intensity per label.
    :param img: Intensity image
    :param limg: Label image
    :return: Dictionary with labels as keys and numpy.ndarrays containing voxel intensities as values and dictionary with labels as keys and mean intensity values as values
    """
    arr = img.get_fdata()
    larr = limg.get_fdata()
    labels = np.unique(larr)
    voxel_ints = {label: arr[larr == label] for label in labels if label !=0}
    mean_int = {label: np.round(np.mean(values), 3) for label, values in voxel_ints.items()}
    return voxel_ints, mean_int
    

def extract_features(imgs : Dict[str, nib.Nifti1Image], limgs : Dict[str, nib.Nifti1Image], return_vints : bool =False, label_dict : Dict[int, str] = {1: 'S1l', 2: 'S1r', 3: 'L5l', 4: 'L5r'}) -> pd.DataFrame:
    """Extract volume and intensity features from intensity-label-image-pair.
    
    :param imgs: Dictionary of image IDs and intensity images
    :param limgs: Dictionary of image IDs and label images
    :return_vints: Optional, if True return the voxel intensities of each label additional to the mean intensities, default: False
    :label_dict: Optional, mapping of int label values to label strings, default: {1: 'S1l', 2: 'S1r', 3: 'L5l', 4: 'L5r'}
    :return: Extracted features per image and label in tabular form
    """
    dicts = []
    for (img_id, img), (limg_id, limg) in zip(sorted(imgs.items()), sorted(limgs.items())):
        # Check if the IDs of the intensity and label images match
        if img_id == limg_id:
            volumes = extract_vol(limg)
            voxel_ints, mean_ints = extract_int(img, limg)
            for (vlabel, vol), (milabel, mint), (vilabel, vints) in zip(volumes.items(), mean_ints.items(), voxel_ints.items()):
                # Check if the labels of the volume, mean intensity, and voxel intensity entries match
                if vlabel == milabel == vilabel:
                    if return_vints:
                        tmp_dict = dict(
                                Sub_ID=img_id,
                                Label=label_dict[vlabel],
                                Volume=vol,
                                VoxelInt=vints,
                                MeanInt=mint,
                            )
                    else:
                        tmp_dict = dict(
                                Sub_ID=img_id,
                                Label=label_dict[vlabel],
                                Volume=vol,
                                MeanInt=mint,
                            )
                    dicts.append(tmp_dict)
                # Raise an error if the labels of the volume, mean intensity, and voxel intensity entries do not match
                else:
                    raise ValueError('Labels of volume and intensity entry do not match!')
        # Raise an error if the IDs of the intensity and label images do not match   
        else:
            raise ValueError(f'IDs of intensity and label images do not match!\n{img_id} <> {limg_id}')
    
    df = pd.DataFrame.from_records(dicts)
    try:
        df['Sub_ID'] = df.Sub_ID.astype(int)
    except:
        pass
    
    return df


def add_meta_info(df : pd.DataFrame, meta_df : pd.DataFrame, meta_keys: List):
    """Add meta information columns to a dataframe.
    The meta information is added based on subject identifiers. Both dataframes must have a Sub_ID column.
    The target dataframe and the mata information dataframe must not contain all identifiers. If the target 
    dataframe contains IDs not present in the meta information dataframe the cells will be filled with np.nan 
    values.
    
    :param df: Target dataframe, must have a Sub_ID column for mapping
    :param meta_df: Dataframe containing meta information, must have a Sub_ID column for mapping
    :param meta_keys: Names of meta information columns that should be added to the target dataframe
    :return: Copy of the target dataframe with added meta information columns
    """
    out_df = df.copy()
    for key in meta_keys:
        out_df[key] = np.nan
        for row in df.iterrows():
                out_df.loc[row[0], key] = meta_df[meta_df.Sub_ID==row[1].Sub_ID][key].item()
    return out_df


## DRG Feature calculation
Execute the following cell to load the two testsets and calculate the label volumes and mean intensities for each image.<br>
Input the path to the directory containing the testset directories.

Testset directories must contain an images (intensity images, nifti files), model_predictions (label images, nifti files) and staple_gt (label images, nifti files) subdirectory as well as a metainformation.csv file.<br>

In [18]:
data_dir = input('Path to directory containing test set subdirectories:\n')

# extract the features from the data set of healthy volunteers (HE)
he_imgs = load_files(Path(data_dir, 'HE/images/'))
he_pred = load_files(Path(data_dir, 'HE/model_predictions/DC-TopK'))
he_gt = load_files(Path(data_dir, 'HE/staple_gt/'))
he_fpred = extract_features(he_imgs, he_pred)
he_fgt = extract_features(he_imgs, he_gt)

# add meta information
he_meta = pd.read_csv(Path(data_dir, 'HE/metainformation.csv'))
he_fpred = add_meta_info(he_fpred, he_meta, ['Sex'])
he_fgt = add_meta_info(he_fgt, he_meta, ['Sex'])

# extract the features from the data set of FD patients (FD)
fd_imgs = load_files(Path(data_dir, 'FD/images/'))
fd_pred = load_files(Path(data_dir, 'FD/model_predictions/DC-TopK'))
fd_gt = load_files(Path(data_dir, 'FD/staple_gt/'))
fd_fpred = extract_features(fd_imgs, fd_pred)
fd_fgt = extract_features(fd_imgs, fd_gt)

# add meta information
fd_meta = pd.read_csv(Path(data_dir, 'FD/metainformation.csv'))
fd_fpred = add_meta_info(fd_fpred, fd_meta, ['Sex'])
ref_fgt = add_meta_info(fd_fgt, fd_meta, ['Sex'])

# stack into combined data frames
fpred = pd.concat([he_fpred, fd_fpred], keys=['HE', 'FD']).reset_index(level=0, names='DataSet')
fgt = pd.concat([he_fgt, fd_fgt], keys=['HE', 'FD']).reset_index(level=0, names='DataSet')

print('DRG features predicted labels')
display(fpred.head())
print('\nDRG features gt labels')
display(fgt.head())

Save the results as .csv file.

In [16]:
save_dir = input('Path to save directory:\n')

fpred.to_csv(Path(save_dir, 'DRG_features_pred.csv'))
fgt.to_csv(Path(save_dir, 'DRG_features_gt.csv'))

## Visualizaton

In [21]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from typing import List, Dict, Tuple


def multigrid_feature_correlation_pred_ref(
        df_pred : pd.DataFrame, 
        df_gt : pd.DataFrame, 
        feature : str, 
        xlabel : str=None, 
        ylabel : str=None, 
        legend : bool=False, 
        axe_lim_dist : float = 10, 
        outlier_threshold : float = 2, 
        anno_pos : Tuple[int, int]=(500,390), 
        font_scale : float=1
    ) -> sns.FacetGrid:
    """
    Plots a scatter plot of the correlation between a feature in two datasets
    with a regression line and excludes outliers based on a given threshold value.
    
    :param df_pred: Dataframe containing label features of one or multiple dataset (DataSet column required), calculated with predicted labels
    :param df_gt: Dataframe containing label features of one or multiple dataset (DataSet column required), calculated with ground truth labels
    :param feature: Name of a feature column
    :param xlabel: Optional, change xlabel to value, if None derive from colum names, default: None
    :param ylabel: Optional, change ylabel to value, if None derive from colum names, default: None
    :param legend: Optional, if True show legend, default: False
    :param axe_lim_dist: Optional, distance between biggest/smallest value and axe limits, default: 10
    :param outlier_threshold: Optional, this value will be multiplied with the standard deviation to calculate a threshold value for exclusion of outliers, no thresholding if 0, default: 2
    :param anno_pos: Optional, position of the annotation field, default: (500,390)
    :param font_scale: Optional, font scaling factor, default: 1
    :return: seaborn.FacetGrid object
    """
    sns.set(font_scale=font_scale)
    sns.set_style('ticks')
    # Combine into one df and clean outliers
    if not xlabel:
        xlabel = f'{feature}_ref'
    if not ylabel:
        ylabel = f'{feature}_pred'
    data = pd.concat(
        [
            df_pred.set_index('Sub_ID')[['DataSet',feature]].rename(columns={feature:ylabel}),
            df_gt.set_index('Sub_ID')[feature].rename(xlabel, axis=0)
        ], 
        axis=1
    ).reset_index()

    # Make grid
    axe_lim = (data.min()[-2:].min() - axe_lim_dist, data.max()[-2:].max() + axe_lim_dist)
    grid = sns.FacetGrid(
        data=data, col='DataSet',
        sharex=True,
        sharey=True,
        xlim=axe_lim,
        ylim=axe_lim,
        height=5,
    )

    # Add scatterplots
    grid.map(sns.scatterplot, xlabel, ylabel, color='black', label=f'Data Points')
    ax1, ax2 = grid.axes_dict['HE'], grid.axes_dict['FD']

    if outlier_threshold:
        corr = data[ylabel] / data[xlabel]
        cleaned_data = data[
            np.logical_and(
                np.mean(corr) - outlier_threshold * np.std(corr)<corr,
                corr<np.mean(corr) + outlier_threshold * np.std(corr)
            )
        ]
        outliers = data.loc[data.index.difference(cleaned_data.index)]
        outliers_sp1 = sns.scatterplot(
            x=outliers[outliers.DataSet=='HE'][xlabel], 
            y=outliers[outliers.DataSet=='HE'][ylabel], 
            marker='o', edgecolor='black', linewidth=1, color='white',
            ax=ax1,
            label='Outliers',
        )
        outliers_sp2 = sns.scatterplot(
            x=outliers[outliers.DataSet=='FD'][xlabel], 
            y=outliers[outliers.DataSet=='FD'][ylabel], 
            marker='o', edgecolor='black', linewidth=1, color='white',
            ax=ax2,
            label='Outliers',
        )
        reg_data = cleaned_data
    else:
        reg_data = data

    # calc and add regression line
    he_x = reg_data[reg_data.DataSet=='HE'][xlabel]
    he_y = reg_data[reg_data.DataSet=='HE'][ylabel]
    he_reg = stats.linregress(he_x, he_y)
    regline1 = ax1.axline((0,he_reg.intercept), slope=he_reg.slope, color='black', alpha=0.5, label='Regression Line')
    ax1.annotate(
        f'y={np.round(he_reg.slope,2)}*x+{np.round(he_reg.intercept,2)}\nPCC: {np.round(he_reg.rvalue,2)}'.replace('+-','-'), 
        xy=anno_pos,
        xycoords='axes fraction',
        bbox=dict(facecolor='none', edgecolor='black', boxstyle='round'),
    )

    fd_x = reg_data[reg_data.DataSet=='FD'][xlabel]
    fd_y = reg_data[reg_data.DataSet=='FD'][ylabel]
    fd_reg = stats.linregress(fd_x, fd_y)
    regline2 = ax2.axline((0,fd_reg.intercept), slope=fd_reg.slope, color='black', alpha=0.5, label='Regression Line')
    ax2.annotate(
        f'y={np.round(fd_reg.slope,2)}*x+{np.round(fd_reg.intercept,2)}\nPCC: {np.round(fd_reg.rvalue,2)}'.replace('+-','-'), 
        xy=anno_pos,
        xycoords='axes fraction',
        bbox=dict(facecolor='none', edgecolor='black', boxstyle='round'),
    )

    # add zero line
    preg1 = ax1.axline((0, 0), slope=1, color='red', linestyle='--', alpha=0.5, label='PCC +1')
    preg2 = ax2.axline((0, 0), slope=1, color='red', linestyle='--', alpha=0.5, label='PCC +1')

    ax1.get_legend().remove()
    ax2.get_legend().remove()
    handles, labels = ax1.get_legend_handles_labels()
    if legend:
        grid.add_legend(dict(zip(labels, handles)))
    return grid

Execute the following cell to display the plots in a matplotlib pop-up window.<br>
The window allows for manual modifications of border width etc. and saving.

In [11]:
%matplotlib qt

You have to either execute the calculation cell above or load the results from a .csv file.

In [None]:
fpred = pd.read_csv('../data/feature_extraction/TestSet/DRG_features.csv')
fgt = pd.read_csv('../data/feature_extraction/TestSet/reference_DRG_features.csv')

**Volume Correlation**
For the paper figures the following attributes were adjusted in the pop-up window:
- top=0.92
- bottom=0.15
- left=0.125
- right=0.98

In [None]:
multigrid_feature_correlation_pred_ref(
    fpred, fgt, 'Volume', 
    xlabel='Label Size GT [vx.]', ylabel='Label Size Pred. [vx.]', 
    axe_lim_dist= 50, outlier_threshold = 2, anno_pos=(0.42,0.05), font_scale=1.6, legend=False,
)

**DRG Mean Signal Intensity Correlation**
For the paper figures the following attributes were adjusted in the pop-up window:
- top=0.92
- bottom=0.15
- left=0.1
- right=0.98


In [None]:
multigrid_feature_correlation_pred_ref(
    fpred, fgt, 'MeanInt', 
    xlabel='Signal Intensity GT [a.u.]', ylabel='Signal Intensity Pred. [a.u.]', 
    outlier_threshold = 2, anno_pos=(0.55, 0.05), font_scale=1.4, legend=False,
)

**DRG Mean Signal Intensity Comparison**
For the paper figures the following attributes were adjusted in the pop-up window:
- top=0.92
- bottom=0.08
- left=0.1
- right=0.98


In [None]:
# generate new dataframe containing only males but both label types (gt, predicted)
int_diff_male = pd.concat([fgt[fgt.Sex=='M'], fpred[fpred.Sex=='M']], keys=['GT', 'Prediction'], names=['Type']).reset_index(level='Type')

sns.set(font_scale=1.4)
sns.set_style('ticks')
grid = sns.catplot(int_diff_male, x='DataSet', y='MeanInt', col='Type', kind='box', palette='Greys')
grid.map(sns.stripplot, 'DataSet', 'MeanInt', marker='o', color='white', edgecolor='black', linewidth=1, alpha=0.7)
grid.set_ylabels('DRG Mean Signal Intensity')
grid.set_xlabels('Data Set')
grid.set_titles('{col_name}')

# Test for significant difference
t_gt, p_gt = stats.ttest_ind(
    int_diff_male[(int_diff_male['Type'] == 'GT') & (int_diff_male['DataSet'] == 'HE')]['MeanInt'].to_list(),
    int_diff_male[(int_diff_male['Type'] == 'GT') & (int_diff_male['DataSet'] == 'FD')]['MeanInt'].to_list(),
    equal_var=False,
)
t_pred, p_pred = stats.ttest_ind(
    int_diff_male[(int_diff_male['Type'] == 'Prediction') & (int_diff_male['DataSet'] == 'HE')]['MeanInt'].to_list(),
    int_diff_male[(int_diff_male['Type'] == 'Prediction') & (int_diff_male['DataSet'] == 'FD')]['MeanInt'].to_list(),
    equal_var=False,
)
print(f'GT:         t-statistic: {t_gt}; p-value: {p_gt}\nPrediction: t-statistic: {t_pred}; p-value: {p_pred}')