# 

# BraTS Data Augmentation using DALI NVIDIA LIBRARY

> As the NVIDIA DALI LIBRARY allows one to augment input images on GPU, this notebook aims to test and verify the usability of the library in our project.

In [2]:
import random
import itertools
import os
from typing import Tuple, Literal

import json

from pathlib import Path
import os
from typing import Dict

import torch
import torchio as tio
import numpy as np
import nibabel as nib
import nvidia.dali.fn as fn
import nvidia.dali.math as math
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIGenericIterator

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def probabilistic_augmentation(probability, augmented, original):
    condiction = fn.cast(
        fn.random.coin_flip(probability=probability),
        dtype=types.DALIDataType.BOOL
    )
    negative_condition = condiction ^ True

    return condiction * augmented + negative_condition * original

In [4]:
def parse_json_to_dict(data_descriptors_path:str, phase:str) -> Dict:
    root_dir = os.path.join(Path.cwd(), Path(data_descriptors_path))
    json_path = os.path.join(root_dir, Path(f"{phase}.json"))
    with open(json_path, 'r') as j:
        parsed_json = json.load(j)
    return parsed_json

In [5]:
class NiftiIterator(object):
    def __init__(self, image_paths:str, label_paths:str) -> None:
        self.image_paths = image_paths
        self.label_paths = label_paths
        self._shuffle(self.image_paths, self.label_paths)

    def __iter__(self):
        self.i = 0
        self.n = len(self.image_paths)

        return self

    def __next__(self) -> Tuple:
        image_modalities_paths = self.image_paths[self.i]
        label_path = self.label_paths[self.i]

        image_modalities_data = []
        for path in image_modalities_paths:
            image_data = self._read_nifti_image(path)
            image_modalities_data.append(image_data)
        
        image = torch.concatenate(
            image_modalities_data, axis=0
        ).to(torch.float64)
        label = self._read_nifti_image(label_path).to(torch.uint8)

        """Iterates to the next sample, or if reached the end, reshufle
        paths and start again.
        """
        self.i = (self.i + 1) % self.n
        if self.i == 0:
            self._shuffle(self.image_paths, self.label_paths)

        image = image.unsqueeze(0)
        label = label.unsqueeze(0)
        print(image.shape)
        print(label.shape)

        return (image, label)

    def _shuffle(self, image_paths, label_paths) -> None:
        temp_list = list(zip(image_paths, label_paths))
        random.shuffle(temp_list)
        image_paths, label_paths = zip(*temp_list)
        
        self.image_paths = list(image_paths)
        self.label_paths = list(label_paths)

    def _read_nifti_image(self, image_path:str) -> torch.Tensor:
        image_data = tio.ScalarImage(image_path)[tio.DATA]

        return image_data

In [6]:
class GenericPipeline(Pipeline):
    def __init__(self, data_path:str, data_descriptors_path:str,
            phase: Literal["train", "validation", "test"], n_modalities:int,
            batch_size:int, num_threads:int, device_id:int, dim:int,
            patch_size:Tuple, load_to_gpu:bool, has_labels:bool
        ):
        super().__init__(batch_size, num_threads, device_id)
        self.data_path = data_path
        self.data_descriptors_path = data_descriptors_path
        self.phase = phase
        self.n_modalities = n_modalities
        self.dim = dim
        self.patch_size = patch_size
        self.load_to_gpu = load_to_gpu
        self.has_labels = has_labels

        image_paths, label_paths = self._get_image_paths(
            data_path=self.data_path,
            data_descriptors_path=self.data_descriptors_path,
            phase=self.phase,
            n_modalities=self.n_modalities
        )

        self.nift_iterator = NiftiIterator(
            image_paths=image_paths,
            label_paths=label_paths
        )        

    def _get_image_paths(self, data_path:str, data_descriptors_path:str,
            phase:Literal["train", "validation", "test"],
            n_modalities:int
        ):
        subject_paths = parse_json_to_dict(
            data_descriptors_path=data_descriptors_path,
            phase=phase
        )["ids"]

        if n_modalities == 2:
            modalities = ["flair", "t1ce"]
        elif n_modalities == 4:
            modalities = ["flair", "t1", "t1ce", "t2"]
        else:
            raise ValueError(
                f"Number of Modalities must be 2 or 4. Received {n_modalities}"
            )

        image_paths = list(
            map(lambda subject_path: [
                f"{data_path}/{subject_path}/{subject_path}_{modality}.nii.gz" \
                    for modality in modalities
            ], subject_paths)
        )
        label_paths = list(
            map(lambda subject_path: 
                f"{data_path}/{subject_path}/{subject_path}_seg.nii.gz"
            , subject_paths)
        )

        return image_paths, label_paths

    def _crop(self, data):
        return fn.crop(data, crop=self.patch_size, out_of_bounds_policy="pad")

    def _crop_fn(self, image, label):
        image, label = self.crop(image), self.crop(label)

        return image, label

