Skip to content

Commit

Permalink
improve dreambooth batch:
Browse files Browse the repository at this point in the history
- img2img support
- KerrasDPM support

updated libraries & added samples.py for testing
  • Loading branch information
anotherjesse committed Jun 25, 2023
1 parent 09f693c commit b636ff4
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 51 deletions.
12 changes: 6 additions & 6 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
build:
gpu: true
cuda: "11.6"
cuda: "11.8"
python_version: "3.10"
python_packages:
- "diffusers==0.11.1"
- "torch==1.13.0"
- "diffusers==0.17.1"
- "torch==2.0.1"
- "ftfy==6.1.1"
- "scipy==1.9.3"
- "transformers==4.25.1"
- "accelerate==0.15.0"
- "scipy==1.10.1"
- "transformers==4.30.2"
- "accelerate==0.20.3"
system_packages:
- "unzip"

Expand Down
103 changes: 58 additions & 45 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from diffusers import (
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
PNDMScheduler,
LMSDiscreteScheduler,
DDIMScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
UniPCMultistepScheduler,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
Expand All @@ -23,20 +25,42 @@
from transformers import CLIPFeatureExtractor
import shutil
import subprocess
from diffusers.utils import load_image

SAFETY_MODEL_CACHE = "diffusers-cache"
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker"

DEFAULT_HEIGHT = 512
DEFAULT_WIDTH = 512
DEFAULT_SCHEDULER = "DDIM"
DEFAULT_GUIDANCE_SCALE = 7.5
DEFAULT_NUM_INFERENCE_STEPS = 50
DEFAULT_STRENGTH = 0.8

# grab instance_prompt from weights,
# unless empty string or not existent

DEFAULT_PROMPT = "a photo of an astronaut riding a horse on mars"


class KerrasDPM:
def from_config(config):
return DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True)


SCHEDULERS = {
"DDIM": DDIMScheduler,
"DPMSolverMultistep": DPMSolverMultistepScheduler,
"HeunDiscrete": HeunDiscreteScheduler,
"KerrasDPM": KerrasDPM,
"K_EULER_ANCESTRAL": EulerAncestralDiscreteScheduler,
"K_EULER": EulerDiscreteScheduler,
"KLMS": LMSDiscreteScheduler,
"PNDM": PNDMScheduler,
"UniPCMultistep": UniPCMultistepScheduler,
}


class Predictor(BasePredictor):
def setup(self):
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
Expand Down Expand Up @@ -106,53 +130,53 @@ def load_weights(self, url):
feature_extractor=self.txt2img_pipe.feature_extractor,
).to("cuda")
print("Loaded pipelines in {:.2f} seconds".format(time.time() - start_time))

self.txt2img_pipe.set_progress_bar_config(disable=True)
self.img2img_pipe.set_progress_bar_config(disable=True)
self.url = url

def generate_images(self, images, output_dir):
with torch.autocast("cuda"), torch.inference_mode():
pipeline = self.txt2img_pipe
pipeline.set_progress_bar_config(disable=True)
for info in tqdm(images, desc="Generating samples"):
inputs = info.get("input") or info.get("inputs")
name = info["name"]
prompt = inputs["prompt"]
negative_prompt = inputs.get("negative_prompt")
width = int(inputs.get("width", 512))
height = int(inputs.get("height", 512))
print(name)

num_outputs = int(inputs.get("num_outputs", 1))
disable_safety_check = bool(
inputs.get("disable_safety_check", False)
)
num_inference_steps = int(inputs.get("num_inference_steps", 50))
guidance_scale = float(inputs.get("guidance_scale", 7.5))
scheduler = inputs.get("scheduler", "DDIM")
seed = inputs.get("seed")
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")

kwargs = {
"prompt": [inputs["prompt"]] * num_outputs,
"num_inference_steps": int(inputs.get("num_inference_steps", DEFAULT_NUM_INFERENCE_STEPS)),
"guidance_scale": float(inputs.get("guidance_scale", DEFAULT_GUIDANCE_SCALE)),
}

