## HuggingFace Stable Diffusion XL 1.0 (1024x1024) Inference on Inf2

**Introduction**

This notebook demonstrates how to compile and run the HuggingFace Stable Diffusion XL (1024x1024) model for accelerated inference on Neuron.

This Jupyter notebook should be run on an Inf2 instance (`inf2.8xlarge` or larger)

Verify that this Jupyter notebook is running the Python kernel environment that was set up according to the [PyTorch Installation Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/torch-neuronx.html#setup-torch-neuronx). You can select the kernel from the 'Kernel -> Change Kernel' option on the top of this Jupyter notebook page.

**Install Dependencies**
This tutorial requires the following pip packages to be installed:
- `torch-neuronx`
- `neuronx-cc`
- `diffusers==0.20.2`
- `transformers==4.33.1`
- `accelerate==0.22.0`
- `matplotlib`

`torch-neuronx` and `neuronx-cc` will be installed when you configure your environment following the Inf2 setup guide. The remaining dependencies can be installed below:

In [1]:
%env TOKENIZERS_PARALLELISM=True #Supresses tokenizer warnings making errors easier to detect
!pip install diffusers==0.20.2 transformers==4.33.1 accelerate==0.22.0 matplotlib

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com
Collecting diffusers==0.20.2
  Downloading diffusers-0.20.2.tar.gz (989 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m989.1/989.1 kB[0m [31m38.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Installing backend dependencies ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: diffusers
  Building wheel for diffusers (pyproject.toml) ... [?25ldone
[?25h  Created wheel for diffusers: filename=diffusers-0.20.2-py3-none-any.whl size=1342633 sha256=2b653082a7d92d9dc1af2ab435deeb49c54efc91e642699a10acf78d36a2cc03
  Stored in directory: /home/ec2-user/.cache/pip/wheels/0b/45/27/33fb12340309a9c256506ebb06cd7cea7e371cb305e2c4bda9
Successfully built di

**imports**

In [1]:
import os
 
import numpy as np
import torch
import torch.nn as nn
import torch_neuronx
from diffusers import DiffusionPipeline
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.models.attention_processor import Attention
from transformers.models.clip.modeling_clip import CLIPTextModelOutput

from matplotlib import pyplot as plt
from matplotlib import image as mpimg
import time
import copy
from util import NeuronUNet, UNetWrap, TextEncoderOutputWrapper, TraceableTextEncoder

from IPython.display import clear_output

clear_output(wait=False)

**Compile the model into an optimized TorchScript and save the TorchScript**

In the following section, we will compile parts of the Stable Diffusion pipeline for execution on Neuron. Note that this only needs to be done once: After you have compiled and saved the model by running the following section of code, you can reuse it any number of times without having to recompile. In particular, we will compile:
1. The text encoders (text_encoder, text_encoder_2)
2. The VAE decoder;
3. The UNet, and
4. The VAE_post_quant_conv
These blocks are chosen because they represent the bulk of the compute in the pipeline, and performance benchmarking has shown that running them on Neuron yields significant performance benefit.

Several points worth noting are:
1. In order to save RAM (these compiles need lots of RAM!), before tracing each model, we make a deepcopy of the part of the pipeline (i.e. the UNet or the VAE decoder) that is to be traced, and then delete the pipeline object from memory with `del pipe`. This trick allows the compile to succeed on instance types with a smaller amount of RAM.
2. When compiling each part of the pipeline, we need to pass `torch_neuronx.trace` sample input(s), When there are multiple inputs, they are passed together as a tuple. For details on how to use `torch_neuronx.trace`, please refer to our documentation here: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/api-reference-guide/inference/api-torch-neuronx-trace.html
3. Note that while compiling the UNet, we make use of the double-wrapper structure defined above. In addition, we also use the optimized `get_attention_scores_neuron` function to replace the original `get_attention_scores` function in the `Attention` class.

In [10]:
COMPILER_WORKDIR_ROOT = 'sdxl_compile_dir_1024'

# Model ID for SD XL version pipeline
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
# LoRA adapter folder
lora_weight_dir = "lora"

# --- Compile Text Encoders and save ---

pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)

# Apply wrappers to make text encoders traceable
traceable_text_encoder = copy.deepcopy(TraceableTextEncoder(pipe.text_encoder))
traceable_text_encoder_2 = copy.deepcopy(TraceableTextEncoder(pipe.text_encoder_2))

del pipe

text_input_ids_1 = torch.tensor([[49406,   736,  1615, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407]])


text_input_ids_2 = torch.tensor([[49406,   736,  1615, 49407,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]])


# Text Encoder 1
neuron_text_encoder = torch_neuronx.trace(
    traceable_text_encoder,
    text_input_ids_1,
    compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder'),
)

text_encoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder/model.pt')
torch.jit.save(neuron_text_encoder, text_encoder_filename)


# Text Encoder 2
neuron_text_encoder_2 = torch_neuronx.trace(
    traceable_text_encoder_2,
    text_input_ids_2,
    compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder_2'),
)

text_encoder_2_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder_2/model.pt')
torch.jit.save(neuron_text_encoder_2, text_encoder_2_filename)

# --- Compile VAE decoder and save ---

# Only keep the model being compiled in RAM to minimze memory pressure
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32, low_cpu_mem_usage=True)
decoder = copy.deepcopy(pipe.vae.decoder)
del pipe

# # Compile vae decoder
decoder_in = torch.randn([1, 4, 128, 128])
decoder_neuron = torch_neuronx.trace(
    decoder, 
    decoder_in, 
    compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder'),
)

# Save the compiled vae decoder
decoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder/model.pt')
torch.jit.save(decoder_neuron, decoder_filename)

# delete unused objects
del decoder


# --- Compile UNet and save ---

pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32, low_cpu_mem_usage=True)

# --- Load Lora weights and fuse the adapter with the base modle ---
pipe.load_lora_weights(lora_weight_dir, weight_name="pytorch_lora_weights.safetensors")

pipe.fuse_lora()

# Replace original cross-attention module with custom cross-attention module for better performance
Attention.get_attention_scores = get_attention_scores_neuron

# Apply double wrapper to deal with custom return type
pipe.unet = NeuronUNet(UNetWrap(pipe.unet))

# Only keep the model being compiled in RAM to minimze memory pressure
unet = copy.deepcopy(pipe.unet.unetwrap)
del pipe

# Compile unet - FP32
sample_1b = torch.randn([1, 4, 128, 128])
timestep_1b = torch.tensor(999).float().expand((1,))
encoder_hidden_states_1b = torch.randn([1, 77, 2048])
added_cond_kwargs_1b = {"text_embeds": torch.randn([1, 1280]),
                        "time_ids": torch.randn([1, 6])}
example_inputs = (sample_1b, timestep_1b, encoder_hidden_states_1b, added_cond_kwargs_1b["text_embeds"], added_cond_kwargs_1b["time_ids"],)

unet_neuron = torch_neuronx.trace(
    unet,
    example_inputs,
    compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'unet'),
    compiler_args=["--model-type=unet-inference"]
)

# save compiled unet
unet_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'unet/model.pt')
torch.jit.save(unet_neuron, unet_filename)

# delete unused objects
del unet


# --- Compile VAE post_quant_conv and save ---

# Only keep the model being compiled in RAM to minimze memory pressure
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32, low_cpu_mem_usage=True)
post_quant_conv = copy.deepcopy(pipe.vae.post_quant_conv)
del pipe

