Skip to content

Commit

Permalink
add CFG denoiser implementation for DDIM, PLMS and UniPC (this is the…
Browse files Browse the repository at this point in the history
… commit when you can run both old and new implementations to compare them)
  • Loading branch information
AUTOMATIC1111 committed Aug 8, 2023
1 parent 2d8e4a6 commit 8285a14
Show file tree
Hide file tree
Showing 6 changed files with 455 additions and 172 deletions.
3 changes: 2 additions & 1 deletion modules/sd_samplers.py
@@ -1,11 +1,12 @@
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, sd_samplers_timesteps, shared

# imports for functions that previously were here and are used by other modules
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401

all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
*sd_samplers_compvis.samplers_data_compvis,
*sd_samplers_timesteps.samplers_data_timesteps,
]
all_samplers_map = {x.name: x for x in all_samplers}

Expand Down
50 changes: 18 additions & 32 deletions modules/sd_samplers_cfg_denoiser.py
Expand Up @@ -39,7 +39,7 @@ class CFGDenoiser(torch.nn.Module):
negative prompt.
"""

def __init__(self, model):
def __init__(self, model, sampler):
super().__init__()
self.inner_model = model
self.mask = None
Expand All @@ -48,6 +48,7 @@ def __init__(self, model):
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
self.sampler = sampler

def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
Expand All @@ -65,6 +66,9 @@ def combine_denoised_for_edit_model(self, x_out, cond_scale):

return denoised

def get_pred_x0(self, x_in, x_out, sigma):
return x_out

def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
Expand All @@ -78,6 +82,9 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):

assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"

if self.mask is not None:
x = self.init_latent * self.mask + self.nmask * x

batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]

Expand Down Expand Up @@ -170,20 +177,23 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):

devices.test_for_nans(x_out, "unet")

if opts.live_preview_content == "Prompt":
sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
elif opts.live_preview_content == "Negative prompt":
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])

if is_edit_model:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
elif skip_uncond:
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)

if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)

if opts.live_preview_content == "Prompt":
preview = self.sampler.last_latent
elif opts.live_preview_content == "Negative prompt":
preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
else:
preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)

sd_samplers_common.store_latent(preview)

after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
cfg_after_cfg_callback(after_cfg_callback_params)
Expand All @@ -192,27 +202,3 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
self.step += 1
return denoised


class TorchHijack:
def __init__(self, sampler_noises):
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
# implementation.
self.sampler_noises = deque(sampler_noises)

def __getattr__(self, item):
if item == 'randn_like':
return self.randn_like

if hasattr(torch, item):
return getattr(torch, item)

raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")

def randn_like(self, x):
if self.sampler_noises:
noise = self.sampler_noises.popleft()
if noise.shape == x.shape:
return noise

return devices.randn_like(x)

140 changes: 139 additions & 1 deletion modules/sd_samplers_common.py
@@ -1,9 +1,11 @@
from collections import namedtuple
import inspect
from collections import namedtuple, deque
import numpy as np
import torch
from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
from modules.shared import opts, state
import k_diffusion.sampling

SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])

Expand Down Expand Up @@ -127,3 +129,139 @@ def torchsde_randn(size, dtype, device, seed):


replace_torchsde_browinan()


class TorchHijack:
def __init__(self, sampler_noises):
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
# implementation.
self.sampler_noises = deque(sampler_noises)

def __getattr__(self, item):
if item == 'randn_like':
return self.randn_like

if hasattr(torch, item):
return getattr(torch, item)

raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")

def randn_like(self, x):
if self.sampler_noises:
noise = self.sampler_noises.popleft()
if noise.shape == x.shape:
return noise

return devices.randn_like(x)


class Sampler:
def __init__(self, funcname):
self.funcname = funcname
self.func = funcname
self.extra_params = []
self.sampler_noises = None
self.stop_at = None
self.eta = None
self.config = None # set by the function calling the constructor
self.last_latent = None
self.s_min_uncond = None
self.s_churn = 0.0
self.s_tmin = 0.0
self.s_tmax = float('inf')
self.s_noise = 1.0

self.eta_option_field = 'eta_ancestral'
self.eta_infotext_field = 'Eta'

self.conditioning_key = shared.sd_model.model.conditioning_key

self.model_wrap = None
self.model_wrap_cfg = None

def callback_state(self, d):
step = d['i']

if self.stop_at is not None and step > self.stop_at:
raise InterruptedException

state.sampling_step = step
shared.total_tqdm.update()

def launch_sampling(self, steps, func):
state.sampling_steps = steps
state.sampling_step = 0

try:
return func()
except RecursionError:
print(
'Encountered RecursionError during sampling, returning last latent. '
'rho >5 with a polyexponential scheduler may cause this error. '
'You should try to use a smaller rho value instead.'
)
return self.last_latent
except InterruptedException:
return self.last_latent

def number_of_needed_noises(self, p):
return p.steps

def initialize(self, p) -> dict:
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.step = 0
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)

k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])

extra_params_kwargs = {}
for param_name in self.extra_params:
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
extra_params_kwargs[param_name] = getattr(p, param_name)

if 'eta' in inspect.signature(self.func).parameters:
if self.eta != 1.0:
p.extra_generation_params[self.eta_infotext_field] = self.eta

extra_params_kwargs['eta'] = self.eta

if len(self.extra_params) > 0:
s_churn = getattr(opts, 's_churn', p.s_churn)
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
s_noise = getattr(opts, 's_noise', p.s_noise)

if s_churn != self.s_churn:
extra_params_kwargs['s_churn'] = s_churn
p.s_churn = s_churn
p.extra_generation_params['Sigma churn'] = s_churn
if s_tmin != self.s_tmin:
extra_params_kwargs['s_tmin'] = s_tmin
p.s_tmin = s_tmin
p.extra_generation_params['Sigma tmin'] = s_tmin
if s_tmax != self.s_tmax:
extra_params_kwargs['s_tmax'] = s_tmax
p.s_tmax = s_tmax
p.extra_generation_params['Sigma tmax'] = s_tmax
if s_noise != self.s_noise:
extra_params_kwargs['s_noise'] = s_noise
p.s_noise = s_noise
p.extra_generation_params['Sigma noise'] = s_noise

return extra_params_kwargs

def create_noise_sampler(self, x, sigmas, p):
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
if shared.opts.no_dpmpp_sde_batch_determinism:
return None

from k_diffusion.sampling import BrownianTreeNoiseSampler
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)



0 comments on commit 8285a14

Please sign in to comment.