Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# MONAI 201 Tutorial: Advanced Training Techniques

Welcome to MONAI 201! This tutorial builds upon [MONAI 101](https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/monai_101.ipynb) and introduces advanced training techniques and best practices for production-ready medical AI models.

## What You'll Learn

This intermediate tutorial covers advanced concepts that are essential for building robust medical AI systems:

- **Advanced Training Workflow**: Enhanced training with validation monitoring
- **Model Evaluation**: Comprehensive evaluation using `SupervisedEvaluator`
- **Experiment Tracking**: TensorBoard integration for training visualization
- **Model Checkpointing**: Save and restore model states during training
- **Production Best Practices**: Techniques used in real-world medical AI applications

## Prerequisites

- Complete [MONAI 101](https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/monai_101.ipynb) or have basic MONAI knowledge
- Understanding of deep learning concepts (training, validation, etc.)
- Familiarity with PyTorch basics

## Requirements

- **GPU Memory**: Approximately 7GB
- **Runtime**: About 10 minutes
- **Level**: Intermediate

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/2d_classification/monai_201.ipynb)

## Setup environment

In [10]:
!python -c "import monai" || pip install -q "monai-weekly[ignite, tqdm, tensorboard]"

## Setup imports

In [None]:
import logging
import numpy as np
import os
from pathlib import Path
import sys
import tempfile
import torch
import ignite

from monai.apps import MedNISTDataset
from monai.config import print_config
from monai.data import DataLoader
from monai.engines import SupervisedTrainer, SupervisedEvaluator
from monai.handlers import (
    StatsHandler,
    TensorBoardStatsHandler,
    ValidationHandler,
    CheckpointSaver,
    CheckpointLoader,
    ClassificationSaver,
)
from monai.handlers.utils import from_engine
from monai.inferers import SimpleInferer
from monai.networks.nets import densenet121
from monai.transforms import LoadImageD, EnsureChannelFirstD, ScaleIntensityD, Compose, AsDiscreted

print_config()

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [12]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
if directory is not None:
    os.makedirs(directory, exist_ok=True)
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/workspace/Data


## Prepare Data with MONAI Transforms

We'll prepare our data using the same transforms as MONAI 101, but this time we'll also create a validation dataset. This separation is crucial for monitoring training progress and preventing overfitting.

In [13]:
transform = Compose(
    [
        LoadImageD(keys="image", image_only=True),
        EnsureChannelFirstD(keys="image"),
        ScaleIntensityD(keys="image"),
    ]
)

# If you use the MedNIST dataset, please acknowledge the source.
dataset = MedNISTDataset(root_dir=root_dir, transform=transform, section="training", download=True)
valdata = MedNISTDataset(root_dir=root_dir, transform=transform, section="validation", download=False)

2024-02-27 08:31:31,955 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
2024-02-27 08:31:31,955 - INFO - File exists: /workspace/Data/MedNIST.tar.gz, skipped downloading.
2024-02-27 08:31:31,956 - INFO - Non-empty folder exists in /workspace/Data/MedNIST, skipped extracting.


Loading dataset:   0%|          | 0/47164 [00:00<?, ?it/s]

Loading dataset: 100%|██████████| 47164/47164 [00:19<00:00, 2393.21it/s]
Loading dataset: 100%|██████████| 5895/5895 [00:02<00:00, 2465.05it/s]


## Advanced Training Setup with Evaluation and Monitoring

Now we'll create a more sophisticated training setup that includes validation monitoring and experiment tracking. This represents production-level best practices for medical AI development.

### Key Components

1. **`SupervisedEvaluator`**: Handles validation during training to monitor model performance
2. **`TensorBoardStatsHandler`**: Logs training metrics for visualization
3. **`CheckpointSaver`**: Automatically saves model checkpoints during training
4. **`ValidationHandler`**: Coordinates validation runs at specified intervals

This setup provides real-time monitoring of your model's learning progress and helps identify issues like overfitting early in the training process.

In [14]:
max_epochs = 5
save_interval = 2
out_dir = "./eval"
model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to("cuda:0")

logging.basicConfig(stream=sys.stdout, level=logging.INFO)

evaluator = SupervisedEvaluator(
    device=torch.device("cuda:0"),
    val_data_loader=DataLoader(valdata, batch_size=512, shuffle=False, num_workers=4),
    network=model,
    inferer=SimpleInferer(),
    key_val_metric={"val_acc": ignite.metrics.Accuracy(from_engine(["pred", "label"]))},
    val_handlers=[StatsHandler(iteration_log=False), TensorBoardStatsHandler(iteration_log=False)],
)

