Skip to content

Commit

Permalink
update api.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Apr 19, 2023
1 parent 0acee0b commit d5d5a9a
Showing 1 changed file with 47 additions and 6 deletions.
53 changes: 47 additions & 6 deletions modules/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@
import asyncio
from typing import Union
import traceback
import modules.sd_models
import modules.sd_vae

from modules.sd_vae import reload_vae_weights, refresh_vae_list
from modules.hypernetworks import hypernetwork
from modules.paths_internal import script_path
import uuid
import os
import json
import boto3
cache = dict()
s3_client = boto3.client('s3')
s3_resource= boto3.resource('s3')
generated_images_s3uri = os.environ.get('generated_images_s3uri', None)

def upscaler_to_index(name: str):
Expand Down Expand Up @@ -739,24 +742,35 @@ def invocations(self, req: InvocationsRequest):

if req.vae != None:
shared.opts.data['sd_vae'] = req.vae
modules.sd_vae.refresh_vae_list()
refresh_vae_list()

if req.model != None:
sd_model_checkpoint = shared.opts.sd_model_checkpoint
shared.opts.sd_model_checkpoint = req.model
with self.queue_lock:
modules.sd_models.reload_model_weights()
reload_model_weights()
if sd_model_checkpoint == shared.opts.sd_model_checkpoint:
modules.sd_vae.reload_vae_weights()
reload_vae_weights()

quality = req.quality

embeddings_s3uri = shared.cmd_opts.embeddings_s3uri
hypernetwork_s3uri = shared.cmd_opts.hypernetwork_s3uri

self.download_s3files(hypernetwork_s3uri, os.path.join(script_path, shared.cmd_opts.hypernetwork_dir))
hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)
hypernetwork.apply_strength()

try:
if req.task == 'text-to-image':
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
response = self.text2imgapi(req.txt2img_payload)
response.images = self.post_invocations(response.images, quality)
return response
elif req.task == 'image-to-image':
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
response = self.img2imgapi(req.img2img_payload)
response.images = self.post_invocations(response.images, quality)
return response
Expand Down Expand Up @@ -786,3 +800,30 @@ def get_bucket_and_key(self, s3uri):
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]
return bucket, key

def download_s3files(self, s3uri, path):
pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]

s3_bucket = s3_resource.Bucket(bucket)
objs = list(s3_bucket.objects.filter(Prefix=key))

if os.path.isfile('cache'):
cache = json.load(open('cache', 'r'))

for obj in objs:
if obj.key == key:
continue
response = s3_client.head_object(
Bucket = bucket,
Key = obj.key
)
obj_key = 's3://{0}/{1}'.format(bucket, obj.key)
if obj_key not in cache or cache[obj_key] != response['ETag']:
filename = obj.key[obj.key.rfind('/') + 1 : ]

s3_client.download_file(bucket, obj.key, os.path.join(path, filename))
cache[obj_key] = response['ETag']

json.dump(cache, open('cache', 'w'))

0 comments on commit d5d5a9a

Please sign in to comment.