Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# PyTorch 2.0 Integration

This notebook introduces how to use `torch.compile` in MONAI pipeline. It mainly includes several parts as shown below.
- What is torch.compile?

    `torch.compile` is the main API for PyTorch 2.0, which wraps your model and returns a compiled model. It is a fully additive (and optional) feature and hence 2.0 is 100% backward compatible by definition.

- A simple demo to show how to use the `torch.compile`.

- Use the `torch.compile` in the bundle.

- Compared results

    We run an end-to-end pipeline based on ["fast_training_tutorial.ipynb"](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/fast_training_tutorial.ipynb), and we can see a 10% speed up.

## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, matplotlib]"
!pip install -q torch>=2.1.0

## Setup imports

In [None]:
import os
import time
import torch
import tempfile

import monai
import monai.transforms as mt
from monai.config import print_config
from monai.utils import set_determinism
from monai.bundle import download, create_workflow
from monai.engines import SupervisedTrainer

print_config()

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified, a temporary directory will be used.

In [2]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/workspace/data


## A simple demo to show how to use the `torch.compile`

In [3]:
sample_url = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases"
sample_url += "/download/0.8.1/totalSegmentator_mergedLabel_samples.zip"
monai.apps.download_and_extract(sample_url, output_dir=root_dir, filepath="samples.zip")

base_name = os.path.join(root_dir, "totalSegmentator_mergedLabel_samples")
input_data = []
for filename in os.listdir(os.path.join(base_name, "imagesTr")):
    input_data.append(
        {
            "image": os.path.join(base_name, "imagesTr", filename),
            "label": os.path.join(base_name, "labelsTr", filename),
        }
    )

2024-01-09 09:01:27,769 - INFO - Expected md5 is None, skip md5 check for file samples.zip.
2024-01-09 09:01:27,769 - INFO - File exists: samples.zip, skipped downloading.
2024-01-09 09:01:27,771 - INFO - Writing into directory: /workspace/data.


### Set deterministic for reproducibility

In [4]:
set_determinism(seed=0)

### Setup transforms

In [5]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
transform = mt.Compose(
    [
        mt.LoadImageD(keys=("image", "label"), image_only=True, ensure_channel_first=True),
        mt.SpacingD(keys=("image", "label"), pixdim=1.5),
        mt.EnsureTypeD(keys=("image", "label"), device=device),
        mt.RandRotateD(
            keys=("image", "label"),
            prob=1.0,
            range_x=0.1,
            range_y=0.1,
            range_z=0.3,
            mode=("bilinear", "nearest"),
        ),
        mt.RandZoomD(keys=("image", "label"), prob=1.0, min_zoom=0.8, max_zoom=1.2, mode=("trilinear", "nearest")),
        mt.ResizeWithPadOrCropD(keys=("image", "label"), spatial_size=(96, 96, 96)),
        # add `FromMetaTensorD` to convert `MetaTensor` to `torch.Tensor`
        mt.FromMetaTensorD(keys=("image", "label")),
    ]
)

### Create model

Here we used `create_workflow` to get the network instance from the bundle. You can also initialize your own network.

In [6]:
bundle_dir = "./bundle"
os.makedirs(bundle_dir, exist_ok=True)

bundle = download("wholeBody_ct_segmentation", bundle_dir=bundle_dir)
config_file = os.path.join(bundle_dir, "wholeBody_ct_segmentation/configs/train.json")
train_workflow = create_workflow(config_file=str(config_file), workflow_type="train")


def create_model():
    return train_workflow.network_def.to(device)

2024-01-09 09:01:30,861 - INFO - --- input summary of monai.bundle.scripts.download ---
2024-01-09 09:01:30,863 - INFO - > name: 'wholeBody_ct_segmentation'
2024-01-09 09:01:30,864 - INFO - > bundle_dir: './bundle'
2024-01-09 09:01:30,865 - INFO - > source: 'monaihosting'
2024-01-09 09:01:30,865 - INFO - > remove_prefix: 'monai_'
2024-01-09 09:01:30,866 - INFO - > progress: True
2024-01-09 09:01:30,867 - INFO - ---


