Skip to content

Commit

Permalink
model file loaded dynamically from s3
Browse files Browse the repository at this point in the history
  • Loading branch information
xie river committed Apr 2, 2023
1 parent 7146b33 commit 37bb2b1
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 11 deletions.
2 changes: 1 addition & 1 deletion localizations/zh_CN.json
Expand Up @@ -840,7 +840,7 @@
"Save Preview(s) Frequency (Epochs)": "保存预览频率 (Epochs)",
"A generic prompt used to generate a sample image to verify model fidelity.": "用于生成样本图像以验证模型保真度的通用提示。",
"Job detail":"训练任务详情",
"S3 bucket name for uploading train images":"上传训练图片集的S3桶名",
"S3 bucket name for uploading/downloading images":"上传训练图片集或者下载生成图片的S3桶名",
"Output S3 folder":"S3文件夹目录",
"Upload Train Images to S3":"上传训练图片到S3",
"Error, please configure a S3 bucket at settings page first":"失败,请先到设置页面配置S3桶名",
Expand Down
2 changes: 1 addition & 1 deletion modules/api/api.py
Expand Up @@ -457,7 +457,7 @@ def invocations(self, req: InvocationsRequest):
traceback.print_exc()

def ping(self):
print('-------ping------')
# print('-------ping------')
return {'status': 'Healthy'}

def launch(self, server_name, port):
Expand Down
2 changes: 1 addition & 1 deletion modules/shared.py
Expand Up @@ -423,7 +423,7 @@ def refresh_sagemaker_endpoints(username):
}))

options_templates.update(options_section(('saving-paths', "Paths for saving"), {
"train_files_s3bucket":OptionInfo("","S3 bucket name for uploading train images",component_args=hide_dirs),
"train_files_s3bucket":OptionInfo("","S3 bucket name for uploading/downloading images",component_args=hide_dirs),
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
Expand Down
37 changes: 31 additions & 6 deletions modules/ui.py
Expand Up @@ -86,6 +86,29 @@
def gr_show(visible=True):
return {"visible": visible, "__type__": "update"}

## Begin output images uploaded to s3 by River
s3_resource = boto3.resource('s3')

def save_images_to_s3(full_fillnames,timestamp):
username = shared.username
sagemaker_endpoint = shared.opts.sagemaker_endpoint
bucket_name = opts.train_files_s3bucket
if bucket_name == '':
return 'Error, please configure a S3 bucket at settings page first'
s3_bucket = s3_resource.Bucket(bucket_name)
folder_name = f"output-images/{username}/{sagemaker_endpoint}/{timestamp}"
try:
for i, fname in enumerate(full_fillnames):
filename = fname.split('/')[-1]
object_name = f"{folder_name}/{filename}"
s3_bucket.upload_file(fname,object_name)
print (f'upload file [{i}]:{filename} to s3://{bucket_name}/{object_name}')
except ClientError as e:
print(e)
return e
return f"s3://{bucket_name}/{folder_name}"
## End output images uploaded to s3 by River


sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
Expand Down Expand Up @@ -147,7 +170,7 @@ def __init__(self, d=None):

os.makedirs(opts.outdir_save, exist_ok=True)

with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
with open(os.path.join(opts.outdir_save, "log.csv"), "w", encoding="utf8", newline='') as file:
at_start = file.tell() == 0
writer = csv.writer(file)
if at_start:
Expand All @@ -163,16 +186,19 @@ def __init__(self, d=None):
break

fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)

filename = os.path.relpath(fullfn, path)
print(f'fullfn:{fullfn},\n txt_fullfn:{txt_fullfn} \nfilename:{filename}')
filenames.append(filename)
fullfns.append(fullfn)
if txt_fullfn:
filenames.append(os.path.basename(txt_fullfn))
fullfns.append(txt_fullfn)

writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])


timestamp = datetime.now(timezone(timedelta(hours=+8))).strftime('%Y-%m-%dT%H:%M:%S')
logfile = os.path.join(opts.outdir_save, "log.csv")
s3folder = save_images_to_s3(fullfns,timestamp)
save_images_to_s3([logfile],timestamp)
# Make Zip
if do_make_zip:
zip_filepath = os.path.join(path, "images.zip")
Expand All @@ -184,7 +210,7 @@ def __init__(self, d=None):
zip_file.writestr(filenames[i], f.read())
fullfns.insert(0, zip_filepath)

return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}, \nS3 folder:\n{s3folder}")



Expand Down Expand Up @@ -1466,7 +1492,6 @@ def update_orig(image, state):
with gr.Row().style(equal_height=False):
with gr.Tabs(elem_id="train_tabs"):
## Begin add s3 images upload interface by River
s3_resource = boto3.resource('s3')
def upload_to_s3(imgs):
username = shared.username
timestamp = datetime.now(timezone(timedelta(hours=+8))).strftime('%Y-%m-%dT%H:%M:%S')
Expand Down
129 changes: 127 additions & 2 deletions webui.py
Expand Up @@ -36,6 +36,9 @@
from modules.shared import cmd_opts, opts
import modules.hypernetworks.hypernetwork
import boto3
import threading
import time

import traceback
from botocore.exceptions import ClientError
import requests
Expand Down Expand Up @@ -64,6 +67,21 @@ def initialize():
modules.scripts.load_scripts()
return

