# Pre-processing Volumes

In [1]:
#|default_exp preprocessing_volumes

In [2]:
import sys
sys.path.append("../lib")

In [3]:
#|export
from dicom_to_nifti import *

We need to load the CSV containing the series UIDs and corresponding metadata, particularly, the imaging modality, that defines the domain of a series and is used to normalize the spacing between the voxels of each volume:

In [4]:
import os
import polars as pl

base_path_dicom = os.environ['RSNA_IAD_DATA_DIR']
series_path_dicom = f"{base_path_dicom}/series"
base_path_nifti = f"{os.environ['RSNA_IAD_DATA_DIR']}/nifti"
series_path_nifti = f"{base_path_nifti}/series"

df = pl.read_csv(f"{base_path_dicom}/train.csv")
display(df.head(2))

LOCATION_LABELS_COLNAME = [
    'Left Infraclinoid Internal Carotid Artery',
    'Right Infraclinoid Internal Carotid Artery',
    'Left Supraclinoid Internal Carotid Artery',
    'Right Supraclinoid Internal Carotid Artery',
    'Left Middle Cerebral Artery',
    'Right Middle Cerebral Artery',
    'Anterior Communicating Artery',
    'Left Anterior Cerebral Artery',
    'Right Anterior Cerebral Artery',
    'Left Posterior Communicating Artery',
    'Right Posterior Communicating Artery',
    'Basilar Tip',
    'Other Posterior Circulation',
]

LABELS_COLNAME =  LOCATION_LABELS_COLNAME + ['Aneurysm Present']

