# Stable Diffusion v2 Demo with Torch Compile

## Prerequisites

install required packages

In [None]:
%pip install -q "diffusers>=0.14.0" openvino-nightly "datasets>=2.14.6" "transformers>=4.25.1" "gradio>=4.19" "torch>=2.1" Pillow opencv-python --extra-index-url https://download.pytorch.org/whl/cpu
%pip install -q git+https://github.com/anzr299/nncf.git
%pip install -q accelerate

## Stable Diffusion v2 for Text-to-Image Generation

To start, let's look on Text-to-Image process for Stable Diffusion v2. We will use [Stable Diffusion v2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) model for these purposes. The main difference from Stable Diffusion v2 and Stable Diffusion v2.1 is usage of more data, more training, and less restrictive filtering of the dataset, that gives promising results for selecting wide range of input text prompts. More details about model can be found in [Stability AI blog post](https://stability.ai/blog/stablediffusion2-1-release7-dec-2022) and original model [repository](https://github.com/Stability-AI/stablediffusion).

### Stable Diffusion in Diffusers library
To work with Stable Diffusion v2, we will use Hugging Face [Diffusers](https://github.com/huggingface/diffusers) library. To experiment with Stable Diffusion models, Diffusers exposes the [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/using-diffusers/conditional_image_generation) similar to the [other Diffusers pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview).  The code below demonstrates how to create `StableDiffusionPipeline` using `stable-diffusion-2-1`:

In [1]:
from torch._export import capture_pre_autograd_graph
from nncf.torch.dynamic_graph.patch_pytorch import disable_patching
import numpy as np

INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, onnx, openvino


In [2]:
from diffusers import StableDiffusion3Pipeline
import torch
import random
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
generator = torch.Generator(device="cpu").manual_seed(42)
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers")
pipe.to("cpu")

# if using torch < 2.0
# pipe.enable_xformers_memory_efficient_attention()

prompt = "valley in the Alps at sunset, epic vista, beautiful landscape, 4k, 8k"
negative_prompt = "frames, borderline, text, charachter, duplicate, error, out of frame, watermark, low quality, ugly, deformed, blur"

# images = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=25, generator=generator).images[0]
# images.show()
# images.save("/home/user/Downloads/stable-diffusion-xl-experiments/stable_diffusion_xl_all_pytorch_model.png")# del images

model_index.json:   0%|          | 0.00/706 [00:00<?, ?B/s]

Fetching 26 files:   0%|          | 0/26 [00:00<?, ?it/s]

text_encoder_2/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

scheduler/scheduler_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

text_encoder_3/config.json:   0%|          | 0.00/740 [00:00<?, ?B/s]

text_encoder/config.json:   0%|          | 0.00/574 [00:00<?, ?B/s]

(…)t_encoder_3/model.safetensors.index.json:   0%|          | 0.00/19.9k [00:00<?, ?B/s]

tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer/special_tokens_map.json:   0%|          | 0.00/588 [00:00<?, ?B/s]

tokenizer/tokenizer_config.json:   0%|          | 0.00/705 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.39G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.53G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/247M [00:00<?, ?B/s]

tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

tokenizer_2/special_tokens_map.json:   0%|          | 0.00/576 [00:00<?, ?B/s]

tokenizer_2/tokenizer_config.json:   0%|          | 0.00/856 [00:00<?, ?B/s]

tokenizer_3/special_tokens_map.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

tokenizer_3/tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

tokenizer_3/tokenizer_config.json:   0%|          | 0.00/20.6k [00:00<?, ?B/s]

transformer/config.json:   0%|          | 0.00/372 [00:00<?, ?B/s]

vae/config.json:   0%|          | 0.00/739 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/4.17G [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/168M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/9 [00:00<?, ?it/s]

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
pipe.text_encoder = torch.compile(pipe.text_encoder, backend='openvino')
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, backend='openvino')
pipe.vae.decoder = torch.compile(pipe.vae.decoder, backend='openvino')
pipe.vae.encoder = torch.compile(pipe.vae.encoder, backend='openvino')
pipe.unet = torch.compile(pipe.unet, backend='openvino')

In [6]:
def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    model_size_mb = (param_size + buffer_size) / 1024**2

    return model_size_mb

In [7]:
get_model_size(pipe.unet)

9794.096694946289

In [8]:
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
with torch.no_grad():
    pipe(prompt ,num_inference_steps=1)
    image = pipe(prompt=prompt, negative_prompt=negative_prompt ,num_inference_steps=25, generator=generator).images[0]
image.show()
image.save("/home/user/Downloads/stable-diffusion-xl-experiments/stable_diffusion_xl_all_torch_compile_model.png")

  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
Reached
torch.Size([1, 4, 128, 128])


  0%|          | 0/25 [00:00<?, ?it/s]

torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([

### Convert Models to Torch Fx Graph

In [3]:
text_encoder_input = torch.ones((1, 77), dtype=torch.long)
text_encoder_2_input = torch.ones((1, 77), dtype=torch.long)
vae_encoder_input = torch.ones((1, 3, 256, 256))
vae_decoder_input = torch.ones((1, 4, 128, 128))
vae_decoder_kwargs = {}
vae_decoder_kwargs["return_dict"] = False
latents_shape = (2, 4, 128, 128)
latents = torch.randn(latents_shape)
t = torch.from_numpy(np.array(1, dtype=np.float32))
added_cond_kwargs = {}
added_cond_kwargs["text_embeds"] = torch.ones((2, 1280))
added_cond_kwargs["time_ids"] = torch.ones((2,6))
unet_kwargs = {}
unet_kwargs["encoder_hidden_states"] = torch.ones((2, 77, 2048))
unet_kwargs["added_cond_kwargs"] = added_cond_kwargs
unet_kwargs["return_dict"] = False
unet_input = (latents, t)

text_encoder_kwargs = {}
text_encoder_kwargs['output_hidden_states'] = True

with torch.no_grad():
    with disable_patching():
        text_encoder = capture_pre_autograd_graph(pipe.text_encoder.eval(), args=(text_encoder_input,), kwargs=(text_encoder_kwargs))
        text_encoder_2 = capture_pre_autograd_graph(pipe.text_encoder_2.eval(), args=(text_encoder_2_input,), kwargs=(text_encoder_kwargs))
        vae_encoder = capture_pre_autograd_graph(pipe.vae.encoder, args=(vae_encoder_input,))
        vae_decoder = capture_pre_autograd_graph(pipe.vae.decoder.eval(), args=(vae_decoder_input,))
        unet = capture_pre_autograd_graph(pipe.unet.eval(), args=(*unet_input,), kwargs=(unet_kwargs))
del added_cond_kwargs
del unet_kwargs
del unet_input
del latents
del t
del vae_encoder_input
del vae_decoder_input
del text_encoder_2_input
del text_encoder_input
del text_encoder_kwargs
del vae_decoder_kwargs

In [4]:
pipe.text_encoder = text_encoder
pipe.text_encoder_2 = text_encoder_2
pipe.vae.decoder = vae_decoder
pipe.vae.encoder = vae_encoder
pipe.unet = unet

### Weights Compression

In [None]:
import datasets
import numpy as np
from tqdm.notebook import tqdm
from typing import Any, Dict, List
import torch

def disable_progress_bar(pipeline, disable=True):
    if not hasattr(pipeline, "_progress_bar_config"):
        pipeline._progress_bar_config = {'disable': disable}
    else:
        pipeline._progress_bar_config['disable'] = disable


class UNetWrapper(torch.nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.unet = unet
        self.captured_args = []

    def forward(self, *args, **kwargs):
        if np.random.rand() <= 0.7:
            self.captured_args.append((*args, *tuple(kwargs.values())))
        return self.unet(*args, **kwargs)

def collect_calibration_data(ov_pipe, calibration_dataset_size: int, num_inference_steps: int) -> List[Dict]:
    
    original_unet = ov_pipe.unet
    calibration_data = []
    disable_progress_bar(ov_pipe)

    dataset = datasets.load_dataset("google-research-datasets/conceptual_captions", split="train", trust_remote_code=True).shuffle(seed=42)

    wrapped_unet = UNetWrapper(ov_pipe.unet)
    ov_pipe.unet = wrapped_unet
    print(ov_pipe.unet)
    # Run inference for data collection
    pbar = tqdm(total=calibration_dataset_size)
    for i, batch in enumerate(dataset):
        prompt = batch["caption"]
        print(prompt)
        if len(prompt) > ov_pipe.tokenizer.model_max_length:
            continue
        # Run the pipeline
        ov_pipe(prompt, num_inference_steps=num_inference_steps, seed=1)
        print("pipe completed")
        calibration_data.extend(wrapped_unet.captured_args)
        wrapped_unet.captured_args = []
        pbar.update(len(calibration_data) - pbar.n)
        if pbar.n >= calibration_dataset_size:
            break

    disable_progress_bar(ov_pipe, disable=False)
    ov_pipe.unet = original_unet
    return calibration_data

In [None]:
import pickle
with open("test_faster", "wb") as fp:
    pickle.dump(unet_calibration_data, fp)

In [None]:
with open("test_faster", "rb") as fp:
    unet_calibration_data = pickle.load(fp)

In [None]:
def collect_ops_with_weights(graph_module):
    ops_with_weights = []
    for node in graph_module.graph.nodes:
        if "linear" in node.name:
            ops_with_weights.append(node.name)
    return ops_with_weights


calibration_dataset_size = 30
unet_calibration_data = collect_calibration_data(pipe,
                                                    calibration_dataset_size=calibration_dataset_size,
                                                    num_inference_steps=10)
unet_ignored_scope = collect_ops_with_weights(pipe.unet)

In [None]:
import nncf
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.quantization.range_estimator import RangeEstimatorParametersSet

# compressed_text_encoder = nncf.compress_weights(text_encoder)
# compressed_text_encoder_2 = nncf.compress_weights(text_encoder_2)
# compressed_vae_encoder = nncf.compress_weights(vae_encoder)
# compressed_vae_decoder = nncf.compress_weights(vae_decoder)
with disable_patching():
    with torch.no_grad():
        # compressed_unet = nncf.compress_weights(unet, ignored_scope=nncf.IgnoredScope(types=['conv2d']))
        quantized_unet = nncf.quantize( #1
            model=pipe.unet,
            calibration_dataset=nncf.Dataset(unet_calibration_data),
            subset_size=len(unet_calibration_data),
            model_type=nncf.ModelType.TRANSFORMER,
            ignored_scope=nncf.IgnoredScope(names=unet_ignored_scope),
            advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=-1, disable_bias_correction=True)
        )
        # quantized_unet = nncf.quantize( #2
        #     model=pipe.unet,
        #     calibration_dataset=nncf.Dataset(unet_calibration_data),
        #     subset_size=len(unet_calibration_data),
        #     model_type=nncf.ModelType.TRANSFORMER,
        #     fast_bias_correction=False,
        #     ignored_scope=nncf.IgnoredScope(names=unet_ignored_scope),
        #     advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=-1, disable_bias_correction=True, weights_range_estimator_params=RangeEstimatorParametersSet.MINMAX, activations_range_estimator_params=RangeEstimatorParametersSet.MINMAX)
        # )
        # quantized_unet = nncf.quantize( #3
        #     model=pipe.unet,
        #     calibration_dataset=nncf.Dataset(unet_calibration_data),
        #     subset_size=len(unet_calibration_data),
        #     model_type=nncf.ModelType.TRANSFORMER,
        #     fast_bias_correction=False,
        #     ignored_scope=nncf.IgnoredScope(names=unet_ignored_scope),
        #     advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=-1)
        # )
        # quantized_unet = nncf.quantize( #4.1
        #     model=pipe.unet,
        #     calibration_dataset=nncf.Dataset(unet_calibration_data),
        #     subset_size=len(unet_calibration_data),
        #     model_type=nncf.ModelType.TRANSFORMER,
        #     fast_bias_correction=False,
        #     ignored_scope=nncf.IgnoredScope(names=unet_ignored_scope),
        #     advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alphas=AdvancedSmoothQuantParameters(convolution=0.95, matmul=-1))
        # )
        # quantized_unet = nncf.quantize( #4.2
        #     model=pipe.unet,
        #     calibration_dataset=nncf.Dataset(unet_calibration_data),
        #     subset_size=len(unet_calibration_data),
        #     model_type=nncf.ModelType.TRANSFORMER,
        #     fast_bias_correction=False,
        #     ignored_scope=nncf.IgnoredScope(names=unet_ignored_scope),
        #     advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alphas=AdvancedSmoothQuantParameters(convolution=0.95))
        # )
# del text_encoder
# del text_encoder_2
# del vae_decoder
del unet

### Compile Models with OV Backend

In [None]:
# compiled_compressed_text_encoder = torch.compile(compressed_text_encoder, backend='openvino')
# compiled_compressed_text_encoder_2 = torch.compile(compressed_text_encoder_2, backend='openvino')
compiled_compressed_unet = torch.compile(quantized_unet, backend='openvino')
# compiled_compressed_vae_encoder = torch.compile(compressed_vae_encoder, backend='openvino')
# compiled_compressed_vae_decoder = torch.compile(compressed_vae_decoder, backend='openvino')

In [None]:
del compressed_text_encoder
del compressed_text_encoder_2
del compressed_vae_decoder
del compressed_unet

In [None]:
# pipe.text_encoder = compiled_compressed_text_encoder
# pipe.text_encoder_2 = compiled_compressed_text_encoder_2
# pipe.vae.encoder = compiled_compressed_vae_encoder
# pipe.vae.decoder = compiled_compressed_vae_decoder
pipe.unet = compiled_compressed_unet

In [None]:
pipe.text_encoder = compressed_text_encoder
pipe.text_encoder_2 = compressed_text_encoder_2
pipe.vae.decoder = compressed_vae_decoder
pipe.unet = compressed_unet

In [None]:
del compiled_compressed_text_encoder
del compiled_compressed_text_encoder_2
del compiled_compressed_vae_decoder
del compiled_compressed_vae_encoder
del compiled_compressed_unet

In [None]:
pipe.text_encoder = text_encoder
pipe.text_encoder_2 = text_encoder_2
pipe.vae.decoder = vae_decoder
# pipe.vae.encoder = vae_encoder
# pipe.unet = unet

### Inference for Compilation

In [None]:
#Warmup the model for initial compile
prompt = "valley in the Alps at sunset, epic vista, beautiful landscape, 4k, 8k"
negative_prompt = "frames, borderline, text, charachter, duplicate, error, out of frame, watermark, low quality, ugly, deformed, blur"
num_steps = 1
with torch.no_grad():
    image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_steps, generator=generator).images[0]

In [None]:
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
with torch.no_grad():
    image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=25, generator=generator).images[0]
image.show()
image.save("stable_diffusion_xl_all_fx_unet_quantized_more_calibration.png")

In [None]:
refiner = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0",
    text_encoder_2=pipe.text_encoder_2,
    vae=pipe.vae
)

In [None]:
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
with torch.no_grad():
    image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=25, generator=generator, output_type="latent").images
    image = refiner(
    prompt=prompt,
    num_inference_steps=25,
    denoising_start=0.8,
    image=image,
).images[0]
image.show()
image.save("stable_diffusion_xl_all_fx_unet_quantized_more_calibration.png")

## Running Inference
Generating an image with the same parameters as the original OV Stable diffusion model for comparison

In [None]:
num_steps = 25

with torch.no_grad():
    image = pipe(prompt, num_inference_steps=num_steps).images[0]
image.show()
