Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Live preview #234

Closed
wants to merge 13 commits into from
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ To generate a prompt from a couple of words, use the /generate command and inclu

### Currently supported options

- live preview
- negative prompts
- swap model/checkpoint (_[see wiki](https://github.com/Kilvoctu/aiyabot/wiki/Model-swapping)_)
- sampling steps
Expand Down Expand Up @@ -49,7 +50,8 @@ To generate a prompt from a couple of words, use the /generate command and inclu
- 🎲 - randomize seed, then generate a new image with same parameters.
- 📋 - view the generated image's information.
- ⬆️ - upscale the generated image with defaults. Batch grids require use of the drop downs
- ❌ - deletes the generated image.
- ❌ - deletes the generated image. In Live preview this button interrupts generation process
- ➡️ - skips the current image generation in live preview and go to next batch (if there's more than 1)
- dropdown menus - batch images produce two drop down menus for the first 25 images.
- The first menu prompts the bot to send only the images that you select at single images
- The second menu prompts the bot to upscale the selected image from the batch.
Expand Down
67 changes: 65 additions & 2 deletions core/stablecog.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import base64
import contextlib
import threading
import discord
import io
import math
Expand All @@ -15,7 +17,47 @@
from core import viewhandler
from core import settings
from core import settingscog
from threading import Thread

async def update_progress(event_loop, status_message_task, s, queue_object, tries):
status_message = status_message_task.result()
try:
progress_data = s.get(url=f'{settings.global_var.url}/sdapi/v1/progress').json()

if progress_data["current_image"] is None and tries <= 10:
time.sleep(3)
event_loop.create_task(update_progress(event_loop, status_message_task, s, queue_object, tries + 1))
return

if progress_data["current_image"] is None and tries > 10:
return

image = Image.open(io.BytesIO(base64.b64decode(progress_data["current_image"])))

with contextlib.ExitStack() as stack:
buffer = stack.enter_context(io.BytesIO())
image.save(buffer, 'PNG')
buffer.seek(0)
file = discord.File(fp=buffer, filename=f'{queue_object.seed}.png')

ips = '?'
if progress_data["eta_relative"] != 0:
ips = round((int(queue_object.steps) - progress_data["state"]["sampling_step"]) / progress_data["eta_relative"], 2)

view = viewhandler.ProgressView()

await status_message.edit(
content=f'**Author**: {queue_object.ctx.author.id} ({queue_object.ctx.author.name})\n'
f'**Prompt**: `{queue_object.prompt}`\n**Progress**: {round(progress_data.get("progress") * 100, 2)}% '
f'\n{progress_data.get("state").get("sampling_step")}/{queue_object.steps} iterations, '
f'~{ips} it/s'
f'\n**ETA**: {round(progress_data.get("eta_relative"), 2)} seconds',
files=[file], view=view)
except Exception as e:
print('Something goes wrong...', str(e))

time.sleep(1)
event_loop.create_task(update_progress(event_loop, status_message_task, s, queue_object, tries))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this may also be an issue. What I'm experiencing:

  • start /draw
  • live preview updates begin
  • generation finishes (I can see in SD logs it is done and breakpoint hit in aiyabot in dream)
  • live preview hangs for some time
  • eventually, live preview deletes and bot posts finished image

I believe I am seeing this because my generation times are under ~8 seconds (using 3070 and SDXL Lightning) but the update_progress task waits at least 10 tries before finishing, regardless of whether the generation is done yet or not.

Based on my breakpoints what I see happening:

  • start /draw
  • live preview updates begin and queues to event_loop
  • generation finishes but cannot post_dream because queue is still full of update_progress tasks that keep being created
  • eventually update_progress hits 10 tries and exits, queue is cleared, and eventually gets to post_dream where it queues up finished image message

I see that the progress view is also deleted once post_dream starts so that seems like a good order to do things in but waiting for update_progress to finish takes too long.

This isn't perfect but its doing better:

async def update_progress(event_loop, status_message_task, s, queue_object, tries, had_image):
    status_message = status_message_task.result()
    has_image = False
    try:
        progress_data = s.get(url=f'{settings.global_var.url}/sdapi/v1/progress').json()

        if not had_image and progress_data["current_image"] is None and tries <= 10:
            time.sleep(1)
            event_loop.create_task(update_progress(event_loop, status_message_task, s, queue_object, tries + 1, had_image))
            return

        if progress_data["current_image"] is None and tries > 10:
            return

        if had_image and not has_image:
            return

        has_image = True
        had_image = True

        image = Image.open(io.BytesIO(base64.b64decode(progress_data["current_image"])))

        with contextlib.ExitStack() as stack:
            buffer = stack.enter_context(io.BytesIO())
            image.save(buffer, 'PNG')
            buffer.seek(0)
            file = discord.File(fp=buffer, filename=f'{queue_object.seed}.png')
        ips = '?'
        if progress_data["eta_relative"] != 0:
            ips = round((int(queue_object.steps) - progress_data["state"]["sampling_step"]) / progress_data["eta_relative"], 2)

        view = viewhandler.ProgressView()

        await status_message.edit(
            content=f'**Author**: {queue_object.ctx.author.id} ({queue_object.ctx.author.name})\n'
                    f'**Prompt**: `{queue_object.prompt}`\n**Progress**: {round(progress_data.get("progress") * 100, 2)}% '
                    f'\n{progress_data.get("state").get("sampling_step")}/{queue_object.steps} iterations, '
                    f'~{ips} it/s'
                    f'\n**ETA**: {round(progress_data.get("eta_relative"), 2)} seconds',
            files=[file], view=view)
    except Exception as e:
        print('Something goes wrong...', str(e))

    time.sleep(1)

    event_loop.create_task(update_progress(event_loop, status_message_task, s, queue_object, tries, had_image))

The main difference is that it returns from update_progress if the previous task got an image and the current task did not.

EDIT: I think it may be better to do the "has" check but check for empty string in progress_data["state"]["job"] instead. With a small 1-2 try buffer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Version using jobs...this seems to work better

async def update_progress(event_loop, status_message_task, s, queue_object, tries, any_job, tries_since_no_job):
    status_message = status_message_task.result()
    try:
        progress_data = s.get(url=f'{settings.global_var.url}/sdapi/v1/progress').json()
        job_name = progress_data.get('state').get('job')
        if job_name != '':
            any_job = True

        if progress_data["current_image"] is None:
            if job_name == '':
                if any_job:
                    if tries_since_no_job >= 2:
                        return
                    time.sleep(1)
                    event_loop.create_task(
                        update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, tries_since_no_job + 1))
                    return
                else:
                    # escape hatch
                    if tries > 10:
                        return
                    time.sleep(1)
                    event_loop.create_task(
                        update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, tries_since_no_job))
                    return
            else:
                time.sleep(1)
                event_loop.create_task(
                    update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, 0))
                return

        image = Image.open(io.BytesIO(base64.b64decode(progress_data["current_image"])))

        with contextlib.ExitStack() as stack:
            buffer = stack.enter_context(io.BytesIO())
            image.save(buffer, 'PNG')
            buffer.seek(0)
            file = discord.File(fp=buffer, filename=f'{queue_object.seed}.png')
        ips = '?'
        if progress_data["eta_relative"] != 0:
            ips = round(
                (int(queue_object.steps) - progress_data["state"]["sampling_step"]) / progress_data["eta_relative"], 2)

        view = viewhandler.ProgressView()

        await status_message.edit(
            content=f'**Author**: {queue_object.ctx.author.id} ({queue_object.ctx.author.name})\n'
                    f'**Prompt**: `{queue_object.prompt}`\n**Progress**: {round(progress_data.get("progress") * 100, 2)}% '
                    f'\n{progress_data.get("state").get("sampling_step")}/{queue_object.steps} iterations, '
                    f'~{ips} it/s'
                    f'\n**ETA**: {round(progress_data.get("eta_relative"), 2)} seconds',
            files=[file], view=view)
    except Exception as e:
        print('Something goes wrong...', str(e))

    time.sleep(1)

    event_loop.create_task(
        update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, 0))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe I am seeing this because my generation times are under ~8 seconds (using 3070 and SDXL Lightning) but the update_progress task waits at least 10 tries before finishing, regardless of whether the generation is done yet or not.

