Skip to content

Commit

Permalink
save generated imagess and models with user specific path
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Apr 9, 2023
1 parent 5be73c9 commit 4331462
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
18 changes: 10 additions & 8 deletions modules/api/api.py
Expand Up @@ -137,7 +137,6 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
self.cache = dict()
self.s3_client = boto3.client('s3')
self.s3_resource= boto3.resource('s3')
self.generated_images_s3uri = os.environ.get('generated_images_s3uri', None)

def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth:
Expand Down Expand Up @@ -400,9 +399,12 @@ def download_s3files(self, s3uri, path):

json.dump(self.cache, open('cache', 'w'))

def post_invocations(self, b64images):
if self.generated_images_s3uri:
bucket, key = self.get_bucket_and_key(self.generated_images_s3uri)
def post_invocations(self, username, b64images):
generated_images_s3uri = os.environ.get('generated_images_s3uri', None)

if generated_images_s3uri:
generated_images_s3uri = f'{generated_images_s3uri}{username}/'
bucket, key = self.get_bucket_and_key(generated_images_s3uri)
images = []
for b64image in b64images:
image = decode_base64_to_image(b64image).convert('RGB')
Expand Down Expand Up @@ -453,24 +455,24 @@ def invocations(self, req: InvocationsRequest):
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
response = self.text2imgapi(req.txt2img_payload)
response.images = self.post_invocations(response.images)
response.images = self.post_invocations(username, response.images)
shared.opts.data = default_options
return response
elif req.task == 'image-to-image':
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
response = self.img2imgapi(req.img2img_payload)
response.images = self.post_invocations(response.images)
response.images = self.post_invocations(username, response.images)
shared.opts.data = default_options
return response
elif req.task == 'extras-single-image':
response = self.extras_single_image_api(req.extras_single_payload)
response.image = self.post_invocations([response.image])[0]
response.image = self.post_invocations(username, [response.image])[0]
shared.opts.data = default_options
return response
elif req.task == 'extras-batch-images':
response = self.extras_batch_images_api(req.extras_batch_payload)
response.images = self.post_invocations(response.images)
response.images = self.post_invocations(username, response.images)
shared.opts.data = default_options
return response
else:
Expand Down
10 changes: 5 additions & 5 deletions webui.py
Expand Up @@ -838,28 +838,28 @@ def train():
print('Uploading SD Models...')
if db_config.v2:
upload_s3files(
sd_models_s3uri,
f'{sd_models_s3uri}/{username}/',
os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.yaml')
)
if db_config.save_safetensors:
upload_s3files(
sd_models_s3uri,
f'{sd_models_s3uri}/{username}/',
os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.safetensors')
)
else:
upload_s3files(
sd_models_s3uri,
f'{sd_models_s3uri}/{username}/',
os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.ckpt')
)
print('Uploading DB Models...')
upload_s3folder(
f'{db_models_s3uri}{db_model_name}',
f'{db_models_s3uri}{username}/{db_model_name}',
os.path.join(db_model_dir, db_model_name)
)
if db_config.use_lora:
print('Uploading Lora Models...')
upload_s3files(
lora_models_s3uri,
f'{lora_models_s3uri}/{username}/',
os.path.join(lora_model_dir, f'{db_model_name}_*.pt')
)
#automatic tar latest checkpoint and upload to s3 by zheng on 2023.03.22
Expand Down

0 comments on commit 4331462

Please sign in to comment.