## HuggingFace Stable Diffusion XL 1.0 Base and Refiner (1024x1024) Inference on Inf2

**Introduction**

This notebook demonstrates how to compile and run the HuggingFace Stable Diffusion XL (1024x1024) model, both refiner and base, for accelerated inference on Neuron.

This Jupyter notebook should be run on an Inf2 instance (`inf2.8xlarge` or larger)

Verify that this Jupyter notebook is running the Python kernel environment that was set up according to the [PyTorch Installation Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/torch-neuronx.html#setup-torch-neuronx). You can select the kernel from the 'Kernel -> Change Kernel' option on the top of this Jupyter notebook page.

**Install Dependencies**

This tutorial requires the following pip packages to be installed:
- `torch-neuronx`
- `neuronx-cc`
- `diffusers==0.29.2`
- `transformers==4.42.3`
- `accelerate==0.31.0`
- `matplotlib`
- `safetensors==0.5.3`

`torch-neuronx` and `neuronx-cc` will be installed when you configure your environment following the Inf2 setup guide. The remaining dependencies can be installed below:

In [None]:
%env TOKENIZERS_PARALLELISM=True #Supresses tokenizer warnings making errors easier to detect
!pip install diffusers==0.29.2 transformers==4.42.3 accelerate==0.31.0 safetensors==0.5.3 matplotlib

**Imports**

In [None]:
import os
 
import numpy as np
import torch
import torch.nn as nn
import torch_neuronx
import diffusers
from diffusers import DiffusionPipeline
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from diffusers.models.attention_processor import Attention
 
from matplotlib import pyplot as plt
from matplotlib import image as mpimg
import time
import copy
from IPython.display import clear_output

try:
    from neuronxcc.nki._private_kernels.attention import attention_isa_kernel  # noqa: E402
except ImportError:
    from neuronxcc.nki.kernels.attention import attention_isa_kernel  # noqa: E402
from torch_neuronx.xla_impl.ops import nki_jit  # noqa: E402
import math
import torch.nn.functional as F
from typing import Optional

_flash_fwd_call = nki_jit()(attention_isa_kernel)
def attention_wrapper_without_swap(query, key, value):
    bs, n_head, q_len, d_head = query.shape  # my change
    k_len = key.shape[2]
    v_len = value.shape[2]
    q = query.clone().permute(0, 1, 3, 2).reshape((bs * n_head, d_head, q_len))
    k = key.clone().permute(0, 1, 3, 2).reshape((bs * n_head, d_head, k_len))
    v = value.clone().reshape((bs * n_head, v_len, d_head))
    attn_output = torch.zeros((bs * n_head, q_len, d_head), dtype=torch.bfloat16, device=q.device)

    scale = 1 / math.sqrt(d_head)
    _flash_fwd_call(q, k, v, scale, attn_output, kernel_name="AttentionMMSoftmaxMMWithoutSwap")

    attn_output = attn_output.reshape((bs, n_head, q_len, d_head))

    return attn_output
class KernelizedAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        *args,
        **kwargs,
    ) -> torch.Tensor:
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            diffusers.utils.deprecate("scale", "1.0.0", deprecation_message)

        residual = hidden_states
        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        if attention_mask is not None or query.shape[3] > query.shape[2] or query.shape[3] > 128 or value.shape[2] == 77:
            hidden_states = F.scaled_dot_product_attention(
                query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
            )
        else:
            hidden_states = attention_wrapper_without_swap(query, key, value)

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

clear_output(wait=False)

**Define utility classes and functions**

The following section defines some utility classes and functions. In particular, we define a double-wrapper for the UNet. These wrappers enable `torch_neuronx.trace` to trace the wrapped models for compilation with the Neuron compiler. In addition, the `get_attention_scores_neuron` utility function performs optimized attention score calculation and is used to replace the origianl `get_attention_scores` function in the `diffusers` package via a monkey patch (see the next code block under "Compile UNet and save" for usage).

