Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Mar 4, 2023
1 parent 47dd0bb commit d62af5c
Show file tree
Hide file tree
Showing 7 changed files with 277 additions and 502 deletions.
68 changes: 45 additions & 23 deletions modules/api/api.py
Expand Up @@ -26,6 +26,9 @@
from typing import Union
import traceback
import requests
import piexif
import piexif.helper
import numpy as np

def upscaler_to_index(name: str):
try:
Expand Down Expand Up @@ -54,23 +57,48 @@ def decode_base64_to_image(encoding):
encoding = encoding.split(";")[1].split(",")[1]
return Image.open(BytesIO(base64.b64decode(encoding)))

def encode_to_base64(image):
if type(image) is str:
return image
elif type(image) is Image.Image:
return encode_pil_to_base64(image)
elif type(image) is np.ndarray:
return encode_np_to_base64(image)
else:
return ""

def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:

# Copy any text-only metadata
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
if opts.samples_format.lower() == 'png':
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)

elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
parameters = image.info.get('parameters', None)
exif_bytes = piexif.dump({
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
})
if opts.samples_format.lower() in ("jpg", "jpeg"):
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
else:
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)

else:
raise HTTPException(status_code=500, detail="Invalid image format")

image.save(
output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
)
bytes_data = output_bytes.getvalue()

return base64.b64encode(bytes_data)

def encode_np_to_base64(image):
pil = Image.fromarray(image)
return encode_pil_to_base64(pil)

class Api:
def __init__(self, app: FastAPI, queue_lock: Lock):
Expand Down Expand Up @@ -138,16 +166,13 @@ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
shared.state.begin()

with self.queue_lock:
processed = process_images(p)

if p.script_args is not None:
processed = p.scripts.run(p, *p.script_args)
processed = p.scripts.run(p, *p.script_args)
if processed is None:
processed = process_images(p)

shared.state.end()

b64images = list(map(encode_pil_to_base64, processed.images))
b64images = list(map(encode_to_base64, processed.images))

return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())

Expand Down Expand Up @@ -185,16 +210,13 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
shared.state.begin()

with self.queue_lock:
processed = process_images(p)

if p.script_args is not None:
processed = p.scripts.run(p, *p.script_args)
processed = p.scripts.run(p, *p.script_args)
if processed is None:
processed = process_images(p)

shared.state.end()

b64images = list(map(encode_pil_to_base64, processed.images))
b64images = list(map(encode_to_base64, processed.images))

if not img2imgreq.include_init_images:
img2imgreq.init_images = None
Expand All @@ -210,7 +232,7 @@ def extras_single_image_api(self, req: ExtrasSingleImageRequest):
with self.queue_lock:
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict)

return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
return ExtrasSingleImageResponse(image=encode_to_base64(result[0][0]), html_info=result[1])

def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
reqDict = setUpscalers(req)
Expand All @@ -226,7 +248,7 @@ def prepareFiles(file):
with self.queue_lock:
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)

return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
return ExtrasBatchImagesResponse(images=list(map(encode_to_base64, result[0])), html_info=result[1])

def pnginfoapi(self, req: PNGInfoRequest):
if(not req.image.strip()):
Expand Down Expand Up @@ -260,7 +282,7 @@ def progressapi(self, req: ProgressRequest = Depends()):

current_image = None
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
current_image = encode_to_base64(shared.state.current_image)

return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)

Expand Down
6 changes: 0 additions & 6 deletions modules/call_queue.py
Expand Up @@ -413,22 +413,16 @@ def f(*args, **kwargs):
username = args[len(args) - 2]
sagemaker_endpoint = args[len(args) -1]
args = args[:-2]
print('username:', username)
print('sagemaker_endpoint:', sagemaker_endpoint)
res = sagemaker_inference('text-to-image', 'sync', username, sagemaker_endpoint, *args, **kwargs)
elif cmd_opts.pureui and func == modules.img2img.img2img:
username = args[len(args) - 2]
sagemaker_endpoint = args[len(args) -1]
args = args[:-2]
print('username:', username)
print('sagemaker_endpoint:', sagemaker_endpoint)
res = sagemaker_inference('image-to-image', 'sync', username, sagemaker_endpoint, *args, **kwargs)
elif cmd_opts.pureui and func == modules.extras.run_extras:
username = args[len(args) - 2]
sagemaker_endpoint = args[len(args) -1]
args = args[:-2]
print('username:', username)
print('sagemaker_endpoint:', sagemaker_endpoint)
res = sagemaker_inference('extras', 'sync', username, sagemaker_endpoint, *args, **kwargs)
else:
shared.state.begin()
Expand Down
15 changes: 11 additions & 4 deletions modules/processing.py
Expand Up @@ -22,7 +22,8 @@
import modules.styles
import logging
import base64
import io
from io import BytesIO
from numpy import asarray

# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
Expand Down Expand Up @@ -68,6 +69,10 @@ def apply_overlay(image, paste_loc, index, overlays):

return image

def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
return Image.open(BytesIO(base64.b64decode(encoding)))

class StableDiffusionProcessing():
"""
Expand Down Expand Up @@ -119,9 +124,11 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom
self.script_args = json.loads(script_args) if script_args != None else None

if self.script_args:
for key in self.script_args:
if key == 'image' or key == 'mask':
self.script_arg[key] = Image.open(io.BytesIO(base64.b64decode(self.script_args[key])))
for idx in range(len(self.script_args)):
if(isinstance(self.script_args[idx], dict)):
for key in self.script_args[idx]:
if key == 'image' or key == 'mask':
self.script_args[idx][key] = asarray(decode_base64_to_image(self.script_args[idx][key]))

if not seed_enable_extras:
self.subseed = -1
Expand Down
15 changes: 7 additions & 8 deletions modules/shared.py
Expand Up @@ -137,14 +137,13 @@
hypernetworks = {}
loaded_hypernetwork = None

if cmd_opts.pureui:
api_endpoint = os.environ['api_endpoint']
industrial_model = ''
default_options = {}
username_state = None
sagemaker_endpoint_component = None
sd_model_checkpoint_component = None
create_train_dreambooth_component = None
api_endpoint = os.environ['api_endpoint']
industrial_model = ''
default_options = {}
username_state = None
sagemaker_endpoint_component = None
sd_model_checkpoint_component = None
create_train_dreambooth_component = None

def reload_hypernetworks():
from modules.hypernetworks import hypernetwork
Expand Down

0 comments on commit d62af5c

Please sign in to comment.