Copyright (c) MONAI Consortium  
Licensed under the Apache License, Version 2.0 (the "License");  
you may not use this file except in compliance with the License.  
You may obtain a copy of the License at  
&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  
Unless required by applicable law or agreed to in writing, software  
distributed under the License is distributed on an "AS IS" BASIS,  
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
See the License for the specific language governing permissions and  
limitations under the License.

# Brain tumor 3D segmentation with MONAI

This tutorial shows how to construct a training workflow of multi-labels segmentation task.

And it contains below features:
1. Transforms for dictionary format data.
1. Define a new transform according to MONAI transform API.
1. Load Nifti image with metadata, load a list of images and stack them.
1. Randomly adjust intensity for data augmentation.
1. Cache IO and transforms to accelerate training and validation.
1. 3D SegResNet model, Dice loss function, Mean Dice metric for 3D segmentation task.
1. Deterministic training for reproducibility.

The dataset comes from http://medicaldecathlon.com/.  
Target: Gliomas segmentation necrotic/active tumour and oedema  
Modality: Multimodal multisite MRI data (FLAIR, T1w, T1gd,T2w)  
Size: 750 4D volumes (484 Training + 266 Testing)  
Source: BRATS 2016 and 2017 datasets.  
Challenge: Complex and heterogeneously-located targets

Below figure shows image patches with the tumor sub-regions that are annotated in the different modalities (top left) and the final labels for the whole dataset (right).
(Figure taken from the [BraTS IEEE TMI paper](https://ieeexplore.ieee.org/document/6975210/))

![image](../figures/brats_tasks.png)

The image patches show from left to right:
1. the whole tumor (yellow) visible in T2-FLAIR (Fig.A).
1. the tumor core (red) visible in T2 (Fig.B).
1. the enhancing tumor structures (light blue) visible in T1Gd, surrounding the cystic/necrotic components of the core (green) (Fig. C).
1. The segmentations are combined to generate the final labels of the tumor sub-regions (Fig.D): edema (yellow), non-enhancing solid core (red), necrotic/cystic core (green), enhancing core (blue).

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb)

In [None]:
cd "/kaggle/input/unetr/unetr_plus_plus-main"

In [None]:
!pip install batchgenerators

In [None]:
!pip install timm

In [None]:
!pip install thop

In [None]:
!pip install monai

In [None]:
!pip install einops

In [None]:
from torch import nn
from typing import Tuple, Union
from unetr_pp.network_architecture.neural_network import SegmentationNetwork
from unetr_pp.network_architecture.dynunet_block import UnetOutBlock, UnetResBlock
from unetr_pp.network_architecture.tumor.model_components import UnetrPPEncoder, UnetrUpBlock


class UNETR_PP(SegmentationNetwork):
    """
    UNETR++ based on: "Shaker et al.,
    UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
    """
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            feature_size: int = 16,
            hidden_size: int = 256,
            num_heads: int = 4,
            pos_embed: str = "perceptron",
            norm_name: Union[Tuple, str] = "instance",
            dropout_rate: float = 0.0,
            depths=None,
            dims=None,
            conv_op=nn.Conv3d,
            do_ds=True,

    ) -> None:
        """
        Args:
            in_channels: dimension of input channels.
            out_channels: dimension of output channels.
            img_size: dimension of input image.
            feature_size: dimension of network feature size.
            hidden_size: dimensions of  the last encoder.
            num_heads: number of attention heads.
            pos_embed: position embedding layer type.
            norm_name: feature normalization type and arguments.
            dropout_rate: faction of the input units to drop.
            depths: number of blocks for each stage.
            dims: number of channel maps for the stages.
            conv_op: type of convolution operation.
            do_ds: use deep supervision to compute the loss.
        """

        super().__init__()
        if depths is None:
            depths = [3, 3, 3, 3]
        self.do_ds = do_ds
        self.conv_op = conv_op
        self.num_classes = out_channels
        if not (0 <= dropout_rate <= 1):
            raise AssertionError("dropout_rate should be between 0 and 1.")

        if pos_embed not in ["conv", "perceptron"]:
            raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")

        self.feat_size = (4, 4, 4)
        self.hidden_size = hidden_size

        self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads)

        self.encoder1 = UnetResBlock(
            spatial_dims=3,
            in_channels=in_channels,
            out_channels=feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
        )
        self.decoder5 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 16,
            out_channels=feature_size * 8,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            out_size=8*8*8,
        )
        self.decoder4 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 8,
            out_channels=feature_size * 4,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            out_size=16*16*16,
        )
        self.decoder3 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 4,
            out_channels=feature_size * 2,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            out_size=32*32*32,
        )
        self.decoder2 = UnetrUpBlock(
            spatial_dims=3,
            in_channels=feature_size * 2,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=(4, 4, 4),
            norm_name=norm_name,
            out_size=128*128*128,
            conv_decoder=True,
        )
        self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels)
        if self.do_ds:
            self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels)
            self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels)

    def proj_feat(self, x, hidden_size, feat_size):
        x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size)
        x = x.permute(0, 4, 1, 2, 3).contiguous()
        return x

    def forward(self, x_in):
        #print("###########reached forward network")
        #print("XIN",x_in.shape)
        x_output, hidden_states = self.unetr_pp_encoder(x_in)
        convBlock = self.encoder1(x_in)

        # Four encoders
        enc1 = hidden_states[0]
        enc2 = hidden_states[1]
        enc3 = hidden_states[2]
        enc4 = hidden_states[3]

        # Four decoders
        dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size)
        dec3 = self.decoder5(dec4, enc3)
        dec2 = self.decoder4(dec3, enc2)
        dec1 = self.decoder3(dec2, enc1)

        out = self.decoder2(dec1, convBlock)
        if self.do_ds:
            logits = [self.out1(out), self.out2(dec1), self.out3(dec2)]
        else:
            logits = self.out1(out)

        return logits