SeriesInstanceUID,PatientAge,PatientSex,Modality,Left Infraclinoid Internal Carotid Artery,Right Infraclinoid Internal Carotid Artery,Left Supraclinoid Internal Carotid Artery,Right Supraclinoid Internal Carotid Artery,Left Middle Cerebral Artery,Right Middle Cerebral Artery,Anterior Communicating Artery,Left Anterior Cerebral Artery,Right Anterior Cerebral Artery,Left Posterior Communicating Artery,Right Posterior Communicating Artery,Basilar Tip,Other Posterior Circulation,Aneurysm Present
str,i64,str,str,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64
"""1.2.826.0.1.3680043.8.498.1000…",64,"""Female""","""MRA""",0,0,0,0,0,0,0,0,0,0,0,0,0,0
"""1.2.826.0.1.3680043.8.498.1000…",76,"""Female""","""MRA""",0,0,0,0,0,0,0,0,0,0,0,0,0,0


We need some functions to load NifTI files:

In [5]:
#|export
import nibabel
import torch

def nifti_load(series_base_path, serie_uid):
    nifti = nibabel.load(f"{series_base_path}/{serie_uid}.nii.gz")
    return nifti

def nifti_process(nifti):
    volume = torch.from_numpy(nifti.get_fdata()).to(torch.float32)
    metadata = nifti.header
    return volume, metadata 

def nifti_get_spacing(metadata):
    return torch.tensor(nifti.header.get_zooms())

In [6]:
# Get volume from DICOM and corresponding nifti to test transformations on
serie_uid = df["SeriesInstanceUID"][0]
ds_l = dicom_serie_load(series_path_dicom, serie_uid)
volume, ds_metadata_l = dicom_serie_process(ds_l)
dicom_volume_to_nifti(volume, ds_metadata_l, serie_uid, series_path_nifti)

In [7]:
nifti = nifti_load(series_path_nifti, serie_uid)
volume, nifti_metadata = nifti_process(nifti)

In [8]:
volume = volume.unsqueeze(0).unsqueeze(0)

In [9]:
volume.shape

torch.Size([1, 1, 188, 512, 512])

In [10]:
nifti_get_spacing(nifti_metadata)

tensor([0.3516, 0.3516, 0.5000])

All transformations are applied to a whole volume.

We need functions to plot either a single volume, or two volumes, to see the original vs the transformed one:

In [11]:
#|export
import ipywidgets as widgets
import matplotlib.pyplot as plt
%matplotlib widget

def plot_volume(volume, fig, ax):

    volume = volume[0,0]

    plt.ioff()
    if fig is None or ax is None:
        fig, ax = plt.subplots(figsize=(5,5))

    ax_img = ax.imshow(volume[0])  # Create and cache plot element to modify

    slider = widgets.IntSlider(value=0, min=0, max=volume.shape[0]-1, step=1, description='Slice', readout=True, readout_format='d')

    def update(slice_num):
        ax_img.set_data(volume[slice_num])
        fig.canvas.draw_idle()

    slider.observe(lambda change: update(change["new"]), names="value")

    return widgets.VBox([fig.canvas, slider], layout=widgets.Layout(align_items="center"))

In [12]:
plot_volume(volume, None, None)

VBox(children=(Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Ba…

In [13]:
#|export
def plot_2volumes(volume1, volume2):

    ui1 = plot_volume(volume1, None, None)
    ui2 = plot_volume(volume2, None, None)
    return widgets.HBox([ui1, ui2])

The `NormalizeSpacing` transformation is to normalize the spacing between the voxels, by: using the pre-calculated voxel spacing by domain; calculating the correct volume size as specified in the previous post "Pre-calculating voxel spacings"; and interpolating the voxels using PyTorch's `interpolate` function:

In [14]:
#|export
import torch.nn.functional as F

class NormalizeSpacing:

    def __init__(self, interp_mode, domain_spacings_dict, get_metadata):
        self.interp_mode = interp_mode
        self.domain_spacings_dict = domain_spacings_dict
        self.get_metadata = get_metadata        

    def transform(self, volume, domain, spacing):

        assert volume.dtype == torch.float32, "Pixel array data has to be of type torch.float32."
        assert len(volume.shape) == 5, "Volume array must have 5 dimensions, i.e. (batch_size, channel_count, z, y, x)."
                
        domain_spacing = torch.tensor(self.domain_spacings_dict[domain])
        target_size = (torch.tensor(volume.shape[2:], dtype=torch.float32) / spacing * domain_spacing).to(torch.int32)
        volume = F.interpolate(volume, size=tuple(target_size), mode=self.interp_mode)
    
        return volume

The `get_metadata` function is defined in transformations that need to get metadata from the DICOMs of a series or from other sources. This function needs to get the modality (domain) from the previously loaded DataFrame, and the spacing from the DICOMs:

In [15]:
volume_domain_median_spacing_dict = {
    "CTA": (0.46875, 0.46875, 0.8),
    "MRA": (0.410156, 0.410156, 0.6),
    "MRI T1post": (0.5, 0.5, 1.2),
    "MRI T2": (0.5, 0.5, 5.)
}

# DICOM
def get_metadata_dicom(idx, ds_metadata_l):
    modality = df["Modality"][idx]
    spacing = dicom_serie_get_spacing(ds_metadata_l)
    return modality, spacing

transform = NormalizeSpacing("nearest", volume_domain_median_spacing_dict, get_metadata_dicom)
transform_metadata = transform.get_metadata(0, ds_metadata_l)
transformed_volume = transform.transform(volume, *transform_metadata)
print(transformed_volume.shape)

# NifTI
def get_metadata_nifti(idx, metadata):
    modality = df["Modality"][idx]
    spacing = nifti_get_spacing(metadata)
    return modality, spacing

transform = NormalizeSpacing("nearest", volume_domain_median_spacing_dict, get_metadata_nifti)
transform_metadata = transform.get_metadata(0, nifti_metadata)
transformed_volume = transform.transform(volume, *transform_metadata)
print(transformed_volume.shape)

torch.Size([1, 1, 219, 597, 614])
torch.Size([1, 1, 219, 597, 614])


In [16]:
plot_2volumes(volume, transformed_volume)

HBox(children=(VBox(children=(Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'…

The `ResizeInterp` resizes the image by interpolating with the PyTorch's `interpolate` function. It can interpolate voxels using nearest neighbors ("nearest"/"nearest-exact"), "trilinear" and "area" (average pooling):

In [17]:
#|export
class ResizeInterp:

    get_metadata = None
    
    def __init__(self, target_size, mode):
        self.target_size = target_size
        self.mode = mode    

    def transform(self, volume):
    
        assert volume.dtype == torch.float32, "Pixel array data has to be of type torch.float32"
        
        volume = F.interpolate(volume, size=self.target_size, mode=self.mode)
        return volume

In [18]:
transform = ResizeInterp((32, 224, 224), "nearest")
transformed_volume = transform.transform(volume)
print(transformed_volume.shape)

torch.Size([1, 1, 32, 224, 224])


In [19]:
plot_2volumes(volume, transformed_volume)

HBox(children=(VBox(children=(Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'…

The `PercentileCropIntensity` transformation clamps voxel values between two percentiles, so outliers are removed:

In [20]:
#|export
import numpy as np

class PercentileCropIntensity:

    get_metadata = None

    def __init__(self, percentiles):
        self.percentiles = percentiles
    
    def transform(self, volume):
        percentiles = np.percentile(volume.flatten().detach().cpu().numpy(), self.percentiles)
        volume = torch.clamp(volume, min=percentiles[0], max=percentiles[1])
        return volume

In [21]:
transform = PercentileCropIntensity((0.5, 99.5))
transformed_volume = transform.transform(volume)
print(volume.min(), volume.max())
print(transformed_volume.min(), transformed_volume.max())

tensor(0.) tensor(682.)
tensor(0.) tensor(131.)


In [22]:
plot_2volumes(volume, transformed_volume)

HBox(children=(VBox(children=(Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'…

The `StandardizeIntensity` transformation normalizes the voxel values so they have mean 0 and standard deviation 1:

In [23]:
#|export
class StandardizeIntensity:

    get_metadata = None
    
    def transform(self, volume):

        assert volume.dtype == torch.float32, \
        "Pixel array data has to be of type torch.float32"
    
        volume = (volume - volume.mean()) / volume.std()
        
        return volume

In [24]:
transform = StandardizeIntensity()
transformed_volume = transform.transform(volume)
print(transformed_volume.mean(), transformed_volume.std())

tensor(-5.0727e-09) tensor(1.)


In [25]:
plot_2volumes(volume, transformed_volume)

HBox(children=(VBox(children=(Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'…

Now let's use the PyTorch `Dataset` class to later use the `DataLoader` class to load and transform the volumes in parallel.
First we implement `Dataset` classes that load and transform a single series:

In [29]:
#|export
import torch

class Dataset(torch.utils.data.Dataset):

    def __init__(self, path, series_uid_l, labels, transforms):
        self.path = path
        self.series_uid_l = series_uid_l
        self.labels = labels
        self.transforms = transforms

        self.n = len(self.series_uid_l)
    
    def __len__(self):
        return self.n

In [30]:
#|export
import pydicom

class DicomDataset(Dataset):

    def __init__(self, path, series_uid_l, labels, transforms):
        super().__init__(path, series_uid_l, labels, transforms)

    def __getitem__(self, idx):

        serie_uid = self.series_uid_l[idx]
        label = self.labels[idx]

        ds_l = dicom_serie_load(self.path, serie_uid)
        volume, ds_metadata_l = dicom_serie_process(ds_l)
        volume = volume.unsqueeze(0).unsqueeze(0)

        for transform in self.transforms:
            if transform.get_metadata is not None:
                transform_metadata = transform.get_metadata(idx, ds_metadata_l)
                volume = transform.transform(volume, *transform_metadata)
            else:
                volume = transform.transform(volume)

        return volume[0], label

In [31]:
transforms = [
    NormalizeSpacing("trilinear", volume_domain_median_spacing_dict, get_metadata_dicom),
    PercentileCropIntensity(percentiles=(0.5, 99.5)),
    StandardizeIntensity(), 
    ResizeInterp((32, 224, 224), "nearest")
]

In [32]:
dicom_dataset = DicomDataset(series_path_dicom, 
                             list(df["SeriesInstanceUID"]), 
                             torch.from_numpy(df[LABELS_COLNAME].to_numpy()),
                             transforms)

volume, label = next(iter(dicom_dataset))
print(volume.shape, label)

torch.Size([1, 32, 224, 224]) tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])


In [35]:
plot_volume(volume.unsqueeze(0), None, None)

VBox(children=(Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Ba…

In [36]:
#|export
class NiftiDataset(Dataset):

    def __init__(self, path, series_uid_l, labels, transforms):
        super().__init__(path, series_uid_l, labels, transforms)

    def __getitem__(self, idx):

        serie_uid = self.series_uid_l[idx]
        label = self.labels[idx]

        nifti = nifti_load(self.path, serie_uid)
        volume, metadata = nifti_process(nifti)
        volume = volume.unsqueeze(0).unsqueeze(0)

        for transform in self.transforms:
            if transform.get_metadata is not None:
                transform_metadata = transform.get_metadata(idx, metadata)
                volume = transform.transform(volume, *transform_metadata)
            else:
                volume = transform.transform(volume)

        return volume[0], label

In [39]:
transforms = [
    NormalizeSpacing("trilinear", volume_domain_median_spacing_dict, get_metadata_nifti),
    PercentileCropIntensity(percentiles=(0.5, 99.5)),
    StandardizeIntensity(), 
    ResizeInterp((32, 224, 224), "nearest")
]

In [40]:
nifti_dataset = NiftiDataset(series_path_nifti, 
                             list(df["SeriesInstanceUID"]), 
                             torch.from_numpy(df[LABELS_COLNAME].to_numpy()),
                             transforms)

volume, label = next(iter(nifti_dataset))
print(volume.shape, label)

torch.Size([1, 32, 224, 224]) tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])


In [41]:
plot_volume(volume.unsqueeze(0), None, None)

VBox(children=(Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Ba…

Now lets build the DataLoaders to load and transform the data in parallel and, optionally, returns it in batches.

In [43]:
#|export
import multiprocessing

class DicomDataLoader:

    def __init__(self, path, series_uid_l, labels, unit_transforms,
                 shuffle, batch_size, pin_memory, num_workers):
        dataset = DicomDataset(path, series_uid_l, labels, unit_transforms)
        self.loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                  shuffle=shuffle, pin_memory=pin_memory,
                                                  num_workers=num_workers)

    def __iter__(self):
        for volumes, labels in self.loader:
            yield volumes, labels

In [44]:
transforms = [
    NormalizeSpacing("trilinear", volume_domain_median_spacing_dict, get_metadata_dicom),
    PercentileCropIntensity(percentiles=(0.5, 99.5)),
    StandardizeIntensity(), 
    ResizeInterp((32, 224, 224), "nearest")
]

dicom_loader = DicomDataLoader(series_path_dicom, 
                               list(df["SeriesInstanceUID"]), 
                               torch.from_numpy(df[LABELS_COLNAME].to_numpy()),
                               transforms, 
                               shuffle=True, batch_size=4, pin_memory=True, num_workers=4)

In [45]:
for volumes, labels in dicom_loader:
    print(volumes.shape, labels)
    break



torch.Size([4, 1, 32, 224, 224]) tensor([[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])


In [47]:
#|export
class NiftiDataLoader:

    def __init__(self, path, series_uid_l, labels, unit_transforms,
                 shuffle, batch_size, pin_memory, num_workers):
        dataset = NiftiDataset(path, series_uid_l, labels, unit_transforms)
        self.loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                  shuffle=shuffle, pin_memory=pin_memory,
                                                  num_workers=num_workers)

    def __iter__(self):
        for volumes, labels in self.loader:
            yield volumes, labels

In [48]:
transforms = [
    NormalizeSpacing("trilinear", volume_domain_median_spacing_dict, get_metadata_nifti),
    PercentileCropIntensity(percentiles=(0.5, 99.5)),
    StandardizeIntensity(), 
    ResizeInterp((32, 224, 224), "nearest")
]

nifti_loader = NiftiDataLoader(series_path_nifti, 
                               list(map(lambda x: x[:-7], os.listdir(series_path_nifti))), 
                               torch.from_numpy(df[LABELS_COLNAME].to_numpy()),
                               transforms, 
                               shuffle=True, batch_size=4, pin_memory=True, num_workers=4)

In [49]:
for volumes, labels in nifti_loader:
    print(volumes.shape, labels)
    break

torch.Size([4, 1, 32, 224, 224]) tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])


In [50]:
from nbdev.export import nb_export

In [51]:
nb_export("4_preprocessing_volumes.ipynb", "../lib")