Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Mar 22, 2023
2 parents 92d53f5 + fa0bfc9 commit 216a63c
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 22 deletions.
8 changes: 4 additions & 4 deletions localizations/zh_CN.json
Expand Up @@ -634,6 +634,7 @@
"Username": "用户名",
"Password": "密码",
"Options": "选项",
"Sagemaker Endpoints": "Sagemaker终端节点",
"Output": "输出",
"User management (Only available for admin user)": "用户管理 (仅管理员用户可用)",
"Images S3 URI": "图像 S3 位置",
Expand Down Expand Up @@ -676,7 +677,6 @@
"Parameters": "参数",
"Concepts S3 URI": "概念 S3 位置",
"Intervals": "间隔",
"Training Steps Per Image (Epochs) ": "每张图片的训练步数 (Epochs)",
"Max Training Steps": "最大训练步数",
"Pause After N Epochs": "经过若干步后暂停",
"Amount of time to pause between Epochs, in Seconds": "相邻 Epochs 之间暂停的时间",
Expand All @@ -688,7 +688,7 @@
"Batch Size": "批量大小",
"Class Batch Size": "类批量大小",
"Learning Rate": "学习率",
"Lora unet Learning Rate ": "Lora unet 学习率",
"Lora unet Learning Rate": "Lora unet 学习率",
"Lora Text Encoder Learning Rate": "Lora 文本编码器学习率",
"Scale Learning Rate": "规模学习率",
"Learning Rate Scheduler": "学习率调度器",
Expand Down Expand Up @@ -802,7 +802,7 @@
"Only use mid-control when inference": "推理时仅使用 mid-control",
"Enable CFG-Based guidance": "使用基于扩散度指引",

"Lora UNET Learning Rate ": "Lora UNET 学习率",
"Lora UNET Learning Rate": "Lora UNET 学习率",
"512x Model": "512 模型",
"Unfreeze Model": "取消冻结网络",
"General": "通用",
Expand All @@ -825,7 +825,7 @@
"Sample Negative Prompt": "样本图像负向提示词",
"Class Images Per Instance Image": "为实例图片生成分类图片数",
"Saving":"保存",
"Save in .safetensors format": "保存为 .safetensorts 格式",
"Save in .safetensors format": "保存为 .safetensors 格式",
"Checkpoints": "检查点",
"Generate a .ckpt file when saving during training.": "在训练时保存为 .ckpt 格式",
"Generate a .ckpt file when training completes.": "在训练结束时保存为 .ckpt 格式",
Expand Down
47 changes: 42 additions & 5 deletions localizations/zh_TW.json
Expand Up @@ -669,7 +669,6 @@
"Parameters": "參數",
"Concepts S3 URI": "概念 S3 位置",
"Intervals": "間隔",
"Training Steps Per Image (Epochs)": "每張圖片的訓練步數 (Epochs)",
"Max Training Steps": "最大訓練步數",
"Pause After N Epochs": "經過若干步後暫停",
"Amount of time to pause between Epochs, in Seconds": "相鄰 Epochs 之間暫停的時間",
Expand All @@ -681,13 +680,13 @@
"Batch Size": "批量大小",
"Class Batch Size": "類批量大小",
"Learning Rate": "學習率",
"Lora unet Learning Rate ": "Lora unet 學習率",
"Lora Text Encoder Learning Rate ": "Lora 文本編碼器學習率",
"Lora unet Learning Rate": "Lora unet 學習率",
"Lora Text Encoder Learning Rate": "Lora 文本編碼器學習率",
"Scale Learning Rate": "規模學習率",
"Learning Rate Scheduler": "學習率調度器",
"Learning Rate Warmup Steps ": "學習率預熱步驟",
"Learning Rate Warmup Steps": "學習率預熱步驟",
"Image Processing": "圖像處理",
"Resolution ": "分辨率",
"Resolution": "分辨率",
"Center Crop": "居中裁剪",
"Apply Horizontal Flip": "應用水平翻轉",
"Miscellaneous": "雜項",
Expand Down Expand Up @@ -795,5 +794,43 @@
"Only use mid-control when inference": "推理時僅使用 mid-control",
"Enable CFG-Based guidance": "使用基於擴散度指引",

