# Auto-segmentation Inference & Analysis

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AustralianCancerDataNetwork/pydicer/blob/main/examples/AutoSegmentation.ipynb)

A common task when working the medical imaging data, it to run an auto-segmentation model
(inference) on the images in your dataset. If you have manual definitions of the same structures
available in your dataset, you will typically want to compare the auto and the manual
segmentations, computing metrics and produce plots and visualisation.

This example notebook will guide you through the process of performing model inference and analysis
of those structures. We will use a single atlas-based segmentation model for demonstration
purposes.

> Warning: The auto-segmentation results produced by the example in this notebook are poor. The
> atlas based segmentation function is optimised for runtime is purely provided to demonstrate how
> to run and analyse an auto-segmentation model.

In [None]:
try:
    from pydicer import PyDicer
except ImportError:
    !pip install pydicer
    from pydicer import PyDicer

import os
import logging

from pathlib import Path

import SimpleITK as sitk

from platipy.imaging.registration.utils import apply_transform
from platipy.imaging.registration.linear import linear_registration
from platipy.imaging.registration.deformable import fast_symmetric_forces_demons_registration

from pydicer.utils import fetch_converted_test_data

from pydicer.generate.segmentation import segment_image, read_all_segmentation_logs

from pydicer.analyse.compare import (
    compute_contour_similarity_metrics,
    get_all_similarity_metrics_for_dataset,
    prepare_similarity_metric_analysis
)


## Setup PyDicer

For this example, we will use the LCTSC test data which has already been converted using PyDicer.
We also initialise our PyDicer object.

In [None]:
working_directory = fetch_converted_test_data("./lctsc_autoseg", dataset="LCTSC")
pydicer = PyDicer(working_directory)

## Prepare Atlas

Since we will use a single atlas-based segmentation model, we must
split our data, selecting one case as our `atlas` and the remaining cases as the `validation` set.
We use the PyDicer `dataset preparation` module to split our dataset.

In [None]:
df = pydicer.read_converted_data()

# Specify the patient ID to use as the atlas case
atlas_dataset = "atlas"
atlas_case  = "LCTSC-Train-S1-001"
df_atlas = df[df.patient_id==atlas_case]

# And the remaining cases will make up our validation set
validation_dataset = "validation"
df_validation = df[df.patient_id!=atlas_case]

# Use the dataset preparation module to prepare these two data subsets
pydicer.dataset.prepare_from_dataframe(atlas_dataset, df_atlas)
pydicer.dataset.prepare_from_dataframe(validation_dataset, df_validation)

## Define Segmentation Function

