# BraTS Data Augmentation using DALI NVIDIA LIBRARY

> As the NVIDIA DALI LIBRARY allows one to augment input images on GPU, this notebook aims to test and verify the usability of the library in our project.

In [2]:
import random
import itertools
import os
from typing import Tuple, Literal

import json

import matplotlib.pyplot as plt
from pathlib import Path
import os
from typing import Dict

import torch
import torchio as tio
import numpy as np
import nibabel as nib
import nvidia.dali.fn as fn
import nvidia.dali.math as math
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline

from monai.transforms import CropForeground
from monai.transforms import SpatialPad

In [3]:
import SimpleITK as sitk
from ipywidgets import interact, interactive, IntSlider, ToggleButtons

In [4]:
def probabilistic_augmentation(probability, augmented, original):
    condiction = fn.cast(
        fn.random.coin_flip(probability=probability),
        dtype=types.DALIDataType.BOOL
    )
    negative_condition = condiction ^ True

    return condiction * augmented + negative_condition * original

In [5]:
def parse_json_to_dict(data_descriptors_path:str, phase:str) -> Dict:
    root_dir = os.path.join(Path.cwd(), Path(data_descriptors_path))
    json_path = os.path.join(root_dir, Path(f"{phase}.json"))
    with open(json_path, 'r') as j:
        parsed_json = json.load(j)
    return parsed_json

In [6]:
class NiftiIterator(object):
    def __init__(self, image_paths:str, label_paths:str, crop:bool=True) -> None:
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.crop = crop
        self.dataset_len = len(self.image_paths)
        # Attribute to track image ids for each data pair
        self.seen_data = []
        
        if self.crop:
            # Objects to tackle image pre-processing
            self.cropper = CropForeground(
                select_fn=lambda x: x != 0, margin=0, return_coords=True
            )
            self.padder = SpatialPad(spatial_size=(128, 128, 128))

    def __iter__(self):
        self.idx = 0

        return self

    def __next__(self) -> Tuple:
        if self.idx == 0:
            self._shuffle(self.image_paths, self.label_paths)
        
        image_modalities_paths = self.image_paths[self.idx]
        label_path = self.label_paths[self.idx]
        self.seen_data.append(label_path.split("/")[-2])

        image_modalities_data = []
        for path in image_modalities_paths:
            image_data = self._read_nifti_image(path)
            image_data = self._normalize(image_data)
            image_modalities_data.append(image_data)
        
        image = torch.cat(
            image_modalities_data, axis=0
        ).to(torch.float)
        label = self._read_nifti_image(label_path).to(torch.uint8)
        
        if self.crop:
            bbox_start, bbox_end = self.cropper.compute_bounding_box(image)
            image = image[
                0:len(image_modalities_paths),
                bbox_start[0]:bbox_end[0], 
                bbox_start[1]:bbox_end[1], 
                bbox_start[2]:bbox_end[2]
            ].contiguous()
            label = label[
                :1,
                bbox_start[0]:bbox_end[0], 
                bbox_start[1]:bbox_end[1], 
                bbox_start[2]:bbox_end[2]
            ].contiguous()
            
            # Verify if croped image is smaller than 128, 128, 128 and pad it
            image = self.padder(image).as_tensor()
            label = self.padder(label).as_tensor()

        """Iterates to the next sample, or if reached the end, reshufle
        paths and start again.
        """
        self.idx = (self.idx + 1) % self.dataset_len
            
        return image, label    

    def _normalize(self, image):
        image = image.to(torch.float)
        non_zero_voxels = image[image != 0]
        mean = torch.mean(non_zero_voxels)
        std = torch.std(non_zero_voxels)

        normalized_image = image
        normalized_image[image != 0] = (non_zero_voxels - mean) / std
        
        return normalized_image
    
    def _shuffle(self, image_paths, label_paths) -> None:
        temp_list = list(zip(image_paths, label_paths))
        random.shuffle(temp_list)
        image_paths, label_paths = zip(*temp_list)
        
        self.image_paths = list(image_paths)
        self.label_paths = list(label_paths)

    def _read_nifti_image(self, image_path:str) -> torch.Tensor:
        image_data = tio.ScalarImage(image_path)[tio.DATA]

        return image_data
    
    def reset(self) -> None:
        self.seen_data = self.seen_data[self.dataset_len:]

