Skip to content

Commit

Permalink
Merge pull request #3511 from bamarillo/master
Browse files Browse the repository at this point in the history
[API][Feature] Add extras endpoints
  • Loading branch information
AUTOMATIC1111 committed Oct 29, 2022
2 parents 737eb28 + 0edf100 commit d5f31f1
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 56 deletions.
104 changes: 54 additions & 50 deletions modules/api/api.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,37 @@
from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, HTTPException
import modules.shared as shared
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo
import modules.shared as shared
import uvicorn
from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json
from typing import List
import json
import io
import base64
from PIL import Image

sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
from modules.extras import run_extras

class TextToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
except:
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")

class ImageToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)

def setUpscalers(req: dict):
reqDict = vars(req)
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
reqDict.pop('upscaler_1')
reqDict.pop('upscaler_2')
return reqDict

class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])

def __base64_to_image(self, base64_string):
# if has a comma, deal with prefix
if "," in base64_string:
base64_string = base64_string.split(",")[1]
imgdata = base64.b64decode(base64_string)
# convert base64 to PIL image
return Image.open(io.BytesIO(imgdata))
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)

def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
Expand All @@ -60,15 +51,9 @@ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
with self.queue_lock:
processed = process_images(p)

b64images = []
for i in processed.images:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))

return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())

b64images = list(map(encode_pil_to_base64, processed.images))

return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())

def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
Expand All @@ -83,7 +68,7 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):

mask = img2imgreq.mask
if mask:
mask = self.__base64_to_image(mask)
mask = decode_base64_to_image(mask)


populate = img2imgreq.copy(update={ # Override __init__ params
Expand All @@ -98,29 +83,48 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):

imgs = []
for img in init_images:
img = self.__base64_to_image(img)
img = decode_base64_to_image(img)
imgs = [img] * p.batch_size

p.init_images = imgs
# Override object param
with self.queue_lock:
processed = process_images(p)

b64images = []
for i in processed.images:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
b64images = list(map(encode_pil_to_base64, processed.images))

if (not img2imgreq.include_init_images):
img2imgreq.init_images = None
img2imgreq.mask = None

return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())

return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
reqDict = setUpscalers(req)

def extrasapi(self):
raise NotImplementedError
reqDict['image'] = decode_base64_to_image(reqDict['image'])

with self.queue_lock:
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict)

return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])

def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
reqDict = setUpscalers(req)

def prepareFiles(file):
file = decode_base64_to_file(file.data, file_path=file.name)
file.orig_name = file.name
return file

reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
reqDict.pop('imageList')

with self.queue_lock:
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)

return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])

def pnginfoapi(self):
raise NotImplementedError

Expand Down
55 changes: 49 additions & 6 deletions modules/api/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from array import array
from inflection import underscore
from typing import Any, Dict, Optional
import inspect
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
import inspect

from modules.shared import sd_upscalers

API_NOT_ALLOWED = [
"self",
Expand Down Expand Up @@ -105,4 +105,47 @@ def generate_model(self):
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
).generate_model()
).generate_model()

class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str

class ImageToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str

class ExtrasBaseRequest(BaseModel):
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?")
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")

class ExtraBaseResponse(BaseModel):
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")

class ExtrasSingleImageRequest(ExtrasBaseRequest):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")

class ExtrasSingleImageResponse(ExtraBaseResponse):
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")

class FileData(BaseModel):
data: str = Field(title="File data", description="Base64 representation of the file")
name: str = Field(title="File name")

class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")

class ExtrasBatchImagesResponse(ExtraBaseResponse):
images: list[str] = Field(title="Images", description="The generated images in base64 format.")

0 comments on commit d5f31f1

Please sign in to comment.