In [None]:
import torch
import einops
model=UNETR_PP(in_channels=4,
                             out_channels=3,
                             feature_size=16,
                             num_heads=4,
                             depths=[3, 3, 3, 3],
                             dims=[32, 64, 128, 256],
                             do_ds=True,)
input=torch.rand(1,4,128,128,128)

output=model(input)

In [None]:
print(len(output))
print(output[0].shape)
print(output[1].shape)
print(output[2].shape)

In [None]:
print(model)

In [None]:
print(model)

In [None]:
!pip install monai

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

## Setup imports

In [None]:

from __future__ import annotations

import os
import shutil
import sys
import warnings
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Any

import numpy as np

from monai.apps.tcia import (
    download_tcia_series_instance,
    get_tcia_metadata,
    get_tcia_ref_uid,
    match_tcia_ref_uid_in_study,
)
# from monai.apps.utils import download_and_extract
from monai.config.type_definitions import PathLike
from monai.data import (
    CacheDataset,
    PydicomReader,
    load_decathlon_datalist,
    load_decathlon_properties,
    partition_dataset,
    select_cross_validation_folds,
)
from monai.transforms import LoadImaged, Randomizable
from monai.utils import ensure_tuple

def _basename(p: PathLike) -> str:
    """get the last part of the path (removing the trailing slash if it exists)"""
    sep = os.path.sep + (os.path.altsep or "") + "/ "
    return Path(f"{p}".rstrip(sep)).name

def extractall(
    filepath: PathLike,
    output_dir: PathLike = ".",
    hash_val: str | None = None,
    hash_type: str = "md5",
    file_type: str = "",
    has_base: bool = True,
) -> None:
    """
    Extract file to the output directory.
    Expected file types are: `zip`, `tar.gz` and `tar`.

    Args:
        filepath: the file path of compressed file.
        output_dir: target directory to save extracted files.
        hash_val: expected hash value to validate the compressed file.
            if None, skip hash validation.
        hash_type: 'md5' or 'sha1', defaults to 'md5'.
        file_type: string of file type for decompressing. Leave it empty to infer the type from the filepath basename.
        has_base: whether the extracted files have a base folder. This flag is used when checking if the existing
            folder is a result of `extractall`, if it is, the extraction is skipped. For example, if A.zip is unzipped
            to folder structure `A/*.png`, this flag should be True; if B.zip is unzipped to `*.png`, this flag should
            be False.

    Raises:
        RuntimeError: When the hash validation of the ``filepath`` compressed file fails.
        NotImplementedError: When the ``filepath`` file extension is not one of [zip", "tar.gz", "tar"].

    """
    if has_base:
        # the extracted files will be in this folder
        cache_dir = Path(output_dir, _basename(filepath).split(".")[0])
    else:
        cache_dir = Path(output_dir)
    if cache_dir.exists() and next(cache_dir.iterdir(), None) is not None:
        logger.info(f"Non-empty folder exists in {cache_dir}, skipped extracting.")
        return
    filepath = Path(filepath)
    if hash_val and not check_hash(filepath, hash_val, hash_type):
        raise RuntimeError(
            f"{hash_type} check of compressed file failed: " f"filepath={filepath}, expected {hash_type}={hash_val}."
        )
    print(f"Writing into directory: {output_dir}.")
    _file_type = file_type.lower().strip()
    print(filepath, _file_type)
    if filepath.name.endswith("zip") or _file_type == "zip":
        zip_file = zipfile.ZipFile(filepath)
        zip_file.extractall(output_dir)
        zip_file.close()
        return
    if filepath.name.endswith("tar") or filepath.name.endswith("tar.gz") or "tar" in _file_type:
        tar_file = tarfile.open(filepath)
        tar_file.extractall(output_dir)
        tar_file.close()
        return
    raise NotImplementedError(
        f'Unsupported file type, available options are: ["zip", "tar.gz", "tar"]. name={filepath} type={file_type}.'
    )


def download_and_extract(
    url: str,
    filepath: PathLike = "",
    output_dir: PathLike = ".",
    hash_val: str | None = None,
    hash_type: str = "md5",
    file_type: str = "",
    has_base: bool = True,
    progress: bool = True,
) -> None:
    """
    Download file from URL and extract it to the output directory.

    Args:
        url: source URL link to download file.
        filepath: the file path of the downloaded compressed file.
            use this option to keep the directly downloaded compressed file, to avoid further repeated downloads.
        output_dir: target directory to save extracted files.
            default is the current directory.
        hash_val: expected hash value to validate the downloaded file.
            if None, skip hash validation.
        hash_type: 'md5' or 'sha1', defaults to 'md5'.
        file_type: string of file type for decompressing. Leave it empty to infer the type from url's base file name.
        has_base: whether the extracted files have a base folder. This flag is used when checking if the existing
            folder is a result of `extractall`, if it is, the extraction is skipped. For example, if A.zip is unzipped
            to folder structure `A/*.png`, this flag should be True; if B.zip is unzipped to `*.png`, this flag should
            be False.
        progress: whether to display progress bar.
    """
    print()
    with tempfile.TemporaryDirectory() as tmp_dir:
        filename = filepath or Path(tmp_dir, _basename(url)).resolve()
