# Recovering Quality after Quantizing Models to 4 Bits

<a target="_blank" href="https://colab.research.google.com/github/PrunaAI/pruna/blob/v|version|/docs/tutorials/recovery.ipynb">
    <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

### 1. Loading the Sana Model

First, load the Sana model, and generate an image for quality reference.

In [None]:
import torch
from diffusers import SanaPipeline

pipe = SanaPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
    torch_dtype=torch.bfloat16,
).to("cuda")


We generate an image to have a reference for quality.

In [None]:
prompt = "A crow walking along a river near a foggy cliff, with cute yellow ducklings following it in a line, at sunset."
pipe(prompt).images[0]

### 2. Initializing the SmashConfig

In [None]:
from pruna import SmashConfig

smash_config = SmashConfig({
    # Quantize the model to 4-bits
    "diffusers_int8": {
        "weight_bits": 4
    },
    # Recover, allowing you to push quantization to lower bit rates without compromising quality
    "text_to_image_perp": {
        # you can increase or reduce 'batch_size' depending on your GPU, or use 'gradient_accumulation_steps' with it
        "batch_size": 8,
        "num_epochs": 4,
        "validate_every_n_epoch": 0.5 # run validation every half epoch
    }
})
# Attach a text-to-image dataset, used for recovery
smash_config.add_data("COCO")
smash_config.data.limit_datasets((256, 64, 1))  # training on 256 samples and validating on 64

### 3. Smashing the Model

Now, smash the model. This takes about 9 minutes on an L40S GPU, but it depends on how many samples are used for recovery.
Recovery logging is handled though __Weights & Biases__, make sure you have it installed and set up in your environment.

In [None]:
from pruna import smash

smashed_model = smash(
    model=pipe,
    smash_config=smash_config,
)

### 4. Running the Model
Finally, we run the model which has been quantized and recovered. It has a lower memory footprint than the original because of the quantization.

In [None]:
smashed_model(prompt).images[0]

### Wrap up