Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# GDS Integration

This notebook introduces how to integrate GDS into MONAI. It mainly includes several parts as shown below.
- What is GPUDirect Storage(GDS)?

    GDS is the newest addition to the GPUDirect family. Like GPUDirect peer to peer (https://developer.nvidia.com/gpudirect) that enables a direct memory access (DMA) path between the memory of two graphics processing units (GPUs) and GPUDirect RDMA that enables a direct DMA path to a network interface card (NIC), GDS enables a direct DMA data path between GPU memory and storage, thus avoiding a bounce buffer through the CPU. This direct path can increase system bandwidth while decreasing latency and utilization load on the CPU and GPU.

- GDS hardware and software requirements and how to install GDS.

    1. GDS has been tested on following NVIDIA GPUs: T10x, T4, A10, Quadro P6000, A100, and V100. For a full list of GPUs that GDS works with, refer to the [Known Limitations](https://docs.nvidia.com/gpudirect-storage/release-notes/index.html#known-limitations) section. For more requirements, you can refer to the 3 and 4 in this [link](https://docs.nvidia.com/gpudirect-storage/release-notes/index.html#mofed-fs-req).

    2. To install GDS, follow the detailed steps provided in this [section](https://docs.nvidia.com/gpudirect-storage/troubleshooting-guide/index.html#troubleshoot-install). To verify successful GDS installation, run the following command:
        
        ```/usr/local/cuda-<x>.<y>/gds/tools/gdscheck.py -p``` 
        
        (Replace X with the major version of the CUDA toolkit, and Y with the minor version.)

- A simple demo comparing the time taken with and without GDS.

   In this tutorial, we are creating a conda environment to install `kvikio`, which provides a Python API for GDS. To install `kvikio` using other methods, refer to https://github.com/rapidsai/kvikio#install.

    ```conda create -n gds_env -c rapidsai-nightly -c conda-forge python=3.10 cuda-version=11.8 kvikio```

- An End-to-end workflow Profiling Comparison

## Setup environment

In [1]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, matplotlib]"

## Setup imports

In [1]:
import os
import time
import torch
import shutil
import tempfile

import monai
import monai.transforms as mt
from monai.config import print_config
from monai.data.dataset import GDSDataset
from monai.utils import set_determinism

print_config()

  Referenced from: <E03EDA44-89AE-3115-9796-62BA9E0E2EDE> /Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/torchvision/image.so
  warn(


MONAI version: 1.3.0
Numpy version: 1.24.3
Pytorch version: 2.3.0.dev20240311
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 865972f7a791bf7b42efbcd87c8402bd865b329e
MONAI __file__: /Users/<username>/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.3.0
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: 0.23.2
scipy version: 1.11.2
Pillow version: 10.2.0
Tensorboard version: 2.15.2
gdown version: 4.7.3
TorchVision version: 0.15.2a0
tqdm version: 4.66.2
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.5
pandas version: 2.1.1
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified, a temporary directory will be used.

In [2]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/var/folders/85/ql5yb2_14pzc9s20ccfhqs7m0000gn/T/tmpqgohsp9f


## A simple demo to show how to use the GDS

### Download dataset and set dataset path

In [3]:
sample_url = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases"
sample_url += "/download/0.8.1/totalSegmentator_mergedLabel_samples.zip"
monai.apps.download_and_extract(sample_url, output_dir=root_dir, filepath="samples.zip")

base_name = os.path.join(root_dir, "totalSegmentator_mergedLabel_samples")
input_data = []
for filename in os.listdir(os.path.join(base_name, "imagesTr")):
    input_data.append(
        {
            "image": os.path.join(base_name, "imagesTr", filename),
            "label": os.path.join(base_name, "labelsTr", filename),
        }
    )

2024-05-08 13:31:29,803 - INFO - Expected md5 is None, skip md5 check for file samples.zip.
2024-05-08 13:31:29,804 - INFO - File exists: samples.zip, skipped downloading.
2024-05-08 13:31:29,805 - INFO - Writing into directory: /var/folders/85/ql5yb2_14pzc9s20ccfhqs7m0000gn/T/tmpqgohsp9f.


### Set deterministic for reproducibility

In [4]:
set_determinism(seed=0)

### Setup transforms

In [6]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
transform = mt.Compose(
    [
        mt.LoadImageD(keys=("image", "label"), image_only=True, ensure_channel_first=True),
        mt.SpacingD(keys=("image", "label"), pixdim=1.5),
        mt.EnsureTypeD(keys=("image", "label"), device=device),
        mt.RandRotateD(
            keys=("image", "label"),
            prob=1.0,
            range_x=0.1,
            range_y=0.1,
            range_z=0.3,
            mode=("bilinear", "nearest"),
        ),
        mt.RandZoomD(keys=("image", "label"), prob=1.0, min_zoom=0.8, max_zoom=1.2, mode=("trilinear", "nearest")),
        mt.ResizeWithPadOrCropD(keys=("image", "label"), spatial_size=(200, 210, 220)),
    ]
)

### Using GDSDataset

In [7]:
cache_dir = os.path.join(root_dir, "gds_cache_dir")
dataset = GDSDataset(data=input_data, transform=transform, cache_dir=cache_dir, device=0)

data_loader = monai.data.ThreadDataLoader(dataset, batch_size=1)

s = time.time()
for i in range(5):
    e = time.time()
    for _x in data_loader:
        pass
    print(f"epoch{i} time", time.time() - e)
print("total time", time.time() - s)

Exception in thread Thread-4 (enqueue_values):
Traceback (most recent call last):
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/threading.py", line 1038, in _bootstrap_inner
    self.run()
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/threading.py", line 975, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/monai/data/thread_buffer.py", line 49, in enqueue_values
    for src_val in self.src:
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/askelun

epoch0 time 0.4847869873046875


Exception in thread Thread-5 (enqueue_values):
Traceback (most recent call last):
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/threading.py", line 1038, in _bootstrap_inner
    self.run()
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/threading.py", line 975, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/monai/data/thread_buffer.py", line 49, in enqueue_values
    for src_val in self.src:
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/askelun

epoch1 time 0.4576592445373535


Exception in thread Thread-6 (enqueue_values):
Traceback (most recent call last):
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/threading.py", line 1038, in _bootstrap_inner
    self.run()
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/threading.py", line 975, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/monai/data/thread_buffer.py", line 49, in enqueue_values
    for src_val in self.src:
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/askelun

epoch2 time 0.4713747501373291


Exception in thread Thread-7 (enqueue_values):
Traceback (most recent call last):
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/threading.py", line 1038, in _bootstrap_inner
    self.run()
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/threading.py", line 975, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/monai/data/thread_buffer.py", line 49, in enqueue_values
    for src_val in self.src:
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/askelun

epoch3 time 0.4590342044830322


Exception in thread Thread-8 (enqueue_values):
Traceback (most recent call last):
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/threading.py", line 1038, in _bootstrap_inner
    self.run()
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/threading.py", line 975, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/monai/data/thread_buffer.py", line 49, in enqueue_values
    for src_val in self.src:
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/Users/askelundsgaard/opt/anaconda3/envs/MLA/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/askelun

epoch4 time 0.45446085929870605
total time 2.327802896499634


### Using PersistentDataset

In [9]:
cache_dir_per = os.path.join("./persistent_cache_dir")
dataset = monai.data.PersistentDataset(data=input_data, transform=transform, cache_dir=cache_dir_per)
data_loader = monai.data.ThreadDataLoader(dataset, batch_size=1)

s = time.time()
for i in range(5):
    e = time.time()
    for _x in data_loader:
        pass
    print(f"epoch{i} time", time.time() - e)
print("total time", time.time() - s)

KeyboardInterrupt: 

## An End-to-end workflow Profiling Comparison

We also conducted a quantitative analysis of the end-to-end workflow performence using the brats dataset. To learn how to implement the full pipeline, please follow this [tutorial](/home/lab/yliu/tutorials/acceleration/distributed_training/brats_training_ddp.py). The only step that requires modification is the dataset part. The end-to-end pipeline was benchmarked on a V100 32G GPU.

### Total time and every epoch time comparison
![gds_benchmark_total_epoch_time_comparison](../figures/gds_total_epoch_time_comparison.png)

### Total time to achieve metrics comparison
![gds_benchmark_achieve_metrics_comparison](../figures/gds_metric_time_epochs.png)

## Cleanup data directory

Remove directory if a temporary was used.

In [10]:
if directory is None:
    shutil.rmtree(root_dir)