In [7]:
class GenericPipeline(Pipeline):
    def __init__(self, data_path:str, data_descriptors_path:str,
            phase: Literal["train", "validation", "test"], n_modalities:int,
            batch_size:int, num_threads:int, device_id:int, dim:int,
            patch_size:Tuple, load_to_gpu:bool, has_labels:bool, crop:bool
        ):
        super().__init__(batch_size, num_threads, device_id)
        self.data_path = data_path
        self.data_descriptors_path = data_descriptors_path
        self.phase = phase
        self.n_modalities = n_modalities
        self.dim = dim
        self.patch_size = patch_size
        self.load_to_gpu = load_to_gpu
        self.has_labels = has_labels

        image_paths, label_paths = self._get_image_paths(
            data_path=self.data_path,
            data_descriptors_path=self.data_descriptors_path,
            phase=self.phase,
            n_modalities=self.n_modalities
        )

        self.nift_iterator = NiftiIterator(
            image_paths=image_paths,
            label_paths=label_paths,
            crop=crop
        )

    def _get_image_paths(self, data_path:str, data_descriptors_path:str,
            phase:Literal["train", "validation", "test"],
            n_modalities:int
        ):
        subject_paths = parse_json_to_dict(
            data_descriptors_path=data_descriptors_path,
            phase=phase
        )["ids"]

        if n_modalities == 2:
            modalities = ["flair", "t1ce"]
        elif n_modalities == 4:
            modalities = ["flair", "t1", "t1ce", "t2"]
        else:
            raise ValueError(
                f"Number of Modalities must be 2 or 4. Received {n_modalities}"
            )

        image_paths = list(
            map(lambda subject_path: [
                f"{data_path}/{subject_path}/{subject_path}_{modality}.nii.gz" \
                    for modality in modalities
            ], subject_paths)
        )
        label_paths = list(
            map(lambda subject_path: 
                f"{data_path}/{subject_path}/{subject_path}_seg.nii.gz"
            , subject_paths)
        )

        return image_paths, label_paths

    def _crop(self, data):
        return fn.crop(data, crop=self.patch_size, out_of_bounds_policy="pad")

    def _crop_fn(self, image, label):
        image, label = self.crop(image), self.crop(label)

        return image, label

