## HuggingFace PixArt Alpha 1k resolution inference on trn2

**Introduction**

This notebook demonstrates how to compile and run the HuggingFace PixArt Alpha 1k resolution model for accelerated inference on Neuron.
This Jupyter notebook should be run on a trn2 instance. This tutorial has a similar structure as `hf_pretrained_pixart_sigma_inference_on_inf2.ipynb` notebook, so we will not repeat the prescribed pattern and jump directly to the code.

In [None]:
# install dependencies
!pip install -r ./requirements.txt

In [None]:
# download the pretained pixart sigma model.
!python neuron_pixart_sigma/cache_hf_model.py

In [None]:
!pip list | grep -i neuron

In [None]:
# compile the component models. 
!sh compile_latency_optimized.sh

In [None]:
# imports
from diffusers import PixArtSigmaPipeline

import neuronx_distributed
import numpy as npy
import time
import torch
import torch_neuronx

from neuron_pixart_sigma.neuron_commons import InferenceTextEncoderWrapper
from neuron_pixart_sigma.neuron_commons import InferenceTransformerWrapper
from neuron_pixart_sigma.neuron_commons import SimpleWrapper

In [None]:
COMPILED_MODELS_DIR = "compile_workdir_latency_optimized"
HUGGINGFACE_CACHE_DIR = "pixart_sigma_hf_cache_dir_1024"

if __name__ == "__main__":
    pipe: PixArtSigmaPipeline = PixArtSigmaPipeline.from_pretrained(
        "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
        torch_dtype=torch.bfloat16,
        local_files_only=True,
        cache_dir="pixart_sigma_hf_cache_dir_1024")

    text_encoder_model_path = f"{COMPILED_MODELS_DIR}/text_encoder"
    transformer_model_path = f"{COMPILED_MODELS_DIR}/transformer" 
    decoder_model_path = f"{COMPILED_MODELS_DIR}/decoder/model.pt"
    post_quant_conv_model_path = f"{COMPILED_MODELS_DIR}/post_quant_conv/model.pt"

    seqlen=300
    text_encoder_wrapper = InferenceTextEncoderWrapper(
        torch.bfloat16, pipe.text_encoder, seqlen
    )
    
    text_encoder_wrapper.t = neuronx_distributed.trace.parallel_model_load(
        text_encoder_model_path
    )

    transformer_wrapper = InferenceTransformerWrapper(pipe.transformer)
    transformer_wrapper.transformer = neuronx_distributed.trace.parallel_model_load(
        transformer_model_path
    )

    vae_decoder_wrapper = SimpleWrapper(pipe.vae.decoder)
    vae_decoder_wrapper.model = torch_neuronx.DataParallel(
        torch.jit.load(decoder_model_path), [0, 1, 2, 3], False
    )
    
    vae_post_quant_conv_wrapper = SimpleWrapper(pipe.vae.post_quant_conv)
    vae_post_quant_conv_wrapper.model = torch_neuronx.DataParallel(
        torch.jit.load(post_quant_conv_model_path), [0, 1, 2, 3], False
    )
    
    pipe.text_encoder = text_encoder_wrapper
    pipe.transformer = transformer_wrapper
    pipe.vae.decoder = vae_decoder_wrapper
    pipe.vae.post_quant_conv = vae_post_quant_conv_wrapper
    
    # Run pipeline
    prompt = "a photo of an astronaut riding a horse on mars"
    negative_prompt = "mountains"
    
    # First do a warmup run so all the asynchronous loads can finish
    image_warmup = pipe(
        prompt=prompt, 
        negative_prompt=negative_prompt, 
        num_images_per_prompt=1, 
        height=1024,
        width=1024,
        num_inference_steps=25
    ).images[0]
    

    images = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_images_per_prompt=1,
        height=1024,
        width=1024,
        num_inference_steps=25
    ).images
    
    for idx, img in enumerate(images): 
        img.save(f"image_{idx}.png")