#         shutil.copy("/kaggle/input/segresnet-data/BraTS-MEN-Train.zip", tmp_dir)
#         download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress)
        extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)
        
class DecathlonDataset(Randomizable, CacheDataset):
    """
    The Dataset to automatically download the data of Medical Segmentation Decathlon challenge
    (http://medicaldecathlon.com/) and generate items for training, validation or test.
    It will also load these properties from the JSON config file of dataset. user can call `get_properties()`
    to get specified properties or all the properties loaded.
    It's based on :py:class:`monai.data.CacheDataset` to accelerate the training process.

    Args:
        root_dir: user's local directory for caching and loading the MSD datasets.
        task: which task to download and execute: one of list ("Task01_BrainTumour", "Task02_Heart",
            "Task03_Liver", "Task04_Hippocampus", "Task05_Prostate", "Task06_Lung", "Task07_Pancreas",
            "Task08_HepaticVessel", "Task09_Spleen", "Task10_Colon").
        section: expected data section, can be: `training`, `validation` or `test`.
        transform: transforms to execute operations on input data.
            for further usage, use `EnsureChannelFirstd` to convert the shape to [C, H, W, D].
        download: whether to download and extract the Decathlon from resource link, default is False.
            if expected file already exists, skip downloading even set it to True.
            user can manually copy tar file or dataset folder to the root directory.
        val_frac: percentage of validation fraction in the whole dataset, default is 0.2.
        seed: random seed to randomly shuffle the datalist before splitting into training and validation, default is 0.
            note to set same seed for `training` and `validation` sections.
        cache_num: number of items to be cached. Default is `sys.maxsize`.
            will take the minimum of (cache_num, data_length x cache_rate, data_length).
        cache_rate: percentage of cached data in total, default is 1.0 (cache all).
            will take the minimum of (cache_num, data_length x cache_rate, data_length).
        num_workers: the number of worker threads if computing cache in the initialization.
            If num_workers is None then the number returned by os.cpu_count() is used.
            If a value less than 1 is specified, 1 will be used instead.
        progress: whether to display a progress bar when downloading dataset and computing the transform cache content.
        copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
            default to `True`. if the random transforms don't modify the cached content
            (for example, randomly crop from the cached image and deepcopy the crop region)
            or if every cache item is only used once in a `multi-processing` environment,
            may set `copy=False` for better performance.
        as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
            it may help improve the performance of following logic.
        runtime_cache: whether to compute cache at the runtime, default to `False` to prepare
            the cache content at initialization. See: :py:class:`monai.data.CacheDataset`.

    Raises:
        ValueError: When ``root_dir`` is not a directory.
        ValueError: When ``task`` is not one of ["Task01_BrainTumour", "Task02_Heart",
            "Task03_Liver", "Task04_Hippocampus", "Task05_Prostate", "Task06_Lung", "Task07_Pancreas",
            "Task08_HepaticVessel", "Task09_Spleen", "Task10_Colon"].
        RuntimeError: When ``dataset_dir`` doesn't exist and downloading is not selected (``download=False``).

    Example::

        transform = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                EnsureChannelFirstd(keys=["image", "label"]),
                ScaleIntensityd(keys="image"),
                ToTensord(keys=["image", "label"]),
            ]
        )

        val_data = DecathlonDataset(
            root_dir="./", task="Task09_Spleen", transform=transform, section="validation", seed=12345, download=True
        )

        print(val_data[0]["image"], val_data[0]["label"])

    """

    resource = {
        "Task01_BrainTumour": "/kaggle/input/segresnet-data/BraTS-MEN-Train",
        "Task02_Heart": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task02_Heart.tar",
        "Task03_Liver": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task03_Liver.tar",
        "Task04_Hippocampus": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task04_Hippocampus.tar",
        "Task05_Prostate": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task05_Prostate.tar",
        "Task06_Lung": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task06_Lung.tar",
        "Task07_Pancreas": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task07_Pancreas.tar",
        "Task08_HepaticVessel": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task08_HepaticVessel.tar",
        "Task09_Spleen": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar",
        "Task10_Colon": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task10_Colon.tar",
    }
    md5 = {
        "Task01_BrainTumour": "240a19d752f0d9e9101544901065d872",
        "Task02_Heart": "06ee59366e1e5124267b774dbd654057",
        "Task03_Liver": "a90ec6c4aa7f6a3d087205e23d4e6397",
        "Task04_Hippocampus": "9d24dba78a72977dbd1d2e110310f31b",
        "Task05_Prostate": "35138f08b1efaef89d7424d2bcc928db",
        "Task06_Lung": "8afd997733c7fc0432f71255ba4e52dc",
        "Task07_Pancreas": "4f7080cfca169fa8066d17ce6eb061e4",
        "Task08_HepaticVessel": "641d79e80ec66453921d997fbf12a29c",
        "Task09_Spleen": "410d4a301da4e5b2f6f86ec3ddba524e",
        "Task10_Colon": "bad7a188931dc2f6acf72b08eb6202d0",
    }

    def __init__(
        self,
        root_dir: PathLike,
        task: str,
        section: str,
        transform: Sequence[Callable] | Callable = (),
        download: bool = False,
        seed: int = 0,
        val_frac: float = 0.2,
        cache_num: int = sys.maxsize,
        cache_rate: float = 1.0,
        num_workers: int = 1,
        progress: bool = True,
        copy_cache: bool = True,
        as_contiguous: bool = True,
        runtime_cache: bool = False,
    ) -> None:
        root_dir = Path(root_dir)
        if not root_dir.is_dir():
            raise ValueError("Root directory root_dir must be a directory.")
        self.section = section
        self.val_frac = val_frac
        self.set_random_state(seed=seed)
        if task not in self.resource:
            raise ValueError(f"Unsupported task: {task}, available options are: {list(self.resource.keys())}.")
        dataset_dir = root_dir / task
