# How to set `batch_transform` and `output_transform`

MONAI defines many ignite style Event handlers to provide general purpose logic during training or evaluation, like: stats logging, MLFlow tracking, TensorBoard, metrics, etc.

Usually, the handler has a callable arg named `batch_transform` or `output_transform` to help prepare expected data from ignite `engine.state.batch` or `engine.state.output`. This tutorial shows examples about how to set `batch_transform` and `output_transform` for different kinds of handlers with different data shapes.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/modules/batch_output_transform.ipynb)

## Setup environment

In [2]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm, ignite]"

Note: you may need to restart the kernel to use updated packages.


## Setup imports

In [1]:
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import sys
import tempfile
from glob import glob
import shutil

import nibabel as nib
import numpy as np
import torch
from ignite.metrics import Accuracy

from monai.data import CacheDataset, create_test_image_3d, DataLoader
from monai.engines import SupervisedEvaluator, SupervisedTrainer
from monai.handlers import (
    CheckpointSaver,
    MeanDice,
    StatsHandler,
    TensorBoardImageHandler,
    TensorBoardStatsHandler,
    ValidationHandler,
    from_engine,
)
from monai.inferers import SimpleInferer, SlidingWindowInferer
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from monai.transforms import (
    Activationsd,
    EnsureChannelFirstd,
    AsDiscreted,
    Compose,
    KeepLargestConnectedComponentd,
    LoadImaged,
    RandCropByPosNegLabeld,
    ScaleIntensityd,
    EnsureTyped,
)
from monai.utils import get_torch_version_tuple

logging.basicConfig(stream=sys.stdout, level=logging.INFO)

## Data shape in `engine.state.batch` and `engine.state.output`

All the MONAI engines and handlers inherit from **PyTorch ignite** concepts, which clearly define the engine workflow with `State` for data context and `Event-handler` mechanism to attach and trigger components. More details about the basic concepts of `engine.state` and `batch / output transform`, please refer to: https://pytorch.org/ignite/concepts.html#state.

First of all, let's take a look at the possible data shape in `engine.state.batch` and `engine.state.output`.

### engine.state.batch
(1) For a common ignite program, `batch` is usually the iterable output of PyTorch DataLoader, for example: `{"image": MetaTensor, "label" MetaTensor, "image_meta_dict": Dict}` where `image` and `label` are batch-first arrays, `image_meta_dict` is a dictionary of meta information for the input images, every item is a batch:
```
image.shape = [2, 4, 64, 64, 64]  # here 2 is batch size, 4 is channels
label.shape = [2, 3, 64, 64, 64]
image_meta_data = {"filename_or_obj": ["/data/image1.nii", "/data/image2.nii"]}
```

(2) For MONAI engines, it will automatically `decollate` the batch data into a list of `channel-first` data after every iteration. For more details about `decollate`, please refer to: https://github.com/Project-MONAI/tutorials/blob/main/modules/decollate_batch.ipynb.

The `engine.state.batch` example in (1) will be decollated into a list of dictionaries:
`[{"image": MetaTensor, "label" MetaTensor, "image_meta_dict": Dict}, {"image": MetaTensor, "label" MetaTensor, "image_meta_dict": Dict}]`.

each item of the list can be:
```
image.shape = [3, 64, 64, 64]  # here is channel-first array for 1 image
label.shape = [3, 64, 64, 64]
image_meta_data = {"filename_or_obj": "/data/image1.nii"}
```

### engine.state.output
(1) For a common ignite program, `output` is usually the output data of current iteration, for example: `{"pred": MetaTensor, "label": MetaTensor, "loss": scalar}` where `pred` and `label` are batch-first arrays, `loss` is a scalar value of current iteration:
```
pred.shape = [2, 3, 64, 64, 64]  # here 2 is batch size, 3 is channels
label.shape = [2, 3, 64, 64, 64]
loss = 0.4534
```

(2) For MONAI engines, it will also automatically `decollate` the output data into a list of `channel-first` data after every iteration.
The `engine.state.output` example in (1) will be decollated into a list of dictionaries:
`[{"pred": MetaTensor, "label": MetaTensor, "loss" 0.4534}, {"pred": MetaTensor, "label": MetaTensor, "loss" 0.4534}]`. Please note that it replicated the scalar value of `loss` to every item of the decollated list.

## Define `batch_transform` and `output_transform` to extract data

