Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# Transforms updating meta data

MONAI transforms update meta data as necessary when they are applied. This is done by storing the image as a `MetaTensor`, which is a sub-class of `torch.Tensor`. That is to say, it operates just like a normal `torch.Tensor`, but it also contains information such as the meta data and the image's affine transformation matrix.

This allows us to perform some pre-processing transforms, infer an image with our network and then save the inferred segmentation to file. Although the output is different in terms of size and voxel spacing, the input image and output segmentation should align as expected in an external image viewer (a screenshot of ITKSnap is given, but any should do the trick).

We only use a single image in this training script. This allows us to obtain near-perfect results very quickly. The process would be exactly the same if more images were to be used.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/MetaTensor/modules/transforms_update_meta_data.ipynb)

## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"

## Setup imports

In [1]:
import os
import shutil
from tqdm import trange
import torch
from glob import glob
import tempfile

from monai.config import print_config
import monai.transforms as mt
from monai.data import CacheDataset, DataLoader, decollate_batch
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.apps import download_and_extract

print_config()

MONAI version: 1.1.0+11.g7de6c336.dirty
Numpy version: 1.22.2
Pytorch version: 1.13.0+cu117
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 7de6c33656a99087ca3b89a817b0879cf093febc
MONAI __file__: /workspace/Code/MONAI/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.10
Nibabel version: 4.0.2
scikit-image version: 0.19.3
Pillow version: 9.0.1
Tensorboard version: 2.11.0
gdown version: 4.6.0
TorchVision version: 0.14.0+cu117
tqdm version: 4.64.1
lmdb version: 1.3.0
psutil version: 5.9.2
pandas version: 1.1.5
einops version: 0.6.0
transformers version: 4.21.3
mlflow version: 2.0.1
pynrrd version: 1.0.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



In [2]:
device = torch.device("cuda")

## Setup data directory

In [2]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/workspace/Data


## Download the data

The dataset comes from http://medicaldecathlon.com/.  
Target: Gliomas segmentation necrotic/active tumour and oedema  
Modality: Multimodal multisite MRI data (FLAIR, T1w, T1gd,T2w)  
Size: 750 4D volumes (484 Training + 266 Testing)  
Source: BRATS 2016 and 2017 datasets.  
Challenge: Complex and heterogeneously-located targets

Below figure shows image patches with the tumor sub-regions that are annotated in the different modalities (top left) and the final labels for the whole dataset (right).
(Figure taken from the [BraTS IEEE TMI paper](https://ieeexplore.ieee.org/document/6975210/))

![image](../figures/brats_tasks.png)

The image patches show from left to right:
1. the whole tumor (yellow) visible in T2-FLAIR (Fig.A).
1. the tumor core (red) visible in T2 (Fig.B).
1. the enhancing tumor structures (light blue) visible in T1Gd, surrounding the cystic/necrotic components of the core (green) (Fig. C).
1. The segmentations are combined to generate the final labels of the tumor sub-regions (Fig.D): edema (yellow), non-enhancing solid core (red), necrotic/cystic core (green), enhancing core (blue).

In [4]:
task = "Task01_BrainTumour"
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/" + task + ".tar"

compressed_file = os.path.join(root_dir, task + ".tar")
data_dir = os.path.join(root_dir, task)
if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, root_dir)

images = sorted(glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
labels = sorted(glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [{"image": image, "label": label}
              for image, label in zip(images, labels)]
data = [data_dicts[0]]

## Transforms

Of those applied, we expect transforms such as `Orientationd`, `CropForegroundd`, `Spacingd`, and `DivisiblePadd` to all modify an images affine transformation matrix.

In [5]:
keys = ("image", "label")
t = mt.Compose([
    mt.LoadImaged(keys),
    mt.EnsureChannelFirstd(keys),
    mt.Orientationd(keys, "RAI"),
    mt.CropForegroundd(keys, source_key="image"),
    mt.Spacingd(keys, pixdim=[0.5, 0.5, 1], mode=("bilinear", "nearest")),
    mt.ScaleIntensityd("image"),
    mt.DivisiblePadd(keys, 8),
])

## Simple dataset and dataloader

In [6]:
ds = CacheDataset(data, transform=t)
dl = DataLoader(ds)

Loading dataset: 100%|██████████| 1/1 [00:03<00:00,  3.16s/it]


## Quick training loop

In [7]:
model = UNet(
    spatial_dims=3,
    in_channels=4,
    out_channels=4,
    channels=(16, 32, 64, 128),
    strides=(2, 2, 2),
).to(device)
loss_function = DiceLoss(softmax=True, to_onehot_y=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-2)

max_epochs = 50
tr = trange(max_epochs)
for _ in tr:
    for batch in dl:
        inputs, labels = batch["image"].to(device), batch["label"].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
    tr.set_description(f"Loss: {loss.item():.4f}")

Loss: 0.1242: 100%|██████████| 50/50 [00:09<00:00,  5.18it/s]


## Post transforms

Infer a test image, perform argmax to get the class for each voxel, and then save the result to file.

In [8]:
out_path = tempfile.mkdtemp()

post_trans = mt.Compose([
    mt.AsDiscrete(argmax=True),
    mt.SaveImage(output_dir=out_path, output_ext=".nii", resample=False),
])
out = [post_trans(i) for i in decollate_batch(model(batch["image"].to(device))) for batch in dl]

2022-06-16 10:22:00,304 INFO image_writer.py:193 - writing: /tmp/tmppc6xwd66/BRATS_001/BRATS_001_trans.nii


ignoring the tracking transform info.


## Viewing results

If you load the input image and inferred segmentation in an external viewer, hopefully you should see that although the input and output images have different sizes and spacing, they are nicely aligned.

![transforms_update_meta.png](attachment:transforms_update_meta.png)

## Cleanup data directory

Remove directory if a temporary was used.

In [3]:
if directory is None:
    shutil.rmtree(root_dir)