#         tarfile_name = f"{dataset_dir}.tar"
#         if download:
#             download_and_extract(
#                 url=self.resource[task],
#                 filepath=tarfile_name,
#                 output_dir=root_dir,
#                 hash_val=self.md5[task],
#                 hash_type="md5",
#                 progress=progress,
#             )

#         if not dataset_dir.exists():
#             raise RuntimeError(
#                 f"Cannot find dataset directory: {dataset_dir}, please use download=True to download it."
#             )
#         dataset_dir = "/kaggle/input/meningits-part1/brain-men-train1"
        self.indices: np.ndarray = np.array([])
        data = self._generate_data_list("/kaggle/input/segres-json")
        # as `release` key has typo in Task04 config file, ignore it.
        property_keys = [
            "name",
            "description",
            "reference",
            "licence",
            "tensorImageSize",
            "modality",
            "labels",
            "numTraining",
            "numTest",
        ]
#         self._properties = load_decathlon_properties("/kaggle/input/segres-json/dataset.json", property_keys)
        if transform == ():
            transform = LoadImaged(["image", "label"])
        CacheDataset.__init__(
            self,
            data=data,
            transform=transform,
            cache_num=cache_num,
            cache_rate=cache_rate,
            num_workers=num_workers,
            progress=progress,
            copy_cache=copy_cache,
            as_contiguous=as_contiguous,
            runtime_cache=runtime_cache,
        )


# [docs]
    def get_indices(self) -> np.ndarray:
        """
        Get the indices of datalist used in this dataset.

        """
        return self.indices




# [docs]
    def randomize(self, data: np.ndarray) -> None:
        self.R.shuffle(data)




# [docs]
    def get_properties(self, keys: Sequence[str] | str | None = None) -> dict:
        """
        Get the loaded properties of dataset with specified keys.
        If no keys specified, return all the loaded properties.

        """
        if keys is None:
            return self._properties
        if self._properties is not None:
            return {key: self._properties[key] for key in ensure_tuple(keys)}
        return {}



    def _generate_data_list(self, dataset_dir: PathLike) -> list[dict]:
        # the types of the item in data list should be compatible with the dataloader
        dataset_dir = Path(dataset_dir) 
        section = "training" if self.section in ["training", "validation", "test"] else "test"
        datalist = load_decathlon_datalist("/kaggle/input/segres-json/dataset (1).json", True, section)
#         datalist2 = load_decathlon_datalist("/kaggle/input/segres-json/dataset.json", True, section)
#         datalist = datalist.append(datalist2)
        print(datalist[898])
        print(".............................................................................")
        return self._split_datalist(datalist)

    def _split_datalist(self, datalist: list[dict]) -> list[dict]:
#         if self.section == "test":
#             return datalist
        length = len(datalist)
        indices = np.arange(length)
        self.randomize(indices)

        
        val_length = int(length * self.val_frac)
        
        if self.section == "training":
            self.indices = indices[val_length:]
        elif self.section == "validation":
            self.indices = indices[:100]
        else:
            self.indices = indices[100:246]
        print(self.indices)
        return [datalist[i] for i in self.indices]

In [None]:
import os
import shutil
import tempfile
import time
import matplotlib.pyplot as plt
# from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader, decollate_batch
from monai.handlers.utils import from_engine
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
# from monai.networks.nets import SegResNet
from monai.transforms import (
    Activations,
    Activationsd,
    AsDiscrete,
    AsDiscreted,
    Compose,
    Invertd,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)
from monai.utils import set_determinism

import torch

print_config()

## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

## Set deterministic training for reproducibility

In [None]:
set_determinism(seed=0)

In [None]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WT (Whole tumor)
    and ET (Enhancing tumor).

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # merge label 2 and label 3 to construct TC
            result.append(torch.logical_or(d[key] == 2, d[key] == 3))
            # merge labels 1, 2 and 3 to construct WT
            result.append(torch.logical_or(torch.logical_or(d[key] == 2, d[key] == 3), d[key] == 1))
            # label 2 is ET
            result.append(d[key] == 2)
            d[key] = torch.stack(result, axis=0).float()
        return d

In [None]:
train_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        RandSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 128], random_size=False),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
        
    ]
)
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

## Define a new transform to convert brain tumor labels

Here we convert the multi-classes labels into multi-labels segmentation task in One-Hot format.

## Setup transforms for training and validation

## Quickly load data with DecathlonDataset

Here we use `DecathlonDataset` to automatically download and extract the dataset.
It inherits MONAI `CacheDataset`, if you want to use less memory, you can set `cache_num=N` to cache N items for training and use the default args to cache all the items for validation, it depends on your memory size.

In [None]:
# here we don't cache any data in case out of memory issue
train_ds = DecathlonDataset(
    root_dir=root_dir,
    task="Task01_BrainTumour",
    transform=train_transform,
    section="training",
    download=True,
    cache_rate=0.0,
    num_workers=4,
)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4)

val_ds = DecathlonDataset(
    root_dir=root_dir,
    task="Task01_BrainTumour",
    transform=val_transform,
    section="validation",
    download=False,
    cache_rate=0.0,
    num_workers=4,
)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4)

test_ds = DecathlonDataset(
    root_dir=root_dir,
    task="Task01_BrainTumour",
    transform=val_transform,
    section="test",
    download=False,
    cache_rate=0.0,
    num_workers=4,
)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4)

