This tests if `forward` still works after making ControlNetXSModel savable via `register_to_config` & `save_pretrained`

In [1]:
import torch

In [2]:
device = 'mps'
device_dtype = torch.float32

## Load the model

In [3]:
from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionXLPipeline

In [4]:
model = "stabilityai/stable-diffusion-xl-base-1.0"
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=device_dtype)

In [5]:
sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(model, vae=vae, torch_dtype=device_dtype).to(device)
sdxl_unet = sdxl_pipe.unet

ConnectionError: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/models/stabilityai/stable-diffusion-xl-base-1.0 (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x2a976c850>: Failed to establish a new connection: [Errno 8] nodename nor servname provided, or not known'))"), '(Request ID: c473dcf0-99c1-4187-ab98-debc018b7195)')

In [None]:
from diffusers.models.controlnetxs import ControlNetXSModel

In [None]:
cnxs = ControlNetXSModel.from_pretrained('weights/cnxs').to(device)

In [None]:
cnxs.base_model = sdxl_unet

___

## Prepare the input to it

(Later, this will be done by the corresponding pipeline)

In [None]:
prompt = 'A turtle'

In [None]:
from diffusers.utils import load_image

In [None]:
try: original_image = load_image('https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png')
except Exception: original_image = load_image('/Users/umer/Desktop/input_image_vermeer.png')
image = original_image

I will need to use some preprocessing functions from `StableDiffusionControlNetPipeline`, so let create an instance

In [None]:
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel

(We need a `StableDiffusionControlNetPipeline` instead of a regular `StableDiffusionPipeline` because it has a `prepare_image` method)

In [None]:
controlnet = ControlNetModel.from_pretrained('lllyasviel/sd-controlnet-canny', torch_dtype=device_dtype)
pipe = StableDiffusionControlNetPipeline.from_pretrained('runwayml/stable-diffusion-v1-5', controlnet=controlnet, torch_dtype=device_dtype).to(device)

In [None]:
guidance_scale = 7.5
do_classifier_free_guidance = guidance_scale > 1.0

In [None]:
# 2. Define call parameters
batch_size = 1 # because prompt is a single string
num_images_per_prompt  = 1

In [None]:
prompt_embeds, negative_prompt_embeds, _, _ = sdxl_pipe.encode_prompt(prompt)
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_embeds.shape

In [None]:
# 6. Prepare latent variables
num_channels_latents = cnxs.base_model.config.in_channels # we're using our unet here!
num_channels_latents

In [None]:
# Default values for prepare_image
height, width = None, None
generator = None
latents = None
guess_mode = False

In [None]:
# 4. Prepare image
image = pipe.prepare_image(
    image=image,
    width=width,
    height=height,
    batch_size=batch_size * num_images_per_prompt,
    num_images_per_prompt=num_images_per_prompt,
    device=device,
    dtype=controlnet.dtype,
    do_classifier_free_guidance=do_classifier_free_guidance,
    guess_mode=guess_mode,
)
height, width = image.shape[-2:]
height, width, image.shape

In [None]:
# Default values for set_timesteps
num_inference_steps = 50

In [None]:
# 5. Prepare timesteps
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = pipe.scheduler.timesteps

In [None]:
# 6. Prepare latent variables
latents = pipe.prepare_latents(
    batch_size * num_images_per_prompt,
    num_channels_latents,
    height,
    width,
    prompt_embeds.dtype,
    device,
    generator,
    latents,
)

In [None]:
latents.shape

In [None]:
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order

i,t = 0, timesteps[0] # NOTE: We only do 1 step for testing

In [None]:
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

In [None]:
# controlnet(s) inference
# guess_mode == False
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds

In [None]:
latent_model_input.shape, t, prompt_embeds.shape

We still need the hint

In [None]:
hint_image = original_image

In [None]:
import cv2
class CannyDetector:
    def __call__(self, img, low_threshold, high_threshold):
        return cv2.Canny(img, low_threshold, high_threshold)

In [None]:
import numpy as np
def get_canny_edges(image, size=512, threshold=(50, 200)):
    image = np.array(image).astype(np.uint8)
    edges = CannyDetector()(image, *threshold)  # original sized greyscale edges
    edges = edges / 255.
    return edges

In [None]:
import matplotlib.pyplot as plt

In [None]:
edges = get_canny_edges(hint_image)

In [None]:
edges.shape

In [None]:
plt.imshow(edges);

In [None]:
num_samples=2

from einops import repeat
edges = repeat(torch.tensor(edges), 'h w -> b c h w', b=num_samples, c=3)

In [None]:
image.shape, edges.shape

In [None]:
x = latent_model_input
t = t
c = {}
hint = edges.to(device, dtype=device_dtype)
no_control = False

## Run the model!

In [None]:
cnxs.debug = True
result = cnxs(x, t, prompt_embeds, c, hint)

In [None]:
result.sample.shape

**Running 1 step works!**

In [None]:
cnxs.debug = False

___

## Let's now run the denoising loop with `no_control=True`

Let's fix the random seed so we get the same results as the paper. (The paper uses `pytorch_lightning.utilities.seed.seed_everything` which is doesn't exist anymore).

