Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Apr 27, 2023
1 parent 71169e0 commit f95f89c
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 94 deletions.
52 changes: 6 additions & 46 deletions modules/api/api.py
Expand Up @@ -37,11 +37,6 @@
import uuid
import os
import json
import boto3
cache = dict()
s3_client = boto3.client('s3')
s3_resource= boto3.resource('s3')
generated_images_s3uri = os.environ.get('generated_images_s3uri', None)

def upscaler_to_index(name: str):
try:
Expand Down Expand Up @@ -710,8 +705,8 @@ def get_memory(self):
return MemoryResponse(ram = ram, cuda = cuda)

def post_invocations(self, b64images, quality):
if generated_images_s3uri:
bucket, key = self.get_bucket_and_key(generated_images_s3uri)
if shared.generated_images_s3uri:
bucket, key = shared.get_bucket_and_key(shared.generated_images_s3uri)
images = []
for b64image in b64images:
image = decode_base64_to_image(b64image).convert('RGB')
Expand All @@ -726,7 +721,7 @@ def post_invocations(self, b64images, quality):
image.save(output, format='PNG', quality=95)

image_id = str(uuid.uuid4())
s3_client.put_object(
shared.s3_client.put_object(
Body=output.getvalue(),
Bucket=bucket,
Key=f'{key}/{image_id}.png'
Expand Down Expand Up @@ -759,7 +754,7 @@ def invocations(self, req: InvocationsRequest):
hypernetwork_s3uri = shared.cmd_opts.hypernetwork_s3uri

if hypernetwork_s3uri !='':
self.download_s3files(hypernetwork_s3uri, shared.cmd_opts.hypernetwork_dir)
shared.download_s3files(hypernetwork_s3uri, shared.cmd_opts.hypernetwork_dir)
shared.reload_hypernetworks()

if req.options != None:
Expand All @@ -769,14 +764,14 @@ def invocations(self, req: InvocationsRequest):

if req.task == 'text-to-image':
if embeddings_s3uri != '':
self.download_s3files(embeddings_s3uri, shared.cmd_opts.embeddings_dir)
shared.download_s3files(embeddings_s3uri, shared.cmd_opts.embeddings_dir)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
response = self.text2imgapi(req.txt2img_payload)
response.images = self.post_invocations(response.images, quality)
return response
elif req.task == 'image-to-image':
if embeddings_s3uri != '':
self.download_s3files(embeddings_s3uri, shared.cmd_opts.embeddings_dir)
shared.download_s3files(embeddings_s3uri, shared.cmd_opts.embeddings_dir)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
response = self.img2imgapi(req.img2img_payload)
response.images = self.post_invocations(response.images, quality)
Expand All @@ -803,38 +798,3 @@ def ping(self):
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)

def get_bucket_and_key(self, s3uri):
pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]
return bucket, key

def download_s3files(self, s3uri, path):
global cache

pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]

s3_bucket = s3_resource.Bucket(bucket)
objs = list(s3_bucket.objects.filter(Prefix=key))

if os.path.isfile('cache'):
cache = json.load(open('cache', 'r'))

for obj in objs:
if obj.key == key:
continue
response = s3_client.head_object(
Bucket = bucket,
Key = obj.key
)
obj_key = 's3://{0}/{1}'.format(bucket, obj.key)
if obj_key not in cache or cache[obj_key] != response['ETag']:
filename = obj.key[obj.key.rfind('/') + 1 : ]

s3_client.download_file(bucket, obj.key, os.path.join(path, filename))
cache[obj_key] = response['ETag']

json.dump(cache, open('cache', 'w'))
1 change: 1 addition & 0 deletions modules/cmd_args.py
Expand Up @@ -109,3 +109,4 @@
parser.add_argument('--dreambooth-config-id', default='', type=str, help='Dreambooth config ID')
parser.add_argument('--embeddings-s3uri', default='', type=str, help='Embedding S3Uri')
parser.add_argument('--hypernetwork-s3uri', default='', type=str, help='Hypernetwork S3Uri')
parser.add_argument('--region-name', default='', type=str, help='Region name')
46 changes: 46 additions & 0 deletions modules/shared.py
Expand Up @@ -657,3 +657,49 @@ def html(filename):
return file.read()

return ""

import boto3
import requests

