Skip to content

Commit

Permalink
revise & big cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Nov 25, 2022
1 parent e44eebc commit 0cc9ee8
Show file tree
Hide file tree
Showing 10 changed files with 754 additions and 226 deletions.
4 changes: 3 additions & 1 deletion launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ def tests(argv):
def start():
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
import webui
if '--nowebui' in sys.argv:
if '--train' in sys.argv:
webui.train()
elif '--nowebui' in sys.argv:
webui.api_only()
else:
webui.webui()
Expand Down
73 changes: 63 additions & 10 deletions modules/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
from modules.sd_models import checkpoints_list
from modules.realesrgan_model import get_realesrgan_models
from typing import List
from modules.paths import script_path
import json
import os
import boto3
from modules import sd_hijack
from typing import Union
import traceback

def upscaler_to_index(name: str):
try:
Expand Down Expand Up @@ -77,8 +84,11 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
self.app.add_api_route("/invocations", self.invocations, methods=["POST"], response_model=InvocationsResponse)
self.app.add_api_route("/invocations", self.invocations, methods=["POST"], response_model=Union[TextToImageResponse, ImageToImageResponse, ExtrasSingleImageResponse, ExtrasBatchImagesResponse, List[SDModelItem]])
self.app.add_api_route("/ping", self.ping, methods=["GET"], response_model=PingResponse)
self.cache = dict()
self.s3_client = boto3.client('s3')
self.s3_resource= boto3.resource('s3')

def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
Expand Down Expand Up @@ -166,10 +176,8 @@ def extras_single_image_api(self, req: ExtrasSingleImageRequest):
reqDict = setUpscalers(req)

reqDict['image'] = decode_base64_to_image(reqDict['image'])

with self.queue_lock:
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict)

return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])

def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
Expand All @@ -183,9 +191,9 @@ def prepareFiles(file):
reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
reqDict.pop('imageList')


with self.queue_lock:
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)

return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])

def pnginfoapi(self, req: PNGInfoRequest):
Expand Down Expand Up @@ -306,15 +314,60 @@ def get_artists_categories(self):
def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]

def download_s3files(self, s3uri, path):
pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]

s3_bucket = self.s3_resource.Bucket(bucket)
objs = list(s3_bucket.objects.filter(Prefix=key))

if os.path.isfile('cache'):
self.cache = json.load(open('cache', 'r'))

for obj in objs:
response = self.s3_client.head_object(
Bucket = bucket,
Key = obj.key
)
obj_key = 's3://{0}/{1}'.format(bucket, obj.key)
if obj_key not in self.cache or self.cache[obj_key] != response['ETag']:
filename = obj.key[obj.key.rfind('/') + 1 : ]

self.s3_client.download_file(bucket, obj.key, os.path.join(path, filename))
self.cache[obj_key] = response['ETag']

json.dump(self.cache, open('cache', 'w'))

def invocations(self, req: InvocationsRequest):
if req.task == 'text-to-image':
return self.text2imgapi(req.payload)
elif req.task == 'image-to-image':
return self.img2imgapi(req.payload)
else:
raise NotImplementedError
print('-------invocation------')
print(req)

embeddings_s3uri = shared.cmd_opts.embeddings_s3uri
hypernetwork_s3uri = shared.cmd_opts.hypernetwork_s3uri

try:
if req.task == 'text-to-image':
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
return self.text2imgapi(req.txt2img_payload)
elif req.task == 'image-to-image':
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
return self.img2imgapi(req.img2img_payload)
elif req.task == 'extras-single-image':
return self.extras_single_image_api(req.extras_single_payload)
elif req.task == 'extras-batch-images':
return self.extras_batch_images_api(req.extras_batch_payload)
elif req.task == 'sd-models':
return self.get_sd_models()
else:
raise NotImplementedError
except Exception as e:
traceback.print_exc()

def ping(self):
print('-------ping------')
return {'status': 'Healthy'}

def launch(self, server_name, port):
Expand Down
14 changes: 6 additions & 8 deletions modules/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers, opts, parser
from typing import Dict, List
from typing import Union
from typing import Optional