In [None]:
import random
import numpy as np
import torch

def seed_everything(seed):
    # paper u
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

RANDOM_SEED_IN_PAPER = 1999158951

Some prep to better inspect the denoising process

In [None]:
def save_latents(i,t,lat,lats):
    if i==0 and len(lat)>0: lats = []
    lats.append((i,t,lat))

from PIL import Image, ImageOps
from tqdm.notebook import tqdm
from functools import partial

def lat2img(lat, resize_to=None, output_type='pil'):
    with torch.no_grad():
        ims = cnxs_pipe.vae.decode(lat / cnxs_pipe.vae.config.scaling_factor, return_dict=False)[0]
        ims = cnxs_pipe.image_processor.postprocess(ims, output_type=output_type)
        if resize_to is not None:
            if output_type=='pil': ims = [im.resize(resize_to) for im in ims]
            else: print(f'Not resizing as output_type = {output_type} requested')
    return ims

def only_lat(o): return o[-1] if isinstance(o,tuple) else o
def lats2imgs(lats, resize_to=None, output_type='pil',pbar=True):
    if pbar: lats = tqdm(lats)
    ims = [lat2img(only_lat(lat), resize_to, output_type) for lat in lats]
    if output_type=='pt': ims = [im.cpu() for im in ims]
    return ims