cache = dict()
region_name = boto3.session.Session().region_name if not cmd_opts.train else cmd_opts.region_name
s3_client = boto3.client('s3', region_name=region_name)
endpointUrl = s3_client.meta.endpoint_url
s3_client = boto3.client('s3', endpoint_url=endpointUrl, region_name=region_name)
s3_resource= boto3.resource('s3')
generated_images_s3uri = os.environ.get('generated_images_s3uri', None)

def get_bucket_and_key(s3uri):
pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]
return bucket, key

def s3_download(s3uri, path):
global cache

pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]

if os.path.isfile('cache'):
cache = json.load(open('cache', 'r'))

response = s3_client.head_object(
Bucket=bucket,
Key=key
)
if key not in cache or cache[key] != response['ETag']:
filename = key[key.rfind('/') + 1 : ]

s3_client.download_file(bucket, key, os.path.join(path, filename))
cache[key] = response['ETag']

json.dump(cache, open('cache', 'w'))

def http_download(httpuri, path):
with requests.get(httpuri, stream=True) as r:
r.raise_for_status()
with open(path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
52 changes: 4 additions & 48 deletions webui.py
Expand Up @@ -72,14 +72,6 @@
from extensions.sd_dreambooth_extension.dreambooth.db_concept import Concept
from modules import paths
import glob
else:
import requests
cache = dict()
region_name = boto3.session.Session().region_name
s3_client = boto3.client('s3', region_name=region_name)
endpointUrl = s3_client.meta.endpoint_url
s3_client = boto3.client('s3', endpoint_url=endpointUrl, region_name=region_name)
s3_resource= boto3.resource('s3')

startup_timer.record("other imports")

Expand Down Expand Up @@ -258,8 +250,8 @@ def webui():
if launch_api:
models_config_s3uri = os.environ.get('models_config_s3uri', None)
if models_config_s3uri:
bucket, key = get_bucket_and_key(models_config_s3uri)
s3_object = s3_client.get_object(Bucket=bucket, Key=key)
bucket, key = shared.get_bucket_and_key(models_config_s3uri)
s3_object = shared.s3_client.get_object(Bucket=bucket, Key=key)
bytes = s3_object["Body"].read()
payload = bytes.decode('utf8')
huggingface_models = json.loads(payload).get('huggingface_models', None)
Expand Down Expand Up @@ -290,14 +282,14 @@ def webui():
for s3_model in s3_models:
uri = s3_model['uri']
name = s3_model['name']
s3_download(uri, f'/tmp/models/{name}')
shared.s3_download(uri, f'/tmp/models/{name}')

if http_models:
for http_model in http_models:
uri = http_model['uri']
filename = http_model['filename']
name = http_model['name']
http_download(uri, f'/tmp/models/{name}/{filename}')
shared.http_download(uri, f'/tmp/models/{name}/{filename}')

initialize()

Expand Down Expand Up @@ -620,7 +612,6 @@ def train():
)
os.makedirs(os.path.dirname("/opt/ml/model/"), exist_ok=True)
os.makedirs(os.path.dirname("/opt/ml/model/Stable-diffusion/"), exist_ok=True)
os.makedirs(os.path.dirname("/opt/ml/model/ControlNet/"), exist_ok=True)
train_steps=int(db_config.revision)
model_file_basename = f'{db_model_name}_{train_steps}_lora' if db_config.use_lora else f'{db_model_name}_{train_steps}'
f1=os.path.join(sd_models_dir, db_model_name, f'{model_file_basename}.yaml')
Expand All @@ -637,41 +628,6 @@ def train():
except Exception as e:
traceback.print_exc()
print(e)
else:
def get_bucket_and_key(s3uri):
pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]
return bucket, key

def s3_download(s3uri, path):
global cache

pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]

if os.path.isfile('cache'):
cache = json.load(open('cache', 'r'))

response = s3_client.head_object(
Bucket=bucket,
Key=key
)
if key not in cache or cache[key] != response['ETag']:
filename = key[key.rfind('/') + 1 : ]

s3_client.download_file(bucket, key, os.path.join(path, filename))
cache[key] = response['ETag']

json.dump(cache, open('cache', 'w'))

def http_download(httpuri, path):
with requests.get(httpuri, stream=True) as r:
r.raise_for_status()
with open(path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)

if __name__ == "__main__":
if cmd_opts.train:
Expand Down

0 comments on commit f95f89c

Please sign in to comment.