From 74ebbb5aa1feb290412eb552fd030f99e7642b69 Mon Sep 17 00:00:00 2001 From: Can Zhao Date: Wed, 12 Feb 2025 06:30:18 +0000 Subject: [PATCH 01/50] New version, Accelerated maisi Signed-off-by: Can Zhao --- .../configs/inference.json | 38 +++-- .../scripts/rectified_flow.py | 149 ++++++++++++++++++ models/maisi_ct_generative/scripts/sample.py | 101 ++++++------ models/maisi_ct_generative/scripts/utils.py | 17 +- 4 files changed, 234 insertions(+), 71 deletions(-) create mode 100644 models/maisi_ct_generative/scripts/rectified_flow.py diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json index 283305ed..c60b4945 100644 --- a/models/maisi_ct_generative/configs/inference.json +++ b/models/maisi_ct_generative/configs/inference.json @@ -10,8 +10,8 @@ "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)", "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "trained_autoencoder_path": "$@model_dir + '/autoencoder_epoch273.pt'", - "trained_diffusion_path": "$@model_dir + '/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt'", - "trained_controlnet_path": "$@model_dir + '/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt'", + "trained_diffusion_path": "$@model_dir + '/diff_unet_ckpt_epoch19200.pt'", + "trained_controlnet_path": "$@model_dir + '/controlnet_current.pt'", "trained_mask_generation_autoencoder_path": "$@model_dir + '/mask_generation_autoencoder.pt'", "trained_mask_generation_diffusion_path": "$@model_dir + '/mask_generation_diffusion_unet.pt'", "all_mask_files_base_dir": "$@bundle_root + '/datasets/all_masks_flexible_size_and_spacing_3000'", @@ -27,8 +27,9 @@ "anatomy_list": [ "liver" ], + "modality": "ct", "controllable_anatomy_size": [], - "num_inference_steps": 1000, + "num_inference_steps": 30, "mask_generation_num_inference_steps": 1000, "random_seed": null, "spatial_dims": 3, @@ -67,7 +68,7 @@ 96, 96 ], - "autoencoder_sliding_window_infer_overlap": 0.6667, + "autoencoder_sliding_window_infer_overlap": 0.3333, "autoencoder_def": { "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi", "spatial_dims": "@spatial_dims", @@ -124,9 +125,12 @@ ], "num_res_blocks": 2, "use_flash_attention": true, - "include_top_region_index_input": true, - "include_bottom_region_index_input": true, - "include_spacing_input": true + "include_top_region_index_input": false, + "include_bottom_region_index_input": false, + "include_spacing_input": true, + "num_class_embeds": 128, + "resblock_updown": true, + "include_fc": true }, "controlnet_def": { "_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi", @@ -153,11 +157,10 @@ "num_res_blocks": 2, "use_flash_attention": true, "conditioning_embedding_in_channels": 8, - "conditioning_embedding_num_channels": [ - 8, - 32, - 64 - ] + "conditioning_embedding_num_channels": [8, 32, 64], + "num_class_embeds": 128, + "resblock_updown": true, + "include_fc": true }, "mask_generation_autoencoder_def": { "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi", @@ -239,12 +242,12 @@ "load_mask_generation_diffusion": "$@mask_generation_diffusion_unet.load_state_dict(@checkpoint_mask_generation_diffusion_unet['unet_state_dict'], strict=True)", "mask_generation_scale_factor": "$@checkpoint_mask_generation_diffusion_unet['scale_factor']", "noise_scheduler": { - "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler", + "_target_": "scripts.rectified_flow.RFlowScheduler", "num_train_timesteps": 1000, - "beta_start": 0.0015, - "beta_end": 0.0195, - "schedule": "scaled_linear_beta", - "clip_sample": false + "use_discrete_timesteps": false, + "use_timestep_transform": true, + "sample_method": "logit-normal", + "scale":1.2 }, "mask_generation_noise_scheduler": { "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler", @@ -269,6 +272,7 @@ ], "body_region": "@body_region", "anatomy_list": "@anatomy_list", + "modality": "@modality", "all_mask_files_json": "@all_mask_files_json", "all_anatomy_size_condtions_json": "@all_anatomy_size_condtions_json", "all_mask_files_base_dir": "@all_mask_files_base_dir", diff --git a/models/maisi_ct_generative/scripts/rectified_flow.py b/models/maisi_ct_generative/scripts/rectified_flow.py new file mode 100644 index 00000000..0ffcae3d --- /dev/null +++ b/models/maisi_ct_generative/scripts/rectified_flow.py @@ -0,0 +1,149 @@ +import numpy as np +import torch +from torch.distributions import LogisticNormal +from monai.networks.schedulers import Scheduler +from typing import Any + +# code modified from https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py + + +def timestep_transform( + t, + input_img_size, + base_img_size=32*32*32, + scale=1.0, + num_train_timesteps=1000, + spatial_dim = 3, +): + t = t / num_train_timesteps + resolution = input_img_size + ratio_space = (input_img_size / base_img_size).pow(1./spatial_dim) + + ratio = ratio_space * scale + new_t = ratio * t / (1 + (ratio - 1) * t) + + new_t = new_t * num_train_timesteps + return new_t + + +class RFlowScheduler(Scheduler): + def __init__( + self, + num_train_timesteps=1000, + num_inference_steps=10, + use_discrete_timesteps=False, + sample_method="uniform", + loc=0.0, + scale=1.0, + use_timestep_transform=False, + transform_scale=1.0, + steps_offset: int = 0, + ): + self.num_train_timesteps = num_train_timesteps + self.num_inference_steps = num_inference_steps + self.use_discrete_timesteps = use_discrete_timesteps + + # sample method + assert sample_method in ["uniform", "logit-normal"] + # assert ( + # sample_method == "uniform" or not use_discrete_timesteps + # ), "Only uniform sampling is supported for discrete timesteps" + self.sample_method = sample_method + if sample_method == "logit-normal": + self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale])) + self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device) + + # timestep transform + self.use_timestep_transform = use_timestep_transform + self.transform_scale = transform_scale + self.steps_offset = steps_offset + + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + """ + compatible with diffusers add_noise() + """ + timepoints = timesteps.float() / self.num_train_timesteps + timepoints = 1 - timepoints # [1,1/1000] + + # timepoint (bsz) noise: (bsz, 4, frame, w ,h) + # expand timepoint to noise shape + timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) + timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4]) + + return timepoints * original_samples + (1 - timepoints) * noise + + def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None, input_img_size: int |None = None, base_img_size: int = 32*32*32) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + input_img_size: int, H*W*D of the image, used with self.use_timestep_transform is True. + base_img_size: int, reference H*W*D size, used with self.use_timestep_transform is True. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + # prepare timesteps + timesteps = [(1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps)] + if self.use_discrete_timesteps: + timesteps = [int(round(t)) for t in timesteps] + if self.use_timestep_transform: + timesteps = [timestep_transform(t, input_img_size=input_img_size, base_img_size=base_img_size, num_train_timesteps=self.num_train_timesteps) for t in timesteps] + timesteps = np.array(timesteps).astype(np.float16) + if self.use_discrete_timesteps: + timesteps = timesteps.astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps += self.steps_offset + print(self.timesteps) + + def sample_timesteps(self, x_start): + if self.sample_method == "uniform": + t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_train_timesteps + elif self.sample_method == "logit-normal": + t = self.sample_t(x_start) * self.num_train_timesteps + + if self.use_discrete_timesteps: + t = t.long() + + if self.use_timestep_transform: + input_img_size = torch.prod(torch.tensor(x_start.shape[-3:])) + base_img_size = 32*32*32 + t = timestep_transform(t, input_img_size=input_img_size, base_img_size=base_img_size, num_train_timesteps=self.num_train_timesteps) + + return t + + def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep = None) -> tuple[torch.Tensor, Any]: + """ + Predict the sample at the previous timestep. Core function to propagate the diffusion + process from the learned model outputs. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + Returns: + pred_prev_sample: Predicted previous sample + None + """ + v_pred = model_output + if next_timestep is None: + dt = 1.0 / self.num_inference_steps + else: + dt = timestep - next_timestep + dt = dt / self.num_train_timesteps + z = sample + v_pred * dt + + return z, None diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index d161597f..25984a18 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -21,6 +21,7 @@ import torch from monai.data import MetaTensor from monai.inferers import sliding_window_inference +from monai.inferers.inferer import SlidingWindowInferer from monai.inferers.inferer import DiffusionInferer from monai.transforms import Compose, SaveImage from monai.utils import set_determinism @@ -29,8 +30,24 @@ from .augmentation import augmentation from .find_masks import find_masks from .quality_check import is_outlier -from .utils import binarize_labels, general_mask_generation_post_process, get_body_region_index_from_mask, remap_labels - +from .utils import binarize_labels, general_mask_generation_post_process, get_body_region_index_from_mask, remap_labels, dynamic_infer + + +modality_mapping = { + "unknown":0, + "ct":1, + "ct_wo_contrast":2, + "ct_contrast":3, + "mri":8, + "mri_t1":9, + "mri_t2":10, + "mri_flair":11, + "mri_pd":12, + "mri_dwi":13, + "mri_adc":14, + "mri_ssfp":15, + "mri_mra":16 +} # current version only support "ct" class ReconModel(torch.nn.Module): """ @@ -123,28 +140,16 @@ def ldm_conditional_sample_one_mask( conditioning=anatomy_size.to(device), ) # decode latents to synthesized masks - if math.prod(latent_shape[1:]) < math.prod(autoencoder_sliding_window_infer_size): - synthetic_mask = recon_model(latents).cpu().detach() - else: - synthetic_mask = ( - sliding_window_inference( - inputs=latents, - roi_size=( - autoencoder_sliding_window_infer_size[0], - autoencoder_sliding_window_infer_size[1], - autoencoder_sliding_window_infer_size[2], - ), - sw_batch_size=1, - predictor=recon_model, - mode="gaussian", - overlap=autoencoder_sliding_window_infer_overlap, - sw_device=device, - device=torch.device("cpu"), - progress=True, - ) - .cpu() - .detach() - ) + inferer = SlidingWindowInferer( + roi_size= autoencoder_sliding_window_infer_size, + sw_batch_size=1, + progress=True, + mode="gaussian", + overlap=autoencoder_sliding_window_infer_overlap, + device=torch.device("cpu"), + sw_device=device + ) + synthetic_mask = dynamic_infer(inferer, recon_model, latents) synthetic_mask = torch.softmax(synthetic_mask, dim=1) synthetic_mask = torch.argmax(synthetic_mask, dim=1, keepdim=True) # mapping raw index to 132 labels @@ -174,8 +179,7 @@ def ldm_conditional_sample_one_image( scale_factor, device, combine_label_or, - top_region_index_tensor, - bottom_region_index_tensor, + modality_tensor, spacing_tensor, latent_shape, output_size, @@ -239,19 +243,18 @@ def ldm_conditional_sample_one_image( latents = initialize_noise_latents(latent_shape, device) * noise_factor # synthesize latents - noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps) + noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps, input_img_size = torch.prod(torch.tensor(latent_shape[-3:])) ) for t in tqdm(noise_scheduler.timesteps, ncols=110): # Get controlnet output down_block_res_samples, mid_block_res_sample = controlnet( - x=latents, timesteps=torch.Tensor((t,)).to(device), controlnet_cond=controlnet_cond_vis + x=latents, timesteps=torch.Tensor((t,)).to(device), controlnet_cond=controlnet_cond_vis, class_labels = modality_tensor, ) latent_model_input = latents noise_pred = diffusion_unet( x=latent_model_input, timesteps=torch.Tensor((t,)).to(device), - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, spacing_tensor=spacing_tensor, + class_labels = modality_tensor, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, ) @@ -263,25 +266,17 @@ def ldm_conditional_sample_one_image( # decode latents to synthesized images logging.info("---- Start decoding latent features into images... ----") + inferer = SlidingWindowInferer( + roi_size= autoencoder_sliding_window_infer_size, + sw_batch_size=1, + progress=True, + mode="gaussian", + overlap=autoencoder_sliding_window_infer_overlap, + device=torch.device("cpu"), + sw_device=device + ) start_time = time.time() - if math.prod(latent_shape[1:]) < math.prod(autoencoder_sliding_window_infer_size): - synthetic_images = recon_model(latents) - else: - synthetic_images = sliding_window_inference( - inputs=latents, - roi_size=( - min(output_size[0] // 4 // 4 * 3, autoencoder_sliding_window_infer_size[0]), - min(output_size[1] // 4 // 4 * 3, autoencoder_sliding_window_infer_size[1]), - min(output_size[2] // 4 // 4 * 3, autoencoder_sliding_window_infer_size[2]), - ), - sw_batch_size=1, - predictor=recon_model, - mode="gaussian", - overlap=autoencoder_sliding_window_infer_overlap, - sw_device=device, - device=torch.device("cpu"), - progress=True, - ) + synthetic_images = dynamic_infer(inferer, recon_model, latents) synthetic_images = torch.clip(synthetic_images, b_min, b_max).cpu() end_time = time.time() logging.info(f"---- Image decoding time: {end_time - start_time} seconds ----") @@ -474,6 +469,7 @@ def __init__( self, body_region, anatomy_list, + modality, all_mask_files_json, all_anatomy_size_condtions_json, all_mask_files_base_dir, @@ -520,6 +516,7 @@ def __init__( # intialize variables self.body_region = body_region self.anatomy_list = [label_dict[organ] for organ in anatomy_list] + self.modality_int = modality_mapping[modality] self.all_mask_files_json = all_mask_files_json self.data_root = all_mask_files_base_dir self.label_dict_remap_json = label_dict_remap_json @@ -677,9 +674,10 @@ def sample_multiple_images(self, num_img): # generate image/label pairs to_generate = True try_time = 0 + modality_tensor = torch.ones_like(spacing_tensor[:,0]).long()*self.modality_int while to_generate: synthetic_images, synthetic_labels = self.sample_one_pair( - combine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor + combine_label_or, modality_tensor, spacing_tensor ) # synthetic image quality check pass_quality_check = self.quality_check( @@ -741,7 +739,7 @@ def select_mask(self, candidate_mask_files, num_img): return selected_mask_files def sample_one_pair( - self, combine_label_or_aug, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor + self, combine_label_or_aug, modality_tensor, spacing_tensor ): """ Generate a single pair of synthetic image and mask. @@ -764,8 +762,7 @@ def sample_one_pair( scale_factor=self.scale_factor, device=self.device, combine_label_or=combine_label_or_aug, - top_region_index_tensor=top_region_index_tensor, - bottom_region_index_tensor=bottom_region_index_tensor, + modality_tensor = modality_tensor, spacing_tensor=spacing_tensor, latent_shape=self.latent_shape, output_size=self.output_size, diff --git a/models/maisi_ct_generative/scripts/utils.py b/models/maisi_ct_generative/scripts/utils.py index 0cd46590..f7a05cbb 100644 --- a/models/maisi_ct_generative/scripts/utils.py +++ b/models/maisi_ct_generative/scripts/utils.py @@ -661,7 +661,6 @@ def __call__(self, img: NdarrayOrTensor): out, *_ = convert_to_dst_type(src=out_t, dst=img, dtype=self.dtype) return out - def dynamic_infer(inferer, model, images): """ Perform dynamic inference using a model and an inferer, typically a monai SlidingWindowInferer. @@ -680,4 +679,18 @@ def dynamic_infer(inferer, model, images): if torch.numel(images[0:1, 0:1, ...]) < math.prod(inferer.roi_size): return model(images) else: - return inferer(network=model, inputs=images) + # Extract the spatial dimensions from the images tensor (H, W, D) + spatial_dims = images.shape[2:] + orig_roi = inferer.roi_size + + # Check that roi has the same number of dimensions as spatial_dims + if len(orig_roi) != len(spatial_dims): + raise ValueError(f"ROI length ({len(orig_roi)}) does not match spatial dimensions ({len(spatial_dims)}).") + + # Iterate and adjust each ROI dimension + adjusted_roi = [min(roi_dim, img_dim) for roi_dim, img_dim in zip(orig_roi, spatial_dims)] + inferer.roi_size = adjusted_roi + output = inferer(network=model, inputs=images) + inferer.roi_size = orig_roi + return output + From baa3583955183a906424b488d2daa3534b7182b7 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 12 Feb 2025 06:39:24 +0000 Subject: [PATCH 02/50] update meta Signed-off-by: Can-Zhao --- models/maisi_ct_generative/configs/metadata.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/configs/metadata.json b/models/maisi_ct_generative/configs/metadata.json index 010a70d1..40bacb52 100644 --- a/models/maisi_ct_generative/configs/metadata.json +++ b/models/maisi_ct_generative/configs/metadata.json @@ -1,7 +1,8 @@ { "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_generator_ldm_20240318.json", - "version": "0.4.6", + "version": "1.0.0", "changelog": { + "1.0.0": "accelerated maisi, inference only, is not compartible with previous maisi diffusion model weights" "0.4.6": "add TensorRT support", "0.4.5": "update README", "0.4.4": "update issue for IgniteInfo", From 6b806a124eff56dba50c592c73a4b0170d57f533 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 12 Feb 2025 06:45:17 +0000 Subject: [PATCH 03/50] update meta Signed-off-by: Can-Zhao --- models/maisi_ct_generative/configs/metadata.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/configs/metadata.json b/models/maisi_ct_generative/configs/metadata.json index 40bacb52..f46c2307 100644 --- a/models/maisi_ct_generative/configs/metadata.json +++ b/models/maisi_ct_generative/configs/metadata.json @@ -2,7 +2,7 @@ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_generator_ldm_20240318.json", "version": "1.0.0", "changelog": { - "1.0.0": "accelerated maisi, inference only, is not compartible with previous maisi diffusion model weights" + "1.0.0": "accelerated maisi, inference only, is not compartible with previous maisi diffusion model weights", "0.4.6": "add TensorRT support", "0.4.5": "update README", "0.4.4": "update issue for IgniteInfo", From 9b5b1dd98a24fe693c6c3e529dc358326bd38497 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 12 Feb 2025 06:52:38 +0000 Subject: [PATCH 04/50] update sample Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/sample.py | 48 ++++++++++++++------ 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 25984a18..c95710a9 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -244,21 +244,39 @@ def ldm_conditional_sample_one_image( # synthesize latents noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps, input_img_size = torch.prod(torch.tensor(latent_shape[-3:])) ) - for t in tqdm(noise_scheduler.timesteps, ncols=110): - # Get controlnet output - down_block_res_samples, mid_block_res_sample = controlnet( - x=latents, timesteps=torch.Tensor((t,)).to(device), controlnet_cond=controlnet_cond_vis, class_labels = modality_tensor, - ) - latent_model_input = latents - noise_pred = diffusion_unet( - x=latent_model_input, - timesteps=torch.Tensor((t,)).to(device), - spacing_tensor=spacing_tensor, - class_labels = modality_tensor, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ) - latents, _ = noise_scheduler.step(noise_pred, t, latents) + # synthesize latents + guidance_scale = 0 # API for classifier-free guidence, not used in this version + all_next_timesteps = torch.cat((noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype))) + for t, next_t in tqdm(zip(noise_scheduler.timesteps, all_next_timesteps), total=min(len(noise_scheduler.timesteps), len(all_next_timesteps))): + timesteps = torch.Tensor((t,)).to(device) + if guidance_scale == 0: + down_block_res_samples, mid_block_res_sample = controlnet( + x=latents, timesteps=timesteps, controlnet_cond=controlnet_cond_vis, + class_labels = modality_tensor, + ) + predicted_velocity = diffusion_unet( + x=latents, + timesteps=timesteps, + spacing_tensor=spacing_tensor, + class_labels = modality_tensor, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + else: + down_block_res_samples, mid_block_res_sample = controlnet( + x=torch.cat([latents] * 2), timesteps=torch.cat([timesteps] * 2), controlnet_cond=torch.cat([controlnet_cond_vis] * 2), + class_labels = torch.cat([modality_tensor, torch.zeros_like(modality_tensor)]), + ) + model_t, model_uncond = diffusion_unet( + x=torch.cat([latents] * 2), + timesteps=timesteps, + spacing_tensor=torch.cat([timesteps] * 2), + class_labels = torch.cat([modality_tensor, torch.zeros_like(modality_tensor)]), + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).chunk(2) + predicted_velocity = model_uncond + guidance_scale * (model_t - model_uncond) + latents, _ = noise_scheduler.step(predicted_velocity, t, latents, next_timestep= next_t) end_time = time.time() logging.info(f"---- Latent features generation time: {end_time - start_time} seconds ----") del noise_pred From 7358801b2560550da4b222246ed401497a4ce807 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 12 Feb 2025 06:53:16 +0000 Subject: [PATCH 05/50] update sample Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index c95710a9..7a2fc2c6 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -279,7 +279,7 @@ def ldm_conditional_sample_one_image( latents, _ = noise_scheduler.step(predicted_velocity, t, latents, next_timestep= next_t) end_time = time.time() logging.info(f"---- Latent features generation time: {end_time - start_time} seconds ----") - del noise_pred + del predicted_velocity torch.cuda.empty_cache() # decode latents to synthesized images From e8ef29fad216e761829d63cc51e870a82b202d31 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 12 Feb 2025 06:58:44 +0000 Subject: [PATCH 06/50] update sample Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/sample.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 7a2fc2c6..7325bf49 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -221,7 +221,7 @@ def ldm_conditional_sample_one_image( recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device) - with torch.no_grad(), torch.amp.autocast("cuda"): + with torch.no_grad(), torch.amp.autocast("cuda", enabled=True): logging.info("---- Start generating latent features... ----") start_time = time.time() # generate segmentation mask @@ -581,7 +581,7 @@ def __init__( self.autoencoder_sliding_window_infer_overlap = autoencoder_sliding_window_infer_overlap # quality check args - self.max_try_time = 5 # if not pass quality check, will try self.max_try_time times + self.max_try_time = 1 # if not pass quality check, will try self.max_try_time times with open(real_img_median_statistics, "r") as json_file: self.median_statistics = json.load(json_file) self.label_int_dict = { From 7099ceb50eb6eb25194f238c59fa9869a86f4254 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Feb 2025 06:59:13 +0000 Subject: [PATCH 07/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- models/maisi_ct_generative/configs/inference.json | 8 ++++++-- .../maisi_ct_generative/scripts/rectified_flow.py | 14 +++++++------- models/maisi_ct_generative/scripts/sample.py | 6 +++--- models/maisi_ct_generative/scripts/utils.py | 5 ++--- 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json index c60b4945..e85da210 100644 --- a/models/maisi_ct_generative/configs/inference.json +++ b/models/maisi_ct_generative/configs/inference.json @@ -157,7 +157,11 @@ "num_res_blocks": 2, "use_flash_attention": true, "conditioning_embedding_in_channels": 8, - "conditioning_embedding_num_channels": [8, 32, 64], + "conditioning_embedding_num_channels": [ + 8, + 32, + 64 + ], "num_class_embeds": 128, "resblock_updown": true, "include_fc": true @@ -247,7 +251,7 @@ "use_discrete_timesteps": false, "use_timestep_transform": true, "sample_method": "logit-normal", - "scale":1.2 + "scale": 1.2 }, "mask_generation_noise_scheduler": { "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler", diff --git a/models/maisi_ct_generative/scripts/rectified_flow.py b/models/maisi_ct_generative/scripts/rectified_flow.py index 0ffcae3d..0c7918bb 100644 --- a/models/maisi_ct_generative/scripts/rectified_flow.py +++ b/models/maisi_ct_generative/scripts/rectified_flow.py @@ -95,16 +95,16 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N f" maximal {self.num_train_timesteps} timesteps." ) - self.num_inference_steps = num_inference_steps + self.num_inference_steps = num_inference_steps # prepare timesteps - timesteps = [(1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps)] + timesteps = [(1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps)] if self.use_discrete_timesteps: timesteps = [int(round(t)) for t in timesteps] if self.use_timestep_transform: timesteps = [timestep_transform(t, input_img_size=input_img_size, base_img_size=base_img_size, num_train_timesteps=self.num_train_timesteps) for t in timesteps] - timesteps = np.array(timesteps).astype(np.float16) + timesteps = np.array(timesteps).astype(np.float16) if self.use_discrete_timesteps: - timesteps = timesteps.astype(np.int64) + timesteps = timesteps.astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps += self.steps_offset print(self.timesteps) @@ -119,12 +119,12 @@ def sample_timesteps(self, x_start): t = t.long() if self.use_timestep_transform: - input_img_size = torch.prod(torch.tensor(x_start.shape[-3:])) + input_img_size = torch.prod(torch.tensor(x_start.shape[-3:])) base_img_size = 32*32*32 t = timestep_transform(t, input_img_size=input_img_size, base_img_size=base_img_size, num_train_timesteps=self.num_train_timesteps) - + return t - + def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep = None) -> tuple[torch.Tensor, Any]: """ Predict the sample at the previous timestep. Core function to propagate the diffusion diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 7325bf49..78fb8047 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -38,7 +38,7 @@ "ct":1, "ct_wo_contrast":2, "ct_contrast":3, - "mri":8, + "mri":8, "mri_t1":9, "mri_t2":10, "mri_flair":11, @@ -248,7 +248,7 @@ def ldm_conditional_sample_one_image( guidance_scale = 0 # API for classifier-free guidence, not used in this version all_next_timesteps = torch.cat((noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype))) for t, next_t in tqdm(zip(noise_scheduler.timesteps, all_next_timesteps), total=min(len(noise_scheduler.timesteps), len(all_next_timesteps))): - timesteps = torch.Tensor((t,)).to(device) + timesteps = torch.Tensor((t,)).to(device) if guidance_scale == 0: down_block_res_samples, mid_block_res_sample = controlnet( x=latents, timesteps=timesteps, controlnet_cond=controlnet_cond_vis, @@ -261,7 +261,7 @@ def ldm_conditional_sample_one_image( class_labels = modality_tensor, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, - ) + ) else: down_block_res_samples, mid_block_res_sample = controlnet( x=torch.cat([latents] * 2), timesteps=torch.cat([timesteps] * 2), controlnet_cond=torch.cat([controlnet_cond_vis] * 2), diff --git a/models/maisi_ct_generative/scripts/utils.py b/models/maisi_ct_generative/scripts/utils.py index f7a05cbb..ee5483fd 100644 --- a/models/maisi_ct_generative/scripts/utils.py +++ b/models/maisi_ct_generative/scripts/utils.py @@ -682,15 +682,14 @@ def dynamic_infer(inferer, model, images): # Extract the spatial dimensions from the images tensor (H, W, D) spatial_dims = images.shape[2:] orig_roi = inferer.roi_size - + # Check that roi has the same number of dimensions as spatial_dims if len(orig_roi) != len(spatial_dims): raise ValueError(f"ROI length ({len(orig_roi)}) does not match spatial dimensions ({len(spatial_dims)}).") - + # Iterate and adjust each ROI dimension adjusted_roi = [min(roi_dim, img_dim) for roi_dim, img_dim in zip(orig_roi, spatial_dims)] inferer.roi_size = adjusted_roi output = inferer(network=model, inputs=images) inferer.roi_size = orig_roi return output - From dbe8612382d3aaa29f3cf073fd2141f2c9de3474 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 12 Feb 2025 07:04:23 +0000 Subject: [PATCH 08/50] update config Signed-off-by: Can-Zhao --- models/maisi_ct_generative/configs/inference.json | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json index c60b4945..a1cd0e1b 100644 --- a/models/maisi_ct_generative/configs/inference.json +++ b/models/maisi_ct_generative/configs/inference.json @@ -11,7 +11,7 @@ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "trained_autoencoder_path": "$@model_dir + '/autoencoder_epoch273.pt'", "trained_diffusion_path": "$@model_dir + '/diff_unet_ckpt_epoch19200.pt'", - "trained_controlnet_path": "$@model_dir + '/controlnet_current.pt'", + "trained_controlnet_path": "$@model_dir + '/controlnet_epoch87.pt'", "trained_mask_generation_autoencoder_path": "$@model_dir + '/mask_generation_autoencoder.pt'", "trained_mask_generation_diffusion_path": "$@model_dir + '/mask_generation_diffusion_unet.pt'", "all_mask_files_base_dir": "$@bundle_root + '/datasets/all_masks_flexible_size_and_spacing_3000'", @@ -64,11 +64,11 @@ 64 ], "autoencoder_sliding_window_infer_size": [ - 96, - 96, - 96 + 80, + 80, + 80 ], - "autoencoder_sliding_window_infer_overlap": 0.3333, + "autoencoder_sliding_window_infer_overlap": 0.25, "autoencoder_def": { "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi", "spatial_dims": "@spatial_dims", From 8064b5430bc18ce716e01862daa18a7269a43e7a Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 12 Feb 2025 07:13:54 +0000 Subject: [PATCH 09/50] reformat Signed-off-by: Can-Zhao --- .../scripts/rectified_flow.py | 47 +++++++--- models/maisi_ct_generative/scripts/sample.py | 90 +++++++++++-------- 2 files changed, 87 insertions(+), 50 deletions(-) diff --git a/models/maisi_ct_generative/scripts/rectified_flow.py b/models/maisi_ct_generative/scripts/rectified_flow.py index 0c7918bb..c2a1485f 100644 --- a/models/maisi_ct_generative/scripts/rectified_flow.py +++ b/models/maisi_ct_generative/scripts/rectified_flow.py @@ -1,8 +1,9 @@ +from typing import Any + import numpy as np import torch -from torch.distributions import LogisticNormal from monai.networks.schedulers import Scheduler -from typing import Any +from torch.distributions import LogisticNormal # code modified from https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py @@ -10,14 +11,14 @@ def timestep_transform( t, input_img_size, - base_img_size=32*32*32, + base_img_size=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, - spatial_dim = 3, + spatial_dim=3, ): t = t / num_train_timesteps resolution = input_img_size - ratio_space = (input_img_size / base_img_size).pow(1./spatial_dim) + ratio_space = (input_img_size / base_img_size).pow(1.0 / spatial_dim) ratio = ratio_space * scale new_t = ratio * t / (1 + (ratio - 1) * t) @@ -58,7 +59,6 @@ def __init__( self.transform_scale = transform_scale self.steps_offset = steps_offset - def add_noise( self, original_samples: torch.FloatTensor, @@ -78,7 +78,13 @@ def add_noise( return timepoints * original_samples + (1 - timepoints) * noise - def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None, input_img_size: int |None = None, base_img_size: int = 32*32*32) -> None: + def set_timesteps( + self, + num_inference_steps: int, + device: str | torch.device | None = None, + input_img_size: int | None = None, + base_img_size: int = 32 * 32 * 32, + ) -> None: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -97,11 +103,21 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N self.num_inference_steps = num_inference_steps # prepare timesteps - timesteps = [(1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps)] + timesteps = [ + (1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps) + ] if self.use_discrete_timesteps: timesteps = [int(round(t)) for t in timesteps] if self.use_timestep_transform: - timesteps = [timestep_transform(t, input_img_size=input_img_size, base_img_size=base_img_size, num_train_timesteps=self.num_train_timesteps) for t in timesteps] + timesteps = [ + timestep_transform( + t, + input_img_size=input_img_size, + base_img_size=base_img_size, + num_train_timesteps=self.num_train_timesteps, + ) + for t in timesteps + ] timesteps = np.array(timesteps).astype(np.float16) if self.use_discrete_timesteps: timesteps = timesteps.astype(np.int64) @@ -120,12 +136,19 @@ def sample_timesteps(self, x_start): if self.use_timestep_transform: input_img_size = torch.prod(torch.tensor(x_start.shape[-3:])) - base_img_size = 32*32*32 - t = timestep_transform(t, input_img_size=input_img_size, base_img_size=base_img_size, num_train_timesteps=self.num_train_timesteps) + base_img_size = 32 * 32 * 32 + t = timestep_transform( + t, + input_img_size=input_img_size, + base_img_size=base_img_size, + num_train_timesteps=self.num_train_timesteps, + ) return t - def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep = None) -> tuple[torch.Tensor, Any]: + def step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep=None + ) -> tuple[torch.Tensor, Any]: """ Predict the sample at the previous timestep. Core function to propagate the diffusion process from the learned model outputs. diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 78fb8047..20a61727 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -21,8 +21,7 @@ import torch from monai.data import MetaTensor from monai.inferers import sliding_window_inference -from monai.inferers.inferer import SlidingWindowInferer -from monai.inferers.inferer import DiffusionInferer +from monai.inferers.inferer import DiffusionInferer, SlidingWindowInferer from monai.transforms import Compose, SaveImage from monai.utils import set_determinism from tqdm import tqdm @@ -30,24 +29,30 @@ from .augmentation import augmentation from .find_masks import find_masks from .quality_check import is_outlier -from .utils import binarize_labels, general_mask_generation_post_process, get_body_region_index_from_mask, remap_labels, dynamic_infer - +from .utils import ( + binarize_labels, + dynamic_infer, + general_mask_generation_post_process, + get_body_region_index_from_mask, + remap_labels, +) modality_mapping = { - "unknown":0, - "ct":1, - "ct_wo_contrast":2, - "ct_contrast":3, - "mri":8, - "mri_t1":9, - "mri_t2":10, - "mri_flair":11, - "mri_pd":12, - "mri_dwi":13, - "mri_adc":14, - "mri_ssfp":15, - "mri_mra":16 -} # current version only support "ct" + "unknown": 0, + "ct": 1, + "ct_wo_contrast": 2, + "ct_contrast": 3, + "mri": 8, + "mri_t1": 9, + "mri_t2": 10, + "mri_flair": 11, + "mri_pd": 12, + "mri_dwi": 13, + "mri_adc": 14, + "mri_ssfp": 15, + "mri_mra": 16, +} # current version only support "ct" + class ReconModel(torch.nn.Module): """ @@ -141,13 +146,13 @@ def ldm_conditional_sample_one_mask( ) # decode latents to synthesized masks inferer = SlidingWindowInferer( - roi_size= autoencoder_sliding_window_infer_size, + roi_size=autoencoder_sliding_window_infer_size, sw_batch_size=1, progress=True, mode="gaussian", overlap=autoencoder_sliding_window_infer_overlap, device=torch.device("cpu"), - sw_device=device + sw_device=device, ) synthetic_mask = dynamic_infer(inferer, recon_model, latents) synthetic_mask = torch.softmax(synthetic_mask, dim=1) @@ -243,40 +248,51 @@ def ldm_conditional_sample_one_image( latents = initialize_noise_latents(latent_shape, device) * noise_factor # synthesize latents - noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps, input_img_size = torch.prod(torch.tensor(latent_shape[-3:])) ) + noise_scheduler.set_timesteps( + num_inference_steps=num_inference_steps, input_img_size=torch.prod(torch.tensor(latent_shape[-3:])) + ) # synthesize latents - guidance_scale = 0 # API for classifier-free guidence, not used in this version - all_next_timesteps = torch.cat((noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype))) - for t, next_t in tqdm(zip(noise_scheduler.timesteps, all_next_timesteps), total=min(len(noise_scheduler.timesteps), len(all_next_timesteps))): + guidance_scale = 0 # API for classifier-free guidence, not used in this version + all_next_timesteps = torch.cat( + (noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype)) + ) + for t, next_t in tqdm( + zip(noise_scheduler.timesteps, all_next_timesteps), + total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)), + ): timesteps = torch.Tensor((t,)).to(device) if guidance_scale == 0: down_block_res_samples, mid_block_res_sample = controlnet( - x=latents, timesteps=timesteps, controlnet_cond=controlnet_cond_vis, - class_labels = modality_tensor, + x=latents, + timesteps=timesteps, + controlnet_cond=controlnet_cond_vis, + class_labels=modality_tensor, ) predicted_velocity = diffusion_unet( x=latents, timesteps=timesteps, spacing_tensor=spacing_tensor, - class_labels = modality_tensor, + class_labels=modality_tensor, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, ) else: down_block_res_samples, mid_block_res_sample = controlnet( - x=torch.cat([latents] * 2), timesteps=torch.cat([timesteps] * 2), controlnet_cond=torch.cat([controlnet_cond_vis] * 2), - class_labels = torch.cat([modality_tensor, torch.zeros_like(modality_tensor)]), + x=torch.cat([latents] * 2), + timesteps=torch.cat([timesteps] * 2), + controlnet_cond=torch.cat([controlnet_cond_vis] * 2), + class_labels=torch.cat([modality_tensor, torch.zeros_like(modality_tensor)]), ) model_t, model_uncond = diffusion_unet( x=torch.cat([latents] * 2), timesteps=timesteps, spacing_tensor=torch.cat([timesteps] * 2), - class_labels = torch.cat([modality_tensor, torch.zeros_like(modality_tensor)]), + class_labels=torch.cat([modality_tensor, torch.zeros_like(modality_tensor)]), down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, ).chunk(2) predicted_velocity = model_uncond + guidance_scale * (model_t - model_uncond) - latents, _ = noise_scheduler.step(predicted_velocity, t, latents, next_timestep= next_t) + latents, _ = noise_scheduler.step(predicted_velocity, t, latents, next_timestep=next_t) end_time = time.time() logging.info(f"---- Latent features generation time: {end_time - start_time} seconds ----") del predicted_velocity @@ -285,13 +301,13 @@ def ldm_conditional_sample_one_image( # decode latents to synthesized images logging.info("---- Start decoding latent features into images... ----") inferer = SlidingWindowInferer( - roi_size= autoencoder_sliding_window_infer_size, + roi_size=autoencoder_sliding_window_infer_size, sw_batch_size=1, progress=True, mode="gaussian", overlap=autoencoder_sliding_window_infer_overlap, device=torch.device("cpu"), - sw_device=device + sw_device=device, ) start_time = time.time() synthetic_images = dynamic_infer(inferer, recon_model, latents) @@ -692,7 +708,7 @@ def sample_multiple_images(self, num_img): # generate image/label pairs to_generate = True try_time = 0 - modality_tensor = torch.ones_like(spacing_tensor[:,0]).long()*self.modality_int + modality_tensor = torch.ones_like(spacing_tensor[:, 0]).long() * self.modality_int while to_generate: synthetic_images, synthetic_labels = self.sample_one_pair( combine_label_or, modality_tensor, spacing_tensor @@ -756,9 +772,7 @@ def select_mask(self, candidate_mask_files, num_img): selected_mask_files.append({"mask_file": mask_file, "if_aug": True}) return selected_mask_files - def sample_one_pair( - self, combine_label_or_aug, modality_tensor, spacing_tensor - ): + def sample_one_pair(self, combine_label_or_aug, modality_tensor, spacing_tensor): """ Generate a single pair of synthetic image and mask. @@ -780,7 +794,7 @@ def sample_one_pair( scale_factor=self.scale_factor, device=self.device, combine_label_or=combine_label_or_aug, - modality_tensor = modality_tensor, + modality_tensor=modality_tensor, spacing_tensor=spacing_tensor, latent_shape=self.latent_shape, output_size=self.output_size, From e8a79066964f2dc0c728a43797399c634661a233 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 12 Feb 2025 07:18:45 +0000 Subject: [PATCH 10/50] reformat Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/rectified_flow.py | 12 ++---------- models/maisi_ct_generative/scripts/sample.py | 5 +---- models/maisi_ct_generative/scripts/utils.py | 1 + 3 files changed, 4 insertions(+), 14 deletions(-) diff --git a/models/maisi_ct_generative/scripts/rectified_flow.py b/models/maisi_ct_generative/scripts/rectified_flow.py index c2a1485f..d8cee94d 100644 --- a/models/maisi_ct_generative/scripts/rectified_flow.py +++ b/models/maisi_ct_generative/scripts/rectified_flow.py @@ -9,12 +9,7 @@ def timestep_transform( - t, - input_img_size, - base_img_size=32 * 32 * 32, - scale=1.0, - num_train_timesteps=1000, - spatial_dim=3, + t, input_img_size, base_img_size=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3 ): t = t / num_train_timesteps resolution = input_img_size @@ -60,10 +55,7 @@ def __init__( self.steps_offset = steps_offset def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, + self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor ) -> torch.FloatTensor: """ compatible with diffusers add_noise() diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 20a61727..3051e5be 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -263,10 +263,7 @@ def ldm_conditional_sample_one_image( timesteps = torch.Tensor((t,)).to(device) if guidance_scale == 0: down_block_res_samples, mid_block_res_sample = controlnet( - x=latents, - timesteps=timesteps, - controlnet_cond=controlnet_cond_vis, - class_labels=modality_tensor, + x=latents, timesteps=timesteps, controlnet_cond=controlnet_cond_vis, class_labels=modality_tensor ) predicted_velocity = diffusion_unet( x=latents, diff --git a/models/maisi_ct_generative/scripts/utils.py b/models/maisi_ct_generative/scripts/utils.py index ee5483fd..43cc62d7 100644 --- a/models/maisi_ct_generative/scripts/utils.py +++ b/models/maisi_ct_generative/scripts/utils.py @@ -661,6 +661,7 @@ def __call__(self, img: NdarrayOrTensor): out, *_ = convert_to_dst_type(src=out_t, dst=img, dtype=self.dtype) return out + def dynamic_infer(inferer, model, images): """ Perform dynamic inference using a model and an inferer, typically a monai SlidingWindowInferer. From 96965aa5219b238237be5162bc7a11c9d0822c72 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 12 Feb 2025 07:25:59 +0000 Subject: [PATCH 11/50] reformat Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/rectified_flow.py | 1 - models/maisi_ct_generative/scripts/sample.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/models/maisi_ct_generative/scripts/rectified_flow.py b/models/maisi_ct_generative/scripts/rectified_flow.py index d8cee94d..6bdcb00a 100644 --- a/models/maisi_ct_generative/scripts/rectified_flow.py +++ b/models/maisi_ct_generative/scripts/rectified_flow.py @@ -12,7 +12,6 @@ def timestep_transform( t, input_img_size, base_img_size=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3 ): t = t / num_train_timesteps - resolution = input_img_size ratio_space = (input_img_size / base_img_size).pow(1.0 / spatial_dim) ratio = ratio_space * scale diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 3051e5be..0bc8b4dc 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -11,7 +11,6 @@ import json import logging -import math import os import random import time @@ -20,7 +19,6 @@ import monai import torch from monai.data import MetaTensor -from monai.inferers import sliding_window_inference from monai.inferers.inferer import DiffusionInferer, SlidingWindowInferer from monai.transforms import Compose, SaveImage from monai.utils import set_determinism From 814d9921025a120a0f89009dbf5a985b535bfc45 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Fri, 14 Feb 2025 04:42:51 +0000 Subject: [PATCH 12/50] resume original setting Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 0bc8b4dc..93f67bd2 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -592,7 +592,7 @@ def __init__( self.autoencoder_sliding_window_infer_overlap = autoencoder_sliding_window_infer_overlap # quality check args - self.max_try_time = 1 # if not pass quality check, will try self.max_try_time times + self.max_try_time = 5 # if not pass quality check, will try self.max_try_time times with open(real_img_median_statistics, "r") as json_file: self.median_statistics = json.load(json_file) self.label_int_dict = { From 4e75b296d6f6a658109e4973db7b6c4a7d4daf92 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Fri, 14 Feb 2025 04:43:14 +0000 Subject: [PATCH 13/50] fully trained checkpoints Signed-off-by: Can-Zhao --- models/maisi_ct_generative/configs/inference.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json index 9794f3c1..112bb0a8 100644 --- a/models/maisi_ct_generative/configs/inference.json +++ b/models/maisi_ct_generative/configs/inference.json @@ -10,8 +10,8 @@ "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)", "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "trained_autoencoder_path": "$@model_dir + '/autoencoder_epoch273.pt'", - "trained_diffusion_path": "$@model_dir + '/diff_unet_ckpt_epoch19200.pt'", - "trained_controlnet_path": "$@model_dir + '/controlnet_epoch87.pt'", + "trained_diffusion_path": "$@model_dir + '/diff_unet_ckpt_epoch19350.pt'", + "trained_controlnet_path": "$@model_dir + '/controlnet_epoch150.pt'", "trained_mask_generation_autoencoder_path": "$@model_dir + '/mask_generation_autoencoder.pt'", "trained_mask_generation_diffusion_path": "$@model_dir + '/mask_generation_diffusion_unet.pt'", "all_mask_files_base_dir": "$@bundle_root + '/datasets/all_masks_flexible_size_and_spacing_3000'", From abe510cb509f1bdb8aaf5ff439e4e68f821ec401 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Fri, 14 Feb 2025 05:17:23 +0000 Subject: [PATCH 14/50] simplify find_mask Signed-off-by: Can-Zhao --- .../maisi_ct_generative/scripts/find_masks.py | 95 ++++++++++++++++++- 1 file changed, 94 insertions(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/scripts/find_masks.py b/models/maisi_ct_generative/scripts/find_masks.py index de626552..3e6ccc90 100644 --- a/models/maisi_ct_generative/scripts/find_masks.py +++ b/models/maisi_ct_generative/scripts/find_masks.py @@ -52,7 +52,7 @@ def convert_body_region(body_region: str | Sequence[str]) -> Sequence[int]: return body_region_indices -def find_masks( +def find_masks_original( body_region: str | Sequence[str], anatomy_list: int | Sequence[int], spacing: Sequence[float] | float = 1.0, @@ -153,3 +153,96 @@ def find_masks( raise ValueError("Cannot find body region with given anatomy list.") return candidate_masks + +def find_masks( + body_region: str | Sequence[str], + anatomy_list: int | Sequence[int], + spacing: Sequence[float] | float = 1.0, + output_size: Sequence[int] = (512, 512, 512), + check_spacing_and_output_size: bool = False, + database_filepath: str = "./configs/database.json", + mask_foldername: str = "./datasets/masks/", +): + """ + Find candidate masks that fullfills all the requirements. + They shoud contain all the anatomies in `anatomy_list`. + If there is no tumor specified in `anatomy_list`, we also expect the candidate masks to be tumor free. + If check_spacing_and_output_size is True, the candidate masks need to have the expected `spacing` and `output_size`. + Args: + anatomy_list: list of input anatomy. The found candidate mask will include these anatomies. + spacing: list of three floats, voxel spacing. If providing a single number, will use it for all the three dimensions. + output_size: list of three int, expected candidate mask spatial size. + check_spacing_and_output_size: whether we expect candidate mask to have spatial size of `output_size` + and voxel size of `spacing`. + database_filepath: path for the json file that stores the information of all the candidate masks. + mask_foldername: directory that saves all the candidate masks. + Return: + candidate_masks, list of dict, each dict contains information of one candidate mask that fullfills all the requirements. + """ + # check and preprocess input + if isinstance(anatomy_list, int): + anatomy_list = [anatomy_list] + + spacing = ensure_tuple_rep(spacing, 3) + + if not os.path.exists(mask_foldername): + zip_file_path = mask_foldername + ".zip" + + if not os.path.isfile(zip_file_path): + raise ValueError(f"Please download {zip_file_path} following the instruction in ./datasets/README.md.") + + print(f"Extracting {zip_file_path} to {os.path.dirname(zip_file_path)}") + extractall(filepath=zip_file_path, output_dir=os.path.dirname(zip_file_path), file_type="zip") + print(f"Unzipped {zip_file_path} to {mask_foldername}.") + + if not os.path.isfile(database_filepath): + raise ValueError(f"Please download {database_filepath} following the instruction in ./datasets/README.md.") + with open(database_filepath, "r") as f: + db = json.load(f) + + # select candidate_masks + candidate_masks = [] + for _item in db: + if not set(anatomy_list).issubset(_item["label_list"]): + continue + + # extract region indice (top_index and bottom_index) for candidate mask + top_index = [index for index, element in enumerate(_item["top_region_index"]) if element != 0] + top_index = top_index[0] + bottom_index = [index for index, element in enumerate(_item["bottom_region_index"]) if element != 0] + bottom_index = bottom_index[0] + + # whether to keep this mask, default to be True. + keep_mask = True + + for tumor_label in [23, 24, 26, 27, 128]: + # we skip those mask with tumors if users do not provide tumor label in anatomy_list + if tumor_label not in anatomy_list and tumor_label in _item["label_list"]: + keep_mask = False + + if check_spacing_and_output_size: + # if the output_size and spacing are different with user's input, skip it + for axis in range(3): + if _item["dim"][axis] != output_size[axis] or _item["spacing"][axis] != spacing[axis]: + keep_mask = False + + if keep_mask: + # if decide to keep this mask, we pack the information of this mask and add to final output. + candidate = { + "pseudo_label": os.path.join(mask_foldername, _item["pseudo_label_filename"]), + "spacing": _item["spacing"], + "dim": _item["dim"], + "top_region_index": _item["top_region_index"], + "bottom_region_index": _item["bottom_region_index"], + } + + # Conditionally add the label to the candidate dictionary + if "label_filename" in _item: + candidate["label"] = os.path.join(mask_foldername, _item["label_filename"]) + + candidate_masks.append(candidate) + + if len(candidate_masks) == 0 and not check_spacing_and_output_size: + raise ValueError("Cannot find body region with given anatomy list.") + + return candidate_masks From e0e6101c88b4522a41ca1637c64b8e42b37ba6cb Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Fri, 14 Feb 2025 05:24:32 +0000 Subject: [PATCH 15/50] simplify find_mask Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/find_masks.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/models/maisi_ct_generative/scripts/find_masks.py b/models/maisi_ct_generative/scripts/find_masks.py index 3e6ccc90..f5159e0f 100644 --- a/models/maisi_ct_generative/scripts/find_masks.py +++ b/models/maisi_ct_generative/scripts/find_masks.py @@ -206,12 +206,6 @@ def find_masks( if not set(anatomy_list).issubset(_item["label_list"]): continue - # extract region indice (top_index and bottom_index) for candidate mask - top_index = [index for index, element in enumerate(_item["top_region_index"]) if element != 0] - top_index = top_index[0] - bottom_index = [index for index, element in enumerate(_item["bottom_region_index"]) if element != 0] - bottom_index = bottom_index[0] - # whether to keep this mask, default to be True. keep_mask = True From 3c8e5a647361df8102b1cd19f54a60a57139511c Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Fri, 14 Feb 2025 18:38:07 +0000 Subject: [PATCH 16/50] reformat Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/find_masks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/models/maisi_ct_generative/scripts/find_masks.py b/models/maisi_ct_generative/scripts/find_masks.py index f5159e0f..62b1a6c4 100644 --- a/models/maisi_ct_generative/scripts/find_masks.py +++ b/models/maisi_ct_generative/scripts/find_masks.py @@ -154,6 +154,7 @@ def find_masks_original( return candidate_masks + def find_masks( body_region: str | Sequence[str], anatomy_list: int | Sequence[int], From a5ec051fa9dddd2da88081ddf7f3e647c8c565d0 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Fri, 14 Feb 2025 18:43:43 +0000 Subject: [PATCH 17/50] reduce max_try, usually if it failes one tim, it is likely the mask has issue, and it keps failing. Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/sample.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 93f67bd2..d73e7bf3 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -592,7 +592,7 @@ def __init__( self.autoencoder_sliding_window_infer_overlap = autoencoder_sliding_window_infer_overlap # quality check args - self.max_try_time = 5 # if not pass quality check, will try self.max_try_time times + self.max_try_time = 3 # if not pass quality check, will try self.max_try_time times with open(real_img_median_statistics, "r") as json_file: self.median_statistics = json.load(json_file) self.label_int_dict = { @@ -669,7 +669,6 @@ def sample_multiple_images(self, num_img): need_resample = True selected_mask_files = self.select_mask(candidate_mask_files, num_img) - logging.info(f"Images will be generated based on {selected_mask_files}.") if len(selected_mask_files) != num_img: raise ValueError( ( @@ -680,6 +679,7 @@ def sample_multiple_images(self, num_img): for item in selected_mask_files: logging.info("---- Start preparing masks... ----") start_time = time.time() + logging.info(f"Image will be generated based on {item}.") if len(self.controllable_anatomy_size) > 0: # generate a synthetic mask (combine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor) = ( @@ -746,6 +746,10 @@ def sample_multiple_images(self, num_img): "Generated image/label pair did not pass quality check, will re-generate another pair." ) try_time += 1 + if try_time > self.max_try_time: + logging.info( + "Generated image/label pair did not pass quality check. Please consider changing spacing and output_size to facilitate a more realistic setting." + ) return output_filenames def select_mask(self, candidate_mask_files, num_img): From 6acf87e55a6ada0cc4a3820de0a9b7bde2e83e5e Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Fri, 14 Feb 2025 19:11:28 +0000 Subject: [PATCH 18/50] reformat Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/sample.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index d73e7bf3..155e2e5e 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -748,7 +748,8 @@ def sample_multiple_images(self, num_img): try_time += 1 if try_time > self.max_try_time: logging.info( - "Generated image/label pair did not pass quality check. Please consider changing spacing and output_size to facilitate a more realistic setting." + "Generated image/label pair did not pass quality check. + Please consider changing spacing and output_size to facilitate a more realistic setting." ) return output_filenames From aaa56de553e870223e2a408d34efad5569474650 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 14 Feb 2025 19:11:48 +0000 Subject: [PATCH 19/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- models/maisi_ct_generative/scripts/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 155e2e5e..3b6ba99e 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -748,7 +748,7 @@ def sample_multiple_images(self, num_img): try_time += 1 if try_time > self.max_try_time: logging.info( - "Generated image/label pair did not pass quality check. + "Generated image/label pair did not pass quality check. Please consider changing spacing and output_size to facilitate a more realistic setting." ) return output_filenames From 7b68d37d495e49bcad78ecf511be88c56ff62716 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Fri, 14 Feb 2025 19:13:39 +0000 Subject: [PATCH 20/50] reformat Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/sample.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 155e2e5e..ab14cfd3 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -748,8 +748,8 @@ def sample_multiple_images(self, num_img): try_time += 1 if try_time > self.max_try_time: logging.info( - "Generated image/label pair did not pass quality check. - Please consider changing spacing and output_size to facilitate a more realistic setting." + "Generated image/label pair did not pass quality check. " + "Please consider changing spacing and output_size to facilitate a more realistic setting." ) return output_filenames From 11290dd28e729c1267dd14cdbec3d3c1fe87b2ab Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 18 Feb 2025 02:44:05 +0000 Subject: [PATCH 21/50] remove unused code Signed-off-by: Can-Zhao --- .../maisi_ct_generative/scripts/find_masks.py | 105 ------------------ models/maisi_ct_generative/scripts/sample.py | 26 +---- 2 files changed, 5 insertions(+), 126 deletions(-) diff --git a/models/maisi_ct_generative/scripts/find_masks.py b/models/maisi_ct_generative/scripts/find_masks.py index 62b1a6c4..6d86d5af 100644 --- a/models/maisi_ct_generative/scripts/find_masks.py +++ b/models/maisi_ct_generative/scripts/find_masks.py @@ -52,109 +52,6 @@ def convert_body_region(body_region: str | Sequence[str]) -> Sequence[int]: return body_region_indices -def find_masks_original( - body_region: str | Sequence[str], - anatomy_list: int | Sequence[int], - spacing: Sequence[float] | float = 1.0, - output_size: Sequence[int] = (512, 512, 512), - check_spacing_and_output_size: bool = False, - database_filepath: str = "./configs/database.json", - mask_foldername: str = "./datasets/masks/", -): - """ - Find candidate masks that fullfills all the requirements. - They shoud contain all the body region in `body_region`, all the anatomies in `anatomy_list`. - If there is no tumor specified in `anatomy_list`, we also expect the candidate masks to be tumor free. - If check_spacing_and_output_size is True, the candidate masks need to have the expected `spacing` and `output_size`. - Args: - body_region: list of input body region string. If single str, will be converted to list of str. - The found candidate mask will include these body regions. - anatomy_list: list of input anatomy. The found candidate mask will include these anatomies. - spacing: list of three floats, voxel spacing. If providing a single number, will use it for all the three dimensions. - output_size: list of three int, expected candidate mask spatial size. - check_spacing_and_output_size: whether we expect candidate mask to have spatial size of `output_size` - and voxel size of `spacing`. - database_filepath: path for the json file that stores the information of all the candidate masks. - mask_foldername: directory that saves all the candidate masks. - Return: - candidate_masks, list of dict, each dict contains information of one candidate mask that fullfills all the requirements. - """ - # check and preprocess input - body_region = convert_body_region(body_region) - - if isinstance(anatomy_list, int): - anatomy_list = [anatomy_list] - - spacing = ensure_tuple_rep(spacing, 3) - - if not os.path.exists(mask_foldername): - zip_file_path = mask_foldername + ".zip" - - if not os.path.isfile(zip_file_path): - raise ValueError(f"Please download {zip_file_path} following the instruction in ./datasets/README.md.") - - print(f"Extracting {zip_file_path} to {os.path.dirname(zip_file_path)}") - extractall(filepath=zip_file_path, output_dir=os.path.dirname(zip_file_path), file_type="zip") - print(f"Unzipped {zip_file_path} to {mask_foldername}.") - - if not os.path.isfile(database_filepath): - raise ValueError(f"Please download {database_filepath} following the instruction in ./datasets/README.md.") - with open(database_filepath, "r") as f: - db = json.load(f) - - # select candidate_masks - candidate_masks = [] - for _item in db: - if not set(anatomy_list).issubset(_item["label_list"]): - continue - - # extract region indice (top_index and bottom_index) for candidate mask - top_index = [index for index, element in enumerate(_item["top_region_index"]) if element != 0] - top_index = top_index[0] - bottom_index = [index for index, element in enumerate(_item["bottom_region_index"]) if element != 0] - bottom_index = bottom_index[0] - - # whether to keep this mask, default to be True. - keep_mask = True - - # if candiate mask does not contain all the body_region, skip it - for _idx in body_region: - if _idx > bottom_index or _idx < top_index: - keep_mask = False - - for tumor_label in [23, 24, 26, 27, 128]: - # we skip those mask with tumors if users do not provide tumor label in anatomy_list - if tumor_label not in anatomy_list and tumor_label in _item["label_list"]: - keep_mask = False - - if check_spacing_and_output_size: - # if the output_size and spacing are different with user's input, skip it - for axis in range(3): - if _item["dim"][axis] != output_size[axis] or _item["spacing"][axis] != spacing[axis]: - keep_mask = False - - if keep_mask: - # if decide to keep this mask, we pack the information of this mask and add to final output. - candidate = { - "pseudo_label": os.path.join(mask_foldername, _item["pseudo_label_filename"]), - "spacing": _item["spacing"], - "dim": _item["dim"], - "top_region_index": _item["top_region_index"], - "bottom_region_index": _item["bottom_region_index"], - } - - # Conditionally add the label to the candidate dictionary - if "label_filename" in _item: - candidate["label"] = os.path.join(mask_foldername, _item["label_filename"]) - - candidate_masks.append(candidate) - - if len(candidate_masks) == 0 and not check_spacing_and_output_size: - raise ValueError("Cannot find body region with given anatomy list.") - - return candidate_masks - - def find_masks( body_region: str | Sequence[str], anatomy_list: int | Sequence[int], @@ -227,8 +124,6 @@ def find_masks( "pseudo_label": os.path.join(mask_foldername, _item["pseudo_label_filename"]), "spacing": _item["spacing"], "dim": _item["dim"], - "top_region_index": _item["top_region_index"], - "bottom_region_index": _item["bottom_region_index"], } # Conditionally add the label to the candidate dictionary diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index ab14cfd3..778f91a2 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -202,8 +202,6 @@ def ldm_conditional_sample_one_image( scale_factor (float): Scaling factor for the latent space. device (torch.device): The device to run the computation on. combine_label_or (torch.Tensor): The combined label tensor. - top_region_index_tensor (torch.Tensor): Tensor specifying the top region index. - bottom_region_index_tensor (torch.Tensor): Tensor specifying the bottom region index. spacing_tensor (torch.Tensor): Tensor specifying the spacing. latent_shape (tuple): The shape of the latent space. output_size (tuple): The desired output size of the image. @@ -625,11 +623,7 @@ def __init__( monai.transforms.EnsureChannelFirstd(keys=["pseudo_label"]), monai.transforms.Orientationd(keys=["pseudo_label"], axcodes="RAS"), monai.transforms.EnsureTyped(keys=["pseudo_label"], dtype=torch.uint8), - monai.transforms.Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x)), - monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x)), monai.transforms.Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)), - monai.transforms.Lambdad(keys="top_region_index", func=lambda x: x * 1e2), - monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: x * 1e2), monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2), ] ) @@ -682,14 +676,14 @@ def sample_multiple_images(self, num_img): logging.info(f"Image will be generated based on {item}.") if len(self.controllable_anatomy_size) > 0: # generate a synthetic mask - (combine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor) = ( + (combine_label_or, spacing_tensor) = ( self.prepare_one_mask_and_meta_info(anatomy_size_condtion) ) else: # read in mask file mask_file = item["mask_file"] if_aug = item["if_aug"] - (combine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor) = ( + (combine_label_or, spacing_tensor) = ( self.read_mask_information(mask_file) ) if need_resample: @@ -778,8 +772,7 @@ def sample_one_pair(self, combine_label_or_aug, modality_tensor, spacing_tensor) Args: combine_label_or_aug (torch.Tensor): Combined label tensor or augmented label. - top_region_index_tensor (torch.Tensor): Tensor specifying the top region index. - bottom_region_index_tensor (torch.Tensor): Tensor specifying the bottom region index. + modality_tensor (torch.Tensor): Tensor specifying the image modality. spacing_tensor (torch.Tensor): Tensor specifying the spacing. Returns: @@ -875,13 +868,9 @@ def prepare_one_mask_and_meta_info(self, anatomy_size_condtion): combine_label_or = MetaTensor(combine_label_or, affine=affine) combine_label_or = self.ensure_output_size_and_spacing(combine_label_or) - top_region_index, bottom_region_index = get_body_region_index_from_mask(combine_label_or) - spacing_tensor = torch.FloatTensor(self.spacing).unsqueeze(0).half().to(self.device) * 1e2 - top_region_index_tensor = torch.FloatTensor(top_region_index).unsqueeze(0).half().to(self.device) * 1e2 - bottom_region_index_tensor = torch.FloatTensor(bottom_region_index).unsqueeze(0).half().to(self.device) * 1e2 - return combine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor + return combine_label_or, spacing_tensor def sample_one_mask(self, anatomy_size): """ @@ -969,13 +958,11 @@ def read_mask_information(self, mask_file): """ val_data = self.val_transforms(mask_file) - for key in ["pseudo_label", "spacing", "top_region_index", "bottom_region_index"]: + for key in ["pseudo_label", "spacing", ]: val_data[key] = val_data[key].unsqueeze(0).to(self.device) return ( val_data["pseudo_label"], - val_data["top_region_index"], - val_data["bottom_region_index"], val_data["spacing"], ) @@ -1030,9 +1017,6 @@ def find_closest_masks(self, num_img): else: raise e # get region_index after resample - top_region_index, bottom_region_index = get_body_region_index_from_mask(label) - c["top_region_index"] = top_region_index - c["bottom_region_index"] = bottom_region_index c["spacing"] = self.spacing c["dim"] = self.output_size From 62446d74ad08b89ba1676a3c68043c7dac0475d9 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 18 Feb 2025 03:16:22 +0000 Subject: [PATCH 22/50] remove unused code Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/find_masks.py | 1 - models/maisi_ct_generative/scripts/sample.py | 1 - 2 files changed, 2 deletions(-) diff --git a/models/maisi_ct_generative/scripts/find_masks.py b/models/maisi_ct_generative/scripts/find_masks.py index 6d86d5af..f8f3a14d 100644 --- a/models/maisi_ct_generative/scripts/find_masks.py +++ b/models/maisi_ct_generative/scripts/find_masks.py @@ -53,7 +53,6 @@ def convert_body_region(body_region: str | Sequence[str]) -> Sequence[int]: def find_masks( - body_region: str | Sequence[str], anatomy_list: int | Sequence[int], spacing: Sequence[float] | float = 1.0, output_size: Sequence[int] = (512, 512, 512), diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 778f91a2..8621e8a8 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -647,7 +647,6 @@ def sample_multiple_images(self, num_img): need_resample = False # find candidate mask and save to candidate_mask_files candidate_mask_files = find_masks( - self.body_region, self.anatomy_list, self.spacing, self.output_size, From 1d28f195893bf98cce5313c1eabaf234a30b9461 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 18 Feb 2025 03:51:30 +0000 Subject: [PATCH 23/50] make quality range more reasonable Signed-off-by: Can-Zhao --- .../scripts/quality_check.py | 4 ++-- models/maisi_ct_generative/scripts/sample.py | 22 +++++++++++++------ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/models/maisi_ct_generative/scripts/quality_check.py b/models/maisi_ct_generative/scripts/quality_check.py index fe34661f..a41b6e6e 100644 --- a/models/maisi_ct_generative/scripts/quality_check.py +++ b/models/maisi_ct_generative/scripts/quality_check.py @@ -109,8 +109,8 @@ def is_outlier(statistics, image_data, label_data, label_int_dict): for label_name, stats in statistics.items(): # Get the thresholds from the statistics - low_thresh = stats["sigma_6_low"] # or "sigma_12_low" depending on your needs - high_thresh = stats["sigma_6_high"] # or "sigma_12_high" depending on your needs + low_thresh = min(stats["sigma_6_low"], stats["percentile_0_5"]) # or "sigma_12_low" depending on your needs + high_thresh = max(stats["sigma_6_high"], stats["percentile_99_5"]) # or "sigma_12_high" depending on your needs # Retrieve the corresponding label integers labels = label_int_dict.get(label_name, []) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 8621e8a8..507e29a7 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -31,7 +31,6 @@ binarize_labels, dynamic_infer, general_mask_generation_post_process, - get_body_region_index_from_mask, remap_labels, ) @@ -459,8 +458,7 @@ def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing else: logging.info( ( - "`controllable_anatomy_size` is empty.\nWe will synthesize based on `body_region`: " - f"({body_region}) and `anatomy_list`: ({anatomy_list})." + f"`controllable_anatomy_size` is empty.\nWe will synthesize based on `anatomy_list`: ({anatomy_list})." ) ) # check body_region format @@ -980,7 +978,6 @@ def find_closest_masks(self, num_img): """ # first check the database based on anatomy list candidates = find_masks( - self.body_region, self.anatomy_list, self.spacing, self.output_size, @@ -991,19 +988,30 @@ def find_closest_masks(self, num_img): if len(candidates) < num_img: raise ValueError(f"candidate masks are less than {num_img}).") + # loop through the database and find closest combinations new_candidates = [] for c in candidates: diff = 0 + include_c = True for axis in range(3): + if abs(c["dim"][axis]) < self.output_size[axis]-64: + # we cannot upsample the mask too much + include_c = False + break + # check diff in FOV + diff += abs((abs(c["dim"][axis]*c["spacing"][axis]) - self.output_size[axis]*self.spacing[axis]) / 10) # check diff in dim - diff += abs((c["dim"][axis] - self.output_size[axis]) / 100) + diff += abs((abs(c["dim"][axis]) - self.output_size[axis]) / 100) # check diff in spacing - diff += abs(c["spacing"][axis] - self.spacing[axis]) - new_candidates.append((c, diff)) + diff += abs(abs(c["spacing"][axis]) - self.spacing[axis]) + if include_c: + new_candidates.append((c, diff)) + # choose top-2*num_img candidates (at least 5) new_candidates = sorted(new_candidates, key=lambda x: x[1])[: max(2 * num_img, 5)] final_candidates = [] + # check top-2*num_img candidates and update spacing after resampling image_loader = monai.transforms.LoadImage(image_only=True, ensure_channel_first=True) for c, _ in new_candidates: From bf08109e0d0c01c9b5ce41073c552cb4d7160f4c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Feb 2025 03:51:48 +0000 Subject: [PATCH 24/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- models/maisi_ct_generative/scripts/sample.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 507e29a7..2fcc06fa 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -988,7 +988,7 @@ def find_closest_masks(self, num_img): if len(candidates) < num_img: raise ValueError(f"candidate masks are less than {num_img}).") - + # loop through the database and find closest combinations new_candidates = [] for c in candidates: @@ -1007,11 +1007,11 @@ def find_closest_masks(self, num_img): diff += abs(abs(c["spacing"][axis]) - self.spacing[axis]) if include_c: new_candidates.append((c, diff)) - + # choose top-2*num_img candidates (at least 5) new_candidates = sorted(new_candidates, key=lambda x: x[1])[: max(2 * num_img, 5)] final_candidates = [] - + # check top-2*num_img candidates and update spacing after resampling image_loader = monai.transforms.LoadImage(image_only=True, ensure_channel_first=True) for c, _ in new_candidates: From 740fc98eeb3e47e2e49e88779cdb2f2f3e18ba85 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 18 Feb 2025 04:22:41 +0000 Subject: [PATCH 25/50] make hish thresh of bone to 1000, change logic of regeneration when quality failed, will regenerate bsed on a new mask Signed-off-by: Can-Zhao --- .../scripts/quality_check.py | 3 + models/maisi_ct_generative/scripts/sample.py | 104 +++++++++--------- 2 files changed, 57 insertions(+), 50 deletions(-) diff --git a/models/maisi_ct_generative/scripts/quality_check.py b/models/maisi_ct_generative/scripts/quality_check.py index a41b6e6e..e4539b9a 100644 --- a/models/maisi_ct_generative/scripts/quality_check.py +++ b/models/maisi_ct_generative/scripts/quality_check.py @@ -112,6 +112,9 @@ def is_outlier(statistics, image_data, label_data, label_int_dict): low_thresh = min(stats["sigma_6_low"], stats["percentile_0_5"]) # or "sigma_12_low" depending on your needs high_thresh = max(stats["sigma_6_high"], stats["percentile_99_5"]) # or "sigma_12_high" depending on your needs + if label_name == "bone": + high_thresh = 1000. + # Retrieve the corresponding label integers labels = label_int_dict.get(label_name, []) masked_data = get_masked_data(label_data, image_data, labels) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 507e29a7..b738cf95 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -660,14 +660,18 @@ def sample_multiple_images(self, num_img): need_resample = True selected_mask_files = self.select_mask(candidate_mask_files, num_img) - if len(selected_mask_files) != num_img: + if len(selected_mask_files) < num_img: raise ValueError( ( - f"len(selected_mask_files) ({len(selected_mask_files)}) != num_img ({num_img}). " + f"len(selected_mask_files) ({len(selected_mask_files)}) < num_img ({num_img}). " "This should not happen. Please revisit function select_mask(self, candidate_mask_files, num_img)." ) ) - for item in selected_mask_files: + num_generated_img = 0 + for index_s in range(len(selected_mask_files)): + item = selected_mask_files[index_s] + if num_generated_img >= num_img: + break logging.info("---- Start preparing masks... ----") start_time = time.time() logging.info(f"Image will be generated based on {item}.") @@ -695,53 +699,53 @@ def sample_multiple_images(self, num_img): to_generate = True try_time = 0 modality_tensor = torch.ones_like(spacing_tensor[:, 0]).long() * self.modality_int - while to_generate: - synthetic_images, synthetic_labels = self.sample_one_pair( - combine_label_or, modality_tensor, spacing_tensor + # start generation + synthetic_images, synthetic_labels = self.sample_one_pair( + combine_label_or, modality_tensor, spacing_tensor + ) + # synthetic image quality check + pass_quality_check = self.quality_check( + synthetic_images.cpu().detach().numpy(), combine_label_or.cpu().detach().numpy() + ) + if pass_quality_check or (num_img - num_generated_img)>=(len(selected_mask_files)-index_s): + if not pass_quality_check: + logging.info( + "Generated image/label pair did not pass quality check, but will still save them. " + "Please consider changing spacing and output_size to facilitate a more realistic setting." ) - # synthetic image quality check - pass_quality_check = self.quality_check( - synthetic_images.cpu().detach().numpy(), combine_label_or.cpu().detach().numpy() + num_generated_img = num_generated_img +1 + # save image/label pairs + output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + synthetic_labels.meta["filename_or_obj"] = "sample.nii.gz" + synthetic_images = MetaTensor(synthetic_images, meta=synthetic_labels.meta) + img_saver = SaveImage( + output_dir=self.output_dir, + output_postfix=output_postfix + "_image", + output_ext=self.image_output_ext, + separate_folder=False, + ) + img_saver(synthetic_images[0]) + synthetic_images_filename = os.path.join( + self.output_dir, "sample_" + output_postfix + "_image" + self.image_output_ext + ) + # filter out the organs that are not in anatomy_list + # synthetic_labels = filter_mask_with_organs(synthetic_labels, self.anatomy_list) + label_saver = SaveImage( + output_dir=self.output_dir, + output_postfix=output_postfix + "_label", + output_ext=self.label_output_ext, + separate_folder=False, + ) + label_saver(synthetic_labels[0]) + synthetic_labels_filename = os.path.join( + self.output_dir, "sample_" + output_postfix + "_label" + self.label_output_ext + ) + output_filenames.append([synthetic_images_filename, synthetic_labels_filename]) + to_generate = False + else: + logging.info( + "Generated image/label pair did not pass quality check, will re-generate another pair." ) - if pass_quality_check or try_time > self.max_try_time: - # save image/label pairs - output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f") - synthetic_labels.meta["filename_or_obj"] = "sample.nii.gz" - synthetic_images = MetaTensor(synthetic_images, meta=synthetic_labels.meta) - img_saver = SaveImage( - output_dir=self.output_dir, - output_postfix=output_postfix + "_image", - output_ext=self.image_output_ext, - separate_folder=False, - ) - img_saver(synthetic_images[0]) - synthetic_images_filename = os.path.join( - self.output_dir, "sample_" + output_postfix + "_image" + self.image_output_ext - ) - # filter out the organs that are not in anatomy_list - synthetic_labels = filter_mask_with_organs(synthetic_labels, self.anatomy_list) - label_saver = SaveImage( - output_dir=self.output_dir, - output_postfix=output_postfix + "_label", - output_ext=self.label_output_ext, - separate_folder=False, - ) - label_saver(synthetic_labels[0]) - synthetic_labels_filename = os.path.join( - self.output_dir, "sample_" + output_postfix + "_label" + self.label_output_ext - ) - output_filenames.append([synthetic_images_filename, synthetic_labels_filename]) - to_generate = False - else: - logging.info( - "Generated image/label pair did not pass quality check, will re-generate another pair." - ) - try_time += 1 - if try_time > self.max_try_time: - logging.info( - "Generated image/label pair did not pass quality check. " - "Please consider changing spacing and output_size to facilitate a more realistic setting." - ) return output_filenames def select_mask(self, candidate_mask_files, num_img): @@ -758,7 +762,7 @@ def select_mask(self, candidate_mask_files, num_img): selected_mask_files = [] random.shuffle(candidate_mask_files) - for n in range(num_img): + for n in range(num_img*self.max_try_time): mask_file = candidate_mask_files[n % len(candidate_mask_files)] selected_mask_files.append({"mask_file": mask_file, "if_aug": True}) return selected_mask_files @@ -999,7 +1003,7 @@ def find_closest_masks(self, num_img): # we cannot upsample the mask too much include_c = False break - # check diff in FOV + # check diff in FOV, major metric diff += abs((abs(c["dim"][axis]*c["spacing"][axis]) - self.output_size[axis]*self.spacing[axis]) / 10) # check diff in dim diff += abs((abs(c["dim"][axis]) - self.output_size[axis]) / 100) From 708cd68b3e711aa0ad8ae68d325f4b3709af96fd Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 18 Feb 2025 04:29:53 +0000 Subject: [PATCH 26/50] add input restirction for z-axis FOV Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/sample.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index b738cf95..aa1bdda9 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -393,13 +393,13 @@ def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing f"spacing[0] have to be between 0.5 and 3.0 mm, spacing[2] have to be between 0.5 and 5.0 mm, yet got {spacing}." ) - if output_size[0] * spacing[0] < 256: + if output_size[0] * spacing[0] < 256 or output_size[2] * spacing[2] < 128: fov = [output_size[axis] * spacing[axis] for axis in range(3)] raise ValueError( ( f"`'spacing'({spacing}mm) and 'output_size'({output_size}) together decide the output field of view (FOV). " f"The FOV will be {fov}mm. We recommend the FOV in x and y axis to be at least 256mm for head, and at least " - "384mm for other body regions like abdomen. There is no such restriction for z-axis." + "384mm for other body regions like abdomen. For z-axis, we require it to be at least 128mm." ) ) From ce3532de24108a8d7e487a27775c0e54d0c8f12b Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 19 Feb 2025 05:45:32 +0000 Subject: [PATCH 27/50] filter other labels Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 5821beaa..8d2c5a7d 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -729,7 +729,7 @@ def sample_multiple_images(self, num_img): self.output_dir, "sample_" + output_postfix + "_image" + self.image_output_ext ) # filter out the organs that are not in anatomy_list - # synthetic_labels = filter_mask_with_organs(synthetic_labels, self.anatomy_list) + synthetic_labels = filter_mask_with_organs(synthetic_labels, self.anatomy_list) label_saver = SaveImage( output_dir=self.output_dir, output_postfix=output_postfix + "_label", From 5e1f5643798563735f99b5c80acf3d5c83b07c0c Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 26 Feb 2025 18:48:05 +0000 Subject: [PATCH 28/50] random seed Signed-off-by: Can-Zhao --- .../configs/inference.json | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json index 112bb0a8..412a5b69 100644 --- a/models/maisi_ct_generative/configs/inference.json +++ b/models/maisi_ct_generative/configs/inference.json @@ -11,27 +11,26 @@ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "trained_autoencoder_path": "$@model_dir + '/autoencoder_epoch273.pt'", "trained_diffusion_path": "$@model_dir + '/diff_unet_ckpt_epoch19350.pt'", - "trained_controlnet_path": "$@model_dir + '/controlnet_epoch150.pt'", + "trained_controlnet_path": "$@model_dir + '/controlnet_epoch208_v9.pt'", "trained_mask_generation_autoencoder_path": "$@model_dir + '/mask_generation_autoencoder.pt'", "trained_mask_generation_diffusion_path": "$@model_dir + '/mask_generation_diffusion_unet.pt'", - "all_mask_files_base_dir": "$@bundle_root + '/datasets/all_masks_flexible_size_and_spacing_3000'", - "all_mask_files_json": "$@bundle_root + '/configs/candidate_masks_flexible_size_and_spacing_3000.json'", + "all_mask_files_base_dir": "$@bundle_root + '/datasets/all_masks_flexible_size_and_spacing_4000'", + "all_mask_files_json": "$@bundle_root + '/configs/candidate_masks_flexible_size_and_spacing_4000.json'", "all_anatomy_size_condtions_json": "$@bundle_root + '/configs/all_anatomy_size_condtions.json'", "label_dict_json": "$@bundle_root + '/configs/label_dict.json'", "label_dict_remap_json": "$@bundle_root + '/configs/label_dict_124_to_132.json'", "real_img_median_statistics_file": "$@bundle_root + '/configs/image_median_statistics.json'", "num_output_samples": 1, "body_region": [ - "abdomen" ], "anatomy_list": [ - "liver" + "hepatic tumor" ], "modality": "ct", "controllable_anatomy_size": [], "num_inference_steps": 30, "mask_generation_num_inference_steps": 1000, - "random_seed": null, + "random_seed": 0, "spatial_dims": 3, "image_channels": 1, "latent_channels": 4, @@ -44,8 +43,8 @@ ], "image_output_ext": ".nii.gz", "label_output_ext": ".nii.gz", - "spacing_xy": 1.0, - "spacing_z": 1.0, + "spacing_xy": 1, + "spacing_z": 1, "spacing": [ "@spacing_xy", "@spacing_xy", @@ -68,7 +67,7 @@ 80, 80 ], - "autoencoder_sliding_window_infer_overlap": 0.25, + "autoencoder_sliding_window_infer_overlap": 0.4, "autoencoder_def": { "_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi", "spatial_dims": "@spatial_dims", @@ -97,7 +96,7 @@ "use_checkpointing": false, "use_convtranspose": false, "norm_float16": true, - "num_splits": 8, + "num_splits": 2, "dim_split": 1 }, "diffusion_unet_def": { @@ -308,6 +307,7 @@ "autoencoder_sliding_window_infer_overlap": "@autoencoder_sliding_window_infer_overlap" }, "run": [ + "$monai.utils.set_determinism(seed=@random_seed)", "$@ldm_sampler.sample_multiple_images(@num_output_samples)" ], "evaluator": null From e8f0a59e6b9ca06cda5bef650681e06333d8689a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Feb 2025 18:50:17 +0000 Subject: [PATCH 29/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- models/maisi_ct_generative/configs/inference.json | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json index 412a5b69..e1c67686 100644 --- a/models/maisi_ct_generative/configs/inference.json +++ b/models/maisi_ct_generative/configs/inference.json @@ -21,8 +21,7 @@ "label_dict_remap_json": "$@bundle_root + '/configs/label_dict_124_to_132.json'", "real_img_median_statistics_file": "$@bundle_root + '/configs/image_median_statistics.json'", "num_output_samples": 1, - "body_region": [ - ], + "body_region": [], "anatomy_list": [ "hepatic tumor" ], From b8339888b27e868ee46873592f9c289023a8fb1a Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 26 Feb 2025 20:10:57 +0000 Subject: [PATCH 30/50] FOV max limit Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 8d2c5a7d..36bf9c7f 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -393,7 +393,7 @@ def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing f"spacing[0] have to be between 0.5 and 3.0 mm, spacing[2] have to be between 0.5 and 5.0 mm, yet got {spacing}." ) - if output_size[0] * spacing[0] < 256 or output_size[2] * spacing[2] < 128: + if output_size[0] * spacing[0] < 256 or output_size[2] * spacing[2] < 128 or output_size[0] * spacing[0] >640 or output_size[2] * spacing[2] > 2000: fov = [output_size[axis] * spacing[axis] for axis in range(3)] raise ValueError( ( From 73564c0f32a3d4810ed62aa14f1207ce144c50b0 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 26 Feb 2025 20:12:48 +0000 Subject: [PATCH 31/50] FOV max limit Signed-off-by: Can-Zhao --- models/maisi_ct_generative/configs/inference.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json index 412a5b69..bb8be50c 100644 --- a/models/maisi_ct_generative/configs/inference.json +++ b/models/maisi_ct_generative/configs/inference.json @@ -11,7 +11,7 @@ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "trained_autoencoder_path": "$@model_dir + '/autoencoder_epoch273.pt'", "trained_diffusion_path": "$@model_dir + '/diff_unet_ckpt_epoch19350.pt'", - "trained_controlnet_path": "$@model_dir + '/controlnet_epoch208_v9.pt'", + "trained_controlnet_path": "$@model_dir + '/controlnet_current.pt'", "trained_mask_generation_autoencoder_path": "$@model_dir + '/mask_generation_autoencoder.pt'", "trained_mask_generation_diffusion_path": "$@model_dir + '/mask_generation_diffusion_unet.pt'", "all_mask_files_base_dir": "$@bundle_root + '/datasets/all_masks_flexible_size_and_spacing_4000'", @@ -30,12 +30,12 @@ "controllable_anatomy_size": [], "num_inference_steps": 30, "mask_generation_num_inference_steps": 1000, - "random_seed": 0, + "random_seed": 10, "spatial_dims": 3, "image_channels": 1, "latent_channels": 4, - "output_size_xy": 512, - "output_size_z": 512, + "output_size_xy": 256, + "output_size_z": 256, "output_size": [ "@output_size_xy", "@output_size_xy", @@ -43,8 +43,8 @@ ], "image_output_ext": ".nii.gz", "label_output_ext": ".nii.gz", - "spacing_xy": 1, - "spacing_z": 1, + "spacing_xy": 1.5, + "spacing_z": 1.5, "spacing": [ "@spacing_xy", "@spacing_xy", From 33cece04a003efa606a0346c11907cc203e05fae Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 26 Feb 2025 20:17:40 +0000 Subject: [PATCH 32/50] FOV max limit Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/sample.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 36bf9c7f..acf33eac 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -399,7 +399,8 @@ def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing ( f"`'spacing'({spacing}mm) and 'output_size'({output_size}) together decide the output field of view (FOV). " f"The FOV will be {fov}mm. We recommend the FOV in x and y axis to be at least 256mm for head, and at least " - "384mm for other body regions like abdomen. For z-axis, we require it to be at least 128mm." + "384mm for other body regions like abdomen, and less than 640mm. " + "For z-axis, we require it to be at least 128mm and less than 2000mm." ) ) From f834762fd86011b4db4f90723fd77a961e09e2c9 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Thu, 27 Feb 2025 05:53:19 +0000 Subject: [PATCH 33/50] fix deterministic issue Signed-off-by: Can-Zhao --- .../scripts/augmentation.py | 33 +++++++++++-------- models/maisi_ct_generative/scripts/sample.py | 3 +- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/models/maisi_ct_generative/scripts/augmentation.py b/models/maisi_ct_generative/scripts/augmentation.py index 6317781f..85e6ff76 100644 --- a/models/maisi_ct_generative/scripts/augmentation.py +++ b/models/maisi_ct_generative/scripts/augmentation.py @@ -60,7 +60,7 @@ def dilate3d(input_tensor, erosion=3): return output.squeeze(0).squeeze(0) -def augmentation_tumor_bone(pt_nda, output_size): +def augmentation_tumor_bone(pt_nda, output_size, random_seed): volume = pt_nda.squeeze(0) real_l_volume_ = torch.zeros_like(volume) real_l_volume_[volume == 128] = 1 @@ -74,6 +74,7 @@ def augmentation_tumor_bone(pt_nda, output_size): scale_range=(0.15, 0.15, 0), padding_mode="zeros", ) + elastic.set_random_state(seed=random_seed) tumor_szie = torch.sum((real_l_volume_ > 0).float()) ########################### @@ -112,7 +113,7 @@ def augmentation_tumor_bone(pt_nda, output_size): return pt_nda -def augmentation_tumor_liver(pt_nda, output_size): +def augmentation_tumor_liver(pt_nda, output_size, random_seed): volume = pt_nda.squeeze(0) real_l_volume_ = torch.zeros_like(volume) real_l_volume_[volume == 1] = 1 @@ -129,6 +130,7 @@ def augmentation_tumor_liver(pt_nda, output_size): scale_range=(0.2, 0.2, 0.2), padding_mode="zeros", ) + elastic.set_random_state(seed=random_seed) tumor_szie = torch.sum(real_l_volume_ == 2) ########################### @@ -161,7 +163,7 @@ def augmentation_tumor_liver(pt_nda, output_size): return pt_nda -def augmentation_tumor_lung(pt_nda, output_size): +def augmentation_tumor_lung(pt_nda, output_size, random_seed): volume = pt_nda.squeeze(0) real_l_volume_ = torch.zeros_like(volume) real_l_volume_[volume == 23] = 1 @@ -177,6 +179,7 @@ def augmentation_tumor_lung(pt_nda, output_size): scale_range=(0.15, 0.15, 0.15), padding_mode="zeros", ) + elastic.set_random_state(seed=random_seed) tumor_szie = torch.sum(real_l_volume_) # before move lung tumor maks, full the original location by lung labels @@ -224,7 +227,7 @@ def augmentation_tumor_lung(pt_nda, output_size): return pt_nda -def augmentation_tumor_pancreas(pt_nda, output_size): +def augmentation_tumor_pancreas(pt_nda, output_size, random_seed): volume = pt_nda.squeeze(0) real_l_volume_ = torch.zeros_like(volume) real_l_volume_[volume == 4] = 1 @@ -241,6 +244,7 @@ def augmentation_tumor_pancreas(pt_nda, output_size): scale_range=(0.1, 0.1, 0.1), padding_mode="zeros", ) + elastic.set_random_state(seed=random_seed) tumor_szie = torch.sum(real_l_volume_ == 2) ########################### @@ -273,7 +277,7 @@ def augmentation_tumor_pancreas(pt_nda, output_size): return pt_nda -def augmentation_tumor_colon(pt_nda, output_size): +def augmentation_tumor_colon(pt_nda, output_size, random_seed): volume = pt_nda.squeeze(0) real_l_volume_ = torch.zeros_like(volume) real_l_volume_[volume == 27] = 1 @@ -289,6 +293,7 @@ def augmentation_tumor_colon(pt_nda, output_size): scale_range=(0.1, 0.1, 0.1), padding_mode="zeros", ) + elastic.set_random_state(seed=random_seed) tumor_szie = torch.sum(real_l_volume_) ########################### @@ -330,37 +335,39 @@ def augmentation_tumor_colon(pt_nda, output_size): return pt_nda -def augmentation_body(pt_nda): +def augmentation_body(pt_nda, random_seed): volume = pt_nda.squeeze(0) zoom = RandZoom(min_zoom=0.99, max_zoom=1.01, mode="nearest", align_corners=None, prob=1.0) + zoom.set_random_state(seed=random_seed) + volume = zoom(volume) pt_nda = volume.unsqueeze(0) return pt_nda -def augmentation(pt_nda, output_size): +def augmentation(pt_nda, output_size, random_seed): label_list = torch.unique(pt_nda) label_list = list(label_list.cpu().numpy()) if 128 in label_list: print("augmenting bone lesion/tumor") - pt_nda = augmentation_tumor_bone(pt_nda, output_size) + pt_nda = augmentation_tumor_bone(pt_nda, output_size, random_seed) elif 26 in label_list: print("augmenting liver tumor") - pt_nda = augmentation_tumor_liver(pt_nda, output_size) + pt_nda = augmentation_tumor_liver(pt_nda, output_size, random_seed) elif 23 in label_list: print("augmenting lung tumor") - pt_nda = augmentation_tumor_lung(pt_nda, output_size) + pt_nda = augmentation_tumor_lung(pt_nda, output_size, random_seed) elif 24 in label_list: print("augmenting pancreas tumor") - pt_nda = augmentation_tumor_pancreas(pt_nda, output_size) + pt_nda = augmentation_tumor_pancreas(pt_nda, output_size, random_seed) elif 27 in label_list: print("augmenting colon tumor") - pt_nda = augmentation_tumor_colon(pt_nda, output_size) + pt_nda = augmentation_tumor_colon(pt_nda, output_size, random_seed) else: print("augmenting body") - pt_nda = augmentation_body(pt_nda) + pt_nda = augmentation_body(pt_nda, random_seed) return pt_nda diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index acf33eac..1885ab62 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -532,6 +532,7 @@ def __init__( Args: Various parameters related to model configuration, input settings, and output specifications. """ + self.random_seed = random_seed if random_seed is not None: set_determinism(seed=random_seed) @@ -692,7 +693,7 @@ def sample_multiple_images(self, num_img): combine_label_or = self.ensure_output_size_and_spacing(combine_label_or) # mask augmentation if if_aug: - combine_label_or = augmentation(combine_label_or, self.output_size) + combine_label_or = augmentation(combine_label_or, self.output_size, random_seed=self.random_seed) end_time = time.time() logging.info(f"---- Mask preparation time: {end_time - start_time} seconds ----") torch.cuda.empty_cache() From 1484429eb6db407dc23ea9a954b68a7ecb2e8687 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Feb 2025 05:53:59 +0000 Subject: [PATCH 34/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- models/maisi_ct_generative/scripts/augmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/scripts/augmentation.py b/models/maisi_ct_generative/scripts/augmentation.py index 85e6ff76..64469403 100644 --- a/models/maisi_ct_generative/scripts/augmentation.py +++ b/models/maisi_ct_generative/scripts/augmentation.py @@ -340,7 +340,7 @@ def augmentation_body(pt_nda, random_seed): zoom = RandZoom(min_zoom=0.99, max_zoom=1.01, mode="nearest", align_corners=None, prob=1.0) zoom.set_random_state(seed=random_seed) - + volume = zoom(volume) pt_nda = volume.unsqueeze(0) From 4486c615cf714fb62bab13009e5392bb8f54c669 Mon Sep 17 00:00:00 2001 From: binliu Date: Wed, 5 Mar 2025 22:16:03 +0800 Subject: [PATCH 35/50] rename the model name Signed-off-by: binliu --- models/maisi_ct_generative/configs/inference.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json index ca221596..8aeffd7e 100644 --- a/models/maisi_ct_generative/configs/inference.json +++ b/models/maisi_ct_generative/configs/inference.json @@ -9,9 +9,9 @@ "output_dir": "$@bundle_root + '/output'", "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)", "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", - "trained_autoencoder_path": "$@model_dir + '/autoencoder_epoch273.pt'", - "trained_diffusion_path": "$@model_dir + '/diff_unet_ckpt_epoch19350.pt'", - "trained_controlnet_path": "$@model_dir + '/controlnet_current.pt'", + "trained_autoencoder_path": "$@model_dir + '/autoencoder.pt'", + "trained_diffusion_path": "$@model_dir + '/diffusion_unet.pt'", + "trained_controlnet_path": "$@model_dir + '/controlnet.pt'", "trained_mask_generation_autoencoder_path": "$@model_dir + '/mask_generation_autoencoder.pt'", "trained_mask_generation_diffusion_path": "$@model_dir + '/mask_generation_diffusion_unet.pt'", "all_mask_files_base_dir": "$@bundle_root + '/datasets/all_masks_flexible_size_and_spacing_4000'", From fd6b7020a2f901108dbebf27ab33a22307791d61 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Fri, 7 Mar 2025 19:27:54 +0000 Subject: [PATCH 36/50] update doc for inference Signed-off-by: Can-Zhao --- models/maisi_ct_generative/docs/README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/models/maisi_ct_generative/docs/README.md b/models/maisi_ct_generative/docs/README.md index be7937b1..c46e6645 100644 --- a/models/maisi_ct_generative/docs/README.md +++ b/models/maisi_ct_generative/docs/README.md @@ -4,7 +4,7 @@ This bundle is for Nvidia MAISI (Medical AI for Synthetic Imaging), a 3D Latent The inference workflow of MAISI is depicted in the figure below. It first generates latent features from random noise by applying multiple denoising steps using the trained diffusion model. Then it decodes the denoised latent features into images using the trained autoencoder.