# Compile vae post_quant_conv
post_quant_conv_in = torch.randn([1, 4, 128, 128])
post_quant_conv_neuron = torch_neuronx.trace(
    post_quant_conv, 
    post_quant_conv_in,
    compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv'),
)

# Save the compiled vae post_quant_conv
post_quant_conv_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv/model.pt')
torch.jit.save(post_quant_conv_neuron, post_quant_conv_filename)

# delete unused objects
del post_quant_conv

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

2024-05-10T00:04:31Z Running DoNothing
2024-05-10T00:04:31Z DoNothing finished after 0.000 seconds
2024-05-10T00:04:31Z Running AliasDependencyInduction
2024-05-10T00:04:31Z AliasDependencyInduction finished after 0.003 seconds
2024-05-10T00:04:31Z Running CanonicalizeIR
2024-05-10T00:04:31Z CanonicalizeIR finished after 0.010 seconds
2024-05-10T00:04:31Z Running LegalizeCCOpLayout
2024-05-10T00:04:31Z LegalizeCCOpLayout finished after 0.010 seconds
2024-05-10T00:04:31Z Running ResolveComplicatePredicates
2024-05-10T00:04:31Z ResolveComplicatePredicates finished after 0.009 seconds
2024-05-10T00:04:31Z Running AffinePredicateResolution
2024-05-10T00:04:31Z AffinePredicateResolution finished after 0.010 seconds
2024-05-10T00:04:31Z Running EliminateDivs
2024-05-10T00:04:31Z EliminateDivs finished after 0.010 seconds
2024-05-10T00:04:31Z Running PerfectLoopNest
2024-05-10T00:04:31Z PerfectLoopNest finished after 0.010 seconds
2024-05-10T00:04:31Z Running Simplifier
2024-05-10T00:04:31Z S

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