print(len(train_ds), len(val_ds), len(test_ds))

## Check data shape and visualize

In [None]:
# pick one image from DecathlonDataset to visualize and check the 4 channels
val_data_example = val_ds[2]
print(f"image shape: {val_data_example['image'].shape}")
plt.figure("image", (24, 6))
for i in range(4):
    plt.subplot(1, 4, i + 1)
    plt.title(f"image channel {i}")
    plt.imshow(val_data_example["image"][i, :, :, 60].detach().cpu(), cmap="gray")
plt.show()
# also visualize the 3 channels label corresponding to this image
print(f"label shape: {val_data_example['label'].shape}")
plt.figure("label", (18, 6))
for i in range(3):
    plt.subplot(1, 3, i + 1)
    plt.title(f"label channel {i}")
    plt.imshow(val_data_example["label"][i, :, :, 60].detach().cpu())
plt.show()

## Create Model, Loss, Optimizer

In [None]:
!pip install thop

In [None]:
from thop import profile
from thop import clever_format

In [None]:
max_epochs = 100
val_interval = 1
VAL_AMP = True


device = torch.device("cuda:0")
model = UNETR_PP(in_channels=4,out_channels=3,feature_size=16,num_heads=4,depths=[3, 3, 3, 3],dims=[32, 64, 128, 256],do_ds=True).to(device)
print(model)
image_size = 128
inputs = torch.rand(1, 4, image_size, image_size, image_size).to(device)
input_size = (1, 4, 128, 128, 128)
# input_tensor = torch.randn(*input_size).to(device)

# flops, params = profile(model, inputs=(input_tensor,))

# # Convert FLOPs to gigaFLOPs and format the results
# flops, params = clever_format([flops, params], "%.2f")
# print(f"FLOPs: {flops}, Params: {params}")

outputs = model(inputs)
print(outputs[0].shape)
# print(model)
# model.load_state_dict(torch.load("/kaggle/input/segseresnet-checkpoints/SegSeResNet_best_model_198.pth"))

loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])


# define inference method
def inference(input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(128, 128, 128),
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)


# use amp to accelerate training
scaler = torch.cuda.amp.GradScaler()
# enable cuDNN benchmark
torch.backends.cudnn.benchmark = True

## Execute a typical PyTorch training process

In [None]:
best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
val_epoch_loss_values = []
metric_values = []
train_metric_values = []
metric_values_tc = []
metric_values_wt = []
metric_values_et = []

start_epoch = 0
latest_checkpoint_path = "/kaggle/input/best-metric-model-77/best_metric_model_77.pth"
if os.path.exists(latest_checkpoint_path):
    checkpoint = torch.load(latest_checkpoint_path)
    model.load_state_dict(checkpoint)
    start_epoch = 80
    print("checkpoint loaded")
else:
    print("No checkpoint. Starting from scratch")

total_start = time.time()

for epoch in range(start_epoch, max_epochs):
    epoch_start = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    epoch_metric = 0
    val_epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step_start = time.time()
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = loss_function(outputs[0], labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
        outputs = [post_trans(i) for i in decollate_batch(outputs[0])]
        dice_metric(y_pred=outputs, y=labels)
        train_metric = dice_metric.aggregate().item()
        epoch_metric += train_metric
        
#         dice_metric_batch(y_pred=outputs, y=labels)
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}"
            f", train_loss: {loss.item():.4f}"
            f", current mean dice: {train_metric:.4f}"
            f", step time: {(time.time() - step_start):.4f}"
        )
    lr_scheduler.step()
    epoch_loss /= step
    epoch_metric /= step
    epoch_loss_values.append(epoch_loss)
    train_metric_values.append(epoch_metric)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}"
          f", average metric: {epoch_metric:.4f}")
    print("train_time: ", time.time()-epoch_start)

    if (epoch + 1) % val_interval == 0:
        model.eval()
        val_start = time.time()
        with torch.no_grad():
            val_step = 0
            for val_data in val_loader:
                val_step += 1
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                val_outputs = inference(val_inputs)
                loss = loss_function(val_outputs[0], val_labels)
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs[0])]
                dice_metric(y_pred=val_outputs, y=val_labels)
                dice_metric_batch(y_pred=val_outputs, y=val_labels)
                val_epoch_loss += loss.item()
                
            metric = dice_metric.aggregate().item()
            metric_values.append(metric)
            metric_batch = dice_metric_batch.aggregate()
            metric_tc = metric_batch[0].item()
            metric_values_tc.append(metric_tc)
            metric_wt = metric_batch[1].item()
            metric_values_wt.append(metric_wt)
            metric_et = metric_batch[2].item()
            metric_values_et.append(metric_et)
            dice_metric.reset()
            dice_metric_batch.reset()

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                best_metrics_epochs_and_time[0].append(best_metric)
                best_metrics_epochs_and_time[1].append(best_metric_epoch)
                best_metrics_epochs_and_time[2].append(time.time() - total_start)
                torch.save(
                    model.state_dict(),
                    os.path.join("/kaggle/working/", f"best_metric_model_{epoch+1}.pth"),
                )
                print("saved new best metric model")
            chk_file_name = "epoch_" + str(epoch+1) + "_model.pth"    
            torch.save(
                    model.state_dict(),
                    os.path.join("/kaggle/working/", chk_file_name),
                )
            val_epoch_loss /= val_step
            val_epoch_loss_values.append(val_epoch_loss)
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f", val_loss: {val_epoch_loss:.4f}"
                f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
                f"\nbest mean dice: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )
        print("val_time: ", time.time()-val_start)
    print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")
