Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __call__(
inputs: torch.Tensor,
diffusion_model: Callable[..., torch.Tensor],
noise: torch.Tensor,
timesteps: torch.Tensor,
condition: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Expand All @@ -48,10 +49,9 @@ def __call__(
inputs: Input image to which noise is added.
diffusion_model: diffusion model.
noise: random noise, of the same shape as the input.
timesteps: random timesteps.
condition: Conditioning for network input.
"""
num_timesteps = self.scheduler.num_train_timesteps
timesteps = torch.randint(0, num_timesteps, (inputs.shape[0],), device=inputs.device).long()
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition)

Expand Down Expand Up @@ -123,6 +123,7 @@ def __call__(
autoencoder_model: Callable[..., torch.Tensor],
diffusion_model: Callable[..., torch.Tensor],
noise: torch.Tensor,
timesteps: torch.Tensor,
condition: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Expand All @@ -133,6 +134,7 @@ def __call__(
autoencoder_model: first stage model.
diffusion_model: diffusion model.
noise: random noise, of the same shape as the latent representation.
timesteps: random timesteps.
condition: conditioning for network input.
"""
with torch.no_grad():
Expand All @@ -142,6 +144,7 @@ def __call__(
inputs=latent,
diffusion_model=diffusion_model,
noise=noise,
timesteps=timesteps,
condition=condition,
)

Expand Down
37 changes: 36 additions & 1 deletion generative/networks/schedulers/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class DDIMScheduler(nn.Module):
steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type: prediction type of the scheduler function, one of `epsilon` (predicting the noise of the
diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""

def __init__(
Expand All @@ -66,6 +69,7 @@ def __init__(
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
) -> None:
super().__init__()
self.beta_schedule = beta_schedule
Expand All @@ -79,6 +83,12 @@ def __init__(
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

if prediction_type.lower() not in ["epsilon", "sample", "v_prediction"]:
raise ValueError(
f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`"
)

self.prediction_type = prediction_type
self.num_train_timesteps = num_train_timesteps
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
Expand Down Expand Up @@ -171,7 +181,14 @@ def step(

# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
if self.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.prediction_type == "sample":
pred_original_sample = model_output
elif self.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample

# 4. Clip "predicted x_0"
if self.clip_sample:
Expand Down Expand Up @@ -231,3 +248,21 @@ def add_noise(

noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)

sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
32 changes: 30 additions & 2 deletions generative/networks/schedulers/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
beta_schedule: str = "linear",
variance_type: str = "fixed_small",
clip_sample: bool = True,
prediction_type: str = "epsilon",
) -> None:
super().__init__()
self.beta_schedule = beta_schedule
Expand All @@ -74,6 +75,13 @@ def __init__(
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

if prediction_type.lower() not in ["epsilon", "sample", "v_prediction"]:
raise ValueError(
f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`"
)

self.prediction_type = prediction_type

self.num_train_timesteps = num_train_timesteps
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
Expand Down Expand Up @@ -170,10 +178,12 @@ def step(

# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon:
if self.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else:
elif self.prediction_type == "sample":
pred_original_sample = model_output
elif self.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output

# 3. Clip "predicted x_0"
if self.clip_sample:
Expand Down Expand Up @@ -233,3 +243,21 @@ def add_noise(

noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)

sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
3 changes: 2 additions & 1 deletion tests/test_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def test_call(self, model_params, input_shape):
)
inferer = DiffusionInferer(scheduler=scheduler)
scheduler.set_timesteps(num_inference_steps=10)
sample = inferer(inputs=input, noise=noise, diffusion_model=model)
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
sample = inferer(inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps)
self.assertEqual(sample.shape, input_shape)

@parameterized.expand(TEST_CASES)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_latent_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def test_prediction_shape(self, model_type, autoencoder_params, stage_2_params,
)
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
scheduler.set_timesteps(num_inference_steps=10)
prediction = inferer(inputs=input, autoencoder_model=autoencoder_model, diffusion_model=stage_2, noise=noise)
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
prediction = inferer(
inputs=input, autoencoder_model=autoencoder_model, diffusion_model=stage_2, noise=noise, timesteps=timesteps
)
self.assertEqual(prediction.shape, latent_shape)

@parameterized.expand(TEST_CASES)
Expand Down
12 changes: 10 additions & 2 deletions tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -780,8 +780,13 @@
" # Generate random noise\n",
" noise = torch.randn_like(images).to(device)\n",
"\n",
" # Create timesteps\n",
" timesteps = torch.randint(\n",
" 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device\n",
" ).long()\n",
"\n",
" # Get model prediction\n",
" noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise)\n",
" noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)\n",
"\n",
" loss = F.mse_loss(noise_pred.float(), noise.float())\n",
"\n",
Expand All @@ -806,7 +811,10 @@
" with torch.no_grad():\n",
" with autocast(enabled=True):\n",
" noise = torch.randn_like(images).to(device)\n",
" noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise)\n",
" timesteps = torch.randint(\n",
" 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device\n",
" ).long()\n",
" noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)\n",
" val_loss = F.mse_loss(noise_pred.float(), noise.float())\n",
"\n",
" val_epoch_loss += val_loss.item()\n",
Expand Down
12 changes: 10 additions & 2 deletions tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,13 @@
# Generate random noise
noise = torch.randn_like(images).to(device)

# Create timesteps
timesteps = torch.randint(
0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
).long()

# Get model prediction
noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise)
noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)

loss = F.mse_loss(noise_pred.float(), noise.float())

Expand All @@ -233,7 +238,10 @@
with torch.no_grad():
with autocast(enabled=True):
noise = torch.randn_like(images).to(device)
noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise)
timesteps = torch.randint(
0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
).long()
noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)
val_loss = F.mse_loss(noise_pred.float(), noise.float())

val_epoch_loss += val_loss.item()
Expand Down
16 changes: 11 additions & 5 deletions tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,9 @@
")\n",
"model.to(device)\n",
"\n",
"num_train_timesteps = 1000\n",
"scheduler = DDPMScheduler(\n",
" num_train_timesteps=1000,\n",
" num_train_timesteps=num_train_timesteps,\n",
")\n",
"\n",
"optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5)\n",
Expand Down Expand Up @@ -433,13 +434,17 @@
"\n",
" \"\"\"\n",
"\n",
" def __init__(self, condition_name: Optional[str] = None):\n",
" def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None):\n",
" self.condition_name = condition_name\n",
" self.num_train_timesteps = num_train_timesteps\n",
"\n",
" def get_noise(self, images):\n",
" \"\"\"Returns the noise tensor for input tensor `images`, override this for different noise distributions.\"\"\"\n",
" return torch.randn_like(images)\n",
"\n",
" def get_timesteps(self, images):\n",
" return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()\n",
"\n",
" def __call__(\n",
" self,\n",
" batchdata: Dict[str, torch.Tensor],\n",
Expand All @@ -449,8 +454,9 @@
" ):\n",
" images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)\n",
" noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)\n",
" timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)\n",
"\n",
" kwargs = {\"noise\": noise}\n",
" kwargs = {\"noise\": noise, \"timesteps\": timesteps}\n",
"\n",
" if self.condition_name is not None and isinstance(batchdata, Mapping):\n",
" kwargs[\"conditioning\"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)\n",
Expand Down Expand Up @@ -2159,7 +2165,7 @@
" val_data_loader=val_loader,\n",
" network=model,\n",
" inferer=inferer,\n",
" prepare_batch=DiffusionPrepareBatch(),\n",
" prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps),\n",
" key_val_metric={\"val_mean_abs_error\": MeanAbsoluteError(output_transform=from_engine([\"pred\", \"label\"]))},\n",
" val_handlers=val_handlers,\n",
")\n",
Expand All @@ -2178,7 +2184,7 @@
" optimizer=optimizer,\n",
" loss_function=torch.nn.MSELoss(),\n",
" inferer=inferer,\n",
" prepare_batch=DiffusionPrepareBatch(),\n",
" prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps),\n",
" key_train_metric={\"train_acc\": MeanSquaredError(output_transform=from_engine([\"pred\", \"label\"]))},\n",
" train_handlers=train_handlers,\n",
")\n",
Expand Down
16 changes: 11 additions & 5 deletions tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,9 @@
)
model.to(device)

