Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Live preview #234

Closed
wants to merge 13 commits into from
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ To generate a prompt from a couple of words, use the /generate command and inclu

### Currently supported options

- live preview
- negative prompts
- swap model/checkpoint (_[see wiki](https://github.com/Kilvoctu/aiyabot/wiki/Model-swapping)_)
- sampling steps
Expand Down Expand Up @@ -49,7 +50,8 @@ To generate a prompt from a couple of words, use the /generate command and inclu
- 🎲 - randomize seed, then generate a new image with same parameters.
- 📋 - view the generated image's information.
- ⬆️ - upscale the generated image with defaults. Batch grids require use of the drop downs
- ❌ - deletes the generated image.
- ❌ - deletes the generated image. In Live preview this button interrupts generation process
- ➡️ - skips the current image generation in live preview and go to next batch (if there's more than 1)
- dropdown menus - batch images produce two drop down menus for the first 25 images.
- The first menu prompts the bot to send only the images that you select at single images
- The second menu prompts the bot to upscale the selected image from the batch.
Expand Down
5 changes: 5 additions & 0 deletions core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
display_ignored_words = "False"
# These words will be added to the beginning of the negative prompt
negative_prompt_prefix = []
# the time, in seconds, between when AIYA checks for generation progress from SD -- can be a float
preview_update_interval = 3


# the fallback channel defaults template for AIYA if nothing is set
Expand Down Expand Up @@ -125,6 +127,7 @@ class GlobalVar:
negative_prompt_prefix = []
spoiler = False
spoiler_role = None
preview_update_interval = 3


global_var = GlobalVar()
Expand Down Expand Up @@ -512,6 +515,8 @@ def populate_global_vars():
global_var.prompt_ignore_list = [x for x in config['prompt_ignore_list']]
global_var.display_ignored_words = config['display_ignored_words']
global_var.negative_prompt_prefix = [x for x in config['negative_prompt_prefix']]
if config['preview_update_interval'] is not None:
global_var.preview_update_interval = float(config['preview_update_interval'])

# create persistent session since we'll need to do a few API calls
s = authenticate_user()
Expand Down
91 changes: 89 additions & 2 deletions core/stablecog.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import base64
import contextlib
import threading
import discord
import io
import math
Expand All @@ -15,7 +17,71 @@
from core import viewhandler
from core import settings
from core import settingscog
from threading import Thread

async def update_progress(event_loop, status_message_task, s, queue_object, tries, any_job, tries_since_no_job):
status_message = status_message_task.result()
try:
progress_data = s.get(url=f'{settings.global_var.url}/sdapi/v1/progress').json()
job_name = progress_data.get('state').get('job')
if job_name != '':
any_job = True

if progress_data["current_image"] is None:
if job_name == '':
if any_job:
if tries_since_no_job >= 2:
return
time.sleep(settings.global_var.preview_update_interval)
event_loop.create_task(
update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, tries_since_no_job + 1))
return
else:
# escape hatch
if tries > 10:
return
time.sleep(settings.global_var.preview_update_interval)
event_loop.create_task(
update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, tries_since_no_job))
return
else:
time.sleep(settings.global_var.preview_update_interval)
event_loop.create_task(
update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, 0))
return

image = Image.open(io.BytesIO(base64.b64decode(progress_data["current_image"])))

with contextlib.ExitStack() as stack:
buffer = stack.enter_context(io.BytesIO())
image.save(buffer, 'PNG')
buffer.seek(0)
filename = f'{queue_object.seed}.png'
if queue_object.spoiler:
filename = f'SPOILER_{queue_object.seed}.png'
fp = buffer
file = discord.File(fp, filename)
ips = '?'
if progress_data["eta_relative"] != 0:
ips = round(
(int(queue_object.steps) - progress_data["state"]["sampling_step"]) / progress_data["eta_relative"], 2)

view = viewhandler.ProgressView()

await status_message.edit(
content=f'**Author**: {queue_object.ctx.author.id} ({queue_object.ctx.author.name})\n'
f'**Prompt**: `{queue_object.prompt}`\n**Progress**: {round(progress_data.get("progress") * 100, 2)}% '
f'\n{progress_data.get("state").get("sampling_step")}/{queue_object.steps} iterations, '
f'~{ips} it/s'
f'\n**ETA**: {round(progress_data.get("eta_relative"), 2)} seconds',
files=[file], view=view)
except Exception as e:
print('Something goes wrong...', str(e))

time.sleep(1)