total_time = time.time() - total_start

In [None]:
print(len(val_epoch_loss_values), val_epoch_loss_values)
print(len(epoch_loss_values), epoch_loss_values)
print(len(train_metric_values), train_metric_values)
print(len(metric_values), metric_values)

In [None]:
import shutil

source_path = "/kaggle/input/metrics-loss/SegResNet_data.csv"
destination_path = "/kaggle/working/SegResNet_data_3.csv"


# Copy the file
shutil.copy(source_path, destination_path)

print(f"File '{source_path}' copied to '{destination_path}'.")


In [None]:
import csv

# Your lists of data

# list1 = [1, 2, 3, 4, 5]
# list2 = ['a', 'b', 'c', 'd', 'e']
# list3 = [10, 20, 30, 40, 50]
# list4 = ['apple', 'banana', 'cherry', 'date', 'elderberry']

# CSV file path
csv_file_path = '/kaggle/working/SegResNet_data_3.csv'

# Column names
fieldnames = ['epoch_loss_values', 'train_metric_values', 'val_epoch_loss_values', 'metric_values']

# Writing lists to a CSV file with specific column names
with open(csv_file_path, 'a', newline='') as csvfile:  # Change 'w' to 'a' for append mode
    # Create a CSV writer object with DictWriter
    csv_writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

    # Write the lists to the specified columns
    for i in range(len(epoch_loss_values)):
        csv_writer.writerow({
            'epoch_loss_values': epoch_loss_values[i],
            'train_metric_values': train_metric_values[i],
            'val_epoch_loss_values': val_epoch_loss_values[i],
            'metric_values': metric_values[i]
        })



In [None]:
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {16}, total time: {1}.")

## Plot the loss and metric

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y, color="red")
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y, color="green")
plt.show()

plt.figure("train", (18, 6))
plt.subplot(1, 3, 1)
plt.title("Val Mean Dice TC")
x = [val_interval * (i + 1) for i in range(len(metric_values_tc))]
y = metric_values_tc
plt.xlabel("epoch")
plt.plot(x, y, color="blue")
plt.subplot(1, 3, 2)
plt.title("Val Mean Dice WT")
x = [val_interval * (i + 1) for i in range(len(metric_values_wt))]
y = metric_values_wt
plt.xlabel("epoch")
plt.plot(x, y, color="brown")
plt.subplot(1, 3, 3)
plt.title("Val Mean Dice ET")
x = [val_interval * (i + 1) for i in range(len(metric_values_et))]
y = metric_values_et
plt.xlabel("epoch")
plt.plot(x, y, color="purple")
plt.show()

## Check best model output with the input image and label

In [None]:
pip install nibabel

In [None]:
import nibabel as nib 
img_add = "/kaggle/input/trainingdata-brats-part2/BraTS-GLI-00814-000/BraTS-GLI-00814-000-t1c.nii"
label_add = "/kaggle/input/trainingdata-brats-part2/BraTS-GLI-00814-000/BraTS-GLI-00814-000-t1n.nii"
img_add_2 = "/kaggle/input/trainingdata-brats-part2/BraTS-GLI-00814-000/BraTS-GLI-00814-000-t2f.nii"
img_add_3 = "/kaggle/input/trainingdata-brats-part2/BraTS-GLI-00814-000/BraTS-GLI-00814-000-t2w.nii"
img = nib.load(img_add).get_fdata()
label = nib.load(label_add).get_fdata()
img_2 = nib.load(img_add_2).get_fdata()
img_3 = nib.load(img_add_3).get_fdata()
print(f"image shape: {img.shape}, label shape: {label.shape}")
plt.figure("image", (18, 6))
plt.subplot(1, 4, 1)
plt.title("BraTS-GLI-00814-000-t1c.nii")
plt.imshow(img[:, :, 78], cmap="gray")
plt.subplot(1, 4, 2)
plt.title("BraTS-GLI-00814-000-t1n.nii")
plt.imshow(label[:, :, 78], cmap="gray")
plt.subplot(1, 4, 3)
plt.title("BraTS-GLI-00814-000-t2f.nii")
plt.imshow(img_2[:, :, 78], cmap="gray")
plt.subplot(1, 4, 4)
plt.title("BraTS-GLI-00814-000-t2w.nii")
plt.imshow(img_3[:, :, 78], cmap="gray")
plt.show()

In [None]:
import nibabel as nib 
from matplotlib.colors import ListedColormap
img_add = "/kaggle/input/trainingdata-brats-part2/BraTS-GLI-00814-000/BraTS-GLI-00814-000-t2f.nii"
label_add = "/kaggle/input/trainingdata-brats-part2/BraTS-GLI-00814-000/BraTS-GLI-00814-000-seg.nii"
img = nib.load(img_add).get_fdata()
label = nib.load(label_add).get_fdata()
# Create a custom colormap with purple, blue, and orange colors
cmap = ListedColormap(['black', 'blue', 'orange', 'green'])

print(f"image shape: {img.shape}, label shape: {label.shape}")
plt.figure("image", (18, 6))
plt.subplot(1, 3, 1)
plt.title("BraTS-GLI-00000-000-t2f.nii")
plt.imshow(img[:, :, 78], cmap="gray")
plt.subplot(1, 3, 2)
plt.title("BraTS-GLI-00000-000-seg.nii")
plt.imshow(label[:, :, 78])
plt.subplot(1, 3, 3)
plt.title("BraTS-GLI-00000-000-seg.nii")
plt.imshow(label[:, :, 78], cmap = cmap)
plt.show()

