Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# PyTorch 2.0 Integration

This notebook introduces how to use `torch.compile` in MONAI pipeline. It mainly includes several parts as shown below.
- What is torch.compile?

    `torch.compile`` is the main API for PyTorch 2.0, which wraps your model and returns a compiled model. It is a fully additive (and optional) feature and hence 2.0 is 100% backward compatible by definition.

- GDS hardware and software requirements and how to install GDS.



- A simple demo comparing the time taken with and without GDS.

   In this tutorial, we are creating a conda environment to install `kvikio`, which provides a Python API for GDS. To install `kvikio` using other methods, refer to https://github.com/rapidsai/kvikio#install.

    ```conda create -n gds_env -c rapidsai-nightly -c conda-forge python=3.10 cuda-version=11.8 kvikio```

- An End-to-end workflow Profiling Comparison

## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, matplotlib]"
!pip install -q torch>=2.1.0

## Setup imports

In [None]:
import os
import time
import torch
import shutil
import tempfile

import monai
import monai.transforms as mt
from monai.config import print_config
from monai.utils import set_determinism

print_config()

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified, a temporary directory will be used.

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

## A simple demo to show how to use the `torch.compile`

In [None]:
sample_url = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases"
sample_url += "/download/0.8.1/totalSegmentator_mergedLabel_samples.zip"
monai.apps.download_and_extract(sample_url, output_dir=root_dir, filepath="samples.zip")

base_name = os.path.join(root_dir, "totalSegmentator_mergedLabel_samples")
input_data = []
for filename in os.listdir(os.path.join(base_name, "imagesTr")):
    input_data.append(
        {
            "image": os.path.join(base_name, "imagesTr", filename),
            "label": os.path.join(base_name, "labelsTr", filename),
        }
    )

### Set deterministic for reproducibility

In [None]:
set_determinism(seed=0)

### Setup transforms

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
transform = mt.Compose(
    [
        mt.LoadImageD(keys=("image", "label"), image_only=True, ensure_channel_first=True),
        mt.SpacingD(keys=("image", "label"), pixdim=1.5),
        mt.EnsureTypeD(keys=("image", "label"), device=device),
        mt.RandRotateD(
            keys=("image", "label"),
            prob=1.0,
            range_x=0.1,
            range_y=0.1,
            range_z=0.3,
            mode=("bilinear", "nearest"),
        ),
        mt.RandZoomD(keys=("image", "label"), prob=1.0, min_zoom=0.8, max_zoom=1.2, mode=("trilinear", "nearest")),
        mt.ResizeWithPadOrCropD(keys=("image", "label"), spatial_size=(200, 210, 220)),
    ]
)