Skip to content

Commit

Permalink
[DirectML] Fix samplers.
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed May 8, 2024
1 parent e2cbdab commit 88c1224
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 104 deletions.
223 changes: 120 additions & 103 deletions modules/dml/hijack/diffusers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Union, Tuple
import torch
import diffusers
import diffusers.utils.torch_utils
from typing import Optional, Union, Tuple


def PNDMScheduler__get_prev_sample(self, sample: torch.FloatTensor, timestep, prev_timestep, model_output):
Expand All @@ -17,7 +17,7 @@ def PNDMScheduler__get_prev_sample(self, sample: torch.FloatTensor, timestep, pr
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
sample.__str__() # PNDM Sampling does not work without 'stringify'. (because it depends on PLMS)
torch.dml.synchronize_tensor(sample) # DML synchronize
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
Expand Down Expand Up @@ -53,51 +53,68 @@ def PNDMScheduler__get_prev_sample(self, sample: torch.FloatTensor, timestep, pr


def UniPCMultistepScheduler_multistep_uni_p_bh_update(
self: diffusers.UniPCMultistepScheduler,
self,
model_output: torch.FloatTensor,
prev_timestep: int,
sample: torch.FloatTensor,
order: int,
*args,
sample: torch.FloatTensor = None,
order: int = None,
**kwargs,
) -> torch.FloatTensor:
"""
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
Args:
model_output (`torch.FloatTensor`):
direct outputs from learned diffusion model at the current timestep.
prev_timestep (`int`): previous discrete timestep in the diffusion chain.
The direct output from the learned diffusion model at the current timestep.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
order (`int`): the order of UniP at this step, also the p in UniPC-p.
A current instance of a sample created by the diffusion process.
order (`int`):
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
Returns:
`torch.FloatTensor`: the sample tensor at the previous timestep.
`torch.FloatTensor`:
The sample tensor at the previous timestep.
"""
timestep_list = self.timestep_list
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError(" missing `sample` as a required keyward argument")
if order is None:
if len(args) > 2:
order = args[2]
else:
raise ValueError(" missing `order` as a required keyward argument")
model_output_list = self.model_outputs

s0, t = self.timestep_list[-1], prev_timestep
s0 = self.timestep_list[-1]
m0 = model_output_list[-1]
x = sample

if self.solver_p:
x_t = self.solver_p.step(model_output, s0, x).prev_sample
return x_t

sample.__str__() # UniPC Sampling does not work without 'stringify'.
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
torch.dml.synchronize_tensor(sample) # DML synchronize
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)

lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)

h = lambda_t - lambda_s0
device = sample.device

rks = []
D1s = []
for i in range(1, order):
si = timestep_list[-(i + 1)]
si = self.step_index - i
mi = model_output_list[-(i + 1)]
lambda_si = self.lambda_t[si]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
Expand Down Expand Up @@ -143,14 +160,14 @@ def UniPCMultistepScheduler_multistep_uni_p_bh_update(
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - alpha_t * B_h * pred_res
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - sigma_t * B_h * pred_res
Expand All @@ -170,91 +187,91 @@ def LCMScheduler_step(
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[diffusers.schedulers.scheduling_lcm.LCMSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)

if self.step_index is None:
self._init_step_index(timestep)

# 1. get previous step value
prev_step_index = self.step_index + 1
if prev_step_index < len(self.timesteps):
prev_timestep = self.timesteps[prev_step_index]
else:
prev_timestep = timestep

# 2. compute alphas, betas
sample.__str__()
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod

beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

# 3. Get scalings for boundary conditions
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)

# 4. Compute the predicted original sample x_0 based on the model parameterization
if self.config.prediction_type == "epsilon": # noise-prediction
predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
elif self.config.prediction_type == "sample": # x-prediction
predicted_original_sample = model_output
elif self.config.prediction_type == "v_prediction": # v-prediction
predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
" `v_prediction` for `LCMScheduler`."
)

# 5. Clip or threshold "predicted x_0"
if self.config.thresholding:
predicted_original_sample = self._threshold_sample(predicted_original_sample)
elif self.config.clip_sample:
predicted_original_sample = predicted_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)

# 6. Denoise model output using boundary conditions
denoised = c_out * predicted_original_sample + c_skip * sample

# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
# Noise is not used for one-step sampling.
if len(self.timesteps) > 1:
noise = diffusers.utils.torch_utils.randn_tensor(model_output.shape, generator=generator, device=model_output.device)
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
else:
prev_sample = denoised
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)

if self.step_index is None:
self._init_step_index(timestep)

# 1. get previous step value
prev_step_index = self.step_index + 1
if prev_step_index < len(self.timesteps):
prev_timestep = self.timesteps[prev_step_index]
else:
prev_timestep = timestep

# 2. compute alphas, betas
torch.dml.synchronize_tensor(sample) # DML synchronize
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod

beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

# 3. Get scalings for boundary conditions
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)

# 4. Compute the predicted original sample x_0 based on the model parameterization
if self.config.prediction_type == "epsilon": # noise-prediction
predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
elif self.config.prediction_type == "sample": # x-prediction
predicted_original_sample = model_output
elif self.config.prediction_type == "v_prediction": # v-prediction
predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
" `v_prediction` for `LCMScheduler`."
)

# 5. Clip or threshold "predicted x_0"
if self.config.thresholding:
predicted_original_sample = self._threshold_sample(predicted_original_sample)
elif self.config.clip_sample:
predicted_original_sample = predicted_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)

# 6. Denoise model output using boundary conditions
denoised = c_out * predicted_original_sample + c_skip * sample

# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
# Noise is not used for one-step sampling.
if len(self.timesteps) > 1:
noise = diffusers.utils.torch_utils.randn_tensor(model_output.shape, generator=generator, device=model_output.device)
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
else:
prev_sample = denoised

# upon completion increase step index by one
self._step_index += 1
# upon completion increase step index by one
self._step_index += 1

if not return_dict:
return (prev_sample, denoised)
if not return_dict:
return (prev_sample, denoised)

return diffusers.schedulers.scheduling_lcm.LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
return diffusers.schedulers.scheduling_lcm.LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)


diffusers.LCMScheduler.step = LCMScheduler_step
2 changes: 1 addition & 1 deletion modules/dml/hijack/stablediffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
alphas[index].__str__() # synchronize DML device
torch.dml.synchronize_tensor(alphas[index]) # DML synchronize
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
Expand Down

0 comments on commit 88c1224

Please sign in to comment.