Now that our `atlas` and `validation` sets are ready, we will define our function which will run
our simple single atlas-based auto-segmentation model for us. This example uses the [atlas-based
segmentation tools available in PlatiPy](https://pyplati.github.io/platipy/_examples/atlas_segmentation.html#).

In a real-world scenario, you will have your own segmentation model you wish to apply. You should
integrate this model into such a function which accepts an image and returns a dictionary of
structures.

To get started, we recommend you try running the [TotalSegmentator](https://github.com/wasserth/TotalSegmentator) 
model on your CT data. PyDicer already has a function ready which runs this model, check out
[run_total_segmentator](https://australiancancerdatanetwork.github.io/pydicer/generate.html#pydicer.generate.models.run_total_segmentator).

In [None]:
def single_atlas_segmentation(img):
    """Segment an image using a single atlas case

    Args:
        img (SimpleITK.Image): The SimpleITK image to segment.

    Returns:
        dict: The segmented structure dictionary
    """

    # Load the atlas case image
    atlas_img_row = df_atlas[df_atlas.modality=="CT"].iloc[0]
    atlas_img = sitk.ReadImage(str(Path(atlas_img_row.path).joinpath("CT.nii.gz")))

    # Load the atlas case structures
    atlas_structures = {}
    atlas_struct_row = df_atlas[df_atlas.modality=="RTSTRUCT"].iloc[0]
    for struct_path in Path(atlas_struct_row.path).glob("*.nii.gz"):
        struct_name = struct_path.name.replace(".nii.gz", "")
        atlas_structures[struct_name] = sitk.ReadImage(str(struct_path))

    # Use a simple linear (rigid) registration to align the input image with the atlas image
    img_ct_atlas_reg_linear, tfm_linear = linear_registration(
        fixed_image = img,
        moving_image = atlas_img,
        reg_method='similarity',
        metric='mean_squares',
        optimiser='gradient_descent',
        shrink_factors=[4, 2],
        smooth_sigmas=[2, 0],
        sampling_rate=1.0,
        number_of_iterations=50,
    )

    # Perform a fast deformable registration
    img_ct_atlas_reg_dir, tfm_dir, dvf = fast_symmetric_forces_demons_registration(
        img,
        img_ct_atlas_reg_linear,
        ncores=4,
        isotropic_resample=True,
        resolution_staging=[4],
        iteration_staging=[20],
    )

    # Combine the two transforms
    tfm_combined = sitk.CompositeTransform((tfm_linear, tfm_dir))

    # Apply the transform to the atlas structures
    auto_segmentations = {}
    for s in atlas_structures:
        auto_segmentations[s] = apply_transform(
            atlas_structures[s],
            reference_image=img,
            transform=tfm_combined
        )

    return auto_segmentations

## Run Auto-segmentation

The `segment_dataset` function will run over all images in our dataset and will pass the images to
a function we define for segmentation. We pass in the name of our `validation_dataset` so that only
the images in this dataset will be segmented.

In [None]:
segment_id = "atlas" # Used to generate the ID of the resulting auto-segmented structure sets

pydicer.segment_dataset(segment_id, single_atlas_segmentation, dataset_name=validation_dataset)

We can use PyDicer's `visualisation` module to produce snapshots of the auto-segmentations
produced.

In [None]:
pydicer.visualise.visualise(force=False)

## Read Segmentation Logs

After running the auto-segmentation on across the dataset, we can fetch the logs to confirm that
everything went well. This will also let us inspect the runtime of the segmentation. In case
something went wrong, we can use these logs to help debug the issue.

In [None]:
# Read the segmentation log DataFrame
df_logs = read_all_segmentation_logs(working_directory)
df_logs

In [None]:
# Use some Pandas magic to produce some stats on the segmentation runtime
df_success = df_logs[df_logs.success_flag]
agg_stats = ["mean", "std", "max", "min", "count"]
df_success[["segment_id", "total_time_seconds"]].groupby("segment_id").agg(agg_stats)

## Auto-segmentation Analysis

Now that our auto-segmentation model has been run on our `validation` set, we can compare these
structures to the manual structures available on this dataset. PyDicer provides functionality to
compute similarity metrics, but we must first prepare a DataFrame containing our auto structure
sets (`df_target`) and a separate DataFrame with our manual structure sets (`df_reference`).

In [None]:
df = pydicer.read_converted_data(dataset_name=validation_dataset)
df_structs = df[df.modality=="RTSTRUCT"]

df_reference = df_structs[~df_structs.hashed_uid.str.startswith(f"atlas_")]
df_target = df_structs[df_structs.hashed_uid.str.startswith(f"atlas_")]

In [None]:
df_reference

In [None]:
df_target

### Compute Similarity 

We use the `compute_contour_similarity_metrics` function to compute the metrics comparing our
target structures to our reference structures.

We can specify which metrics we want to compute, in this example we compute the Dice Similarity
Coefficient (DSC), Hausdorff Distance, Mean Surface Distance and the Surface DSC.

> Structure names must match exactly, so we use a structure name mapping to standardise our
> structure names prior to computing the similarity metrics.

In [None]:
# Add our structure name mapping
mapping_id = "nnunet_lctsc"
mapping = {
    "Esophagus": [],
    "Heart": [],
    "Lung_L": ["L_Lung", "Lung_Left"],
    "Lung_R": ["Lung_Right"],
    "SpinalCord": ["SC"],
}
pydicer.add_structure_name_mapping(mapping, mapping_id)

# Specify the metrics we want to compute
compute_metrics = ["DSC", "hausdorffDistance", "meanSurfaceDistance", "surfaceDSC"]

# Compute the similarity metrics
compute_contour_similarity_metrics(
    df_target,
    df_reference,
    segment_id,
    compute_metrics=compute_metrics,
    mapping_id=mapping_id
)

### Fetch the similarity metrics

Here we fetch the metrics computed and output some stats. Note that if a segmentation fails, 
surface metrics will return NaN and will be excluded from these stats.

In [None]:
# Fetch the similarity metrics
df_metrics = get_all_similarity_metrics_for_dataset(
    working_directory,
    dataset_name=validation_dataset,
    structure_mapping_id=mapping_id
)

# Aggregate the stats using Pandas
df_metrics[
    ["segment_id", "structure", "metric", "value"]
    ].groupby(
        ["segment_id", "structure", "metric"]
    ).agg(
        ["mean", "std", "min", "max", "count"]
    )


### Perform Analysis

There are various plots and visualisations you may wish to produce following computation of
similarity metrics. The `prepare_similarity_metric_analysis` will generate several useful plots
which will serve as a useful starting point when analysing your auto-segmentation results.

Plots are generated in a directory you provide. In this example, plots and tables (`.csv`) are
output in the `testdata_lctsc/analysis/atlas` directory. Run the following cell, then navigate
to that directory to explore the results.

In [None]:
analysis_output_directory = working_directory.joinpath(
    "analysis",
    segment_id
)
analysis_output_directory.mkdir(parents=True, exist_ok=True)

prepare_similarity_metric_analysis(
    working_directory,
    analysis_output_directory=analysis_output_directory,
    dataset_name=validation_dataset,
    structure_mapping_id=mapping_id
)