In [None]:
model.load_state_dict(torch.load("/kaggle/input/97model"))
cmap = ListedColormap(['black', 'blue', 'red', 'black'])
model.eval()
with torch.no_grad():
    # select one image to evaluate and visualize the model output
    val_input = val_ds[1]["image"].unsqueeze(0).to(device)
    roi_size = (128, 128, 128)
    sw_batch_size = 4
    val_output = inference(val_input)
    val_output = post_trans(val_output[0])
    titles_modalities = ["t1c", "t1n", "t2f", "t2w"]
    titles = ["Tumor Core (TC) i/p Channel", "Whole Tumor (WT) i/p Channel", "Enhancing tumor (ET) i/p Channel"]
    plt.figure("image", (24, 6))
    for i in range(4):
        plt.subplot(1, 4, i + 1)
        plt.title(titles_modalities[i])
        plt.imshow(val_ds[1]["image"][i, :, :, 78].detach().cpu(), cmap="gray")
    plt.show()
#     # visualize the 3 channels label corresponding to this image
#     plt.figure("label", (18, 6))
#     for i in range(3):
#         plt.subplot(1, 3, i + 1)
#         plt.title(titles[i])
#         plt.imshow(val_ds[1]["label"][i, :, :, 78].detach().cpu())
#     plt.show()
#     # visualize the 3 channels model output corresponding to this image
#     plt.figure("output", (18, 6))
#     for i in range(3):
#         plt.subplot(1, 3, i + 1)
#         plt.title(f"output channel {i}")
#         plt.imshow(val_output[i, :, :, 78].detach().cpu())
#     plt.show()
    
    seg_label = torch.zeros((val_ds[1]["label"].shape[1], val_ds[1]["label"].shape[2], val_ds[1]["label"].shape[3]))
    seg_label[val_ds[1]["label"][1] == 1] = 2
    seg_label[val_ds[1]["label"][0] == 1] = 1
    seg_label[val_ds[1]["label"][2] == 1] = 4
    
    seg_out = torch.zeros((val_output.shape[1], val_output.shape[2], val_output.shape[3]))
    seg_out[val_output[1] == 1] = 2
    seg_out[val_output[0] == 1] = 1
    seg_out[val_output[2] == 1] = 4
    
    slice_num = 78
#     img_add = "/kaggle/input/trainingdata-brats-part2/BraTS-GLI-00742-000/BraTS-GLI-00742-000-t2w.nii",
#     label_add = "/kaggle/input/trainingdata-brats-part2/BraTS-GLI-00742-000/BraTS-GLI-00742-000-seg.nii",
#     img = nib.load(img_add).get_fdata()
#     label = nib.load(label_add).get_fdata()
#     plt.figure("image", (18, 6))
# #     plt.subplot(1, 3, 1)
# #     plt.title("image")
# #     plt.imshow(seg_label[:, :, slice_num])
#     plt.subplot(1, 2, 1)
#     plt.title("Ground Truth Mask")
#     plt.imshow(seg_label[:, :, slice_num])
#     plt.subplot(1, 2, 2)
#     plt.title("Predicted Mask")
#     plt.imshow(seg_out[:, :, slice_num])
#     plt.show()
    
    plt.figure("image", (18, 6))
    for i in range(3):
        plt.subplot(1, 4, i + 1)
        plt.title(titles[i])
        plt.imshow(val_ds[1]["label"][i, :, :, 78].detach().cpu())
    plt.subplot(1, 4, 4)
    plt.title("Ground Truth Mask")
    plt.imshow(seg_label[:, :, slice_num])
    plt.show()   
    titles_output = ["Tumor Core (TC) o/p Channel", "Whole Tumor (WT) o/p Channel", "Enhancing tumor (ET) o/p Channel"]
    plt.figure("image", (18, 6))
    for i in range(3):
        plt.subplot(1, 4, i + 1)
        plt.title(titles_output[i])
        plt.imshow(val_output[i, :, :, 78].detach().cpu())
    plt.subplot(1, 4, 4)
    plt.title("Predicted Mask")
    plt.imshow(seg_out[:, :, slice_num])
    plt.show()   

## Evaluation on original image spacings

In [None]:
val_org_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image"], axcodes="RAS"),
        Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode="bilinear"),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

post_transforms = Compose(
    [
        Invertd(
            keys="pred",
            transform=val_transform,
            orig_keys="image",
            meta_keys="pred_meta_dict",
            orig_meta_keys="image_meta_dict",
            meta_key_postfix="meta_dict",
            nearest_interp=False,
            to_tensor=True,
            device="cpu",
        ),
        Activationsd(keys="pred", sigmoid=True),
        AsDiscreted(keys="pred", threshold=0.5),
    ]
)

In [None]:
from monai.metrics import HausdorffDistanceMetric
from monai.metrics import SurfaceDiceMetric
from monai.metrics import MeanIoU
from monai.metrics import SurfaceDistanceMetric
from monai.metrics import ROCAUCMetric
from monai.metrics import SurfaceDiceMetric

hd_metric = HausdorffDistanceMetric(include_background=False, reduction="mean", percentile = 95)
hd_metric_batch = HausdorffDistanceMetric(include_background=True, reduction="mean_batch", percentile = 95)

sd_metric = SurfaceDistanceMetric(include_background=True, reduction="mean")
sd_metric_batch = SurfaceDistanceMetric(include_background=True, reduction="mean_batch")

