Copyright (c) MONAI Consortium

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
    
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and

# MRI Brain image generation

* This tutorial illustrates a generative model for creating 3D brain MRI from Gaussian noise.

## Setup environment

* installing the required libraries

In [None]:
!pip install -q "monai-weekly[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub]"
!pip install -q nibabel
!pip install -q ipywidgets
!pip install -q opencv-python-headless
!pip install -q matplotlib
%matplotlib inline

## Setup imports
* importing libraries

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from monai.transforms import AsDiscrete
from monai.config import print_config
from monai.transforms import LoadImage, Orientation

print_config()

* plotting helper functions

In [10]:
def find_label_center_loc(x):
    """
    Find the center location of non-zero elements in a binary mask.

    Args:
        x (torch.Tensor): Binary mask tensor. Expected shape: [H, W, D] or [C, H, W, D].

    Returns:
        list: Center locations for each dimension. Each element is either
              the middle index of non-zero locations or None if no non-zero elements exist.
    """
    label_loc = torch.where(x != 0)
    center_loc = []
    for loc in label_loc:
        unique_loc = torch.unique(loc)
        if len(unique_loc) == 0:
            center_loc.append(None)
        else:
            center_loc.append(unique_loc[len(unique_loc) // 2])

    return center_loc


def normalize_label_to_uint8(colorize, label, n_label):
    """
    Normalize and colorize a label tensor to a uint8 image.

    Args:
        colorize (torch.Tensor): Weight tensor for colorization. Expected shape: [3, n_label, 1, 1].
        label (torch.Tensor): Input label tensor. Expected shape: [1, H, W].
        n_label (int): Number of unique labels.

    Returns:
        numpy.ndarray: Normalized and colorized image as uint8 numpy array. Shape: [H, W, 3].
    """
    with torch.no_grad():
        post_label = AsDiscrete(to_onehot=n_label)
        label = post_label(label).permute(1, 0, 2, 3)
        label = F.conv2d(label, weight=colorize)
        label = torch.clip(label, 0, 1).squeeze().permute(1, 2, 0).cpu().numpy()

    draw_img = (label * 255).astype(np.uint8)

    return draw_img


def visualize_one_slice_in_3d(image, axis: int = 2, center=None, mask_bool=True, n_label=105, colorize=None):
    """
    Extract and visualize a 2D slice from a 3D image or label tensor.

    Args:
        image (torch.Tensor): Input 3D image or label tensor. Expected shape: [1, H, W, D].
        axis (int, optional): Axis along which to extract the slice (0, 1, or 2). Defaults to 2.
        center (int, optional): Index of the slice to extract. If None, the middle slice is used.
        mask_bool (bool, optional): If True, treat the input as a label mask and normalize it. Defaults to True.
        n_label (int, optional): Number of labels in the mask. Used only if mask_bool is True. Defaults to 105.
        colorize (torch.Tensor, optional): Colorization weights for label normalization.
                                           Expected shape: [3, n_label, 1, 1] if provided.

    Returns:
        numpy.ndarray: 2D slice of the input. If mask_bool is True, returns a normalized uint8 array
                       with shape [3, H, W]. Otherwise, returns a float32 array with shape [3, H, W].

    Raises:
        ValueError: If the specified axis is not 0, 1, or 2.
    """
    # draw image
    if center is None:
        center = image.shape[2:][axis] // 2
    if axis == 0:
        draw_img = image[..., center, :, :]
    elif axis == 1:
        draw_img = image[..., :, center, :]
    elif axis == 2:
        draw_img = image[..., :, :, center]
    else:
        raise ValueError("axis should be in [0,1,2]")
    if mask_bool:
        draw_img = normalize_label_to_uint8(colorize, draw_img, n_label)
    else:
        draw_img = draw_img.squeeze().cpu().numpy().astype(np.float32)
        draw_img = np.stack((draw_img,) * 3, axis=-1)
    return draw_img


def show_image(image, title="mask"):
    """
    Plot and display an input image.

    Args:
        image (numpy.ndarray): Image to be displayed. Expected shape: [H, W] for grayscale or [H, W, 3] for RGB.
        title (str, optional): Title for the plot. Defaults to "mask".
    """
    plt.figure("check", (24, 12))
    plt.subplot(1, 2, 1)
    plt.title(title)
    plt.imshow(image)
    plt.show()


def to_shape(a, shape):
    """
    Pad an image to a desired shape.

    This function pads a 3D numpy array (image) with zeros to reach the specified shape.
    The padding is added equally on both sides of each dimension, with any odd padding
    added to the end.

    Args:
        a (numpy.ndarray): Input 3D array to be padded. Expected shape: [X, Y, Z].
        shape (tuple): Desired output shape as (x_, y_, z_).

    Returns:
        numpy.ndarray: Padded array with the desired shape [x_, y_, z_].

    Note:
        If the input shape is larger than the desired shape in any dimension,
        no padding is removed; the original size is maintained for that dimension.
        Padding is done using numpy's pad function with 'constant' mode (zero-padding).
    """
    x_, y_, z_ = shape
    x, y, z = a.shape
    x_pad = x_ - x
    y_pad = y_ - y
    z_pad = z_ - z
    return np.pad(
        a,
        (
            (x_pad // 2, x_pad // 2 + x_pad % 2),
            (y_pad // 2, y_pad // 2 + y_pad % 2),
            (z_pad // 2, z_pad // 2 + z_pad % 2),
        ),
        mode="constant",
    )


def get_xyz_plot(image, center_loc_axis, mask_bool=True, n_label=105, colorize=None, target_class_index=0):
    """
    Generate a concatenated XYZ plot of 2D slices from a 3D image.

    This function creates visualizations of three orthogonal slices (XY, XZ, YZ) from a 3D image
    and concatenates them into a single 2D image.

    Args:
        image (torch.Tensor): Input 3D image tensor. Expected shape: [1, H, W, D].
        center_loc_axis (list): List of three integers specifying the center locations for each axis.
        mask_bool (bool, optional): Whether to apply masking. Defaults to True.
        n_label (int, optional): Number of labels for visualization. Defaults to 105.
        colorize (torch.Tensor, optional): Colorization weights. Expected shape: [3, n_label, 1, 1] if provided.
        target_class_index (int, optional): Index of the target class. Defaults to 0.

    Returns:
        numpy.ndarray: Concatenated 2D image of the three orthogonal slices. Shape: [max(H,W,D), 3*max(H,W,D), 3].

    Note:
        The output image is padded to ensure all slices have the same dimensions.
    """
    target_shape = list(image.shape[1:])  # [1,H,W,D]
    img_list = []

    for axis in range(3):
        center = center_loc_axis[axis]

        img = visualize_one_slice_in_3d(
            torch.flip(image.unsqueeze(0), [-3, -2, -1]),
            axis,
            center=center,
            mask_bool=mask_bool,
            n_label=n_label,
            colorize=colorize,
        )
        img = img.transpose([2, 1, 0])

        img = to_shape(img, (3, max(target_shape), max(target_shape)))
        img_list.append(img)
        img = np.concatenate(img_list, axis=2).transpose([1, 2, 0])
    return img



## Execute inference

* A pre-trained model for volumetric (3D) Brats MRI 3D Latent Diffusion Generative Model.

* This model is trained on BraTS 2016 and 2017 data from [Medical Decathlon](http://medicaldecathlon.com/), using the Latent diffusion model [1].

![model workflow](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_network.png)

* This model is a generator for creating images like the Flair MRIs based on BraTS 2016 and 2017 data. It was trained as a 3d latent diffusion model and accepts Gaussian random noise as inputs to produce an image output.

* The following code generates a synthetic image from a random sampled noise.

In [None]:
## method 1
# !python -m monai.bundle run --config_file "configs/inference.json"

## method 2
# import monai
# monai.bundle.run(config_file="configs/inference.json")

## method 3
def generate_mri_sample(config_path: str, output_dir: str):
    import glob
    import os
    # Capture existing output files
    existing_files = set(glob.glob(os.path.join(output_dir, "**", "*.nii.gz"), recursive=True))

    import monai
    monai.bundle.run(config_file=config_path)

    # Identify new files
    all_files = set(glob.glob(os.path.join(output_dir, "**", "*.nii.gz"), recursive=True))
    new_files = all_files - existing_files

    if new_files:
        return max(new_files, key=os.path.getmtime)  # Return the most recent MRI file
    else:
        return None  # No new MRI file detected

filepath = generate_mri_sample(config_path="configs/inference.json", output_dir="output/")

#### Example synthetic image
An example result from inference is shown below:
![Example synthetic image](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_example_generation_v2.png)

## Visualize the results


In [None]:
visualize_image_filename = filepath #eg: "./output/0/0_sample_20250129_060631.nii.gz"
print(f"Visualizing {visualize_image_filename} ...")

# load image/mask pairs
loader = LoadImage(image_only=True, ensure_channel_first=True)
orientation = Orientation(axcodes="RAS")
image_volume = orientation(loader(visualize_image_filename))
mask_volume = orientation(loader(visualize_image_filename)).to(torch.uint8)

# visualize for CT HU intensity between [-200, 500]
image_volume = torch.clip(image_volume, -200, 500)
image_volume = image_volume - torch.min(image_volume)
image_volume = image_volume / torch.max(image_volume)

# create a random color map for mask visualization
colorize = torch.clip(torch.cat([torch.zeros(3, 1, 1, 1), torch.randn(3, 200, 1, 1)], 1), 0, 1)
target_class_index = 4

# find center voxel location for 2D slice visualization
center_loc_axis = find_label_center_loc(torch.flip(mask_volume[0, ...] == target_class_index, [-3, -2, -1]))


vis_image = get_xyz_plot(image_volume, center_loc_axis, mask_bool=False)
show_image(vis_image, title="image")

# interactive plot

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.animation as animation

img = image_volume[0]
frames = [] # for storing the generated images
fig = plt.figure()
for i in range(img.shape[0]):
    frames.append([plt.imshow(img[i], aspect=0.74, cmap="gray", animated=True)])

ani = animation.ArtistAnimation(fig, frames, interval=50, blit=True,
                                repeat_delay=1000);                             
# plt.show()
# ani.save('mri.mp4')
from IPython.display import HTML
HTML(ani.to_jshtml())


# MRI Image Segmentation

* here we are segmenting the above generated MRI, using a pretrained segmentation model.


In [None]:
import os
import torch
from monai.inferers import sliding_window_inference
from monai.networks.nets import SegResNet
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose
)

In [14]:
VAL_AMP = True

# standard PyTorch program style: create SegResNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2,
).to(device)

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

# define inference method
def inference(input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(240, 240, 160),
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)

In [None]:
model.load_state_dict(torch.load(os.path.join("./models", "best_metric_model.pth")))
model.eval()

layer = 10
with torch.no_grad():
    val_input = image_volume.unsqueeze(1).repeat(1, 4, 1, 1, 1).to(device)
    val_output = inference(val_input)
    print(val_input.shape, val_output.shape)
    val_output = post_trans(val_output[0])
    plt.figure("image", (16, 12))
    plt.subplot(2, 2, 1)
    plt.title(f"input image")
    plt.imshow(val_input[0][1, :, :, layer].detach().cpu(), cmap="gray")
 
    summed_output = torch.sum(val_output, dim=0, keepdim=True)
 
    plt.subplot(2, 2, 2)
    plt.title(f"output mask")
    plt.imshow(summed_output[0][:, :, layer].detach().cpu())

    plt.subplot(2, 2, 3)
    plt.title(f"mask on input")
    plt.imshow(val_input[0][1, :, :, layer].detach().cpu() + 0.3 * summed_output[0][:, :, layer].detach().cpu(), cmap="gray")
    plt.show()

# Suggested pipeline for continious training and integration

* The following pipeline suggested a fully automated workflow to generate new synthetic MRI tumor segmenattion samples, which can be used for training new versions of this segmentation model.

![pipeline diagram](./docs/monai-imagegen.drawio.png)

In [None]:
# END