class TrainPipeline(GenericPipeline):
    def __init__(self, data_path:str, data_descriptors_path:str,
            phase: Literal["train", "validation", "test"], n_modalities:int,
            batch_size:int, num_threads:int, device_id:int, dim:int,
            patch_size:Tuple, load_to_gpu:bool, has_labels:bool, crop:bool):
        super().__init__(
            data_path=data_path, data_descriptors_path=data_descriptors_path,
            phase=phase, n_modalities=n_modalities, batch_size=batch_size,
            num_threads=num_threads, device_id=device_id, dim=dim,
            patch_size=patch_size, load_to_gpu=load_to_gpu, 
            has_labels=has_labels, crop=crop
        )
        self.crop_shape = types.Constant(
            np.array(self.patch_size), dtype=types.INT64
        )
        self.crop_shape_float = types.Constant(
            np.array(self.patch_size), dtype=types.FLOAT
        )

    def _biased_crop_fn(self, image, label):
        # With probability of 0.4 the patch selected via random biased crop is
        # going to hold foreground voxels.
        roi_start, roi_end = fn.segmentation.random_object_bbox(
            label, background=0, format="start_end", cache_objects=False,
            foreground_prob=0.4
        )
        # Generates a Random Crop Window which coints the roi defined by
        # random_object_bbox.
        anchor = fn.roi_random_crop(
            label, roi_start=roi_start, roi_end=roi_end, 
            crop_shape=[1, *self.patch_size]
        )
        # Drop channels from anchor
        anchor = fn.slice(anchor, 1, 3, axes=[0])
        image, label = fn.slice(
            [image, label], anchor, self.crop_shape, axis_names="DHW",
            out_of_bounds_policy="pad"
        )

        return image.gpu(), label.gpu()
    
    def _resize(self, data, interpolation_type):
        return fn.resize(data, interp_type=interpolation_type, size=self.crop_shape_float)
    
    def _zoom_fn(self, image, label):
        scale = probabilistic_augmentation(0.15, fn.random.uniform(range=(1.0, 1.4)), 1.0)
        c, h, w = [scale * x for x in self.patch_size]
        
        image = fn.crop(image, crop_h=h, crop_w=w, crop_d=c, out_of_bounds_policy="pad")
        label = fn.crop(label, crop_h=h, crop_w=w, crop_d=c, out_of_bounds_policy="pad")
        image = self._resize(image, types.DALIInterpType.INTERP_CUBIC)
        label = self._resize(label, types.DALIInterpType.INTERP_NN)

        return image, label
    
    def _flips_fn(self, image, label):
        kwargs = {
            "horizontal": fn.random.coin_flip(probability=0.5),
            "vertical": fn.random.coin_flip(probability=0.5),
            "depthwise": fn.random.coin_flip(probability=0.5)
        }
        
        return fn.flip(image, **kwargs), fn.flip(label, **kwargs)
    
    def _noise_fn(self, image):
        image_noised = image + fn.random.normal(image, stddev=fn.random.uniform(range=(0.0, 0.33)))
        
        return probabilistic_augmentation(0.15, image_noised, image)
    
    def _blur_fn(self, image):
        image_blurred = fn.gaussian_blur(image, sigma=fn.random.uniform(range=(0.5, 1.5)))
        
        return probabilistic_augmentation(0.15, image_blurred, image)
    
    def _brightness_fn(self, image):
        brightness_scale = probabilistic_augmentation(0.15, fn.random.uniform(range=(0.7, 1.3)), 1.0)
        image = image * brightness_scale
        
        return image
    
    def _contrast_fn(self, image):
        scale = probabilistic_augmentation(0.15, fn.random.uniform(range=(0.65, 1.5)), 1.0)
        image = math.clamp(image * scale, fn.reductions.min(image), fn.reductions.max(image))        
        
        return image

    def define_graph(self):
        image, label = fn.external_source(
            source=self.nift_iterator, num_outputs=2, 
            dtype=[types.FLOAT, types.UINT8], batch=False
        )
        image = fn.reshape(image, layout="CDHW")
        label = fn.reshape(label, layout="CDHW")
        image, label = self._biased_crop_fn(image, label)
        image, label = self._zoom_fn(image, label)
        image, label = self._flips_fn(image, label)
        image = self._noise_fn(image)
        image = self._blur_fn(image)
        image = self._brightness_fn(image)
        image = self._contrast_fn(image)

        return (image, label)

In [8]:
batch_size = 2

In [9]:
pipeline = TrainPipeline(
    data_path="../datasets/RSNA_ASNR_MICCAI_BraTS2021_TrainingData_16July2021",
    data_descriptors_path="../src/data/descriptors/",
    phase="train", n_modalities=2, batch_size=2, num_threads=1, device_id=0,
    dim=4, patch_size=(128, 128, 128), load_to_gpu=True, has_labels=True, crop=True
)

In [10]:
pipeline.build()

In [134]:
pipeline.nift_iterator.dataset_len

875

In [135]:
n_steps = int(np.ceil(pipeline.nift_iterator.dataset_len / 2))

In [136]:
print(pipeline.nift_iterator.image_paths[-3])
print(pipeline.nift_iterator.image_paths[-2])
print(pipeline.nift_iterator.image_paths[-1])

['../datasets/RSNA_ASNR_MICCAI_BraTS2021_TrainingData_16July2021/BraTS2021_01624/BraTS2021_01624_flair.nii.gz', '../datasets/RSNA_ASNR_MICCAI_BraTS2021_TrainingData_16July2021/BraTS2021_01624/BraTS2021_01624_t1ce.nii.gz']
['../datasets/RSNA_ASNR_MICCAI_BraTS2021_TrainingData_16July2021/BraTS2021_01621/BraTS2021_01621_flair.nii.gz', '../datasets/RSNA_ASNR_MICCAI_BraTS2021_TrainingData_16July2021/BraTS2021_01621/BraTS2021_01621_t1ce.nii.gz']
['../datasets/RSNA_ASNR_MICCAI_BraTS2021_TrainingData_16July2021/BraTS2021_00545/BraTS2021_00545_flair.nii.gz', '../datasets/RSNA_ASNR_MICCAI_BraTS2021_TrainingData_16July2021/BraTS2021_00545/BraTS2021_00545_t1ce.nii.gz']


In [137]:
pipeline.nift_iterator.image_paths[-1]

