<a href="https://colab.research.google.com/github/jeaneigsi/cookbook/blob/main/Idefics2_inference_optimization_tradeoffs_between_memory_and_speed.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This colab provides the code to run the memory/speed benchmarks to optimize Idefics2's inference.

In [None]:
!pip install -q transformers bitsandbytes autoawq accelerate

In [None]:
import requests
import torch
from PIL import Image
from io import BytesIO

from transformers import AutoProcessor, AutoModelForVision2Seq, AwqConfig, BitsAndBytesConfig
from transformers.image_utils import load_image
import time

DEVICE = "cuda:0"

In [None]:
# These are the only variables to change
IMAGE_SPLITTING = False
DTYPE = torch.float16 # [torch.float32, torch.float16, torch.bfloat16]
ATTN_IMPLEMENTATION = "flash_attention_2" # [None, "flash_attention_2"]
QUANTIZATION_SCHEME = "awq" # [None, "bitsandbytes", "awq"]
AWQ_FUSING = False

In [None]:
model_name = "HuggingFaceM4/idefics2-8b"

if QUANTIZATION_SCHEME is not None:
    if QUANTIZATION_SCHEME == "bitsandbytes":
        quantization_config = BitsAndBytesConfig(
          load_in_4bit=True,
          bnb_4bit_quant_type="nf4",
          bnb_4bit_use_double_quant=True,
          bnb_4bit_compute_dtype=torch.float16
        )
    elif QUANTIZATION_SCHEME == "awq":
        if AWQ_FUSING:
            if ATTN_IMPLEMENTATION is not None:
                raise ValueError("Cannot use flash attention with AWQ Quantization")
            quantization_config = AwqConfig(
                bits=4,
                fuse_max_seq_len=4096,
                modules_to_fuse={
                    "attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
                    "mlp": ["gate_proj", "up_proj", "down_proj"],
                    "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
                    "use_alibi": False,
                    "num_attention_heads": 32,
                    "num_key_value_heads": 8,
                    "hidden_size": 4096,
                }
            )
        model_name = "HuggingFaceM4/idefics2-8b-AWQ"
    else:
        raise ValueError("Unknown configuration")


processor = AutoProcessor.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    do_image_splitting=IMAGE_SPLITTING
)
if QUANTIZATION_SCHEME == "bitsandbytes":
    model = AutoModelForVision2Seq.from_pretrained(
        model_name,
        torch_dtype=DTYPE,
        _attn_implementation=ATTN_IMPLEMENTATION,
        quantization_config=quantization_config,
    )
elif QUANTIZATION_SCHEME == "awq" and AWQ_FUSING:
    model = AutoModelForVision2Seq.from_pretrained(
        model_name,
        torch_dtype=DTYPE,
        _attn_implementation=ATTN_IMPLEMENTATION,
        quantization_config=quantization_config,
    ).to(DEVICE)
else:
    model = AutoModelForVision2Seq.from_pretrained(
        model_name,
        torch_dtype=DTYPE,
        _attn_implementation=ATTN_IMPLEMENTATION,
    ).to(DEVICE)


In [None]:
# Note that passing the image urls (instead of the actual pil images) to the processor is also possible
image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")


messages = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "What do we see in this image?"},
        ]
    },
    {
        "role": "assistant",
        "content": [
            {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
        ]
    },
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "And how about this image?"},
        ]
    },
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[image1, image2], return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}


In [None]:
NB_REPEATS = 20
start = time.time()
for i in range(NB_REPEATS):
    generated_ids = model.generate(**inputs, max_new_tokens=500)

print("time for 20 generations:", (time.time() - start))
print("max memory allocated:", torch.cuda.max_memory_allocated())
print("number of tokens generated:", len(generated_ids[:, inputs["input_ids"].size(1):][0]))
print(processor.batch_decode(generated_ids, skip_special_tokens=True))