<a href="https://colab.research.google.com/github/fepegar/torchio-notebooks/blob/main/notebooks/Brain_parcellation_with_TorchIO_and_HighRes3DNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Brain parcellation with TorchIO and HighRes3DNet

We are going to perform a full brain parcellation of a 3D T1-weighted MRI using [TorchIO](https://torchio.readthedocs.io/) and a pre-trained [PyTorch](https://pytorch.org/) deep learning model in less than 50 lines of code and less than two minutes.

## TorchIO

TorchIO is a Python package to prepare 3D medical images for deep learning pipelines. Check out the [documentation](https://torchio.readthedocs.io/) and a longer [Colab notebook](https://colab.research.google.com/drive/112NTL8uJXzcMw4PQbUvMQN-WHlVwQS3i) containing many examples.

## HighRes3DNet

HighRes3DNet is a 3D residual network presented by Li et al. in [On the Compactness, Efficiency, and Representation of 3D Convolutional Networks: Brain Parcellation as a Pretext Task](https://link.springer.com/chapter/10.1007/978-3-319-59050-9_28). The authors shared the weights of the model they trained to perform full brain parcellation as in [Geodesic Information Flows: Spatially-Variant Graphs and Their Application to Segmentation and Fusion](https://spiral.imperial.ac.uk/bitstream/10044/1/30755/4/07086081.pdf), also known as GIF parcellation.

The weights were ported from TensorFlow as shown in [my entry to the MICCAI educational challenge 2019](https://github.com/fepegar/miccai-educational-challenge-2019).

First we will install TorchIO and download a couple of useful files:

In [45]:
!pip install -q torchio==0.18.15
!curl -sS -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/visualization.py
!curl -sS -O https://raw.githubusercontent.com/fepegar/highresnet/master/GIFNiftyNet.ctbl

In [46]:
import datetime
from pathlib import Path

import torch
torch.set_grad_enabled(False);
import numpy as np
import torchio as tio
from tqdm.notebook import tqdm, trange
from torchvision.datasets.utils import download_and_extract_archive
import visualization
plot_volume = visualization.plot_volume_interactive
%config InlineBackend.figure_format = 'retina'
torch.manual_seed(20202021)
print('TorchIO version:', tio.__version__)
print('Last run:', datetime.date.today())

TorchIO version: 0.18.15
Last run: 2021-01-01


## Preparing the data

Let's download a T1-weighted MRI hosted on the NiftyNet model zoo:

In [47]:
root_dir = Path('data')
download_and_extract_archive(
    'https://github.com/NifTK/NiftyNetModelZoo/raw/5-reorganising-with-lfs/highres3dnet_brain_parcellation/data.tar.gz',
    root_dir
)
mri_path = list(root_dir.glob('*.nii.gz'))[0]
mri_path

Using downloaded and verified file: data/data.tar.gz
Extracting data/data.tar.gz to data


PosixPath('data/OAS1_0145_MR2_mpr_n4_anon_sbj_111.nii.gz')

We will now create an instance of [`torchio.Image`](https://torchio.readthedocs.io/data/dataset.html#image) and pass it to a [`torchio.Subject`](https://torchio.readthedocs.io/data/dataset.html#subject):

In [48]:
subject_oasis = tio.Subject(t1=tio.ScalarImage(mri_path))
subject_colin = tio.datasets.Colin27()
subject = subject_oasis  # try subject_colin instead!

We are going to apply four [preprocessing transforms](https://torchio.readthedocs.io/transforms/preprocessing.html) to our instance of `Subject`:

1. [`ToCanonical`](https://torchio.readthedocs.io/transforms/preprocessing.html#torchio.transforms.ToCanonical) reorients our image using [NiBabel](https://nipy.org/nibabel/) so that it is in [RAS+ orientation](https://nipy.org/nibabel/image_orientation.html)
2. [`Resample`](https://torchio.readthedocs.io/transforms/preprocessing.html#torchio.transforms.Resample) changes the voxels spacing using [SimpleITK](https://simpleitk.org/). This is analogous to using [`torchvision.transforms.Resize`](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Resize)
3. [`HistogramStandardization`](https://torchio.readthedocs.io/transforms/preprocessing.html#torchio.transforms.HistogramStandardization) is a sofisticated method based on histogram landmarks training. It was presented by Nyúl in [New Variants of a Method of MRI Scale Standardization](https://pubmed.ncbi.nlm.nih.gov/10784285/)
4. [`ZNormalization`](https://torchio.readthedocs.io/transforms/preprocessing.html#torchio.transforms.ZNormalization) generates an image with zero mean and unit variance. Both normalization transforms use the foreground values (computed as the values over the mean intensity) to calculate the corresponding statistics

In [49]:
# From NiftyNet model zoo
li_landmarks = np.array((0, 8, 16, 19, 22, 26, 30, 34, 38, 41, 44, 58, 100))

transforms = [
    tio.ToCanonical(),
    tio.Resample(1),
    tio.HistogramStandardization(landmarks={'t1': li_landmarks}, masking_method=tio.ZNormalization.mean),
    tio.ZNormalization(masking_method=tio.ZNormalization.mean),
]
transform = tio.Compose(transforms)
preprocessed = transform(subject)

The transforms are chained together using [`torchio.Compose`](https://torchio.readthedocs.io/transforms/augmentation.html#compose) or [`torchvision.transforms.Compose`](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Compose). As you can see, the interface is similar to [`torchvision.transforms`](https://pytorch.org/docs/stable/torchvision/transforms.html). Let's take a look at the input volume:

In [50]:
plot_volume(subject.t1.numpy().squeeze())

HBox(children=(IntSlider(value=128, continuous_update=False, description='Sagittal L-R', max=255), IntSlider(v…

Output()

The labels are incorrect because `plot_volume` expects an image in [RAS+ orientation](https://nipy.org/nibabel/image_orientation.html). It was a good idea to add `ToCanonical` to our list of transforms:

In [51]:
plot_volume(preprocessed.t1.numpy().squeeze())

HBox(children=(IntSlider(value=80, continuous_update=False, description='Sagittal L-R', max=159), IntSlider(va…

Output()

## Pretrained model

We will use the wonderful [PyTorch Hub](https://pytorch.org/hub/) to download the pretrained model from GitHub:

In [52]:
repo = 'fepegar/highresnet'
model_name = 'highres3dnet'
model = torch.hub.load(repo, model_name, pretrained=True)
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
print('Device:', device)
model.to(device).eval();

Device: cuda


Using cache found in /root/.cache/torch/hub/fepegar_highresnet_master


## Inference


In [53]:
input_tensor = preprocessed.t1.data[None].to(device)  # add batch dim
with torch.cuda.amp.autocast():
    logits = model(input_tensor)
full_volume_output_tensor = logits.argmax(dim=tio.CHANNELS_DIMENSION, keepdim=True).cpu()
plot_volume(
    full_volume_output_tensor.numpy().squeeze(),
    enhance=False,
    colors_path='GIFNiftyNet.ctbl',
)

HBox(children=(IntSlider(value=80, continuous_update=False, description='Sagittal L-R', max=159), IntSlider(va…

Output()

Since the image might be too large for the available GPU, we could [perform inference using image patches](https://torchio.readthedocs.io/data/patch_based.html) instead.

We will use a [`GridSampler`](https://torchio.readthedocs.io/data/patch_inference.html#grid-sampler) to extract patches from all the necessary locations in the image. `GridSampler` is a subclass of [`torch.utils.data.Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset), so we can easily extract batches of patches using a [`torch.utils.data.Dataloader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataloader).

We will add a small overlap between patches to avoid the border effect.

In [54]:
patch_overlap = 4
patch_size = 128
grid_sampler = tio.inference.GridSampler(
    preprocessed,
    patch_size,
    patch_overlap,
)
patch_loader = torch.utils.data.DataLoader(grid_sampler)
aggregator = tio.inference.GridAggregator(grid_sampler)

for patches_batch in tqdm(patch_loader, unit='batch'):
    input_tensor = patches_batch['t1'][tio.DATA].to(device)
    locations = patches_batch[tio.LOCATION]
    with torch.cuda.amp.autocast():
        logits = model(input_tensor)
    labels = logits.argmax(dim=tio.CHANNELS_DIMENSION, keepdim=True)
    aggregator.add_batch(labels, locations)
patchwise_output_tensor = aggregator.get_output_tensor()
plot_volume(
    patchwise_output_tensor.numpy().squeeze(),
    enhance=False,
    colors_path='GIFNiftyNet.ctbl',
)

HBox(children=(FloatProgress(value=0.0, max=18.0), HTML(value='')))






HBox(children=(IntSlider(value=80, continuous_update=False, description='Sagittal L-R', max=159), IntSlider(va…

Output()

The result is not as good as a full [GIF](https://spiral.imperial.ac.uk/bitstream/10044/1/30755/4/07086081.pdf) or [FreeSurfer](https://surfer.nmr.mgh.harvard.edu/fswiki/recon-all) parcellation but hey, it's 600 times faster!

## Test-time augmentation and uncertainty estimation

[Test-time augmentation (TTA)](https://www.nature.com/articles/s41598-020-61808-3) can be used to improve the results. We will apply some transform to the image, infer the segmentation on the transformed space and apply the inverse transform to bring the image to the result to the original space. Then, we can use majority voting to obtain a more robust segmentation result.

In [55]:
num_augmentations = 20
crop = tio.Crop((0, 0, 10, 30, 40, 40))
cropped = crop(preprocessed)
cropped.clear_history()  # so that inverse of transforms so far are not applied to output
results = []
augment = tio.OneOf({
    tio.RandomAffine(image_interpolation='nearest'): 0.75,
    tio.RandomElasticDeformation(image_interpolation='nearest'): 0.25,
})
for _ in trange(num_augmentations):
    augmented = augment(cropped)
    input_tensor = augmented.t1.data[None].to(device)
    with torch.cuda.amp.autocast():
        logits = model(input_tensor)
    full_volume_output_tensor = logits.argmax(dim=1, keepdim=True).cpu()
    augmented.t1.data = full_volume_output_tensor[0]
    back = augmented.apply_inverse_transform(warn=False)
    results.append(back.t1.data)
result = torch.stack(results).long()

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




In [56]:
tta_result_tensor = result.mode(dim=0).values  # majority voting
plot_volume(
    tta_result_tensor.numpy().squeeze(),
    enhance=False,
    colors_path='GIFNiftyNet.ctbl',
)

HBox(children=(IntSlider(value=80, continuous_update=False, description='Sagittal L-R', max=159), IntSlider(va…

Output()

The result using TTA should be more robust.

In [57]:
torch.save(tta_result_tensor, 'seg.pth')

### Voxel-wise aleatoric uncertainty estimation

We can [estimate the aleatoric uncertainty](https://www.sciencedirect.com/science/article/pii/S0925231219301961) using our multiple results. We will use the technique from [Li et al.](https://link.springer.com/chapter/10.1007/978-3-319-59050-9_28): the uncertainty at each voxel is the percentage of predictions different to the most frequent prediction during our TTA.

In [58]:
different = torch.stack([
    tensor != tta_result_tensor
    for tensor in results
])
uncertainty = different.float().mean(dim=0)
plot_volume(
    uncertainty.numpy().squeeze(),
    enhance=False,
)

HBox(children=(IntSlider(value=80, continuous_update=False, description='Sagittal L-R', max=159), IntSlider(va…

Output()

As expected, the highest uncertainty values are in voxels on the structures boundaries.

## Conclusion

We have seen how to combine TorchIO and PyTorch hub to infer a full brain parcellation using patches of a 3D T1-weighted MRI and a pre-trained convolutional neural network, or the full volume.

We have also used TorchIO to perform test-time augmentation and estimate a voxel-wise aleatoric uncertainty of our prediction.

TorchIO is looking for feedback and contributors. Don't hesitate to [open an issue](https://github.com/fepegar/torchio/issues/new/choose) in the repository with questions or feature requests!