Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Oct 5, 2022
1 parent 67d011b commit c26732f
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 19 deletions.
2 changes: 1 addition & 1 deletion modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def infotext(iteration=0, position_in_batch=0):
#c = p.sd_model.get_learned_conditioning(prompts)
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_learned_conditioning(shared.sd_model, prompts, p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)

if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
Expand Down
114 changes: 108 additions & 6 deletions modules/prompt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,26 @@ def get_schedule(prompt):


ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"])


def get_learned_conditioning(model, prompts, steps):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
Input:
(model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
Output:
[
[
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
],
[
ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
]
]
"""
res = []

prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
Expand All @@ -123,13 +139,75 @@ def get_learned_conditioning(model, prompts, steps):
cache[prompt] = cond_schedule
res.append(cond_schedule)

return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res)
return res


re_AND = re.compile(r"\bAND\b")
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?\s*(?:\d+|\d*\.\d+)?))?\s*$")


def get_multicond_prompt_list(prompts):
res_indexes = []

prompt_flat_list = []
prompt_indexes = {}

for prompt in prompts:
subprompts = re_AND.split(prompt)

indexes = []
for subprompt in subprompts:
text, weight = re_weight.search(subprompt).groups()

weight = float(weight) if weight is not None else 1.0

index = prompt_indexes.get(text, None)
if index is None:
index = len(prompt_flat_list)
prompt_flat_list.append(text)
prompt_indexes[text] = index

indexes.append((index, weight))

res_indexes.append(indexes)

return res_indexes, prompt_flat_list, prompt_indexes


class ComposableScheduledPromptConditioning:
def __init__(self, schedules, weight=1.0):
self.schedules: list[ScheduledPromptConditioning] = schedules
self.weight: float = weight


class MulticondLearnedConditioning:
def __init__(self, shape, batch):
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self.batch: list[list[ComposableScheduledPromptConditioning]] = batch


def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
param = c.schedules[0][0].cond
res = torch.zeros(c.shape, device=param.device, dtype=param.dtype)
for i, cond_schedule in enumerate(c.schedules):
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator.
https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
"""

res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)

learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)

res = []
for indexes in res_indexes:
res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])

return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)


def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
for i, cond_schedule in enumerate(c):
target_index = 0
for current, (end_at, cond) in enumerate(cond_schedule):
if current_step <= end_at:
Expand All @@ -140,6 +218,30 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
return res


def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
param = c.batch[0][0].schedules[0].cond

tensors = []
conds_list = []

for batch_no, composable_prompts in enumerate(c.batch):
conds_for_batch = []

for cond_index, composable_prompt in enumerate(composable_prompts):
target_index = 0
for current, (end_at, cond) in enumerate(composable_prompt.schedules):
if current_step <= end_at:
target_index = current
break

conds_for_batch.append((len(tensors), composable_prompt.weight))
tensors.append(composable_prompt.schedules[target_index].cond)

conds_list.append(conds_for_batch)

return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)


re_attention = re.compile(r"""
\\\(|
\\\)|
Expand Down
35 changes: 25 additions & 10 deletions modules/sd_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,12 @@ def number_of_needed_noises(self, p):
return 0

def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)

assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor

if self.mask is not None:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec
Expand Down Expand Up @@ -183,19 +186,31 @@ def __init__(self, model):
self.step = 0

def forward(self, x, sigma, uncond, cond, cond_scale):
cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)

batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]

x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
cond_in = torch.cat([tensor, uncond])

if shared.batch_cond_uncond:
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
denoised = uncond + (cond - uncond) * cond_scale
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
else:
uncond = self.inner_model(x, sigma, cond=uncond)
cond = self.inner_model(x, sigma, cond=cond)
denoised = uncond + (cond - uncond) * cond_scale
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])

denoised_uncond = x_out[-batch_size:]
denoised = torch.clone(denoised_uncond)

for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)

if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
Expand Down
6 changes: 4 additions & 2 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import modules.codeformer_model
import modules.styles
import modules.generation_parameters_copypaste
from modules.prompt_parser import get_learned_conditioning_prompt_schedules
from modules import prompt_parser
from modules.images import apply_filename_pattern, get_next_sequence_number
import modules.textual_inversion.ui

Expand Down Expand Up @@ -394,7 +394,9 @@ def copy_seed(gen_info_string: str, index):

def update_token_counter(text, steps):
try:
prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)

except Exception:
# a parsing error can happen here during typing, and we don't want to bother the user with
# messages related to it in console
Expand Down

6 comments on commit c26732f

@ilikenwf
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks launch for me:

stable-diffusion-webui/modules/prompt_parser.py", line 207, in <module>
    def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
TypeError: 'type' object is not subscriptable

@raefu
Copy link
Contributor

@raefu raefu commented on c26732f Oct 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c: list[list[ScheduledPromptConditioning]] is Python 3.9 syntax. Needs from typing import List; c: List[List[ScheduledPromptConditioning]] to support 3.8.

@ilikenwf
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @raefu - my conda environment had an old symlink.

@ilikenwf
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c: list[list[ScheduledPromptConditioning]] is Python 3.9 syntax. Needs from typing import List; c: List[List[ScheduledPromptConditioning]] to support 3.8.

from typing import List as list

would work universally would it not?

@bmaltais
Copy link

@bmaltais bmaltais commented on c26732f Oct 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the wiki be updated with how to make use of all the recent prompt syntax additions? Hard to know the syntax when there is no documentation for it. I can't find any documentation on the syntax to use this anywhere.

@0924249460
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IJJJK OK

Please sign in to comment.