## Stable Diffusion on SPR with IPEX

In [None]:
import torch

from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker

from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

In [None]:
# Intel token
# MY_TOKEN="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"
# My token
MY_TOKEN='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'

In [None]:
%env ONEDNN_VERBOSE=1

In [None]:
# Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer", use_auth_token=MY_TOKEN)
text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_auth_token=MY_TOKEN)
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_auth_token=MY_TOKEN)
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_auth_token=MY_TOKEN)

In [None]:
import intel_extension_for_pytorch as ipex

unet.eval()
# text_encoder.eval()
# unet = unet.to(memory_format=torch.channels_last)

# unet = ipex.optimize(unet)

unet = ipex.optimize(unet, dtype=torch.bfloat16)
# text_encoder = ipex.optimize(text_encoder, dtype=torch.bfloat16)

In [None]:
pipeline = StableDiffusionPipeline(
    text_encoder=text_encoder,
    vae=vae,
    unet=unet,
    tokenizer=tokenizer,
    scheduler=PNDMScheduler(beta_start=0.00085, 
                            beta_end=0.012, 
                            beta_schedule="scaled_linear", 
                            skip_prk_steps=True),
    safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
    feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"))

**Single image inference**

In [None]:
prompt = "Painting of a frog with hat on a bicycle cycling in New York City at a beautiful dusk with a traffic jam and moody people in the style of Picasso"

# Setting seed for deterministic output
generator = torch.Generator("cpu").manual_seed(777)

with torch.cpu.amp.autocast():
    image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]

image.save("frog_test.png")

**Batched inference**

In [None]:
from PIL import Image

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

In [None]:
num_images = 3

prompt = ["Painting of a frog with hat on a bicycle cycling in New York City at a beautiful dusk with a traffic jam and moody people in the style of Picasso"] * num_images

with torch.cpu.amp.autocast():
    images = pipeline(prompt).images

grid = image_grid(images, rows=1, cols=3)

grid.save(f"frog_batch.png")