Skip to content

Commit

Permalink
Merge pull request AUTOMATIC1111#4 from xraxra/master
Browse files Browse the repository at this point in the history
weighted prompts similar to Midjourney
  • Loading branch information
hlky committed Aug 24, 2022
2 parents 5ae104b + f2f866e commit eae4ec8
Showing 1 changed file with 76 additions and 8 deletions.
84 changes: 76 additions & 8 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def check_prompt_length(prompt, comments):
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")


def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False):
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, normalize_prompt_weights=True):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
assert prompt is not None
torch_gc()
Expand Down Expand Up @@ -452,7 +452,26 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
uc = model.get_learned_conditioning(len(prompts) * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)

# split the prompt if it has : for weighting
# TODO for speed it might help to have this occur when all_prompts filled??
subprompts,weights = split_weighted_subprompts(prompts[0])
# get total weight for normalizing, this gets weird if large negative values used
totalPromptWeight = sum(weights)

# sub-prompt weighting used if more than 1
if len(subprompts) > 1:
c = torch.zeros_like(uc) # i dont know if this is correct.. but it works
for i in range(0,len(subprompts)): # normalize each prompt and add it
weight = weights[i]
if normalize_prompt_weights:
weight = weight / totalPromptWeight
#print(f"{subprompts[i]} {weight*100.0}%")
# note if alpha negative, it functions same as torch.sub
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just behave like usual
c = model.get_learned_conditioning(prompts)

shape = [opt_C, height // opt_f, width // opt_f]

# we manually generate all input noises because each one should have a specific seed
Expand Down Expand Up @@ -518,7 +537,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
return output_images, seed, info, stats


def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, skip_grid: bool, skip_save: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, skip_grid: bool, skip_save: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, normalize_prompt_weights: bool):
outpath = opt.outdir or "outputs/txt2img-samples"
err = False
seed = seed_to_int(seed)
Expand Down Expand Up @@ -556,8 +575,8 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN

use_GFPGAN=use_GFPGAN,
normalize_prompt_weights=normalize_prompt_weights
)

del sampler
Expand Down Expand Up @@ -631,6 +650,7 @@ def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
gr.Textbox(label="Seed ('random' to randomize)", lines=1, value="random"),
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
gr.Checkbox(label="Normalize Prompt Weights (ensure sum of weights add up to 1.0)", value=True),
],
outputs=[
gr.Gallery(label="Images"),
Expand All @@ -645,7 +665,7 @@ def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
)


def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix, loopback: bool, skip_grid: bool, skip_save: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix, loopback: bool, skip_grid: bool, skip_save: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, normalize_prompt_weights: bool):
outpath = opt.outdir or "outputs/img2img-samples"
err = False
seed = seed_to_int(seed)
Expand Down Expand Up @@ -758,7 +778,8 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN
use_GFPGAN=use_GFPGAN,
normalize_prompt_weights=normalize_prompt_weights
)

del sampler
Expand Down Expand Up @@ -797,7 +818,8 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
gr.Textbox(label="Seed ('random' to randomize)", lines=1, value="random"),
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize"),
gr.Checkbox(label="Normalize Prompt Weights (ensure sum of weights add up to 1.0)", value=True),
],
outputs=[
gr.Gallery(),
Expand All @@ -816,6 +838,52 @@ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name)
(img2img_interface, "img2img")
]

# grabs all text up to the first occurrence of ':' as sub-prompt
# takes the value following ':' as weight
# if ':' has no value defined, defaults to 1.0
# repeats until no text remaining
# TODO this could probably be done with less code
def split_weighted_subprompts(text):
print(text)
remaining = len(text)
prompts = []
weights = []
while remaining > 0:
if ":" in text:
idx = text.index(":") # first occurrence from start
# grab up to index as sub-prompt
prompt = text[:idx]
remaining -= idx
# remove from main text
text = text[idx+1:]
# find value for weight, assume it is followed by a space or comma
idx = len(text) # default is read to end of text
if " " in text:
idx = min(idx,text.index(" ")) # want the closer idx
if "," in text:
idx = min(idx,text.index(",")) # want the closer idx
if idx != 0:
try:
weight = float(text[:idx])
except: # couldn't treat as float
print(f"Warning: '{text[:idx]}' is not a value, are you missing a space or comma after a value?")
weight = 1.0
else: # no value found
weight = 1.0
# remove from main text
remaining -= idx
text = text[idx+1:]
# append the sub-prompt and its weight
prompts.append(prompt)
weights.append(weight)
else: # no : found
if len(text) > 0: # there is still text though
# take remainder as weight 1
prompts.append(text)
weights.append(1.0)
remaining = 0
return prompts, weights

def run_GFPGAN(image, strength):
image = image.convert("RGB")

Expand Down

0 comments on commit eae4ec8

Please sign in to comment.