"Lora UNET Learning Rate": "Lora UNET 學習率",
"512x Model": "512 模型",
"Unfreeze Model": "取消凍結網絡",
"General": "通用",
"Train UNET": "訓練 UNET",
"Step Ratio of Text Encoder Training": "文本編碼器訓練步長比",
"Freeze CLIP Normalization Layers": "凍結 CLIP 歸一化層",
"Clip Skip": "剪輯跳過",
"AdamW Weight Decay": "AdamW 權重衰減",
"Training Steps Per Image (Epochs)": "每張圖片的訓練步數 (Epochs)",
"Amount of time to pause between Epochs": "相鄰 Epochs 之間暫停的時間",
"Save Model Frequency (Epochs)": "保存模型的頻率 (Epchs)",
"Save Prview(s) Frequency (Epochs)": "保存於臉圖的頻率 (Epchs)",
"Set Gradients to None When Zeroing": "歸零時設置梯度為無",
"Prior Loss": "先前損失",
"Scale Prior Loss": "縮放先驗損失",
"Generate Classification Images Using txt2img": "使用文生圖來生成類圖",
"Concept 4": "概念4",
"Sanity Sample Prompt": "完整性採樣提示詞",
"Sanity Sample Seed": "完整性採樣種子",
"Sample Negative Prompt": "樣本圖像負向提示詞",
"Class Images Per Instance Image": "為實例圖片生成分類圖片數",
"Saving":"保存",
"Save in .safetensors format": "保存為 .safetensors 格式",
"Checkpoints": "檢查點",
"Generate a .ckpt file when saving during training.": "在訓練時保存為 .ckpt 格式",
"Generate a .ckpt file when training completes.": "在訓練結束時保存為 .ckpt 格式",
"Generate a .ckpt file when training is canceled.": "在訓練取消時保存為 .ckpt 格式",
"Diffusion Weights": "擴散權重",
"Save separate diffusers snapshots when saving during training.": "在訓練時保存為單獨的 diffusers 快照",
"Save separate diffusers snapshots when training completes.": "在訓練完成時保存為單獨的 diffusers 快照",
"Save separate diffusers snapshots when training is canceled.": "在訓練取消時保存為單獨的 diffusers 快照",
"Basic": "基本",
"Graph Smoothing Steps": "圖形平滑步數",
"Amount of time to pause between Epochs (s)": "Epochs 間隔等待時間",
"Save Preview(s) Frequency (Epochs)": "保存預覽頻率 (Epochs)",
"A generic prompt used to generate a sample image to verify model fidelity.": "用於生成樣本圖像以驗證模型保真度的通用提示。",

"--------": "--------"
}
3 changes: 2 additions & 1 deletion modules/api/api.py
Expand Up @@ -418,7 +418,8 @@ def invocations(self, req: InvocationsRequest):
response = requests.post(url=f'{api_endpoint}/sd/user', json=inputs)
if response.status_code == 200 and response.text != '':
try:
shared.opts.data = json.loads(response.text)
data = json.loads(response.text)
shared.opts.data = data['options']
with self.queue_lock:
sd_models.reload_model_weights()
except Exception as e:
Expand Down
26 changes: 26 additions & 0 deletions modules/shared.py
Expand Up @@ -326,6 +326,20 @@ def list_sagemaker_endpoints():

return sagemaker_endpoints

def intersection(lst1, lst2):
set1 = set(lst1)
set2 = set(lst2)

intersec = set1.intersection(set2)
return list(intersec)

def get_available_sagemaker_endpoints(item):
attrs = item.get('attributes', '')
if attrs == '':
return ''

return attrs.get('sagemaker_endpoints', '')

def refresh_sagemaker_endpoints(username):
global industrial_model, api_endpoint, sagemaker_endpoints

