This notebooks tests new code in diffusers: `ControlNetXSModel.__init__`.

### Create ControlNet-XS

In [1]:
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.controlnetxs import ControlNetXSModel

In [2]:
cnxs = ControlNetXSModel.create_as_in_paper()

In [3]:
type(cnxs.base_model),type(cnxs.control_model)

(diffusers.models.unet_2d_condition.UNet2DConditionModel,
 diffusers.models.unet_2d_condition.UNet2DConditionModel)

### Prepare input

I will need to use some preprocessing functions from `StableDiffusionControlNetPipeline`, so let create an instance

(We need a `StableDiffusionControlNetPipeline` instead of a regular `StableDiffusionPipeline` because it has a `prepare_image` method)


In [4]:
import cv2
import torch
import numpy as np
from einops import repeat
from torch import tensor

from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers.utils import load_image

In [5]:
device = 'cpu'
device_dtype = torch.float32

In [6]:
controlnet = ControlNetModel.from_pretrained('lllyasviel/sd-controlnet-canny', torch_dtype=device_dtype)
pipe = StableDiffusionControlNetPipeline.from_pretrained('runwayml/stable-diffusion-v1-5', controlnet=controlnet, torch_dtype=device_dtype).to(device)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.


In [7]:
def text_enc(txt):
    tokenizer = pipe.tokenizer
    text_encoder = pipe.text_encoder
    text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    return text_encoder(text_input.input_ids.to(device))[0]

In [8]:
class CannyDetector:
    def __call__(self, img, low_threshold, high_threshold):
        return cv2.Canny(img, low_threshold, high_threshold)

def get_canny_edges(image, size=512, threshold=(50, 200)):
    image = np.array(image).astype(np.uint8)
    edges = CannyDetector()(image, *threshold)  # original sized greyscale edges
    edges = edges / 255.
    return edges

In [9]:
def prepare_input(prompt, image, cnxs):
    text_embeddings = text_enc(prompt)
    
    guidance_scale = 7.5
    do_classifier_free_guidance = guidance_scale > 1.0
    
    # 2. Define call parameters
    batch_size = 1 # because prompt is a single string
    num_images_per_prompt  = 1
    
    prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
    )
    prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
    prompt_embeds.shape
    
    # 6. Prepare latent variables
    num_channels_latents = cnxs.base_model.config.in_channels # we're using our unet here!
    num_channels_latents
    
    # Default values for prepare_image
    height, width = None, None
    generator = None
    latents = None
    guess_mode = False
    
    # 4. Prepare image
    image = pipe.prepare_image(
        image=image,
        width=width,
        height=height,
        batch_size=batch_size * num_images_per_prompt,
        num_images_per_prompt=num_images_per_prompt,
        device=device,
        dtype=controlnet.dtype,
        do_classifier_free_guidance=do_classifier_free_guidance,
        guess_mode=guess_mode,
    )
    height, width = image.shape[-2:]
    height, width, image.shape
    
    # Default values for set_timesteps
    num_inference_steps = 50
    
    # 5. Prepare timesteps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps = pipe.scheduler.timesteps
    
    # 6. Prepare latent variables
    latents = pipe.prepare_latents(
        batch_size * num_images_per_prompt,
        num_channels_latents,
        height,
        width,
        prompt_embeds.dtype,
        device,
        generator,
        latents,
    )
    
    latents.shape
    
    # 8. Denoising loop
    num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order
    
    i,t = 0, timesteps[0] # NOTE: We only do 1 step for testing
    
    # expand the latents if we are doing classifier free guidance
    latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
    latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
    
    # controlnet(s) inference
    # guess_mode == False
    control_model_input = latent_model_input
    controlnet_prompt_embeds = prompt_embeds
    
    hint_image = image
    edges = get_canny_edges(hint_image)    
    num_samples=2
    edges = repeat(tensor(edges), 'h w -> b c h w', b=num_samples, c=3)

    # x,t,c,context,hint
    return latent_model_input, t, prompt_embeds, {}, edges.to(device, dtype=device_dtype)

In [10]:
prompt = 'A turtle'

In [11]:
try: original_image = load_image('https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png')
except Exception: original_image = load_image('/Users/umer/Desktop/input_image_vermeer.png')
image = original_image

In [12]:
x,t,c,context,hint = prepare_input(prompt,image,cnxs)

### Run ControlNet-XS

In [13]:
result = cnxs(x,t,c,context,hint)

In [16]:
result.shape

torch.Size([2, 4, 64, 64])