API_NOT_ALLOWED = [
"self",
Expand Down Expand Up @@ -142,7 +142,7 @@ class ExtrasSingleImageRequest(ExtrasBaseRequest):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")

class ExtrasSingleImageResponse(ExtraBaseResponse):
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
image: str = Field(title="Image", description="The generated image in base64 format.")

class FileData(BaseModel):
data: str = Field(title="File data", description="Base64 representation of the file")
Expand Down Expand Up @@ -242,12 +242,10 @@ class ArtistItem(BaseModel):

class InvocationsRequest(BaseModel):
task: str
payload: Union[StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI]

class InvocationsResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
txt2img_payload: Optional[StableDiffusionTxt2ImgProcessingAPI]
img2img_payload: Optional[StableDiffusionImg2ImgProcessingAPI]
extras_single_payload: Optional[ExtrasSingleImageRequest]
extras_batch_payload: Optional[ExtrasBatchImagesRequest]

class PingResponse(BaseModel):
status: str
9 changes: 7 additions & 2 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="",
self.all_subseeds = all_subseeds or [self.subseed]
self.infotexts = infotexts or [info]

self.scripts = p.scripts
self.script_args = p.script_args

def js(self):
obj = {
"prompt": self.prompt,
Expand Down Expand Up @@ -472,10 +475,8 @@ def infotext(iteration=0, position_in_batch=0):

if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()

if p.scripts is not None:
p.scripts.process(p)

infotexts = []
output_images = []

Expand Down Expand Up @@ -609,6 +610,8 @@ def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstp
self.firstphase_height = firstphase_height
self.truncate_x = 0
self.truncate_y = 0
self.scripts = modules.scripts.scripts_txt2img
self.scripts.setup_scripts(False)

def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
Expand Down Expand Up @@ -740,6 +743,8 @@ def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strengt
self.mask = None
self.nmask = None
self.image_conditioning = None
self.scripts = modules.scripts.scripts_img2img
self.scripts.setup_scripts(True)

def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
Expand Down
96 changes: 59 additions & 37 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from modules import shared, modelloader, devices, script_callbacks, sd_vae
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
import requests
import json

api_endpoint = os.environ['api_endpoint'] if 'api_endpoint' in os.environ else ''
endpoint_name = os.environ['endpoint_name'] if 'endpoint_name' in os.environ else ''

model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
Expand Down Expand Up @@ -44,44 +49,61 @@ def checkpoint_tiles():


def list_models():
global checkpoints_list
checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])

def modeltitle(path, shorthash):
abspath = os.path.abspath(path)

if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
elif abspath.startswith(model_path):
name = abspath.replace(model_path, '')
else:
name = os.path.basename(path)

if name.startswith("\\") or name.startswith("/"):
name = name[1:]

shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]

return f'{name} [{shorthash}]', shortname

cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list:
h = model_hash(filename)
title, short_model_name = modeltitle(filename, h)

basename, _ = os.path.splitext(filename)
config = basename + ".yaml"
if not os.path.exists(config):
config = shared.cmd_opts.config

checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)

if shared.cmd_opts.pureui:
response = requests.get(url=f'{api_endpoint}/sd/models')
model_list = json.loads(response.text)

for model in model_list:
h = model['hash']
filename = model['filename']
title = model['title']
short_model_name = model['model_name']
config = model['config']

checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)

else:
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])

def modeltitle(path, shorthash):
abspath = os.path.abspath(path)

if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
elif abspath.startswith(model_path):
name = abspath.replace(model_path, '')
else:
name = os.path.basename(path)

if name.startswith("\\") or name.startswith("/"):
name = name[1:]

shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]

return f'{name} [{shorthash}]', shortname

cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)

for filename in model_list:
h = model_hash(filename)
title, short_model_name = modeltitle(filename, h)

basename, _ = os.path.splitext(filename)
config = basename + ".yaml"
if not os.path.exists(config):
config = shared.cmd_opts.config

checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)


def get_closet_checkpoint_match(searchString):
Expand Down
Loading

0 comments on commit 0cc9ee8

Please sign in to comment.