When there was a time.sleep(1) no errors but messages 404
That's why here I replaced 1 with 3
Bad choice.


class StableCog(commands.Cog, name='Stable Diffusion', description='Create images from natural language.'):
ctx_parse = discord.ApplicationContext
Expand Down Expand Up @@ -359,6 +401,23 @@ def post(self, event_loop: queuehandler.GlobalQueue.post_event_loop, post_queue_
def dream(self, event_loop: queuehandler.GlobalQueue.event_loop, queue_object: queuehandler.DrawObject):
try:
start_time = time.time()

status_message_task = event_loop.create_task(queue_object.ctx.channel.send(
f'**Author**: {queue_object.ctx.author.id} ({queue_object.ctx.author.name})\n'
f'**Prompt**: `{queue_object.prompt}`\n**Progress**: initialization...'
f'\n0/{queue_object.steps} iteractions, 0.00 it/s'
f'\n**Relative ETA**: initialization...'))

def worker():
event_loop.create_task(update_progress(event_loop, status_message_task, s, queue_object, 0))
return

status_thread = threading.Thread(target=worker)

def start_thread(*args):
status_thread.start()

status_message_task.add_done_callback(start_thread)

# construct a payload for data model, then the normal payload
model_payload = {
Expand Down Expand Up @@ -517,6 +576,10 @@ def dream(self, event_loop: queuehandler.GlobalQueue.event_loop, queue_object: q
queue_object.view.input_tuple = new_tuple

# set up discord message
def post_dream():
event_loop.create_task(status_message_task.result().delete())
Thread(target=post_dream, daemon=True).start()

content = f'> for {queue_object.ctx.author.name}'
noun_descriptor = "drawing" if image_count == 1 else f'{image_count} drawings'
draw_time = '{0:.3f}'.format(end_time - start_time)
Expand Down Expand Up @@ -583,10 +646,10 @@ def dream(self, event_loop: queuehandler.GlobalQueue.event_loop, queue_object: q
embed = discord.Embed(title='txt2img failed', description=f'{e}\n{traceback.print_exc()}',
color=settings.global_var.embed_color)
event_loop.create_task(queue_object.ctx.channel.send(embed=embed))

# check each queue for any remaining tasks
queuehandler.process_queue()



def setup(bot):
bot.add_cog(StableCog(bot))

Expand Down
41 changes: 41 additions & 0 deletions core/viewhandler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import discord
import random
import re
import requests
import os
from discord.ui import InputText, Modal, View

Expand Down Expand Up @@ -299,7 +300,47 @@ async def callback(self, interaction: discord.Interaction):
else:
await queuehandler.process_dream(draw_dream, queuehandler.DrawObject(stablecog.StableCog(self), *prompt_tuple, DrawView(prompt_tuple)))
await interaction.response.send_message(f'<@{interaction.user.id}>, {settings.messages()}\nQueue: ``{len(queuehandler.GlobalQueue.queue)}``{prompt_output}')
# view that holds the interrupt button for progress
class ProgressView(View):
def __init__(self):
super().__init__(timeout=None)

@discord.ui.button(
custom_id="button_interrupt",
emoji="❌")
async def button_interrupt(self, button, interaction):
try:
if str(interaction.user.id) not in interaction.message.content:
await interaction.response.send_message("Cannot interrupt other people's tasks!", ephemeral=True)
return
button.disabled = True
s = settings.authenticate_user()
s.post(url=f'{settings.global_var.url}/sdapi/v1/interrupt')
await interaction.response.edit_message(view=self)
except Exception as e:
button.disabled = True
await interaction.response.send_message("I have no idea why, but I broke. Either the request has fallen "
"through "
"or I no longer have the message in my cache.\n"
f"Good luck:\n`{str(e)}`", ephemeral=True)
@discord.ui.button(
custom_id="button_skip",
emoji="➡️")
async def button_skip(self, button, interaction):
try:
if str(interaction.user.id) not in interaction.message.content:
await interaction.response.send_message("Cannot skip other people's tasks!", ephemeral=True)
return
button.disabled = True
s = settings.authenticate_user()
s.post(url=f'{settings.global_var.url}/sdapi/v1/skip')
await interaction.response.edit_message(view=self)
except Exception as e:
button.disabled = True
await interaction.response.send_message("I have no idea why, but I broke. Either the request has fallen "
"through "
"or I no longer have the message in my cache.\n"
f"Good luck:\n`{str(e)}`", ephemeral=True)

# creating the view that holds the buttons for /draw output
class DrawView(View):
Expand Down
Loading