### Optimizing and Deploying AI Models with Pruna and Hugging Face

`Goal`: Create an end-to-end tutorial to optimize the black-forest-labs/FLUX.1-dev model using Pruna and deploy it on the Hugging Face Hub.

`Model`:[black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)

`Dataset`: [data-is-better-together/open-image-preferences-v1-binarized](https://huggingface.co/datasets/data-is-better-together/open-image-preferences-v1-binarized)

To complete the tutorial, you need to install the pruna SDK along with a few third-party libraries via pip. It is recommended to run this notebook in a new virtual environment.


In [None]:
pip install pruna 

In [None]:
pip install datasets huggingface_hub gradio diffusers

You will need to login on the Hugging Face Hub for using the model weights. Run the cell below to do the same.

In [None]:
from huggingface_hub import login

login()


Smash Configuration:

In order to optimize the model, we need to define the methods which can help to improve the performance. To know more, you can view the [SmashConfig guide](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html)

In [None]:
import torch
from diffusers import FluxPipeline
from pruna import smash, SmashConfig

# Load the pipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

pipe = pipe.to("cuda")

# Configure Pruna smash
smash_config = SmashConfig()
smash_config["compiler"] = "torch_compile"
smash_config["quantizer"] = "hqq_diffusers" 
smash_config["cacher"]="deepcache" 

# Smash the pipeline
smashed_pipe = smash(model=pipe, smash_config=smash_config)

# Save the model
smashed_model.save_pretrained("saved_models/FLUX.1-schnell-smashed")


Upload the model to HuggingFace Hub

In [None]:
from huggingface_hub import HfApi, Repository

# Load the model
smashed_model = PrunaModel.from_pretrained("saved_models/FLUX.1-schnell-smashed")



api = HfApi()
repo_url = api.create_repo("FLUX.1-schnell-smashed", exist_ok=True, private=False)

repo = Repository(local_dir="FLUX.1-schnell-smashed", clone_from=repo_url)

repo.push_to_hub(commit_message="add FLUX.1-schnell-smashed")


Load Dataset

In [None]:
from datasets import load_dataset

# load the binarized Open Image Preferences prompts
ds = load_dataset("data-is-better-together/open-image-preferences-v1-binarized", split="train")

# preview 10 examples
for example in ds.select(range(10)):
    print(example["prompt"])


Evaluate the model

In [None]:
from datasets import load_dataset
from diffusers import AutoPipelineForText2Image
from pruna.engine.pruna_model import PrunaModel
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.task import Task

# Step 1: Load the dataset and select a subset (e.g., 10 examples)
dataset = load_dataset("data-is-better-together/open-image-preferences-v1-binarized", split="train")
selected_dataset = dataset.select(range(10))  # Use first 10 examples only

# Step 2: Load the Flux Dev model
pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev")
pipe.set_progress_bar_config(disable=True)
model = PrunaModel(pipe)

# Step 3: Define the evaluation task using the selected dataset
task = Task(
    task_type="image_generation",
    dataset=selected_dataset,
    prompt_column="prompt",
    reference_column="image",       # Ensure the column contains reference images or ground truth
    preference_column="preference"  # Used for preference-based evaluation
)

# Step 4: Run evaluation
agent = EvaluationAgent(model=model, task=task)
results = agent.evaluate(metrics=["cmmd"])

# Step 5: Print results
print("Evaluation Results on Selected Subset:")
for metric, score in results.items():
    print(f"{metric}: {score:.4f}")


Gradio Demo

In [None]:
import gradio as gr
from diffusers import DiffusionPipeline

# Load the HiDream model
pipe = DiffusionPipeline.from_pretrained("FLUX.1-schnell-smashed")

# Define the generation function
def generate(prompt):
    return pipe(prompt).images[0]

# Create the Gradio interface
gr.Interface(fn=generate, inputs="text", outputs="image").launch()
