## HuggingFace Stable Diffusion Inpainting (512x512) Inference on Inf2

**Introduction**

This notebook demonstrates how to compile and run the HuggingFace Stable Diffusion Inpainting (512x512) model for accelerated inference on Neuron.

This Jupyter notebook should be run on an Inf2 instance (`inf2.8xlarge` or larger)

Verify that this Jupyter notebook is running the Python kernel environment that was set up according to the [PyTorch Installation Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/torch-neuronx.html#setup-torch-neuronx). You can select the kernel from the 'Kernel -> Change Kernel' option on the top of this Jupyter notebook page.

**Install Dependencies**

This tutorial requires the following pip packages to be installed:
- `torch-neuronx`
- `neuronx-cc`
- `diffusers==0.20.0`
- `transformers==4.26.1`
- `accelerate==0.16.0`
- `matplotlib`

`torch-neuronx` and `neuronx-cc` will be installed when you configure your environment following the Inf2 setup guide. The remaining dependencies can be installed below:

In [None]:
%env TOKENIZERS_PARALLELISM=True #Supresses tokenizer warnings making errors easier to detect
!pip install diffusers==0.20.0 transformers==4.26.1 accelerate==0.16.0 matplotlib

**imports**

In [None]:
import os
 
import numpy as np
import torch
import torch.nn as nn
import torch_neuronx
from diffusers import StableDiffusionInpaintPipeline, DPMSolverMultistepScheduler
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.models.attention_processor import Attention
 
from matplotlib import pyplot as plt
from matplotlib import image as mpimg
import time
import copy
from IPython.display import clear_output
from PIL import Image

clear_output(wait=False)

**Define utility classes and functions**

The following section defines some utility classes and functions. In particular, we define a double-wrapper for the UNet. These wrappers enable `torch_neuronx.trace` to trace the wrapped models for compilation with the Neuron compiler. In addition, the `get_attention_scores_neuron` utility function performs optimized attention score calculation and is used to replace the origianl `get_attention_scores` function in the `diffusers` package via a monkey patch (see the next code block under "Compile UNet and save" for usage).

In [None]:
def get_attention_scores_neuron(self, query, key, attn_mask):    
    if(query.size() == key.size()):
        attention_scores = custom_badbmm(
            key,
            query.transpose(-1, -2),
            self.scale
        )
        attention_probs = attention_scores.softmax(dim=1).permute(0,2,1)

    else:
        attention_scores = custom_badbmm(
            query,
            key.transpose(-1, -2),
            self.scale
        )
        attention_probs = attention_scores.softmax(dim=-1)
  
    return attention_probs
 

def custom_badbmm(a, b, scale):
    bmm = torch.bmm(a, b)
    scaled = bmm * scale
    return scaled
 

class UNetWrap(nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.unet = unet
 
    def forward(self, sample, timestep, encoder_hidden_states, text_embeds=None, time_ids=None):
        out_tuple = self.unet(sample,
                              timestep,
                              encoder_hidden_states,
                              return_dict=False)
        return out_tuple
    
    
class NeuronUNet(nn.Module):
    def __init__(self, unetwrap):
        super().__init__()
        self.unetwrap = unetwrap
        self.config = unetwrap.unet.config
        self.in_channels = unetwrap.unet.in_channels
        self.device = unetwrap.unet.device
 
    def forward(self, sample, timestep, encoder_hidden_states, added_cond_kwargs=None, return_dict=False, cross_attention_kwargs=None):
        sample = self.unetwrap(sample,
                               timestep.float().expand((sample.shape[0],)),
                               encoder_hidden_states)[0]
        return UNet2DConditionOutput(sample=sample)

In [None]:
COMPILER_WORKDIR_ROOT = 'sd_inpainting_compile_dir_512'

# Model ID for SD XL version pipeline
model_id = "runwayml/stable-diffusion-inpainting"

# --- Compile VAE decoder and save ---

# Only keep the model being compiled in RAM to minimze memory pressure
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
decoder = copy.deepcopy(pipe.vae.decoder)
del pipe

# # Compile vae decoder
decoder_in = torch.randn([1, 4, 64, 64])
decoder_neuron = torch_neuronx.trace(
    decoder, 
    decoder_in, 
    compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder'),
)


# Save the compiled vae decoder
decoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder/model.pt')
torch.jit.save(decoder_neuron, decoder_filename)

# delete unused objects
del decoder

# --- Compile UNet and save ---

pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=torch.float32)

# Replace original cross-attention module with custom cross-attention module for better performance
Attention.get_attention_scores = get_attention_scores_neuron

# Apply double wrapper to deal with custom return type
pipe.unet = NeuronUNet(UNetWrap(pipe.unet))

# Only keep the model being compiled in RAM to minimze memory pressure
unet = copy.deepcopy(pipe.unet.unetwrap)
del pipe

# Compile unet - FP32
sample_1b = torch.randn([1, 9, 64, 64])
timestep_1b = torch.tensor(999).float().expand((1,))
encoder_hidden_states_1b = torch.randn([1, 77, 768])
example_inputs = (sample_1b, timestep_1b, encoder_hidden_states_1b)

unet_neuron = torch_neuronx.trace(
    unet,
    example_inputs,
    compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'unet'),
    compiler_args=["--model-type=unet-inference"]
)

