# Introduction
**Requirements:**
* This notebook need to be run on a instance that features second-generation Neuron hardware (NeuronCore-v2), we recommend `inf2.8xlarge` instances or larger.

In this notebook, you will compile and run the [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) text-to-image pipeline (Stable Diffusion 1.5) on AWS Neuron hardware to 
generate images of size 512x512 pixels.
To run the pipeline, you will use the dedicated [`StableDiffutionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline) 
abstraction from HuggingFace's [`diffusers`](https://huggingface.co/docs/diffusers/index) library.

Diffusion models are peculiar in the sense that they do not consist of a single model but of a collection of models and stateless components orcherstrated into a pipeline. Running a diffusion pipeline on Neuron therefore 
involves compiling multiple models while ensuring that the compiled models can still be handled by the pipeline. 

The [`model_index.json` file](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json) 
from the HuggingFace model repository lists the different components of the `runwayml/stable-diffusion-v1-5` pipeline:
* The CLIP text encoder and its tokenizer,
* The U-Net,
* The variational auto-encoder (VAE),
* The safety checker,
* The scheduler (default scheduler is the [Pseudo Numerical Methods for Diffusion Models (PNDM) scheduler](https://huggingface.co/docs/diffusers/api/schedulers/pndm))
* The feature extractor

Now, not all these components will be run using Neuron devices. The tokenizer and the feature extractor are typically run on CPU. The scheduler and the feature extractor are stateless 
components and will be run on CPU as well. The other components will be compiled and run on Neuron hardware.

Under the hood, the VAE consists of [4 different model components](https://github.com/huggingface/diffusers/blob/v0.20.0/src/diffusers/models/autoencoder_kl.py#L77C8-L77C8): the encoder, the 
conv layer, the post conv layer and the decoder. Looking at the Pipeline's [`__call__`](https://github.com/huggingface/diffusers/blob/v0.20.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L524) 
method, you notice that only the VAE's `decode` method is actually called when performing inference. By studying the VAE's [`decode`](https://github.com/huggingface/diffusers/blob/v0.20.0/src/diffusers/models/autoencoder_kl.py#L265) 
method, you can see that only the its post conv and decoder components are used. You therefore need to compile the VAE's decoder and post conv layer only.

Similarly, the safety checker consists of [2 different model components](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L40): the vison model 
and the visual projection model.

To sum it up, here is the list of the model components you will compile, attach to the Stable Diffusion pipeline and run on Neuron devices:
* The CLIP text encoder
* The U-Net
* The VAE decoder and post quant conv layer
* The safety checker vision model and visual projection model

## Model compilation
First, you will load the pipeline using the `StableDiffusionPipeline.from_pretrained` API. Each model of interest will then be compiled using PyTorch Neuron (`torch-neuronx`) 
[Tracing API](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/api-reference-guide/inference/api-torch-neuronx-trace.html). Concretely, a given model is passed 
to `torch_neuronx.trace` together with example input tensors. The `trace` function then traces the model, i.e. extracts the graph of tensor operations, converts it into a intermediate representation 
in the HLO format which is then fed to the Neuron Compiler (`neuronx-cc`). The Neuron Compiler then compiles the model. The `trace` function then creates a Neuron-specific TorchScript program that 
can be either directly used to run the model on Neuron hardware or serialized for a later use. This TorchScript program is embedded in a [`torch.jit.ScriptModule`](https://pytorch.org/docs/master/generated/torch.jit.ScriptModule.html#torch.jit.ScriptModule) 
object. The `trace` function finally returns the `ScriptModule` object.

The compiled model includes both the compiled execution code and the model weights. The compiler has a a type casting 
capability which by default downcasts all the weights involved in matrix multiplications into the BF16 type. You will keep this capability enabled and use the FP32 checkpoints to compile the model.

**Notice:** The Neuron Compiler does not support complex output types some HuggingFace models may use. For these models, you will modify their `forward` method so that it only returns "simple" types such 
as individual or tuple of tensors. Cf. Section 2. Compile the models.

## Runtime compatibility
For each model to be run on Neuron hardware, its corresponding PyTorch module object in the `StableDiffusionPipeline` is replaced by the `ScriptModule` object created at compile time. A `torch.jit.ScriptModule` 
object implements the same interface as the familiar `torch.nn.Module`. In both cases, a model inference is performed by calling the object's `__call__` magic method (which later calls their `forward` 
method), i.e. by "calling the object". Using a `ScriptModule` therefore feels like calling a function with the following runtime peculiarities:
* The `ScriptModule` object only accepts positonal arguments. The number of arguments and the shape of input tensors must match with the argument count and tensor shapes used at compile time.
* The `ScriptModule` object you will produce returns tensors or tuple of tensors. These return types may not match the interface expected by the `StableDiffusionPipeline`.

To ensure runtime compatibility, e.g. by making all arguments positional ones, by ensuring tensor shapes are adequate, by discarding irrelevant arguments passed by the `StableDiffusionPipeline`, by wrapping returned tensors in 
expected return types, etc., you may require to modify a `ScriptModule`'s `forward` method accordingly. Cf. Section 3. Load the models.

## Dependencies
This tutorial requires the following pip packages to be installed:
- `torch-neuronx`
- `neuronx-cc`
- `diffusers==0.29.2`
- `transformers==4.42.3`
- `accelerate==0.31.0`
- `matplotlib`

At this point, your PyTorch Neuron environment should be already set up, the latest versions of `torch-neuronx` and `neuronx-cc` should therefore already be installed in your environement (see the [PyTorch Neuron setup guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/torch-neuronx.html#setup-torch-neuronx) from the docs). The remaining dependencies can be installed by running the cell below:

In [None]:
%pip install diffusers==0.29.2 transformers==4.42.3 accelerate==0.31.0 matplotlib Pillow --upgrade --quiet

## Imports

In [None]:
import copy
import datetime
import os
from pathlib import Path
import shutil
from typing import Any, Callable, Dict, Optional, Tuple

from diffusers import StableDiffusionPipeline
from diffusers.models.attention_processor import Attention
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.jit
import torch.nn
import torch_neuronx
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.models.clip.modeling_clip import CLIPTextModel, CLIPVisionModel

try:
    from neuronxcc.nki._private_kernels.attention import attention_isa_kernel  # noqa: E402
except ImportError:
    from neuronxcc.nki.kernels.attention import attention_isa_kernel  # noqa: E402
from torch_neuronx.xla_impl.ops import nki_jit  # noqa: E402
import diffusers
import math
import torch.nn.functional as F
from typing import Optional

_flash_fwd_call = nki_jit()(attention_isa_kernel)
def attention_wrapper_without_swap(query, key, value):
    bs, n_head, q_len, d_head = query.shape  # my change
    k_len = key.shape[2]
    v_len = value.shape[2]
    q = query.clone().permute(0, 1, 3, 2).reshape((bs * n_head, d_head, q_len))
    k = key.clone().permute(0, 1, 3, 2).reshape((bs * n_head, d_head, k_len))
    v = value.clone().reshape((bs * n_head, v_len, d_head))
    attn_output = torch.zeros((bs * n_head, q_len, d_head), dtype=torch.bfloat16, device=q.device)

    scale = 1 / math.sqrt(d_head)
    _flash_fwd_call(q, k, v, scale, attn_output, kernel_name="AttentionMMSoftmaxMMWithoutSwap")

    attn_output = attn_output.reshape((bs, n_head, q_len, d_head))

    return attn_output
class KernelizedAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        *args,
        **kwargs,
    ) -> torch.Tensor:
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            diffusers.utils.deprecate("scale", "1.0.0", deprecation_message)

        residual = hidden_states
        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        if attention_mask is not None or query.shape[3] > query.shape[2] or query.shape[3] > 128 or value.shape[2] == 77:
            hidden_states = F.scaled_dot_product_attention(
                query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
            )
        else:
            hidden_states = attention_wrapper_without_swap(query, key, value)

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

In [None]:
NEURON_COMPILER_WORKDIR = Path("neuron_compiler_workdir")
NEURON_COMPILER_WORKDIR.mkdir(exist_ok=True)
NEURON_COMPILER_OUTPUT_DIR = Path("compiled_models")
NEURON_COMPILER_OUTPUT_DIR.mkdir(exist_ok=True)

MODEL_ID = "runwayml/stable-diffusion-v1-5"

## Configuration
Notice that the 512x512 output image size coincides with the default output image size:
* The default height is indeed equal to `unet_sample_size * vae_scaling_factor`, i.e. `64*8=512`.
* Same goes for the default width.

The default output image format is therefore a square image of size 512x512 pixels.

***Remarks:***
* The `unet_sample_size` is part of the [U-Net configuration values](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json),
* The `vae_scaling_factor` is equal to $2^{nb\_vae\_decoder\_blocks-1}=8$ (cf. the [VAE configuration](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json)).

**WARNING**: In this notebook, you are going to compile compute-heavy components of the Stable Diffusion pipeline. The compilation step 
compiles the execution path traced ***at compile-time*** into a static Neuron-optimized set of instructions. To avoid runtime errors, tensor 
shapes supplied to compiled models must be identical at compile and runtime. It is therefore highly recommended to use the same loading and generation 
parameters at compile and runtime. If not possible, be very cautious when modifying parameters that impact the shape of input tensors, e.g. 
batch size, `num_images_per_prompt`, `guidance_scale`, etc. Parameters that do not influence tensor shapes can safely take different values 
between compile and runtime, e.g. `num_inference_step`.

In [None]:
HEIGHT = WIDTH = 512
DTYPE = torch.float32
BATCH_SIZE = 1
NUM_IMAGES_PER_PROMPT = 1

PIPELINE_LOADING_CONFIG = {
    "revision": "main", # FP32 checkpoints
    "use_safetensors": True,
    "low_cpu_mem_usage": True,
}

PIPELINE_GENERATION_CONFIG = {
    "height": HEIGHT,
    "width": WIDTH,
    "num_inference_steps": 50,
    "num_images_per_prompt": NUM_IMAGES_PER_PROMPT,
    "guidance_scale": 8.0,
    "output_type": "pil",
}

In [None]:
# Neuron compiler configuration

# Reminder: By default, the Neuron Compiler (neuronx-cc) casts FP32 operations that use the Neuron matrix-multiplication 
# engine (--auto-cast matmul) to the lower-precision BF16 data type (--auto-cast-type bf16). Supported lower-precision
# data types are: tf32, fp16, bf16 and fp8_e4m3.
NEURON_COMPILER_TYPE_CASTING_CONFIG = [
    "--auto-cast=matmult",
    f"--auto-cast-type=bf16"
]

NEURON_COMPILER_CLI_ARGS = [
    "--target=inf2",
    "--enable-fast-loading-neuron-binaries",
    *NEURON_COMPILER_TYPE_CASTING_CONFIG,
]

# https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/api-reference-guide/training/torch-neuron-envvars.html
os.environ["NEURON_FUSE_SOFTMAX"] = "1"

# 1. List compile time and runtime requirements
In this section, you will:
* Load the Stable Diffusion pipeline
* Make the `forward` method of the models you will compile more verbose using a decorator function
* Perform a trial pipeline run on CPU to gather the input/output information you will need in the next sections

In [None]:
def print_args(args: Tuple[Any]) -> None:
    print(f"Positional args (count: {len(args)})")
    for i, arg in enumerate(args):
        if isinstance(arg, torch.Tensor):
            arg = f"Tensor{tuple(arg.shape)}"
        print(f"  - arg({i}): {arg}")

def print_kwargs(kwargs: Dict[str, Any]) -> None:
    print(f"Keyword args (count: {len(kwargs)})")
    for i, (kwarg_name, kwarg_value) in enumerate(kwargs.items()):
        if isinstance(kwarg_value, torch.Tensor):
            kwarg_value = f"Tensor{tuple(kwarg_value.shape)}"
        print(f"  - kwarg({i}): {kwarg_name}={kwarg_value}")

def print_output(output: Any) -> None:
    if isinstance(output, torch.Tensor):
        print(f"Output: Tensor{tuple(output.shape)}")
    elif isinstance(output, BaseModelOutputWithPooling):
        print(f"Output type: {type(output).__name__}")
        print(f"  last_hidden_state=Tensor{tuple(output.last_hidden_state.shape)}")
        print(f"  pooler_output=Tensor{tuple(output.pooler_output.shape)}")
    elif isinstance(output, UNet2DConditionOutput):
        print(f"Output type: {type(output).__name__}")
        print(f"  sample=Tensor{tuple(output.sample.shape)}")
    else:
        print(f"Output type: {type(output).__name__}")

def make_forward_verbose(model: torch.nn.Module, model_name: str) -> torch.nn.Module:
    """
    The `make_forward_verbose` function is implemented as a Python decorator function with custom arguments.
    The `make_forward_verbose` decorates an input model. 
    Model decoration consists in:
        1. Decorating the model's forward method using the `make_verbose` decorator function,
        2. Monkey-patching the orginal method with the decorated one.
    """
    def make_verbose(f: Callable) -> Callable:
        def decorated_forward_method(*args, **kwargs) -> Any:
            print("-"*50)
            print(f"Model: {model_name}")
            print(f"Model type: {type(model).__name__}")
            print_args(args)
            print_kwargs(kwargs)
            output = f(*args, **kwargs)
            print_output(output)
            return output
        return decorated_forward_method
    model.forward = make_verbose(model.forward)
    return model

In [None]:
pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, **PIPELINE_LOADING_CONFIG)

pipe.text_encoder = make_forward_verbose(model=pipe.text_encoder, model_name="CLIP text encoder")
pipe.unet = make_forward_verbose(model=pipe.unet, model_name="U-Net")
pipe.vae.decoder = make_forward_verbose(model=pipe.vae.decoder, model_name="VAE (decoder)")
pipe.vae.post_quant_conv = make_forward_verbose(model=pipe.vae.post_quant_conv, model_name="VAE (post_quant_conv)")
pipe.safety_checker.vision_model = make_forward_verbose(model=pipe.safety_checker.vision_model, model_name="Safety checker (vision_model)")
pipe.safety_checker.visual_projection = make_forward_verbose(model=pipe.safety_checker.visual_projection, model_name="Safety checker (visual_projection)")

In [None]:
generation_config = {**PIPELINE_GENERATION_CONFIG, "num_inference_steps": 1}
pipe("a photo of an astronaut riding a horse on mars", **generation_config)
del pipe

***Notices:***
* A pipeline loaded using `.from_pretrained` is automatically set in evaluation mode [by default](https://huggingface.co/docs/diffusers/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained), ([`model.eval()`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.eval)),
* Gradient calculation is automatically disabled when executing a pipeline by calling its `__call__` magic method ([`torch.no_grad()`](https://pytorch.org/docs/stable/generated/torch.no_grad.html)). 

## 2. Compile the models
***Warning***: Compiling models for Neuron consumes a lot of memory. The following cells therefore carefully manage their variables. For each model to be compiled, you will only keep a copy of the model, the rest of the pipeline will be discarded to save memory.

### 2.1. The CLIP text encoder
From the information gathered above, you can see that at runtime, the CLIP text encoder:
* Gets a single input tensor of size `(batch_size, model_max_length)`, i.e. `(1, 77)`, the tensor of input token IDs. Cf. the `model_max_length` 
[tokenizer configuration](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/tokenizer/tokenizer_config.json)).
* Returns a object of type `transformers.modeling_outputs.BaseModelOutputWithPooling`. This object features 2 non-empty indexable attributes:
  * `last_hidden_state`: Tensor of shape `(batch_size, max_input_seq_length, encoder_projection_dim)`, i.e. `(1, 77, 768)`.
  * `pooler_output`: Tensor of shape `(batch_size, encoder_projection_dim)`, i.e. `(1, 768)`.

***Notices:***
* The text encoder can actually take two input tensors which correspond to the two tensors returned by the tokenizer, i.e. the tensor of token IDs and the 
corresponding attention mask. In the present case, we are ensured that the attention mask is never used, i.e. that the value of the `attention_mask` argument 
is always `None`. Using an attention mask or not is indeed a property of the text encoder. The attention mask is used [only if](https://github.com/huggingface/diffusers/blob/v0.20.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L328) 
the configuration of the text encoder includes a `use_attention_mask` entry and if this entry is set to `True`. 
[This is not the case for the CLIP text encoder](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/text_encoder/config.json), the example input tensors therefore do not include an attention mask tensor.
* `model_max_length` is the maximum input tokenized sequence length. Longer sequences are trucated by the tokenizer. Shorter sequences are right-padded 
by the tokenizer.
* `encoder_projection_dim` is the text embedding size. Cf. the [text encoder configuration](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/text_encoder/config.json).

To be able to compile the `torch.nn.Module` corresponding to the CLIP text encoder, you will modify its `forward` method so that:
* It returns a tuple of tensors instead of a `BaseModelOutputWithPooling` object.

In [None]:
TEXT_ENCODER_COMPILATION_DIR = NEURON_COMPILER_WORKDIR / "text_encoder"
TEXT_ENCODER_COMPILATION_DIR.mkdir(exist_ok=True)

def ensure_text_encoder_forward_neuron_compilable(model: CLIPTextModel) -> CLIPTextModel:
    def decorate_forward_method(f: Callable) -> Callable:
        def decorated_forward_method(*args, **kwargs) -> Tuple[torch.Tensor]:
            kwargs.update({"return_dict": False})
            output = f(*args, **kwargs)
            return output
        return decorated_forward_method
    model.forward = decorate_forward_method(model.forward)
    return model

In [None]:
# To minimze memory pressure, you only keep the model being compiled in RAM
pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, **PIPELINE_LOADING_CONFIG)

text_encoder = copy.deepcopy(pipe.text_encoder)
text_encoder = ensure_text_encoder_forward_neuron_compilable(text_encoder)

VOCAB_SIZE = pipe.tokenizer.vocab_size
MODEL_MAX_LENGTH = pipe.tokenizer.model_max_length

del pipe

In [None]:
# Execution time: ~1-2mins
example_input_ids = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, MODEL_MAX_LENGTH), dtype=torch.int64)

with torch.no_grad(): 
    text_encoder_neuron = torch_neuronx.trace(
        text_encoder,
        example_input_ids, 
            compiler_workdir=TEXT_ENCODER_COMPILATION_DIR,
            compiler_args=[*NEURON_COMPILER_CLI_ARGS, f'--logfile={TEXT_ENCODER_COMPILATION_DIR}/log-neuron-cc.txt'],
            )

# Free up memory
del example_input_ids, text_encoder

Now let's visualize a representation of the graph operations of the `forward` method of the `torch.jit.ScriptModule` returned by `torch_neuronx.trace`:

In [None]:
print(text_encoder_neuron.code)

One can clearly see that the call to the `forward` method:
* Takes a single input tensor,
* Consists of a single Neuron computation step,
* Returns a 2-tuple of tensors.

Now, let's enable both lazy and asynchronous loading for a better model loading performance and serialize the Neuron `ScriptModule` by simply using `torch.jit.save`. 
For more information, refer to the documentation of the [PyTorch Neuron Lazy and Asynchronous Loading API](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/api-reference-guide/inference/api-torch-neuronx-async-lazy-load.html#torch-neuronx-lazy-async-load-api).

In [None]:
torch_neuronx.async_load(text_encoder_neuron)
torch_neuronx.lazy_load(text_encoder_neuron)

torch.jit.save(text_encoder_neuron, NEURON_COMPILER_OUTPUT_DIR / "text_encoder.pt")

In [None]:
# Free up memory
del text_encoder_neuron

### 2.2. The U-Net
#### 2.2.1. Negative prompting
The U-Net predicts the amount of noise from a input sample (latent) at the given time step `t`. The predicted noise is then used by the scheduler to generate the (denoised) 
sample corresponding to the the previous time step. The noise prediction is conditioned on the input text prompt. To better guide the model in generating the desired image, the 
`StableDiffusionPipeline` allows the user to provide what they don't want to see in the generated image using a separate text prompt called "negative prompt" (`negative_prompt` argument).
When using negative prompts, two U-Net forward passes are performed at each time step: one to generate a noise prediction conditionned on the prompt, another to generate a noise 
prediction conditioned on the negative prompt. The final noise prediction for a given time step is created from both predictions and using a guidance coefficient ("classifier-free guidance"). 
The guidance coefficient (`guidance_scale` generation parameter) is a number greater than 1, the higher the value, the closer the generated image is from the input prompt.

For example the following `(prompt, negative_prompt)` pair, `("portrait of a man", "mustache")` would aim at generating portraits of men without a mustache and would probably 
do a better job than a single `"portrait of a man without mustache"` prompt.

***Notice:*** The value of the `guidance_scale` actually determines whether classifier-free guidance, to be understood here as "guidance using negative prompts", is used or not. 
If `guidance_scale` is greater than `1.0`, then denoised samples are generated using classifier-free guidance, and two U-Net forward passes are performed at each time step. If 
no `negative_prompt` is supplied, a fully padded negative prompt is actually used under the hood. If `guidance_scale` is lower or equal to `1.0`, then classifier-free guidance 
is disabled: a single U-Net forward pass is performed at each time step, supplied `negative_prompts` are ignored.

From the tensor operations perpective, classifier-free guidance / negative prompts have the following impact:
* The text encoder is used twice: once on the prompt(s), once on the negative prompt(s).
* Two U-Net forward passes are performed at each time step. For increased efficiency, the `diffusers` implementation concatenates the encodings of both prompts in a single tensor 
and a new sample input tensor is created by repeating the sample tensor twice. This trick allows to compute both noise predictions in a single forward pass. **The U-Net input 
tensors therefore do not have the same shape depending on whether the user use classifier-free guidance or not**.

#### 2.2.2. Tensor shapes
The first dimension of U-Net input tensor is:
* If classifier-free guidance (negative prompts) is disabled: `batch_size*num_images_per_prompt`.
* If classifier-free guidance (negative prompts) is enabled: `2*batch_size*num_images_per_prompt` (cf. above for more information about the `2` factor).

In the present case, classifier-free guidance is **enabled**.

From the information gathered in the first section, you can see that at runtime, the U-Net:
* Gets three input tensors:
  * A `sample` tensor (arg 0) of shape `(2*batch_size*num_images_per_prompt, unet_in_channels, height // vae_scaling_factor, width // vae_scaling_factor)`, i.e. `(2*1*1, 4, 512//8, 512//8)`, i.e. `(2, 4, 64, 64)`:
    * `unet_in_channels` is a [U-Net configuration](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json),
    * The VAE scaling factor is $2^{nb\_vae\_decoder\_blocks-1}=8$ (cf. the [VAE configuration](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json)).
  * A 0-dim `timestep` tensor (arg 1). You will need to make it a 1-dim tensor, i.e. a tensor of shape `(2*batch_size*num_images_per_prompt,)`, i.e. `(2,)`.
  * A `encoder_hidden_states` tensor (kwarg 0) of shape `(2*batch_size*num_images_per_prompt, max_input_seq_length, d_model)`, i.e. `(2, 77, 768)`. For each prompt kind, text encoding are repeated `num_images_per_prompt` times.
  This tensor concatenates the representations of both positive and negative prompts along its dimension 0.
* Returns a object of type `diffusers.models.unet_2d_condition.UNet2DConditionOutput`. This object features a single non-empty indexable attribute:
  * `sample`: Tensor of shape indentical to the input `sample` tensor.

To be able to compile the `torch.nn.Module` corresponding to the CLIP text encoder, you will modify its `forward` method so that:
* It returns a plain tensor instead of a `UNet2DConditionOutput` object.

#### 2.2.3. Compilation options
Assuming `num_images_per_prompt=1` for simplicity, one faces the following alternative when compiling the U-Net model:
* Using example inputs of size `2*batch_size` along their dimension 0. In that case, the compiled U-Net will compute the noise prediction for both positive and negative text conditioning.
* Using example inputs of size `batch_size` along their dimension 0 and apply model parallelism (parallelism degree of 2) to the compiled U-Net. Computation would be split over two model replicas, 
therefore making a better use of the available NeuronCores (there are indeed at least 2 NeuronCores available per instance) and would contribute to increasing throughput.

**This notebooks opts for the second option.** Actual input tensor shapes will therefore be:
* `sample`: `(1, 4, 64, 64)`
* `timestep`: `(1,)`
* `encoder_hidden_states`: `(1, 77, 768)`

In [None]:
UNET_COMPILATION_DIR = NEURON_COMPILER_WORKDIR / "unet"
UNET_COMPILATION_DIR.mkdir(exist_ok=True)

def ensure_unet_forward_neuron_compilable(model: UNet2DConditionModel) -> UNet2DConditionModel:
    def decorate_forward_method(f: Callable) -> Callable:
        def decorated_forward_method(*args, **kwargs) -> torch.Tensor:
            kwargs.update({"return_dict": False})
            output_sample, = f(*args, **kwargs)
            return output_sample
        return decorated_forward_method
    model.forward = decorate_forward_method(model.forward)
    return model

In [None]:
# To minimze memory pressure, you only keep the model being compiled in RAM
pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, **PIPELINE_LOADING_CONFIG)
diffusers.models.attention_processor.AttnProcessor2_0.__call__ = KernelizedAttnProcessor2_0.__call__
unet = copy.deepcopy(pipe.unet)
unet = ensure_unet_forward_neuron_compilable(unet)

UNET_IN_CHANNELS = pipe.unet.config.in_channels
VAE_SCALING_FACTOR = 2**(len(pipe.vae.config.block_out_channels)-1)
ENCODER_PROJECTION_DIM = pipe.text_encoder.config.hidden_size
MODEL_MAX_LENGTH = pipe.tokenizer.model_max_length

del pipe

Attention computation is a particularly compute-heavy operation. In the following cell, you monkey-patch the attention computation with an implementation that ensures it is matched 
to a highly-optimized kernel by the Neuron Compiler. 

You actually only monkey-patch the `get_attention_scores` method of the attention processors used by the cross-attention blocks of the 
UNet (`UNet2DConditionModel`), i.e. `diffuser.models.attention_processor.Attention`. Here are the key differences with the `diffusers` implementation:
* The code related to the attention mask has been removed. In the first section, you indeed noticed that no attention mask was actually passed to the Unet's `forward` method.
* The native implementation uses the biased `torch.baddbmm` function which performs the following operation: `beta*attention_mask+scaling_factor*(query@key)`. Since no attention mask 
is used, only the `scaling_factor*(query@key)` part is relevant and this is exactly what the simpler `torch.bmm` does.
* The `torch.nn.functional.softmax` function is used instead of the `torch.Tensor.softmax` method.

In [None]:
def get_attention_scores(self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
    dtype = query.dtype

    if self.upcast_attention:
        query = query.float()
        key = key.float()

    attention_scores = self.scale * torch.bmm(query, key.transpose(-1, -2))
    # output = torch.bmm(input_1, input_2) where input_1 of shape (b, n, m), input_2 of shape (b, m, p) and output of shape (b, n, p)
    # Attention: query of shape (b, model_length, d_q), key of shape (b, model_length, d_k), assuming d_q=d_k, expected 
    # bmm(query, key) = attention tensor of shape (b, model_length, model_length) -> Key tensor dims needs to be rearranged into
    # (b, d_k, model_length), i.e. swapping dim -1 & -2, i.e. key.transpose(-1, -2)
    
    if self.upcast_softmax:
        attention_scores = attention_scores.float()

    attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
    del attention_scores
    attention_probs = attention_probs.to(dtype)
        
    return attention_probs

In [None]:
# Monkey-patching
Attention.get_attention_scores = get_attention_scores

In [None]:
# Execution time: ~10mins
example_input_sample = torch.randn((BATCH_SIZE*NUM_IMAGES_PER_PROMPT, UNET_IN_CHANNELS, HEIGHT//VAE_SCALING_FACTOR, WIDTH//VAE_SCALING_FACTOR), dtype=DTYPE)
example_timestep = torch.randint(0, 1000, (BATCH_SIZE*NUM_IMAGES_PER_PROMPT,), dtype=DTYPE)
example_encoder_hidden_states = torch.randn((BATCH_SIZE*NUM_IMAGES_PER_PROMPT, MODEL_MAX_LENGTH, ENCODER_PROJECTION_DIM), dtype=DTYPE)
example_inputs = (example_input_sample, example_timestep, example_encoder_hidden_states)

with torch.no_grad():
    unet_neuron = torch_neuronx.trace(
        unet,
        example_inputs,
        compiler_workdir=UNET_COMPILATION_DIR,
        compiler_args=[*NEURON_COMPILER_CLI_ARGS, f'--logfile={UNET_COMPILATION_DIR}/log-neuron-cc.txt', "--model-type=unet-inference"],
    )

# Free up memory
del example_input_sample, example_timestep, example_encoder_hidden_states, example_inputs, unet

Now let's visualize a representation of the graph operations of the `forward` method of the `torch.jit.ScriptModule` returned by `torch_neuronx.trace`:

In [None]:
print(unet_neuron.code)

One can clearly see that the call to the `forward` method:
* Takes three input tensors,
* Consists of a single Neuron computation step,
* Returns a single tensor.

Now, let's enable both lazy and asynchronous loading for a better model loading performance and serialize the Neuron `ScriptModule` by simply using `torch.jit.save`. 

In [None]:
torch_neuronx.async_load(unet_neuron)
torch_neuronx.lazy_load(unet_neuron)

torch.jit.save(unet_neuron, NEURON_COMPILER_OUTPUT_DIR / "unet.pt")

In [None]:
# Free up memory
del unet_neuron

### 2.3. The VAE
At inference time, only the VAE's decoder is used. The decoder actually consists of two models that you need to compile: the post_quant_conv model and the decoder itself.

From the information gathered in the first section, you can see that at runtime, the `forward` method of both elements takes a single tensor as input and returns a single 
tensor as output. There is therefore no need to modify the `forward` method using decoration/monkey-patching for compilation. Since the post_quant_conv model is very lightweight, 
you will compile both at the same time.

The post_quant_conv model input and output tensors are of shape `(batch_size*num_images_per_prompt, unet_out_channels, height // vae_scaling_factor, width // vae_scaling_factor)`, i.e. `(1, 4, 64, 64)`.

***Notice***: The value of the `unet_out_channels` U-Net configuration is identical to the value of the `vae_latent_channels` [VAE configuration](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json).

For the decoder model:
* Input tensor is identical to the output tensor of the post_quant_conv model.
* Output tensor is of shape `(batch_size*num_images_per_prompt, vae_out_channels, height, width)`, i.e. `(1, 3, 512, 512)`

In [None]:
VAE_COMPILATION_DIR = NEURON_COMPILER_WORKDIR / "vae"
VAE_COMPILATION_DIR.mkdir(exist_ok=True)

In [None]:
# To minimze memory pressure, you only keep the model being compiled in RAM
pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, **PIPELINE_LOADING_CONFIG)

vae_post_quant_conv = copy.deepcopy(pipe.vae.post_quant_conv)
vae_decoder = copy.deepcopy(pipe.vae.decoder)

LATENT_CHANNELS = pipe.vae.config.latent_channels
VAE_SCALING_FACTOR = 2**(len(pipe.vae.config.block_out_channels)-1)

del pipe

In [None]:
# Execution time: ~7-8mins
example_latent_sample = torch.randn((BATCH_SIZE*NUM_IMAGES_PER_PROMPT, LATENT_CHANNELS, HEIGHT//VAE_SCALING_FACTOR, WIDTH//VAE_SCALING_FACTOR), dtype=DTYPE)

with torch.no_grad():
    VAE_POST_QUANT_COMPILATION_DIR = VAE_COMPILATION_DIR / "post_quant_conv"
    vae_post_quant_conv_neuron = torch_neuronx.trace(
        vae_post_quant_conv,
        example_latent_sample,
        compiler_workdir=VAE_POST_QUANT_COMPILATION_DIR,
        compiler_args=[*NEURON_COMPILER_CLI_ARGS, f'--logfile={VAE_POST_QUANT_COMPILATION_DIR}/log-neuron-cc.txt'],
    )

    VAE_DECODER_COMPILATION_DIR = VAE_COMPILATION_DIR / "decoder"
    vae_decoder_neuron = torch_neuronx.trace(
        vae_decoder,
        example_latent_sample,
        compiler_workdir=VAE_DECODER_COMPILATION_DIR / "decoder",
        compiler_args=[*NEURON_COMPILER_CLI_ARGS, "--model-type=unet-inference", f'--logfile={VAE_DECODER_COMPILATION_DIR}/log-neuron-cc.txt'],
    )

# Free up memory
del vae_post_quant_conv, vae_decoder, example_latent_sample

Now let's visualize a representation of the graph operations of the `forward` method of the `torch.jit.ScriptModule` returned by `torch_neuronx.trace`:

In [None]:
print(vae_post_quant_conv_neuron.code)

In [None]:
print(vae_decoder_neuron.code)

One can clearly see that in both cases, the call to the `forward` method:
* Takes a single input tensor,
* Consists of a single Neuron computation step,
* Returns a single tensor.

Now, let's enable both lazy and asynchronous loading for a better model loading performance and serialize the Neuron `ScriptModule`s by simply using `torch.jit.save`. 

In [None]:
for neuron_model, file_name in zip((vae_post_quant_conv_neuron, vae_decoder_neuron), ("vae_post_quant_conv.pt", "vae_decoder.pt")):
    torch_neuronx.async_load(neuron_model)
    torch_neuronx.lazy_load(neuron_model)
    torch.jit.save(neuron_model, NEURON_COMPILER_OUTPUT_DIR / file_name)

In [None]:
# Free up memory
del vae_post_quant_conv_neuron, vae_decoder_neuron

### 2.4. The safety model
Like the VAE, the safety checker actually consist of two model components that you need to compile separately: the vision model and the visual projection.

Regarding the vision model, from the information gathered in the first section, the safety model is a CLIP vision transformer that:
* Gets a single input tensor of size `(batch_size*num_images_per_prompt, vae_out_channels, feature_extractor_crop_height, feature_extractor_crop_width)`, i.e. `(1, 3, 224, 224)` 
(cf. the `crop_size` [feature extractor configuration](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/feature_extractor/preprocessor_config.json)). 
* Returns a object of type `transformers.modeling_outputs.BaseModelOutputWithPooling`. This object features 2 non-empty indexable attributes:
  * `last_hidden_state`: Tensor of shape `(batch_size*num_images_per_prompt, ?, hidden_size)`, i.e. `(1, 257, 1024)`.
  * `pooler_output`: Tensor of shape `(batch_size*num_images_per_prompt, hidden_size)`, i.e. `(1, 1024)`.

Similarly to the CLIP text encoder, to be able to compile the `torch.nn.Module` corresponding to the safety checker, you will modify its `forward` method so that:
* It returns a tuple of tensors instead of a `BaseModelOutputWithPooling` object.

Regarding the visual projection, it is actually a simple `torch.nn.Linear` layer that:
* Gets a single input tensor of size `(batch_size*num_images_per_prompt, hidden_size)`, i.e. `(1, 1024)` (the `pooler_output` from the vision model).
* Returns a single tensor of size `(batch_size*num_images_per_prompt, projection_dim)`, i.e. `(1, 768)`.

The `forward` method of the visual projection model can therefore be directly traced and compiled by `torch_neuronx.trace`. Since the visual projection model is very lightweight, 
you will compile both models at the same time.

***Notice:*** Cf. the `hidden_size` and `projection_dim` [safety checker configurations](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/safety_checker/config.json).

In [None]:
SAFETY_CHECKER_COMPILATION_DIR = NEURON_COMPILER_WORKDIR / "safety_checker"
SAFETY_CHECKER_COMPILATION_DIR.mkdir(exist_ok=True)

def ensure_vision_model_forward_neuron_compilable(model: CLIPVisionModel) -> CLIPVisionModel:
    def decorate_forward_method(f: Callable) -> Callable:
        def decorated_forward_method(*args, **kwargs) -> Tuple[torch.Tensor]:
            kwargs.update({"return_dict": False})
            output = f(*args, **kwargs)
            return output
        return decorated_forward_method
    model.forward = decorate_forward_method(model.forward)
    return model

In [None]:
# To minimze memory pressure, you only keep the model being compiled in RAM
pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, **PIPELINE_LOADING_CONFIG)

safety_checker_vision_model = copy.deepcopy(pipe.safety_checker.vision_model)
safety_checker_visual_projection = copy.deepcopy(pipe.safety_checker.visual_projection)
safety_checker_vision_model = ensure_vision_model_forward_neuron_compilable(safety_checker_vision_model)

VAE_OUT_CHANNELS = pipe.vae.config.out_channels
FEATURE_EXTRACTOR_CROP_HEIGHT, FEATURE_EXTRACTOR_CROP_WIDTH = pipe.feature_extractor.crop_size.values()
VISION_MODEL_HIDDEN_DIM = pipe.safety_checker.config.vision_config.hidden_size

del pipe

In [None]:
# Execution time: ~2-3mins
example_safety_checker_vision_model_input = torch.randn((BATCH_SIZE*NUM_IMAGES_PER_PROMPT, VAE_OUT_CHANNELS, FEATURE_EXTRACTOR_CROP_HEIGHT, FEATURE_EXTRACTOR_CROP_WIDTH), dtype=DTYPE)
example_safety_checker_visual_projection_input = torch.randn((BATCH_SIZE*NUM_IMAGES_PER_PROMPT, VISION_MODEL_HIDDEN_DIM), dtype=DTYPE)

with torch.no_grad(): 
    SAFETY_CHECKER_VISION_MODEL_COMPILATION_DIR = SAFETY_CHECKER_COMPILATION_DIR / "vision_model"
    safety_checker_vision_model_neuron = torch_neuronx.trace(
        safety_checker_vision_model,
        example_safety_checker_vision_model_input, 
            compiler_workdir=SAFETY_CHECKER_VISION_MODEL_COMPILATION_DIR,
            compiler_args=[*NEURON_COMPILER_CLI_ARGS, f'--logfile={SAFETY_CHECKER_VISION_MODEL_COMPILATION_DIR}/log-neuron-cc.txt'],
            )

    SAFETY_CHECKER_VISUAL_PROJECTION_DIR = SAFETY_CHECKER_COMPILATION_DIR / "visual_projection"
    safety_checker_visual_projection_neuron = torch_neuronx.trace(
        safety_checker_visual_projection,
        example_safety_checker_visual_projection_input, 
            compiler_workdir=SAFETY_CHECKER_VISUAL_PROJECTION_DIR,
            compiler_args=[*NEURON_COMPILER_CLI_ARGS, f'--logfile={SAFETY_CHECKER_VISUAL_PROJECTION_DIR}/log-neuron-cc.txt'],
            )

# Free up memory
del safety_checker_vision_model, example_safety_checker_vision_model_input, safety_checker_visual_projection, example_safety_checker_visual_projection_input

Now let's visualize a representation of the graph operations of the `forward` method of the `torch.jit.ScriptModule` returned by `torch_neuronx.trace`:

In [None]:
print(safety_checker_vision_model_neuron.code)

In [None]:
print(safety_checker_visual_projection_neuron.code)

In the case of the vision model, one can clearly see that the call to the `forward` method:
* Takes a single input tensor,
* Consists of a single Neuron computation step,
* Returns a 2-tuple of tensors.

The case of the visual projection model is very similar but only returns a single tensor.

Now, let's enable both lazy and asynchronous loading for a better model loading performance and serialize the Neuron `ScriptModule`s by simply using `torch.jit.save`.

In [None]:
for neuron_model, file_name in zip((safety_checker_vision_model_neuron, safety_checker_visual_projection_neuron), ("safety_checker_vision_model.pt", "safety_checker_visual_projection.pt")):
    torch_neuronx.async_load(neuron_model)
    torch_neuronx.lazy_load(neuron_model)
    torch.jit.save(neuron_model, NEURON_COMPILER_OUTPUT_DIR / file_name)

In [None]:
# Free up memory
del safety_checker_vision_model_neuron, safety_checker_visual_projection_neuron

# 3. Load the models
In this section you will:
1. Load the Stable Diffusion pipeline,
2. Load the serialized compiled models,
3. When required, modify the `forward` method of the loaded Neuron `torch.jit.ScriptModule` so that they are adapted to the `StableDiffusionPipeline` (cf. below).
4. Loading the compiled models to your host's NeuronCores.

## 3.1. Ensuring runtime compatibility
In practice, running Stable Diffusion on Neuron devices consists in replacing the model components of the `StableDiffusionPipeline`. The `StableDiffusionPipeline` handles 
the different model components to generate an image from input text. In the first section, you extracted for each model to be compiled and run on Neuron:
* The input positional and keyword arguments injected to their `forward` method by the `StableDiffusionPipeline`,
* The output type of their `forward` method, i.e. the object type the `StableDiffusionPipeline` then expects for further processing.

On the other hand, the `forward` method of compiled models only expect positional arguments and returns a single tensor or a sequence of tensors. Depending on the model, there 
is a mismatch between the interface of the compiled model and the interface set by the `StableDiffusionPipeline`. When studying the output of the first section, one can notice 
that the compiled VAE models can be used as is. For the other models however (text encoder, U-Net, safety checker vision model), their `forward` function need to be adapted. More concretely, 
you will use adaptor functions that 1) decorate (i.e. modify) a model's `forward` method, 2) monkey-patches the model's `torch.jit.ScriptModule` with the decorated `forward` method.

## 3.2 U-Net data parallelism
The following code applies data parallelism to the compiled U-Net. The data parallelism degree equals 2, i.e. `torch-neuronx` creates 2 model replicas. Input tensor are 
automatically chunked along their first dimension (dim 0). Since dynamic batching is disabled, input tensor must fulfill the following condition: `runtime_input.shape[0]/data_parallelism_degree == compile_time_input.shape[0]`. 
This condition is fulfilled since for dim 0, compile time size is 1, runtime size is 2 and data parallelism degree is 2.

## 3.3 Dynamic batching
[Dynamic batching](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/api-reference-guide/inference/api-torch-neuronx-trace.html#dynamic-batching) enables 
a compiled Neuron model to be called with variable sized batches. When disabled, input tensor shapes must fulfill the condition detailed in section 3.2. If not, a runtime error will be raised. 
Dynamic batching allows to consume input tensors that do not respect this condition without requiring to recompile.

In the present case, dynamic batching would allow to use (possibly variable sized) batch sizes and/or a number of images per prompt greater than 1 (compile-time value). However, the property 
must be enabled for all compiled models.

Dynamic batching can be enabled [at compile time](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/api-reference-guide/inference/api-torch-neuronx-trace.html#dynamic-batching) 
(disabled by default) or enabled/disabled [at load time](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/api-reference-guide/inference/api-torch-neuronx-data-parallel.html) when 
using the PyTorch Neuron DataParallel API (enabled by default) only.

In [None]:
# Notice: Text encoder kwargs are discared since 1) Neuron

def text_encoder_runtime_decorator_factory(neuron_model: torch.jit.ScriptModule) -> torch.jit.ScriptModule:
    def text_encoder_decorator(model: CLIPTextModel) -> CLIPTextModel:
        def neuron_forward_method(*args, **kwargs) -> BaseModelOutputWithPooling:
            input_ids, *_ = args # Ensures that only a single (the first) arg is supplied to the compiled model, kwargs are discarded
            last_hidden_state, pooler_output = neuron_model(input_ids)
            return BaseModelOutputWithPooling(last_hidden_state=last_hidden_state, pooler_output=pooler_output)
        model.forward = neuron_forward_method
        return model
    return text_encoder_decorator


def unet_runtime_decorator_factory(neuron_model: torch_neuronx.xla_impl.data_parallel.DataParallel) -> torch_neuronx.xla_impl.data_parallel.DataParallel:
    def unet_decorator(model: UNet2DConditionModel) -> UNet2DConditionModel:
        def neuron_forward_method(*args, **kwargs) -> UNet2DConditionOutput:
            sample, timestep, *_ = args
            dim_0 = sample.shape[0]
            if isinstance(timestep, (int, float)):
                timestep = torch.Tensor([timestep]*dim_0).type(DTYPE)
            else:
                # Can either be tensor([999.]) (1-dim) or tensor(999.) (0-dim)
                timestep = timestep.type(DTYPE).repeat(dim_0)
            encoder_hidden_states = kwargs["encoder_hidden_states"]
            output_sample = neuron_model(sample, timestep, encoder_hidden_states)
            return UNet2DConditionOutput(sample=output_sample)
        model.forward = neuron_forward_method
        return model
    return unet_decorator


def vision_model_runtime_decorator_factory(neuron_model: torch.jit.ScriptModule) -> torch.jit.ScriptModule:
    def vision_model_checker_decorator(model: CLIPVisionModel) -> CLIPVisionModel:
        def neuron_forward_method(*args, **kwargs) -> BaseModelOutputWithPooling:
            pixel_values, *_ = args # Ensures that only a single (the first) arg is supplied to the compiled model
            last_hidden_state, pooler_output = neuron_model(pixel_values)
            return BaseModelOutputWithPooling(last_hidden_state=last_hidden_state, pooler_output=pooler_output)
        model.forward = neuron_forward_method
        return model
    return vision_model_checker_decorator

In [None]:
# ---- 1. Load the StableDiffusionPipeline ----
pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, **PIPELINE_LOADING_CONFIG)

# ---- 2. Load the serialized compiled models ----
# The VAE compiled models and the safety checker's visual projection compiled model can be directly used by the StableDiffusionPipeline
pipe.vae.decoder = torch.jit.load(NEURON_COMPILER_OUTPUT_DIR / "vae_decoder.pt")
pipe.vae.post_quant_conv = torch.jit.load(NEURON_COMPILER_OUTPUT_DIR / "vae_post_quant_conv.pt")

pipe.safety_checker.visual_projection = torch.jit.load(NEURON_COMPILER_OUTPUT_DIR / "safety_checker_visual_projection.pt")

text_encoder_neuron = torch.jit.load(NEURON_COMPILER_OUTPUT_DIR / "text_encoder.pt")
unet_neuron = torch.jit.load(NEURON_COMPILER_OUTPUT_DIR / "unet.pt")
safety_checker_vision_neuron = torch.jit.load(NEURON_COMPILER_OUTPUT_DIR / "safety_checker_vision_model.pt")
# Since lazy loading has been enabled for all compiled models, nothing has been allocated on the NeuronCores at this point

# ---- 3. Adapt the compiled models to the StableDiffusionPipeline ----
ensure_text_encoder_forward_neuron_runnable = text_encoder_runtime_decorator_factory(neuron_model=text_encoder_neuron)
pipe.text_encoder = ensure_text_encoder_forward_neuron_runnable(pipe.text_encoder)
ensure_vision_model_forward_neuron_runnable = vision_model_runtime_decorator_factory(neuron_model=safety_checker_vision_neuron)
pipe.safety_checker.vision_model = ensure_vision_model_forward_neuron_runnable(pipe.safety_checker.vision_model)

# Since the U-Net is the pipeline's bottleneck, you apply data parallelism to the U-Net by creating two model replicas, each 
# on its own NeuronCore
unet_neuron = torch_neuronx.DataParallel(unet_neuron, device_ids=[0, 1], set_dynamic_batching=False)
ensure_unet_forward_neuron_runnable = unet_runtime_decorator_factory(neuron_model=unet_neuron)
pipe.unet = ensure_unet_forward_neuron_runnable(pipe.unet)

# ---- 4. Trigger the loading of Neuron models with a warmup generation ----
generation_config = {**PIPELINE_GENERATION_CONFIG, "num_inference_steps": 1}
pipe("a photo of an astronaut riding a horse on mars", **generation_config);

***Notice:*** When loading multiple models like here, the default behavior of the Neuron Runtime is to (automatically) evenly distribute models across all available NeuronCores. 
The Neuron Runtime places models on the NeuronCore that has the fewest models loaded to it first. Multiple models can therefore be automatically be placed on the same NeuronCore. 
However, a NeuronCore can only execute one model at a time.

# 4. Run the Stable Diffusion pipeline

In [None]:
# Optional, for reproducibility, see also:
#   - https://huggingface.co/docs/diffusers/using-diffusers/reusing_seeds
#   - https://huggingface.co/docs/diffusers/using-diffusers/reproducibility

DETERMINISTIC_GENERATION_ENABLED = True
if DETERMINISTIC_GENERATION_ENABLED: 
    generators = [torch.Generator().manual_seed(i) for i in range(NUM_IMAGES_PER_PROMPT)]
    PIPELINE_GENERATION_CONFIG.update({"generator": generators})
elif not DETERMINISTIC_GENERATION_ENABLED and "generator" in PIPELINE_GENERATION_CONFIG:
    PIPELINE_GENERATION_CONFIG.pop("generator")

In [None]:
prompts = [
    # "Positive" prompt only
    ("a photo of an astronaut riding a horse on mars", None),
    ("sonic on the moon", None),
    ("elvis playing guitar while eating a hotdog", None),
    ("saved by the bell", None),
    ("engineers eating lunch at the opera", None),
    # "Positive" & negative prompts
    ("panda eating bamboo on an aircraft", "lowres, out of frame, bad face"),
    ("a digital illustration of a steampunk flying machine in the sky with cogs and mechanisms, 4k, detailed, trending in artstation, fantasy vivid colors", "dark, out of frame, photorealistic"),
    ("kids playing soccer at the FIFA World Cup", "disfigured, poorly drawn face, oversaturated"),
    ("a smiling tomato, cartoon style", "distorted, faded"),
    ("ad for a burger, burger in the center, eiffel tower in the background", "underexposed, boring background"),
    ]

plt.xlabel("Width (pixels)")
plt.ylabel("Height (pixels)")

generation_times = []
for prompt, negative_prompt in prompts:
    start_time = datetime.datetime.now()
    image = pipe(prompt=prompt, negative_prompt=negative_prompt, **PIPELINE_GENERATION_CONFIG).images[0]
    generation_times.append((datetime.datetime.now() - start_time).total_seconds())
    plt.imshow(np.asarray(image))
    plt.show()

total_runtime, mean_runtime, std_runtime = sum(generation_times), np.mean(generation_times), np.std(generation_times)
print(f"Runtimes (in seconds) - Total: {total_runtime:3.3f} - Mean {mean_runtime:3.3f} (std: {std_runtime:3.3f})")

## 4.1 Runtime performance
Runtime performance is assessed both in terms of latency (P95 or other relevant quantile of the distribution of generation 
times) and throughput (number of generated images per second).

Performance numbers must be communicated together with the performance impacting configurations. In the case of Stable 
Diffusion pipelines, these parameters include:
* Number of denoising steps (`num_inference_steps`)
* Number of images generated for each prompt (`num_images_per_prompt`)
* Batch size
* Data type (including type casting options of the Neuron Compiler, e.g. matrix multiplications cast to BF16)
* Whether a safety checker model is included since it can technically be omitted
* Whether classifier-free guidance (CFG) is enabled (`guidance_scale` value)

More performance numbers for vision models on Inf2 are available on the [Inf2 Performance page of the Neuron docs](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/benchmarks/inf2/inf2-performance.html#vision-models-inference-performance).

## 4.2 Disabling classifier-free guidance
The compiled models can work with disabled classifier-free guidance (CFG), i.e. for `guidance_scale` values smaller or equal to `1.0`. Negative prompts should not be supplied since they will be ignored. In that case however, the application 
of data parallelism to the U-Net model looses part of its benefits.

In [None]:
%%time
image = pipe(
    prompt="An erupting volcano, Claude Monet, impressionist", 
    **{**PIPELINE_GENERATION_CONFIG, "guidance_scale": 0.0}
    ).images[0]
plt.imshow(np.asarray(image))
plt.show()

# 5. Clean up

In [None]:
HF_CACHE_DIR = Path.home() / ".cache" / "huggingface" / "hub"
shutil.rmtree(HF_CACHE_DIR / "--".join(["models", *MODEL_ID.split("/")]))
shutil.rmtree(NEURON_COMPILER_WORKDIR)
shutil.rmtree(NEURON_COMPILER_OUTPUT_DIR)