In [None]:
def get_attention_scores_neuron(self, query, key, attn_mask):    
    if(query.size() == key.size()):
        attention_scores = custom_badbmm(
            key,
            query.transpose(-1, -2),
            self.scale
        )
        attention_probs = attention_scores.softmax(dim=1).permute(0,2,1)

    else:
        attention_scores = custom_badbmm(
            query,
            key.transpose(-1, -2),
            self.scale
        )
        attention_probs = attention_scores.softmax(dim=-1)
  
    return attention_probs
 

def custom_badbmm(a, b, scale):
    bmm = torch.bmm(a, b)
    scaled = bmm * scale
    return scaled
 

class UNetWrap(nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.unet = unet
 
    def forward(self, sample, timestep, encoder_hidden_states, text_embeds=None, time_ids=None):
        out_tuple = self.unet(sample,
                              timestep,
                              encoder_hidden_states,
                              added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids},
                              return_dict=False)
        return out_tuple
    
    
class NeuronUNet(nn.Module):
    def __init__(self, unetwrap):
        super().__init__()
        self.unetwrap = unetwrap
        self.config = unetwrap.unet.config
        self.in_channels = unetwrap.unet.in_channels
        self.add_embedding = unetwrap.unet.add_embedding
        self.device = unetwrap.unet.device
        diffusers.models.attention_processor.AttnProcessor2_0.__call__ = KernelizedAttnProcessor2_0.__call__
 
    def forward(self, sample, timestep, encoder_hidden_states, timestep_cond=None, added_cond_kwargs=None, return_dict=False, cross_attention_kwargs=None):
        sample = self.unetwrap(sample,
                               timestep.float().expand((sample.shape[0],)),
                               encoder_hidden_states,
                               added_cond_kwargs["text_embeds"],
                               added_cond_kwargs["time_ids"])[0]
        return UNet2DConditionOutput(sample=sample)

    

**Compile the model into an optimized TorchScript and save the TorchScript**

In the following section, we will compile parts of the Stable Diffusion pipeline for execution on Neuron. Note that this only needs to be done once: After you have compiled and saved the model by running the following section of code, you can reuse it any number of times without having to recompile. In particular, we will compile:
1. The text encoders (note that text_encoder_2 is shared between base & refiner)
2. The base's UNet
3. The refiner's UNet
4. The VAE decoder (shared between base & refiner)
5. The VAE post_quant_conv (shared between base & refiner)
These blocks are chosen because they represent the bulk of the compute in the pipeline, and performance benchmarking has shown that running them on Neuron yields significant performance benefit.

Several points worth noting are:
1. In order to save RAM (these compiles need lots of RAM!), before tracing each model, we make a deepcopy of the part of the pipeline (i.e. the UNet or the VAE decoder) that is to be traced, and then delete the pipeline object from memory with `del pipe`. This trick allows the compile to succeed on instance types with a smaller amount of RAM.
2. When compiling each part of the pipeline, we need to pass `torch_neuronx.trace` sample input(s), When there are multiple inputs, they are passed together as a tuple. For details on how to use `torch_neuronx.trace`, please refer to our documentation here: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/api-reference-guide/inference/api-torch-neuronx-trace.html
3. Note that while compiling the UNet and text encoders, we make use of the double-wrapper structures defined above. In addition, we also use the optimized `get_attention_scores_neuron` function to replace the original `get_attention_scores` function in the `Attention` class prior to UNet compilation.

In [None]:
COMPILER_WORKDIR_ROOT = 'sdxl_base_and_refiner_compile_dir_1024'

# Model ID for SD XL version pipeline
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0"



# --- Compile UNet (base) and save ---
pipe_base = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float32, low_cpu_mem_usage=True)

# Replace original cross-attention module with custom cross-attention module for better performance
Attention.get_attention_scores = get_attention_scores_neuron

# Apply double wrapper to deal with custom return type
pipe_base.unet = NeuronUNet(UNetWrap(pipe_base.unet))

# Only keep the model being compiled in RAM to minimze memory pressure
unet_base = copy.deepcopy(pipe_base.unet.unetwrap)
del pipe_base

# Compile base unet - fp32
sample_1b = torch.randn([1, 4, 128, 128])
timestep_1b = torch.tensor(999).float().expand((1,))
encoder_hidden_states_1b = torch.randn([1, 77, 2048])
added_cond_kwargs_1b = {"text_embeds": torch.randn([1, 1280]),
                        "time_ids": torch.randn([1, 6])}
