diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index ee94b1ebdb..f023db490e 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -11,6 +11,7 @@ from __future__ import annotations +import inspect import math import warnings from abc import ABC, abstractmethod @@ -861,6 +862,96 @@ def __init__(self, scheduler: Scheduler) -> None: # type: ignore[override] self.scheduler = scheduler + @staticmethod + def _scheduler_step_supports_kwarg(scheduler: Scheduler, kwarg: str) -> bool: + try: + return kwarg in inspect.signature(scheduler.step).parameters + except (TypeError, ValueError): + return False + + @staticmethod + def _get_previous_sample_from_step_output(step_output: Any) -> torch.Tensor: + if isinstance(step_output, tuple): + return step_output[0] + if isinstance(step_output, Mapping): + return step_output["prev_sample"] + if hasattr(step_output, "prev_sample"): + return step_output.prev_sample + raise TypeError("Unsupported scheduler.step output. Expected a tuple or an object with `prev_sample`.") + + @staticmethod + def _get_scheduler_name(scheduler: Scheduler) -> str: + if hasattr(scheduler, "_get_name"): + return scheduler._get_name() + return scheduler.__class__.__name__ + + @staticmethod + def _get_scheduler_config_value(scheduler: Scheduler, name: str, default: Any = None) -> Any: + config = getattr(scheduler, "config", None) + if isinstance(config, Mapping): + if name in config: + return config[name] + elif config is not None and hasattr(config, name): + return getattr(config, name) + + if hasattr(scheduler, name): + return getattr(scheduler, name) + return default + + @staticmethod + def _get_posterior_mean( + scheduler: Scheduler, timestep: int | torch.Tensor, x_0: torch.Tensor, x_t: torch.Tensor + ) -> torch.Tensor: + alpha_t = scheduler.alphas[timestep] + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = scheduler.alphas_cumprod[timestep - 1] if timestep > 0 else scheduler.one + + x_0_coefficient = alpha_prod_t_prev.sqrt() * scheduler.betas[timestep] / (1 - alpha_prod_t) + x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + + return x_0_coefficient * x_0 + x_t_coefficient * x_t + + def _get_posterior_variance( + self, scheduler: Scheduler, timestep: int | torch.Tensor, predicted_variance: torch.Tensor | None = None + ) -> torch.Tensor: + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = scheduler.alphas_cumprod[timestep - 1] if timestep > 0 else scheduler.one + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * scheduler.betas[timestep] + variance_type = self._get_scheduler_config_value(scheduler, "variance_type") + + if variance_type == "fixed_small": + variance = torch.clamp(variance, min=1e-20) + elif variance_type == "fixed_large": + variance = scheduler.betas[timestep] + elif variance_type == "learned" and predicted_variance is not None: + return predicted_variance + elif variance_type == "learned_range" and predicted_variance is not None: + min_log = variance + max_log = scheduler.betas[timestep] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def _scheduler_step( + self, + scheduler: Scheduler, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.Tensor, + next_timestep: int | torch.Tensor | None = None, + ) -> torch.Tensor: + step_kwargs = {} + if self._scheduler_step_supports_kwarg(scheduler, "return_dict"): + step_kwargs["return_dict"] = False + + if isinstance(scheduler, RFlowScheduler): + step_output = scheduler.step(model_output, timestep, sample, next_timestep, **step_kwargs) # type: ignore + else: + step_output = scheduler.step(model_output, timestep, sample, **step_kwargs) # type: ignore + + return self._get_previous_sample_from_step_output(step_output) + def __call__( # type: ignore[override] self, inputs: torch.Tensor, @@ -940,7 +1031,12 @@ def sample( scheduler = self.scheduler image = input_noise - all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype))) + all_next_timesteps = torch.cat( + ( + scheduler.timesteps[1:], + torch.tensor([0], dtype=scheduler.timesteps.dtype, device=scheduler.timesteps.device), + ) + ) if verbose and has_tqdm: progress_bar = tqdm( zip(scheduler.timesteps, all_next_timesteps), @@ -984,10 +1080,9 @@ def sample( model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond) # 2. compute previous image: x_t -> x_t-1 - if not isinstance(scheduler, RFlowScheduler): - image, _ = scheduler.step(model_output, t, image) # type: ignore - else: - image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore + image = self._scheduler_step( + scheduler=scheduler, model_output=model_output, timestep=t, sample=image, next_timestep=next_t + ) if save_intermediates and t % intermediate_steps == 0: intermediates.append(image) @@ -1028,10 +1123,10 @@ def get_likelihood( if not scheduler: scheduler = self.scheduler - if scheduler._get_name() != "DDPMScheduler": + scheduler_name = self._get_scheduler_name(scheduler) + if scheduler_name != "DDPMScheduler": raise NotImplementedError( - f"Likelihood computation is only compatible with DDPMScheduler," - f" you are using {scheduler._get_name()}" + f"Likelihood computation is only compatible with DDPMScheduler," f" you are using {scheduler_name}" ) if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") @@ -1046,7 +1141,7 @@ def get_likelihood( total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() - noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + noisy_image = scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) diffusion_model = ( partial(diffusion_model, seg=seg) if isinstance(diffusion_model, SPADEDiffusionModelUNet) @@ -1059,7 +1154,8 @@ def get_likelihood( model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) # get the model's predicted mean, and variance if it is predicted - if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + variance_type = self._get_scheduler_config_value(scheduler, "variance_type") + if model_output.shape[1] == inputs.shape[1] * 2 and variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) else: predicted_variance = None @@ -1072,15 +1168,17 @@ def get_likelihood( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if scheduler.prediction_type == "epsilon": + prediction_type = self._get_scheduler_config_value(scheduler, "prediction_type") + if prediction_type == "epsilon": pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif scheduler.prediction_type == "sample": + elif prediction_type == "sample": pred_original_sample = model_output - elif scheduler.prediction_type == "v_prediction": + elif prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output # 3. Clip "predicted x_0" - if scheduler.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + if self._get_scheduler_config_value(scheduler, "clip_sample"): + clip_sample_range = self._get_scheduler_config_value(scheduler, "clip_sample_range", 1.0) + pred_original_sample = torch.clamp(pred_original_sample, -clip_sample_range, clip_sample_range) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf @@ -1092,11 +1190,15 @@ def get_likelihood( predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image # get the posterior mean and variance - posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator] - posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator] + posterior_mean = self._get_posterior_mean(scheduler=scheduler, timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = self._get_posterior_variance( + scheduler=scheduler, timestep=t, predicted_variance=predicted_variance + ) log_posterior_variance = torch.log(posterior_variance) - log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + log_predicted_variance = ( + torch.log(predicted_variance) if predicted_variance is not None else log_posterior_variance + ) if t == 0: # compute -log p(x_0|x_1) @@ -1509,7 +1611,12 @@ def sample( # type: ignore[override] scheduler = self.scheduler image = input_noise - all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype))) + all_next_timesteps = torch.cat( + ( + scheduler.timesteps[1:], + torch.tensor([0], dtype=scheduler.timesteps.dtype, device=scheduler.timesteps.device), + ) + ) if verbose and has_tqdm: progress_bar = tqdm( zip(scheduler.timesteps, all_next_timesteps), @@ -1583,10 +1690,9 @@ def sample( # type: ignore[override] model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond) # 3. compute previous image: x_t -> x_t-1 - if not isinstance(scheduler, RFlowScheduler): - image, _ = scheduler.step(model_output, t, image) # type: ignore - else: - image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore + image = self._scheduler_step( + scheduler=scheduler, model_output=model_output, timestep=t, sample=image, next_timestep=next_t + ) if save_intermediates and t % intermediate_steps == 0: intermediates.append(image) @@ -1631,10 +1737,10 @@ def get_likelihood( # type: ignore[override] if not scheduler: scheduler = self.scheduler - if scheduler._get_name() != "DDPMScheduler": + scheduler_name = self._get_scheduler_name(scheduler) + if scheduler_name != "DDPMScheduler": raise NotImplementedError( - f"Likelihood computation is only compatible with DDPMScheduler," - f" you are using {scheduler._get_name()}" + f"Likelihood computation is only compatible with DDPMScheduler," f" you are using {scheduler_name}" ) if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") @@ -1647,7 +1753,7 @@ def get_likelihood( # type: ignore[override] total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() - noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + noisy_image = scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): @@ -1680,7 +1786,8 @@ def get_likelihood( # type: ignore[override] mid_block_additional_residual=mid_block_res_sample, ) # get the model's predicted mean, and variance if it is predicted - if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + variance_type = self._get_scheduler_config_value(scheduler, "variance_type") + if model_output.shape[1] == inputs.shape[1] * 2 and variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) else: predicted_variance = None @@ -1693,15 +1800,17 @@ def get_likelihood( # type: ignore[override] # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if scheduler.prediction_type == "epsilon": + prediction_type = self._get_scheduler_config_value(scheduler, "prediction_type") + if prediction_type == "epsilon": pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif scheduler.prediction_type == "sample": + elif prediction_type == "sample": pred_original_sample = model_output - elif scheduler.prediction_type == "v_prediction": + elif prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output # 3. Clip "predicted x_0" - if scheduler.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + if self._get_scheduler_config_value(scheduler, "clip_sample"): + clip_sample_range = self._get_scheduler_config_value(scheduler, "clip_sample_range", 1.0) + pred_original_sample = torch.clamp(pred_original_sample, -clip_sample_range, clip_sample_range) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf @@ -1713,11 +1822,15 @@ def get_likelihood( # type: ignore[override] predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image # get the posterior mean and variance - posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator] - posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator] + posterior_mean = self._get_posterior_mean(scheduler=scheduler, timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = self._get_posterior_variance( + scheduler=scheduler, timestep=t, predicted_variance=predicted_variance + ) log_posterior_variance = torch.log(posterior_variance) - log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + log_predicted_variance = ( + torch.log(predicted_variance) if predicted_variance is not None else log_posterior_variance + ) if t == 0: # compute -log p(x_0|x_1) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index c7ac4b77e6..f9be2e5b45 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -518,6 +518,7 @@ ApplyTransformToPoints, AsChannelLast, CastToType, + ChannelWise, ClassesToIndices, ConvertToMultiChannelBasedOnBratsClasses, CuCIM, @@ -536,6 +537,7 @@ RandIdentity, RandImageFilter, RandLambda, + RandChannelWise, RandTorchIO, RandTorchVision, RemoveRepeatedChannel, @@ -568,6 +570,9 @@ CastToTyped, CastToTypeD, CastToTypeDict, + ChannelWised, + ChannelWiseD, + ChannelWiseDict, ClassesToIndicesd, ClassesToIndicesD, ClassesToIndicesDict, @@ -631,6 +636,9 @@ RandLambdad, RandLambdaD, RandLambdaDict, + RandChannelWised, + RandChannelWiseD, + RandChannelWiseDict, RandTorchIOd, RandTorchIOD, RandTorchIODict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index ed4b149e6b..51ac15bbc4 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -81,6 +81,8 @@ "EnsureType", "RepeatChannel", "RemoveRepeatedChannel", + "ChannelWise", + "RandChannelWise", "SplitDim", "CastToType", "ToTensor", @@ -288,6 +290,82 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: return out +class ChannelWise(Transform): + """ + Apply a given transform to each channel of the input array independently and + concatenate the results back along the channel dimension. + + Args: + transform: a callable transform to apply to each channel. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, transform: Callable) -> None: + self.transform = transform + + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + """ + Apply the transform to `img`. + """ + if img.shape[0] == 0: + return img + + results = [] + for i in range(img.shape[0]): + res = self.transform(img[[i], ...]) + results.append(res) + + if isinstance(img, torch.Tensor): + return torch.cat(results, dim=0) + return np.concatenate(results, axis=0) + + +class RandChannelWise(RandomizableTransform): + """ + Randomizable version of :py:class:`monai.transforms.ChannelWise`, the input + `transform` will be applied independently to each channel. + + Args: + transform: a callable transform to apply to each channel. + prob: probability of applying the transform to the entire image. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, transform: Callable, prob: float = 1.0) -> None: + RandomizableTransform.__init__(self, prob) + self.transform = transform + + def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandChannelWise: + super().set_random_state(seed, state) + if hasattr(self.transform, "set_random_state"): + self.transform.set_random_state(seed, state) + return self + + def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + """ + Apply the transform to `img`. + """ + if randomize: + self.randomize(None) + if not self._do_transform: + return img + + if img.shape[0] == 0: + return img + + results = [] + for i in range(img.shape[0]): + res = self.transform(img[[i], ...]) + results.append(res) + + if isinstance(img, torch.Tensor): + return torch.cat(results, dim=0) + return np.concatenate(results, axis=0) + + + class SplitDim(Transform, MultiSampleTrait): """ Given an image of size X along a certain dimension, return a list of length X containing diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 7dd24a3880..e93468a0ba 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -38,6 +38,7 @@ ApplyTransformToPoints, AsChannelLast, CastToType, + ChannelWise, ClassesToIndices, ConvertToMultiChannelBasedOnBratsClasses, CuCIM, @@ -52,6 +53,7 @@ LabelToMask, Lambda, MapLabelValue, + RandChannelWise, RemoveRepeatedChannel, RepeatChannel, SimulateDelay, @@ -88,6 +90,9 @@ "ConcatItemsD", "ConcatItemsDict", "ConcatItemsd", + "ChannelWiseD", + "ChannelWiseDict", + "ChannelWised", "ConvertToMultiChannelBasedOnBratsClassesD", "ConvertToMultiChannelBasedOnBratsClassesDict", "ConvertToMultiChannelBasedOnBratsClassesd", @@ -131,6 +136,9 @@ "FlattenSubKeysd", "FlattenSubKeysD", "FlattenSubKeysDict", + "RandChannelWiseD", + "RandChannelWiseDict", + "RandChannelWised", "RandCuCIMd", "RandCuCIMD", "RandCuCIMDict", @@ -338,6 +346,70 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class ChannelWised(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.ChannelWise`. + """ + + backend = ChannelWise.backend + + def __init__(self, keys: KeysCollection, transform: Callable, allow_missing_keys: bool = False) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + transform: a callable transform to apply to each channel. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.converter = ChannelWise(transform=transform) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter(d[key]) + return d + + +class RandChannelWised(MapTransform, RandomizableTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.RandChannelWise`. + """ + + backend = RandChannelWise.backend + + def __init__(self, keys: KeysCollection, transform: Callable, prob: float = 1.0, allow_missing_keys: bool = False) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + transform: a callable transform to apply to each channel. + prob: probability of applying the transform to the entire image. + allow_missing_keys: don't raise exception if key is missing. + """ + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + self.converter = RandChannelWise(transform=transform, prob=1.0) + + def set_random_state( + self, seed: int | None = None, state: np.random.RandomState | None = None + ) -> RandChannelWised: + super().set_random_state(seed, state) + if hasattr(self.converter, "set_random_state"): + self.converter.set_random_state(seed, state) + return self + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + if not self._do_transform: + return d + + for key in self.key_iterator(d): + d[key] = self.converter(d[key], randomize=False) + return d + + class SplitDimd(MapTransform, MultiSampleTrait): backend = SplitDim.backend @@ -2032,6 +2104,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N AsChannelLastD = AsChannelLastDict = AsChannelLastd EnsureChannelFirstD = EnsureChannelFirstDict = EnsureChannelFirstd RemoveRepeatedChannelD = RemoveRepeatedChannelDict = RemoveRepeatedChanneld +ChannelWiseD = ChannelWiseDict = ChannelWised +RandChannelWiseD = RandChannelWiseDict = RandChannelWised RepeatChannelD = RepeatChannelDict = RepeatChanneld SplitDimD = SplitDimDict = SplitDimd CastToTypeD = CastToTypeDict = CastToTyped diff --git a/tests/inferers/test_diffusion_inferer.py b/tests/inferers/test_diffusion_inferer.py index 81874ed3a8..9e1a3072dd 100644 --- a/tests/inferers/test_diffusion_inferer.py +++ b/tests/inferers/test_diffusion_inferer.py @@ -24,6 +24,7 @@ _, has_scipy = optional_import("scipy") _, has_einops = optional_import("einops") +DiffusersDDPMScheduler, has_diffusers = optional_import("diffusers", name="DDPMScheduler") TEST_CASES = [ [ @@ -126,6 +127,63 @@ def test_ddpm_sampler(self, model_params, input_shape): ) self.assertEqual(len(intermediates), 10) + @skipUnless(has_einops and has_diffusers, "Requires einops and diffusers") + def test_diffusers_ddpm_call(self): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=[32, 64], + attention_levels=[False, True], + num_res_blocks=1, + num_head_channels=32, + ) + model.to(device) + model.eval() + scheduler = DiffusersDDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="epsilon") + scheduler.set_timesteps(num_inference_steps=50) + inferer = DiffusionInferer(scheduler=scheduler) + + batch_size = 2 + image_size = 32 + inputs = torch.randn(batch_size, 1, image_size, image_size).to(device) + noise = torch.randn_like(inputs) + timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,)).long().to(device) + with torch.no_grad(): + prediction = inferer(inputs=inputs, diffusion_model=model, noise=noise, timesteps=timesteps) + + self.assertEqual(prediction.shape, inputs.shape) + scheduler.set_timesteps(num_inference_steps=2) + sample = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=scheduler, verbose=False) + self.assertEqual(sample.shape, inputs.shape) + + @skipUnless(has_einops and has_diffusers, "Requires einops and diffusers") + def test_diffusers_ddpm_get_likelihood(self): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=[8], + norm_num_groups=8, + attention_levels=[True], + num_res_blocks=1, + num_head_channels=8, + ) + model.to(device) + model.eval() + inputs = torch.randn(2, 1, 8, 8).to(device) + scheduler = DiffusersDDPMScheduler(num_train_timesteps=10, beta_schedule="linear", prediction_type="epsilon") + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=inputs, diffusion_model=model, scheduler=scheduler, save_intermediates=True + ) + self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, inputs.shape) + self.assertEqual(likelihood.shape[0], inputs.shape[0]) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_ddim_sampler(self, model_params, input_shape): diff --git a/tests/inferers/test_latent_diffusion_inferer.py b/tests/inferers/test_latent_diffusion_inferer.py index ab80363cde..23dd594d8e 100644 --- a/tests/inferers/test_latent_diffusion_inferer.py +++ b/tests/inferers/test_latent_diffusion_inferer.py @@ -23,6 +23,7 @@ from monai.utils import optional_import _, has_einops = optional_import("einops") +DiffusersDDPMScheduler, has_diffusers = optional_import("diffusers", name="DDPMScheduler") TEST_CASES = [ [ "AutoencoderKL", @@ -414,6 +415,46 @@ def test_sample_shape( ) self.assertEqual(sample.shape, input_shape) + @skipUnless(has_einops and has_diffusers, "Requires einops and diffusers") + def test_diffusers_ddpm_sample_shape(self): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1 = AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(4, 4), + latent_channels=3, + attention_levels=[False, False], + num_res_blocks=1, + with_encoder_nonlocal_attn=False, + with_decoder_nonlocal_attn=False, + norm_num_groups=4, + ) + stage_2 = DiffusionModelUNet( + spatial_dims=2, + in_channels=3, + out_channels=3, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + ) + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(1, 3, 4, 4).to(device) + scheduler = DiffusersDDPMScheduler(num_train_timesteps=10, beta_schedule="linear", prediction_type="epsilon") + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + sample = inferer.sample( + input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler + ) + self.assertEqual(sample.shape, (1, 1, 8, 8)) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_sample_shape_with_cfg( diff --git a/tests/test_channel_wise.py b/tests/test_channel_wise.py new file mode 100644 index 0000000000..5940eb87d6 --- /dev/null +++ b/tests/test_channel_wise.py @@ -0,0 +1,49 @@ +import unittest + +import numpy as np + +from monai.transforms import ChannelWise, RandChannelWise, RandGaussianNoise, ScaleIntensity +from monai.utils import set_determinism + + +class TestChannelWise(unittest.TestCase): + def test_channel_wise_deterministic(self): + # Test applying a deterministic transform channel-wise + data = np.array([[[1.0, 2.0], [3.0, 4.0]], [[10.0, 20.0], [30.0, 40.0]]]) # shape (2, 2, 2) + + # ScaleIntensity applies to the whole input array independently + transform = ChannelWise(transform=ScaleIntensity()) + out = transform(data) + + # Channel 0 scaled + np.testing.assert_allclose(out[0], np.array([[0.0, 0.3333333], [0.6666667, 1.0]]), atol=1e-5) + # Channel 1 scaled + np.testing.assert_allclose(out[1], np.array([[0.0, 0.3333333], [0.6666667, 1.0]]), atol=1e-5) + self.assertEqual(out.shape, data.shape) + + def test_rand_channel_wise(self): + # Test applying a randomized transform channel-wise + data = np.zeros((3, 4, 4)) + + set_determinism(seed=0) + # Apply random noise with high standard deviation to see the difference + transform = RandChannelWise(transform=RandGaussianNoise(prob=1.0, std=1.0)) + out = transform(data) + + # All channels should have different noise values + self.assertFalse(np.allclose(out[0], out[1])) + self.assertFalse(np.allclose(out[1], out[2])) + self.assertFalse(np.allclose(out[0], out[2])) + + # Output shape should be exactly the same + self.assertEqual(out.shape, data.shape) + + def test_prob_zero(self): + # Test when RandChannelWise prob is 0.0 + data = np.zeros((2, 2, 2)) + transform = RandChannelWise(transform=RandGaussianNoise(prob=1.0, std=1.0), prob=0.0) + out = transform(data) + np.testing.assert_allclose(out, data) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_channel_wised.py b/tests/test_channel_wised.py new file mode 100644 index 0000000000..e2242be7f3 --- /dev/null +++ b/tests/test_channel_wised.py @@ -0,0 +1,49 @@ +import unittest + +import numpy as np + +from monai.transforms import ChannelWised, RandChannelWised, RandGaussianNoise, ScaleIntensity +from monai.utils import set_determinism + + +class TestChannelWised(unittest.TestCase): + def test_channel_wise_deterministic(self): + # Test applying a deterministic transform channel-wise + data = {"image": np.array([[[1.0, 2.0], [3.0, 4.0]], [[10.0, 20.0], [30.0, 40.0]]])} # shape (2, 2, 2) + + # ScaleIntensity applies to the whole input array independently + transform = ChannelWised(keys=["image"], transform=ScaleIntensity()) + out = transform(data) + + # Channel 0 scaled + np.testing.assert_allclose(out["image"][0], np.array([[0.0, 0.3333333], [0.6666667, 1.0]]), atol=1e-5) + # Channel 1 scaled + np.testing.assert_allclose(out["image"][1], np.array([[0.0, 0.3333333], [0.6666667, 1.0]]), atol=1e-5) + self.assertEqual(out["image"].shape, data["image"].shape) + + def test_rand_channel_wise(self): + # Test applying a randomized transform channel-wise + data = {"image": np.zeros((3, 4, 4))} + + set_determinism(seed=0) + # Apply random noise with high standard deviation to see the difference + transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0)) + out = transform(data) + + # All channels should have different noise values + self.assertFalse(np.allclose(out["image"][0], out["image"][1])) + self.assertFalse(np.allclose(out["image"][1], out["image"][2])) + self.assertFalse(np.allclose(out["image"][0], out["image"][2])) + + # Output shape should be exactly the same + self.assertEqual(out["image"].shape, data["image"].shape) + + def test_prob_zero(self): + # Test when RandChannelWised prob is 0.0 + data = {"image": np.zeros((2, 2, 2))} + transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0), prob=0.0) + out = transform(data) + np.testing.assert_allclose(out["image"], data["image"]) + +if __name__ == "__main__": + unittest.main()