# Pruna

Pruna is a model-optimization framework designed for developers, enabling you to deliver faster, more efficient models. With Pruna, you can combine multiple optimization algorithms in less than 10 lines of code.

There are four main algorithm groups for optimizing 🤗 Diffusers pipelines: quantizers, compilers, cachers, and factorizers. The overview below shows which algorithms are currently supported in each category. Depending on your pipeline, different combinations yield the best results.

<img src="../assets/images/diffusers_combinations.png" alt="Algorithms for 🤗 Diffusers pipelines" width="400" />

Pruna is available on PyPI, so you can install it using pip:

```bash
pip install pruna
```

### FLUX.1 [dev]

To optimize FLUX, combine Pruna’s FORA cacher with torch.compile and TorchAO’s dynamic quantization for the best results.

<img src="../assets/images/flux_combination.png" alt="Algorithm Combination for FLUX" width="400" />

The snippet below is all you need to use this combination.

In [None]:
import torch
from diffusers import FluxPipeline

from pruna import SmashConfig, smash

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")

smash_config = SmashConfig()
smash_config["cacher"] = "fora"
smash_config["fora_interval"] = 2  # or 3 for even faster inference
smash_config["compiler"] = "torch_compile"
smash_config["torch_compile_mode"] = "max-autotune-no-cudagraphs"
smash_config["quantizer"] = "torchao"
smash_config["torchao_quant_type"] = "int8dq"
smash_config["torchao_excluded_modules"] = "norm+embedding"
smashed_pipe = smash(pipe, smash_config)

smashed_pipe("a knitted purple prune").images[0]

This combiantion accelerates inference by up to 4.2× and cuts peak GPU memory usage from 34.7 GB to 28.0 GB, all while maintaining virtually the same output quality. The times were measured on an NVIDIA L40S GPU, so your results may vary on different hardware.

<img src="../assets/images/flux_smashed_comparison.png" alt="FLUX image comparison"/>

You’re now set to enjoy lightning-fast inference with FLUX.1 [dev]. To dive into the algorithms we used, explore their hyperparameters, and learn how each impacts output quality, keep reading.

## Cacher

Diffusion models generate images by starting with pure noise and gradually removing it over multiple inference steps until the final image emerges. At each step, a blurry image is fed into a neural network backbone (for example a transformer), which predicts the noise that should be subtracted from the image.

Recent papers have shown that consecutive backbone passes share many similarities. In particular, the outputs of expensive operations within the backbone tend to remain almost the same from one step to the next. This finding motivates the use of caching: if these outputs only differ slightly, we can compute them once and reuse them in subsequent steps.

<img src="../assets/images/fora_caching.png" alt="FORA caching" width="600" />

In Pruna, each caching algorithm provides an `interval` hyperparameter to tune caching aggressiveness. An `interval` of 3, for example, runs the backbone every third step and reuses the cached outputs for the two intervening steps.

### FLUX.1 [dev]

For FLUX you can choose between 3 cachers: FORA, PAB, and FasterCache.

In [None]:
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")

cacher = "fora"  # or "pab" or "fastercache"
interval = 2  # 3, 4

smash_config = SmashConfig()
smash_config["cacher"] = cacher
smash_config[f"{cacher}_interval"] = interval
smashed_pipe = smash(pipe, smash_config)

smashed_pipe("a knitted purple prune").images[0]

To understand how different cacher-interval configurations impact our model, we measured both inference time and quality metrics on DrawBench:

<img src="../assets/images/flux_cacher_comparison.png" alt="FLUX cacher comparison" width="600" />

As the chart shows, the FORA cacher achieves the largest speedup while also scoring highest on the ARNIQA metric. We can clearly see the trade-off controlled by the interval parameter: as the interval increases, inference becomes faster but at the cost of quality. In fact, across all cachers, increasing the interval from 3 to 4 causes a drop in quality.

## Compiler

While caching reduces the number of times expensive operations in the backbone are computed, another way of obtaining a speedup is to optimize these operations themselves. Compilers analyze the computations in the backbone and determine which operations can be fused or performed with existing Triton kernels to make computations more efficient. The only draw back is that the first execution takes longer.

### FLUX.1 [dev]

For FLUX you can choose between 2 compilers: torch.compile and Stable Fast

In [None]:
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")

compiler = "torch_compile"  # or "stable_fast"

smash_config = SmashConfig()
smash_config["cacher"] = "fora"
smash_config["fora_interval"] = 2  # 3, 4
smash_config["compiler"] = compiler
smashed_pipe = smash(pipe, smash_config)

smashed_pipe("a knitted purple prune").images[0]

When paired with FORA, we don’t compile the entire Transformer at once—instead, we compile it block by block. This approach slashes cold-start latency by nearly 50%.

<img src="../assets/images/flux_latency.png" alt="FLUX compiler warm up time" width="600" />

As the plots below demonstrate, compilation has virtually no impact on output quality. For FLUX, torch.compile achieves greater speedups than Stable Fast.

<img src="../assets/images/flux_compiler_comparison.png" alt="FLUX compiler warm up time" width="600" />

## Quantizer

Quantization lowers the precision of the numbers used to represent a model’s parameters and calculations. By employing lower bit widths (for example, converting 16-bit floating-point values to 8-bit integers), it reduces model size and speeds up inference.

### FLUX.1 [dev]

Given its compute-intensive attention mechanism, FLUX benefits most from dynamic quantization, which quantizes both weights and activations. For instance, torchao's dynamic quantization can yield an additional speedup on top of torch.compile. Because modules such as normalization layers can be sensitive to dynamic quantization, we make it easy to exclude them. To get speedups with torchao, we have to use it with torch.compiles "max-autotune-no-cudagraphs" mode that will increase the cold start time compared to the default compile mode.

In [None]:
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")

smash_config = SmashConfig()
smash_config["cacher"] = "fora"
smash_config["fora_interval"] = 2  # 3, 4
smash_config["compiler"] = "torch_compile"
smash_config["quantizer"] = "torchao"
smash_config["torchao_quant_type"] = "int8dq"
smash_config["torchao_excluded_modules"] = "norm+embedding"  # or "none"
smashed_pipe = smash(pipe, smash_config)

smashed_pipe("a knitted purple prune").images[0]

As can be seen in the plots below, if these critical modules are filtered out correctly applying dynamic quantization does barely affect the quality while giving a speedup. Further it cuts peak GPU memory usage from 34.7 GB to 28.0 GB.

<img src="../assets/images/flux_quantizer_comparison.png" alt="FLUX compiler warm up time" width="600" />