# Deep Imputation of BraTS dataset with MONAI

The dataset comes from http://medicaldecathlon.com/.  
Modality: Multimodal multisite MRI data (FLAIR, T1w, T1gd,T2w)  
Size: 750 4D volumes (484 Training + 266 Testing)  
Source: BRATS 2016 and 2017 datasets.  
Challenge: **Drop some of the modalities randomly and reconstruct it by imputing with a 3D U-Net**

## Setup environment

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
# !python -c "import onnxruntime" || pip install -q onnxruntime
%matplotlib inline

## Setup imports

In [None]:
import os
import shutil
import tempfile
import time
import matplotlib.pyplot as plt
from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader, decollate_batch
from monai.handlers.utils import from_engine
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import SegResNet
from monai.transforms import (
    Activations,
    Activationsd,
    AsDiscrete,
    AsDiscreted,
    Compose,
    Invertd,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)
from monai.utils import set_determinism
# import onnxruntime
from tqdm import tqdm

import torch

print_config()

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [None]:
os.environ["MONAI_DATA_DIRECTORY"] = "/scratch1/sachinsa/monai_data_1"

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
if directory is not None:
    os.makedirs(directory, exist_ok=True)
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

## Set deterministic training for reproducibility

In [None]:
set_determinism(seed=0)

In [None]:
train_ds_1 = DecathlonDataset(
    root_dir=root_dir,
    task="Task01_BrainTumour",
    # transform=train_transform,
    section="training",
    download=True,
    cache_rate=0.0,
    num_workers=4,
)

train_ds_2 = DecathlonDataset(
    root_dir=root_dir,
    task="Task01_BrainTumour",
    # transform=train_transform,
    section="training",
    download=True,
    cache_rate=0.0,
    num_workers=4,
)

In [None]:
import numpy as np

def generate_boolean_list():
    while True:
        # np.random.seed(42)
        # Generate a boolean array where each entry is True with 20% probability
        bool_array = np.random.rand(4) < 0.2

        # Check if at least one value is False
        if not np.all(bool_array):
            return np.where(bool_array)[0].tolist()

true_indices = generate_boolean_list()
print("Indices where True:", true_indices)

Challenge (unsolved): How to ensure only input images have some modalities dropped, and output modalities have the entire data?

In [None]:
from monai.transforms import (
    RandCoarseDropoutD,
)

class RandomDropd(MapTransform):
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            drop_indices = generate_boolean_list()
            if len(drop_indices):
              d[key][drop_indices] = 0
        return d

# SACHIN NOTE: Only "image", "label" are supported keys
train_ds_1.transform = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        EnsureTyped(keys=["image"]),
        RandomDropd(keys=["image"]), #, prob=0.1),
        Orientationd(keys=["image"], axcodes="RAS"),
        # Spacingd(
        #     keys=["image"],
        #     pixdim=(1.0, 1.0, 1.0),
        #     mode=("bilinear", "nearest"),
        # ),
        # RandCoarseDropoutD(keys=["image"], holes = 2, spatial_size = (96, 96, 96), fill_value = 0, prob=1.0),
        RandSpatialCropd(keys=["image"], roi_size=[224, 224, 144], random_size=False),
        RandFlipd(keys=["image"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys=["image"], factors=0.1, prob=1.0),
        RandShiftIntensityd(keys=["image"], offsets=0.1, prob=1.0),
    ]
)

val_data_example = train_ds_1[2]['image']
print(f"image shape: {val_data_example.shape}")
plt.figure("image", (24, 6))
for i in range(4):
    plt.subplot(1, 4, i + 1)
    plt.title(f"image channel {i}")
    plt.imshow(val_data_example[i, :, :, 60].detach().cpu(), cmap="gray")
    plt.colorbar()
plt.show()


######################################

class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WT (Whole tumor)
    and ET (Enhancing tumor).

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # merge label 2 and label 3 to construct TC
            result.append(torch.logical_or(d[key] == 2, d[key] == 3))
            # merge labels 1, 2 and 3 to construct WT
            result.append(torch.logical_or(torch.logical_or(d[key] == 2, d[key] == 3), d[key] == 1))
            # label 2 is ET
            result.append(d[key] == 2)
            d[key] = torch.stack(result, axis=0).float()
        return d
    