num_train_timesteps = 1000
scheduler = DDPMScheduler(
num_train_timesteps=1000,
num_train_timesteps=num_train_timesteps,
)

optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5)
Expand All @@ -203,13 +204,17 @@ class DiffusionPrepareBatch(PrepareBatch):

"""

def __init__(self, condition_name: Optional[str] = None):
def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None):
self.condition_name = condition_name
self.num_train_timesteps = num_train_timesteps

def get_noise(self, images):
"""Returns the noise tensor for input tensor `images`, override this for different noise distributions."""
return torch.randn_like(images)

def get_timesteps(self, images):
return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()

def __call__(
self,
batchdata: Dict[str, torch.Tensor],
Expand All @@ -219,8 +224,9 @@ def __call__(
):
images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)
timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)

kwargs = {"noise": noise}
kwargs = {"noise": noise, "timesteps": timesteps}

if self.condition_name is not None and isinstance(batchdata, Mapping):
kwargs["conditioning"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)
Expand All @@ -244,7 +250,7 @@ def __call__(
val_data_loader=val_loader,
network=model,
inferer=inferer,
prepare_batch=DiffusionPrepareBatch(),
prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps),
key_val_metric={"val_mean_abs_error": MeanAbsoluteError(output_transform=from_engine(["pred", "label"]))},
val_handlers=val_handlers,
)
Expand All @@ -263,7 +269,7 @@ def __call__(
optimizer=optimizer,
loss_function=torch.nn.MSELoss(),
inferer=inferer,
prepare_batch=DiffusionPrepareBatch(),
prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps),
key_train_metric={"train_acc": MeanSquaredError(output_transform=from_engine(["pred", "label"]))},
train_handlers=train_handlers,
)
Expand Down
Loading