Now let's analyze the cases of extracting data from `engine.state.batch` or `engine.state.output`. To simplify the operation, we developed a utility function `monai.handlers.from_engine` to automatically handle all the common cases.

(1) To get the meta data from dictionary format `engine.state.batch`, set arg `batch_transform=lambda x: x.meta`.

(2) To get the meta data from decollated list of dictionaries `engine.state.batch`, set arg `lambda x: [i.meta for i in x]` or `from_engine("image_meta_dict")`.

(3) Metrics usually expect a `Tuple(pred, label)` input, if `engine.state.output` is a dictionary, set arg `output_transform=lambda x: (x["pred"], x["label"])`. If decollated list, set arg `lambda x: ([i["pred"] for i in x], [i["label"] for i in x])` or `from_engine(["pred", "label"])`.

(4) To get the scalar value like `loss`, if `engine.state.output` is a dictionary, set arg `lambda x: x["loss"]`. If decollated list, set arg `lambda x: x[0]["loss"]` or `from_engine(["loss"], first=True)`. Please note that here we need to set `first=True` in `from_engine` to make sure only get the scalar value from 1 item.

## Setup data directory

Now let's try to setup a real-world demo program with MONAI engines and show the setting of `batch_transform` and `output_transform` in handlers.

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified, a temporary directory will be used.

In [2]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(f"root dir is: {root_dir}")

root dir is: /workspace/data/medical


## Prepare synthetic data for test

Here we generate 40 (image, label) pairs, 20 for training, 20 for validation.

In [3]:
for i in range(40):
    im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)
    n = nib.Nifti1Image(im, np.eye(4))
    nib.save(n, os.path.join(root_dir, f"img{i:d}.nii.gz"))
    n = nib.Nifti1Image(seg, np.eye(4))
    nib.save(n, os.path.join(root_dir, f"seg{i:d}.nii.gz"))

images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
train_files = [{"image": img, "label": seg} for img, seg in zip(images[:20], segs[:20])]
val_files = [{"image": img, "label": seg} for img, seg in zip(images[-20:], segs[-20:])]

## Setup train / val transforms, dataset, DataLoader, post transforms

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
        ),
        EnsureTyped(keys=["image", "label"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"], channel_dim=-1),
        ScaleIntensityd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
    ]
)

# create a training data loader
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5)
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
# create a validation data loader
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)

val_post_transforms = Compose(
    [
        EnsureTyped(keys="pred"),
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold=0.5),
        KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
    ]
)

## Setup network, optimizer, loss function, etc.

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)
loss = DiceLoss(sigmoid=True)
opt = torch.optim.Adam(net.parameters(), 3e-4)

## Setup `handlers` and `metrics` for train and validation

Main topic of this tutorial, set arg `batch_transform` and `output_transform` to extract expected data.

In [6]:
val_handlers = [
    # no stats for iteration of validation
    StatsHandler(output_transform=lambda x: None),
    TensorBoardStatsHandler(log_dir="./runs/", output_transform=lambda x: None),
    TensorBoardImageHandler(
        log_dir="./runs/",
        batch_transform=from_engine(["image", "label"]),
        output_transform=from_engine(["pred"]),
    ),
    CheckpointSaver(save_dir="./runs/", save_dict={"net": net}, save_key_metric=True),
]

evaluator = SupervisedEvaluator(
    device=device,
    val_data_loader=val_loader,
    network=net,
    inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5),
    postprocessing=val_post_transforms,
    key_val_metric={
        "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"]))
    },
    additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
    val_handlers=val_handlers,
    # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
    amp=True if get_torch_version_tuple() >= (1, 6) else False,
)

train_handlers = [
    ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
    StatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)),
    TensorBoardStatsHandler(
        log_dir="./runs/",
        tag_name="train_loss",
        output_transform=from_engine(["loss"], first=True),
    ),
]

trainer = SupervisedTrainer(
    device=device,
    max_epochs=5,
    train_data_loader=train_loader,
    network=net,
    optimizer=opt,
    loss_function=loss,
    inferer=SimpleInferer(),
    train_handlers=train_handlers,
    # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP training
    amp=True if get_torch_version_tuple() >= (1, 6) else False,
)

## Execute training to verify the settings

In [None]:
trainer.run()

## Cleanup data directory

Remove directory if a temporary was used.

In [None]:
if directory is None:
    shutil.rmtree(root_dir)