Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed May 5, 2023
1 parent 3cef84f commit 480a6a8
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 44 deletions.
34 changes: 34 additions & 0 deletions modules/shared.py
Expand Up @@ -15,6 +15,8 @@
import modules.devices as devices
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
from botocore.exceptions import ClientError
import glob

demo = None

Expand Down Expand Up @@ -719,3 +721,35 @@ def http_download(httpuri, path):
with open(path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)

def upload_s3files(s3uri, file_path_with_pattern):
pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]

try:
for file_path in glob.glob(file_path_with_pattern):
file_name = os.path.basename(file_path)
__s3file = f'{key}{file_name}'
print(file_path, __s3file)
s3_client.upload_file(file_path, bucket, __s3file)
except ClientError as e:
print(e)
return False
return True

def upload_s3folder(s3uri, file_path):
pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]

try:
for path, _, files in os.walk(file_path):
for file in files:
dest_path = path.replace(file_path,"")
__s3file = f'{key}{dest_path}/{file}'
__local_file = os.path.join(path, file)
print(__local_file, __s3file)
s3_client.upload_file(__local_file, bucket, __s3file)
except Exception as e:
print(e)
49 changes: 5 additions & 44 deletions webui.py
Expand Up @@ -71,7 +71,6 @@
from extensions.sd_dreambooth_extension.scripts.dreambooth import performance_wizard, training_wizard
from extensions.sd_dreambooth_extension.dreambooth.db_concept import Concept
from modules import paths
import glob

startup_timer.record("other imports")

Expand Down Expand Up @@ -394,44 +393,6 @@ def webui():
startup_timer.record("initialize extra networks")

if cmd_opts.train:
def upload_s3files(s3uri, file_path_with_pattern):
pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]

s3_resource = boto3.resource('s3')
s3_bucket = s3_resource.Bucket(bucket)

try:
for file_path in glob.glob(file_path_with_pattern):
file_name = os.path.basename(file_path)
__s3file = f'{key}{file_name}'
print(file_path, __s3file)
s3_bucket.upload_file(file_path, __s3file)
except ClientError as e:
print(e)
return False
return True

def upload_s3folder(s3uri, file_path):
pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]

s3_resource = boto3.resource('s3')
s3_bucket = s3_resource.Bucket(bucket)

try:
for path, _, files in os.walk(file_path):
for file in files:
dest_path = path.replace(file_path,"")
__s3file = f'{key}{dest_path}/{file}'
__local_file = os.path.join(path, file)
print(__local_file, __s3file)
s3_bucket.upload_file(__local_file, __s3file)
except Exception as e:
print(e)

def train():
initialize()

Expand Down Expand Up @@ -585,28 +546,28 @@ def train():

try:
print('Uploading SD Models...')
upload_s3files(
shared.upload_s3files(
sd_models_s3uri,
os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.yaml')
)
if db_save_safetensors:
upload_s3files(
shared.upload_s3files(
sd_models_s3uri,
os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.safetensors')
)
else:
upload_s3files(
shared.upload_s3files(
sd_models_s3uri,
os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.ckpt')
)
print('Uploading DB Models...')
upload_s3folder(
shared.upload_s3folder(
f'{db_models_s3uri}{db_model_name}',
os.path.join(db_model_dir, db_model_name)
)
if db_config.use_lora:
print('Uploading Lora Models...')
upload_s3files(
shared.upload_s3files(
lora_models_s3uri,
os.path.join(lora_model_dir, f'{db_model_name}_*.pt')
)
Expand Down

0 comments on commit 480a6a8

Please sign in to comment.