real_idx = None
def plot_latents_to_pil_grid(lats, every=5, cols=7, im_size=(300, 300), pbar=True, border=2, return_ims=True, output_type='pil'):
    global real_idx
    
    real_idx = partial(lambda o,every,total: min(total-1,every*o), every=every, total=len(lats))
    
    titles = [f'Image {i}' for i, _, _ in lats if i % every == 0 or i == len(lats)-1]
    lats = [lat for i, _, lat in lats if i % every == 0 or i == len(lats)-1]
    if pbar: lats = tqdm(lats)
    ims = [lat2img(lat, resize_to=im_size, output_type=output_type)[0] for lat in lats]
    ims_bordered = [ImageOps.expand(im, border=2, fill='black') for im in ims]
    im_size = (im_size[0]+border, im_size[1]+border)

    rows = len(ims) // cols
    if rows * cols < len(ims): rows += 1

    grid_image = Image.new('RGB', (cols * im_size[0], rows * im_size[1]), color='grey')
    # draw diagonal white lines
    draw = ImageDraw.Draw(grid_image)
    for xy in range(0,2*max(cols * im_size[0], rows * im_size[1])+1,100):
        draw.line([(xy, 0), (0, xy)], fill="white", width=1)
    
    for i, img in enumerate(ims_bordered):
        x_offset = (i % cols) * im_size[0]
        y_offset = (i // cols) * im_size[1]
        grid_image.paste(img, (x_offset, y_offset))

    if return_ims: return grid_image, ims
    else: return grid_image

Okay, let's go

In [None]:
prompt = 'cinematic, shoe in the streets, made from meat, photorealistic shoe, highly detailed'
neg_prompt = 'lowres, bad anatomy, worst quality, low quality'

In [None]:
image = load_image('input_images/shoe.png')
edges = get_canny_edges(image)

In [None]:
plt.imshow(edges)

In [None]:
cnxs.no_control = True

In [None]:
from diffusers.pipelines.controlnet_xs.pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline

In [None]:
cnxs_pipe = StableDiffusionXLControlNetXSPipeline(
    vae=sdxl_pipe.vae,
    text_encoder=sdxl_pipe.text_encoder,
    text_encoder_2=sdxl_pipe.text_encoder_2,
    tokenizer=sdxl_pipe.tokenizer,
    tokenizer_2=sdxl_pipe.tokenizer_2,
    unet=sdxl_pipe.unet,
    controlnet=cnxs,
    scheduler=sdxl_pipe.scheduler,
)

In [None]:
# takes ~3min to run on cpu / ~1.5min on mps
seed_everything(RANDOM_SEED_IN_PAPER)
lats_sdxl = []
result = cnxs_pipe(prompt, negative_prompt=neg_prompt,image=edges, callback=partial(save_latents, lats=lats_sdxl))

In [None]:
type(result), type(result.images[0])

In [None]:
result.images[0].resize((500,500))

**It works!**

In [None]:
grid, ims_sdxl = plot_latents_to_pil_grid(lats_sdxl)
grid

## Let's now run the denoising loop with control

In [None]:
cnxs.no_control = False

In [None]:
# RuntimeError: Given ... expected input [2, 1, 512, 512] to have 3 channels, but got 1 channels instead
# -> hint has 1 channel, but should have 3
# let'cheat
edges_tensor = torch.tensor(edges)
three_edges = torch.stack((edges_tensor,edges_tensor,edges_tensor))
three_edges.shape

In [None]:
# # Produces shape mismatch at h_ctrl += guided_hint of [2, 32, 64, 64] != [4, 32, 64, 64]
# # As cnxs.forward did run, I assume the error is due to negative prompting. Maybe it's double the inputs from 2 -> 4
# # (Altough it shouldn't. We should have 3 inputs: uncond, cond, neg.)
seed_everything(RANDOM_SEED_IN_PAPER)
lats_cnxs = []
result_controlled = cnxs_pipe(prompt, negative_prompt=neg_prompt,image=three_edges, callback=partial(save_latents, lats=lats_cnxs))

In [None]:
grid, ims_cnxs = plot_latents_to_pil_grid(lats_cnxs, every=1, cols=10, im_size=(200,200)) 
grid

**1.** In the 2nd (=10th) image I can see the contour of the shoe as in the guidance image. But over time, the image get's destroyed.<br/>
**2.** The 1st (=5th) image looks extremely weird.

**Q:** Why is 2nd image so dark? Are the averages predicted noises between sdxl / cnxs different?

**Hypothesis re 1:** The image vanishes when using cnxs, because it gets pushed outside the (0,1) boundary. This might becaise I need to rescale the noise prediceted by cnxs (which I currently don't do).

In [None]:
def percentage_outside_range(x, lo=0,hi=1):
    x = np.array(x).flatten()
    return (np.sum(x<lo)+np.sum(x>hi))/len(x)

for inp,res in (([3],1),([1],0),([[0,0],[-0.5,2]],0.5)): assert percentage_outside_range(inp)==res

Because the conversion to PIL clamsp the images to `(0,255)` (which corresponds to `(0,1)` in numpy), we need do `lats2imgs` again with `output_type='pt'`

In [None]:
ims_cnxs_pt = lats2imgs(lats_cnxs,output_type='pt')

for i,im in enumerate(ims_cnxs_pt):
    p = percentage_outside_range(im)
    print(f'Step {real_idx(i)}] {p:.2f} of values outside range')

**A:** No, the values are all in `(0,1)`

Hmm, but at least image 0 had values < 0! I saw that yesterday. Where are they?