train_ds_2.transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        RandSpatialCropd(keys=["image", "label"], roi_size=[224, 224, 144], random_size=False),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
    ]
)

val_data_example = train_ds_2[2]['image']
print(f"image shape: {val_data_example.shape}")
plt.figure("image", (24, 6))
for i in range(4):
    plt.subplot(1, 4, i + 1)
    plt.title(f"image channel {i}")
    plt.imshow(val_data_example[i, :, :, 60].detach().cpu(), cmap="gray")
    plt.colorbar()
plt.show()

## Setup transforms for training and validation


In [None]:
train_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        EnsureTyped(keys=["image"]),
        # RandomDropd(keys=["in_image"], prob=0.1), # TODO
        Orientationd(keys=["image"], axcodes="RAS"),
        # Spacingd(
        #     keys=["image"],
        #     pixdim=(1.0, 1.0, 1.0),
        #     mode=("bilinear", "nearest"),
        # ),
        RandSpatialCropd(keys=["image"], roi_size=[224, 224, 144], random_size=False),
        RandFlipd(keys=["image"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys=["image"], factors=0.1, prob=1.0),
        RandShiftIntensityd(keys=["image"], offsets=0.1, prob=1.0),
    ]
)
val_transform = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        EnsureTyped(keys=["image"]),
        Orientationd(keys=["image"], axcodes="RAS"),
        # Spacingd(
        #     keys=["image"],
        #     pixdim=(1.0, 1.0, 1.0),
        #     mode=("bilinear", "nearest"),
        # ),
        NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
    ]
)

## Quickly load data with DecathlonDataset

Here we use `DecathlonDataset` to automatically download and extract the dataset.
It inherits MONAI `CacheDataset`, if you want to use less memory, you can set `cache_num=N` to cache N items for training and use the default args to cache all the items for validation, it depends on your memory size.

In [None]:
# here we don't cache any data in case out of memory issue
train_ds = DecathlonDataset(
    root_dir=root_dir,
    task="Task01_BrainTumour",
    transform=train_transform,
    section="training",
    download=True,
    cache_rate=0.0,
    num_workers=4,
)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4)
val_ds = DecathlonDataset(
    root_dir=root_dir,
    task="Task01_BrainTumour",
    transform=val_transform,
    section="validation",
    download=False,
    cache_rate=0.0,
    num_workers=4,
)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4)

Consider a subset of train and validation dataset for debugging the training workflow

In [None]:
from torch.utils.data import DataLoader, Subset

train_subs = Subset(train_ds, list(range(30)))
val_subs = Subset(val_ds, list(range(20)))

train_loader = DataLoader(train_subs, batch_size=1, shuffle=True, num_workers=2)
val_loader = DataLoader(val_subs, batch_size=1, shuffle=False, num_workers=2)
print(len(train_loader), len(val_loader))

## Check data shape and visualize

In [None]:
# pick one image from DecathlonDataset to visualize and check the 4 channels
val_data_example = val_ds[2]
print(f"image shape: {val_data_example['image'].shape}")
plt.figure("image", (24, 6))
for i in range(4):
    plt.subplot(1, 4, i + 1)
    plt.title(f"image channel {i}")
    plt.imshow(val_data_example["image"][i, :, :, 60].detach().cpu(), cmap="gray")
    plt.colorbar()
plt.show()

## Create Model, Loss, Optimizer

**Define a 3D Unet**

In [None]:
from monai.networks.nets import UNet

device = torch.device("cuda:0")
model = UNet(
    spatial_dims=3, # 3D
    in_channels=4,
    out_channels=8, # we will output estimated mean and estimated std dev for all 4 image channels
    channels=(4, 8, 16),
    strides=(2, 2),
    num_res_units=2
).to(device)

In [None]:
def GaussianLikelihood(expected_img, output_img):
    # input is 4 channel images, output is 8 channel images
    # TODO (DONE): Deal sigma<=0 issue by either initializing to 1 or adding a small constant
    #   Dealt it using sigma

    output_img_mean = output_img[:, :4, ...]
    output_img_log_std = output_img[:, 4:, ...]
    cost1 = (expected_img - output_img_mean)**2 / (2*torch.exp(2*output_img_log_std))

    cost2 = output_img_log_std

    return torch.mean(cost1 + cost2)