class TrainPipeline(GenericPipeline):
    def __init__(self, data_path:str, data_descriptors_path:str,
            phase: Literal["train", "validation", "test"], n_modalities:int,
            batch_size:int, num_threads:int, device_id:int, dim:int,
            patch_size:Tuple, load_to_gpu:bool, has_labels:bool):
        super().__init__(
            data_path=data_path, data_descriptors_path=data_descriptors_path,
            phase=phase, n_modalities=n_modalities, batch_size=batch_size,
            num_threads=num_threads, device_id=device_id, dim=dim,
            patch_size=patch_size, load_to_gpu=load_to_gpu, 
            has_labels=has_labels
        )
        self.crop_shape = types.Constant(
            np.array(self.patch_size), dtype=types.INT64
        )
        self.crop_shape_float = types.Constant(
            np.array(self.patch_size), dtype=types.FLOAT
        )

    def _biased_crop_fn(self, image, label):
        # With probability of 0.4 the patch selected via random biased crop is
        # going to hold foreground voxels.
        roi_start, roi_end = fn.segmentation.random_object_bbox(
            label, background=0, format="start_end", cache_objects=False,
            foreground_prob=0.4
        )
        # Generates a Random Crop Window which coints the roi defined by
        # random_object_bbox.
        anchor = fn.roi_random_crop(
            label, roi_start=roi_start, roi_end=roi_end, 
            crop_shape=[1, *self.patch_size]
        )
        # Drop channels from anchor
        anchor = fn.slice(anchor, 1, 3, axes=[0])
        image, label = fn.slice(
            [image, label], anchor, self.crop_shape, axis_names="DHW",
            out_of_bounds_policy="pad"
        )

        print("Executing Biased Crop FN")

        return image.gpu(), label.gpu()

    def define_graph(self):
        image, label = fn.external_source(
            source=self.nift_iterator, num_outputs=2, 
            dtype=[types.FLOAT64, types.UINT8], batch=False
        )
        image, label = self._biased_crop_fn(image, label)

        return image, label


In [7]:
pipeline = TrainPipeline(
    data_path="../../msc_repo/datasets/RSNA_ASNR_MICCAI_BraTS2021_TrainingData_16July2021",
    data_descriptors_path="../src/data/descriptors/",
    phase="train", n_modalities=2, batch_size=1, num_threads=1, device_id=0,
    dim=4, patch_size=(128, 128, 128), load_to_gpu=True, has_labels=True
)

In [8]:
pipeline.build()

Executing Biased Crop FN


In [9]:
output = pipeline.run()

torch.Size([1, 2, 240, 240, 155])
torch.Size([1, 1, 240, 240, 155])
torch.Size([1, 2, 240, 240, 155])
torch.Size([1, 1, 240, 240, 155])


