diff --git a/webui.py b/webui.py index 5a967c72bbd..bb7c9b28ff7 100644 --- a/webui.py +++ b/webui.py @@ -238,42 +238,29 @@ def webui(): modules.sd_models.list_models() print('Restarting Gradio') -def upload_s3file(s3uri, file_path, file_name): - s3_client = boto3.client('s3', region_name = cmd_opts.region_name) - +def upload_s3files(s3uri, file_path_with_pattern): pos = s3uri.find('/', 5) bucket = s3uri[5 : pos] key = s3uri[pos + 1 : ] - binary = io.BytesIO(open(file_path, 'rb').read()) - key = key + file_name + s3_resource = boto3.resource('s3') + s3_bucket = s3_resource.Bucket(bucket) + try: - s3_client.upload_fileobj(binary, bucket, key) + for file_path in glob.glob(file_path_with_pattern): + file_name = os.path.basename(file_path) + __s3file = f'{key}/{file_name}' + print(file_path, __s3file) + s3_bucket.upload_file(file_path, __s3file) except ClientError as e: print(e) return False return True -def upload_s3files(s3uri, file_path_with_pattern): - s3_client = boto3.client('s3', region_name = cmd_opts.region_name) - - pos = s3uri.find('/', 5) - bucket = s3uri[5 : pos] - key = s3uri[pos + 1 : ] - - for file_name in glob.glob(file_path_with_pattern): - binary = io.BytesIO(open(file_name, 'rb').read()) - key = key + file_name - try: - s3_client.upload_fileobj(binary, bucket, key) - except ClientError as e: - print(e) - return False - return True - def upload_s3folder(s3uri, file_path): pos = s3uri.find('/', 5) bucket = s3uri[5 : pos] + key = s3uri[pos + 1 : ] s3_resource = boto3.resource('s3') s3_bucket = s3_resource.Bucket(bucket) @@ -282,7 +269,7 @@ def upload_s3folder(s3uri, file_path): for path, _, files in os.walk(file_path): for file in files: dest_path = path.replace(file_path,"") - __s3file = os.path.normpath(s3uri + dest_path + '/' + file) + __s3file = f'{key}{dest_path}/{file}' __local_file = os.path.join(path, file) print(__local_file, __s3file) s3_bucket.upload_file(__local_file, __s3file) @@ -400,7 +387,7 @@ def train(): *txt2img_preview_params ) try: - upload_s3file(embeddings_s3uri, os.path.join(cmd_opts.embeddings_dir, '{0}.pt'.format(train_embedding_name)), '{0}.pt'.format(train_embedding_name)) + upload_s3files(embeddings_s3uri, os.path.join(cmd_opts.embeddings_dir, '{0}.pt'.format(train_embedding_name))) except Exception as e: traceback.print_exc() print(e) @@ -666,13 +653,19 @@ def train(): db_model_dir = os.path.join(db_model_dir, "dreambooth") try: + print('Uploading SD Models...') + upload_s3files( + sd_models_s3uri, + os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.yaml') + ) upload_s3files( sd_models_s3uri, - os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.pt') + os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.ckpt') ) + print('Uploading DB Models...') upload_s3folder( - db_models_s3uri, - db_model_dir + f'{db_models_s3uri}/{db_model_name}', + os.path.join(db_model_dir, db_model_name) ) except Exception as e: traceback.print_exc()