image = inputs.get("image")
if image is not None:
kwargs['image'] = load_image(image)
kwargs['strength'] = float(inputs.get('strength', DEFAULT_STRENGTH))
pipeline = self.img2img_pipe
else:
seed = int(seed)
pipeline = self.txt2img_pipe
kwargs["width"] = int(inputs.get("width", DEFAULT_WIDTH))
kwargs["height"] = int(inputs.get("height", DEFAULT_HEIGHT))

pipeline.scheduler = make_scheduler(
scheduler, pipeline.scheduler.config
)
if disable_safety_check:
negative_prompt = inputs.get("negative_prompt")
if negative_prompt is not None:
kwargs["negative_prompt"] = [negative_prompt] * num_outputs

scheduler = inputs.get("scheduler", DEFAULT_SCHEDULER)
pipeline.scheduler = SCHEDULERS[scheduler].from_config(pipeline.scheduler.config)

if bool(inputs.get("disable_safety_check", False)):
pipeline.safety_checker = None
else:
pipeline.safety_checker = self.safety_checker

seed = int(inputs.get("seed", int.from_bytes(os.urandom(2), "big")))
generator = torch.Generator("cuda").manual_seed(seed)
output = pipeline(
prompt=[prompt] * num_outputs
if prompt is not None
else None,
negative_prompt=[negative_prompt] * num_outputs
if negative_prompt is not None
else None,
guidance_scale=guidance_scale,
generator=generator,
num_inference_steps=num_inference_steps,
width=width,
height=height,
**kwargs,
)

for i, image in enumerate(output.images):
Expand Down Expand Up @@ -196,14 +220,3 @@ def predict(
print(file_path)
results.append(file_path)
return results


def make_scheduler(name, config):
return {
"PNDM": PNDMScheduler.from_config(config),
"KLMS": LMSDiscreteScheduler.from_config(config),
"DDIM": DDIMScheduler.from_config(config),
"K_EULER": EulerDiscreteScheduler.from_config(config),
"K_EULER_ANCESTRAL": EulerAncestralDiscreteScheduler.from_config(config),
"DPMSolverMultistep": DPMSolverMultistepScheduler.from_config(config),
}[name]
60 changes: 60 additions & 0 deletions samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import base64
import requests
import sys
import json


def gen(output_fn, **kwargs):
print("Generating", output_fn)
url = "http://localhost:5000/predictions"
response = requests.post(url, json={"input": kwargs})
data = response.json()

try:
for idx, datauri in enumerate(data["output"]):
base64_encoded_data = datauri.split(",")[1]
data = base64.b64decode(base64_encoded_data)
with open(f"{idx}-{output_fn}", "wb") as f:
f.write(data)
except:
print("Error!")
print("input:", kwargs)
print(data["logs"])
sys.exit(1)


def main():
gen(
"sample.batch.png",
images=json.dumps(
[
{
"name": "txt2img",
"inputs": {
"prompt": "a macro photograph of male bfirsh black arts movement magic realism funk art by Gino Severini, Ric Estradamirror shades, ray - tracing, sexy gaze, one light",
"negative_prompt": "childish, poorly drawn, ugly",
"num_outputs": 4,
"scheduler": "PNDM",
"disable_safety_check": False,
"seed": 13510,
},
},
{
"name": "img2img",
"inputs": {
"prompt": "a macro photograph of male bfirsh black arts movement magic realism funk art by Gino Severini, Ric Estradamirror shades, ray - tracing, sexy gaze, one light",
"negative_prompt": "childish, poorly drawn, ugly",
"image": "https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/human_512x512.png",
"scheduler": "KerrasDPM",
"disable_safety_check": False,
"seed": 42,
},
},
]
),
weights="https://replicate.delivery/pbxt/BxsckHvjQWpyGZL7vj2nn7N8lLi7ATfGVY7YtIErz4utTCBIA/output.zip",
)


if __name__ == "__main__":
main()

0 comments on commit b636ff4

Please sign in to comment.