['../datasets/RSNA_ASNR_MICCAI_BraTS2021_TrainingData_16July2021/BraTS2021_00545/BraTS2021_00545_flair.nii.gz',
 '../datasets/RSNA_ASNR_MICCAI_BraTS2021_TrainingData_16July2021/BraTS2021_00545/BraTS2021_00545_t1ce.nii.gz']

In [138]:
for step in range(n_steps):
    output = pipeline.run()
    if step == 0:
        print(pipeline.nift_iterator.image_paths[-1])
    image, labels = output
    paths = pipeline.nift_iterator.seen_data[step * batch_size : step * batch_size + batch_size]
    print(60*"*")
    print(step)
    print(images.as_cpu().as_array().shape)
    print(labels.as_cpu().as_array().shape)
    print(paths)
    print(60*"*")
pipeline.nift_iterator.reset()
print(pipeline.nift_iterator.seen_data)

['../datasets/RSNA_ASNR_MICCAI_BraTS2021_TrainingData_16July2021/BraTS2021_01117/BraTS2021_01117_flair.nii.gz', '../datasets/RSNA_ASNR_MICCAI_BraTS2021_TrainingData_16July2021/BraTS2021_01117/BraTS2021_01117_t1ce.nii.gz']
************************************************************
0
(2, 2, 128, 128, 128)
(2, 1, 128, 128, 128)
['BraTS2021_00457', 'BraTS2021_00620']
************************************************************
************************************************************
1
(2, 2, 128, 128, 128)
(2, 1, 128, 128, 128)
['BraTS2021_01224', 'BraTS2021_01036']
************************************************************
************************************************************
2
(2, 2, 128, 128, 128)
(2, 1, 128, 128, 128)
['BraTS2021_00192', 'BraTS2021_01450']
************************************************************
************************************************************
3
(2, 2, 128, 128, 128)
(2, 1, 128, 128, 128)
['BraTS2021_01433', 'BraTS2021_01124']
***********

In [12]:
output = pipeline.run()

In [13]:
images, labels = output

In [28]:
torch.Tensor(images.as_cpu().as_array()).shape

torch.Size([2, 2, 128, 128, 128])

In [11]:
pipeline.nift_iterator.visualized_subjects

['BraTS2021_00646', 'BraTS2021_00590', 'BraTS2021_01645', 'BraTS2021_00132']

In [27]:
@interact
def generate_3d_image(
    layer = (0, 127),
    view = ["axial", "sagittal", "coronal"],
):
    array_view = images.as_cpu().as_array()[1][0]
    if view == "axial":
        array_view = array_view[layer, :, :]
    elif view == "coronal":
        array_view = array_view[:, layer, :]
    elif view == "sagittal":
        array_view = array_view[:, :, layer]
    else:
        #raise ValueError(f"view not inside of accepted values: {view}")
        pass

    plt.figure(figsize=(10, 5))
    plt.imshow(array_view, cmap="gray")
    plt.show()

interactive(children=(IntSlider(value=63, description='layer', max=127), Dropdown(description='view', options=…

In [28]:
@interact
def generate_3d_image(
    layer = (0, 127),
    view = ["axial", "sagittal", "coronal"],
):
    array_view = images.as_cpu().as_array()[1][1]
    if view == "axial":
        array_view = array_view[layer, :, :]
    elif view == "coronal":
        array_view = array_view[:, layer, :]
    elif view == "sagittal":
        array_view = array_view[:, :, layer]
    else:
        #raise ValueError(f"view not inside of accepted values: {view}")
        pass

    plt.figure(figsize=(10, 5))
    plt.imshow(array_view, cmap="gray")
    plt.show()

interactive(children=(IntSlider(value=63, description='layer', max=127), Dropdown(description='view', options=…

In [29]:
@interact
def generate_3d_image(
    layer = (0, 127),
    view = ["axial", "sagittal", "coronal"],
):
    array_view = labels.as_cpu().as_array()[1][0]
    if view == "axial":
        array_view = array_view[layer, :, :]
    elif view == "coronal":
        array_view = array_view[:, layer, :]
    elif view == "sagittal":
        array_view = array_view[:, :, layer]
    else:
        #raise ValueError(f"view not inside of accepted values: {view}")
        pass

    plt.figure(figsize=(10, 5))
    plt.imshow(array_view, cmap="gray")
    plt.show()

interactive(children=(IntSlider(value=63, description='layer', max=127), Dropdown(description='view', options=…