Skip to content

Commit

Permalink
dynamic models loading
Browse files Browse the repository at this point in the history
  • Loading branch information
xie river committed Apr 27, 2023
1 parent ea7592c commit 1d92fc5
Show file tree
Hide file tree
Showing 5 changed files with 326 additions and 7 deletions.
2 changes: 1 addition & 1 deletion modules/api/api.py
Expand Up @@ -796,7 +796,7 @@ def invocations(self, req: InvocationsRequest):
return InvocationsErrorResponse(error = str(e))

def ping(self):
print('-------ping------')
# print('-------ping------')
return {'status': 'Healthy'}

def launch(self, server_name, port):
Expand Down
9 changes: 9 additions & 0 deletions modules/script_callbacks.py
Expand Up @@ -93,6 +93,7 @@ def __init__(self, imgs, cols, rows):
callbacks_infotext_pasted=[],
callbacks_script_unloaded=[],
callbacks_before_ui=[],
callbacks_update_cn_models=[]
)


Expand Down Expand Up @@ -224,6 +225,12 @@ def before_ui_callback():
except Exception:
report_exception(c, 'before_ui')

def update_cn_models_callback():
for c in callback_map['callbacks_update_cn_models']:
try:
c.callback()
except Exception:
report_exception(c, 'callbacks_update_cn_models')

def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
Expand All @@ -247,6 +254,8 @@ def remove_callbacks_for_function(callback_func):
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
callback_list.remove(callback_to_remove)

def on_update_cn_models(callback):
add_callback(callback_map['callbacks_update_cn_models'], callback)

