Copyright (c) MONAI Consortium
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# MAISI Inference Tutorial

This tutorial illustrates how to use trained MAISI model and codebase to generate synthetic 3D images and paired masks.

## Setup environment

To run this tutorial, please install `xformers` by following the [official guide](https://github.com/facebookresearch/xformers#installing-xformers) 

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"
!python -c "import generative" || pip install monai-generative
!python -c "import matplotlib" || pip install -q matplotlib
!python -c "import os; print(os.path.isfile('models/autoencoder_epoch273.pt'))" || python -m monai.bundle download_large_files --bundle_path .
%matplotlib inline

## Setup imports

In [None]:
import json
import torch
import monai
from monai.config import print_config
import argparse
from scripts.utils import define_instance, load_autoencoder_ckpt
from scripts.sample import check_input, LDMSampler

print_config()

## Read in environment setting, including data directory, model directory, and output directory

The information for data directory, model directory, and output directory are saved in ./configs/environment.json

In [None]:
args = argparse.Namespace()

environment_file = "./configs/environment.json"
env_dict = json.load(open(environment_file, "r"))
for k, v in env_dict.items():
    setattr(args, k, v)
    print(f"{k}: {v}")
print("Global config variables have been loaded.")

## Read in configuration setting, including network definition, body region and anatomy to generate, etc.

The information used for both training and inference, like network definition, is stored in "./configs/config_maisi.json". Training and inference should use the same "./configs/config_maisi.json".

The information for the inference input, like body region and anatomy to generate, is stored in "./configs/config_infer.json". Please feel free to play with it.
- `"num_output_samples"`: int, the number of output image/mask pairs it will generate.
- `"spacing"`: voxel size of generated images. E.g., if set to `[1.5, 1.5, 2.0]`, it will generate images with a resolution of 1.5x1.5x2.0 mm.
- `"output_size"`: volume size of generated images. E.g., if set to `[512, 512, 256]`, it will generate images with size of 512x512x256. They need to be divisible by 16. If you have a small GPU memory size, you should adjust it to small numbers.
- `"controllable_anatomy_size"`: a list of controllable anatomy and its size scale (0--1). E.g., if set to `[["liver", 0.5],["hepatic tumor", 0.3]]`, the generated image will contain liver that have a median size, with size around 50% percentile, and hepatic tumor that is relatively small, with around 30% percentile. The output will contain paired image and segmentation mask for the controllable anatomy.
- `"body_region"`: If "controllable_anatomy_size" is not specified, "body_region" will be used to constrain the region of generated images. It needs to be chosen from "head", "chest", "thorax", "abdomen", "pelvis", "lower".
- `"anatomy_list"`: If "controllable_anatomy_size" is not specified, the output will contain paired image and segmentation mask for the anatomy in "./configs/label_dict.json".


In [None]:
config_file = "./configs/config_maisi.json"
config_dict = json.load(open(config_file, "r"))
for k, v in config_dict.items():
    setattr(args, k, v)

# check the format of inference inputs
config_infer_file = "./configs/config_infer.json"
config_infer_dict = json.load(open(config_infer_file, "r"))
for k, v in config_infer_dict.items():
    setattr(args, k, v)
    print(f"{k}: {v}")

check_input(
    args.body_region,
    args.anatomy_list,
    args.label_dict_json,
    args.output_size,
    args.spacing,
    args.controllable_anatomy_size,
)
latent_shape = [args.latent_channels, args.output_size[0] // 4, args.output_size[1] // 4, args.output_size[2] // 4]
print("Network definition and inference inputs have been loaded.")

## Initialize networks and noise scheduler, then load the trained model weights.

The networks and noise scheduler are defined in `config_file`. We will read them in and load the model weights.

In [None]:
noise_scheduler = define_instance(args, "noise_scheduler")
mask_generation_noise_scheduler = define_instance(args, "mask_generation_noise_scheduler")

device = torch.device("cuda")

autoencoder = define_instance(args, "autoencoder_def").to(device)
checkpoint_autoencoder = load_autoencoder_ckpt(args.trained_autoencoder_path)
autoencoder.load_state_dict(checkpoint_autoencoder)

diffusion_unet = define_instance(args, "diffusion_unet_def").to(device)
checkpoint_diffusion_unet = torch.load(args.trained_diffusion_path)
diffusion_unet.load_state_dict(checkpoint_diffusion_unet["unet_state_dict"])
scale_factor = checkpoint_diffusion_unet["scale_factor"].to(device)

controlnet = define_instance(args, "controlnet_def").to(device)
checkpoint_controlnet = torch.load(args.trained_controlnet_path)
monai.networks.utils.copy_model_state(controlnet, diffusion_unet.state_dict())
controlnet.load_state_dict(checkpoint_controlnet["controlnet_state_dict"], strict=True)

mask_generation_autoencoder = define_instance(args, "mask_generation_autoencoder_def").to(device)
checkpoint_mask_generation_autoencoder = torch.load(args.trained_mask_generation_autoencoder_path)
mask_generation_autoencoder.load_state_dict(checkpoint_mask_generation_autoencoder)

mask_generation_diffusion_unet = define_instance(args, "mask_generation_diffusion_def").to(device)
checkpoint_mask_generation_diffusion_unet = torch.load(args.trained_mask_generation_diffusion_path)
mask_generation_diffusion_unet.load_state_dict(checkpoint_mask_generation_diffusion_unet, strict=True)
mask_generation_scale_factor = args.mask_generation_scale_factor

print("All the trained model weights have been loaded.")

## Define the LDM Sampler, which contains functions that will perform the inference.

In [None]:
ldm_sampler = LDMSampler(
    args.body_region,
    args.anatomy_list,
    args.all_mask_files_json,
    args.all_anatomy_size_condtions_json,
    args.all_mask_files_base_dir,
    args.label_dict_json,
    args.label_dict_remap_json,
    autoencoder,
    diffusion_unet,
    controlnet,
    noise_scheduler,
    scale_factor,
    mask_generation_autoencoder,
    mask_generation_diffusion_unet,
    mask_generation_scale_factor,
    mask_generation_noise_scheduler,
    device,
    latent_shape,
    args.mask_generation_latent_shape,
    args.output_size,
    args.output_dir,
    args.controllable_anatomy_size,
    image_output_ext=args.image_output_ext,
    label_output_ext=args.label_output_ext,
    spacing=args.spacing,
    num_inference_steps=args.num_inference_steps,
    mask_generation_num_inference_steps=args.mask_generation_num_inference_steps,
    random_seed=args.random_seed,
)

## Perform the inference. 
Time cost: It will take around 80s to generate a [256,256,256] mask, and another 105s to generate the corresponding [256,256,256] image on one A100 using 34G GPU memory. The time cost per image is roughly linear to the output size. i.e., if we want to generate a [512,512,512] image/mask pair, it will take around 80s to generate a [256,256,256] mask and then resample it to a [512,512,512] mask. The time for [512,512,512] image generatation will be 8 times longer, around 8x105s~14min. 

In [None]:
print(f"The generated image/mask pairs will be saved in {args.output_dir}.")
ldm_sampler.sample_multiple_images(args.num_output_samples)
print("MAISI image/mask generation finished")