Skip to content

Commit

Permalink
Add serialization of commands in embed, a link to completed entry, an…
Browse files Browse the repository at this point in the history
…d a thumbnail with link
  • Loading branch information
Jimmy committed Sep 12, 2022
1 parent eecd4cd commit c85d7c4
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 14 deletions.
Empty file added __init__.py
Empty file.
103 changes: 89 additions & 14 deletions bot.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import argparse
import asyncio
import json
import os
import pathlib
import random
import string
import sys
import time

from io import BytesIO
Expand All @@ -20,6 +22,16 @@
from discord.ui import Button, View
from docarray import Document, DocumentArray

SELF_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SELF_DIR))

from serializers import (
serialize_image_request,
serialize_interpolate_request,
serialize_riff_request,
serialize_upscale_request,
)


parser = argparse.ArgumentParser()
parser.add_argument('token', help='Discord token')
Expand Down Expand Up @@ -51,6 +63,7 @@
type=int, required=False)
args = parser.parse_args()


# Load up diffusers NSFW detection model and the NSFW wordlist detector.
nsfw_toxic_detection_fn = None
nsfw_wordlist: list[str] = []
Expand Down Expand Up @@ -477,6 +490,21 @@ async def select_aspect_ratio(self, interaction: discord.Interaction,
await interaction.response.defer()


async def send_alert_embed(
channel: discord.abc.GuildChannel,
author_id: str,
work_msg: discord.Message,
serialized_cmd: str,
):
guild_id = str(channel.guild.id)
channel_id = str(channel.id)
completed_id = str(work_msg.id)
embed = discord.Embed()
embed.description = f'Your request has finished. [Please view it here](https://discord.com/channels/{guild_id}/{channel_id}/{completed_id}).'
embed.add_field(name="Command Executed", value=serialized_cmd, inline=False)
embed.set_thumbnail(url=work_msg.attachments[0].url)
await channel.send(f'Job completed for <@{author_id}>.', embed=embed)


async def _image(
channel: discord.abc.GuildChannel,
Expand Down Expand Up @@ -556,25 +584,38 @@ async def _image(

file = to_discord_file_and_maybe_check_safety(image_loc)
if seed_search is True:
await work_msg.edit(
content=f'Image generation for prompt "{prompt}" by <@{author_id}> complete. The ID for your images is `{short_id}`.',
seed_lst = []
for i, _s in enumerate(seeds):
seed_lst.append(f'{i}: {_s}')
seeds_str = ', '.join(seed_lst)

work_msg = await work_msg.edit(
content=f'Image generation for prompt "{prompt}" by <@{author_id}> complete. The ID for your images is `{short_id}`. Seeds used were {seeds_str}',
attachments=[file])
elif typ == 'promptarray':
work_msg = await work_msg.edit(
content=f'Image generation for prompt array "{prompt}" by <@{author_id}> complete. The ID for your images is `{short_id}`.',
attachments=[file])
else:
btns = FourImageButtons(message_id=work_msg.id, short_id=short_id)
btns.serialize_to_json_and_store()
client.add_view(btns, message_id=work_msg.id)
await work_msg.edit(
work_msg = await work_msg.edit(
content=f'Image generation for prompt "{prompt}" by <@{author_id}> complete. The ID for your images is `{short_id}`.',
attachments=[file],
view=btns)
if seed_search is True:
await channel.send(short_id)
if seeds is not None:
seed_lst = []
for i, _s in enumerate(seeds):
seed_lst.append(f'{i}: {_s}')
seeds_str = ', '.join(seed_lst)
await channel.send(f'Seeds used were {seeds_str}')

serialized_cmd = serialize_image_request(
prompt=prompt,
height=height,
sampler=sampler,
scale=scale,
seed=seed,
seed_search=seed_search,
steps=steps,
width=width)
await send_alert_embed(channel, author_id, work_msg, serialized_cmd)

except Exception as e:
await channel.send(f'Got unknown error on prompt "{prompt}": {str(e)}')
finally:
Expand Down Expand Up @@ -712,10 +753,24 @@ async def _riff(
btns = FourImageButtons(message_id=work_msg.id, short_id=short_id)
btns.serialize_to_json_and_store()
client.add_view(btns, message_id=work_msg.id)
await work_msg.edit(
work_msg = await work_msg.edit(
content=f'Image generation for riff on `{docarray_id}` index {str(idx)} for <@{author_id}> complete. The ID for your new images is `{short_id}`.',
attachments=[file],
view=btns)

serialized_cmd = serialize_riff_request(
docarray_id=docarray_id,
idx=idx,
height=height,
iterations=iterations,
latentless=latentless,
prompt=prompt,
sampler=sampler,
scale=scale,
seed=seed,
strength=strength,
width=width)
await send_alert_embed(channel, author_id, work_msg, serialized_cmd)
except Exception as e:
await channel.send(f'Got unknown error on riff "{docarray_id}" index {str(idx)}: {str(e)}')
finally:
Expand Down Expand Up @@ -800,6 +855,9 @@ async def _interpolate(
):
global currently_fetching_ai_image
author_id = str(user.id)

prompt1 = prompt1.strip()
prompt2 = prompt2.strip()

if args.restrict_all_to_channel:
if channel.id != args.restrict_all_to_channel:
Expand Down Expand Up @@ -851,9 +909,20 @@ async def _interpolate(
short_id = output['id']

file = to_discord_file_and_maybe_check_safety(image_loc)
await work_msg.edit(
work_msg = await work_msg.edit(
content=f'Image generation for interpolate on `{prompt1}` to `{prompt2}` for <@{author_id}> complete. The ID for your new images is `{short_id}`.',
attachments=[file])

serialized_cmd = serialize_interpolate_request(
prompt1=prompt1,
prompt2=prompt2,
height=height,
sampler=sampler,
scale=scale,
seed=seed,
strength=strength,
width=width)
await send_alert_embed(channel, author_id, work_msg, serialized_cmd)
except Exception as e:
await channel.send(f'Got unknown error on interpolate `{prompt1}` to `{prompt2}`: {str(e)}')
finally:
Expand Down Expand Up @@ -959,10 +1028,16 @@ async def _upscale(
raise Exception(err)
image_loc = output['image_loc']

print()

file = to_discord_file_and_maybe_check_safety(image_loc)
await work_msg.edit(
work_msg = await work_msg.edit(
content=f'Image generation for upscale on `{docarray_id}` index {str(idx)} for <@{author_id}> complete.',
attachments=[file])

serialized_cmd = serialize_upscale_request(docarray_id=docarray_id,
idx=idx)
await send_alert_embed(channel, author_id, work_msg, serialized_cmd)
completed = True
except Exception as e:
await channel.send(f'Got unknown error on upscale "{docarray_id}" index {str(idx)}: {str(e)}')
Expand Down
167 changes: 167 additions & 0 deletions serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from typing import Optional


def remove_quotes_from_cmd_kwargs(cmd_kwargs):
split = cmd_kwargs.split(',')
keywordarg_list = []
for keywordarg in split:
as_pair = keywordarg.split('=')
if as_pair[1].startswith('\'') and as_pair[1].endswith('\''):
as_pair[1] = as_pair[1][1:-1]
if as_pair[1] == 'False':
continue
keywordarg_list.append(f'{as_pair[0]}={as_pair[1]}')
return ', '.join(keywordarg_list)


def prompt_un_parenthesis(prompt):
'''
Handle parenthesis in slash command prompts.
'''
if '(' in prompt:
prompt = prompt.replace('(', '「')
if ')' in prompt:
prompt = prompt.replace(')', '」')
return prompt


def serialize_image_request(
prompt: str,

height: Optional[int]=None,
sampler: Optional[str]=None,
scale: Optional[float]=None,
seed: Optional[int]=None,
seed_search: bool=None,
steps: Optional[int]=None,
width: Optional[int]=None,
):
'''
Serialize an image request to >image format.
'''
options = ''
if height is not None:
options += f'{height=},'
if sampler is not None:
options += f'{sampler=},'
if scale is not None:
options += f'{scale=},'
if seed is not None:
options += f'{seed=},'
if seed_search is not None:
options += f'{seed_search=},'
if steps is not None:
options += f'{steps=},'
if width is not None:
options += f'{width=},'
if len(options) > 0 and options[-1] == ',':
options = f'{options[:-1]}'
options = remove_quotes_from_cmd_kwargs(options)

prompt = prompt_un_parenthesis(prompt)

as_string = f'>image {prompt}'
if options == '':
return as_string
return f'{as_string} ({options})'


def serialize_riff_request(
docarray_id: str,
idx: int,

height: Optional[int]=None,
iterations: Optional[int]=None,
latentless: bool=False,
prompt: Optional[str]=None,
sampler: Optional[str]=None,
scale: Optional[float]=None,
seed: Optional[int]=None,
strength: Optional[float]=None,
width: Optional[int]=None,
):
'''
Serialize a riff request to >riff format.
'''
options = ''
if height is not None:
options += f'{height=},'
if iterations is not None:
options += f'{iterations=},'
if latentless is not None:
options += f'{latentless=},'
if prompt is not None:
prompt = prompt_un_parenthesis(prompt)
options += f'{prompt=},'
if sampler is not None:
options += f'{sampler=},'
if scale is not None:
options += f'{scale=},'
if seed is not None:
options += f'{seed=},'
if strength is not None:
options += f'{strength=},'
if width is not None:
options += f'{width=},'
if len(options) > 0 and options[-1] == ',':
options = f'{options[:-1]}'
options = remove_quotes_from_cmd_kwargs(options)

as_string = f'>riff {docarray_id} {idx}'
if options == '':
return as_string
return f'{as_string} ({options})'


def serialize_interpolate_request(
prompt1: str,
prompt2: str,

height: Optional[int]=None,
sampler: Optional[str]=None,
scale: Optional[float]=None,
seed: Optional[int]=None,
strength: Optional[float]=None,
width: Optional[int]=None,
):
'''
Serialize an interpolate request to >interpolate format.
'''
options = ''
if height is not None:
options += f'{height=},'
if sampler is not None:
options += f'{sampler=},'
if scale is not None:
options += f'{scale=},'
if seed is not None:
options += f'{seed=},'
if strength is not None:
options += f'{strength=},'
if width is not None:
options += f'{width=},'
if len(options) > 0 and options[-1] == ',':
options = f'{options[:-1]}'
options = remove_quotes_from_cmd_kwargs(options)

prompt1 = prompt_un_parenthesis(prompt1)
if '|' in prompt1:
prompt1 = prompt1.replace('|', '')
prompt2 = prompt_un_parenthesis(prompt2)
if '|' in prompt2:
prompt2 = prompt2.replace('|', '')

as_string = f'>interpolate {prompt1} | {prompt2}'
if options == '':
return as_string
return f'{as_string} ({options})'


def serialize_upscale_request(
docarray_id: str,
idx: int,
):
'''
Serialize an upscale request to '>upscale'.
'''
return f'>upscale {docarray_id} {idx}'

0 comments on commit c85d7c4

Please sign in to comment.