Skip to content

Commit

Permalink
fix issues with s3_download
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Apr 28, 2023
1 parent 62bec0b commit 666008c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
6 changes: 3 additions & 3 deletions modules/api/api.py
Expand Up @@ -754,7 +754,7 @@ def invocations(self, req: InvocationsRequest):
hypernetwork_s3uri = shared.cmd_opts.hypernetwork_s3uri

if hypernetwork_s3uri !='':
shared.download_s3files(hypernetwork_s3uri, shared.cmd_opts.hypernetwork_dir)
shared.s3_download(hypernetwork_s3uri, shared.cmd_opts.hypernetwork_dir)
shared.reload_hypernetworks()

if req.options != None:
Expand All @@ -764,14 +764,14 @@ def invocations(self, req: InvocationsRequest):

if req.task == 'text-to-image':
if embeddings_s3uri != '':
shared.download_s3files(embeddings_s3uri, shared.cmd_opts.embeddings_dir)
shared.s3_download(embeddings_s3uri, shared.cmd_opts.embeddings_dir)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
response = self.text2imgapi(req.txt2img_payload)
response.images = self.post_invocations(response.images, quality)
return response
elif req.task == 'image-to-image':
if embeddings_s3uri != '':
shared.download_s3files(embeddings_s3uri, shared.cmd_opts.embeddings_dir)
shared.s3_download(embeddings_s3uri, shared.cmd_opts.embeddings_dir)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
response = self.img2imgapi(req.img2img_payload)
response.images = self.post_invocations(response.images, quality)
Expand Down
29 changes: 21 additions & 8 deletions modules/shared.py
Expand Up @@ -685,18 +685,31 @@ def s3_download(s3uri, path):
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]

objects = []
paginator = s3_client.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=key, Prefix=key)
for page in page_iterator:
if 'Contents' in page:
for obj in page['Contents']:
objects.append(obj)
if 'NextContinuationToken' in page:
page_iterator = paginator.paginate(Bucket=bucket, Prefix=key,
ContinuationToken=page['NextContinuationToken'])

if os.path.isfile('cache'):
cache = json.load(open('cache', 'r'))

response = s3_client.head_object(
Bucket=bucket,
Key=key
)
if key not in cache or cache[key] != response['ETag']:
filename = key[key.rfind('/') + 1 : ]
for obj in objects:
response = s3_client.head_object(
Bucket = bucket,
Key = obj.key
)
obj_key = 's3://{0}/{1}'.format(bucket, obj.key)
if obj_key not in cache or cache[obj_key] != response['ETag']:
filename = obj.key[obj.key.rfind('/') + 1 : ]

s3_client.download_file(bucket, key, os.path.join(path, filename))
cache[key] = response['ETag']
s3_client.download_file(bucket, obj.key, os.path.join(path, filename))
cache[obj_key] = response['ETag']

json.dump(cache, open('cache', 'w'))

Expand Down

0 comments on commit 666008c

Please sign in to comment.