Skip to content

Commit

Permalink
fix issues with dreambooth training
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Mar 7, 2023
1 parent 082f5b1 commit 231076a
Showing 1 changed file with 29 additions and 28 deletions.
57 changes: 29 additions & 28 deletions modules/shared.py
Expand Up @@ -137,36 +137,37 @@
hypernetworks = {}
loaded_hypernetwork = None

api_endpoint = os.environ['api_endpoint']
industrial_model = ''
default_options = {}
sagemaker_endpoint_component = None
sd_model_checkpoint_component = None
create_train_dreambooth_component = None

response = requests.get(url=f'{api_endpoint}/sd/industrialmodel')
if response.status_code == 200:
industrial_model = response.text
else:
model_name = 'stable-diffusion-webui'
model_description = model_name
inputs = {
'model_algorithm': 'stable-diffusion-webui',
'model_name': model_name,
'model_description': model_description,
'model_extra': '{"visible": "false"}',
'model_samples': '',
'file_content': {
'data': [(lambda x: int(x))(x) for x in open(os.path.join(script_path, 'logo.ico'), 'rb').read()]
}
}

response = requests.post(url=f'{api_endpoint}/industrialmodel', json = inputs)
if not cmd_opts.train:
api_endpoint = os.environ['api_endpoint']
industrial_model = ''
default_options = {}
sagemaker_endpoint_component = None
sd_model_checkpoint_component = None
create_train_dreambooth_component = None

response = requests.get(url=f'{api_endpoint}/sd/industrialmodel')
if response.status_code == 200:
body = json.loads(response.text)
industrial_model = body['id']
industrial_model = response.text
else:
print(response.text)
model_name = 'stable-diffusion-webui'
model_description = model_name
inputs = {
'model_algorithm': 'stable-diffusion-webui',
'model_name': model_name,
'model_description': model_description,
'model_extra': '{"visible": "false"}',
'model_samples': '',
'file_content': {
'data': [(lambda x: int(x))(x) for x in open(os.path.join(script_path, 'logo.ico'), 'rb').read()]
}
}

response = requests.post(url=f'{api_endpoint}/industrialmodel', json = inputs)
if response.status_code == 200:
body = json.loads(response.text)
industrial_model = body['id']
else:
print(response.text)

def reload_hypernetworks():
from modules.hypernetworks import hypernetwork
Expand Down

0 comments on commit 231076a

Please sign in to comment.