Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# GDS Integration

This notebook introduces how to integrate GDS into MONAI. It mainly includes several parts as shown below.
- What is GPUDirect Storage(GDS)?

    GDS is the newest addition to the GPUDirect family. Like GPUDirect peer to peer (https://developer.nvidia.com/gpudirect) that enables a direct memory access (DMA) path between the memory of two graphics processing units (GPUs) and GPUDirect RDMA that enables a direct DMA path to a network interface card (NIC), GDS enables a direct DMA data path between GPU memory and storage, thus avoiding a bounce buffer through the CPU. This direct path can increase system bandwidth while decreasing latency and utilization load on the CPU and GPU.

- GDS hardware and software requirements and how to enable GDS.

    1. GDS has been tested on following NVIDIA GPUs: T10x, T4, A10, Quadro P6000, A100, and V100. For a full list of GPUs that GDS works with, refer to the [Known Limitations](https://docs.nvidia.com/gpudirect-storage/release-notes/index.html#known-limitations) section. For more requirments, you can refer to the 3 and 4 in this [link](https://docs.nvidia.com/gpudirect-storage/release-notes/index.html#mofed-fs-req).

    2. To enable GDS on bare metal, follow the detailed steps provided in this [section](https://docs.nvidia.com/dgx/dgx-os-6-user-guide/additional_software.html#installing-gpudirect-storage-support). To verify successful GDS installation, run the following command:
        
        ```/usr/local/cuda-<x>.<y>/gds/tools/gdscheck.py -p``` 
        
        (Replace X with the major version of the CUDA toolkit, and Y with the minor version.)

- `GDSDataset` inherited from `PersistentDataset`.

    In this tutorial, we have implemented a `GDSDataset` that inherits from `PersistentDataset`. We have re-implemented the `_cachecheck` method to create and save cache using GDS.

- A simple demo comparing the time taken with and without GDS.

   In this tutorial, we are creating a conda environment to install `kvikio`, which provides a Python API for GDS. To install `kvikio` using other methods, refer to https://github.com/rapidsai/kvikio#install.

    ```conda create -n gds_env -c rapidsai-nightly -c conda-forge python=3.10 cuda-version=11.8 kvikio```

## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel]"

## Setup imports

In [None]:
import os
import time
import cupy
import torch
import shutil
import tempfile
import numpy as np
from typing import Any
from pathlib import Path
from copy import deepcopy
from collections.abc import Callable, Sequence
from kvikio.numpy import fromfile, tofile

import monai
import monai.transforms as mt
from monai.config import print_config
from monai.data.utils import pickle_hashing
from monai.utils import convert_to_tensor, convert_to_cupy, set_determinism

print_config()

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified, a temporary directory will be used.

In [9]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/tmp/tmppwi1_sx5


## GDSDataset

In [10]:
class GDSDataset(monai.data.PersistentDataset):
    def __init__(
        self,
        data: Sequence,
        transform: Sequence[Callable] | Callable,
        cache_dir: Path | str | None,
        hash_func: Callable[..., bytes] = pickle_hashing,
        hash_transform: Callable[..., bytes] | None = None,
        reset_ops_id: bool = True,
        device : int = None,
        **kwargs: Any,
    ) -> None:
        super().__init__(
            data=data,
            transform=transform,
            cache_dir=cache_dir,
            hash_func = hash_func,
            hash_transform = hash_transform,
            reset_ops_id = reset_ops_id,
            **kwargs
            )
        self.device = device
        self.shape_dict = {}

    def _cachecheck(self, item_transformed):
        """given the input dictionary ``item_transformed``, return a transformed version of it"""
        hashfile = None
        # compute a cache id
        if self.cache_dir is not None:
            data_item_md5 = self.hash_func(item_transformed).decode("utf-8")
            data_item_md5 += self.transform_hash
            hashfile = self.cache_dir / f"{data_item_md5}.pt"

        if hashfile is not None and hashfile.is_file():  # cache hit
            with cupy.cuda.Device(self.device):
                item = {}
                for k in item_transformed:
                    if f"{hashfile}-{k}" in self.shape_dict:
                        shape_k = self.shape_dict[f"{hashfile}-{k}"]
                    else:
                        shape_k = fromfile(f"{hashfile}-{k}-shape", dtype=int)
                        self.shape_dict[f"{hashfile}-{k}"] = shape_k
                    item[k] = fromfile(f"{hashfile}-{k}", dtype=np.float32, like=cupy.empty(()))
                    item[k] = convert_to_tensor(item[k].reshape(shape_k), device=f"cuda:{self.device}")
                return item

        # create new cache
        _item_transformed = self._pre_transform(deepcopy(item_transformed))  # keep the original hashed
        if hashfile is None:
            return _item_transformed

        for k in _item_transformed:  # {'image': ..., 'label': ...}
            item_k_shape = convert_to_cupy(_item_transformed[k].shape, wrap_sequence=True)
            tofile(item_k_shape, f"{hashfile}-{k}-shape")
            item_k = _item_transformed[k]
            tofile(item_k, f"{hashfile}-{k}")
        open(hashfile, "a").close()  # store cacheid
        return _item_transformed

## A simple demo to show how to use the GDS

### Download dataset and set dataset path

In [11]:
sample_url = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases"
sample_url += "/download/0.8.1/totalSegmentator_mergedLabel_samples.zip"
monai.apps.download_and_extract(
    sample_url, output_dir=root_dir, filepath="samples.zip"
)

base_name = os.path.join(root_dir, "totalSegmentator_mergedLabel_samples")
input_data = []
for filename in os.listdir(os.path.join(base_name, "imagesTr")):
    input_data.append(
        {
            "image": os.path.join(base_name, "imagesTr", filename),
            "label": os.path.join(base_name, "labelsTr", filename),
        }
    )

2023-07-10 22:35:29,985 - INFO - Expected md5 is None, skip md5 check for file samples.zip.
2023-07-10 22:35:29,987 - INFO - File exists: samples.zip, skipped downloading.
2023-07-10 22:35:29,988 - INFO - Writing into directory: /tmp/tmppwi1_sx5.


### Set deterministic for reproducibility

In [12]:
set_determinism(seed=0)

### Setup transforms

In [13]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
transform = mt.Compose(
    [
        mt.LoadImageD(keys=("image", "label"), image_only=True, ensure_channel_first=True),
        mt.EnsureTypeD(keys=("image", "label"), device=device),
        mt.SpacingD(keys=("image", "label"), pixdim=1.5),
        mt.RandRotateD(
            keys=("image", "label"), prob=1.0, range_x=0.1, range_y=0.1, range_z=0.3, mode=("bilinear", "nearest"),
        ),
        mt.RandZoomD(keys=("image", "label"), prob=1.0, min_zoom=0.8, max_zoom=1.2, mode=("trilinear", "nearest")),
        mt.ResizeWithPadOrCropD(keys=("image", "label"), spatial_size=(200, 210, 220)),
    ]
)

### Using GDSDataset

In [None]:
cache_dir = os.path.join(root_dir, "gds_cache_dir")
dataset = GDSDataset(data=input_data, transform=transform, cache_dir=cache_dir, device=0)

data_loader = monai.data.ThreadDataLoader(dataset, batch_size=1)

s = time.time()
for x in range(1):
    e = time.time()
    for x in data_loader:
        print(x["image"].dtype, x["image"].device, x["image"].shape)
    print("epoch time", time.time() - e)
print("total time", time.time() - s)

### Using PersistentDataset

In [None]:
cache_dir_per = os.path.join(root_dir, "persistent_cache_dir")
dataset = monai.data.PersistentDataset(data=input_data, transform=transform, cache_dir=cache_dir_per)
data_loader = monai.data.ThreadDataLoader(dataset, batch_size=1)

s = time.time()
for x in range(1):
    e = time.time()
    for x in data_loader:
        print(x["image"].dtype, x["image"].device, x["image"].shape)
    print("epoch time", time.time() - e)
print("total time", time.time() - s)

## Cleanup data directory

Remove directory if a temporary was used.

In [16]:
if directory is None:
    shutil.rmtree(root_dir)