trainer = SupervisedTrainer(
    device=torch.device("cuda:0"),
    max_epochs=max_epochs,
    train_data_loader=DataLoader(dataset, batch_size=512, shuffle=True, num_workers=4),
    network=model,
    optimizer=torch.optim.Adam(model.parameters(), lr=1e-5),
    loss_function=torch.nn.CrossEntropyLoss(),
    inferer=SimpleInferer(),
    train_handlers=[
        ValidationHandler(validator=evaluator, epoch_level=True, interval=1),
        CheckpointSaver(
            save_dir=out_dir,
            save_dict={"model": model},
            save_interval=save_interval,
            save_final=True,
            final_filename="checkpoint.pt",
        ),
        StatsHandler(),
        TensorBoardStatsHandler(tag_name="train_loss", output_transform=from_engine(["loss"], first=True)),
    ],
)

## Run the training

In [None]:
trainer.run()

## Visualize Training Progress with TensorBoard

TensorBoard provides powerful visualization tools to monitor your training progress. You can view:
- Training and validation loss curves
- Model performance metrics over time
- Learning rate schedules
- Model architecture graphs

To view the results, uncomment and run the following cell. TensorBoard will open in your browser showing real-time training metrics.

In [2]:
# %load_ext tensorboard
# %tensorboard --logdir ./runs

## Inference

First thing to do is to prepare the test dataset:

In [6]:
dataset_dir = Path(root_dir, "MedNIST")
class_names = sorted(f"{x.name}" for x in dataset_dir.iterdir() if x.is_dir())
testdata = MedNISTDataset(root_dir=root_dir, transform=transform, section="test", download=False, runtime_cache=True)

Next, we're going to establish a `SupervisedEvaluator`. This evaluator will process all the files in the specified directory and persist the results into a CSV file. Validation handlers (val_handlers) will be utilized to load the checkpoint file, providing an error if any file is unavailable, and they will also save the classification outcomes.

In [10]:
evaluator = SupervisedEvaluator(
    device=torch.device("cuda:0"),
    val_data_loader=DataLoader(testdata, batch_size=1, num_workers=0),
    network=model,
    inferer=SimpleInferer(),
    postprocessing=AsDiscreted(keys="pred", argmax=True),
    val_handlers=[
        CheckpointLoader(load_path=f"{out_dir}/checkpoint.pt", load_dict={"model": model}),
        ClassificationSaver(
            batch_transform=lambda batch: batch[0]["image"].meta, output_transform=from_engine(["pred"])
        ),
    ],
)

evaluator.run()

INFO:ignite.engine.engine.SupervisedEvaluator:Engine run resuming from iteration 0, epoch 0 until 1 epochs
INFO:ignite.engine.engine.SupervisedEvaluator:Restored all variables from ./eval/checkpoint.pt
INFO:ignite.engine.engine.SupervisedEvaluator:Epoch[1] Complete. Time taken: 00:01:24.338
INFO:ignite.engine.engine.SupervisedEvaluator:Engine run complete. Time taken: 00:01:24.390


By default, the inference results are stored in a file named "predictions.csv". However, this output filename can be customized within the `ClassificationSaver` handler, according to your preferences.
Upon examining the output, one can note that the second column corresponds to the predicted class. A more discernable interpretation can be achieved by using these values as indices mapped to our predefined list of class names.

In [12]:
max_items_to_print = 10
for fn, idx in np.loadtxt("./predictions.csv", delimiter=",", dtype=str):
    print(fn, class_names[int(float(idx))])
    max_items_to_print -= 1
    if max_items_to_print == 0:
        break

/workspace/Data/MedNIST/AbdomenCT/006070.jpeg AbdomenCT
/workspace/Data/MedNIST/BreastMRI/006574.jpeg BreastMRI
/workspace/Data/MedNIST/ChestCT/009858.jpeg ChestCT
/workspace/Data/MedNIST/CXR/007398.jpeg CXR
/workspace/Data/MedNIST/Hand/005663.jpeg Hand
/workspace/Data/MedNIST/HeadCT/006896.jpeg HeadCT
/workspace/Data/MedNIST/HeadCT/007179.jpeg HeadCT
/workspace/Data/MedNIST/CXR/001190.jpeg CXR
/workspace/Data/MedNIST/ChestCT/005138.jpeg ChestCT
/workspace/Data/MedNIST/BreastMRI/000023.jpeg BreastMRI
