Skip to content

Commit

Permalink
Merge pull request #87 from Project-MONAI/66-add-latent-diffusion-inf…
Browse files Browse the repository at this point in the history
…erer

Add latent diffusion inferer
  • Loading branch information
SANCHES-Pedro authored Nov 30, 2022
2 parents 6da5d43 + 96fa095 commit 5f0fd39
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 2 deletions.
2 changes: 1 addition & 1 deletion generative/inferers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .inferer import DiffusionInferer
from .inferer import DiffusionInferer, LatentDiffusionInferer
100 changes: 99 additions & 1 deletion generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.


from typing import Callable, Optional
from typing import Callable, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -100,3 +100,101 @@ def sample(
return image, intermediates
else:
return image


class LatentDiffusionInferer(DiffusionInferer):
"""
LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can
be used to perform a signal forward pass for a training iteration, and sample from the model.
Args:
scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents.
scale_factor: scale factor to multiply the values of the latent representation before processing it by the
second stage.
"""

def __init__(self, scheduler: nn.Module, scale_factor: float = 1.0) -> None:
super().__init__(scheduler=scheduler)
self.scale_factor = scale_factor

def __call__(
self,
inputs: torch.Tensor,
autoencoder_model: Callable[..., torch.Tensor],
diffusion_model: Callable[..., torch.Tensor],
noise: torch.Tensor,
condition: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Implements the forward pass for a supervised training iteration.
Args:
inputs: input image to which the latent representation will be extracted and noise is added.
autoencoder_model: first stage model.
diffusion_model: diffusion model.
noise: random noise, of the same shape as the latent representation.
condition: conditioning for network input.
"""
with torch.no_grad():
latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor

prediction = super().__call__(
inputs=latent,
diffusion_model=diffusion_model,
noise=noise,
condition=condition,
)

return prediction

def sample(
self,
input_noise: torch.Tensor,
autoencoder_model: Callable[..., torch.Tensor],
diffusion_model: Callable[..., torch.Tensor],
scheduler: Optional[Callable[..., torch.Tensor]] = None,
save_intermediates: Optional[bool] = False,
intermediate_steps: Optional[int] = 100,
conditioning: Optional[torch.Tensor] = None,
verbose: Optional[bool] = True,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
"""
Args:
input_noise: random noise, of the same shape as the desired latent representation.
autoencoder_model: first stage model.
diffusion_model: model to sample from.
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
save_intermediates: whether to return intermediates along the sampling change
intermediate_steps: if save_intermediates is True, saves every n steps
conditioning: Conditioning for network input.
verbose: if true, prints the progression bar of the sampling process.
"""
outputs = super().sample(
input_noise=input_noise,
diffusion_model=diffusion_model,
scheduler=scheduler,
save_intermediates=save_intermediates,
intermediate_steps=intermediate_steps,
conditioning=conditioning,
verbose=verbose,
)

if save_intermediates:
latent, latent_intermediates = outputs
else:
latent = outputs

with torch.no_grad():
image = autoencoder_model.decode_stage_2_outputs(latent) * self.scale_factor

if save_intermediates:
intermediates = []
for latent_intermediate in latent_intermediates:
with torch.no_grad():
intermediates.append(
autoencoder_model.decode_stage_2_outputs(latent_intermediate) * self.scale_factor
)
return image, intermediates

else:
return image
9 changes: 9 additions & 0 deletions generative/networks/nets/autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,12 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Te
z = self.sampling(z_mu, z_sigma)
reconstruction = self.decode(z)
return reconstruction, z_mu, z_sigma

def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor:
z_mu, z_sigma = self.encode(x)
z = self.sampling(z_mu, z_sigma)
return z

def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor:
image = self.decode(z)
return image
10 changes: 10 additions & 0 deletions generative/networks/nets/vqvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,13 @@ def forward(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
reconstruction = self.decode(quantizations)

return reconstruction, quantization_losses

def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor:
z = self.encoder(x)
e, _ = self.quantize(z)
return e

def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor:
e, _ = self.quantize(z)
image = self.decode(e)
return image
160 changes: 160 additions & 0 deletions tests/test_latent_diffusion_inferer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from generative.inferers import LatentDiffusionInferer
from generative.networks.nets import VQVAE, AutoencoderKL, DiffusionModelUNet
from generative.schedulers import DDPMScheduler

TEST_CASES = [
[
"AutoencoderKL",
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_channels": 8,
"latent_channels": 3,
"ch_mult": [1, 1, 1],
"attention_levels": [False, False, False],
"num_res_blocks": 1,
"with_encoder_nonlocal_attn": False,
"with_decoder_nonlocal_attn": False,
"norm_num_groups": 8,
},
{
"spatial_dims": 2,
"in_channels": 3,
"out_channels": 3,
"model_channels": 8,
"norm_num_groups": 8,
"attention_resolutions": [8],
"num_res_blocks": 1,
"channel_mult": [1, 1, 1],
"num_heads": 1,
},
(1, 1, 32, 32),
(1, 3, 8, 8),
],
[
"VQVAE",
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_levels": 2,
"downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)),
"upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
"num_res_layers": 1,
"num_channels": [8, 8],
"num_res_channels": [8, 8],
"num_embeddings": 16,
"embedding_dim": 3,
},
{
"spatial_dims": 2,
"in_channels": 3,
"out_channels": 3,
"model_channels": 8,
"norm_num_groups": 8,
"attention_resolutions": [8],
"num_res_blocks": 1,
"channel_mult": [1, 1, 1],
"num_heads": 1,
},
(1, 1, 32, 32),
(1, 3, 8, 8),
],
]


class TestDiffusionSamplingInferer(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_prediction_shape(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape):
if model_type == "AutoencoderKL":
autoencoder_model = AutoencoderKL(**autoencoder_params)
if model_type == "VQVAE":
autoencoder_model = VQVAE(**autoencoder_params)
stage_2 = DiffusionModelUNet(**stage_2_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
autoencoder_model.to(device)
stage_2.to(device)
autoencoder_model.eval()
autoencoder_model.train()
input = torch.randn(input_shape).to(device)
noise = torch.randn(latent_shape).to(device)
scheduler = DDPMScheduler(
num_train_timesteps=10,
)
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
scheduler.set_timesteps(num_inference_steps=10)
prediction = inferer(inputs=input, autoencoder_model=autoencoder_model, diffusion_model=stage_2, noise=noise)
self.assertEqual(prediction.shape, latent_shape)

@parameterized.expand(TEST_CASES)
def test_sample_shape(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape):
if model_type == "AutoencoderKL":
autoencoder_model = AutoencoderKL(**autoencoder_params)
if model_type == "VQVAE":
autoencoder_model = VQVAE(**autoencoder_params)
stage_2 = DiffusionModelUNet(**stage_2_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
autoencoder_model.to(device)
stage_2.to(device)
autoencoder_model.eval()
autoencoder_model.train()
noise = torch.randn(latent_shape).to(device)
scheduler = DDPMScheduler(
num_train_timesteps=10,
)
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
scheduler.set_timesteps(num_inference_steps=10)
sample = inferer.sample(
input_noise=noise, autoencoder_model=autoencoder_model, diffusion_model=stage_2, scheduler=scheduler
)
self.assertEqual(sample.shape, input_shape)

@parameterized.expand(TEST_CASES)
def test_sample_intermediates(self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape):
if model_type == "AutoencoderKL":
autoencoder_model = AutoencoderKL(**autoencoder_params)
if model_type == "VQVAE":
autoencoder_model = VQVAE(**autoencoder_params)
stage_2 = DiffusionModelUNet(**stage_2_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
autoencoder_model.to(device)
stage_2.to(device)
autoencoder_model.eval()
autoencoder_model.train()
noise = torch.randn(latent_shape).to(device)
scheduler = DDPMScheduler(
num_train_timesteps=10,
)
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
scheduler.set_timesteps(num_inference_steps=10)
sample, intermediates = inferer.sample(
input_noise=noise,
autoencoder_model=autoencoder_model,
diffusion_model=stage_2,
scheduler=scheduler,
save_intermediates=True,
intermediate_steps=1,
)
self.assertEqual(len(intermediates), 10)
self.assertEqual(intermediates[0].shape, input_shape)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5f0fd39

Please sign in to comment.