diff --git a/ci/unit_tests/test_maisi_ct_generative.py b/ci/unit_tests/test_maisi_ct_generative.py index e2b9db54..94389480 100644 --- a/ci/unit_tests/test_maisi_ct_generative.py +++ b/ci/unit_tests/test_maisi_ct_generative.py @@ -85,17 +85,6 @@ } ] -TEST_CASE_INFER_ERROR = [ - { - "bundle_root": "models/maisi_ct_generative", - "num_output_samples": 1, - "output_size": [256, 256, 256], - "body_region": ["head"], - "anatomy_list": ["colon cancer primaries"], - }, - "Cannot find body region with given anatomy list.", -] - TEST_CASE_INFER_ERROR_2 = [ { "bundle_root": "models/maisi_ct_generative", @@ -277,7 +266,7 @@ def test_infer_config(self, override): else: self.assertTrue(output_file.endswith(".nii.gz")) - @parameterized.expand([TEST_CASE_INFER_ERROR, TEST_CASE_INFER_ERROR_7]) + @parameterized.expand([TEST_CASE_INFER_ERROR_7]) def test_infer_config_error_input(self, override, expected_error): # update override override["output_dir"] = self.output_dir diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json index 283305ed..35f76248 100644 --- a/models/maisi_ct_generative/configs/inference.json +++ b/models/maisi_ct_generative/configs/inference.json @@ -9,9 +9,9 @@ "output_dir": "$@bundle_root + '/output'", "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)", "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", - "trained_autoencoder_path": "$@model_dir + '/autoencoder_epoch273.pt'", - "trained_diffusion_path": "$@model_dir + '/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt'", - "trained_controlnet_path": "$@model_dir + '/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt'", + "trained_autoencoder_path": "$@model_dir + '/autoencoder.pt'", + "trained_diffusion_path": "$@model_dir + '/diffusion_unet.pt'", + "trained_controlnet_path": "$@model_dir + '/controlnet.pt'", "trained_mask_generation_autoencoder_path": "$@model_dir + '/mask_generation_autoencoder.pt'", "trained_mask_generation_diffusion_path": "$@model_dir + '/mask_generation_diffusion_unet.pt'", "all_mask_files_base_dir": "$@bundle_root + '/datasets/all_masks_flexible_size_and_spacing_3000'", @@ -21,14 +21,13 @@ "label_dict_remap_json": "$@bundle_root + '/configs/label_dict_124_to_132.json'", "real_img_median_statistics_file": "$@bundle_root + '/configs/image_median_statistics.json'", "num_output_samples": 1, - "body_region": [ - "abdomen" - ], + "body_region": [], "anatomy_list": [ "liver" ], + "modality": "ct", "controllable_anatomy_size": [], - "num_inference_steps": 1000, + "num_inference_steps": 30, "mask_generation_num_inference_steps": 1000, "random_seed": null, "spatial_dims": 3, @@ -63,11 +62,11 @@ 64 ], "autoencoder_sliding_window_infer_size": [ - 96, - 96, - 96 + 80, + 80, + 80 ], - "autoencoder_sliding_window_infer_overlap": 0.6667, + "autoencoder_sliding_window_infer_overlap": 0.4, "autoencoder_def": { "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi", "spatial_dims": "@spatial_dims", @@ -96,7 +95,7 @@ "use_checkpointing": false, "use_convtranspose": false, "norm_float16": true, - "num_splits": 8, + "num_splits": 2, "dim_split": 1 }, "diffusion_unet_def": { @@ -124,9 +123,12 @@ ], "num_res_blocks": 2, "use_flash_attention": true, - "include_top_region_index_input": true, - "include_bottom_region_index_input": true, - "include_spacing_input": true + "include_top_region_index_input": false, + "include_bottom_region_index_input": false, + "include_spacing_input": true, + "num_class_embeds": 128, + "resblock_updown": true, + "include_fc": true }, "controlnet_def": { "_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi", @@ -157,7 +159,10 @@ 8, 32, 64 - ] + ], + "num_class_embeds": 128, + "resblock_updown": true, + "include_fc": true }, "mask_generation_autoencoder_def": { "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi", @@ -239,12 +244,11 @@ "load_mask_generation_diffusion": "$@mask_generation_diffusion_unet.load_state_dict(@checkpoint_mask_generation_diffusion_unet['unet_state_dict'], strict=True)", "mask_generation_scale_factor": "$@checkpoint_mask_generation_diffusion_unet['scale_factor']", "noise_scheduler": { - "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler", + "_target_": "scripts.rectified_flow.RFlowScheduler", "num_train_timesteps": 1000, - "beta_start": 0.0015, - "beta_end": 0.0195, - "schedule": "scaled_linear_beta", - "clip_sample": false + "use_discrete_timesteps": false, + "use_timestep_transform": true, + "sample_method": "uniform" }, "mask_generation_noise_scheduler": { "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler", @@ -269,6 +273,7 @@ ], "body_region": "@body_region", "anatomy_list": "@anatomy_list", + "modality": "@modality", "all_mask_files_json": "@all_mask_files_json", "all_anatomy_size_condtions_json": "@all_anatomy_size_condtions_json", "all_mask_files_base_dir": "@all_mask_files_base_dir", @@ -300,6 +305,7 @@ "autoencoder_sliding_window_infer_overlap": "@autoencoder_sliding_window_infer_overlap" }, "run": [ + "$monai.utils.set_determinism(seed=@random_seed)", "$@ldm_sampler.sample_multiple_images(@num_output_samples)" ], "evaluator": null diff --git a/models/maisi_ct_generative/configs/metadata.json b/models/maisi_ct_generative/configs/metadata.json index 010a70d1..f46c2307 100644 --- a/models/maisi_ct_generative/configs/metadata.json +++ b/models/maisi_ct_generative/configs/metadata.json @@ -1,7 +1,8 @@ { "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_generator_ldm_20240318.json", - "version": "0.4.6", + "version": "1.0.0", "changelog": { + "1.0.0": "accelerated maisi, inference only, is not compartible with previous maisi diffusion model weights", "0.4.6": "add TensorRT support", "0.4.5": "update README", "0.4.4": "update issue for IgniteInfo", diff --git a/models/maisi_ct_generative/docs/README.md b/models/maisi_ct_generative/docs/README.md index be7937b1..c46e6645 100644 --- a/models/maisi_ct_generative/docs/README.md +++ b/models/maisi_ct_generative/docs/README.md @@ -4,7 +4,7 @@ This bundle is for Nvidia MAISI (Medical AI for Synthetic Imaging), a 3D Latent The inference workflow of MAISI is depicted in the figure below. It first generates latent features from random noise by applying multiple denoising steps using the trained diffusion model. Then it decodes the denoised latent features into images using the trained autoencoder.

- MAISI inference scheme + MAISI inference scheme

MAISI is based on the following papers: @@ -13,6 +13,8 @@ MAISI is based on the following papers: [**ControlNet:** Lvmin Zhang, Anyi Rao, Maneesh Agrawala; “Adding Conditional Control to Text-to-Image Diffusion Models.” ICCV 2023.](https://openaccess.thecvf.com/content/ICCV2023/papers/Zhang_Adding_Conditional_Control_to_Text-to-Image_Diffusion_Models_ICCV_2023_paper.pdf) +[**Rectified Flow:** Liu, Xingchao, and Chengyue Gong. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow." ICLR 2023.](https://arxiv.org/pdf/2209.03003) + #### Example synthetic image An example result from inference is shown below: ![Example synthetic image](https://developer.download.nvidia.com/assets/Clara/Images/monai_maisi_ct_generative_example_synthetic_data.png) @@ -27,11 +29,11 @@ The information for the inference input, like body region and anatomy to generat - `"num_output_samples"`: int, the number of output image/mask pairs it will generate. - `"spacing"`: voxel size of generated images. E.g., if set to `[1.5, 1.5, 2.0]`, it will generate images with a resolution of 1.5×1.5×2.0 mm. The spacing for x and y axes has to be between 0.5 and 3.0 mm and the spacing for the z axis has to be between 0.5 and 5.0 mm. -- `"output_size"`: volume size of generated images. E.g., if set to `[512, 512, 256]`, it will generate images with size of 512×512×256. They need to be divisible by 16. If you have a small GPU memory size, you should adjust it to small numbers. Note that `"spacing"` and `"output_size"` together decide the output field of view (FOV). For eample, if set them to `[1.5, 1.5, 2.0]`mm and `[512, 512, 256]`, the FOV is 768×768×512 mm. We recommend output_size is the FOV in x and y axis are same and to be at least 256mm for head, and at least 384mm for other body regions like abdomen. The output size for the x and y axes can be selected from [256, 384, 512], while for the z axis, it can be chosen from [128, 256, 384, 512, 640, 768]. +- `"output_size"`: volume size of generated images. E.g., if set to `[512, 512, 256]`, it will generate images with size of 512×512×256. They need to be divisible by 16. If you have a small GPU memory size, you should adjust it to small numbers. Note that `"spacing"` and `"output_size"` together decide the output field of view (FOV). For eample, if set them to `[1.5, 1.5, 2.0]`mm and `[512, 512, 256]`, the FOV is 768×768×512 mm. We recommend output_size is the FOV in x and y axis are same and to be at least 256mm for head, at least 384mm for other body regions like abdomen, and no larger than 640mm. The output size for the x and y axes can be selected from [256, 384, 512], while for the z axis, it can be chosen from [128, 256, 384, 512, 640, 768]. - `"controllable_anatomy_size"`: a list of controllable anatomy and its size scale (0--1). E.g., if set to `[["liver", 0.5],["hepatic tumor", 0.3]]`, the generated image will contain liver that have a median size, with size around 50% percentile, and hepatic tumor that is relatively small, with around 30% percentile. In addition, if the size scale is set to -1, it indicates that the organ does not exist or should be removed. The output will contain paired image and segmentation mask for the controllable anatomy. The following organs support generation with a controllable size: ``["liver", "gallbladder", "stomach", "pancreas", "colon", "lung tumor", "bone lesion", "hepatic tumor", "colon cancer primaries", "pancreatic tumor"]``. The raw output of the current mask generation model has a fixed size of $256^3$ voxels with a spacing of $1.5^3$ mm. If the "output_size" differs from this default, the generated masks will be resampled to the desired `"output_size"` and `"spacing"`. Note that resampling may degrade the quality of the generated masks and could trigger multiple inference attempts if the images fail to pass the [image quality check](../scripts/quality_check.py). -- `"body_region"`: If "controllable_anatomy_size" is not specified, "body_region" will be used to constrain the region of generated images. It needs to be chosen from "head", "chest", "thorax", "abdomen", "pelvis", "lower". +- `"body_region"`: Deprecated, please leave it as empty `"[]"`. - `"anatomy_list"`: If "controllable_anatomy_size" is not specified, the output will contain paired image and segmentation mask for the anatomy in "./configs/label_dict.json". - `"autoencoder_sliding_window_infer_size"`: in order to save GPU memory, we use sliding window inference when decoding latents to image when `"output_size"` is large. This is the patch size of the sliding window. Small value will reduce GPU memory but increase time cost. They need to be divisible by 16. - `"autoencoder_sliding_window_infer_overlap"`: float between 0 and 1. Large value will reduce the stitching artifacts when stitching patches during sliding window inference, but increase time cost. If you do not observe seam lines in the generated image result, you can use a smaller value to save inference time. diff --git a/models/maisi_ct_generative/large_files.yml b/models/maisi_ct_generative/large_files.yml index 97b26885..17400eab 100644 --- a/models/maisi_ct_generative/large_files.yml +++ b/models/maisi_ct_generative/large_files.yml @@ -1,5 +1,5 @@ large_files: -- path: "models/autoencoder_epoch273.pt" +- path: "models/autoencoder.pt" url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_autoencoder_epoch273_alternative.pt" hash_val: "917cfb1e49631c8a713e3bb7c758fbca" hash_type: "md5" @@ -11,6 +11,14 @@ large_files: url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current_alternative.pt" hash_val: "6c36572335372f405a0e85c760fa6dee" hash_type: "md5" +- path: "models/diffusion_unet.pt" + url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/diff_unet_ckpt_rflow_epoch19350.pt" + hash_val: "10501d59a3066802087c82ebd7a71719" + hash_type: "md5" +- path: "models/controlnet.pt" + url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/controlnet_rflow_epoch208.pt" + hash_val: "49933da32826c0f7ca17016ccd13e23b" + hash_type: "md5" - path: "models/mask_generation_autoencoder.pt" url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/mask_generation_autoencoder.pt" hash_val: "b177778820f412abc9218cdb7ce3b653" diff --git a/models/maisi_ct_generative/scripts/augmentation.py b/models/maisi_ct_generative/scripts/augmentation.py index 6317781f..64469403 100644 --- a/models/maisi_ct_generative/scripts/augmentation.py +++ b/models/maisi_ct_generative/scripts/augmentation.py @@ -60,7 +60,7 @@ def dilate3d(input_tensor, erosion=3): return output.squeeze(0).squeeze(0) -def augmentation_tumor_bone(pt_nda, output_size): +def augmentation_tumor_bone(pt_nda, output_size, random_seed): volume = pt_nda.squeeze(0) real_l_volume_ = torch.zeros_like(volume) real_l_volume_[volume == 128] = 1 @@ -74,6 +74,7 @@ def augmentation_tumor_bone(pt_nda, output_size): scale_range=(0.15, 0.15, 0), padding_mode="zeros", ) + elastic.set_random_state(seed=random_seed) tumor_szie = torch.sum((real_l_volume_ > 0).float()) ########################### @@ -112,7 +113,7 @@ def augmentation_tumor_bone(pt_nda, output_size): return pt_nda -def augmentation_tumor_liver(pt_nda, output_size): +def augmentation_tumor_liver(pt_nda, output_size, random_seed): volume = pt_nda.squeeze(0) real_l_volume_ = torch.zeros_like(volume) real_l_volume_[volume == 1] = 1 @@ -129,6 +130,7 @@ def augmentation_tumor_liver(pt_nda, output_size): scale_range=(0.2, 0.2, 0.2), padding_mode="zeros", ) + elastic.set_random_state(seed=random_seed) tumor_szie = torch.sum(real_l_volume_ == 2) ########################### @@ -161,7 +163,7 @@ def augmentation_tumor_liver(pt_nda, output_size): return pt_nda -def augmentation_tumor_lung(pt_nda, output_size): +def augmentation_tumor_lung(pt_nda, output_size, random_seed): volume = pt_nda.squeeze(0) real_l_volume_ = torch.zeros_like(volume) real_l_volume_[volume == 23] = 1 @@ -177,6 +179,7 @@ def augmentation_tumor_lung(pt_nda, output_size): scale_range=(0.15, 0.15, 0.15), padding_mode="zeros", ) + elastic.set_random_state(seed=random_seed) tumor_szie = torch.sum(real_l_volume_) # before move lung tumor maks, full the original location by lung labels @@ -224,7 +227,7 @@ def augmentation_tumor_lung(pt_nda, output_size): return pt_nda -def augmentation_tumor_pancreas(pt_nda, output_size): +def augmentation_tumor_pancreas(pt_nda, output_size, random_seed): volume = pt_nda.squeeze(0) real_l_volume_ = torch.zeros_like(volume) real_l_volume_[volume == 4] = 1 @@ -241,6 +244,7 @@ def augmentation_tumor_pancreas(pt_nda, output_size): scale_range=(0.1, 0.1, 0.1), padding_mode="zeros", ) + elastic.set_random_state(seed=random_seed) tumor_szie = torch.sum(real_l_volume_ == 2) ########################### @@ -273,7 +277,7 @@ def augmentation_tumor_pancreas(pt_nda, output_size): return pt_nda -def augmentation_tumor_colon(pt_nda, output_size): +def augmentation_tumor_colon(pt_nda, output_size, random_seed): volume = pt_nda.squeeze(0) real_l_volume_ = torch.zeros_like(volume) real_l_volume_[volume == 27] = 1 @@ -289,6 +293,7 @@ def augmentation_tumor_colon(pt_nda, output_size): scale_range=(0.1, 0.1, 0.1), padding_mode="zeros", ) + elastic.set_random_state(seed=random_seed) tumor_szie = torch.sum(real_l_volume_) ########################### @@ -330,37 +335,39 @@ def augmentation_tumor_colon(pt_nda, output_size): return pt_nda -def augmentation_body(pt_nda): +def augmentation_body(pt_nda, random_seed): volume = pt_nda.squeeze(0) zoom = RandZoom(min_zoom=0.99, max_zoom=1.01, mode="nearest", align_corners=None, prob=1.0) + zoom.set_random_state(seed=random_seed) + volume = zoom(volume) pt_nda = volume.unsqueeze(0) return pt_nda -def augmentation(pt_nda, output_size): +def augmentation(pt_nda, output_size, random_seed): label_list = torch.unique(pt_nda) label_list = list(label_list.cpu().numpy()) if 128 in label_list: print("augmenting bone lesion/tumor") - pt_nda = augmentation_tumor_bone(pt_nda, output_size) + pt_nda = augmentation_tumor_bone(pt_nda, output_size, random_seed) elif 26 in label_list: print("augmenting liver tumor") - pt_nda = augmentation_tumor_liver(pt_nda, output_size) + pt_nda = augmentation_tumor_liver(pt_nda, output_size, random_seed) elif 23 in label_list: print("augmenting lung tumor") - pt_nda = augmentation_tumor_lung(pt_nda, output_size) + pt_nda = augmentation_tumor_lung(pt_nda, output_size, random_seed) elif 24 in label_list: print("augmenting pancreas tumor") - pt_nda = augmentation_tumor_pancreas(pt_nda, output_size) + pt_nda = augmentation_tumor_pancreas(pt_nda, output_size, random_seed) elif 27 in label_list: print("augmenting colon tumor") - pt_nda = augmentation_tumor_colon(pt_nda, output_size) + pt_nda = augmentation_tumor_colon(pt_nda, output_size, random_seed) else: print("augmenting body") - pt_nda = augmentation_body(pt_nda) + pt_nda = augmentation_body(pt_nda, random_seed) return pt_nda diff --git a/models/maisi_ct_generative/scripts/find_masks.py b/models/maisi_ct_generative/scripts/find_masks.py index de626552..f8f3a14d 100644 --- a/models/maisi_ct_generative/scripts/find_masks.py +++ b/models/maisi_ct_generative/scripts/find_masks.py @@ -53,7 +53,6 @@ def convert_body_region(body_region: str | Sequence[str]) -> Sequence[int]: def find_masks( - body_region: str | Sequence[str], anatomy_list: int | Sequence[int], spacing: Sequence[float] | float = 1.0, output_size: Sequence[int] = (512, 512, 512), @@ -63,12 +62,10 @@ def find_masks( ): """ Find candidate masks that fullfills all the requirements. - They shoud contain all the body region in `body_region`, all the anatomies in `anatomy_list`. + They shoud contain all the anatomies in `anatomy_list`. If there is no tumor specified in `anatomy_list`, we also expect the candidate masks to be tumor free. If check_spacing_and_output_size is True, the candidate masks need to have the expected `spacing` and `output_size`. Args: - body_region: list of input body region string. If single str, will be converted to list of str. - The found candidate mask will include these body regions. anatomy_list: list of input anatomy. The found candidate mask will include these anatomies. spacing: list of three floats, voxel spacing. If providing a single number, will use it for all the three dimensions. output_size: list of three int, expected candidate mask spatial size. @@ -80,8 +77,6 @@ def find_masks( candidate_masks, list of dict, each dict contains information of one candidate mask that fullfills all the requirements. """ # check and preprocess input - body_region = convert_body_region(body_region) - if isinstance(anatomy_list, int): anatomy_list = [anatomy_list] @@ -108,20 +103,9 @@ def find_masks( if not set(anatomy_list).issubset(_item["label_list"]): continue - # extract region indice (top_index and bottom_index) for candidate mask - top_index = [index for index, element in enumerate(_item["top_region_index"]) if element != 0] - top_index = top_index[0] - bottom_index = [index for index, element in enumerate(_item["bottom_region_index"]) if element != 0] - bottom_index = bottom_index[0] - # whether to keep this mask, default to be True. keep_mask = True - # if candiate mask does not contain all the body_region, skip it - for _idx in body_region: - if _idx > bottom_index or _idx < top_index: - keep_mask = False - for tumor_label in [23, 24, 26, 27, 128]: # we skip those mask with tumors if users do not provide tumor label in anatomy_list if tumor_label not in anatomy_list and tumor_label in _item["label_list"]: @@ -139,8 +123,6 @@ def find_masks( "pseudo_label": os.path.join(mask_foldername, _item["pseudo_label_filename"]), "spacing": _item["spacing"], "dim": _item["dim"], - "top_region_index": _item["top_region_index"], - "bottom_region_index": _item["bottom_region_index"], } # Conditionally add the label to the candidate dictionary diff --git a/models/maisi_ct_generative/scripts/quality_check.py b/models/maisi_ct_generative/scripts/quality_check.py index fe34661f..bff49b6d 100644 --- a/models/maisi_ct_generative/scripts/quality_check.py +++ b/models/maisi_ct_generative/scripts/quality_check.py @@ -109,8 +109,11 @@ def is_outlier(statistics, image_data, label_data, label_int_dict): for label_name, stats in statistics.items(): # Get the thresholds from the statistics - low_thresh = stats["sigma_6_low"] # or "sigma_12_low" depending on your needs - high_thresh = stats["sigma_6_high"] # or "sigma_12_high" depending on your needs + low_thresh = min(stats["sigma_6_low"], stats["percentile_0_5"]) # or "sigma_12_low" depending on your needs + high_thresh = max(stats["sigma_6_high"], stats["percentile_99_5"]) # or "sigma_12_high" depending on your needs + + if label_name == "bone": + high_thresh = 1000.0 # Retrieve the corresponding label integers labels = label_int_dict.get(label_name, []) diff --git a/models/maisi_ct_generative/scripts/rectified_flow.py b/models/maisi_ct_generative/scripts/rectified_flow.py new file mode 100644 index 00000000..6bdcb00a --- /dev/null +++ b/models/maisi_ct_generative/scripts/rectified_flow.py @@ -0,0 +1,163 @@ +from typing import Any + +import numpy as np +import torch +from monai.networks.schedulers import Scheduler +from torch.distributions import LogisticNormal + +# code modified from https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py + + +def timestep_transform( + t, input_img_size, base_img_size=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3 +): + t = t / num_train_timesteps + ratio_space = (input_img_size / base_img_size).pow(1.0 / spatial_dim) + + ratio = ratio_space * scale + new_t = ratio * t / (1 + (ratio - 1) * t) + + new_t = new_t * num_train_timesteps + return new_t + + +class RFlowScheduler(Scheduler): + def __init__( + self, + num_train_timesteps=1000, + num_inference_steps=10, + use_discrete_timesteps=False, + sample_method="uniform", + loc=0.0, + scale=1.0, + use_timestep_transform=False, + transform_scale=1.0, + steps_offset: int = 0, + ): + self.num_train_timesteps = num_train_timesteps + self.num_inference_steps = num_inference_steps + self.use_discrete_timesteps = use_discrete_timesteps + + # sample method + assert sample_method in ["uniform", "logit-normal"] + # assert ( + # sample_method == "uniform" or not use_discrete_timesteps + # ), "Only uniform sampling is supported for discrete timesteps" + self.sample_method = sample_method + if sample_method == "logit-normal": + self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale])) + self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device) + + # timestep transform + self.use_timestep_transform = use_timestep_transform + self.transform_scale = transform_scale + self.steps_offset = steps_offset + + def add_noise( + self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + """ + compatible with diffusers add_noise() + """ + timepoints = timesteps.float() / self.num_train_timesteps + timepoints = 1 - timepoints # [1,1/1000] + + # timepoint (bsz) noise: (bsz, 4, frame, w ,h) + # expand timepoint to noise shape + timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) + timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4]) + + return timepoints * original_samples + (1 - timepoints) * noise + + def set_timesteps( + self, + num_inference_steps: int, + device: str | torch.device | None = None, + input_img_size: int | None = None, + base_img_size: int = 32 * 32 * 32, + ) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + input_img_size: int, H*W*D of the image, used with self.use_timestep_transform is True. + base_img_size: int, reference H*W*D size, used with self.use_timestep_transform is True. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + # prepare timesteps + timesteps = [ + (1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps) + ] + if self.use_discrete_timesteps: + timesteps = [int(round(t)) for t in timesteps] + if self.use_timestep_transform: + timesteps = [ + timestep_transform( + t, + input_img_size=input_img_size, + base_img_size=base_img_size, + num_train_timesteps=self.num_train_timesteps, + ) + for t in timesteps + ] + timesteps = np.array(timesteps).astype(np.float16) + if self.use_discrete_timesteps: + timesteps = timesteps.astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps += self.steps_offset + print(self.timesteps) + + def sample_timesteps(self, x_start): + if self.sample_method == "uniform": + t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_train_timesteps + elif self.sample_method == "logit-normal": + t = self.sample_t(x_start) * self.num_train_timesteps + + if self.use_discrete_timesteps: + t = t.long() + + if self.use_timestep_transform: + input_img_size = torch.prod(torch.tensor(x_start.shape[-3:])) + base_img_size = 32 * 32 * 32 + t = timestep_transform( + t, + input_img_size=input_img_size, + base_img_size=base_img_size, + num_train_timesteps=self.num_train_timesteps, + ) + + return t + + def step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep=None + ) -> tuple[torch.Tensor, Any]: + """ + Predict the sample at the previous timestep. Core function to propagate the diffusion + process from the learned model outputs. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + Returns: + pred_prev_sample: Predicted previous sample + None + """ + v_pred = model_output + if next_timestep is None: + dt = 1.0 / self.num_inference_steps + else: + dt = timestep - next_timestep + dt = dt / self.num_train_timesteps + z = sample + v_pred * dt + + return z, None diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index d161597f..d22d26e2 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -11,7 +11,6 @@ import json import logging -import math import os import random import time @@ -20,8 +19,7 @@ import monai import torch from monai.data import MetaTensor -from monai.inferers import sliding_window_inference -from monai.inferers.inferer import DiffusionInferer +from monai.inferers.inferer import DiffusionInferer, SlidingWindowInferer from monai.transforms import Compose, SaveImage from monai.utils import set_determinism from tqdm import tqdm @@ -29,7 +27,23 @@ from .augmentation import augmentation from .find_masks import find_masks from .quality_check import is_outlier -from .utils import binarize_labels, general_mask_generation_post_process, get_body_region_index_from_mask, remap_labels +from .utils import binarize_labels, dynamic_infer, general_mask_generation_post_process, remap_labels + +modality_mapping = { + "unknown": 0, + "ct": 1, + "ct_wo_contrast": 2, + "ct_contrast": 3, + "mri": 8, + "mri_t1": 9, + "mri_t2": 10, + "mri_flair": 11, + "mri_pd": 12, + "mri_dwi": 13, + "mri_adc": 14, + "mri_ssfp": 15, + "mri_mra": 16, +} # current version only support "ct" class ReconModel(torch.nn.Module): @@ -123,28 +137,16 @@ def ldm_conditional_sample_one_mask( conditioning=anatomy_size.to(device), ) # decode latents to synthesized masks - if math.prod(latent_shape[1:]) < math.prod(autoencoder_sliding_window_infer_size): - synthetic_mask = recon_model(latents).cpu().detach() - else: - synthetic_mask = ( - sliding_window_inference( - inputs=latents, - roi_size=( - autoencoder_sliding_window_infer_size[0], - autoencoder_sliding_window_infer_size[1], - autoencoder_sliding_window_infer_size[2], - ), - sw_batch_size=1, - predictor=recon_model, - mode="gaussian", - overlap=autoencoder_sliding_window_infer_overlap, - sw_device=device, - device=torch.device("cpu"), - progress=True, - ) - .cpu() - .detach() - ) + inferer = SlidingWindowInferer( + roi_size=autoencoder_sliding_window_infer_size, + sw_batch_size=1, + progress=True, + mode="gaussian", + overlap=autoencoder_sliding_window_infer_overlap, + device=torch.device("cpu"), + sw_device=device, + ) + synthetic_mask = dynamic_infer(inferer, recon_model, latents) synthetic_mask = torch.softmax(synthetic_mask, dim=1) synthetic_mask = torch.argmax(synthetic_mask, dim=1, keepdim=True) # mapping raw index to 132 labels @@ -174,8 +176,7 @@ def ldm_conditional_sample_one_image( scale_factor, device, combine_label_or, - top_region_index_tensor, - bottom_region_index_tensor, + modality_tensor, spacing_tensor, latent_shape, output_size, @@ -195,8 +196,6 @@ def ldm_conditional_sample_one_image( scale_factor (float): Scaling factor for the latent space. device (torch.device): The device to run the computation on. combine_label_or (torch.Tensor): The combined label tensor. - top_region_index_tensor (torch.Tensor): Tensor specifying the top region index. - bottom_region_index_tensor (torch.Tensor): Tensor specifying the bottom region index. spacing_tensor (torch.Tensor): Tensor specifying the spacing. latent_shape (tuple): The shape of the latent space. output_size (tuple): The desired output size of the image. @@ -217,7 +216,7 @@ def ldm_conditional_sample_one_image( recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device) - with torch.no_grad(), torch.amp.autocast("cuda"): + with torch.no_grad(), torch.amp.autocast("cuda", enabled=True): logging.info("---- Start generating latent features... ----") start_time = time.time() # generate segmentation mask @@ -239,49 +238,66 @@ def ldm_conditional_sample_one_image( latents = initialize_noise_latents(latent_shape, device) * noise_factor # synthesize latents - noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps) - for t in tqdm(noise_scheduler.timesteps, ncols=110): - # Get controlnet output - down_block_res_samples, mid_block_res_sample = controlnet( - x=latents, timesteps=torch.Tensor((t,)).to(device), controlnet_cond=controlnet_cond_vis - ) - latent_model_input = latents - noise_pred = diffusion_unet( - x=latent_model_input, - timesteps=torch.Tensor((t,)).to(device), - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, - spacing_tensor=spacing_tensor, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ) - latents, _ = noise_scheduler.step(noise_pred, t, latents) + noise_scheduler.set_timesteps( + num_inference_steps=num_inference_steps, input_img_size=torch.prod(torch.tensor(latent_shape[-3:])) + ) + # synthesize latents + guidance_scale = 0 # API for classifier-free guidence, not used in this version + all_next_timesteps = torch.cat( + (noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype)) + ) + for t, next_t in tqdm( + zip(noise_scheduler.timesteps, all_next_timesteps), + total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)), + ): + timesteps = torch.Tensor((t,)).to(device) + if guidance_scale == 0: + down_block_res_samples, mid_block_res_sample = controlnet( + x=latents, timesteps=timesteps, controlnet_cond=controlnet_cond_vis, class_labels=modality_tensor + ) + predicted_velocity = diffusion_unet( + x=latents, + timesteps=timesteps, + spacing_tensor=spacing_tensor, + class_labels=modality_tensor, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + else: + down_block_res_samples, mid_block_res_sample = controlnet( + x=torch.cat([latents] * 2), + timesteps=torch.cat([timesteps] * 2), + controlnet_cond=torch.cat([controlnet_cond_vis] * 2), + class_labels=torch.cat([modality_tensor, torch.zeros_like(modality_tensor)]), + ) + model_t, model_uncond = diffusion_unet( + x=torch.cat([latents] * 2), + timesteps=timesteps, + spacing_tensor=torch.cat([timesteps] * 2), + class_labels=torch.cat([modality_tensor, torch.zeros_like(modality_tensor)]), + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).chunk(2) + predicted_velocity = model_uncond + guidance_scale * (model_t - model_uncond) + latents, _ = noise_scheduler.step(predicted_velocity, t, latents, next_timestep=next_t) end_time = time.time() logging.info(f"---- Latent features generation time: {end_time - start_time} seconds ----") - del noise_pred + del predicted_velocity torch.cuda.empty_cache() # decode latents to synthesized images logging.info("---- Start decoding latent features into images... ----") + inferer = SlidingWindowInferer( + roi_size=autoencoder_sliding_window_infer_size, + sw_batch_size=1, + progress=True, + mode="gaussian", + overlap=autoencoder_sliding_window_infer_overlap, + device=torch.device("cpu"), + sw_device=device, + ) start_time = time.time() - if math.prod(latent_shape[1:]) < math.prod(autoencoder_sliding_window_infer_size): - synthetic_images = recon_model(latents) - else: - synthetic_images = sliding_window_inference( - inputs=latents, - roi_size=( - min(output_size[0] // 4 // 4 * 3, autoencoder_sliding_window_infer_size[0]), - min(output_size[1] // 4 // 4 * 3, autoencoder_sliding_window_infer_size[1]), - min(output_size[2] // 4 // 4 * 3, autoencoder_sliding_window_infer_size[2]), - ), - sw_batch_size=1, - predictor=recon_model, - mode="gaussian", - overlap=autoencoder_sliding_window_infer_overlap, - sw_device=device, - device=torch.device("cpu"), - progress=True, - ) + synthetic_images = dynamic_infer(inferer, recon_model, latents) synthetic_images = torch.clip(synthetic_images, b_min, b_max).cpu() end_time = time.time() logging.info(f"---- Image decoding time: {end_time - start_time} seconds ----") @@ -372,13 +388,19 @@ def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing f"spacing[0] have to be between 0.5 and 3.0 mm, spacing[2] have to be between 0.5 and 5.0 mm, yet got {spacing}." ) - if output_size[0] * spacing[0] < 256: + if ( + output_size[0] * spacing[0] < 256 + or output_size[2] * spacing[2] < 128 + or output_size[0] * spacing[0] > 640 + or output_size[2] * spacing[2] > 2000 + ): fov = [output_size[axis] * spacing[axis] for axis in range(3)] raise ValueError( ( f"`'spacing'({spacing}mm) and 'output_size'({output_size}) together decide the output field of view (FOV). " f"The FOV will be {fov}mm. We recommend the FOV in x and y axis to be at least 256mm for head, and at least " - "384mm for other body regions like abdomen. There is no such restriction for z-axis." + "384mm for other body regions like abdomen, and less than 640mm. " + "For z-axis, we require it to be at least 128mm and less than 2000mm." ) ) @@ -436,10 +458,7 @@ def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing ) else: logging.info( - ( - "`controllable_anatomy_size` is empty.\nWe will synthesize based on `body_region`: " - f"({body_region}) and `anatomy_list`: ({anatomy_list})." - ) + (f"`controllable_anatomy_size` is empty.\nWe will synthesize based on `anatomy_list`: ({anatomy_list}).") ) # check body_region format available_body_region = ["head", "chest", "thorax", "abdomen", "pelvis", "lower"] @@ -474,6 +493,7 @@ def __init__( self, body_region, anatomy_list, + modality, all_mask_files_json, all_anatomy_size_condtions_json, all_mask_files_base_dir, @@ -510,6 +530,7 @@ def __init__( Args: Various parameters related to model configuration, input settings, and output specifications. """ + self.random_seed = random_seed if random_seed is not None: set_determinism(seed=random_seed) @@ -520,6 +541,7 @@ def __init__( # intialize variables self.body_region = body_region self.anatomy_list = [label_dict[organ] for organ in anatomy_list] + self.modality_int = modality_mapping[modality] self.all_mask_files_json = all_mask_files_json self.data_root = all_mask_files_base_dir self.label_dict_remap_json = label_dict_remap_json @@ -566,7 +588,7 @@ def __init__( self.autoencoder_sliding_window_infer_overlap = autoencoder_sliding_window_infer_overlap # quality check args - self.max_try_time = 5 # if not pass quality check, will try self.max_try_time times + self.max_try_time = 3 # if not pass quality check, will try self.max_try_time times with open(real_img_median_statistics, "r") as json_file: self.median_statistics = json.load(json_file) self.label_int_dict = { @@ -599,11 +621,7 @@ def __init__( monai.transforms.EnsureChannelFirstd(keys=["pseudo_label"]), monai.transforms.Orientationd(keys=["pseudo_label"], axcodes="RAS"), monai.transforms.EnsureTyped(keys=["pseudo_label"], dtype=torch.uint8), - monai.transforms.Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x)), - monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x)), monai.transforms.Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)), - monai.transforms.Lambdad(keys="top_region_index", func=lambda x: x * 1e2), - monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: x * 1e2), monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2), ] ) @@ -627,13 +645,7 @@ def sample_multiple_images(self, num_img): need_resample = False # find candidate mask and save to candidate_mask_files candidate_mask_files = find_masks( - self.body_region, - self.anatomy_list, - self.spacing, - self.output_size, - True, - self.all_mask_files_json, - self.data_root, + self.anatomy_list, self.spacing, self.output_size, True, self.all_mask_files_json, self.data_root ) if len(candidate_mask_files) < num_img: # if we cannot find enough masks based on the exact match of anatomy list, spacing, and output size, @@ -643,82 +655,81 @@ def sample_multiple_images(self, num_img): need_resample = True selected_mask_files = self.select_mask(candidate_mask_files, num_img) - logging.info(f"Images will be generated based on {selected_mask_files}.") - if len(selected_mask_files) != num_img: + if len(selected_mask_files) < num_img: raise ValueError( ( - f"len(selected_mask_files) ({len(selected_mask_files)}) != num_img ({num_img}). " + f"len(selected_mask_files) ({len(selected_mask_files)}) < num_img ({num_img}). " "This should not happen. Please revisit function select_mask(self, candidate_mask_files, num_img)." ) ) - for item in selected_mask_files: + num_generated_img = 0 + for index_s in range(len(selected_mask_files)): + item = selected_mask_files[index_s] + if num_generated_img >= num_img: + break logging.info("---- Start preparing masks... ----") start_time = time.time() + logging.info(f"Image will be generated based on {item}.") if len(self.controllable_anatomy_size) > 0: # generate a synthetic mask - (combine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor) = ( - self.prepare_one_mask_and_meta_info(anatomy_size_condtion) - ) + (combine_label_or, spacing_tensor) = self.prepare_one_mask_and_meta_info(anatomy_size_condtion) else: # read in mask file mask_file = item["mask_file"] if_aug = item["if_aug"] - (combine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor) = ( - self.read_mask_information(mask_file) - ) + (combine_label_or, spacing_tensor) = self.read_mask_information(mask_file) if need_resample: combine_label_or = self.ensure_output_size_and_spacing(combine_label_or) # mask augmentation if if_aug: - combine_label_or = augmentation(combine_label_or, self.output_size) + combine_label_or = augmentation(combine_label_or, self.output_size, random_seed=self.random_seed) end_time = time.time() logging.info(f"---- Mask preparation time: {end_time - start_time} seconds ----") torch.cuda.empty_cache() # generate image/label pairs - to_generate = True - try_time = 0 - while to_generate: - synthetic_images, synthetic_labels = self.sample_one_pair( - combine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor - ) - # synthetic image quality check - pass_quality_check = self.quality_check( - synthetic_images.cpu().detach().numpy(), combine_label_or.cpu().detach().numpy() - ) - if pass_quality_check or try_time > self.max_try_time: - # save image/label pairs - output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f") - synthetic_labels.meta["filename_or_obj"] = "sample.nii.gz" - synthetic_images = MetaTensor(synthetic_images, meta=synthetic_labels.meta) - img_saver = SaveImage( - output_dir=self.output_dir, - output_postfix=output_postfix + "_image", - output_ext=self.image_output_ext, - separate_folder=False, - ) - img_saver(synthetic_images[0]) - synthetic_images_filename = os.path.join( - self.output_dir, "sample_" + output_postfix + "_image" + self.image_output_ext - ) - # filter out the organs that are not in anatomy_list - synthetic_labels = filter_mask_with_organs(synthetic_labels, self.anatomy_list) - label_saver = SaveImage( - output_dir=self.output_dir, - output_postfix=output_postfix + "_label", - output_ext=self.label_output_ext, - separate_folder=False, - ) - label_saver(synthetic_labels[0]) - synthetic_labels_filename = os.path.join( - self.output_dir, "sample_" + output_postfix + "_label" + self.label_output_ext - ) - output_filenames.append([synthetic_images_filename, synthetic_labels_filename]) - to_generate = False - else: + modality_tensor = torch.ones_like(spacing_tensor[:, 0]).long() * self.modality_int + # start generation + synthetic_images, synthetic_labels = self.sample_one_pair(combine_label_or, modality_tensor, spacing_tensor) + # synthetic image quality check + pass_quality_check = self.quality_check( + synthetic_images.cpu().detach().numpy(), combine_label_or.cpu().detach().numpy() + ) + if pass_quality_check or (num_img - num_generated_img) >= (len(selected_mask_files) - index_s): + if not pass_quality_check: logging.info( - "Generated image/label pair did not pass quality check, will re-generate another pair." + "Generated image/label pair did not pass quality check, but will still save them. " + "Please consider changing spacing and output_size to facilitate a more realistic setting." ) - try_time += 1 + num_generated_img = num_generated_img + 1 + # save image/label pairs + output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + synthetic_labels.meta["filename_or_obj"] = "sample.nii.gz" + synthetic_images = MetaTensor(synthetic_images, meta=synthetic_labels.meta) + img_saver = SaveImage( + output_dir=self.output_dir, + output_postfix=output_postfix + "_image", + output_ext=self.image_output_ext, + separate_folder=False, + ) + img_saver(synthetic_images[0]) + synthetic_images_filename = os.path.join( + self.output_dir, "sample_" + output_postfix + "_image" + self.image_output_ext + ) + # filter out the organs that are not in anatomy_list + synthetic_labels = filter_mask_with_organs(synthetic_labels, self.anatomy_list) + label_saver = SaveImage( + output_dir=self.output_dir, + output_postfix=output_postfix + "_label", + output_ext=self.label_output_ext, + separate_folder=False, + ) + label_saver(synthetic_labels[0]) + synthetic_labels_filename = os.path.join( + self.output_dir, "sample_" + output_postfix + "_label" + self.label_output_ext + ) + output_filenames.append([synthetic_images_filename, synthetic_labels_filename]) + else: + logging.info("Generated image/label pair did not pass quality check, will re-generate another pair.") return output_filenames def select_mask(self, candidate_mask_files, num_img): @@ -735,21 +746,18 @@ def select_mask(self, candidate_mask_files, num_img): selected_mask_files = [] random.shuffle(candidate_mask_files) - for n in range(num_img): + for n in range(num_img * self.max_try_time): mask_file = candidate_mask_files[n % len(candidate_mask_files)] selected_mask_files.append({"mask_file": mask_file, "if_aug": True}) return selected_mask_files - def sample_one_pair( - self, combine_label_or_aug, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor - ): + def sample_one_pair(self, combine_label_or_aug, modality_tensor, spacing_tensor): """ Generate a single pair of synthetic image and mask. Args: combine_label_or_aug (torch.Tensor): Combined label tensor or augmented label. - top_region_index_tensor (torch.Tensor): Tensor specifying the top region index. - bottom_region_index_tensor (torch.Tensor): Tensor specifying the bottom region index. + modality_tensor (torch.Tensor): Tensor specifying the image modality. spacing_tensor (torch.Tensor): Tensor specifying the spacing. Returns: @@ -764,8 +772,7 @@ def sample_one_pair( scale_factor=self.scale_factor, device=self.device, combine_label_or=combine_label_or_aug, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, + modality_tensor=modality_tensor, spacing_tensor=spacing_tensor, latent_shape=self.latent_shape, output_size=self.output_size, @@ -846,13 +853,9 @@ def prepare_one_mask_and_meta_info(self, anatomy_size_condtion): combine_label_or = MetaTensor(combine_label_or, affine=affine) combine_label_or = self.ensure_output_size_and_spacing(combine_label_or) - top_region_index, bottom_region_index = get_body_region_index_from_mask(combine_label_or) - spacing_tensor = torch.FloatTensor(self.spacing).unsqueeze(0).half().to(self.device) * 1e2 - top_region_index_tensor = torch.FloatTensor(top_region_index).unsqueeze(0).half().to(self.device) * 1e2 - bottom_region_index_tensor = torch.FloatTensor(bottom_region_index).unsqueeze(0).half().to(self.device) * 1e2 - return combine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor + return combine_label_or, spacing_tensor def sample_one_mask(self, anatomy_size): """ @@ -940,15 +943,10 @@ def read_mask_information(self, mask_file): """ val_data = self.val_transforms(mask_file) - for key in ["pseudo_label", "spacing", "top_region_index", "bottom_region_index"]: + for key in ["pseudo_label", "spacing"]: val_data[key] = val_data[key].unsqueeze(0).to(self.device) - return ( - val_data["pseudo_label"], - val_data["top_region_index"], - val_data["bottom_region_index"], - val_data["spacing"], - ) + return (val_data["pseudo_label"], val_data["spacing"]) def find_closest_masks(self, num_img): """ @@ -965,30 +963,37 @@ def find_closest_masks(self, num_img): """ # first check the database based on anatomy list candidates = find_masks( - self.body_region, - self.anatomy_list, - self.spacing, - self.output_size, - False, - self.all_mask_files_json, - self.data_root, + self.anatomy_list, self.spacing, self.output_size, False, self.all_mask_files_json, self.data_root ) if len(candidates) < num_img: raise ValueError(f"candidate masks are less than {num_img}).") + # loop through the database and find closest combinations new_candidates = [] for c in candidates: diff = 0 + include_c = True for axis in range(3): + if abs(c["dim"][axis]) < self.output_size[axis] - 64: + # we cannot upsample the mask too much + include_c = False + break + # check diff in FOV, major metric + diff += abs( + (abs(c["dim"][axis] * c["spacing"][axis]) - self.output_size[axis] * self.spacing[axis]) / 10 + ) # check diff in dim - diff += abs((c["dim"][axis] - self.output_size[axis]) / 100) + diff += abs((abs(c["dim"][axis]) - self.output_size[axis]) / 100) # check diff in spacing - diff += abs(c["spacing"][axis] - self.spacing[axis]) - new_candidates.append((c, diff)) + diff += abs(abs(c["spacing"][axis]) - self.spacing[axis]) + if include_c: + new_candidates.append((c, diff)) + # choose top-2*num_img candidates (at least 5) new_candidates = sorted(new_candidates, key=lambda x: x[1])[: max(2 * num_img, 5)] final_candidates = [] + # check top-2*num_img candidates and update spacing after resampling image_loader = monai.transforms.LoadImage(image_only=True, ensure_channel_first=True) for c, _ in new_candidates: @@ -1001,9 +1006,6 @@ def find_closest_masks(self, num_img): else: raise e # get region_index after resample - top_region_index, bottom_region_index = get_body_region_index_from_mask(label) - c["top_region_index"] = top_region_index - c["bottom_region_index"] = bottom_region_index c["spacing"] = self.spacing c["dim"] = self.output_size diff --git a/models/maisi_ct_generative/scripts/utils.py b/models/maisi_ct_generative/scripts/utils.py index 0cd46590..43cc62d7 100644 --- a/models/maisi_ct_generative/scripts/utils.py +++ b/models/maisi_ct_generative/scripts/utils.py @@ -680,4 +680,17 @@ def dynamic_infer(inferer, model, images): if torch.numel(images[0:1, 0:1, ...]) < math.prod(inferer.roi_size): return model(images) else: - return inferer(network=model, inputs=images) + # Extract the spatial dimensions from the images tensor (H, W, D) + spatial_dims = images.shape[2:] + orig_roi = inferer.roi_size + + # Check that roi has the same number of dimensions as spatial_dims + if len(orig_roi) != len(spatial_dims): + raise ValueError(f"ROI length ({len(orig_roi)}) does not match spatial dimensions ({len(spatial_dims)}).") + + # Iterate and adjust each ROI dimension + adjusted_roi = [min(roi_dim, img_dim) for roi_dim, img_dim in zip(orig_roi, spatial_dims)] + inferer.roi_size = adjusted_roi + output = inferer(network=model, inputs=images) + inferer.roi_size = orig_roi + return output