2024-05-10T00:07:50Z Running DoNothing
2024-05-10T00:07:50Z DoNothing finished after 0.000 seconds
2024-05-10T00:07:50Z Running AliasDependencyInduction
2024-05-10T00:07:50Z AliasDependencyInduction finished after 0.002 seconds
2024-05-10T00:07:50Z Running CanonicalizeIR
2024-05-10T00:07:50Z CanonicalizeIR finished after 0.008 seconds
2024-05-10T00:07:50Z Running LegalizeCCOpLayout
2024-05-10T00:07:50Z LegalizeCCOpLayout finished after 0.008 seconds
2024-05-10T00:07:50Z Running ResolveComplicatePredicates
2024-05-10T00:07:50Z ResolveComplicatePredicates finished after 0.007 seconds
2024-05-10T00:07:50Z Running AffinePredicateResolution
2024-05-10T00:07:50Z AffinePredicateResolution finished after 0.008 seconds
2024-05-10T00:07:50Z Running EliminateDivs
2024-05-10T00:07:50Z EliminateDivs finished after 0.007 seconds
2024-05-10T00:07:50Z Running PerfectLoopNest
2024-05-10T00:07:50Z PerfectLoopNest finished after 0.007 seconds
2024-05-10T00:07:50Z Running Simplifier
2024-05-10T00:07:50Z S

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

  self.in_channels = unetwrap.unet.in_channels


2024-05-10T00:21:27Z Running DoNothing
2024-05-10T00:21:27Z DoNothing finished after 0.000 seconds
2024-05-10T00:21:27Z Running AliasDependencyInduction
2024-05-10T00:21:27Z AliasDependencyInduction finished after 0.024 seconds
2024-05-10T00:21:27Z Running CanonicalizeIR
2024-05-10T00:21:27Z CanonicalizeIR finished after 0.075 seconds
2024-05-10T00:21:27Z Running LegalizeCCOpLayout
2024-05-10T00:21:27Z LegalizeCCOpLayout finished after 0.079 seconds
2024-05-10T00:21:27Z Running ResolveComplicatePredicates
2024-05-10T00:21:27Z ResolveComplicatePredicates finished after 0.075 seconds
2024-05-10T00:21:27Z Running AffinePredicateResolution
2024-05-10T00:21:27Z AffinePredicateResolution finished after 0.076 seconds
2024-05-10T00:21:27Z Running EliminateDivs
2024-05-10T00:21:27Z EliminateDivs finished after 0.074 seconds
2024-05-10T00:21:27Z Running PerfectLoopNest
2024-05-10T00:21:27Z PerfectLoopNest finished after 0.074 seconds
2024-05-10T00:21:27Z Running Simplifier
2024-05-10T00:21:28Z S

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

