From c5e65c3235655ca1da0c26691106597e1425bd5b Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Tue, 26 Mar 2024 10:12:15 +0000 Subject: [PATCH] Allowed for the quantized flag to be passed to the LatentDiffusionInferer methods, which is then passed to VQVAE encode_stage_2_inputs if autoencoder_model is a VQVAE. Set this flag randomly during testing (when the autoencoder is a VAE, it shouldn't matter), ran the tests, and ran reformatting. + controlnet.py has been changed for reformatting purposes only. --- generative/inferers/inferer.py | 36 ++++++++++++++++++++++---- generative/networks/nets/controlnet.py | 23 +++++++++------- tests/test_latent_diffusion_inferer.py | 9 +++++++ 3 files changed, 53 insertions(+), 15 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 48eb7f6d..2426fe41 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -23,7 +23,7 @@ from monai.transforms import CenterSpatialCrop, SpatialPad from monai.utils import optional_import -from generative.networks.nets import SPADEAutoencoderKL, SPADEDiffusionModelUNet +from generative.networks.nets import VQVAE, SPADEAutoencoderKL, SPADEDiffusionModelUNet tqdm, has_tqdm = optional_import("tqdm", name="tqdm") @@ -362,6 +362,7 @@ def __call__( condition: torch.Tensor | None = None, mode: str = "crossattn", seg: torch.Tensor | None = None, + quantized: bool = True, ) -> torch.Tensor: """ Implements the forward pass for a supervised training iteration. @@ -375,9 +376,14 @@ def __call__( condition: conditioning for network input. mode: Conditioning mode for the network. seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM + are quantized or not. """ with torch.no_grad(): - latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + autoencode = autoencoder_model.encode_stage_2_inputs + if isinstance(autoencoder_model, VQVAE): + autoencode = partial(autoencoder_model.encode_stage_2_inputs, quantized=quantized) + latent = autoencode(inputs) * self.scale_factor if self.ldm_latent_shape is not None: latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) @@ -496,6 +502,7 @@ def get_likelihood( resample_latent_likelihoods: bool = False, resample_interpolation_mode: str = "nearest", seg: torch.Tensor | None = None, + quantized: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Computes the log-likelihoods of the latent representations of the input. @@ -517,12 +524,18 @@ def get_likelihood( or 'trilinear; seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model is instance of SPADEAutoencoderKL, segmentation must be provided. + quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM + are quantized or not. """ if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): raise ValueError( f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" ) - latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + + autoencode = autoencoder_model.encode_stage_2_inputs + if isinstance(autoencoder_model, VQVAE): + autoencode = partial(autoencoder_model.encode_stage_2_inputs, quantized=quantized) + latents = autoencode(inputs) * self.scale_factor if self.ldm_latent_shape is not None: latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0) @@ -882,6 +895,7 @@ def __call__( condition: torch.Tensor | None = None, mode: str = "crossattn", seg: torch.Tensor | None = None, + quantized: bool = True, ) -> torch.Tensor: """ Implements the forward pass for a supervised training iteration. @@ -897,9 +911,14 @@ def __call__( condition: conditioning for network input. mode: Conditioning mode for the network. seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM + are quantized or not. """ with torch.no_grad(): - latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + autoencode = autoencoder_model.encode_stage_2_inputs + if isinstance(autoencoder_model, VQVAE): + autoencode = partial(autoencoder_model.encode_stage_2_inputs, quantized=quantized) + latent = autoencode(inputs) * self.scale_factor if self.ldm_latent_shape is not None: latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0) @@ -1036,6 +1055,7 @@ def get_likelihood( resample_latent_likelihoods: bool = False, resample_interpolation_mode: str = "nearest", seg: torch.Tensor | None = None, + quantized: bool = True, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Computes the log-likelihoods of the latent representations of the input. @@ -1059,13 +1079,19 @@ def get_likelihood( or 'trilinear; seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model is instance of SPADEAutoencoderKL, segmentation must be provided. + quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM + are quantized or not. """ if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"): raise ValueError( f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}" ) - latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor + with torch.no_grad(): + autoencode = autoencoder_model.encode_stage_2_inputs + if isinstance(autoencoder_model, VQVAE): + autoencode = partial(autoencoder_model.encode_stage_2_inputs, quantized=quantized) + latents = autoencode(inputs) * self.scale_factor if cn_cond.shape[2:] != latents.shape[2:]: cn_cond = F.interpolate(cn_cond, latents.shape[2:]) diff --git a/generative/networks/nets/controlnet.py b/generative/networks/nets/controlnet.py index caedf736..e6d736b0 100644 --- a/generative/networks/nets/controlnet.py +++ b/generative/networks/nets/controlnet.py @@ -41,6 +41,7 @@ from generative.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding + class ControlNetConditioningEmbedding(nn.Module): """ Network to encode the conditioning into a latent space. @@ -120,10 +121,9 @@ def zero_module(module): nn.init.zeros_(p) return module -def copy_weights_to_controlnet(controlnet : nn.Module, - diffusion_model: nn.Module, - verbose: bool = True) -> None: - ''' + +def copy_weights_to_controlnet(controlnet: nn.Module, diffusion_model: nn.Module, verbose: bool = True) -> None: + """ Copy the state dict from the input diffusion model to the ControlNet, printing, if user requires it, the output keys that have matched and those that haven't. @@ -131,15 +131,18 @@ def copy_weights_to_controlnet(controlnet : nn.Module, controlnet: instance of ControlNet diffusion_model: instance of DiffusionModelUnet or SPADEDiffusionModelUnet verbose: if True, the matched and unmatched keys will be printed. - ''' + """ - output = controlnet.load_state_dict(diffusion_model.state_dict(), strict = False) + output = controlnet.load_state_dict(diffusion_model.state_dict(), strict=False) if verbose: dm_keys = [p[0] for p in list(diffusion_model.named_parameters()) if p[0] not in output.unexpected_keys] - print(f"Copied weights from {len(dm_keys)} keys of the diffusion model into the ControlNet:" - f"\n{'; '.join(dm_keys)}\nControlNet missing keys: {len(output.missing_keys)}:" - f"\n{'; '.join(output.missing_keys)}\nDiffusion model incompatible keys: {len(output.unexpected_keys)}:" - f"\n{'; '.join(output.unexpected_keys)}") + print( + f"Copied weights from {len(dm_keys)} keys of the diffusion model into the ControlNet:" + f"\n{'; '.join(dm_keys)}\nControlNet missing keys: {len(output.missing_keys)}:" + f"\n{'; '.join(output.missing_keys)}\nDiffusion model incompatible keys: {len(output.unexpected_keys)}:" + f"\n{'; '.join(output.unexpected_keys)}" + ) + class ControlNet(nn.Module): """ diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 3b5e8833..adcd481e 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -13,6 +13,7 @@ import unittest +import numpy as np import torch from parameterized import parameterized @@ -329,6 +330,7 @@ def test_prediction_shape( seg=input_seg, noise=noise, timesteps=timesteps, + quantized=np.random.choice([True, False]), ) else: prediction = inferer( @@ -472,6 +474,7 @@ def test_get_likelihoods( scheduler=scheduler, save_intermediates=True, seg=input_seg, + quantized=np.random.choice([True, False]), ) else: sample, intermediates = inferer.get_likelihood( @@ -480,6 +483,7 @@ def test_get_likelihoods( diffusion_model=stage_2, scheduler=scheduler, save_intermediates=True, + quantized=np.random.choice([True, False]), ) self.assertEqual(len(intermediates), 10) self.assertEqual(intermediates[0].shape, latent_shape) @@ -525,6 +529,7 @@ def test_resample_likelihoods( save_intermediates=True, resample_latent_likelihoods=True, seg=input_seg, + quantized=np.random.choice([True, False]), ) else: sample, intermediates = inferer.get_likelihood( @@ -534,6 +539,7 @@ def test_resample_likelihoods( scheduler=scheduler, save_intermediates=True, resample_latent_likelihoods=True, + quantized=np.random.choice([True, False]), ) self.assertEqual(len(intermediates), 10) self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) @@ -590,6 +596,7 @@ def test_prediction_shape_conditioned_concat( condition=conditioning, mode="concat", seg=input_seg, + quantized=np.random.choice([True, False]), ) else: prediction = inferer( @@ -600,6 +607,7 @@ def test_prediction_shape_conditioned_concat( timesteps=timesteps, condition=conditioning, mode="concat", + quantized=np.random.choice([True, False]), ) self.assertEqual(prediction.shape, latent_shape) @@ -713,6 +721,7 @@ def test_sample_shape_different_latents( noise=noise, timesteps=timesteps, seg=input_seg, + quantized=np.random.choice([True, False]), ) else: prediction = inferer(