def on_app_started(callback):
"""register a function to be called when the webui started, the gradio `Block` component and
Expand Down
50 changes: 49 additions & 1 deletion modules/shared.py
Expand Up @@ -4,7 +4,7 @@
import os
import sys
import time

import threading
from PIL import Image
import gradio as gr
import tqdm
Expand All @@ -18,6 +18,54 @@

demo = None

models_s3_bucket = None
s3_folder_sd = None
s3_folder_cn = None
s3_folder_lora = None
syncLock = threading.Lock()
tmp_models_dir = '/tmp/models'
tmp_cache_dir = '/tmp/model_sync_cache'
class ModelsRef:
def __init__(self):
self.models_ref = {}

def get_models_ref_dict(self):
return self.models_ref

def add_models_ref(self, model_name):
if model_name in self.models_ref:
self.models_ref[model_name] += 1
else:
self.models_ref[model_name] = 0

def remove_model_ref(self,model_name):
if self.models_ref.get(model_name):
del self.models_ref[model_name]

def get_models_ref(self, model_name):
return self.models_ref.get(model_name)

def get_least_ref_model(self):
sorted_models = sorted(self.models_ref.items(), key=lambda item: item[1])
if sorted_models:
least_ref_model, least_counter = sorted_models[0]
return least_ref_model,least_counter
else:
return None,None

def pop_least_ref_model(self):
sorted_models = sorted(self.models_ref.items(), key=lambda item: item[1])
if sorted_models:
least_ref_model, least_counter = sorted_models[0]
del self.models_ref[least_ref_model]
return least_ref_model,least_counter
else:
return None,None

sd_models_Ref = ModelsRef()
cn_models_Ref = ModelsRef()
lora_models_Ref = ModelsRef()

parser = cmd_args.parser

script_loading.preload_extensions(extensions_dir, parser)
Expand Down
236 changes: 236 additions & 0 deletions modules/sync_models.py
@@ -0,0 +1,236 @@
import os,threading,psutil,json,time
import boto3
import modules.shared as shared
import modules.sd_models as sd_models
import modules.script_callbacks as script_callbacks
from modules.shared import syncLock

FREESPACE = 20
def check_space_s3_download(s3_client,bucket_name,s3_folder,local_folder,file,size,mode):
print(f"bucket_name:{bucket_name},s3_folder:{s3_folder},file:{file}")
if file == '' or None:
print('Debug log:file is empty, return')
return True
src = s3_folder + '/' + file
dist = os.path.join(local_folder, file)
os.makedirs(os.path.dirname(dist), exist_ok=True)
# Get disk usage statistics
disk_usage = psutil.disk_usage('/tmp')
freespace = disk_usage.free/(1024**3)
print(f"Total space: {disk_usage.total/(1024**3)}, Used space: {disk_usage.used/(1024**3)}, Free space: {freespace}")
if freespace - size >= FREESPACE:
try:
s3_client.download_file(bucket_name, src, dist)
#init ref cnt to 0, when the model file first time download
hash = sd_models.model_hash(dist)
if mode == 'sd' :
shared.sd_models_Ref.add_models_ref('{0} [{1}]'.format(file, hash))
elif mode == 'cn':
shared.cn_models_Ref.add_models_ref('{0} [{1}]'.format(os.path.splitext(file)[0], hash))
elif mode == 'lora':
shared.lora_models_Ref.add_models_ref('{0} [{1}]'.format(os.path.splitext(file)[0], hash))
print(f'download_file success:from {bucket_name}/{src} to {dist}')
except Exception as e:
print(f'download_file error: from {bucket_name}/{src} to {dist}')
print(f"An error occurred: {e}")
return False
return True
else:
return False

def free_local_disk(local_folder,size,mode):
disk_usage = psutil.disk_usage('/tmp')
freespace = disk_usage.free/(1024**3)
if freespace - size >= FREESPACE:
return
models_Ref = None
if mode == 'sd' :
models_Ref = shared.sd_models_Ref
elif mode == 'cn':
models_Ref = shared.cn_models_Ref
elif mode == 'lora':
models_Ref = shared.lora_models_Ref
model_name,ref_cnt = models_Ref.get_least_ref_model()
print (f'shared.{mode}_models_Ref:{models_Ref.get_models_ref_dict()} -- model_name:{model_name}')
if model_name and ref_cnt:
filename = model_name[:model_name.rfind("[")]
os.remove(os.path.join(local_folder, filename))
disk_usage = psutil.disk_usage('/tmp')
freespace = disk_usage.free/(1024**3)
print(f"Remove file: {os.path.join(local_folder, filename)} now left space:{freespace}")
else:
## if ref_cnt == 0, then delete the oldest zero_ref one
zero_ref_models = set([model[:model.rfind(" [")] for model, count in models_Ref.get_models_ref_dict().items() if count == 0])
local_files = set(os.listdir(local_folder))
# join with local
files = [(os.path.join(local_folder, file), os.path.getctime(os.path.join(local_folder, file))) for file in zero_ref_models.intersection(local_files)]
if len(files) == 0:
print(f"No files to remove in folder: {local_folder}, please remove some files in S3 bucket")
return
files.sort(key=lambda x: x[1])
oldest_file = files[0][0]
os.remove(oldest_file)
disk_usage = psutil.disk_usage('/tmp')
freespace = disk_usage.free/(1024**3)
print(f"Remove file: {oldest_file} now left space:{freespace}")
filename = os.path.basename(oldest_file)

def list_s3_objects(s3_client,bucket_name, prefix=''):
objects = []
paginator = s3_client.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
# iterate over pages
for page in page_iterator:
# loop through objects in page
if 'Contents' in page:
for obj in page['Contents']:
_, ext = os.path.splitext(obj['Key'].lstrip('/'))
if ext in ['.pt', '.pth', '.ckpt', '.safetensors','.yaml']:
objects.append(obj)
# if there are more pages to fetch, continue
if 'NextContinuationToken' in page:
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix,
ContinuationToken=page['NextContinuationToken'])
return objects


def initial_s3_download(s3_client, s3_folder, local_folder,cache_dir,mode):
# Create tmp folders
os.makedirs(os.path.dirname(local_folder), exist_ok=True)
os.makedirs(os.path.dirname(cache_dir), exist_ok=True)
print(f'create dir: {os.path.dirname(local_folder)}')
print(f'create dir: {os.path.dirname(cache_dir)}')
s3_file_name = os.path.join(cache_dir,f's3_files_{mode}.json')
# Create an empty file if not exist
if os.path.isfile(s3_file_name) == False:
s3_files = {}
with open(s3_file_name, "w") as f:
json.dump(s3_files, f)
# List all objects in the S3 folder
s3_objects = list_s3_objects(s3_client=s3_client, bucket_name=shared.models_s3_bucket, prefix=s3_folder)
# only download on model at initialization
fnames_dict = {}
# if there v2 models, one root should have two files (.ckpt,.yaml)
for obj in s3_objects:
filename = obj['Key'].replace(s3_folder, '').lstrip('/')
root, ext = os.path.splitext(filename)
model = fnames_dict.get(root)
if model:
model.append(filename)
else:
fnames_dict[root] = [filename]
tmp_s3_files = {}
for obj in s3_objects:
etag = obj['ETag'].strip('"').strip("'")
size = obj['Size']/(1024**3)
filename = obj['Key'].replace(s3_folder, '').lstrip('/')
tmp_s3_files[filename] = [etag,size]

#only fetch the first model to download.
if mode == 'sd':
s3_files = {}
try:
_, file_names = next(iter(fnames_dict.items()))
for fname in file_names:
s3_files[fname] = tmp_s3_files.get(fname)
check_space_s3_download(s3_client,shared.models_s3_bucket, s3_folder,local_folder, fname, tmp_s3_files.get(fname)[1], mode)
except Exception as e:
print(e)

print(f'-----s3_files---{s3_files}')
# save the lastest one
with open(s3_file_name, "w") as f:
json.dump(s3_files, f)



def sync_s3_folder(local_folder,cache_dir,mode):
s3 = boto3.client('s3')
def sync(mode):
# print (f'sync:{mode}')
if mode == 'sd':
s3_folder = shared.s3_folder_sd
elif mode == 'cn':
s3_folder = shared.s3_folder_cn
elif mode == 'lora':
s3_folder = shared.s3_folder_lora
else:
s3_folder = ''
# Check and Create tmp folders
os.makedirs(os.path.dirname(local_folder), exist_ok=True)
os.makedirs(os.path.dirname(cache_dir), exist_ok=True)
s3_file_name = os.path.join(cache_dir,f's3_files_{mode}.json')
# Create an empty file if not exist
if os.path.isfile(s3_file_name) == False:
s3_files = {}
with open(s3_file_name, "w") as f:
json.dump(s3_files, f)

# List all objects in the S3 folder
s3_objects = list_s3_objects(s3_client=s3,bucket_name=shared.models_s3_bucket, prefix=s3_folder)
# Check if there are any new or deleted files
s3_files = {}
for obj in s3_objects:
etag = obj['ETag'].strip('"').strip("'")
size = obj['Size']/(1024**3)
key = obj['Key'].replace(s3_folder, '').lstrip('/')
s3_files[key] = [etag,size]

# to compared the latest s3 list with last time saved in local json,
# read it first
s3_files_local = {}
with open(s3_file_name, "r") as f:
s3_files_local = json.load(f)
# print (f's3_files:{s3_files}')
# print (f's3_files_local:{s3_files_local}')
# save the lastest one
with open(s3_file_name, "w") as f:
json.dump(s3_files, f)
mod_files = set()
new_files = set([key for key in s3_files if key not in s3_files_local])
del_files = set([key for key in s3_files_local if key not in s3_files])
registerflag = False
#compare etag changes
for key in set(s3_files_local.keys()).intersection(s3_files.keys()):
local_etag = s3_files_local.get(key)[0]
if local_etag and local_etag != s3_files[key][0]:
mod_files.add(key)
# Delete vanished files from local folder
for file in del_files:
if os.path.isfile(os.path.join(local_folder, file)):
os.remove(os.path.join(local_folder, file))
print(f'remove file {os.path.join(local_folder, file)}')
# Add new files
for file in new_files.union(mod_files):
registerflag = True
retry = 3 ##retry limit times to prevent dead loop in case other folders is empty
while retry:
ret = check_space_s3_download(s3,shared.models_s3_bucket, s3_folder,local_folder, file, s3_files[file][1], mode)
#if the space is not enough free
if ret:
retry = 0
else:
free_local_disk(local_folder,s3_files[file][1],mode)
retry = retry - 1
if registerflag:
if mode == 'sd':
#Refreshing Model List
sd_models.list_models()
# cn models sync not supported temporally due to an unfixed bug
elif mode == 'cn':
script_callbacks.update_cn_models_callback()
elif mode == 'lora':
print('Nothing To do')


# Create a thread function to keep syncing with the S3 folder
def sync_thread(mode):
while True:
syncLock.acquire()
sync(mode)
syncLock.release()
time.sleep(30)
thread = threading.Thread(target=sync_thread,args=(mode,))
thread.start()
print (f'{mode}_sync thread start')
return thread
36 changes: 31 additions & 5 deletions webui.py
Expand Up @@ -63,6 +63,8 @@
import json
import shutil
import traceback
from modules.sync_models import initial_s3_download,sync_s3_folder


if cmd_opts.train:
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -299,6 +301,26 @@ def webui():
name = http_model['name']
http_download(uri, f'/tmp/models/{name}/{filename}')

print(os.system('df -h'))
sd_models_tmp_dir = f"{shared.tmp_models_dir}/Stable-diffusion/"
cn_models_tmp_dir = f"{shared.tmp_models_dir}/ControlNet/"
lora_models_tmp_dir = f"{shared.tmp_models_dir}/Lora/"
cache_dir = f"{shared.tmp_cache_dir}/"
session = boto3.Session()
region_name = session.region_name
sts_client = session.client('sts')
account_id = sts_client.get_caller_identity()['Account']
sg_s3_bucket = f"sagemaker-{region_name}-{account_id}"
if not shared.models_s3_bucket:
shared.models_s3_bucket = os.environ['sg_default_bucket'] if os.environ.get('sg_default_bucket') else sg_s3_bucket
shared.s3_folder_sd = "stable-diffusion-webui/models/Stable-diffusion"
shared.s3_folder_cn = "stable-diffusion-webui/models/ControlNet"
shared.s3_folder_lora = "stable-diffusion-webui/models/Lora"
#only download the cn models and the first sd model from default bucket, to accerlate the startup time
initial_s3_download(s3_client,shared.s3_folder_sd,sd_models_tmp_dir,cache_dir,'sd')
sync_s3_folder(sd_models_tmp_dir,cache_dir,'sd')
sync_s3_folder(cn_models_tmp_dir,cache_dir,'cn')
sync_s3_folder(lora_models_tmp_dir,cache_dir,'lora')
initialize()

while 1:
Expand Down Expand Up @@ -667,11 +689,15 @@ def s3_download(s3uri, path):
json.dump(cache, open('cache', 'w'))

def http_download(httpuri, path):
with requests.get(httpuri, stream=True) as r:
r.raise_for_status()
with open(path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
try:
with requests.get(httpuri, stream=True) as r:
r.raise_for_status()
with open(path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
except Exception as e:
print(f'http_download Exception:{e}')


if __name__ == "__main__":
if cmd_opts.train:
Expand Down

0 comments on commit 1d92fc5

Please sign in to comment.