- MAISI inference scheme + MAISI inference scheme

MAISI is based on the following papers: @@ -13,6 +13,8 @@ MAISI is based on the following papers: [**ControlNet:** Lvmin Zhang, Anyi Rao, Maneesh Agrawala; “Adding Conditional Control to Text-to-Image Diffusion Models.” ICCV 2023.](https://openaccess.thecvf.com/content/ICCV2023/papers/Zhang_Adding_Conditional_Control_to_Text-to-Image_Diffusion_Models_ICCV_2023_paper.pdf) +[**Rectified Flow:** Liu, Xingchao, and Chengyue Gong. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow." ICLR 2023.](https://arxiv.org/pdf/2209.03003) + #### Example synthetic image An example result from inference is shown below: ![Example synthetic image](https://developer.download.nvidia.com/assets/Clara/Images/monai_maisi_ct_generative_example_synthetic_data.png) @@ -27,11 +29,11 @@ The information for the inference input, like body region and anatomy to generat - `"num_output_samples"`: int, the number of output image/mask pairs it will generate. - `"spacing"`: voxel size of generated images. E.g., if set to `[1.5, 1.5, 2.0]`, it will generate images with a resolution of 1.5×1.5×2.0 mm. The spacing for x and y axes has to be between 0.5 and 3.0 mm and the spacing for the z axis has to be between 0.5 and 5.0 mm. -- `"output_size"`: volume size of generated images. E.g., if set to `[512, 512, 256]`, it will generate images with size of 512×512×256. They need to be divisible by 16. If you have a small GPU memory size, you should adjust it to small numbers. Note that `"spacing"` and `"output_size"` together decide the output field of view (FOV). For eample, if set them to `[1.5, 1.5, 2.0]`mm and `[512, 512, 256]`, the FOV is 768×768×512 mm. We recommend output_size is the FOV in x and y axis are same and to be at least 256mm for head, and at least 384mm for other body regions like abdomen. The output size for the x and y axes can be selected from [256, 384, 512], while for the z axis, it can be chosen from [128, 256, 384, 512, 640, 768]. +- `"output_size"`: volume size of generated images. E.g., if set to `[512, 512, 256]`, it will generate images with size of 512×512×256. They need to be divisible by 16. If you have a small GPU memory size, you should adjust it to small numbers. Note that `"spacing"` and `"output_size"` together decide the output field of view (FOV). For eample, if set them to `[1.5, 1.5, 2.0]`mm and `[512, 512, 256]`, the FOV is 768×768×512 mm. We recommend output_size is the FOV in x and y axis are same and to be at least 256mm for head, at least 384mm for other body regions like abdomen, and no larger than 640mm. The output size for the x and y axes can be selected from [256, 384, 512], while for the z axis, it can be chosen from [128, 256, 384, 512, 640, 768]. - `"controllable_anatomy_size"`: a list of controllable anatomy and its size scale (0--1). E.g., if set to `[["liver", 0.5],["hepatic tumor", 0.3]]`, the generated image will contain liver that have a median size, with size around 50% percentile, and hepatic tumor that is relatively small, with around 30% percentile. In addition, if the size scale is set to -1, it indicates that the organ does not exist or should be removed. The output will contain paired image and segmentation mask for the controllable anatomy. The following organs support generation with a controllable size: ``["liver", "gallbladder", "stomach", "pancreas", "colon", "lung tumor", "bone lesion", "hepatic tumor", "colon cancer primaries", "pancreatic tumor"]``. The raw output of the current mask generation model has a fixed size of $256^3$ voxels with a spacing of $1.5^3$ mm. If the "output_size" differs from this default, the generated masks will be resampled to the desired `"output_size"` and `"spacing"`. Note that resampling may degrade the quality of the generated masks and could trigger multiple inference attempts if the images fail to pass the [image quality check](../scripts/quality_check.py). -- `"body_region"`: If "controllable_anatomy_size" is not specified, "body_region" will be used to constrain the region of generated images. It needs to be chosen from "head", "chest", "thorax", "abdomen", "pelvis", "lower". +- `"body_region"`: Deprecated, please leave it as empty `"[]"`. - `"anatomy_list"`: If "controllable_anatomy_size" is not specified, the output will contain paired image and segmentation mask for the anatomy in "./configs/label_dict.json". - `"autoencoder_sliding_window_infer_size"`: in order to save GPU memory, we use sliding window inference when decoding latents to image when `"output_size"` is large. This is the patch size of the sliding window. Small value will reduce GPU memory but increase time cost. They need to be divisible by 16. - `"autoencoder_sliding_window_infer_overlap"`: float between 0 and 1. Large value will reduce the stitching artifacts when stitching patches during sliding window inference, but increase time cost. If you do not observe seam lines in the generated image result, you can use a smaller value to save inference time. From ebec91bd49535cd9fb6c4a988108aa7ef5031bae Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Sat, 8 Mar 2025 00:51:17 +0000 Subject: [PATCH 37/50] change to 3000 Signed-off-by: Can-Zhao --- models/maisi_ct_generative/configs/inference.json | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json index ca221596..15c490a6 100644 --- a/models/maisi_ct_generative/configs/inference.json +++ b/models/maisi_ct_generative/configs/inference.json @@ -10,12 +10,12 @@ "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)", "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "trained_autoencoder_path": "$@model_dir + '/autoencoder_epoch273.pt'", - "trained_diffusion_path": "$@model_dir + '/diff_unet_ckpt_epoch19350.pt'", - "trained_controlnet_path": "$@model_dir + '/controlnet_current.pt'", + "trained_diffusion_path": "$@model_dir + '/diffusion_unet.pt'", + "trained_controlnet_path": "$@model_dir + '/controlnet.pt'", "trained_mask_generation_autoencoder_path": "$@model_dir + '/mask_generation_autoencoder.pt'", "trained_mask_generation_diffusion_path": "$@model_dir + '/mask_generation_diffusion_unet.pt'", - "all_mask_files_base_dir": "$@bundle_root + '/datasets/all_masks_flexible_size_and_spacing_4000'", - "all_mask_files_json": "$@bundle_root + '/configs/candidate_masks_flexible_size_and_spacing_4000.json'", + "all_mask_files_base_dir": "$@bundle_root + '/datasets/all_masks_flexible_size_and_spacing_3000'", + "all_mask_files_json": "$@bundle_root + '/configs/candidate_masks_flexible_size_and_spacing_3000.json'", "all_anatomy_size_condtions_json": "$@bundle_root + '/configs/all_anatomy_size_condtions.json'", "label_dict_json": "$@bundle_root + '/configs/label_dict.json'", "label_dict_remap_json": "$@bundle_root + '/configs/label_dict_124_to_132.json'", From 71121a9d45ce346607828144adacd2d3d80c9826 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Sat, 15 Mar 2025 04:40:27 +0000 Subject: [PATCH 38/50] back to 3000 Signed-off-by: Can-Zhao --- models/maisi_ct_generative/configs/train.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/configs/train.json b/models/maisi_ct_generative/configs/train.json index 7138d616..f48ece5c 100644 --- a/models/maisi_ct_generative/configs/train.json +++ b/models/maisi_ct_generative/configs/train.json @@ -268,4 +268,4 @@ "$@train#trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, ignite.handlers.TerminateOnNan())", "$@train#trainer.run()" ] -} +} \ No newline at end of file From 33e587ae7e0c9fca9fe50bfd1491c6aaf754cc2e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 Mar 2025 04:41:09 +0000 Subject: [PATCH 39/50] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- models/maisi_ct_generative/configs/train.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/configs/train.json b/models/maisi_ct_generative/configs/train.json index f48ece5c..7138d616 100644 --- a/models/maisi_ct_generative/configs/train.json +++ b/models/maisi_ct_generative/configs/train.json @@ -268,4 +268,4 @@ "$@train#trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, ignite.handlers.TerminateOnNan())", "$@train#trainer.run()" ] -} \ No newline at end of file +} From 3cea98388a8d6bd85ebc617a22beda9c4ccbf1a6 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Sat, 15 Mar 2025 04:42:17 +0000 Subject: [PATCH 40/50] revert Signed-off-by: Can-Zhao --- models/maisi_ct_generative/configs/inference.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json index bf9ced1f..40c1a9c8 100644 --- a/models/maisi_ct_generative/configs/inference.json +++ b/models/maisi_ct_generative/configs/inference.json @@ -23,18 +23,18 @@ "num_output_samples": 1, "body_region": [], "anatomy_list": [ - "hepatic tumor" + "liver" ], "modality": "ct", "controllable_anatomy_size": [], "num_inference_steps": 30, "mask_generation_num_inference_steps": 1000, - "random_seed": 10, + "random_seed": null, "spatial_dims": 3, "image_channels": 1, "latent_channels": 4, - "output_size_xy": 256, - "output_size_z": 256, + "output_size_xy": 512, + "output_size_z": 512, "output_size": [ "@output_size_xy", "@output_size_xy", @@ -42,8 +42,8 @@ ], "image_output_ext": ".nii.gz", "label_output_ext": ".nii.gz", - "spacing_xy": 1.5, - "spacing_z": 1.5, + "spacing_xy": 1.0, + "spacing_z": 1.0, "spacing": [ "@spacing_xy", "@spacing_xy", From b775a644768d35feaf8a2fa761202f84335befd9 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Sat, 15 Mar 2025 04:44:07 +0000 Subject: [PATCH 41/50] revert Signed-off-by: Can-Zhao --- models/maisi_ct_generative/configs/inference.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json index 40c1a9c8..705fbc47 100644 --- a/models/maisi_ct_generative/configs/inference.json +++ b/models/maisi_ct_generative/configs/inference.json @@ -248,7 +248,7 @@ "num_train_timesteps": 1000, "use_discrete_timesteps": false, "use_timestep_transform": true, - "sample_method": "logit-normal", + "sample_method": "uniform", "scale": 1.2 }, "mask_generation_noise_scheduler": { From c1317b5e44155e571691d9310ef610bb1737f4c6 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Sat, 15 Mar 2025 04:44:28 +0000 Subject: [PATCH 42/50] revert Signed-off-by: Can-Zhao --- models/maisi_ct_generative/configs/inference.json | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json index 705fbc47..35f76248 100644 --- a/models/maisi_ct_generative/configs/inference.json +++ b/models/maisi_ct_generative/configs/inference.json @@ -248,8 +248,7 @@ "num_train_timesteps": 1000, "use_discrete_timesteps": false, "use_timestep_transform": true, - "sample_method": "uniform", - "scale": 1.2 + "sample_method": "uniform" }, "mask_generation_noise_scheduler": { "_target_": "monai.networks.schedulers.ddpm.DDPMScheduler", From db604e2e98794fff8ae3dda484791013712d27a2 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Sat, 15 Mar 2025 04:48:22 +0000 Subject: [PATCH 43/50] reformat Signed-off-by: Can-Zhao --- .../scripts/quality_check.py | 2 +- models/maisi_ct_generative/scripts/sample.py | 73 +++++++------------ 2 files changed, 27 insertions(+), 48 deletions(-) diff --git a/models/maisi_ct_generative/scripts/quality_check.py b/models/maisi_ct_generative/scripts/quality_check.py index e4539b9a..bff49b6d 100644 --- a/models/maisi_ct_generative/scripts/quality_check.py +++ b/models/maisi_ct_generative/scripts/quality_check.py @@ -113,7 +113,7 @@ def is_outlier(statistics, image_data, label_data, label_int_dict): high_thresh = max(stats["sigma_6_high"], stats["percentile_99_5"]) # or "sigma_12_high" depending on your needs if label_name == "bone": - high_thresh = 1000. + high_thresh = 1000.0 # Retrieve the corresponding label integers labels = label_int_dict.get(label_name, []) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 1885ab62..9f46d6d5 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -27,12 +27,7 @@ from .augmentation import augmentation from .find_masks import find_masks from .quality_check import is_outlier -from .utils import ( - binarize_labels, - dynamic_infer, - general_mask_generation_post_process, - remap_labels, -) +from .utils import binarize_labels, dynamic_infer, general_mask_generation_post_process, remap_labels modality_mapping = { "unknown": 0, @@ -393,7 +388,12 @@ def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing f"spacing[0] have to be between 0.5 and 3.0 mm, spacing[2] have to be between 0.5 and 5.0 mm, yet got {spacing}." ) - if output_size[0] * spacing[0] < 256 or output_size[2] * spacing[2] < 128 or output_size[0] * spacing[0] >640 or output_size[2] * spacing[2] > 2000: + if ( + output_size[0] * spacing[0] < 256 + or output_size[2] * spacing[2] < 128 + or output_size[0] * spacing[0] > 640 + or output_size[2] * spacing[2] > 2000 + ): fov = [output_size[axis] * spacing[axis] for axis in range(3)] raise ValueError( ( @@ -458,9 +458,7 @@ def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing ) else: logging.info( - ( - f"`controllable_anatomy_size` is empty.\nWe will synthesize based on `anatomy_list`: ({anatomy_list})." - ) + (f"`controllable_anatomy_size` is empty.\nWe will synthesize based on `anatomy_list`: ({anatomy_list}).") ) # check body_region format available_body_region = ["head", "chest", "thorax", "abdomen", "pelvis", "lower"] @@ -647,12 +645,7 @@ def sample_multiple_images(self, num_img): need_resample = False # find candidate mask and save to candidate_mask_files candidate_mask_files = find_masks( - self.anatomy_list, - self.spacing, - self.output_size, - True, - self.all_mask_files_json, - self.data_root, + self.anatomy_list, self.spacing, self.output_size, True, self.all_mask_files_json, self.data_root ) if len(candidate_mask_files) < num_img: # if we cannot find enough masks based on the exact match of anatomy list, spacing, and output size, @@ -679,16 +672,12 @@ def sample_multiple_images(self, num_img): logging.info(f"Image will be generated based on {item}.") if len(self.controllable_anatomy_size) > 0: # generate a synthetic mask - (combine_label_or, spacing_tensor) = ( - self.prepare_one_mask_and_meta_info(anatomy_size_condtion) - ) + (combine_label_or, spacing_tensor) = self.prepare_one_mask_and_meta_info(anatomy_size_condtion) else: # read in mask file mask_file = item["mask_file"] if_aug = item["if_aug"] - (combine_label_or, spacing_tensor) = ( - self.read_mask_information(mask_file) - ) + (combine_label_or, spacing_tensor) = self.read_mask_information(mask_file) if need_resample: combine_label_or = self.ensure_output_size_and_spacing(combine_label_or) # mask augmentation @@ -702,20 +691,18 @@ def sample_multiple_images(self, num_img): try_time = 0 modality_tensor = torch.ones_like(spacing_tensor[:, 0]).long() * self.modality_int # start generation - synthetic_images, synthetic_labels = self.sample_one_pair( - combine_label_or, modality_tensor, spacing_tensor - ) + synthetic_images, synthetic_labels = self.sample_one_pair(combine_label_or, modality_tensor, spacing_tensor) # synthetic image quality check pass_quality_check = self.quality_check( synthetic_images.cpu().detach().numpy(), combine_label_or.cpu().detach().numpy() ) - if pass_quality_check or (num_img - num_generated_img)>=(len(selected_mask_files)-index_s): + if pass_quality_check or (num_img - num_generated_img) >= (len(selected_mask_files) - index_s): if not pass_quality_check: logging.info( - "Generated image/label pair did not pass quality check, but will still save them. " - "Please consider changing spacing and output_size to facilitate a more realistic setting." - ) - num_generated_img = num_generated_img +1 + "Generated image/label pair did not pass quality check, but will still save them. " + "Please consider changing spacing and output_size to facilitate a more realistic setting." + ) + num_generated_img = num_generated_img + 1 # save image/label pairs output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f") synthetic_labels.meta["filename_or_obj"] = "sample.nii.gz" @@ -745,9 +732,7 @@ def sample_multiple_images(self, num_img): output_filenames.append([synthetic_images_filename, synthetic_labels_filename]) to_generate = False else: - logging.info( - "Generated image/label pair did not pass quality check, will re-generate another pair." - ) + logging.info("Generated image/label pair did not pass quality check, will re-generate another pair.") return output_filenames def select_mask(self, candidate_mask_files, num_img): @@ -764,7 +749,7 @@ def select_mask(self, candidate_mask_files, num_img): selected_mask_files = [] random.shuffle(candidate_mask_files) - for n in range(num_img*self.max_try_time): + for n in range(num_img * self.max_try_time): mask_file = candidate_mask_files[n % len(candidate_mask_files)] selected_mask_files.append({"mask_file": mask_file, "if_aug": True}) return selected_mask_files @@ -961,13 +946,10 @@ def read_mask_information(self, mask_file): """ val_data = self.val_transforms(mask_file) - for key in ["pseudo_label", "spacing", ]: + for key in ["pseudo_label", "spacing"]: val_data[key] = val_data[key].unsqueeze(0).to(self.device) - return ( - val_data["pseudo_label"], - val_data["spacing"], - ) + return (val_data["pseudo_label"], val_data["spacing"]) def find_closest_masks(self, num_img): """ @@ -984,12 +966,7 @@ def find_closest_masks(self, num_img): """ # first check the database based on anatomy list candidates = find_masks( - self.anatomy_list, - self.spacing, - self.output_size, - False, - self.all_mask_files_json, - self.data_root, + self.anatomy_list, self.spacing, self.output_size, False, self.all_mask_files_json, self.data_root ) if len(candidates) < num_img: @@ -1001,12 +978,14 @@ def find_closest_masks(self, num_img): diff = 0 include_c = True for axis in range(3): - if abs(c["dim"][axis]) < self.output_size[axis]-64: + if abs(c["dim"][axis]) < self.output_size[axis] - 64: # we cannot upsample the mask too much include_c = False break # check diff in FOV, major metric - diff += abs((abs(c["dim"][axis]*c["spacing"][axis]) - self.output_size[axis]*self.spacing[axis]) / 10) + diff += abs( + (abs(c["dim"][axis] * c["spacing"][axis]) - self.output_size[axis] * self.spacing[axis]) / 10 + ) # check diff in dim diff += abs((abs(c["dim"][axis]) - self.output_size[axis]) / 100) # check diff in spacing From f8dabcaafc581c9fc0b41f2143c69800b908a63d Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Mon, 17 Mar 2025 23:40:57 +0000 Subject: [PATCH 44/50] reformat Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/sample.py | 1 - 1 file changed, 1 deletion(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 9f46d6d5..3264f1df 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -688,7 +688,6 @@ def sample_multiple_images(self, num_img): torch.cuda.empty_cache() # generate image/label pairs to_generate = True - try_time = 0 modality_tensor = torch.ones_like(spacing_tensor[:, 0]).long() * self.modality_int # start generation synthetic_images, synthetic_labels = self.sample_one_pair(combine_label_or, modality_tensor, spacing_tensor) From d20a1e2b02a6653770713f93bc019167774fc68a Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 18 Mar 2025 04:14:44 +0000 Subject: [PATCH 45/50] reformat Signed-off-by: Can-Zhao --- models/maisi_ct_generative/scripts/sample.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py index 3264f1df..d22d26e2 100644 --- a/models/maisi_ct_generative/scripts/sample.py +++ b/models/maisi_ct_generative/scripts/sample.py @@ -687,7 +687,6 @@ def sample_multiple_images(self, num_img): logging.info(f"---- Mask preparation time: {end_time - start_time} seconds ----") torch.cuda.empty_cache() # generate image/label pairs - to_generate = True modality_tensor = torch.ones_like(spacing_tensor[:, 0]).long() * self.modality_int # start generation synthetic_images, synthetic_labels = self.sample_one_pair(combine_label_or, modality_tensor, spacing_tensor) @@ -729,7 +728,6 @@ def sample_multiple_images(self, num_img): self.output_dir, "sample_" + output_postfix + "_label" + self.label_output_ext ) output_filenames.append([synthetic_images_filename, synthetic_labels_filename]) - to_generate = False else: logging.info("Generated image/label pair did not pass quality check, will re-generate another pair.") return output_filenames From dcd7eb83b1bfc2fb8f3ffe01b352dd167a08907b Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 18 Mar 2025 06:02:54 +0000 Subject: [PATCH 46/50] model link Signed-off-by: Can-Zhao --- models/maisi_ct_generative/large_files.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/models/maisi_ct_generative/large_files.yml b/models/maisi_ct_generative/large_files.yml index 97b26885..dacb00ed 100644 --- a/models/maisi_ct_generative/large_files.yml +++ b/models/maisi_ct_generative/large_files.yml @@ -3,13 +3,13 @@ large_files: url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_autoencoder_epoch273_alternative.pt" hash_val: "917cfb1e49631c8a713e3bb7c758fbca" hash_type: "md5" -- path: "models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt" - url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1_alternative.pt" - hash_val: "623bd02ff223b70d280cc994fcb70a69" +- path: "models/diff_unet_ckpt.pt" + url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/diff_unet_ckpt_rflow_epoch19350.pt" + hash_val: "10501d59a3066802087c82ebd7a71719" hash_type: "md5" -- path: "models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt" - url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current_alternative.pt" - hash_val: "6c36572335372f405a0e85c760fa6dee" +- path: "models/controlnet.pt" + url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/controlnet_rflow_epoch208.pt" + hash_val: "49933da32826c0f7ca17016ccd13e23b" hash_type: "md5" - path: "models/mask_generation_autoencoder.pt" url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/mask_generation_autoencoder.pt" From f8c5740af76002de31d3680c9087abc54f6b2e22 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 18 Mar 2025 06:04:18 +0000 Subject: [PATCH 47/50] model link Signed-off-by: Can-Zhao --- models/maisi_ct_generative/large_files.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/maisi_ct_generative/large_files.yml b/models/maisi_ct_generative/large_files.yml index dacb00ed..45abac2b 100644 --- a/models/maisi_ct_generative/large_files.yml +++ b/models/maisi_ct_generative/large_files.yml @@ -1,9 +1,9 @@ large_files: -- path: "models/autoencoder_epoch273.pt" +- path: "models/autoencoder.pt" url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_autoencoder_epoch273_alternative.pt" hash_val: "917cfb1e49631c8a713e3bb7c758fbca" hash_type: "md5" -- path: "models/diff_unet_ckpt.pt" +- path: "models/diff_unet.pt" url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/diff_unet_ckpt_rflow_epoch19350.pt" hash_val: "10501d59a3066802087c82ebd7a71719" hash_type: "md5" From 347ad7111c8ed206e37c39d04b4dfad91d8a8460 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Tue, 18 Mar 2025 18:44:25 +0000 Subject: [PATCH 48/50] add back old ckpt Signed-off-by: Can-Zhao --- models/maisi_ct_generative/large_files.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/models/maisi_ct_generative/large_files.yml b/models/maisi_ct_generative/large_files.yml index 45abac2b..4dcc9631 100644 --- a/models/maisi_ct_generative/large_files.yml +++ b/models/maisi_ct_generative/large_files.yml @@ -3,6 +3,14 @@ large_files: url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_autoencoder_epoch273_alternative.pt" hash_val: "917cfb1e49631c8a713e3bb7c758fbca" hash_type: "md5" +- path: "models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt" + url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1_alternative.pt" + hash_val: "623bd02ff223b70d280cc994fcb70a69" + hash_type: "md5" +- path: "models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt" + url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current_alternative.pt" + hash_val: "6c36572335372f405a0e85c760fa6dee" + hash_type: "md5" - path: "models/diff_unet.pt" url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/diff_unet_ckpt_rflow_epoch19350.pt" hash_val: "10501d59a3066802087c82ebd7a71719" From 359d21246aa2b16b5347f69cc117a2d10bcb53dd Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 19 Mar 2025 03:41:29 +0000 Subject: [PATCH 49/50] model name Signed-off-by: Can-Zhao --- models/maisi_ct_generative/large_files.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/maisi_ct_generative/large_files.yml b/models/maisi_ct_generative/large_files.yml index 4dcc9631..17400eab 100644 --- a/models/maisi_ct_generative/large_files.yml +++ b/models/maisi_ct_generative/large_files.yml @@ -11,7 +11,7 @@ large_files: url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current_alternative.pt" hash_val: "6c36572335372f405a0e85c760fa6dee" hash_type: "md5" -- path: "models/diff_unet.pt" +- path: "models/diffusion_unet.pt" url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/diff_unet_ckpt_rflow_epoch19350.pt" hash_val: "10501d59a3066802087c82ebd7a71719" hash_type: "md5" From 000d5e64257d9264422888e21f42982642f87682 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 19 Mar 2025 05:46:48 +0000 Subject: [PATCH 50/50] update test Signed-off-by: Can-Zhao --- ci/unit_tests/test_maisi_ct_generative.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/ci/unit_tests/test_maisi_ct_generative.py b/ci/unit_tests/test_maisi_ct_generative.py index e2b9db54..94389480 100644 --- a/ci/unit_tests/test_maisi_ct_generative.py +++ b/ci/unit_tests/test_maisi_ct_generative.py @@ -85,17 +85,6 @@ } ] -TEST_CASE_INFER_ERROR = [ - { - "bundle_root": "models/maisi_ct_generative", - "num_output_samples": 1, - "output_size": [256, 256, 256], - "body_region": ["head"], - "anatomy_list": ["colon cancer primaries"], - }, - "Cannot find body region with given anatomy list.", -] - TEST_CASE_INFER_ERROR_2 = [ { "bundle_root": "models/maisi_ct_generative", @@ -277,7 +266,7 @@ def test_infer_config(self, override): else: self.assertTrue(output_file.endswith(".nii.gz")) - @parameterized.expand([TEST_CASE_INFER_ERROR, TEST_CASE_INFER_ERROR_7]) + @parameterized.expand([TEST_CASE_INFER_ERROR_7]) def test_infer_config_error_input(self, override, expected_error): # update override override["output_dir"] = self.output_dir