Expand All @@ -344,6 +358,18 @@ def refresh_sagemaker_endpoints(username):
for endpoint_item in json.loads(response.text):
sagemaker_endpoints.append(endpoint_item['EndpointName'])

# to filter user's available endpoints
inputs = {
'action': 'get',
'username': username
}
response = requests.post(url=f'{api_endpoint}/sd/user', json=inputs)
if response.status_code == 200 and response.text != '':
data = json.loads(response.text)
eps = get_available_sagemaker_endpoints(data)
if eps != '':
sagemaker_endpoints = intersection(eps.split(','), sagemaker_endpoints)

return sagemaker_endpoints

options_templates.update(options_section(('sd', "Stable Diffusion"), {
Expand Down
26 changes: 15 additions & 11 deletions modules/ui.py
Expand Up @@ -688,7 +688,7 @@ def update_username():
if response.status_code == 200:
items = []
for item in json.loads(response.text):
items.append([item['username'], item['password'], item['options'] if 'options' in item else ''])
items.append([item['username'], item['password'], item['options'] if 'options' in item else '', shared.get_sagemaker_endpoints(item)])
return gr.update(value=shared.username), gr.update(value=items if items != [] else None)
else:
return gr.update(value=shared.username), gr.update()
Expand Down Expand Up @@ -1994,13 +1994,13 @@ def sagemaker_train_hypernetwork(

with gr.Blocks(analytics_enabled=False) as user_interface:
user_dataframe = gr.Dataframe(
headers=["Username", "Password", "Options"],
headers=["Username", "Password", "Options", "Sagemaker Endpoints"],
row_count=2,
col_count=(3,"fixed"),
col_count=(4,"fixed"),
label="User management (Only available for admin user)",
interactive=True,
visible=True,
datatype=["str","str","str"],
datatype=["str","str","str", "str"],
type="array"
)

Expand All @@ -2018,14 +2018,18 @@ def save_userdata(user_dataframe, request: gr.Request):
if not access_token or tokens[access_token] != 'admin':
return gr.update()
items = []
for item in user_dataframe:
items.append(
{
'username': item[0],
'password': item[1],
'options': item[2]
for user_df in user_dataframe:
item = {
'username': user_df[0],
'password': user_df[1],
'options': user_df[2],
'attributes': {},
}
if user_df[3] != '':
item['attributes'] = {
'sagemaker_endpoints': user_df[3]
}
)
items.append(item)
inputs = {
'action': 'save',
'items': items
Expand Down
18 changes: 17 additions & 1 deletion webui.py
@@ -1,4 +1,5 @@
import os
import shutil
import threading
import time
import importlib
Expand Down Expand Up @@ -358,7 +359,8 @@ def train():
}
response = requests.post(url=f'{api_endpoint}/sd/user', json=inputs)
if response.status_code == 200 and response.text != '':
opts.data = json.loads(response.text)
data = json.loads(response.text)
opts.data = data['options']
modules.sd_models.load_model()

if train_task == 'embedding':
Expand Down Expand Up @@ -741,6 +743,20 @@ def train():
lora_models_s3uri,
os.path.join(lora_model_dir, f'{db_model_name}_*.pt')
)
#automatic tar latest checkpoint and upload to s3 by zheng on 2023.03.22
os.makedirs(os.path.dirname("/opt/ml/model/"), exist_ok=True)
train_steps=int(db_config.revision)
f1=os.path.join(sd_models_path, db_model_name, f'{db_model_name}_{train_steps}.yaml')
if os.path.exists(f1):
shutil.copy(f1,"/opt/ml/model/")
if db_save_safetensors:
f2=os.path.join(sd_models_path, db_model_name, f'{db_model_name}_{train_steps}.safetensors')
if os.path.exists(f2):
shutil.copy(f2,"/opt/ml/model/")
else:
f2=os.path.join(sd_models_path, db_model_name, f'{db_model_name}_{train_steps}.ckpt')
if os.path.exists(f2):
shutil.copy(f2,"/opt/ml/model/")
except Exception as e:
traceback.print_exc()
print(e)
Expand Down

0 comments on commit 216a63c

Please sign in to comment.