From 5be73c960274cd12f42d6a32e2ad74a95acc0e32 Mon Sep 17 00:00:00 2001 From: xieyongliang Date: Sun, 9 Apr 2023 21:25:27 +0800 Subject: [PATCH] update webui.py --- modules/api/api.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index b442f8bce4a..edda5c941d9 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -137,6 +137,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock): self.cache = dict() self.s3_client = boto3.client('s3') self.s3_resource= boto3.resource('s3') + self.generated_images_s3uri = os.environ.get('generated_images_s3uri', None) def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: @@ -399,6 +400,25 @@ def download_s3files(self, s3uri, path): json.dump(self.cache, open('cache', 'w')) + def post_invocations(self, b64images): + if self.generated_images_s3uri: + bucket, key = self.get_bucket_and_key(self.generated_images_s3uri) + images = [] + for b64image in b64images: + image = decode_base64_to_image(b64image).convert('RGB') + output = io.BytesIO() + image.save(output, format='JPEG') + image_id = str(uuid.uuid4()) + self.s3_client.put_object( + Body=output.getvalue(), + Bucket=bucket, + Key=f'{key}/{image_id}.jpg' + ) + images.append(f's3://{bucket}/{key}/{image_id}.jpg') + return images + else: + return b64images + def invocations(self, req: InvocationsRequest): print('-------invocation------') print(req) @@ -433,24 +453,26 @@ def invocations(self, req: InvocationsRequest): self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir)) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() response = self.text2imgapi(req.txt2img_payload) + response.images = self.post_invocations(response.images) shared.opts.data = default_options return response elif req.task == 'image-to-image': self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir)) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() response = self.img2imgapi(req.img2img_payload) + response.images = self.post_invocations(response.images) shared.opts.data = default_options return response elif req.task == 'extras-single-image': response = self.extras_single_image_api(req.extras_single_payload) + response.image = self.post_invocations([response.image])[0] shared.opts.data = default_options return response elif req.task == 'extras-batch-images': response = self.extras_batch_images_api(req.extras_batch_payload) + response.images = self.post_invocations(response.images) shared.opts.data = default_options return response - elif req.task == 'sd-models': - return self.get_sd_models() else: raise NotImplementedError except Exception as e: @@ -463,3 +485,9 @@ def ping(self): def launch(self, server_name, port): self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port) + + def get_bucket_and_key(self, s3uri): + pos = s3uri.find('/', 5) + bucket = s3uri[5 : pos] + key = s3uri[pos + 1 : ] + return bucket, key