# save compiled unet
unet_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'unet/model.pt')
torch.jit.save(unet_neuron, unet_filename)

# delete unused objects
del unet


# --- Compile VAE post_quant_conv and save ---

# Only keep the model being compiled in RAM to minimze memory pressure
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
post_quant_conv = copy.deepcopy(pipe.vae.post_quant_conv)
del pipe

# Compile vae post_quant_conv
post_quant_conv_in = torch.randn([1, 4, 64, 64])
post_quant_conv_neuron = torch_neuronx.trace(
    post_quant_conv, 
    post_quant_conv_in,
    compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv'),
)

# Save the compiled vae post_quant_conv
post_quant_conv_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv/model.pt')
torch.jit.save(post_quant_conv_neuron, post_quant_conv_filename)

# delete unused objects
del post_quant_conv

**Load the saved model and run it**

Now that the model is compiled, you can reload it with any number of prompts. Note the use of the `torch_neuronx.DataParallel` API to load the UNet onto two neuron cores for data-parallel inference. Currently the UNet is the only part of the pipeline that runs data-parallel on two cores. All other parts of the pipeline runs on a single Neuron core.

Edit the Prompts below to see what you can create.

In [None]:
# --- Load all compiled models ---
COMPILER_WORKDIR_ROOT = 'sd_inpainting_compile_dir_512'
model_id = "runwayml/stable-diffusion-inpainting"
decoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder/model.pt')
unet_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'unet/model.pt')
post_quant_conv_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv/model.pt')

pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

# Load the compiled UNet onto two neuron cores.
pipe.unet = NeuronUNet(UNetWrap(pipe.unet))
device_ids = [0,1]
pipe.unet.unetwrap = torch_neuronx.DataParallel(torch.jit.load(unet_filename), device_ids, set_dynamic_batching=False)

# Load other compiled models onto a single neuron core.
pipe.vae.decoder = torch.jit.load(decoder_filename)
pipe.vae.post_quant_conv = torch.jit.load(post_quant_conv_filename)


### Load Image and Load Image Mask
Image and mask_image should be PIL images
The mask structure is white for inpainting and black for keeping it as is. 

See https://huggingface.co/runwayml/stable-diffusion-inpainting for sample images. 

Save the image and the image mask and modify the filenames below as needed.

In [None]:
# File names to save
image_filename = 'image.png'
mask_image_filename = 'image_mask.png'

# Load images
image = Image.open(image_filename)
mask_image = Image.open(mask_image_filename)

In [None]:
# Prompt
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"

# Run pipeline
total_time = 0
start_time = time.time()
image_output = pipe(prompt=prompt, image=image, mask_image=mask_image).images[0]
total_time = total_time + (time.time() - start_time)
image_output.save('image_output.png')
plt.title("Image")
plt.xlabel("X pixels scaling")
plt.ylabel("Y pixels scaling")
plt.imshow(image_output)
plt.show()
print("Total time: ", np.round((total_time/len(prompt)), 2), "seconds")

**Now have Fun**

Uncomment the cell below for interactive experiment with different prompts.

In [None]:
# user_input = ""
# print("Enter Prompt, type exit to quit")
# while user_input != "exit":
#     total_time = 0
#     user_input = input("What prompt would you like to give?  ")
#     if user_input == "exit":
#         break
#     start_time = time.time()
#     image_output = pipe(prompt=user_input, image=image, mask_image=mask_image).images[0]
#     total_time = total_time + (time.time()-start_time)
#     image_output.save("image_output.png")

#     plt.title("Image")
#     plt.xlabel("X pixel scaling")
#     plt.ylabel("Y pixels scaling")

#     image_output = mpimg.imread("image_output.png")
#     plt.imshow(image_output)
#     plt.show()
#     print("time: ", np.round(total_time, 2), "seconds")
