Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,23 +504,24 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
"""
Efficient add_noise implementation. Quickly broadcasts mask and avoids device/dtype thrashing.
"""
# Make sure timesteps is on cpu for indexing, then move to correct device as vector
timesteps = timesteps.to(original_samples.device)

sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

# Move alphas_cumprod to correct device/dtype if needed
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)

# Indexing into vector: this produces shape (batch,)
sqrt_alpha_prod = alphas_cumprod[timesteps].sqrt()
sqrt_one_minus_alpha_prod = (1.0 - alphas_cumprod[timesteps]).sqrt()

# Expand to broadcast over sample shape
target_shape = [original_samples.shape[0]] + [1] * (original_samples.ndim - 1)
sqrt_alpha_prod = sqrt_alpha_prod.view(*target_shape)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.view(*target_shape)

# Vectorized noisy sample computation
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

Expand Down