Skip to content

Commit

Permalink
revise webui.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Jan 3, 2023
1 parent 52c313e commit f41e8be
Showing 1 changed file with 21 additions and 28 deletions.
49 changes: 21 additions & 28 deletions webui.py
Expand Up @@ -238,42 +238,29 @@ def webui():
modules.sd_models.list_models()
print('Restarting Gradio')

def upload_s3file(s3uri, file_path, file_name):
s3_client = boto3.client('s3', region_name = cmd_opts.region_name)

def upload_s3files(s3uri, file_path_with_pattern):
pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]

binary = io.BytesIO(open(file_path, 'rb').read())
key = key + file_name
s3_resource = boto3.resource('s3')
s3_bucket = s3_resource.Bucket(bucket)

try:
s3_client.upload_fileobj(binary, bucket, key)
for file_path in glob.glob(file_path_with_pattern):
file_name = os.path.basename(file_path)
__s3file = f'{key}/{file_name}'
print(file_path, __s3file)
s3_bucket.upload_file(file_path, __s3file)
except ClientError as e:
print(e)
return False
return True

def upload_s3files(s3uri, file_path_with_pattern):
s3_client = boto3.client('s3', region_name = cmd_opts.region_name)

pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]

for file_name in glob.glob(file_path_with_pattern):
binary = io.BytesIO(open(file_name, 'rb').read())
key = key + file_name
try:
s3_client.upload_fileobj(binary, bucket, key)
except ClientError as e:
print(e)
return False
return True

def upload_s3folder(s3uri, file_path):
pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]

s3_resource = boto3.resource('s3')
s3_bucket = s3_resource.Bucket(bucket)
Expand All @@ -282,7 +269,7 @@ def upload_s3folder(s3uri, file_path):
for path, _, files in os.walk(file_path):
for file in files:
dest_path = path.replace(file_path,"")
__s3file = os.path.normpath(s3uri + dest_path + '/' + file)
__s3file = f'{key}{dest_path}/{file}'
__local_file = os.path.join(path, file)
print(__local_file, __s3file)
s3_bucket.upload_file(__local_file, __s3file)
Expand Down Expand Up @@ -400,7 +387,7 @@ def train():
*txt2img_preview_params
)
try:
upload_s3file(embeddings_s3uri, os.path.join(cmd_opts.embeddings_dir, '{0}.pt'.format(train_embedding_name)), '{0}.pt'.format(train_embedding_name))
upload_s3files(embeddings_s3uri, os.path.join(cmd_opts.embeddings_dir, '{0}.pt'.format(train_embedding_name)))
except Exception as e:
traceback.print_exc()
print(e)
Expand Down Expand Up @@ -666,13 +653,19 @@ def train():
db_model_dir = os.path.join(db_model_dir, "dreambooth")

try:
print('Uploading SD Models...')
upload_s3files(
sd_models_s3uri,
os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.yaml')
)
upload_s3files(
sd_models_s3uri,
os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.pt')
os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.ckpt')
)
print('Uploading DB Models...')
upload_s3folder(
db_models_s3uri,
db_model_dir
f'{db_models_s3uri}/{db_model_name}',
os.path.join(db_model_dir, db_model_name)
)
except Exception as e:
traceback.print_exc()
Expand Down

0 comments on commit f41e8be

Please sign in to comment.