example_inputs = (sample_1b, timestep_1b, encoder_hidden_states_1b, added_cond_kwargs_1b["text_embeds"], added_cond_kwargs_1b["time_ids"],)

unet_base_neuron = torch_neuronx.trace(
    unet_base,
    example_inputs,
    compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'unet_base'),
    compiler_args=["--model-type=unet-inference"]
)

# save compiled unet
unet_base_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'unet_base/model.pt')
torch.jit.save(unet_base_neuron, unet_base_filename)

# delete unused objects
del unet_base
del unet_base_neuron



# --- Compile UNet (refiner) and save ---

pipe_refiner = DiffusionPipeline.from_pretrained(refiner_model_id, torch_dtype=torch.float32, low_cpu_mem_usage=True)

# Apply double wrapper to deal with custom return type
pipe_refiner.unet = NeuronUNet(UNetWrap(pipe_refiner.unet))

# Only keep the model being compiled in RAM to minimze memory pressure
unet_refiner = copy.deepcopy(pipe_refiner.unet.unetwrap)
del pipe_refiner

# Compile refiner unet - fp32 - some shapes are different than the base model
encoder_hidden_states_refiner_1b = torch.randn([1, 77, 1280])
added_cond_kwargs_refiner_1b = {"text_embeds": torch.randn([1, 1280]),
                                "time_ids": torch.randn([1, 5])}
example_inputs = (sample_1b, timestep_1b, encoder_hidden_states_refiner_1b, added_cond_kwargs_refiner_1b["text_embeds"], added_cond_kwargs_refiner_1b["time_ids"],)

unet_refiner_neuron = torch_neuronx.trace(
    unet_refiner,
    example_inputs,
    compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'unet_refiner'),
    compiler_args=["--model-type=unet-inference"]
)

# save compiled unet
unet_refiner_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'unet_refiner/model.pt')
torch.jit.save(unet_refiner_neuron, unet_refiner_filename)

# delete unused objects
del unet_refiner
del unet_refiner_neuron



# --- Compile VAE decoder and save ---

# Only keep the model being compiled in RAM to minimze memory pressure
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float32, low_cpu_mem_usage=True)
decoder = copy.deepcopy(pipe.vae.decoder)
del pipe

# # Compile vae decoder
decoder_in = torch.randn([1, 4, 128, 128])
decoder_neuron = torch_neuronx.trace(
    decoder, 
    decoder_in, 
    compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder'),
    compiler_args=["--model-type=unet-inference"]
)

# Save the compiled vae decoder
decoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder/model.pt')
torch.jit.save(decoder_neuron, decoder_filename)

# delete unused objects
del decoder
del decoder_neuron


# --- Compile VAE post_quant_conv and save ---

# Only keep the model being compiled in RAM to minimze memory pressure
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float32, low_cpu_mem_usage=True)
post_quant_conv = copy.deepcopy(pipe.vae.post_quant_conv)
del pipe

# Compile vae post_quant_conv
post_quant_conv_in = torch.randn([1, 4, 128, 128])
post_quant_conv_neuron = torch_neuronx.trace(
    post_quant_conv, 
    post_quant_conv_in,
    compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv'),
)

# Save the compiled vae post_quant_conv
post_quant_conv_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv/model.pt')
torch.jit.save(post_quant_conv_neuron, post_quant_conv_filename)

# delete unused objects
del post_quant_conv
del post_quant_conv_neuron

**Helper function to run the base and refiner**

In [None]:
def run_refiner_and_base(base, refiner, prompt, n_steps=40, high_noise_frac=0.8):
    image = base(
        prompt=prompt,
        num_inference_steps=n_steps,
        denoising_end=high_noise_frac,
        output_type="latent",
    ).images

    image = refiner(
        prompt=prompt,
        num_inference_steps=n_steps,
        denoising_start=high_noise_frac,
        image=image,
    ).images[0]

    return image

**Load the saved model and run it**

