In [None]:
%pip install -q diffusers accelerate datasets gradio transformers "nncf==2.10.0" "openvino>=2024.1.0" "torch>=2.1" --extra-index-url https://download.pytorch.org/whl/cpu

In [1]:
import torch
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline, StableCascadeUNet

prompt = "an image of a shiba inu, donning a spacesuit and helmet"
negative_prompt = ""

prior_unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade-prior", subfolder="prior_lite")
decoder_unet = StableCascadeUNet.from_pretrained("stabilityai/stable-cascade", subfolder="decoder_lite")

prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", prior=prior_unet)
decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", decoder=decoder_unet)

2024-05-03 23:06:03.939583: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-05-03 23:06:03.940048: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-03 23:06:03.942559: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-03 23:06:03.975149: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

In [5]:
import gc
from pathlib import Path

import torch

import openvino as ov
import nncf


MODELS_DIR = Path("models")

def convert(model: torch.nn.Module, xml_path: str, example_input, input_shape=None):
    xml_path = Path(xml_path)
    if not xml_path.exists():
        model.eval()
        xml_path.parent.mkdir(parents=True, exist_ok=True)
        with torch.no_grad():
            if not input_shape:
                converted_model = ov.convert_model(model, example_input=example_input)
            else:
                converted_model = ov.convert_model(model, example_input=example_input, input=input_shape)
        #converted_model = nncf.compress_weights(converted_model)
        ov.save_model(converted_model, xml_path)
        del converted_model
        
        # cleanup memory
        torch._C._jit_clear_class_registry()
        torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
        torch.jit._state._clear_class_state()

        gc.collect()

In [6]:
PRIOR_PRIOR_MODEL_OV_PATH = MODELS_DIR / "prior_prior_model.xml"

convert(
    prior_unet,
    PRIOR_PRIOR_MODEL_OV_PATH,
    example_input={
        "sample": torch.zeros(2, 16, 24, 24),
        "timestep_ratio": torch.ones(2),
        "clip_text_pooled": torch.zeros(2, 1, 1280),
        "clip_text": torch.zeros(2, 77, 1280),
        "clip_img": torch.zeros(2, 1, 768),
    },
    input_shape=[((2, 16, 24, 24),), ((2),), ((2, 1, 1280),), ((2, 77, 1280),), (2, 1, 768)],
)
del prior.prior
gc.collect();

  if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):


In [7]:
DECODER_DECODER_MODEL_OV_PATH = MODELS_DIR / "decoder_decoder_model.xml"

convert(
    decoder.decoder,
    DECODER_DECODER_MODEL_OV_PATH,
    example_input={
        "sample": torch.zeros(1, 4, 256, 256),
        "timestep_ratio": torch.ones(1),
        "clip_text_pooled": torch.zeros(1, 1, 1280),
        "effnet": torch.zeros(1, 16, 24, 24)
    },
    input_shape=[((1, 4, 256, 256),), ((1),), ((1, 1, 1280),), ((1, 16, 24, 24),)],
)
del decoder.decoder
gc.collect();

In [9]:
core = ov.Core()

from collections import namedtuple


class PriorPriorWrapper:
    def __init__(self, prior_path):
        self.prior = core.compile_model(prior_path, "CPU")
        self.config = namedtuple("PriorWrapperConfig", ["clip_image_in_channels", "in_channels"])(768, 16)  # accessed in the original workflow
        self.parameters = lambda: (torch.zeros(i, dtype=torch.float32) for i in range(1)) # accessed in the original workflow

    def __call__(self, sample, timestep_ratio, clip_text_pooled, clip_text=None, clip_img=None, **kwargs):
        inputs = {
            "sample": sample,
            "timestep_ratio": timestep_ratio,
            "clip_text_pooled": clip_text_pooled,
            "clip_text": clip_text,
            "clip_img": clip_img,
        }
        output = self.prior(inputs)
        return [torch.from_numpy(output[0])]


prior.prior = PriorPriorWrapper(PRIOR_PRIOR_MODEL_OV_PATH)

In [10]:
class DecoderWrapper:
    dtype = torch.float32  # accessed in the original workflow
    
    def __init__(self, decoder_path):
        self.decoder = core.compile_model(decoder_path, "CPU")

    def __call__(self, sample, timestep_ratio, clip_text_pooled, effnet, **kwargs):
        inputs = {
            "sample": sample,
            "timestep_ratio": timestep_ratio,
            "clip_text_pooled": clip_text_pooled,
            "effnet": effnet
        }
        output = self.decoder(inputs)
        return [torch.from_numpy(output[0])]

In [11]:
prior.prior = PriorPriorWrapper(PRIOR_PRIOR_MODEL_OV_PATH)
decoder.decoder = DecoderWrapper(DECODER_DECODER_MODEL_OV_PATH)