meanIoU_metric = MeanIoU(include_background=True, reduction="mean")
meanIoU_metric_batch = MeanIoU(include_background=True, reduction="mean_batch")

surfaceDice_metric = SurfaceDiceMetric(include_background=True, reduction="mean", class_thresholds = (0.01, 0.01, 0.01))
surfaceDice_metric_batch = SurfaceDiceMetric(include_background=True, reduction="mean_batch", class_thresholds = (0.01, 0.01, 0.01))

model.load_state_dict(torch.load("/kaggle/input/97model/best_metric_model_97.pth"))
model.eval()
count = 0
with torch.no_grad():
    infer_start = time.time()
    for val_data in test_loader:
        count = count + 1
        print(count)
#         val_inputs = val_data["image"].to(device)
#         val_data["pred"] = inference(val_inputs)
#         val_data = [post_trans(i) for i in decollate_batch(val_data)]
#         val_outputs, val_labels = from_engine(["pred", "label"])(val_data)
        val_inputs, val_labels = (
            val_data["image"].to(device),
            val_data["label"].to(device),
        )
        val_outputs = inference(val_inputs)
        val_outputs = [post_trans(i) for i in decollate_batch(val_outputs[0])]
        
        dice_metric(y_pred=val_outputs, y=val_labels)
        dice_metric_batch(y_pred=val_outputs, y=val_labels)
        
        hd_metric(y_pred=val_outputs, y=val_labels)
        hd_metric_batch(y_pred=val_outputs, y=val_labels)
        
        meanIoU_metric(y_pred=val_outputs, y=val_labels)
        meanIoU_metric_batch(y_pred=val_outputs, y=val_labels)
        
        sd_metric(y_pred=val_outputs, y=val_labels)
        sd_metric_batch(y_pred=val_outputs, y=val_labels)
        
        surfaceDice_metric(y_pred=val_outputs, y=val_labels)
        surfaceDice_metric_batch(y_pred=val_outputs, y=val_labels)
    print("infer_time: ", time.time()-infer_start)

    metric_org = dice_metric.aggregate().item()
    metric_batch_org = dice_metric_batch.aggregate()
    
    hd_metric_org = hd_metric.aggregate().item()
    hd_metric_batch_org = hd_metric_batch.aggregate()
    
    meanIoU_metric_org = meanIoU_metric.aggregate().item()
    meanIoU_metric_batch_org = meanIoU_metric_batch.aggregate()
    
    sd_metric_org = sd_metric.aggregate().item()
    sd_metric_batch_org = sd_metric_batch.aggregate()
    
    surfaceDice_metric_org = surfaceDice_metric.aggregate().item()
    surfaceDice_metric_batch_org = surfaceDice_metric_batch.aggregate()

    dice_metric.reset()
    dice_metric_batch.reset()
    
    hd_metric.reset()
    hd_metric_batch.reset()
    
    meanIoU_metric.reset()
    meanIoU_metric_batch.reset()
    
    sd_metric.reset()
    sd_metric_batch.reset()
    
    surfaceDice_metric.reset()
    surfaceDice_metric_batch.reset()

metric_tc, metric_wt, metric_et = metric_batch_org[0].item(), metric_batch_org[1].item(), metric_batch_org[2].item()

hd_metric_tc, hd_metric_wt, hd_metric_et = hd_metric_batch_org[0].item(), hd_metric_batch_org[1].item(), hd_metric_batch_org[2].item()

meanIoU_metric_tc, meanIoU_metric_wt, meanIoU_metric_et = meanIoU_metric_batch_org[0].item(), meanIoU_metric_batch_org[1].item(), meanIoU_metric_batch_org[2].item()

sd_metric_tc, sd_metric_wt, sd_metric_et = sd_metric_batch_org[0].item(), sd_metric_batch_org[1].item(), sd_metric_batch_org[2].item()

surfaceDice_metric_tc, surfaceDice_metric_wt, surfaceDice_metric_et = surfaceDice_metric_batch_org[0].item(), surfaceDice_metric_batch_org[1].item(), surfaceDice_metric_batch_org[2].item()

print("Metric on original image spacing: ", metric_org)
print(f"metric_tc: {metric_tc:.4f}")
print(f"metric_wt: {metric_wt:.4f}")
print(f"metric_et: {metric_et:.4f}")

print("HD Metric on original image spacing: ", hd_metric_org)
print(f"HD metric_tc: {hd_metric_tc:.4f}")
print(f"HD metric_wt: {hd_metric_wt:.4f}")
print(f"HD metric_et: {hd_metric_et:.4f}")

print("MeanIoU Metric on original image spacing: ", meanIoU_metric_org)
print(f"MeanIoU metric_tc: {meanIoU_metric_tc:.4f}")
print(f"MeanIoU metric_wt: {meanIoU_metric_wt:.4f}")
print(f"MeanIoU metric_et: {meanIoU_metric_et:.4f}")

print("SD Metric on original image spacing: ", sd_metric_org)
print(f"SD metric_tc: {sd_metric_tc:.4f}")
print(f"SD metric_wt: {sd_metric_wt:.4f}")
print(f"SD metric_et: {sd_metric_et:.4f}")

print("Surface Dice Metric on original image spacing: ", surfaceDice_metric_org)
print(f"Surface Dice metric_tc: {surfaceDice_metric_tc:.4f}")
print(f"Surface Dice metric_wt: {surfaceDice_metric_wt:.4f}")
print(f"Surface Dice metric_et: {surfaceDice_metric_et:.4f}")

## Cleanup data directory

Remove directory if a temporary was used.

In [None]:
if directory is None:
    shutil.rmtree(root_dir)