# Tutorial on "Decollate a batch"

# `What is decollate?`

`decollate batch` is a highlight feature in MONAI v0.6, which simplifies the post processing transforms and provides flexible following operations on a batch of data with various data shape.

1. As a preprocessing step in a regular PyTorch program, we usually apply transforms to each input item and `collate` the processed data into a mini-batch (via a PyTorch dataloader with a `collate_fn`). The 'batched' data are used in the rest of the workflow, e.g. for the model forward, training loss computation steps:

![image](../figures/collate_batch.png)

2. As of MONAI v0.6, we recommand a decollating operation as the first postprocessing step, to convert a 'batched' data (e.g. model predictions) into a list of tensors. 
The typical logic from `decollate batch`:

![image](../figures/decollate_batch.png)


## `Why decollate?`
The benefits of this 'decollating' operation are:

(1) we can execute postprocessing transforms for each item in the output mini-batch respectively, some randomised transforms could be applied with different randomised behaviour for each prediction independently.

(2) Both the preprocessing and postprocessing transforms only need to support `channel-first` shape of input data. this simplifies the transform API design, and reduces input validation burdens.

(3) It allows to apply `Invertd` transform for the predictions and the inverted data can have different shape, because they are in a list, not stacked in a signle batch tensor anymore.

(4) All the MONAI metrics can support both `batch-first` tensor and list of `channel-first` tensors, so we can compute metrics for the inverted data (potentially in different data shape) directly.



## `How to decollate?`

The rest of the tutorial shows a detailed example program that executes a typical `collate batch` and `decollate batch` workflows.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/master/modules/decollate_batch.ipynb)

## Setup environment

In [1]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"

## Setup imports

In [1]:
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import sys
import shutil
import tempfile
import nibabel as nib
import numpy as np
import torch
from glob import glob

from monai.config import print_config
from monai.data import create_test_image_3d, Dataset, DataLoader, decollate_batch
from monai.handlers import from_engine
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.transforms import (
    Activationsd,
    EnsureChannelFirstd,
    EnsureTyped,
    AsDiscreted,
    Compose,
    Invertd,
    LoadImaged,
    Orientationd,
    Resized,
    SaveImaged,
    ScaleIntensityd,
)
from monai.utils import set_determinism

print_config()

MONAI version: 0.4.0+544.g5e7345d
Numpy version: 1.21.0
Pytorch version: 1.9.0+cu102
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: 5e7345d384ae08011b0e250b93f615d6d5190258

Optional dependencies:
Pytorch Ignite version: 0.4.5
Nibabel version: 3.2.1
scikit-image version: 0.15.0
Pillow version: 7.0.0
Tensorboard version: 1.15.0+nv
gdown version: 3.13.0
TorchVision version: 0.9.0a0
ITK version: 5.1.2
tqdm version: 4.53.0
lmdb version: 1.1.1
psutil version: 5.8.0
pandas version: 1.1.4
einops version: 0.3.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified, a temporary directory will be used.

In [3]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/workspace/data/medical


## Set determinism, logging, device

In [3]:
set_determinism(seed=0)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
device = torch.device("cuda:0")

## Generate random (image, label) pairs

Generate 5 `image` and `label` pairs for this evaluation task.

In [4]:
for i in range(5):
    im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)

    n = nib.Nifti1Image(im, np.eye(4))
    nib.save(n, os.path.join(root_dir, f"img{i:d}.nii.gz"))

    n = nib.Nifti1Image(seg, np.eye(4))
    nib.save(n, os.path.join(root_dir, f"seg{i:d}.nii.gz"))

images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

## Setup preprocessing transforms, dataset, dataloader

In [5]:
preprocessing = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        EnsureChannelFirstd(keys=["img", "seg"]),
        Orientationd(keys="img", axcodes="RAS"),
        Resized(keys="img", spatial_size=(96, 96, 96), mode="trilinear", align_corners=True),
        ScaleIntensityd(keys="img"),
        EnsureTyped(keys=["img", "seg"]),
    ]
)
dataset = Dataset(data=files, transform=preprocessing)
dataloader = DataLoader(dataset, batch_size=1, num_workers=4)

## Setup postprocessing transforms, metrics
Here we try to invert the preprocessing predictions for `pred` and save into Nifti files.

As all the post processing transforms expect `Tensor` input, apply `EnsureTyped` first to ensure the data type after `decollate_batch`.

In [6]:
postprocessing = Compose(
    [
        EnsureTyped(keys=["pred", "seg"]),  # ensure Tensor type after `decollate`
        Activationsd(keys="pred", sigmoid=True),
        Invertd(
            keys="pred",  # invert the `pred` data field, also support multiple fields
            transform=preprocessing,
            orig_keys="img",  # get the previously applied pre_transforms information on the `img` data field
            meta_keys="pred_meta_dict",  # key field to save inverted meta data, every item maps to `keys`
            orig_meta_keys="img_meta_dict",  # use the meta data from `img_meta_dict` field when inverting
            nearest_interp=False,  # don't change the interpolation mode of preprocessing when inverting
            to_tensor=True,
            device=device,
        ),
        AsDiscreted(keys="pred", threshold_values=True),
        SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir=root_dir, resample=False),
    ]
)
# will compute mean dice on the decollated `predictions` and `labels`, which are list of `channel-first` tensors
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

## Execute the evaluation progress with all the above components
Here we use a randomly initialized `UNet` to execute evaluation, usually we load a pretrained weights in the real-world practice.

In [None]:
model = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

model.eval()
with torch.no_grad():
    for data in dataloader:
        images, labels = data["img"].to(device), data["seg"].to(device)
        # define sliding window size and batch size for windows inference
        roi_size = (64, 64, 64)
        sw_batch_size = 4
        data["pred"] = sliding_window_inference(images, roi_size, sw_batch_size, model)
        data["seg"] = labels

        # decollate the batch data into list of dictionaries, every dictionary maps to an input data
        data = [postprocessing(i) for i in decollate_batch(data)]
        # extract a list of `prections` and a list of `labels` with the `from_engine` utility
        pred, y = from_engine(["pred", "seg"])(data)
        # compute mean dice for current iteration
        dice_metric(y_pred=pred, y=y)
    # aggregate the final mean dice result
    print(f"evaluation metric: {dice_metric.aggregate().item()}")
    # reset the metric status
    dice_metric.reset()

## Cleanup data directory

Remove directory if a temporary was used.

In [8]:
if directory is None:
    shutil.rmtree(root_dir)