Skip to content

Commit

Permalink
working on multi-controlnet at once
Browse files Browse the repository at this point in the history
  • Loading branch information
anotherjesse committed Jun 14, 2023
1 parent 1194962 commit d7a870e
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 100 deletions.
4 changes: 2 additions & 2 deletions cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ build:
cuda: "11.6"
python_version: "3.9"
python_packages:
- "diffusers==0.15.1"
- "diffusers==0.17.1"
- "torch==2.0.0"
- "opencv-contrib-python-headless==4.6.0.66"
- "controlnet-aux==0.0.3"
- "controlnet-aux==0.0.5"
- "transformers==4.29.1"
- "xformers==0.0.19"
- "accelerate==0.19.0"
Expand Down
176 changes: 125 additions & 51 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,21 +128,61 @@ def setup(self):
@torch.inference_mode()
def predict(
self,
image: Path = Input(description="Input image"),
prompt: str = Input(description="Prompt for the model"),
# FIXME: support multiple structures by having inputs canny_image, depth_image, ...
structure: str = Input(
description="Structure to condition on",
choices=[
"canny",
"depth",
"hed",
"hough", # FIXME(ja): why do we call it hough when the controlnet is called mlsd: https://huggingface.co/lllyasviel/sd-controlnet-mlsd
"normal",
"pose",
"scribble",
"seg",
],
canny_image: Path = Input(
description="Control image for canny controlnet", default=None
),
canny_conditioning_scale: float = Input(
description="Conditioning scale for canny controlnet",
default=1,
),
depth_image: Path = Input(
description="Control image for depth controlnet", default=None
),
depth_conditioning_scale: float = Input(
description="Conditioning scale for depth controlnet",
default=1,
),
hed_image: Path = Input(
description="Control image for hed controlnet", default=None
),
hed_conditioning_scale: float = Input(
description="Conditioning scale for hed controlnet", default=1
),
hough_image: Path = Input(
description="Control image for hough controlnet", default=None
),
hough_conditioning_scale: float = Input(
description="Conditioning scale for hough controlnet",
default=1,
),
normal_image: Path = Input(
description="Control image for normal controlnet", default=None
),
normal_conditioning_scale: float = Input(
description="Conditioning scale for normal controlnet",
default=1,
),
pose_image: Path = Input(
description="Control image for pose controlnet", default=None
),
pose_conditioning_scale: float = Input(
description="Conditioning scale for pose controlnet",
default=1,
),
scribble_image: Path = Input(
description="Control image for scribble controlnet", default=None
),
scribble_conditioning_scale: float = Input(
description="Conditioning scale for scribble controlnet",
default=1,
),
seg_image: Path = Input(
description="Control image for seg controlnet", default=None
),
seg_conditioning_scale: float = Input(
description="Conditioning scale for seg controlnet",
default=1,
),
num_samples: int = Input(
description="Number of samples (higher values may OOM)",
Expand Down Expand Up @@ -194,38 +234,42 @@ def predict(
if len(MISSING_WEIGHTS) > 0:
raise Exception("missing weights")

pipe = self.select_pipe(structure)
pipe, processed_control_images, controlnet_conditioning_scale = self.build_pipe(
{
"canny": [canny_image, canny_conditioning_scale],
"depth": [depth_image, depth_conditioning_scale],
"hed": [hed_image, hed_conditioning_scale],
"hough": [hough_image, hough_conditioning_scale],
"normal": [normal_image, normal_conditioning_scale],
"pose": [pose_image, pose_conditioning_scale],
"scribble": [scribble_image, scribble_conditioning_scale],
"seg": [seg_image, seg_conditioning_scale],
},
low_threshold=low_threshold,
high_threshold=high_threshold,
)
pipe.enable_xformers_memory_efficient_attention()
pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config)

if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")

# Load input_image
input_image = Image.open(image)
input_image = self.process_image(
input_image,
structure,
low_threshold=low_threshold,
high_threshold=high_threshold,
)
scale = float(image_resolution) / (min(processed_control_images[0].size))

scale = float(image_resolution) / (min(input_image.size))

def quick_rescale(dim, scale):
"""quick rescale to a multiple of 64, as per original controlnet"""
dim *= scale
return int(np.round(dim / 64.0)) * 64
width = quick_rescale(input_image.size[0], scale)
height = quick_rescale(input_image.size[1], scale)

width = quick_rescale(processed_control_images[0].size[0], scale)
height = quick_rescale(processed_control_images[0].size[1], scale)

generator = torch.Generator("cuda").manual_seed(seed)

outputs = pipe(
prompt,
input_image,
processed_control_images,
height=height,
width=width,
num_inference_steps=steps,
Expand All @@ -234,6 +278,7 @@ def quick_rescale(dim, scale):
negative_prompt=negative_prompt,
num_images_per_prompt=num_samples,
generator=generator,
controlnet_conditioning_scale=controlnet_conditioning_scale,
)
output_paths = []
for i, sample in enumerate(outputs.images):
Expand All @@ -242,36 +287,65 @@ def quick_rescale(dim, scale):
output_paths.append(Path(output_path))
return output_paths

def select_pipe(self, structure):
return StableDiffusionControlNetPipeline(
def build_pipe(self, inputs, low_threshold=100, high_threshold=200):
control_nets = []
processed_control_images = []
conditioning_scales = []

if inputs["canny"][0] is not None:
control_nets.append(self.controlnets["canny"])
img = Image.open(inputs["canny"][0])
processed_control_images.append(
self.canny(img, low_threshold, high_threshold)
)
conditioning_scales.append(inputs["canny"][1])
if inputs["depth"][0] is not None:
control_nets.append(self.controlnets["depth"])
img = Image.open(inputs["depth"][0])
processed_control_images.append(self.midas(img))
conditioning_scales.append(inputs["depth"][1])
if inputs["hed"][0] is not None:
control_nets.append(self.controlnets["hed"])
img = Image.open(inputs["hed"][0])
processed_control_images.append(self.hed(img))
conditioning_scales.append(inputs["hed"][1])
if inputs["hough"][0] is not None:
control_nets.append(self.controlnets["hough"])
img = Image.open(inputs["hough"][0])
processed_control_images.append(self.mlsd(img))
conditioning_scales.append(inputs["hough"][1])
if inputs["normal"][0] is not None:
control_nets.append(self.controlnets["normal"])
img = Image.open(inputs["normal"][0])
processed_control_images.append(self.midas(img, depth_and_normal=True)[1])
conditioning_scales.append(inputs["normal"][1])
if inputs["pose"][0] is not None:
control_nets.append(self.controlnets["pose"])
img = Image.open(inputs["pose"][0])
processed_control_images.append(self.pose(img))
conditioning_scales.append(inputs["pose"][1])
if inputs["scribble"][0] is not None:
control_nets.append(self.controlnets["scribble"])
img = Image.open(inputs["scribble"][0])
processed_control_images.append(self.hed(img, scribble=True))
conditioning_scales.append(inputs["scribble"][1])
if inputs["seg"][0] is not None:
control_nets.append(self.controlnets["seg"])
img = Image.open(inputs["seg"][0])
processed_control_images.append(self.seg_preprocessor(img))
conditioning_scales.append(inputs["seg"][1])

pipe = StableDiffusionControlNetPipeline(
vae=self.pipe.vae,
text_encoder=self.pipe.text_encoder,
tokenizer=self.pipe.tokenizer,
unet=self.pipe.unet,
scheduler=self.pipe.scheduler,
safety_checker=self.pipe.safety_checker,
feature_extractor=self.pipe.feature_extractor,
controlnet=self.controlnets[structure],
controlnet=control_nets,
)

def process_image(self, image, structure, low_threshold=100, high_threshold=200):
if structure == "canny":
input_image = self.canny(image, low_threshold, high_threshold)
elif structure == "depth":
input_image = self.midas(image)
elif structure == "hed":
input_image = self.hed(image)
elif structure == "hough":
input_image = self.mlsd(image)
elif structure == "normal":
input_image = self.midas(image, depth_and_normal=True)[1]
elif structure == "pose":
input_image = self.pose(image)
elif structure == "scribble":
input_image = self.hed(image, scribble=True)
elif structure == "seg":
input_image = self.seg_preprocessor(image)
return input_image
return pipe, processed_control_images, conditioning_scales

def seg_preprocessor(self, image):
image = image.convert("RGB")
Expand Down
62 changes: 15 additions & 47 deletions samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,67 +29,35 @@ def gen(output_fn, **kwargs):


def main():
gen(
"sample.depth.png",
structure="depth",
prompt="taylor swift, best quality, extremely detailed",
image="https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png",
seed=42,
steps=30,
)
gen(
"sample.canny.png",
prompt="taylor swift, best quality, extremely detailed",
structure="canny",
image="https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png",
seed=42,
steps=20,
)
gen(
"sample.normal.png",
structure="normal",
prompt="taylor swift, best quality, extremely detailed",
image="https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png",
prompt="taylor swift in a mid century modern bedroom",
canny_image="https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png",
seed=42,
steps=30,
)
gen(
"sample.seg.png",
structure="seg",
prompt="mid century modern bedroom",
image="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/room_512x512.png",
seed=42,
steps=30,
)
gen(
"sample.hed.png",
structure="hed",
image="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/bird_512x512.png",
prompt="rainbow bird, best quality, extremely detailed",
seed=42,
steps=30,
)
gen(
"sample.pose.png",
structure="pose",
image="https://hf.co/datasets/YiYiXu/controlnet-testing/resolve/main/yoga1.jpeg",
prompt="farmer yoga pose, best quality, extremely detailed",
"sample.hough.png",
prompt="taylor swift in a mid century modern bedroom",
hough_image="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/room_512x512.png",
seed=42,
steps=30,
)
gen(
"sample.hough.png",
structure="hough",
prompt="mid century modern bedroom",
image="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/room_512x512.png",
"sample.both.png",
prompt="taylor swift in a mid century modern bedroom",
hough_image="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/room_512x512.png",
canny_image="https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png",
seed=42,
steps=30,
)
gen(
"sample.scribble.png",
structure="scribble",
prompt="rainbow turtle, psychedelic, best quality, extremely detailed",
image="https://replicate.delivery/pbxt/IJE6zP4jtdwxe7SffC7te9DPHWHW99dMXED5AWamlBNcvxn0/user_1.png",
"sample.scaled.png",
prompt="taylor swift in a mid century modern bedroom",
hough_image="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/room_512x512.png",
hough_conditioning_scale=0.6,
canny_image="https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png",
canny_conditioning_scale=0.9,
seed=42,
steps=30,
)
Expand Down

0 comments on commit d7a870e

Please sign in to comment.