Skip to content

Commit

Permalink
try fix key error
Browse files Browse the repository at this point in the history
move some code around and make sure call payloadformatter setup(), also cleanup comments
  • Loading branch information
Kilvoctu committed Oct 20, 2022
1 parent 8534a11 commit 5b7c713
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 42 deletions.
41 changes: 2 additions & 39 deletions core/PayloadFormatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,21 @@
from enum import Enum
import platform
from core import stablecog
responsestr = {}

responsestr = {}

# only need to get the schema once
def setup():
global responsestr
response_format = requests.get("http://127.0.0.1:7860/config")
responsestr = response_format.json()


# prob don't need to do this lmao
class PayloadFormat(Enum):
TXT2IMG = 0
IMG2IMG = 1
UPSCALE = 2


def do_format(StableCog, payload_format: PayloadFormat):

# dependencies have ids that point to components. these components (usually) have a label (like "Sampling steps")
# and a default value (like "20"). we find the dependency we want (key "js" must have value "submit" for txt2img,
# "submit_img2img" for img2img, and "get_extras_tab_index" for upscale).
# then iterate through the ids in that dependency and match them with the corresponding id in the components.
# store the label:value pairs in txt2imgjson.
# example:
# {"components":[
# { "id": 6,
# "props":{
# "label":"Prompt",
# "value":""
# }
# }, etc ],
# "dependencies":[
# { "inputs":{
# 6,etc
# },
# "js":"submit", etc
# }]
# }
#
# dict["dependencies"]["input"][0] equals 6 which is the id of the component for Prompt
setup()
dependenciesjson = responsestr["dependencies"]
componentsjson = responsestr["components"]
dependencylist = []
Expand All @@ -61,11 +35,7 @@ def do_format(StableCog, payload_format: PayloadFormat):
dependencylist.append(i.copy())
except:
dependencylist.append(i)
# later on, json payload uses the function index to determine what parameters to accept.
# function index is the position in dependencies in the schema that the function appears,
# so txt2img is the 13th function (in this version, could change in the future)
if dependenciesjson[dep]["js"] == "submit" and txt2img_fn_index == 0:
# not sure if it's different on linux but this is a guess
txt2img_fn_index = dep
elif dependenciesjson[dep]["js"] == "submit_img2img" and img2img_fn_index == 0:
img2img_fn_index = dep
Expand All @@ -75,30 +45,23 @@ def do_format(StableCog, payload_format: PayloadFormat):
for identifier in dependencylist:
for component in componentsjson:
if identifier == component["id"]:
# one of the labels is empty
if component["props"].get("name") == "label":
labelvaluetuplelist.append(("", 0))
# img2img has a duplicate label that messes things up
elif component["props"].get("label") == "Image for img2img" and component["props"].get("elem_id") != "img2img_image":
labelvaluetuplelist.append(("", None))
# upscale has a duplicate label that messes things up
elif component["props"].get("label") == "Source" and component["props"].get("elem_id") == "pnginf_image":
labelvaluetuplelist.append(("", None))
# only gonna use the one upscaler, idc
elif component["props"].get("label") == "Upscaler 1":
labelvaluetuplelist.append((component["props"].get("label"), "ESRGAN_4x"))
# slightly changing the img2img Script label so it doesn't clash with another label of the same name
elif component["props"].get("label") == "Script" and len(component["props"].get("choices")) > 3:
labelvaluetuplelist.append(("Scripts", "None"))
elif component["props"].get("label") == "Sampling method":
labelvaluetuplelist.append(("Sampling method", "Euler a"))
StableCog.sampling_methods = component["props"].get("choices")
# these are the labels and values we actually care about
else:
labelvaluetuplelist.append((component["props"].get("label"), component["props"].get("value")))
break

# iterate through txt2imgjson, find a label you're looking for, and store the index for later use by StableCog
for i in range(0, len(labelvaluetuplelist)):
if labelvaluetuplelist[i][0] == "Prompt":
StableCog.prompt_ind = i
Expand Down
6 changes: 3 additions & 3 deletions core/stablecog.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, bot):
self.wait_message = []
self.bot = bot
self.url = 'http://127.0.0.1:7860/api/predict'
#initialize indices for PayloadFormatter
#initialize indices for PayloadFormatter then update
self.prompt_ind = 0
self.exclude_ind = 0
self.sample_ind = 0
Expand All @@ -54,6 +54,7 @@ def __init__(self, bot):
self.seed_ind = 0
self.denoise_ind = 0
self.data_ind = 0
PayloadFormatter.do_format(self, PayloadFormatter.PayloadFormat.TXT2IMG)

@commands.slash_command(name = "draw", description = "Create an image")
@option(
Expand Down Expand Up @@ -130,8 +131,7 @@ async def dream_handler(self, ctx: discord.ApplicationContext, *,
strength: Optional[float] = 0.75,
init_image: Optional[discord.Attachment] = None,):
print(f'Request -- {ctx.author.name}#{ctx.author.discriminator} -- Prompt: {prompt}')
#apply indices from PayloadFormatter and confirm
PayloadFormatter.do_format(self, PayloadFormatter.PayloadFormat.TXT2IMG)
#confirm indices from PayloadFormatter
print(f'Indices-prompt:{self.prompt_ind}, exclude:{self.exclude_ind}, steps:{self.sample_ind}, height:{self.resy_ind}, width:{self.resx_ind}, cfg scale:{self.conform_ind}, sampler:{self.sampling_methods_ind}, seed:{self.seed_ind}')
if init_image:
PayloadFormatter.do_format(self, PayloadFormatter.PayloadFormat.IMG2IMG)
Expand Down

0 comments on commit 5b7c713

Please sign in to comment.