# Evaluate Bain Tumor Segmentation Data

In this notebook we will learn:
- how we can evaluate a pre-trained model checkpoint for brain tumor segmentation using MONAI and Weights & Biases.
- how we can visually compare the ground-truth labels with the predicted labels.

## 🌴 Setup and Installation

First, let us install the latest version of both MONAI and Weights and Biases.

In [None]:
!pip install -q -U monai wandb

## 🌳 Initialize a W&B Run

We will start a new W&B run to start tracking our experiment.

In [None]:
import wandb

wandb.init(
    project="brain-tumor-segmentation",
    entity="lifesciences",
    job_type="evaluate"
)

config = wandb.config

In [None]:
from monai.utils import set_determinism

config.seed = 0
set_determinism(seed=config.seed)

## 💿 Loading and Transforming the Data

In [None]:
from utils import ConvertToMultiChannelBasedOnBratsClassesd
from monai.transforms import (
    Compose,
    LoadImaged,
    NormalizeIntensityd,
    Orientationd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)


transforms = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        # Ensure loaded images are in channels-first format
        EnsureChannelFirstd(keys="image"),
        # Ensure the input data to be a PyTorch Tensor or numpy array
        EnsureTyped(keys=["image", "label"]),
        # Convert labels to multi-channels based on brats18 classes
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        # Change the input image’s orientation into the specified based on axis codes
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        # Resample the input images to the specified pixel dimension
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        # Normalize input image intensity
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

In [None]:
from monai.apps import DecathlonDataset


artifact = wandb.use_artifact(
    "lifesciences/brain-tumor-segmentation/decathlon_brain_tumor:latest", type="dataset"
)
artifact_dir = artifact.download()

# Create the dataset for the test split
# of the brain tumor segmentation dataset
val_dataset = DecathlonDataset(
    root_dir=artifact_dir,
    task="Task01_BrainTumour",
    transform=transforms,
    section="validation",
    download=False,
    cache_rate=0.0,
    num_workers=4,
)

In [None]:
import torch
from monai.networks.nets import SegResNet

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

config.model_blocks_down = [1, 2, 2, 4]
config.model_blocks_up = [1, 1, 1]
config.model_init_filters = 16
config.model_in_channels = 4
config.model_out_channels = 3
config.model_dropout_prob = 0.2

# create model
model = SegResNet(
    blocks_down=config.model_blocks_down,
    blocks_up=config.model_blocks_up,
    init_filters=config.model_init_filters,
    in_channels=config.model_in_channels,
    out_channels=config.model_out_channels,
    dropout_prob=config.model_dropout_prob,
).to(device)

In [None]:
import os

model_artifact = wandb.use_artifact(
    "lifesciences/brain-tumor-segmentation/8vmqcqao-checkpoint:latest",
    type="model",
)
model_artifact_dir = model_artifact.download()
model.load_state_dict(torch.load(os.path.join(model_artifact_dir, "model.pth")))
model.eval()

In [None]:
from monai.inferers import sliding_window_inference

config.inference_roi_size = (240, 240, 160)


