In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Callbacks - Visualize and Decode Latents

In [9]:
from typing import List
import torch
from diffusers import StableDiffusionPipeline
import PIL

from src import helpers

The `StableDiffusionPipeline.__call__` function supports the following parameters to attach callbacks to the inference loop:

---

```python
callback (`Callable`, *optional*):
    A function that will be called every `callback_steps` steps during inference. The function will be
    called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
    The frequency at which the `callback` function will be called. If not specified, the callback will be
    called at every step.
```

In [10]:
model_id = "runwayml/stable-diffusion-v1-5"

pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    use_safetensors=True,
    # low_cpu_mem_usage=True,
)

pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.


In [11]:
prompt = "flock of sheep are having selfie with a grazing on grassland, himalayan background extra detailed, highly realistic, extra detailed, himalayn landscape, hyper realistic"
negative_prompt = "low-res, low quality, jpeg artifacts, blurry, grainy, distorted, ugly, out of frame, disfigured, bad anatomy, watermarked"
num_inference_steps = 30

In [12]:
def my_callback(step: int, timestep: int, latents: torch.Tensor) -> None:
    print(f"Step: {step} | Timestep: {timestep} | Latents: {latents.shape}")


_ = pipe(
    prompt=prompt,
    num_inference_steps=5,
    callback=my_callback,
    callback_steps=1,
)

  0%|          | 0/5 [00:00<?, ?it/s]

Step: 1 | Timestep: 601 | Latents: torch.Size([1, 4, 64, 64])
Step: 2 | Timestep: 601 | Latents: torch.Size([1, 4, 64, 64])
Step: 3 | Timestep: 401 | Latents: torch.Size([1, 4, 64, 64])
Step: 4 | Timestep: 201 | Latents: torch.Size([1, 4, 64, 64])
Step: 5 | Timestep: 1 | Latents: torch.Size([1, 4, 64, 64])


In [13]:
decoded_image_resize = 192
frames, frame_titles, captions = [], [], []


def convert_latents_to_imgs(latents: torch.Tensor) -> List[PIL.Image.Image]:
    """
    Converts the latent dimensions to images for visualization purposes.
    Note: this does not use the VAE to decode the latents.
    """
    latents = 1 / pipe.vae.config.scaling_factor * latents
    # latents shape: (batch_size, latent_ch, height, width)
    image_batch = latents.permute(1, 0, 2, 3)
    # image_batch = pipe.vae.decode(image_batch, return_dict=False)[0]
    # image_batch shape: (batch_size, channels, height, width)
    do_denormalize = [True] * image_batch.shape[0]
    return pipe.image_processor.postprocess(
        image_batch, output_type="pil", do_denormalize=do_denormalize
    )  # type: ignore


def decode_latents(latents: torch.Tensor) -> List[PIL.Image.Image]:
    latents = 1 / pipe.vae.config.scaling_factor * latents
    image_batch = pipe.vae.decode(latents, return_dict=False)[0]
    # image_batch shape: (batch, channels, height, width)
    do_denormalize = [True] * image_batch.shape[0]
    return pipe.image_processor.postprocess(
        image_batch, output_type="pil", do_denormalize=do_denormalize
    )  # type: ignore


def my_callback(step: int, timestep: int, latents: torch.Tensor) -> None:
    decoded_image = decode_latents(latents)
    # Resizing for faster animation
    # Note: Even though the pipeline accepts width and height parameters,
    #       the model still works best using the resolution it was trained on.
    #       Therefore we resize/resample the images afterwards.
    decoded_image = helpers.resize(decoded_image, decoded_image_resize)

    # To visualize the latent itself
    converted_latents = convert_latents_to_imgs(latents)

    frames.append(converted_latents + decoded_image)
    frame_titles.append(f"Step {step} | Timestep {timestep}")
    captions.append(
        [f"Latent Dim {i}" for i in range(latents.shape[1])] + ["Decoded Image"]
    )


i = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    num_inference_steps=num_inference_steps,
    num_images_per_prompt=1,
    callback=my_callback,
    callback_steps=1,
)


helpers.plot_anim(
    frames,
    frame_titles,
    captions[0],
    interval=500,
    n_rows=3,
    embed_scale=1.25,
    save_fname="anim_diffusion_latents_decode",
)

  0%|          | 0/30 [00:00<?, ?it/s]

Number of frames: 30
Number of images per frame: 5
