Skip to content

Commit

Permalink
update webui.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Apr 9, 2023
1 parent 78b3a69 commit 5be73c9
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions modules/api/api.py
Expand Up @@ -137,6 +137,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
self.cache = dict()
self.s3_client = boto3.client('s3')
self.s3_resource= boto3.resource('s3')
self.generated_images_s3uri = os.environ.get('generated_images_s3uri', None)

def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth:
Expand Down Expand Up @@ -399,6 +400,25 @@ def download_s3files(self, s3uri, path):

json.dump(self.cache, open('cache', 'w'))

def post_invocations(self, b64images):
if self.generated_images_s3uri:
bucket, key = self.get_bucket_and_key(self.generated_images_s3uri)
images = []
for b64image in b64images:
image = decode_base64_to_image(b64image).convert('RGB')
output = io.BytesIO()
image.save(output, format='JPEG')
image_id = str(uuid.uuid4())
self.s3_client.put_object(
Body=output.getvalue(),
Bucket=bucket,
Key=f'{key}/{image_id}.jpg'
)
images.append(f's3://{bucket}/{key}/{image_id}.jpg')
return images
else:
return b64images

def invocations(self, req: InvocationsRequest):
print('-------invocation------')
print(req)
Expand Down Expand Up @@ -433,24 +453,26 @@ def invocations(self, req: InvocationsRequest):
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
response = self.text2imgapi(req.txt2img_payload)
response.images = self.post_invocations(response.images)
shared.opts.data = default_options
return response
elif req.task == 'image-to-image':
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
response = self.img2imgapi(req.img2img_payload)
response.images = self.post_invocations(response.images)
shared.opts.data = default_options
return response
elif req.task == 'extras-single-image':
response = self.extras_single_image_api(req.extras_single_payload)
response.image = self.post_invocations([response.image])[0]
shared.opts.data = default_options
return response
elif req.task == 'extras-batch-images':
response = self.extras_batch_images_api(req.extras_batch_payload)
response.images = self.post_invocations(response.images)
shared.opts.data = default_options
return response
elif req.task == 'sd-models':
return self.get_sd_models()
else:
raise NotImplementedError
except Exception as e:
Expand All @@ -463,3 +485,9 @@ def ping(self):
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)

def get_bucket_and_key(self, s3uri):
pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]
return bucket, key

0 comments on commit 5be73c9

Please sign in to comment.