Now that the model is compiled, you can reload it with any number of prompts. Note the use of the `torch_neuronx.DataParallel` API to load the UNet onto two neuron cores for data-parallel inference. Currently the UNet is the only part of the pipeline that runs data-parallel on two cores. All other parts of the pipeline runs on a single Neuron core.

Edit the Prompts below to see what you can create.

In [None]:
# --- Load all compiled models ---
COMPILER_WORKDIR_ROOT = 'sdxl_base_and_refiner_compile_dir_1024'
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0"

unet_base_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'unet_base/model.pt')
unet_refiner_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'unet_refiner/model.pt')
decoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder/model.pt')
post_quant_conv_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv/model.pt')

pipe_base = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float32, low_cpu_mem_usage=True)

# Load the compiled UNet (base) onto two neuron cores.
pipe_base.unet = NeuronUNet(UNetWrap(pipe_base.unet))
device_ids = [0,1]
pipe_base.unet.unetwrap = torch_neuronx.DataParallel(torch.jit.load(unet_base_filename), device_ids, set_dynamic_batching=False)

# Load other compiled models onto a single neuron core.
pipe_base.vae.decoder = torch.jit.load(decoder_filename)
pipe_base.vae.post_quant_conv = torch.jit.load(post_quant_conv_filename)

# Load the compiled UNet (refiner) onto two neuron cores.
pipe_refiner = DiffusionPipeline.from_pretrained(
    refiner_model_id,
    text_encoder_2=pipe_base.text_encoder_2,
    vae=pipe_base.vae,
    torch_dtype=torch.float32,
    low_cpu_mem_usage=True,
)

pipe_refiner.unet = NeuronUNet(UNetWrap(pipe_refiner.unet))
device_ids = [0,1]
pipe_refiner.unet.unetwrap = torch_neuronx.DataParallel(torch.jit.load(unet_refiner_filename), device_ids, set_dynamic_batching=False)


# Run pipeline
prompt = ["a photo of an astronaut riding a horse on mars",
          "sonic on the moon",
          "elvis playing guitar while eating a hotdog",
          "saved by the bell",
          "engineers eating lunch at the opera",
          "panda eating bamboo on a plane",
          "A digital illustration of a steampunk flying machine in the sky with cogs and mechanisms, 4k, detailed, trending in artstation, fantasy vivid colors",
          "kids playing soccer at the FIFA World Cup"
         ]

# Define how many steps and what % of steps to be run on each experts (80/20) here
n_steps = 40
high_noise_frac = 0.8

# First do a warmup run so all the asynchronous loads can finish
image_warmup = run_refiner_and_base(pipe_base, pipe_refiner, prompt[0], n_steps, high_noise_frac)

plt.title("Image")
plt.xlabel("X pixel scaling")
plt.ylabel("Y pixels scaling")

total_time = 0
for x in prompt:
    start_time = time.time()
    image = run_refiner_and_base(pipe_base, pipe_refiner, x, n_steps, high_noise_frac)
    total_time = total_time + (time.time()-start_time)
    image.save("image.png")
    image = mpimg.imread("image.png")
    #clear_output(wait=True)
    plt.imshow(image)
    plt.show()
print("Average time: ", np.round((total_time/len(prompt)), 2), "seconds")


**Now have Fun**

Uncomment the cell below for interactive experiment with different prompts.

In [None]:
# # Define how many steps and what % of steps to be run on each experts (80/20) here
# n_steps = 40
# high_noise_frac = 0.8

# user_input = ""
# print("Enter Prompt, type exit to quit")
# while user_input != "exit": 
#     total_time = 0
#     user_input = input("What prompt would you like to give?  ")
#     if user_input == "exit":
#         break
#     start_time = time.time()
#     image = run_refiner_and_base(pipe_base, pipe_refiner, user_input, n_steps, high_noise_frac)
#     total_time = total_time + (time.time()-start_time)
#     image.save("image2.png")

#     plt.title("Image")
#     plt.xlabel("X pixel scaling")
#     plt.ylabel("Y pixels scaling")

#     image = mpimg.imread("image.png")
#     plt.imshow(image)
#     plt.show()
#     print("time: ", np.round(total_time, 2), "seconds")