## auto reload new models from s3 add by River
sd_models_tmp_dir = "/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/"
cn_models_tmp_dir = "/opt/ml/code/stable-diffusion-webui/models/ControlNet/"
session = boto3.Session()
region_name = session.region_name
sts_client = session.client('sts')
account_id = sts_client.get_caller_identity()['Account']
sg_defaul_bucket_name = f"sagemaker-{region_name}-{account_id}"
s3_folder_sd = "stable-diffusion-webui/models/Stable-diffusion"
s3_folder_cn = "stable-diffusion-webui/models/ControlNet"

sync_s3_folder(sg_defaul_bucket_name,s3_folder_sd,sd_models_tmp_dir,'sd')
sync_s3_folder(sg_defaul_bucket_name,s3_folder_cn,cn_models_tmp_dir,'cn')
## end

modelloader.cleanup_models()
modules.sd_models.setup_model()
codeformer.setup_model(cmd_opts.codeformer_models_path)
Expand Down Expand Up @@ -182,6 +200,114 @@ def user_auth(username, password):

return response.status_code == 200


def register_sd_models(sd_models_dir):
print ('---register_sd_models()----')
if 'endpoint_name' in os.environ:
items = []
api_endpoint = os.environ['api_endpoint']
endpoint_name = os.environ['endpoint_name']
print(f'api_endpoint:{api_endpoint}\nendpoint_name:{endpoint_name}')
for file in os.listdir(sd_models_dir):
if os.path.isfile(os.path.join(sd_models_dir, file)) and (file.endswith('.ckpt') or file.endswith('.safetensors')):
hash = modules.sd_models.model_hash(os.path.join(sd_models_dir, file))
item = {}
item['model_name'] = file
item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml'
item['filename'] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}'.format(file)
item['hash'] = hash
item['title'] = '{0} [{1}]'.format(file, hash)
item['endpoint_name'] = endpoint_name
items.append(item)
inputs = {
'items': items
}
params = {
'module': 'Stable-diffusion'
}
if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'):
response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params)
print(response)

def register_cn_models(cn_models_dir):
print ('---register_cn_models()----')
if 'endpoint_name' in os.environ:
items = []
api_endpoint = os.environ['api_endpoint']
endpoint_name = os.environ['endpoint_name']
print(f'api_endpoint:{api_endpoint}\nendpoint_name:{endpoint_name}')

inputs = {
'items': items
}
params = {
'module': 'ControlNet'
}
for file in os.listdir(cn_models_dir):
if os.path.isfile(os.path.join(cn_models_dir, file)) and \
(file.endswith('.pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')):
hash = modules.sd_models.model_hash(os.path.join(cn_models_dir, file))
item = {}
item['model_name'] = file
item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash)
item['endpoint_name'] = endpoint_name
items.append(item)

if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'):
response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params)
print(response)


def sync_s3_folder(bucket_name, s3_folder, local_folder,mode):
print(f"sync S3 bucket '{bucket_name}', folder '{s3_folder}' for new files...")
# Create tmp folders
os.makedirs(os.path.dirname(local_folder), exist_ok=True)
print(f'create dir: {os.path.dirname(local_folder)}')
# Create an S3 client
s3 = boto3.client('s3')
def sync():
# List all objects in the S3 folder
response = s3.list_objects_v2(Bucket=bucket_name, Prefix=s3_folder)
# Check if there are any new or deleted files
s3_files = set()
for obj in response.get('Contents', []):
s3_files.add(obj['Key'].replace(s3_folder, '').lstrip('/'))

local_files = set(os.listdir(local_folder))

new_files = s3_files - local_files
del_files = local_files - s3_files

# Copy new files to local folder
for file in new_files:
s3.download_file(bucket_name, s3_folder + '/' + file, os.path.join(local_folder, file))
print(f'download_file:from {bucket_name}/{s3_folder}/{file} to {os.path.join(local_folder, file)}')

# Delete vanished files from local folder
for file in del_files:
os.remove(os.path.join(local_folder, file))
print(f'remove file {os.path.join(local_folder, file)}')
# If there are changes
if len(new_files) | len(del_files):
if mode == 'sd':
register_sd_models(local_folder)
elif mode == 'cn':
register_cn_models(local_folder)
else:
print(f'unsupported mode:{mode}')
# Create a thread function to keep syncing with the S3 folder
def sync_thread():
while True:
sync()
time.sleep(60)
# Initialize at launch
sync()
# Start the thread
thread = threading.Thread(target=sync_thread)
thread.start()
return thread


def webui():
launch_api = cmd_opts.api
initialize()
Expand Down Expand Up @@ -218,7 +344,7 @@ def webui():

if launch_api:
create_api(app)

cmd_sd_models_path = cmd_opts.ckpt_dir
sd_models_dir = os.path.join(shared.models_path, "Stable-diffusion")
if cmd_sd_models_path is not None:
Expand Down Expand Up @@ -274,7 +400,6 @@ def webui():
if api_endpoint.startswith('http://') or api_endpoint.startswith('https://'):
response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs, params=params)
print(response)

modules.script_callbacks.app_started_callback(shared.demo, app)

wait_on_server(shared.demo)
Expand Down

0 comments on commit 37bb2b1

Please sign in to comment.