event_loop.create_task(
update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, 0))

class StableCog(commands.Cog, name='Stable Diffusion', description='Create images from natural language.'):
ctx_parse = discord.ApplicationContext
Expand Down Expand Up @@ -381,6 +447,23 @@ def post(self, event_loop: queuehandler.GlobalQueue.post_event_loop, post_queue_
def dream(self, event_loop: queuehandler.GlobalQueue.event_loop, queue_object: queuehandler.DrawObject):
try:
start_time = time.time()

status_message_task = event_loop.create_task(queue_object.ctx.channel.send(
f'**Author**: {queue_object.ctx.author.id} ({queue_object.ctx.author.name})\n'
f'**Prompt**: `{queue_object.prompt}`\n**Progress**: initialization...'
f'\n0/{queue_object.steps} iteractions, 0.00 it/s'
f'\n**Relative ETA**: initialization...'))

def worker():
event_loop.create_task(update_progress(event_loop, status_message_task, s, queue_object, 0, False, 0))
return

status_thread = threading.Thread(target=worker)

def start_thread(*args):
status_thread.start()

status_message_task.add_done_callback(start_thread)

# construct a payload for data model, then the normal payload
model_payload = {
Expand Down Expand Up @@ -539,6 +622,10 @@ def dream(self, event_loop: queuehandler.GlobalQueue.event_loop, queue_object: q
queue_object.view.input_tuple = new_tuple

# set up discord message
def post_dream():
event_loop.create_task(status_message_task.result().delete())
Thread(target=post_dream, daemon=True).start()

content = f'> for {queue_object.ctx.author.name}'
noun_descriptor = "drawing" if image_count == 1 else f'{image_count} drawings'
draw_time = '{0:.3f}'.format(end_time - start_time)
Expand Down Expand Up @@ -611,10 +698,10 @@ def dream(self, event_loop: queuehandler.GlobalQueue.event_loop, queue_object: q
embed = discord.Embed(title='txt2img failed', description=f'{e}\n{traceback.print_exc()}',
color=settings.global_var.embed_color)
event_loop.create_task(queue_object.ctx.channel.send(embed=embed))

# check each queue for any remaining tasks
queuehandler.process_queue()



def setup(bot):
bot.add_cog(StableCog(bot))

Expand Down
41 changes: 41 additions & 0 deletions core/viewhandler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import discord
import random
import re
import requests
import os
from discord.ui import InputText, Modal, View

Expand Down Expand Up @@ -298,7 +299,47 @@ async def callback(self, interaction: discord.Interaction):
else:
await queuehandler.process_dream(draw_dream, queuehandler.DrawObject(stablecog.StableCog(self), *prompt_tuple, DrawView(prompt_tuple)))
await interaction.response.send_message(f'<@{interaction.user.id}>, {settings.messages()}\nQueue: ``{len(queuehandler.GlobalQueue.queue)}``{prompt_output}')
# view that holds the interrupt button for progress
class ProgressView(View):
def __init__(self):
super().__init__(timeout=None)

@discord.ui.button(
custom_id="button_interrupt",
emoji="❌")
async def button_interrupt(self, button, interaction):
try:
if str(interaction.user.id) not in interaction.message.content:
await interaction.response.send_message("Cannot interrupt other people's tasks!", ephemeral=True)
return
button.disabled = True
s = settings.authenticate_user()
s.post(url=f'{settings.global_var.url}/sdapi/v1/interrupt')
await interaction.response.edit_message(view=self)
except Exception as e:
button.disabled = True
await interaction.response.send_message("I have no idea why, but I broke. Either the request has fallen "
"through "
"or I no longer have the message in my cache.\n"
f"Good luck:\n`{str(e)}`", ephemeral=True)
@discord.ui.button(
custom_id="button_skip",
emoji="➡️")
async def button_skip(self, button, interaction):
try:
if str(interaction.user.id) not in interaction.message.content:
await interaction.response.send_message("Cannot skip other people's tasks!", ephemeral=True)
return
button.disabled = True
s = settings.authenticate_user()
s.post(url=f'{settings.global_var.url}/sdapi/v1/skip')
await interaction.response.edit_message(view=self)
except Exception as e:
button.disabled = True
await interaction.response.send_message("I have no idea why, but I broke. Either the request has fallen "
"through "
"or I no longer have the message in my cache.\n"
f"Good luck:\n`{str(e)}`", ephemeral=True)

# creating the view that holds the buttons for /draw output
class DrawView(View):
Expand Down
Loading