# Brain parcellation with TorchIO and HighRes3DNet

We are going to perform a full brain parcellation of a 3D T1-weighted MRI using [TorchIO](https://torchio.readthedocs.io/) and a pre-trained [PyTorch](https://pytorch.org/) deep learning model in less than 50 lines of code and less than two minutes.

## TorchIO

TorchIO is a Python package to prepare 3D medical images for deep learning pipelines. Check out the [documentation](https://torchio.readthedocs.io/) and a longer [Colab notebook](https://colab.research.google.com/drive/112NTL8uJXzcMw4PQbUvMQN-WHlVwQS3i) containing many examples.

## HighRes3DNet

HighRes3DNet is a 3D residual network presented by Li et al. in [On the Compactness, Efficiency, and Representation of 3D Convolutional Networks: Brain Parcellation as a Pretext Task](https://link.springer.com/chapter/10.1007/978-3-319-59050-9_28). The authors shared the weights of the model they trained to perform full brain parcellation as in [Geodesic Information Flows: Spatially-Variant Graphs and Their Application to Segmentation and Fusion](https://spiral.imperial.ac.uk/bitstream/10044/1/30755/4/07086081.pdf), also known as GIF parcellation.

The weights were ported from TensorFlow as shown in [my entry to the MICCAI educational challenge 2019](https://github.com/fepegar/miccai-educational-challenge-2019).

First we will install TorchIO and download a couple of useful files:

In [1]:
!pip install -q torchio
!curl -sS -O https://raw.githubusercontent.com/fepegar/miccai-educational-challenge-2019/master/visualization.py
!curl -sS -O https://raw.githubusercontent.com/fepegar/highresnet/master/GIFNiftyNet.ctbl

In [2]:
import datetime
from pathlib import Path

import torch
torch.set_grad_enabled(False);
import numpy as np
import torchio as tio
from tqdm.notebook import tqdm
from torchvision.datasets.utils import download_and_extract_archive
import visualization
plot_volume = visualization.plot_volume_interactive
%config InlineBackend.figure_format = 'retina'
print('TorchIO version:', tio.__version__)
print('Last run:', datetime.date.today())

If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Credits instructions: https://torchio.readthedocs.io/#credits

TorchIO version: 0.17.48
Last run: 2020-10-18


## Preparing the data

Let's download a T1-weighted MRI hosted on the NiftyNet model zoo:

In [3]:
root_dir = Path('data')
download_and_extract_archive(
    'https://github.com/NifTK/NiftyNetModelZoo/raw/5-reorganising-with-lfs/highres3dnet_brain_parcellation/data.tar.gz',
    root_dir
)
mri_path = list(root_dir.glob('*.nii.gz'))[0]
mri_path

Using downloaded and verified file: data/data.tar.gz
Extracting data/data.tar.gz to data


PosixPath('data/OAS1_0145_MR2_mpr_n4_anon_sbj_111.nii.gz')

We will now create an instance of [`torchio.Image`](https://torchio.readthedocs.io/data/dataset.html#image) and pass it to a [`torchio.Subject`](https://torchio.readthedocs.io/data/dataset.html#subject):

In [4]:
subject_oasis = tio.Subject(t1=tio.ScalarImage(mri_path))
subject_colin = tio.datasets.Colin27()
subject = subject_oasis  # try subject_colin instead!

Using cache found in /root/.cache/torchio/mni_colin27_1998_nifti


We are going to apply four [preprocessing transforms](https://torchio.readthedocs.io/transforms/preprocessing.html) to our instance of `Subject`:

1. [`ToCanonical`](https://torchio.readthedocs.io/transforms/preprocessing.html#torchio.transforms.ToCanonical) reorients our image using [NiBabel](https://nipy.org/nibabel/) so that it is in [RAS+ orientation](https://nipy.org/nibabel/image_orientation.html)
2. [`Resample`](https://torchio.readthedocs.io/transforms/preprocessing.html#torchio.transforms.Resample) changes the voxels spacing using [SimpleITK](https://simpleitk.org/). This is analogous to using [`torchvision.transforms.Resize`](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Resize)
3. [`HistogramStandardization`](https://torchio.readthedocs.io/transforms/preprocessing.html#torchio.transforms.HistogramStandardization) is a sofisticated method based on histogram landmarks training. It was presented by Nyúl in [New Variants of a Method of MRI Scale Standardization](https://pubmed.ncbi.nlm.nih.gov/10784285/)
4. [`ZNormalization`](https://torchio.readthedocs.io/transforms/preprocessing.html#torchio.transforms.ZNormalization) generates an image with zero mean and unit variance. Both normalization transforms use the foreground values (computed as the values over the mean intensity) to calculate the corresponding statistics

In [5]:
# From NiftyNet model zoo
LI_LANDMARKS = "4.4408920985e-16 8.06305571158 15.5085721044 18.7007018006 21.5032879029 26.1413278906 29.9862059045 33.8384058795 38.1891334787 40.7217966068 44.0109152758 58.3906435207 100.0"
li_landmarks = np.array([float(n) for n in LI_LANDMARKS.split()])

transforms = [
    tio.ToCanonical(),
    tio.Resample(1),
    tio.HistogramStandardization(landmarks={'t1': li_landmarks}, masking_method=tio.ZNormalization.mean),
    tio.ZNormalization(masking_method=tio.ZNormalization.mean),
]
transform = tio.Compose(transforms)
preprocessed = transform(subject)

The transforms are chained together using [`torchio.Compose`](https://torchio.readthedocs.io/transforms/augmentation.html#compose) or [`torchvision.transforms.Compose`](https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.Compose). As you can see, the interface is similar to [`torchvision.transforms`](https://pytorch.org/docs/stable/torchvision/transforms.html). Let's take a look at the input volume:

In [6]:
plot_volume(subject.t1.numpy().squeeze())

HBox(children=(IntSlider(value=128, continuous_update=False, description='Sagittal L-R', max=255), IntSlider(v…

Output()

The labels are incorrect because `plot_volume` expects an image in RAS+ orientation. It was a good idea to add `ToCanonical` to our list of transforms:

In [7]:
plot_volume(preprocessed.t1.numpy().squeeze())

HBox(children=(IntSlider(value=80, continuous_update=False, description='Sagittal L-R', max=159), IntSlider(va…

Output()

## Pretrained model

We will use the wonderful [PyTorch Hub](https://pytorch.org/hub/) to download the pretrained model from GitHub:

In [8]:
repo = 'fepegar/highresnet'
model_name = 'highres3dnet'
model = torch.hub.load(repo, model_name, pretrained=True)
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
print('Device:', device)
model.to(device);

Using cache found in /root/.cache/torch/hub/fepegar_highresnet_master


Device: cuda


## Inference

Since the image is too large for the available GPU, we need to [perform inference using image patches](https://torchio.readthedocs.io/data/patch_based.html).

We will use a [`GridSampler`](https://torchio.readthedocs.io/data/patch_inference.html#grid-sampler) to extract patches from all the necessary locations in the image. `GridSampler` is a subclass of [`torch.utils.data.Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset), so we can easily extract batches of patches using a [`torch.utils.data.Dataloader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataloader).

We will add a small overlap between patches to avoid the border effect.


In [9]:
patch_overlap = 4
patch_size = 128
grid_sampler = tio.inference.GridSampler(
    preprocessed,
    patch_size,
    patch_overlap,
)
patch_loader = torch.utils.data.DataLoader(grid_sampler)
aggregator = tio.inference.GridAggregator(grid_sampler)

model.eval()
for patches_batch in tqdm(patch_loader):
    input_tensor = patches_batch['t1'][tio.DATA].to(device)
    locations = patches_batch[tio.LOCATION]
    logits = model(input_tensor)
    labels = logits.argmax(dim=tio.CHANNELS_DIMENSION, keepdim=True)
    aggregator.add_batch(labels, locations)
output_tensor = aggregator.get_output_tensor()
plot_volume(
    output_tensor.numpy().squeeze(),
    enhance=False,
    colors_path='GIFNiftyNet.ctbl',
)

HBox(children=(FloatProgress(value=0.0, max=18.0), HTML(value='')))






HBox(children=(IntSlider(value=80, continuous_update=False, description='Sagittal L-R', max=159), IntSlider(va…

Output()

The result is not as good as a full [GIF](https://spiral.imperial.ac.uk/bitstream/10044/1/30755/4/07086081.pdf) or [FreeSurfer](https://surfer.nmr.mgh.harvard.edu/fswiki/recon-all) parcellation but hey, it's 600 times faster!

## Conclusion

We have seen how to combine TorchIO and PyTorch hub to infer a full brain parcellation using patches of a 3D T1-weighted MRI and a pre-trained convolutional neural network.

TorchIO is looking for feedback and contributors. Don't hesitate to [open an issue](https://github.com/fepegar/torchio/issues/new/choose) in the repository with questions or feature requests!