def inference(model, input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=config.inference_roi_size,
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    with torch.cuda.amp.autocast():
        return _compute(input)

In [None]:
from monai.metrics import DiceMetric
from monai.transforms import Activations, AsDiscrete

dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")
postprocessing_transforms = Compose(
    [Activations(sigmoid=True), AsDiscrete(threshold=0.5)]
)

In [None]:
from tqdm.auto import tqdm


def get_target_area_percentage(segmentation_map):
    segmentation_map_list = segmentation_map.flatten().tolist()
    return segmentation_map_list.count(1.0) * 100 / len(segmentation_map_list)


def log_predictions_into_tables(
    sample_image,
    sample_label,
    predicted_label,
    split: str = None,
    data_idx: int = None,
    table: wandb.Table = None,
):
    sample_image = sample_image.cpu().numpy()
    sample_label = sample_label.cpu().numpy()
    predicted_label = predicted_label.cpu().numpy()
    _, _, _, num_slices = sample_image.shape
    with tqdm(total=num_slices, leave=False) as progress_bar:
        for slice_idx in range(num_slices):
            wandb_images = [
                wandb.Image(
                    sample_image[0, :, :, slice_idx],
                    masks={
                        "ground-truth/Tumor-Core": {
                            "mask_data": sample_label[0, :, :, slice_idx],
                            "class_labels": {0: "background", 1: "Tumor Core"},
                        },
                        "prediction/Tumor-Core": {
                            "mask_data": predicted_label[0, :, :, slice_idx] * 2,
                            "class_labels": {0: "background", 2: "Tumor Core"},
                        },
                    },
                ),
                wandb.Image(
                    sample_image[0, :, :, slice_idx],
                    masks={
                        "ground-truth/Whole-Tumor": {
                            "mask_data": sample_label[1, :, :, slice_idx],
                            "class_labels": {0: "background", 1: "Whole Tumor"},
                        },
                        "prediction/Whole-Tumor": {
                            "mask_data": predicted_label[1, :, :, slice_idx] * 2,
                            "class_labels": {0: "background", 2: "Whole Tumor"},
                        },
                    },
                ),
                wandb.Image(
                    sample_image[0, :, :, slice_idx],
                    masks={
                        "ground-truth/Enhancing-Tumor": {
                            "mask_data": sample_label[2, :, :, slice_idx],
                            "class_labels": {0: "background", 1: "Enhancing Tumor"},
                        },
                        "prediction/Enhancing-Tumor": {
                            "mask_data": predicted_label[2, :, :, slice_idx] * 2,
                            "class_labels": {0: "background", 2: "Enhancing Tumor"},
                        },
                    },
                ),
            ]
            tumor_area_percentage = {
                "Ground-Truth": {
                    "Tumor-Core-Area-Percentage": get_target_area_percentage(
                        sample_label[0, :, :, slice_idx]
                    ),
                    "Whole-Tumor-Area-Percentage": get_target_area_percentage(
                        sample_label[1, :, :, slice_idx]
                    ),
                    "Enhancing-Tumor-Area-Percentage": get_target_area_percentage(
                        sample_label[2, :, :, slice_idx]
                    ),
                },
                "Prediction": {
                    "Tumor-Core-Area-Percentage": get_target_area_percentage(
                        predicted_label[0, :, :, slice_idx]
                    ),
                    "Whole-Tumor-Area-Percentage": get_target_area_percentage(
                        predicted_label[1, :, :, slice_idx]
                    ),
                    "Enhancing-Tumor-Area-Percentage": get_target_area_percentage(
                        predicted_label[2, :, :, slice_idx]
                    ),
                },
            }
            table.add_data(
                split, data_idx, slice_idx, tumor_area_percentage, *wandb_images
            )
            progress_bar.update(1)
    return table

In [None]:
# create the prediction table
prediction_table = wandb.Table(
    columns=[
        "Split",
        "Data Index",
        "Slice Index",
        "Tumor-Area-Pixel-Percentage",
        "Prediction/Tumor-Core",
        "Prediction/Whole-Tumor",
        "Prediction/Enhancing-Tumor",
    ]
)

config.max_prediction_images_visualized = 1

# Perform inference and visualization
with torch.no_grad():
    config.max_prediction_images_visualized
    max_samples = (
        min(config.max_prediction_images_visualized, len(val_dataset))
        if config.max_prediction_images_visualized > 0
        else len(val_dataset)
    )
    progress_bar = tqdm(
        enumerate(val_dataset[:max_samples]),
        total=max_samples,
        desc="Generating Predictions:",
    )
    for data_idx, sample in progress_bar:
        test_input, test_labels = (
            torch.unsqueeze(sample["image"], 0).to(device),
            torch.unsqueeze(sample["label"], 0).to(device),
        )
        test_output = inference(model, test_input)
        test_output = postprocessing_transforms(test_output[0])
        prediction_table = log_predictions_into_tables(
            sample_image=sample["image"],
            sample_label=sample["label"],
            predicted_label=test_output,
            data_idx=data_idx,
            split="validation",
            table=prediction_table,
        )

    wandb.log({"Evaluation/Tumor-Segmentation-Prediction": prediction_table})

In [None]:
# End the experiment
wandb.finish()