2024-05-10T00:41:21Z Running DoNothing
2024-05-10T00:41:21Z DoNothing finished after 0.000 seconds
2024-05-10T00:41:21Z Running AliasDependencyInduction
2024-05-10T00:41:21Z AliasDependencyInduction finished after 0.000 seconds
2024-05-10T00:41:21Z Running CanonicalizeIR
2024-05-10T00:41:21Z CanonicalizeIR finished after 0.000 seconds
2024-05-10T00:41:21Z Running LegalizeCCOpLayout
2024-05-10T00:41:21Z LegalizeCCOpLayout finished after 0.000 seconds
2024-05-10T00:41:21Z Running ResolveComplicatePredicates
2024-05-10T00:41:21Z ResolveComplicatePredicates finished after 0.000 seconds
2024-05-10T00:41:21Z Running AffinePredicateResolution
2024-05-10T00:41:21Z AffinePredicateResolution finished after 0.000 seconds
2024-05-10T00:41:21Z Running EliminateDivs
2024-05-10T00:41:21Z EliminateDivs finished after 0.000 seconds
2024-05-10T00:41:21Z Running PerfectLoopNest
2024-05-10T00:41:21Z PerfectLoopNest finished after 0.000 seconds
2024-05-10T00:41:21Z Running Simplifier
2024-05-10T00:41:21Z S

**Load the saved model and run it**

Now that the model is compiled, you can reload it with any number of prompts. Note the use of the `torch_neuronx.DataParallel` API to load the UNet onto two neuron cores for data-parallel inference. Currently the UNet is the only part of the pipeline that runs data-parallel on two cores. All other parts of the pipeline runs on a single Neuron core.

Edit the Prompts below to see what you can create.

In [5]:
# --- Load all compiled models and run pipeline ---
COMPILER_WORKDIR_ROOT = 'sdxl_compile_dir_1024'
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
text_encoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder/model.pt')
text_encoder_2_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder_2/model.pt')
decoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder/model.pt')
unet_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'unet/model.pt')
post_quant_conv_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv/model.pt')

pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)

# Load the compiled UNet onto two neuron cores.
pipe.unet = NeuronUNet(UNetWrap(pipe.unet))
device_ids = [0,1]
pipe.unet.unetwrap = torch_neuronx.DataParallel(torch.jit.load(unet_filename), device_ids, set_dynamic_batching=False)

# Load other compiled models onto a single neuron core.
pipe.vae.decoder = torch.jit.load(decoder_filename)
pipe.vae.post_quant_conv = torch.jit.load(post_quant_conv_filename)
pipe.text_encoder = TextEncoderOutputWrapper(torch.jit.load(text_encoder_filename), pipe.text_encoder)
pipe.text_encoder_2 = TextEncoderOutputWrapper(torch.jit.load(text_encoder_2_filename), pipe.text_encoder_2)

# Run pipeline
prompt = """
photo of <<TOK>> pencil sketch, young and beautiful, face front, centered
"""         


negative_prompt = """
ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, 
watermark, grainy, signature, cut off, draft, amateur, multiple, gross, weird, uneven, furnishing, decorating, decoration, furniture, text, poor, low, basic, worst, juvenile, 
unprofessional, failure, crayon, oil, label, thousand hands
"""

seed = 491057365
generator = [torch.Generator().manual_seed(seed)]

image = pipe(prompt, 
             num_inference_steps=50, 
             guidance_scale=7, 
             negative_prompt=negative_prompt,
             generator=generator).images[0]





Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

  self.in_channels = unetwrap.unet.in_channels


  0%|          | 0/50 [00:00<?, ?it/s]