In [12]:
class CompiledModelDecorator(ov.CompiledModel):
    def __init__(self, compiled_model):
        super().__init__(compiled_model)
        self.data_cache = []

    def __call__(self, *args, **kwargs):
        self.data_cache.append(*args)
        return super().__call__(*args, **kwargs)

In [20]:
import pickle
import datasets
from tqdm.notebook import tqdm
from transformers import set_seed

set_seed(1)

def collect_calibration_data(prior, decoder, subset_size):
    prior_calibration_dataset_filepath = Path(f"prior_calibration_dataset/{subset_size}.pkl")
    decoder_calibration_dataset_filepath = Path(f"decoder_calibration_dataset/{subset_size}.pkl")
    if not prior_calibration_dataset_filepath.exists():
        original_prior = prior.prior.prior
        original_decoder = decoder.decoder.decoder
        prior.prior.prior = CompiledModelDecorator(original_prior)
        decoder.decoder.decoder = CompiledModelDecorator(original_decoder)
    
        dataset = datasets.load_dataset("conceptual_captions", split="train").shuffle(seed=42)
        pbar = tqdm(total=subset_size)
        diff = 0
        for batch in dataset:
            prompt = batch["caption"]
            if len(prompt) > prior.tokenizer.model_max_length:
                continue
            prior_output = prior(
                prompt=prompt,
                height=1024,
                width=1024,
                negative_prompt=negative_prompt,
                guidance_scale=4.0,
                num_images_per_prompt=1,
                num_inference_steps=20
            )
            
            _ = decoder(
                image_embeddings=prior_output.image_embeddings,
                prompt=prompt,
                negative_prompt=negative_prompt,
                guidance_scale=0.0,
                output_type="pil",
                num_inference_steps=20
            )
            collected_subset_size = len(prior.prior.prior.data_cache)
            if collected_subset_size >= subset_size:
                pbar.update(subset_size - pbar.n)
                break
            pbar.update(collected_subset_size - diff)
            diff = collected_subset_size 
    
        prior_calibration_dataset = prior.prior.prior.data_cache
        decoder_calibration_dataset = decoder.decoder.decoder.data_cache
        prior.prior.prior = original_prior
        decoder.decoder.decoder = original_decoder
        prior_calibration_dataset_filepath.parent.mkdir(exist_ok=True, parents=True)
        with open(prior_calibration_dataset_filepath, 'wb') as f:
            pickle.dump(prior_calibration_dataset, f)
        decoder_calibration_dataset_filepath.parent.mkdir(exist_ok=True, parents=True)
        with open(decoder_calibration_dataset_filepath, 'wb') as f:
            pickle.dump(decoder_calibration_dataset, f)
    else:   
        with open(prior_calibration_dataset_filepath, 'rb') as f:
            prior_calibration_dataset = pickle.load(f)
        with open(decoder_calibration_dataset_filepath, 'rb') as f:
            decoder_calibration_dataset = pickle.load(f)
    
    return prior_calibration_dataset, decoder_calibration_dataset

In [21]:
PRIOR_PRIOR_INT8_PATH = MODELS_DIR / "prior_prior_int8.xml"
DECODER_INT8_PATH = MODELS_DIR / "decoder_int8.xml"

if not (PRIOR_PRIOR_INT8_PATH.exists() and DECODER_INT8_PATH.exists()):
    subset_size = 40
    prior_calibration_dataset, decoder_calibration_dataset = collect_calibration_data(prior, decoder, subset_size=subset_size)

In [22]:
import nncf
from nncf.scopes import IgnoredScope
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters

if not PRIOR_PRIOR_INT8_PATH.exists():
    prior_model = core.read_model(PRIOR_PRIOR_MODEL_OV_PATH)
    quantized_prior_prior = nncf.quantize(
        model=prior_model,
        subset_size=subset_size,
        calibration_dataset=nncf.Dataset(prior_calibration_dataset),
        preset=nncf.QuantizationPreset.PERFORMANCE,
        model_type=nncf.ModelType.TRANSFORMER,
        #advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=-1)
        #advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alphas=AdvancedSmoothQuantParameters(matmul=-1))
    )
    ov.save_model(quantized_prior_prior, PRIOR_PRIOR_INT8_PATH)

Output()

Output()

IndexError: list index out of range

In [24]:
if not DECODER_INT8_PATH.exists():
    decoder_model = core.read_model(DECODER_DECODER_MODEL_OV_PATH)
    quantized_decoder = nncf.quantize(
        model=decoder_model,
        calibration_dataset=nncf.Dataset(decoder_calibration_dataset),
        subset_size=len(decoder_calibration_dataset),
         model_type=nncf.ModelType.TRANSFORMER,
        #advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=-1)
        # advanced_parameters=nncf.AdvancedQuantizationParameters(
        #     disable_bias_correction=True
        # )
        #advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alphas=AdvancedSmoothQuantParameters(matmul=-1))
    )
    ov.save_model(quantized_decoder, DECODER_INT8_PATH)

Output()

Output()

IndexError: list index out of range