2024-01-09 09:01:31,054 - INFO - Expected md5 is None, skip md5 check for file bundle/wholeBody_ct_segmentation_v0.2.1.zip.
2024-01-09 09:01:31,055 - INFO - File exists: bundle/wholeBody_ct_segmentation_v0.2.1.zip, skipped downloading.
2024-01-09 09:01:31,056 - INFO - Writing into directory: bundle.
2024-01-09 09:01:31,968 - INFO - --- input summary of monai.bundle.scripts.run ---
2024-01-09 09:01:31,969 - INFO - > config_file: './bundle/wholeBody_ct_segmentation/configs/train.json'
2024-01-09 09:01:31,970 - INFO - > workflow_type: 'train'
2024-01-09 09:01:31,971 - INFO - 

### Without compile

In [7]:
epoch_num = 100
dataset = monai.data.CacheDataset(data=input_data, transform=transform, cache_rate=1.0, num_workers=4)
data_loader = monai.data.DataLoader(dataset, batch_size=1)

model = create_model()
s = time.time()
for i in range(epoch_num):
    e = time.time()
    for batch_data in data_loader:
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        out = model(inputs)
    if i <= 5:
        print(f"epoch{i} time", time.time() - e)
print("total time", time.time() - s)

Loading dataset: 100%|██████████| 20/20 [00:05<00:00,  3.46it/s]


epoch0 time 2.3300938606262207
epoch1 time 1.0478227138519287
epoch2 time 1.0480997562408447
epoch3 time 1.0515162944793701
epoch4 time 1.0385167598724365
epoch5 time 1.0458405017852783
total time 105.7378671169281


### With compile

The only difference is that we wrap the model with `torch.compile`. As [torch](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) mentioned, we can see that `torch.compile`` takes longer in the first epoch, as it must compile the model, but in subsequent iterations, we can see a significant improvement compared to eager.

In [8]:
dataset = monai.data.CacheDataset(data=input_data, transform=transform, cache_rate=1.0, num_workers=4)
data_loader = monai.data.DataLoader(dataset, batch_size=1)

model = torch.compile(create_model())
s = time.time()
for i in range(epoch_num):
    e = time.time()
    for batch_data in data_loader:
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        out = model(inputs)
    if i <= 5:
        print(f"epoch{i} time", time.time() - e)
print("total time", time.time() - s)

Loading dataset: 100%|██████████| 20/20 [00:04<00:00,  4.15it/s]


epoch0 time 15.756214141845703
epoch1 time 0.528465986251831
epoch2 time 0.5261788368225098
epoch3 time 0.5370438098907471
epoch4 time 0.532045841217041
epoch5 time 0.5341622829437256
total time 67.98181772232056


## Use the `torch.compile` in the bundle

We can simply set `compile=True` in the `SupervisedTrainer` and `SupervisedEvaluator`. Here we convert data to `torch.Tensor` internally if set `compile=True`. Here is the [ticket](https://github.com/pytorch/pytorch/issues/117026) we can track.

In [None]:
trainer = SupervisedTrainer(
    device=device,
    max_epochs=epoch_num,
    train_data_loader=data_loader,
    network=model,
    # optimizer=optimizer,
    # loss_function=loss_function,
    # inferer=SimpleInferer(),
    # postprocessing=post_transform,
    # amp=args.amp,
    # key_train_metric={
    #     "train_dice": MeanDice(
    #         include_background=False,
    #         output_transform=from_engine(["pred", "label"]),
    #     )
    # },
    compile=True,
    # you can also add `compile_kwargs` dict of the args for `torch.compile()` API
    compile_kwargs={},
)

## Compared results

We used `torch.compile` in fast_training_tutorial.ipynb and see a 10% speed up.

![compile_benchmark_total_epoch_time_comparison](../figures/total_epoch_time_comparison-compile.png)