import torch.nn as nn
from monai.metrics import MSEMetric

max_epochs = 20 # 300
val_interval = 4
VAL_AMP = True

# Define the loss function
loss_function = GaussianLikelihood #nn.MSELoss()
# loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

# mse_metric = DiceMetric(include_background=True, reduction="mean")
# mse_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

mse_metric = MSEMetric(reduction="mean")
mse_metric_batch = MSEMetric(reduction="mean_batch")

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])


# define inference method
def inference(input):
    def _compute(input):
        output = model(input)
        return output

    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)


# use amp to accelerate training
scaler = torch.cuda.amp.GradScaler()
# enable cuDNN benchmark
torch.backends.cudnn.benchmark = True



In [None]:
import pdb; 

In [None]:
best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
metric_values = []
metric_values_tc = []
metric_values_wt = []
metric_values_et = []

total_start = time.time()
for epoch in range(max_epochs):
    epoch_start = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step_start = time.time()
        step += 1
        inputs, outputs_gt = (
            batch_data["image"].to(device),
            batch_data["image"].to(device),
        )
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            mask_indices = [0, 2] # generate_boolean_list()
            inputs[:, mask_indices, ...] = 0
            outputs = model(inputs)
            loss = loss_function(outputs_gt, outputs)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
        print(
            f"{step}/{len(train_subs) // train_loader.batch_size}"
            f", train_loss: {loss.item():.4f}"
            f", step time: {(time.time() - step_start):.4f}"
        )
    lr_scheduler.step()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_outputs_gt = (
                    batch_data["image"].to(device),
                    batch_data["image"].to(device),
                )
                val_outputs = inference(val_inputs)
                val_outputs = val_outputs[:, :4, ...]
                # val_outputs = [post_trans(i) for i in val_outputs]
                mse_metric(y_pred=val_outputs, y=val_outputs_gt)
                mse_metric_batch(y_pred=val_outputs, y=val_outputs_gt)

            metric = mse_metric.aggregate().item()
            metric_values.append(metric)
            metric_batch = mse_metric_batch.aggregate()
            mse_metric.reset()
            mse_metric_batch.reset()

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                best_metrics_epochs_and_time[0].append(best_metric)
                best_metrics_epochs_and_time[1].append(best_metric_epoch)
                best_metrics_epochs_and_time[2].append(time.time() - total_start)
                torch.save(
                    model.state_dict(),
                    os.path.join(root_dir, "best_metric_model.pth"),
                )
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean mse: {metric:.4f}"
                f"\nbest mean metric: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )
    print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")
total_time = time.time() - total_start

In [None]:
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}, total time: {total_time}.")

## Plot the loss and metric

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y, color="red")
plt.subplot(1, 2, 2)
plt.title("Val Mean MSE")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y, color="green")
plt.show()

## Check best pytorch model output with the input image and label

In [None]:
val_output.shape

In [None]:
val_input.shape

In [None]:
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
    # select one image to evaluate and visualize the model output
    val_input = val_ds[6]["image"].unsqueeze(0).to(device)
    mask_indices = [0, 2] # generate_boolean_list()
    val_input[:, mask_indices, ...] = 0
    val_output = inference(val_input)
    # val_output = post_trans(val_output[0])
    plt.figure("image", (24, 6))
    for i in range(4):
        plt.subplot(1, 4, i + 1)
        plt.title(f"image channel {i}")
        plt.imshow(val_input[0, i, :, :, 70].detach().cpu(), cmap="gray")
    plt.show()
    # visualize the 3 channels model output corresponding to this image
    plt.figure("output mean", (24, 6))
    for i in range(4):
        plt.subplot(1, 4, i + 1)
        plt.title(f"output channel {i}")
        plt.imshow(val_output[0, i, :, :, 70].detach().cpu(), cmap="gray")

    plt.figure("output std", (24, 6))
    for i in range(4):
        plt.subplot(1, 4, i + 1)
        plt.title(f"output channel {i}")
        plt.imshow(val_output[0, i+4, :, :, 70].detach().cpu(), cmap="gray")
    plt.show()

## Cleanup data directory

Remove directory if a temporary was used.

In [None]:
if directory is None:
    shutil.rmtree(root_dir)