Skip to content

Commit

Permalink
Add function/tool trigger and generation mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Trojaner committed Apr 17, 2024
1 parent 61687fb commit b69c265
Show file tree
Hide file tree
Showing 5 changed files with 362 additions and 91 deletions.
204 changes: 177 additions & 27 deletions ext_modules/image_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import time
from datetime import date
from pathlib import Path
from typing import cast
from typing import Any, cast
from partial_json_parser import loads
from PIL import Image
from webuiapi import WebUIApiResult
from modules.logging_colors import logger
from ..context import GenerationContext
Expand Down Expand Up @@ -39,6 +41,11 @@ def normalize_prompt(prompt: str) -> str:
.replace("!", ",")
.replace("?", ",")
.replace("&", "")
.replace("\r", "")
.replace("\n", ", ")
.replace("*", "")
.replace("#", "")
.replace(".,", ",")
.replace(",,", ",")
.replace(", ,", ",")
.replace(";", ",")
Expand All @@ -54,7 +61,7 @@ def normalize_prompt(prompt: str) -> str:

def generate_html_images_for_context(
context: GenerationContext,
) -> tuple[str | None, str, str, str, str]:
) -> tuple[str, str | None, str | None, str | None, str | None, str | None]:
"""
Generates images for the given context using Stable Diffusion
and returns the result as HTML output
Expand All @@ -64,6 +71,8 @@ def generate_html_images_for_context(

sd_client = context.sd_client

output_text = context.output_text or ""

rules_prompt = ""
rules_negative_prompt = ""

Expand Down Expand Up @@ -104,23 +113,21 @@ def generate_html_images_for_context(
]

if (
context.output_text
and context.output_text != ""
output_text
and output_text != ""
and RegexGenerationRuleMatch.OUTPUT.value in rule["match"]
):
match_against.append(html.unescape(context.output_text).strip())
match_against.append(html.unescape(output_text).strip())

if (
context.output_text
and context.output_text != ""
output_text
and output_text != ""
and RegexGenerationRuleMatch.OUTPUT_SENTENCE.value
in rule["match"]
):
match_against += [
x.strip()
for x in re.split(
delimiters_regex_pattern, context.output_text
)
for x in re.split(delimiters_regex_pattern, output_text)
if x.strip() != ""
]

Expand Down Expand Up @@ -154,6 +161,7 @@ def generate_html_images_for_context(
for action in rule["actions"]:
if action["name"] == "skip_generation":
return (
output_text,
None,
"",
"",
Expand Down Expand Up @@ -197,20 +205,157 @@ def generate_html_images_for_context(
exc_info=True,
)

context_prompt = ""
context_prompt = None

if context.params.trigger_mode == TriggerMode.INTERACTIVE and (
context.params.interactive_mode_prompt_generation_mode
== InteractiveModePromptGenerationMode.GENERATED_TEXT
or InteractiveModePromptGenerationMode.DYNAMIC
):
context_prompt = html.unescape(context.output_text or "")
context_prompt = html.unescape(output_text or "")

if context.params.trigger_mode == TriggerMode.CONTINUOUS and (
context.params.continuous_mode_prompt_generation_mode
== ContinuousModePromptGenerationMode.GENERATED_TEXT
):
context_prompt = html.unescape(context.output_text or "")
context_prompt = html.unescape(output_text or "")

if context.params.trigger_mode == TriggerMode.TOOL:
output_text = html.unescape(output_text or "").strip()
logger.info("\n\noutput_text:\n%s\n\n", output_text)

json_search = re.search(
r"(\b)?([{\[].*[\]}])(\b)?", output_text, flags=re.I | re.M | re.S | re.U
)

if not json_search:
logger.warning(
"No JSON output found in the output text: %s.\nTry enabling JSON grammar rules to avoid such errors.",
output_text,
)

json_text_original = json_search.group(0) if json_search else "{}"

try:
json_text = (
json_text_original.strip()
.replace("\r\n", "\n")
.replace("'", "")
.replace("“", '"') # yes, this actually happened.
.replace("”", '"') # llms are really creative and crazy...
.replace(
"{{", "{ {"
) # for some reason the json parser doesnt like this
.replace("}}", "} }")
)
except Exception as e:
logger.warning(
"JSON extraction from text failed: %s\n%s.\n\nTry enabling JSON grammar rules to avoid such errors.",
repr(e),
output_text,
)

json_text = "{}"

output_text = (
output_text.replace(json_text_original + "\n", "")
.replace("\n" + json_text_original, "")
.replace(json_text_original, "")
.replace("Action: ```json\n", "")
.replace("Action: ```json", "")
.replace("Action:\n", "")
.replace("Action:", "")
.replace("\n```json", "")
.replace("```json", "")
.replace("```json\n", "")
.replace("\n```", "")
.replace("```", "")
.strip("\r\n")
.strip("\n")
.strip()
)

json = None

if json_search and json_text and json_text not in ["[]", "{}", "()"]:
try:
json = loads(json_text)
except Exception as e:
logger.warning(
"Failed to parse JSON from output text: %s\n%s\n\nTry enabling JSON grammar rules to avoid such errors.",
repr(e),
json_text,
exc_info=True,
)

if json is not None:
tools: list[Any] = json if isinstance(json, list) else [json]

for tool in tools:
tool_name: str = (
tool.get("tool", None)
or tool.get("tool name", None)
or tool.get("tool_name", None)
or tool.get("tool call", None)
or tool.get("tool_call", None)
or tool.get("name", None)
or tool.get("function", None)
or tool.get("function_name", None)
or tool.get("function name", None)
or tool.get("function_call", None)
or tool.get("function call", None)
)

tool_params: dict = (
tool.get("tool_parameters", None)
or tool.get("tool parameters", None)
or tool.get("parameters", None)
or tool.get("tool_params", None)
or tool.get("tool params", None)
or tool.get("params", None)
or tool.get("tool_arguments", None)
or tool.get("tool arguments", None)
or tool.get("arguments", None)
or tool.get("tool_args", None)
or tool.get("tool args", None)
or tool.get("args", None)
)

if not tool_name or not tool_params:
continue

if tool_name.lower() in [
"generate_image",
"generate image",
"generateimage",
]:
context_prompt = (
tool_params.get("text", None)
or tool_params.get("prompt", None)
or tool_params.get("query", None)
or ""
)

if tool_name.lower() in ["add_text", "add text", "addtext"]:
tool_text = (
tool_params.get("text", None)
or tool_params.get("prompt", None)
or tool_params.get("query", None)
or ""
)
output_text = tool_text + (
"\n" + output_text if output_text else ""
)

if context_prompt is None:
return (
output_text,
None,
None,
None,
None,
None,
)

if ":" in context_prompt:
context_prompt = (
Expand All @@ -234,18 +379,19 @@ def generate_html_images_for_context(
generated_negative_prompt, context.params.base_negative_prompt
)

logger.info
(
"[SD WebUI Integration] Using stable-diffusion-webui to generate images."
+ (
(
f"\n"
f" Prompt: {full_prompt}\n"
f" Negative Prompt: {full_negative_prompt}"
)
if context.params.debug_mode_enabled
else ""
debug_info = (
(
f"\n"
f" Prompt: {full_prompt}\n"
f" Negative Prompt: {full_negative_prompt}"
)
if context.params.debug_mode_enabled
else ""
)

logger.info(
"[SD WebUI Integration] Using stable-diffusion-webui to generate images. %s",
debug_info,
)

try:
Expand All @@ -262,9 +408,11 @@ def generate_html_images_for_context(
denoising_strength=context.params.hires_fix_denoising_strength,
hr_sampler=context.params.hires_fix_sampler,
hr_force=context.params.hires_fix_enabled,
hr_second_pass_steps=context.params.hires_fix_sampling_steps
if context.params.hires_fix_enabled
else 0,
hr_second_pass_steps=(
context.params.hires_fix_sampling_steps
if context.params.hires_fix_enabled
else 0
),
steps=context.params.sampling_steps,
cfg_scale=context.params.cfg_scale,
width=context.params.width,
Expand Down Expand Up @@ -292,6 +440,7 @@ def generate_html_images_for_context(
if len(response.images) == 0:
logger.error("[SD WebUI Integration] Failed to generate any images.")
return (
output_text,
None,
generated_prompt,
generated_negative_prompt,
Expand Down Expand Up @@ -404,6 +553,7 @@ def generate_html_images_for_context(
attempt_vram_reallocation(VramReallocationTarget.LLM, context)

return (
output_text,
formatted_result.rstrip("\n"),
generated_prompt,
generated_negative_prompt,
Expand Down
5 changes: 4 additions & 1 deletion params.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


class TriggerMode(str, Enum):
TOOL = "tool"
CONTINUOUS = "continuous"
INTERACTIVE = "interactive"
MANUAL = "manual"
Expand Down Expand Up @@ -176,7 +177,9 @@ class RegexGenerationRule:
@dataclass
class UserPreferencesParams:
save_images: bool = field(default=True)
trigger_mode: TriggerMode = field(default=TriggerMode.INTERACTIVE)
trigger_mode: TriggerMode = field(default=TriggerMode.TOOL)
tool_mode_force_json_output_enabled: bool = field(default=True)
tool_mode_force_json_output_schema: str = field(default="")
interactive_mode_input_trigger_regex: str = field(
default=".*(send|upload|add|show|attach|generate)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)(s?)" # noqa E501
)
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ pylint
Pillow
types-Pillow
types-requests
stringcase
stringcase
partial-json-parser
git+https://github.com/evanrichards/json-schema-logits-processor.git#egg=json-schema-logits-processor
Loading

0 comments on commit b69c265

Please sign in to comment.