Skip to content

Commit

Permalink
big cleanup and revise for stable-diffusion-webui
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Dec 2, 2022
1 parent 26ae67b commit 18815ab
Show file tree
Hide file tree
Showing 8 changed files with 890 additions and 229 deletions.
35 changes: 30 additions & 5 deletions modules/api/api.py
Expand Up @@ -18,9 +18,10 @@
import json
import os
import boto3
from modules import sd_hijack
from modules import sd_hijack, hypernetworks
from typing import Union
import traceback
import requests

def upscaler_to_index(name: str):
try:
Expand Down Expand Up @@ -347,18 +348,42 @@ def invocations(self, req: InvocationsRequest):
hypernetwork_s3uri = shared.cmd_opts.hypernetwork_s3uri

try:
username = req.username
default_options = shared.opts.data
if username != '':
inputs = {
'action': 'get',
'username': username
}
api_endpoint = os.environ['api_endpoint']
response = requests.post(url=f'{api_endpoint}/sd/user', json=inputs)
if response.status_code == 200 and response.text != '':
shared.opts.data = json.loads(response.text)

self.download_s3files(hypernetwork_s3uri, os.path.join(script_path, shared.cmd_opts.hypernetwork_dir))
hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)
hypernetworks.hypernetwork.apply_strength()

if req.task == 'text-to-image':
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
return self.text2imgapi(req.txt2img_payload)
response = self.text2imgapi(req.txt2img_payload)
shared.opts.data = default_options
return response
elif req.task == 'image-to-image':
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
return self.img2imgapi(req.img2img_payload)
response = self.img2imgapi(req.img2img_payload)
shared.opts.data = default_options
return response
elif req.task == 'extras-single-image':
return self.extras_single_image_api(req.extras_single_payload)
response = self.extras_single_image_api(req.extras_single_payload)
shared.opts.data = default_options
return response
elif req.task == 'extras-batch-images':
return self.extras_batch_images_api(req.extras_batch_payload)
response = self.extras_batch_images_api(req.extras_batch_payload)
shared.opts.data = default_options
return response
elif req.task == 'sd-models':
return self.get_sd_models()
else:
Expand Down
1 change: 1 addition & 0 deletions modules/api/models.py
Expand Up @@ -242,6 +242,7 @@ class ArtistItem(BaseModel):

class InvocationsRequest(BaseModel):
task: str
username: Optional[str]
txt2img_payload: Optional[StableDiffusionTxt2ImgProcessingAPI]
img2img_payload: Optional[StableDiffusionImg2ImgProcessingAPI]
extras_single_payload: Optional[ExtrasSingleImageRequest]
Expand Down
48 changes: 34 additions & 14 deletions modules/sd_models.py
Expand Up @@ -15,9 +15,6 @@
import requests
import json

api_endpoint = os.environ['api_endpoint'] if 'api_endpoint' in os.environ else ''
endpoint_name = os.environ['endpoint_name'] if 'endpoint_name' in os.environ else ''

model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))

Expand Down Expand Up @@ -53,18 +50,41 @@ def list_models():
checkpoints_list.clear()

if shared.cmd_opts.pureui:
response = requests.get(url=f'{api_endpoint}/sd/models')
model_list = json.loads(response.text)

for model in model_list:
h = model['hash']
filename = model['filename']
title = model['title']
short_model_name = model['model_name']
config = model['config']

checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
api_endpoint = os.environ['api_endpoint'] if 'api_endpoint' in os.environ else ''
endpoint_name = os.environ['endpoint_name'] if 'endpoint_name' in os.environ else ''
class SDModel:
def __init__(self, sd_model_name, sd_model_hash, sd_model_checkpoint, sd_checkpoint_info):
self.sd_model_name = sd_model_name
self.sd_model_hash = sd_model_hash
self.sd_model_checkpoint = sd_model_checkpoint
self.sd_checkpoint_info = sd_checkpoint_info

response = requests.get(url=f'{api_endpoint}/sd/models')
if response.status_code == 200:
model_list = json.loads(response.text)

for model in model_list:
h = model['hash']
filename = model['filename']
title = model['title']
short_model_name = model['model_name']
config = model['config']

if 'sd_model_checkpoint' not in shared.opts.data:
shared.opts.data['sd_model_checkpoint'] = title

checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)

sd_model_checkpoint = shared.opts.data['sd_model_checkpoint']
sd_checkpoint_info = checkpoints_list[sd_model_checkpoint]
sd_model_name = checkpoints_list[sd_model_checkpoint].model_name
sd_model_hash = checkpoints_list[sd_model_checkpoint].hash
shared.sd_model = SDModel(
sd_model_name,
sd_model_hash,
sd_model_checkpoint,
sd_checkpoint_info
)
else:
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])

Expand Down
41 changes: 33 additions & 8 deletions modules/shared.py
Expand Up @@ -98,8 +98,9 @@
parser.add_argument("--train-args", type=str, help='Train args', default='')
parser.add_argument('--embeddings-s3uri', default='', type=str, help='Embedding S3Uri')
parser.add_argument('--hypernetwork-s3uri', default='', type=str, help='Hypernetwork S3Uri')
parser.add_argument('--industrial-model', default='', type=str, help='Industrial Model')
parser.add_argument('--region-name', type=str, help='Region Name')
parser.add_argument('--username', default='', type=str, help='Username')
parser.add_argument('--api-endpoint', default='', type=str, help='API Endpoint')

cmd_opts = parser.parse_args()
restricted_opts = {
Expand Down Expand Up @@ -131,6 +132,14 @@
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None

if cmd_opts.pureui:
username = ''
api_endpoint = os.environ['api_endpoint']
industrial_model = ''
endpoint_name = ''
endpoint_names = []
default_options = {}

def reload_hypernetworks():
global hypernetworks

Expand Down Expand Up @@ -472,10 +481,13 @@ def load(self, filename):
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)

if cmd_opts.pureui:
opts.show_progressbar = False
api_endpoint = os.environ['api_endpoint']
global api_endpoint, industrial_model, default_options

if 'industrial_model' not in opts.data:
#opts.show_progressbar = False
response = requests.get(url=f'{api_endpoint}/sd/industrialmodel')
if response.status_code == 200:
industrial_model = response.text
else:
model_name = 'stable-diffusion-webui'
model_description = model_name
inputs = {
Expand All @@ -493,8 +505,8 @@ def load(self, filename):
if response.status_code == 200:
body = json.loads(response.text)
industrial_model = body['id']
opts.data['industrial_model'] = industrial_model
opts.save(config_filename)

default_options = self.data

def onchange(self, key, func, call=True):
item = self.data_labels.get(key)
Expand Down Expand Up @@ -534,8 +546,6 @@ def reorder(self):

progress_print_out = sys.stdout

userid = ''

class TotalTQDM:
def __init__(self):
self._tqdm = None
Expand Down Expand Up @@ -577,3 +587,18 @@ def clear(self):
def listfiles(dirname):
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]

if cmd_opts.pureui:
def init_endpoints():
global endpoint_name, endpoint_names, industrial_model, api_endpoint

endpoints = []
params = {
'industrial_model': industrial_model
}
response = requests.get(url=f'{api_endpoint}/endpoint', params=params)
if response.status_code == 200:
for endpoint_item in json.loads(response.text):
endpoints.append(endpoint_item['EndpointName'])
endpoint_name = endpoints[0] if len(endpoints) > 0 else ''
endpoint_names = endpoints

0 comments on commit 18815ab

Please sign in to comment.