Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from monai.transforms import CenterSpatialCrop, SpatialPad
from monai.utils import optional_import

from generative.networks.nets import SPADEAutoencoderKL, SPADEDiffusionModelUNet
from generative.networks.nets import VQVAE, SPADEAutoencoderKL, SPADEDiffusionModelUNet

tqdm, has_tqdm = optional_import("tqdm", name="tqdm")

Expand Down Expand Up @@ -362,6 +362,7 @@ def __call__(
condition: torch.Tensor | None = None,
mode: str = "crossattn",
seg: torch.Tensor | None = None,
quantized: bool = True,
) -> torch.Tensor:
"""
Implements the forward pass for a supervised training iteration.
Expand All @@ -375,9 +376,14 @@ def __call__(
condition: conditioning for network input.
mode: Conditioning mode for the network.
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM
are quantized or not.
"""
with torch.no_grad():
latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
autoencode = autoencoder_model.encode_stage_2_inputs
if isinstance(autoencoder_model, VQVAE):
autoencode = partial(autoencoder_model.encode_stage_2_inputs, quantized=quantized)
latent = autoencode(inputs) * self.scale_factor

if self.ldm_latent_shape is not None:
latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0)
Expand Down Expand Up @@ -496,6 +502,7 @@ def get_likelihood(
resample_latent_likelihoods: bool = False,
resample_interpolation_mode: str = "nearest",
seg: torch.Tensor | None = None,
quantized: bool = True,
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""
Computes the log-likelihoods of the latent representations of the input.
Expand All @@ -517,12 +524,18 @@ def get_likelihood(
or 'trilinear;
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
is instance of SPADEAutoencoderKL, segmentation must be provided.
quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM
are quantized or not.
"""
if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"):
raise ValueError(
f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}"
)
latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor

autoencode = autoencoder_model.encode_stage_2_inputs
if isinstance(autoencoder_model, VQVAE):
autoencode = partial(autoencoder_model.encode_stage_2_inputs, quantized=quantized)
latents = autoencode(inputs) * self.scale_factor

if self.ldm_latent_shape is not None:
latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0)
Expand Down Expand Up @@ -882,6 +895,7 @@ def __call__(
condition: torch.Tensor | None = None,
mode: str = "crossattn",
seg: torch.Tensor | None = None,
quantized: bool = True,
) -> torch.Tensor:
"""
Implements the forward pass for a supervised training iteration.
Expand All @@ -897,9 +911,14 @@ def __call__(
condition: conditioning for network input.
mode: Conditioning mode for the network.
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM
are quantized or not.
"""
with torch.no_grad():
latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
autoencode = autoencoder_model.encode_stage_2_inputs
if isinstance(autoencoder_model, VQVAE):
autoencode = partial(autoencoder_model.encode_stage_2_inputs, quantized=quantized)
latent = autoencode(inputs) * self.scale_factor

if self.ldm_latent_shape is not None:
latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0)
Expand Down Expand Up @@ -1036,6 +1055,7 @@ def get_likelihood(
resample_latent_likelihoods: bool = False,
resample_interpolation_mode: str = "nearest",
seg: torch.Tensor | None = None,
quantized: bool = True,
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""
Computes the log-likelihoods of the latent representations of the input.
Expand All @@ -1059,13 +1079,19 @@ def get_likelihood(
or 'trilinear;
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
is instance of SPADEAutoencoderKL, segmentation must be provided.
quantized: if autoencoder_model is a VQVAE, quantized controls whether the latents to the LDM
are quantized or not.
"""
if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"):
raise ValueError(
f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}"
)

latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
with torch.no_grad():
autoencode = autoencoder_model.encode_stage_2_inputs
if isinstance(autoencoder_model, VQVAE):
autoencode = partial(autoencoder_model.encode_stage_2_inputs, quantized=quantized)
latents = autoencode(inputs) * self.scale_factor

if cn_cond.shape[2:] != latents.shape[2:]:
cn_cond = F.interpolate(cn_cond, latents.shape[2:])
Expand Down
23 changes: 13 additions & 10 deletions generative/networks/nets/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from generative.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding


class ControlNetConditioningEmbedding(nn.Module):
"""
Network to encode the conditioning into a latent space.
Expand Down Expand Up @@ -120,26 +121,28 @@ def zero_module(module):
nn.init.zeros_(p)
return module

def copy_weights_to_controlnet(controlnet : nn.Module,
diffusion_model: nn.Module,
verbose: bool = True) -> None:
'''

def copy_weights_to_controlnet(controlnet: nn.Module, diffusion_model: nn.Module, verbose: bool = True) -> None:
"""
Copy the state dict from the input diffusion model to the ControlNet, printing, if user requires it, the output
keys that have matched and those that haven't.

Args:
controlnet: instance of ControlNet
diffusion_model: instance of DiffusionModelUnet or SPADEDiffusionModelUnet
verbose: if True, the matched and unmatched keys will be printed.
'''
"""

output = controlnet.load_state_dict(diffusion_model.state_dict(), strict = False)
output = controlnet.load_state_dict(diffusion_model.state_dict(), strict=False)
if verbose:
dm_keys = [p[0] for p in list(diffusion_model.named_parameters()) if p[0] not in output.unexpected_keys]
print(f"Copied weights from {len(dm_keys)} keys of the diffusion model into the ControlNet:"
f"\n{'; '.join(dm_keys)}\nControlNet missing keys: {len(output.missing_keys)}:"
f"\n{'; '.join(output.missing_keys)}\nDiffusion model incompatible keys: {len(output.unexpected_keys)}:"
f"\n{'; '.join(output.unexpected_keys)}")
print(
f"Copied weights from {len(dm_keys)} keys of the diffusion model into the ControlNet:"
f"\n{'; '.join(dm_keys)}\nControlNet missing keys: {len(output.missing_keys)}:"
f"\n{'; '.join(output.missing_keys)}\nDiffusion model incompatible keys: {len(output.unexpected_keys)}:"
f"\n{'; '.join(output.unexpected_keys)}"
)


class ControlNet(nn.Module):
"""
Expand Down
9 changes: 9 additions & 0 deletions tests/test_latent_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import unittest

import numpy as np
import torch
from parameterized import parameterized

Expand Down Expand Up @@ -329,6 +330,7 @@ def test_prediction_shape(
seg=input_seg,
noise=noise,
timesteps=timesteps,
quantized=np.random.choice([True, False]),
)
else:
prediction = inferer(
Expand Down Expand Up @@ -472,6 +474,7 @@ def test_get_likelihoods(
scheduler=scheduler,
save_intermediates=True,
seg=input_seg,
quantized=np.random.choice([True, False]),
)
else:
sample, intermediates = inferer.get_likelihood(
Expand All @@ -480,6 +483,7 @@ def test_get_likelihoods(
diffusion_model=stage_2,
scheduler=scheduler,
save_intermediates=True,
quantized=np.random.choice([True, False]),
)
self.assertEqual(len(intermediates), 10)
self.assertEqual(intermediates[0].shape, latent_shape)
Expand Down Expand Up @@ -525,6 +529,7 @@ def test_resample_likelihoods(
save_intermediates=True,
resample_latent_likelihoods=True,
seg=input_seg,
quantized=np.random.choice([True, False]),
)
else:
sample, intermediates = inferer.get_likelihood(
Expand All @@ -534,6 +539,7 @@ def test_resample_likelihoods(
scheduler=scheduler,
save_intermediates=True,
resample_latent_likelihoods=True,
quantized=np.random.choice([True, False]),
)
self.assertEqual(len(intermediates), 10)
self.assertEqual(intermediates[0].shape[2:], input_shape[2:])
Expand Down Expand Up @@ -590,6 +596,7 @@ def test_prediction_shape_conditioned_concat(
condition=conditioning,
mode="concat",
seg=input_seg,
quantized=np.random.choice([True, False]),
)
else:
prediction = inferer(
Expand All @@ -600,6 +607,7 @@ def test_prediction_shape_conditioned_concat(
timesteps=timesteps,
condition=conditioning,
mode="concat",
quantized=np.random.choice([True, False]),
)
self.assertEqual(prediction.shape, latent_shape)

Expand Down Expand Up @@ -713,6 +721,7 @@ def test_sample_shape_different_latents(
noise=noise,
timesteps=timesteps,
seg=input_seg,
quantized=np.random.choice([True, False]),
)
else:
prediction = inferer(
Expand Down