RuntimeError: Critical error in pipeline:
Error when executing CPU operator ROIRandomCrop encountered:
[/opt/dali/dali/pipeline/operator/arg_helper.h:223] Assert on "is_uniform(view_.shape) && expected_sh_span == view_.shape.tensor_shape_span(0)" failed: Expected uniform shape for argument "roi_start" but got shape {5}
Stacktrace (11 entries):
[frame 0]: /data_lids/home/crispim/.local/lib/python3.8/site-packages/nvidia/dali/libdali_operators.so(+0x5a6262) [0x7fc015c3e262]
[frame 1]: /data_lids/home/crispim/.local/lib/python3.8/site-packages/nvidia/dali/libdali_operators.so(+0x13fb8b5) [0x7fc016a938b5]
[frame 2]: /data_lids/home/crispim/.local/lib/python3.8/site-packages/nvidia/dali/libdali_operators.so(+0x13f69dc) [0x7fc016a8e9dc]
[frame 3]: /data_lids/home/crispim/.local/lib/python3.8/site-packages/nvidia/dali/libdali.so(void dali::Executor<dali::AOT_WS_Policy<dali::UniformQueuePolicy>, dali::UniformQueuePolicy>::RunHelper<dali::HostWorkspace>(dali::OpNode&, dali::HostWorkspace&)+0x80d) [0x7fc035e66cad]
[frame 4]: /data_lids/home/crispim/.local/lib/python3.8/site-packages/nvidia/dali/libdali.so(dali::Executor<dali::AOT_WS_Policy<dali::UniformQueuePolicy>, dali::UniformQueuePolicy>::RunCPUImpl()+0x218) [0x7fc035e6bd58]
[frame 5]: /data_lids/home/crispim/.local/lib/python3.8/site-packages/nvidia/dali/libdali.so(dali::Executor<dali::AOT_WS_Policy<dali::UniformQueuePolicy>, dali::UniformQueuePolicy>::RunCPU()+0xe) [0x7fc035e6c81e]
[frame 6]: /data_lids/home/crispim/.local/lib/python3.8/site-packages/nvidia/dali/libdali.so(+0xb818d) [0x7fc035e2318d]
[frame 7]: /data_lids/home/crispim/.local/lib/python3.8/site-packages/nvidia/dali/libdali.so(+0x130de4) [0x7fc035e9bde4]
[frame 8]: /data_lids/home/crispim/.local/lib/python3.8/site-packages/nvidia/dali/libdali.so(+0x72a8af) [0x7fc0364958af]
[frame 9]: /lib/x86_64-linux-gnu/libpthread.so.0(+0x8609) [0x7fc115090609]
[frame 10]: /lib/x86_64-linux-gnu/libc.so.6(clone+0x43) [0x7fc1151ca163]

Current pipeline object is no longer valid.

In [None]:
output

(TensorListGPU(
     [[[[[0. 0. ... 0. 0.]
         [0. 0. ... 0. 0.]
         ...
         [0. 0. ... 0. 0.]
         [0. 0. ... 0. 0.]]
 
        [[0. 0. ... 0. 0.]
         [0. 0. ... 0. 0.]
         ...
         [0. 0. ... 0. 0.]
         [0. 0. ... 0. 0.]]
 
        ...
 
        [[0. 0. ... 0. 0.]
         [0. 0. ... 0. 0.]
         ...
         [0. 0. ... 0. 0.]
         [0. 0. ... 0. 0.]]
 
        [[0. 0. ... 0. 0.]
         [0. 0. ... 0. 0.]
         ...
         [0. 0. ... 0. 0.]
         [0. 0. ... 0. 0.]]]
 
 
       [[[0. 0. ... 0. 0.]
         [0. 0. ... 0. 0.]
         ...
         [0. 0. ... 0. 0.]
         [0. 0. ... 0. 0.]]
 
        [[0. 0. ... 0. 0.]
         [0. 0. ... 0. 0.]
         ...
         [0. 0. ... 0. 0.]
         [0. 0. ... 0. 0.]]
 
        ...
 
        [[0. 0. ... 0. 0.]
         [0. 0. ... 0. 0.]
         ...
         [0. 0. ... 0. 0.]
         [0. 0. ... 0. 0.]]
 
        [[0. 0. ... 0. 0.]
         [0. 0. ... 0. 0.]
         ...
         [0. 0. ..