In [1]:
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# MAISI Inference Tutorial

This tutorial illustrates how to use trained MAISI model and codebase to generate synthetic 3D images and paired masks.

## Setup environment

In [2]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"
!python -c "import nibabel" || pip install -q "nibabel"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

## Setup imports

In [3]:
import json
import torch
import monai
from monai.config import print_config
import argparse
from monai.utils import set_determinism
from scripts.utils import define_instance, load_autoencoder_ckpt
from scripts.sample import check_input, LDMSampler

print_config()

MONAI version: 1.3.1+25.g64ea76d8
Numpy version: 1.26.4
Pytorch version: 2.3.0+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 64ea76d83a92b7cf7f13c8f93498d50037c3324c
MONAI __file__: /mnt/drive3/wenao/anaconda3/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: 0.22.0
scipy version: 1.11.4
Pillow version: 10.2.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: 5.2.0
TorchVision version: 0.18.1+cu121
tqdm version: 4.65.0
lmdb version: 1.4.1
psutil version: 5.9.0
pandas version: 2.1.4
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: 4.41.2
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional d

## Read in environment setting, including data directory, model directory, and output directory

The information for data directory, model directory, and output directory are saved in ./configs/environment.json

In [4]:
args = argparse.Namespace()

environment_file = './configs/environment.json'
env_dict = json.load(open(environment_file, "r"))
for k, v in env_dict.items():
    setattr(args, k, v)
print("Environment variables have been read in.")

Environment variables have been read in.


## Read in configuration setting, including network definition, body region and anatomy to generate, etc.

The information for the configuration that is not related to working directory, including network definition, body region and anatomy to generate, etc.

In [5]:
config_file = './configs/config_maisi.json'
config_dict = json.load(open(config_file, "r"))
for k, v in config_dict.items():
    setattr(args, k, v)
    
# check the format of inference inputs
check_input(args.body_region,args.anatomy_list,args.label_dict_json,args.output_size,args.spacing,args.controllable_anatomy_size)
latent_shape = [args.latent_channels, args.output_size[0]//4, args.output_size[1]//4, args.output_size[2]//4]
print("Configuration variables have been read in.")

controllable_anatomy_size is not empty. We will ignore body_region and anatomy_list and synthesize based on controllable_anatomy_size ([['hepatic tumor', 0.3], ['liver', 0.5]]).
The generate results will have voxel size to be [1.5, 1.5, 2.0]mm, and volume size to be [256, 256, 256].
Configuration variables have been read in.


## Define networks and noise scheduler, load network weights.

The networks and noise scheduler are defined in config_file. We will read them in and load the model weights.

In [6]:
noise_scheduler = define_instance(args, "noise_scheduler")
mask_generation_noise_scheduler = define_instance(args, "mask_generation_noise_scheduler")

device = torch.device("cuda")

autoencoder = define_instance(args, "autoencoder_def").to(device)
checkpoint_autoencoder = load_autoencoder_ckpt(args.trained_autoencoder_path)
autoencoder.load_state_dict(checkpoint_autoencoder)

diffusion_unet = define_instance(args, "diffusion_unet_def").to(device)
checkpoint_diffusion_unet = torch.load(args.trained_diffusion_path)
diffusion_unet.load_state_dict(checkpoint_diffusion_unet['unet_state_dict'])
scale_factor = checkpoint_diffusion_unet['scale_factor'].to(device)

controlnet = define_instance(args, "controlnet_def").to(device)
checkpoint_controlnet = torch.load(args.trained_controlnet_path)
monai.networks.utils.copy_model_state(controlnet, diffusion_unet.state_dict())
controlnet.load_state_dict(checkpoint_controlnet['controlnet_state_dict'], strict=True)

mask_generation_autoencoder = define_instance(args, "mask_generation_autoencoder_def").to(device)
checkpoint_mask_generation_autoencoder = torch.load(args.trained_mask_generation_autoencoder_path)
mask_generation_autoencoder.load_state_dict(checkpoint_mask_generation_autoencoder)

mask_generation_diffusion_unet = define_instance(args, "mask_generation_diffusion_def").to(device)
checkpoint_mask_generation_diffusion_unet = torch.load(args.trained_mask_generation_diffusion_path)
mask_generation_diffusion_unet.load_state_dict(checkpoint_mask_generation_diffusion_unet, strict=True)
mask_generation_scale_factor = args.mask_generation_scale_factor

print("All the trained model weights have been loaded.")

2024-07-07 06:31:28,641 - INFO - 'dst' model updated: 158 of 206 variables.
All the trained model weights have been loaded.


## Define the LDM Sampler, which contains functions that will perform the inference.

In [7]:
ldm_sampler = LDMSampler(
    args.body_region,
    args.anatomy_list,
    args.all_mask_files_json,
    args.all_anatomy_size_condtions_json,
    args.all_mask_files_base_dir,
    args.label_dict_json,
    args.label_dict_remap_json,
    autoencoder,
    diffusion_unet,
    controlnet,
    noise_scheduler,
    scale_factor,
    mask_generation_autoencoder,
    mask_generation_diffusion_unet,
    mask_generation_scale_factor,
    mask_generation_noise_scheduler,
    device,
    latent_shape,
    args.mask_generation_latent_shape,
    args.output_size,
    args.output_dir,
    args.controllable_anatomy_size,
    image_output_ext = args.image_output_ext,
    label_output_ext = args.label_output_ext,
    spacing=args.spacing,
    num_inference_steps=1000,
    mask_generation_num_inference_steps=1000,
    random_seed = args.random_seed
)

controllable_anatomy_size is given, mask generation is triggered!
LDM sampler initialized.


## Perform the inference. 
It will take around 3min to generate a pair of [256,256,256] image/mask on one 80G A100. The time cost per image pair is roughly linear to the output size.

In [8]:
print(f"The generated image/mask pairs will be saved in {args.output_dir}.")
ldm_sampler.sample_multiple_images(args.num_output_samples)
print(f"MAISI image/mask generation finished")

The generated image/mask pairs will be saved in ./output.
controllable_anatomy_size: [['hepatic tumor', 0.3], ['liver', 0.5]]
provide_anatomy_size: [None, 0.5, None, None, None, None, None, 0.3, None, None]
candidate_condition: [-1.0, 0.463429, 0.486525, 0.287951, 0.278651, -1.0, -1.0, 0.310093, -1.0, -1.0]
final candidate_condition: [-1.0, 0.5, 0.486525, 0.287951, 0.278651, -1.0, -1.0, 0.3, -1.0, -1.0]
Prepare mask...


100%|███████████████████████████████████████| 1000/1000 [01:03<00:00, 15.75it/s]
  return F.conv3d(


target_tumor_label for postprocess: 26
Current model does not support hepatic vessel by size control, so we treat generated hepatic vessel as part of liver for better visiaulization.
Mask preparation time: 87.0064799785614 seconds.
Start generating latent features...


100%|█████████████████████████████████████████████████████████████████████| 1000/1000 [01:15<00:00, 13.21it/s]


Latent features generation time: 75.73426008224487 seconds
Start decoding latent features into images...
Image decoding time: 10.641048669815063 seconds
2024-07-07 06:34:25,599 INFO image_writer.py:197 - writing: output/sample_20240707_063425_584977_image.nii.gz
2024-07-07 06:34:26,892 INFO image_writer.py:197 - writing: output/sample_20240707_063425_584977_label.nii.gz
MAISI image/mask generation finished
