## Packages etc

In [1]:
!pip install datasets
!pip install optimum
!pip install auto-gptq
!pip install -U transformers
!pip install -U bitsandbytes
!pip install editdistance

Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.4.1-py3-none-any.whl (487 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m487.4/487.4 kB[0m [31m33.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading

In [2]:
from google.colab import userdata
import os
from huggingface_hub import login

hf_token = userdata.get('hf_token')
os.environ["HUGGINGFACE_TOKEN"] = hf_token
login(token=hf_token)

## Basics

In [None]:
import numpy as np

# Original matrix
A = np.array([
    [1.2, -2.5, 3.7],
    [-4.1, 5.6, -6.9],
    [7.3, -8.2, 9.0]
])

# Symmetric quantization
def symmetric_quantize(x, bits=8):
    max_val = np.max(np.abs(x))
    scale = max_val / (2**(bits-1) - 1)
    quantized = np.round(x / scale).astype(np.int8)
    return quantized, scale

def symmetric_dequantize(q, scale):
    return q * scale

# Asymmetric quantization
def asymmetric_quantize(x, bits=8):
    min_val = np.min(x)
    max_val = np.max(x)
    scale = (max_val - min_val) / (2**bits - 1)
    zero_point = np.round(-min_val / scale).astype(np.uint8)
    quantized = np.round(x / scale + zero_point).astype(np.uint8)
    return quantized, scale, zero_point

def asymmetric_dequantize(q, scale, zero_point):
    return (q.astype(np.float32) - zero_point) * scale

# Symmetric quantization
q_sym, scale_sym = symmetric_quantize(A)
A_dequant_sym = symmetric_dequantize(q_sym, scale_sym)
error_sym = A - A_dequant_sym

print("Symmetric Quantization:")
print(f"Scale factor: {scale_sym:.6f}")
print("Quantized matrix:")
print(q_sym)
print("Dequantized matrix:")
print(np.around(A_dequant_sym, decimals=4))
print("Quantization error:")
print(np.around(error_sym, decimals=4))
print(f"Mean absolute error: {np.mean(np.abs(error_sym)):.6f}")

# Asymmetric quantization
q_asym, scale_asym, zero_point = asymmetric_quantize(A)
A_dequant_asym = asymmetric_dequantize(q_asym, scale_asym, zero_point)
error_asym = A - A_dequant_asym

print("\nAsymmetric Quantization:")
print(f"Scale factor: {scale_asym:.6f}")
print(f"Zero point: {zero_point}")
print("Quantized matrix:")
print(q_asym)
print("Dequantized matrix:")
print(np.around(A_dequant_asym, decimals=4))
print("Quantization error:")
print(np.around(error_asym, decimals=4))
print(f"Mean absolute error: {np.mean(np.abs(error_asym)):.6f}")

Symmetric Quantization:
Scale factor: 0.070866
Quantized matrix:
[[  17  -35   52]
 [ -58   79  -97]
 [ 103 -116  127]]
Dequantized matrix:
[[ 1.2047 -2.4803  3.685 ]
 [-4.1102  5.5984 -6.874 ]
 [ 7.2992 -8.2205  9.    ]]
Quantization error:
[[-0.0047 -0.0197  0.015 ]
 [ 0.0102  0.0016 -0.026 ]
 [ 0.0008  0.0205  0.    ]]
Mean absolute error: 0.010936

Asymmetric Quantization:
Scale factor: 0.067451
Zero point: 122
Quantized matrix:
[[140  85 177]
 [ 61 205  20]
 [230   0 255]]
Dequantized matrix:
[[ 1.2141 -2.4957  3.7098]
 [-4.1145  5.5984 -6.88  ]
 [ 7.2847 -8.229   8.971 ]]
Quantization error:
[[-0.0141 -0.0043 -0.0098]
 [ 0.0145  0.0016 -0.02  ]
 [ 0.0153  0.029   0.029 ]]
Mean absolute error: 0.015294


## Idefics2-8b activation stats

In [4]:
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch
model = AutoModelForVision2Seq.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    torch_dtype=torch.float16,
    device_map="auto"
)
processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b")
print(model)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/684 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/74.4k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/7 [00:00<?, ?it/s]

model-00001-of-00007.safetensors:   0%|          | 0.00/4.64G [00:00<?, ?B/s]

model-00002-of-00007.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00007.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00004-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00005-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00006-of-00007.safetensors:   0%|          | 0.00/4.83G [00:00<?, ?B/s]

model-00007-of-00007.safetensors:   0%|          | 0.00/4.25G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Chat templates should be in a 'chat_template.jinja' file but found key='chat_template' in the processor's config. Make sure to move your template to its own file.


preprocessor_config.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.64k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/92.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.04k [00:00<?, ?B/s]

Idefics2ForConditionalGeneration(
  (model): Idefics2Model(
    (vision_model): Idefics2VisionTransformer(
      (embeddings): Idefics2VisionEmbeddings(
        (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
        (position_embedding): Embedding(4900, 1152)
      )
      (encoder): Idefics2Encoder(
        (layers): ModuleList(
          (0-26): 27 x Idefics2EncoderLayer(
            (self_attn): Idefics2VisionAttention(
              (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (mlp): Idefics2VisionMLP(
              (activation_fn): PytorchGELUTanh()
              (fc1): Linear(in

In [4]:
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
from collections import defaultdict
import numpy as np
import requests
from PIL import Image
from io import BytesIO

def collect_activation_statistics(model_id, image_urls, text_prompts):
    """
    Collect activation statistics using provided image URLs
    """
    print("Loading model and processor...")
    # Load model in FP16 for activation collection
    model = AutoModelForVision2Seq.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    processor = AutoProcessor.from_pretrained(model_id)

    activation_stats = defaultdict(list)

    # Register hooks
    hooks = []

    # Simple hook function that captures activation data
    def activation_hook(name):
        def hook(module, inputs, outputs):
            print(f"Hook triggered for {name}")
            # Make sure we have input data
            if not inputs or not isinstance(inputs[0], torch.Tensor):
                print(f"  No valid inputs for {name}")
                return

            # Get input activations
            activation = inputs[0].detach()

            # Calculate and store statistics
            if activation.numel() > 0:
                with torch.no_grad():
                    try:
                        # Convert to float32 for calculations if needed
                        if activation.dtype in [torch.float16, torch.bfloat16]:
                            abs_vals = activation.abs().to(torch.float32)
                        else:
                            abs_vals = activation.abs()

                        # Handle potential NaN values
                        if torch.isnan(abs_vals).any():
                            print(f"  Warning: NaN values detected in {name}")
                            return

                        # Compute statistics directly on GPU when possible
                        mean_val = float(abs_vals.mean().cpu())

                        # For quantiles, ensure we have float32
                        flat_vals = abs_vals.reshape(-1)
                        p99_9 = float(torch.quantile(flat_vals, 0.999).cpu())
                        max_val = float(abs_vals.max().cpu())

                        stats = {
                            'abs_mean': mean_val,
                            'p99.9': p99_9,
                            'max': max_val
                        }

                        activation_stats[name].append(stats)
                        print(f"  Collected stats for {name}: p99.9={stats['p99.9']:.4f}")
                    except Exception as e:
                        print(f"  Error calculating statistics for {name}: {e}")
                        # Try a simpler approach for statistics
                        try:
                            abs_vals_cpu = activation.abs().cpu().to(torch.float32)
                            stats = {
                                'abs_mean': float(abs_vals_cpu.mean()),
                                'p99.9': float(np.percentile(abs_vals_cpu.numpy().flatten(), 99.9)),
                                'max': float(abs_vals_cpu.max())
                            }
                            activation_stats[name].append(stats)
                            print(f"  Collected stats using fallback for {name}: p99.9={stats['p99.9']:.4f}")
                        except Exception as e2:
                            print(f"  Failed to collect statistics (fallback) for {name}: {e2}")

        return hook

    # Register hooks for Linear layers in vision component
    print("Registering hooks...")
    for name, module in model.named_modules():
        # Target specific components that might have outlier activations
        if isinstance(module, torch.nn.Linear) and any(comp in name for comp in [
            "vision_model.encoder",
            "connector",
            "text_model.layers"
        ]):
            hooks.append(module.register_forward_hook(activation_hook(name)))
            print(f"Registered hook for {name}")

    print(f"Registered {len(hooks)} hooks")

    # Process each image URL
    for i, (image_url, text_prompt) in enumerate(zip(image_urls, text_prompts)):
        try:
            print(f"Processing image {i+1}/{len(image_urls)}")

            # Download and process the image
            print(f"Downloading image from {image_url}")
            response = requests.get(image_url)
            image = Image.open(BytesIO(response.content))

            # Create message format
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": text_prompt},
                    ]
                }
            ]

            # Process inputs
            print("Preparing model inputs")
            prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
            inputs = processor(text=prompt, images=[image], return_tensors="pt")
            inputs = {k: v.to(model.device) for k, v in inputs.items()}

            # Run forward pass
            print("Running model inference")
            with torch.no_grad():
                outputs = model.generate(**inputs, max_new_tokens=1)

            print("Completed inference")

        except Exception as e:
            print(f"Error processing image {i+1}: {e}")

    # Remove hooks
    print("Removing hooks")
    for hook in hooks:
        hook.remove()

    # Aggregate statistics
    aggregated_stats = {}
    for name, samples in activation_stats.items():
        if samples:
            aggregated_stats[name] = {
                'mean_abs_mean': np.mean([s['abs_mean'] for s in samples]),
                'mean_p99.9': np.mean([s['p99.9'] for s in samples]),
                'max_p99.9': max([s['p99.9'] for s in samples]),
                'max_max': max([s['max'] for s in samples])
            }

    return aggregated_stats

# Using the provided images
image_urls = [
    "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
    "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg",
    "https://static.vecteezy.com/system/resources/thumbnails/002/098/203/small_2x/silver-tabby-cat-sitting-on-green-background-free-photo.jpg",
    "https://public-site.marketing.pandadoc-static.com/app/uploads/1040-2017.png",
    "https://images.fineartamerica.com/images-medium-large-5/carpark-viewed-from-above-with-cars-ken-welsh--design-pics.jpg",
    "https://joyfoodsunshine.com/wp-content/uploads/2022/05/summer-fruit-salad-recipe-1.jpg"
]

text_prompts = [
    "What famous landmarks can you see in this image?",
    "Describe this cityscape.",
    "What do you see in this image?",
    "What kind of document is this?",
    "How many cars are in this image?",
    "What is this?"
]

# Run the analysis
activation_stats = collect_activation_statistics(
    "HuggingFaceM4/idefics2-8b",
    image_urls=image_urls,
    text_prompts=text_prompts
)

# Analyze results
if activation_stats:
    # Sort modules by p99.9 values
    sorted_stats = dict(sorted(
        activation_stats.items(),
        key=lambda x: x[1]['mean_p99.9'],
        reverse=True
    ))

    print("\nTop modules with highest activation values:")
    for i, (name, stats) in enumerate(sorted_stats.items()):
        print(f"{name}: p99.9={stats['mean_p99.9']:.4f}")
        if i >= 9:  # Show top 10
            break

    # Determine reasonable threshold
    p99_9_values = [stats['mean_p99.9'] for stats in activation_stats.values()]
    if p99_9_values:
        threshold = np.percentile(p99_9_values, 95)  # 95th percentile
        print(f"\nRecommended threshold: {threshold:.4f}")

        # Identify outlier modules
        print("\nModules to consider skipping (values > 1.2x threshold):")
        for name, stats in sorted_stats.items():
            if stats['mean_p99.9'] > threshold * 1.2:
                print(f"  {name}: p99.9={stats['mean_p99.9']:.4f}")

        # Generate skip list
        print("\nRecommended llm_int8_skip_modules configuration:")
        print("llm_int8_skip_modules=[")
        # Original modules known to need skipping
        for module in [
            "model.vision_model.post_layernorm",
            "model.vision_model.embeddings",
            "model.connector",
            "model.text_model.norm",
            "model.text_model.layers.*.input_layernorm",
            "model.text_model.layers.*.post_attention_layernorm",
            "lm_head"
        ]:
            print(f"    \"{module}\",")
        # Add modules with extreme outliers
        for name, stats in sorted_stats.items():
            if stats['mean_p99.9'] > threshold * 1.5:  # Very high outliers
                print(f"    \"{name}\",")
        print("]")
else:
    print("No statistics were collected.")

Loading model and processor...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/684 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/74.4k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/7 [00:00<?, ?it/s]

model-00001-of-00007.safetensors:   0%|          | 0.00/4.64G [00:00<?, ?B/s]

model-00002-of-00007.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00007.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00004-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00005-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00006-of-00007.safetensors:   0%|          | 0.00/4.83G [00:00<?, ?B/s]

model-00007-of-00007.safetensors:   0%|          | 0.00/4.25G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Chat templates should be in a 'chat_template.jinja' file but found key='chat_template' in the processor's config. Make sure to move your template to its own file.


preprocessor_config.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.64k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/92.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.04k [00:00<?, ?B/s]

Registering hooks...
Registered hook for model.vision_model.encoder.layers.0.self_attn.k_proj
Registered hook for model.vision_model.encoder.layers.0.self_attn.v_proj
Registered hook for model.vision_model.encoder.layers.0.self_attn.q_proj
Registered hook for model.vision_model.encoder.layers.0.self_attn.out_proj
Registered hook for model.vision_model.encoder.layers.0.mlp.fc1
Registered hook for model.vision_model.encoder.layers.0.mlp.fc2
Registered hook for model.vision_model.encoder.layers.1.self_attn.k_proj
Registered hook for model.vision_model.encoder.layers.1.self_attn.v_proj
Registered hook for model.vision_model.encoder.layers.1.self_attn.q_proj
Registered hook for model.vision_model.encoder.layers.1.self_attn.out_proj
Registered hook for model.vision_model.encoder.layers.1.mlp.fc1
Registered hook for model.vision_model.encoder.layers.1.mlp.fc2
Registered hook for model.vision_model.encoder.layers.2.self_attn.k_proj
Registered hook for model.vision_model.encoder.layers.2.self_a

In [5]:
!pip install datasets matplotlib pandas tqdm



## Quants

In [6]:
import os
import torch
import json
import datetime
from transformers import Idefics2ForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
import transformers
import bitsandbytes as bnb
import gc


# Set paths for Google Colab
BASE_DIR = "/content/quantized_models"
os.makedirs(BASE_DIR, exist_ok=True)

# Create function to quantize and save model
def quantize_and_save_model(quantization_type, quantization_config, model_name_suffix):
    """Quantize and save the IDEFICS2 model with the specified configuration"""

    MODEL_NAME = f"idefics2-8b-{model_name_suffix}"
    SAVE_PATH = os.path.join(BASE_DIR, MODEL_NAME)

    # Create directory
    os.makedirs(SAVE_PATH, exist_ok=True)

    # Clear GPU memory
    torch.cuda.empty_cache()

    # Load processor
    processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b")
    processor.save_pretrained(SAVE_PATH)

    # Load quantized model
    print(f"Loading and quantizing model in {quantization_type} format...")
    model_quantized = Idefics2ForConditionalGeneration.from_pretrained(
        "HuggingFaceM4/idefics2-8b",
        quantization_config=quantization_config,
        device_map="auto",
        torch_dtype=torch.bfloat16
    )

    print(f"Saving quantized model to {SAVE_PATH}...")
    model_quantized.save_pretrained(
        SAVE_PATH,
        safe_serialization=True,
        max_shard_size="2GB"
    )

    # Save metadata
    metadata = {
        "base_model": "HuggingFaceM4/idefics2-8b",
        "quantization": model_name_suffix,
        "version": "1.0.0",
        "creation_date": datetime.datetime.now().isoformat(),
        "framework_version": {
            "transformers": transformers.__version__,
            "torch": torch.__version__,
            "bitsandbytes": bnb.__version__
        },
        "note": f"Optimized {quantization_type} quantization with calibrated threshold and skip modules"
    }

    with open(os.path.join(SAVE_PATH, "model_metadata.json"), "w") as f:
        json.dump(metadata, f, indent=2)

    print(f"Model successfully saved to {SAVE_PATH}")

    # Clear GPU memory again
    del model_quantized
    torch.cuda.empty_cache()
    import gc
    gc.collect()

# Quantize to INT8 with our calibrated parameters
print("=== Quantizing to INT8 ===")
int8_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=36.0,  # Our calibrated value
    llm_int8_has_fp16_weight=False,
    llm_int8_skip_modules=[
        "model.vision_model.post_layernorm",
        "model.vision_model.embeddings",
        "model.connector",
        "model.text_model.norm",
        "model.text_model.layers.*.input_layernorm",
        "model.text_model.layers.*.post_attention_layernorm",
        "lm_head",
        "model.vision_model.encoder.layers.21.mlp.fc1",
        "model.vision_model.encoder.layers.22.mlp.fc1",
        "model.vision_model.encoder.layers.23.mlp.fc1",
        "model.vision_model.encoder.layers.24.mlp.fc1",
        "model.vision_model.encoder.layers.25.mlp.fc1",
        "model.vision_model.encoder.layers.26.mlp.fc1",

    ],
    llm_int8_enable_fp32_cpu_offload=False,
    bnb_8bit_compute_dtype=torch.bfloat16
)

quantize_and_save_model(
    "INT8",
    int8_config,
    "8bit-calibrated"
)

# Quantize to INT4 (NF4)
print("\n=== Quantizing to INT4 (NF4) ===")
int4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    # For 4-bit, we keep the same skip modules as 8-bit
    # Plus additional layer norms from the vision model
    llm_int8_skip_modules=[
        "model.vision_model.post_layernorm",
        "model.vision_model.embeddings",
        "model.vision_model.encoder.layers.*.layer_norm1",
        "model.vision_model.encoder.layers.*.layer_norm2",
        "model.connector",
        "model.text_model.norm",
        "model.text_model.layers.*.input_layernorm",
        "model.text_model.layers.*.post_attention_layernorm",
        "model.text_model.embed_tokens",
        "lm_head",
        "model.vision_model.encoder.layers.24.mlp.fc1",
        "model.vision_model.encoder.layers.23.mlp.fc1",
        "model.vision_model.encoder.layers.25.mlp.fc1",
    ]
)

quantize_and_save_model(
    "INT4",
    int4_config,
    "4bit-nf4-calibrated"
)

# Print final status
print("\nQuantization complete! Models saved to Google Drive:")
print(f"INT8 model: {os.path.join(BASE_DIR, 'idefics2-8b-8bit-calibrated')}")
print(f"INT4 model: {os.path.join(BASE_DIR, 'idefics2-8b-4bit-nf4-calibrated')}")

# Test loading a quantized model to verify
print("\nVerifying that we can load the quantized model...")

def test_load_model(model_path):
    try:
        # Load model
        model = Idefics2ForConditionalGeneration.from_pretrained(
            model_path,
            device_map="auto"
        )
        processor = AutoProcessor.from_pretrained(model_path)

        print(f"Successfully loaded model from {model_path}")

        # Simple test with a dummy input
        test_text = "What can you see in this image?"
        test_input = processor(text=test_text, return_tensors="pt").to("cuda")

        # Just run a very small generation to verify functionality
        with torch.no_grad():
            outputs = model.generate(
                **test_input,
                max_new_tokens=1,
                do_sample=False
            )

        print("Model successfully executed a test generation")

        # Clean up
        del model
        torch.cuda.empty_cache()
        gc.collect()

        return True
    except Exception as e:
        print(f"Error loading model: {e}")
        return False

print("\nTesting INT8 model:")
int8_success = test_load_model(os.path.join(BASE_DIR, "idefics2-8b-8bit-calibrated"))

print("\nQuantization process complete!")

=== Quantizing to INT8 ===
Loading and quantizing model in INT8 format...


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

Saving quantized model to /content/quantized_models/idefics2-8b-8bit-calibrated...
Model successfully saved to /content/quantized_models/idefics2-8b-8bit-calibrated

=== Quantizing to INT4 (NF4) ===
Loading and quantizing model in INT4 format...


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

Saving quantized model to /content/quantized_models/idefics2-8b-4bit-nf4-calibrated...
Model successfully saved to /content/quantized_models/idefics2-8b-4bit-nf4-calibrated

Quantization complete! Models saved to Google Drive:
INT8 model: /content/quantized_models/idefics2-8b-8bit-calibrated
INT4 model: /content/quantized_models/idefics2-8b-4bit-nf4-calibrated

Verifying that we can load the quantized model...

Testing INT8 model:


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Successfully loaded model from /content/quantized_models/idefics2-8b-8bit-calibrated
Model successfully executed a test generation

Quantization process complete!


## DocVQA evals

In [None]:
!pip install flash-attn --no-build-isolation

Collecting flash-attn
  Downloading flash_attn-2.7.4.post1.tar.gz (6.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m54.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: flash-attn
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone
  Created wheel for flash-attn: filename=flash_attn-2.7.4.post1-cp311-cp311-linux_x86_64.whl size=187815463 sha256=d944fc7d2f962bce83fc4708c2fc0c21eaf8255962a0b350ae919362a51b7ef2
  Stored in directory: /root/.cache/pip/wheels/3d/88/d8/284b89f56af7d5bf366b10d6b8e251ac8a7c7bf3f04203fb4f
Successfully built flash-attn
Installing collected packages: flash-attn
Successfully installed flash-attn-2.7.4.post1


In [7]:
import torch
import editdistance
from datasets import load_dataset
from tqdm import tqdm
from transformers import (
    AutoProcessor,
    Idefics2ForConditionalGeneration,
    BitsAndBytesConfig
)
import gc
import os
import time
import json
import numpy as np
from IPython.display import display, HTML


def setup_model(model_type="original"):
    """Setup model with specified quantization type."""
    # Clear GPU memory
    torch.cuda.empty_cache()
    gc.collect()

    # Load processor with efficiency improvements
    processor = AutoProcessor.from_pretrained(
        "HuggingFaceM4/idefics2-8b",
        do_image_splitting=False,  # Disable image splitting for memory efficiency
        size={"longest_edge": 448, "shortest_edge": 378}  # Lower resolution
    )

    # Common kwargs for all model variants
    common_kwargs = {
        "device_map": "auto",
        "torch_dtype": torch.float16,  # Using fp16 for better compatibility with Flash Attention
        "low_cpu_mem_usage": True,
        # "_attn_implementation": "flash_attention_2"  # Enable Flash Attention 2
    }

    if model_type == "original":
        model = Idefics2ForConditionalGeneration.from_pretrained(
            "HuggingFaceM4/idefics2-8b",
            **common_kwargs
        )
    elif model_type == "int8":
        model = Idefics2ForConditionalGeneration.from_pretrained(
            "/content/quantized_models/idefics2-8b-8bit-calibrated",
            **common_kwargs
        )
    elif model_type == "int4":
        model = Idefics2ForConditionalGeneration.from_pretrained(
            "/content/quantized_models/idefics2-8b-4bit-nf4-calibrated",
            **common_kwargs
        )

    # Print model configuration details for verification
    print(f"\n=== Model Configuration ({model_type}) ===")
    print(f"Flash Attention 2: Enabled")
    print(f"Image Splitting: Disabled")
    print(f"Float Type: fp16")
    print(f"Quantization: {model_type}")

    model.eval()
    return processor, model

def clean_answer(text):
    """Remove only the trailing period if it exists."""
    return text.rstrip('.')

def process_single_question(processor, model, image, question):
    """Process a single VQA query."""
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": f"{question} Answer with ONLY the specific text or value from the document, no explanations or full sentences."},
            ]
        }
    ]

    prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(
        text=prompt,
        images=[image],
        return_tensors="pt"
    )
    inputs = {k: v.cuda() for k, v in inputs.items()}

    # Measure inference time
    start_time = time.time()
    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=30,
            num_beams=1,
            do_sample=False
        )
    end_time = time.time()

    response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    response = response.split("Assistant: ")[-1].strip()

    # Clear memory
    del inputs
    del generated_ids
    torch.cuda.empty_cache()

    return clean_answer(response), end_time - start_time

def get_anls(s1, s2):
    """Calculate Average Normalized Levenshtein Similarity."""
    s1 = s1.lower().strip()
    s2 = s2.lower().strip()
    iou = 1 - editdistance.eval(s1, s2) / max(len(s1), len(s2))
    return iou if iou >= 0.5 else 0.0

def benchmark_model(model_type, dataset, sample_size=50):
    """Benchmark a specific model type on DocVQA."""
    print(f"\n=== Benchmarking {model_type} model ===")

    # Load model
    processor, model = setup_model(model_type)
    analyze_memory_usage(model)

    # Track memory usage
    gpu_memory_start = torch.cuda.memory_reserved() / 1024**3

    # Prepare results tracking
    scores = []
    inference_times = []
    results_log = []

    # Process only a subset for quicker benchmarking
    subset = dataset.select(range(min(sample_size, len(dataset))))

    # Track total QA pairs
    total_qa_pairs = sum(len(example["qa"]) for example in subset)
    print(f"Processing {len(subset)} documents with {total_qa_pairs} QA pairs")

    # Track progress with a single progress bar for all QA pairs
    progress_bar = tqdm(total=total_qa_pairs, desc=f"{model_type} Progress")

    try:
        for example_idx, example in enumerate(subset):
            image = example["image"]

            for qa in example["qa"]:
                question = qa["question"]
                ground_truth_answers = qa["answers"]

                # Process the question
                model_answer, inference_time = process_single_question(
                    processor, model, image, question
                )

                # Calculate best ANLS score
                best_anls = max(
                    get_anls(model_answer, gt)
                    for gt in ground_truth_answers
                )

                # Track metrics
                scores.append(best_anls)
                inference_times.append(inference_time)

                # Log detailed results
                results_log.append({
                    'question': question,
                    'model_answer': model_answer,
                    'ground_truth': ground_truth_answers,
                    'anls': best_anls,
                    'inference_time': inference_time
                })

                # Update progress
                progress_bar.update(1)

                # Show a few examples
                if len(scores) <= 3:
                    print(f"\nQuestion: {question}")
                    print(f"Model answer: {model_answer}")
                    print(f"Ground truth: {ground_truth_answers}")
                    print(f"ANLS: {best_anls:.4f}")

                # Show metrics every 10 examples
                if len(scores) % 10 == 0:
                    current_avg = sum(scores) / len(scores)
                    avg_time = sum(inference_times) / len(inference_times)
                    print(f"\nCurrent metrics after {len(scores)} examples:")
                    print(f"Average ANLS: {current_avg:.4f}")
                    print(f"Average inference time: {avg_time:.2f}s")

    except Exception as e:
        print(f"Error during evaluation: {e}")

    finally:
        # Close progress bar
        progress_bar.close()

        # Calculate final metrics
        final_anls = sum(scores) / len(scores) if scores else 0
        avg_inference_time = sum(inference_times) / len(inference_times) if inference_times else 0

        # Get memory usage
        gpu_memory_peak = torch.cuda.max_memory_reserved() / 1024**3

        # Clean up
        del model
        torch.cuda.empty_cache()
        gc.collect()

        # Return metrics
        return {
            "model_type": model_type,
            "anls_score": final_anls,
            "avg_inference_time": avg_inference_time,
            "num_examples": len(scores),
            "gpu_memory_peak_gb": gpu_memory_peak,
            "sample_results": results_log[:5]  # First 5 examples
        }

def analyze_memory_usage(model):
    """Analyze detailed memory usage of the model."""
    print("\n=== Memory Usage Analysis ===")

    # 1. Model parameters memory
    param_count = sum(p.numel() for p in model.parameters())
    param_size_gb = param_count * 2 / (1024**3)  # assuming bfloat16
    print(f"Parameter count: {param_count:,}")
    print(f"Parameters size (bfloat16): {param_size_gb:.2f} GB")

    # 2. Check if there are any unexpected FP32 parameters
    fp32_params = sum(p.numel() for p in model.parameters() if p.dtype == torch.float32)
    if fp32_params > 0:
        fp32_size_gb = fp32_params * 4 / (1024**3)
        print(f"WARNING: Found {fp32_params:,} parameters in FP32 ({fp32_size_gb:.2f} GB)")

    # 3. Check CUDA memory allocations
    if hasattr(torch.cuda, 'memory_allocated'):
        allocated = torch.cuda.memory_allocated() / (1024**3)
        reserved = torch.cuda.memory_reserved() / (1024**3)
        print(f"CUDA memory allocated: {allocated:.2f} GB")
        print(f"CUDA memory reserved: {reserved:.2f} GB")

    # 4. Check if KV cache is being properly managed
    print("\nModule types that might contribute to memory usage:")
    module_types = {}
    for name, module in model.named_modules():
        module_type = type(module).__name__
        if module_type not in module_types:
            module_types[module_type] = 1
        else:
            module_types[module_type] += 1

    for module_type, count in sorted(module_types.items(), key=lambda x: x[1], reverse=True)[:10]:
        print(f"  {module_type}: {count}")

def run_docvqa_benchmarks():
    """Run DocVQA benchmarks on all model variants."""
    print("Starting IDEFICS2 DocVQA benchmark")

    # Load a small sample of DocVQA validation set
    print("Loading DocVQA validation dataset...")
    dataset = load_dataset("vikhyatk/docvqa-val", split="validation")

    # Determine sample size based on available memory
    sample_size = 100

    # Run benchmarks for each model type
    results = {}

    try:
        # Original bfloat16 model
        results["original"] = benchmark_model("original", dataset, sample_size)

        # INT8 model with calibrated parameters
        results["int8"] = benchmark_model("int8", dataset, sample_size)

        # INT4 model
        results["int4"] = benchmark_model("int4", dataset, sample_size)

    except Exception as e:
        print(f"Error during benchmarking: {e}")

    # Display summary table
    print("\n=== DocVQA Benchmark Results ===")
    print("Model Type | ANLS Score | Avg Inference Time (s) | Peak GPU Memory (GB)")
    print("----------|------------|------------------------|--------------------")

    for model_type in ["original", "int8", "int4"]:
        if model_type in results:
            r = results[model_type]
            print(f"{model_type.ljust(10)} | {r['anls_score']:.4f} | {r['avg_inference_time']:.4f} | {r['gpu_memory_peak_gb']:.2f}")

    # Calculate relative metrics if original model results exist
    if "original" in results and results["original"]["anls_score"] > 0:
        print("\n=== Relative Performance ===")
        orig_anls = results["original"]["anls_score"]
        orig_time = results["original"]["avg_inference_time"]
        orig_mem = results["original"]["gpu_memory_peak_gb"]

        for model_type in ["int8", "int4"]:
            if model_type in results:
                r = results[model_type]
                anls_ratio = r["anls_score"] / orig_anls if orig_anls > 0 else 0
                time_ratio = orig_time / r["avg_inference_time"] if r["avg_inference_time"] > 0 else 0
                mem_ratio = r["gpu_memory_peak_gb"] / orig_mem if orig_mem > 0 else 0

                print(f"{model_type} vs original:")
                print(f"  - Accuracy: {anls_ratio:.2f}x ({anls_ratio*100:.1f}%)")
                print(f"  - Speed: {time_ratio:.2f}x faster")
                print(f"  - Memory: {mem_ratio:.2f}x ({mem_ratio*100:.1f}% of original)")

    # Save detailed results
    with open("/content/idefics2_docvqa_results.json", "w") as f:
        json.dump(results, f, indent=2)

    print("\nDetailed results saved to /content/idefics2_docvqa_results.json")

    return results

run_docvqa_benchmarks()

Starting IDEFICS2 DocVQA benchmark
Loading DocVQA validation dataset...


README.md:   0%|          | 0.00/404 [00:00<?, ?B/s]

validation-00000-of-00002.parquet:   0%|          | 0.00/418M [00:00<?, ?B/s]

validation-00001-of-00002.parquet:   0%|          | 0.00/415M [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/1286 [00:00<?, ? examples/s]


=== Benchmarking original model ===


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]


=== Model Configuration (original) ===
Flash Attention 2: Enabled
Image Splitting: Disabled
Float Type: fp16
Quantization: original

=== Memory Usage Analysis ===
Parameter count: 8,402,768,112
Parameters size (bfloat16): 15.65 GB
CUDA memory allocated: 31.34 GB
CUDA memory reserved: 31.55 GB

Module types that might contribute to memory usage:
  Linear: 411
  MistralRMSNorm: 65
  LayerNorm: 55
  SiLU: 36
  MistralDecoderLayer: 32
  MistralAttention: 32
  MistralMLP: 32
  Idefics2EncoderLayer: 27
  Idefics2VisionAttention: 27
  Idefics2VisionMLP: 27
Processing 100 documents with 508 QA pairs


original Progress:   0%|          | 1/508 [00:01<10:40,  1.26s/it]


Question: On which date should the form be completed and send ?
Model answer: September 28
Ground truth: ['September 28']
ANLS: 1.0000


original Progress:   0%|          | 2/508 [00:01<07:18,  1.15it/s]


Question: What is the date mentioned in the form ?
Model answer: September 29, 1966
Ground truth: ['September 28, 1988']
ANLS: 0.8333


original Progress:   1%|          | 3/508 [00:02<05:25,  1.55it/s]


Question: What is the year of operating plan ?
Model answer: 1969
Ground truth: ['1989']
ANLS: 0.7500


original Progress:   2%|▏         | 10/508 [00:04<03:22,  2.46it/s]


Current metrics after 10 examples:
Average ANLS: 0.6925
Average inference time: 0.44s


original Progress:   4%|▍         | 20/508 [00:08<03:32,  2.29it/s]


Current metrics after 20 examples:
Average ANLS: 0.6705
Average inference time: 0.38s


original Progress:   6%|▌         | 30/508 [00:13<04:07,  1.93it/s]


Current metrics after 30 examples:
Average ANLS: 0.6056
Average inference time: 0.39s


original Progress:   8%|▊         | 40/508 [00:17<02:39,  2.93it/s]


Current metrics after 40 examples:
Average ANLS: 0.5625
Average inference time: 0.38s


original Progress:  10%|▉         | 50/508 [00:20<03:26,  2.21it/s]


Current metrics after 50 examples:
Average ANLS: 0.5975
Average inference time: 0.36s


original Progress:  12%|█▏        | 60/508 [00:26<03:44,  1.99it/s]


Current metrics after 60 examples:
Average ANLS: 0.5791
Average inference time: 0.37s


original Progress:  14%|█▍        | 70/508 [00:31<03:34,  2.04it/s]


Current metrics after 70 examples:
Average ANLS: 0.5565
Average inference time: 0.38s


original Progress:  16%|█▌        | 80/508 [00:35<02:40,  2.66it/s]


Current metrics after 80 examples:
Average ANLS: 0.5498
Average inference time: 0.37s


original Progress:  18%|█▊        | 90/508 [00:39<02:54,  2.40it/s]


Current metrics after 90 examples:
Average ANLS: 0.5675
Average inference time: 0.36s


original Progress:  20%|█▉        | 100/508 [00:43<02:35,  2.62it/s]


Current metrics after 100 examples:
Average ANLS: 0.5748
Average inference time: 0.37s


original Progress:  22%|██▏       | 110/508 [00:49<02:25,  2.73it/s]


Current metrics after 110 examples:
Average ANLS: 0.5497
Average inference time: 0.37s


original Progress:  24%|██▎       | 120/508 [00:53<02:50,  2.27it/s]


Current metrics after 120 examples:
Average ANLS: 0.5675
Average inference time: 0.37s


original Progress:  26%|██▌       | 130/508 [00:58<03:41,  1.71it/s]


Current metrics after 130 examples:
Average ANLS: 0.5550
Average inference time: 0.38s


original Progress:  28%|██▊       | 140/508 [01:02<02:09,  2.85it/s]


Current metrics after 140 examples:
Average ANLS: 0.5672
Average inference time: 0.38s


original Progress:  30%|██▉       | 150/508 [01:05<02:02,  2.92it/s]


Current metrics after 150 examples:
Average ANLS: 0.5791
Average inference time: 0.37s


original Progress:  31%|███▏      | 160/508 [01:09<02:16,  2.56it/s]


Current metrics after 160 examples:
Average ANLS: 0.5688
Average inference time: 0.37s


original Progress:  33%|███▎      | 170/508 [01:13<02:25,  2.32it/s]


Current metrics after 170 examples:
Average ANLS: 0.5729
Average inference time: 0.37s


original Progress:  35%|███▌      | 180/508 [01:17<02:17,  2.39it/s]


Current metrics after 180 examples:
Average ANLS: 0.5550
Average inference time: 0.37s


original Progress:  37%|███▋      | 190/508 [01:21<01:55,  2.75it/s]


Current metrics after 190 examples:
Average ANLS: 0.5519
Average inference time: 0.36s


original Progress:  39%|███▉      | 200/508 [01:26<02:55,  1.76it/s]


Current metrics after 200 examples:
Average ANLS: 0.5407
Average inference time: 0.37s


original Progress:  41%|████▏     | 210/508 [01:30<01:58,  2.52it/s]


Current metrics after 210 examples:
Average ANLS: 0.5480
Average inference time: 0.37s


original Progress:  43%|████▎     | 220/508 [01:34<01:40,  2.86it/s]


Current metrics after 220 examples:
Average ANLS: 0.5536
Average inference time: 0.37s


original Progress:  45%|████▌     | 230/508 [01:39<01:45,  2.64it/s]


Current metrics after 230 examples:
Average ANLS: 0.5481
Average inference time: 0.37s


original Progress:  47%|████▋     | 240/508 [01:43<02:08,  2.09it/s]


Current metrics after 240 examples:
Average ANLS: 0.5392
Average inference time: 0.37s


original Progress:  49%|████▉     | 250/508 [01:47<01:42,  2.52it/s]


Current metrics after 250 examples:
Average ANLS: 0.5417
Average inference time: 0.37s


original Progress:  51%|█████     | 260/508 [01:51<01:38,  2.53it/s]


Current metrics after 260 examples:
Average ANLS: 0.5342
Average inference time: 0.37s


original Progress:  53%|█████▎    | 270/508 [01:55<01:50,  2.15it/s]


Current metrics after 270 examples:
Average ANLS: 0.5276
Average inference time: 0.37s


original Progress:  55%|█████▌    | 280/508 [02:01<01:46,  2.13it/s]


Current metrics after 280 examples:
Average ANLS: 0.5194
Average inference time: 0.37s


original Progress:  57%|█████▋    | 290/508 [02:06<01:59,  1.83it/s]


Current metrics after 290 examples:
Average ANLS: 0.5101
Average inference time: 0.37s


original Progress:  59%|█████▉    | 300/508 [02:10<01:37,  2.13it/s]


Current metrics after 300 examples:
Average ANLS: 0.5034
Average inference time: 0.38s


original Progress:  61%|██████    | 310/508 [02:15<01:30,  2.19it/s]


Current metrics after 310 examples:
Average ANLS: 0.5035
Average inference time: 0.38s


original Progress:  63%|██████▎   | 320/508 [02:19<01:03,  2.97it/s]


Current metrics after 320 examples:
Average ANLS: 0.5054
Average inference time: 0.38s


original Progress:  65%|██████▍   | 330/508 [02:23<01:25,  2.07it/s]


Current metrics after 330 examples:
Average ANLS: 0.5117
Average inference time: 0.38s


original Progress:  67%|██████▋   | 340/508 [02:28<01:15,  2.24it/s]


Current metrics after 340 examples:
Average ANLS: 0.5189
Average inference time: 0.38s


original Progress:  69%|██████▉   | 350/508 [02:31<01:01,  2.57it/s]


Current metrics after 350 examples:
Average ANLS: 0.5203
Average inference time: 0.37s


original Progress:  71%|███████   | 360/508 [02:35<00:42,  3.52it/s]


Current metrics after 360 examples:
Average ANLS: 0.5169
Average inference time: 0.37s


original Progress:  73%|███████▎  | 370/508 [02:39<00:55,  2.50it/s]


Current metrics after 370 examples:
Average ANLS: 0.5174
Average inference time: 0.37s


original Progress:  75%|███████▍  | 380/508 [02:44<01:04,  2.00it/s]


Current metrics after 380 examples:
Average ANLS: 0.5079
Average inference time: 0.37s


original Progress:  77%|███████▋  | 390/508 [02:48<00:47,  2.46it/s]


Current metrics after 390 examples:
Average ANLS: 0.5067
Average inference time: 0.38s


original Progress:  79%|███████▊  | 400/508 [02:53<00:45,  2.40it/s]


Current metrics after 400 examples:
Average ANLS: 0.5056
Average inference time: 0.38s


original Progress:  81%|████████  | 410/508 [02:57<00:31,  3.14it/s]


Current metrics after 410 examples:
Average ANLS: 0.5103
Average inference time: 0.38s


original Progress:  83%|████████▎ | 420/508 [03:01<00:33,  2.60it/s]


Current metrics after 420 examples:
Average ANLS: 0.5066
Average inference time: 0.37s


original Progress:  85%|████████▍ | 430/508 [03:05<00:28,  2.72it/s]


Current metrics after 430 examples:
Average ANLS: 0.5043
Average inference time: 0.37s


original Progress:  87%|████████▋ | 440/508 [03:09<00:22,  2.96it/s]


Current metrics after 440 examples:
Average ANLS: 0.5092
Average inference time: 0.37s


original Progress:  89%|████████▊ | 450/508 [03:13<00:23,  2.47it/s]


Current metrics after 450 examples:
Average ANLS: 0.5119
Average inference time: 0.37s


original Progress:  91%|█████████ | 460/508 [03:18<00:25,  1.86it/s]


Current metrics after 460 examples:
Average ANLS: 0.5073
Average inference time: 0.37s


original Progress:  93%|█████████▎| 470/508 [03:23<00:17,  2.14it/s]


Current metrics after 470 examples:
Average ANLS: 0.5104
Average inference time: 0.37s


original Progress:  94%|█████████▍| 480/508 [03:27<00:15,  1.77it/s]


Current metrics after 480 examples:
Average ANLS: 0.5162
Average inference time: 0.38s


original Progress:  96%|█████████▋| 490/508 [03:32<00:08,  2.03it/s]


Current metrics after 490 examples:
Average ANLS: 0.5153
Average inference time: 0.38s


original Progress:  99%|█████████▊| 501/508 [03:36<00:02,  2.79it/s]


Current metrics after 500 examples:
Average ANLS: 0.5117
Average inference time: 0.37s


original Progress: 100%|██████████| 508/508 [03:38<00:00,  2.32it/s]



=== Benchmarking int8 model ===


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]


=== Model Configuration (int8) ===
Flash Attention 2: Enabled
Image Splitting: Disabled
Float Type: fp16
Quantization: int8

=== Memory Usage Analysis ===
Parameter count: 8,402,768,112
Parameters size (bfloat16): 15.65 GB
CUDA memory allocated: 24.49 GB
CUDA memory reserved: 24.57 GB

Module types that might contribute to memory usage:
  Linear8bitLt: 380
  MistralRMSNorm: 65
  LayerNorm: 55
  SiLU: 36
  MistralDecoderLayer: 32
  MistralAttention: 32
  MistralMLP: 32
  Linear: 31
  Idefics2EncoderLayer: 27
  Idefics2VisionAttention: 27
Processing 100 documents with 508 QA pairs


int8 Progress:   0%|          | 1/508 [00:00<07:48,  1.08it/s]


Question: On which date should the form be completed and send ?
Model answer: September 28
Ground truth: ['September 28']
ANLS: 1.0000


int8 Progress:   0%|          | 2/508 [00:02<11:19,  1.34s/it]


Question: What is the date mentioned in the form ?
Model answer: September 29, 1969
Ground truth: ['September 28, 1988']
ANLS: 0.8333


int8 Progress:   1%|          | 3/508 [00:03<10:10,  1.21s/it]


Question: What is the year of operating plan ?
Model answer: 1969
Ground truth: ['1989']
ANLS: 0.7500


int8 Progress:   2%|▏         | 10/508 [00:10<08:41,  1.05s/it]


Current metrics after 10 examples:
Average ANLS: 0.6369
Average inference time: 1.04s


int8 Progress:   4%|▍         | 20/508 [00:22<09:50,  1.21s/it]


Current metrics after 20 examples:
Average ANLS: 0.6427
Average inference time: 1.08s


int8 Progress:   6%|▌         | 30/508 [00:35<10:59,  1.38s/it]


Current metrics after 30 examples:
Average ANLS: 0.5893
Average inference time: 1.11s


int8 Progress:   8%|▊         | 40/508 [00:44<06:53,  1.13it/s]


Current metrics after 40 examples:
Average ANLS: 0.5462
Average inference time: 1.05s


int8 Progress:  10%|▉         | 50/508 [00:53<09:24,  1.23s/it]


Current metrics after 50 examples:
Average ANLS: 0.5844
Average inference time: 1.02s


int8 Progress:  12%|█▏        | 60/508 [01:09<09:13,  1.23s/it]


Current metrics after 60 examples:
Average ANLS: 0.5589
Average inference time: 1.07s


int8 Progress:  14%|█▍        | 70/508 [01:21<09:26,  1.29s/it]


Current metrics after 70 examples:
Average ANLS: 0.5343
Average inference time: 1.09s


int8 Progress:  16%|█▌        | 80/508 [01:32<07:02,  1.01it/s]


Current metrics after 80 examples:
Average ANLS: 0.5178
Average inference time: 1.08s


int8 Progress:  18%|█▊        | 90/508 [01:42<07:56,  1.14s/it]


Current metrics after 90 examples:
Average ANLS: 0.5394
Average inference time: 1.07s


int8 Progress:  20%|█▉        | 100/508 [01:54<07:05,  1.04s/it]


Current metrics after 100 examples:
Average ANLS: 0.5468
Average inference time: 1.07s


int8 Progress:  22%|██▏       | 110/508 [02:07<06:05,  1.09it/s]


Current metrics after 110 examples:
Average ANLS: 0.5242
Average inference time: 1.09s


int8 Progress:  24%|██▎       | 120/508 [02:19<08:05,  1.25s/it]


Current metrics after 120 examples:
Average ANLS: 0.5377
Average inference time: 1.09s


int8 Progress:  26%|██▌       | 130/508 [02:34<10:01,  1.59s/it]


Current metrics after 130 examples:
Average ANLS: 0.5275
Average inference time: 1.12s


int8 Progress:  28%|██▊       | 140/508 [02:45<05:55,  1.04it/s]


Current metrics after 140 examples:
Average ANLS: 0.5417
Average inference time: 1.11s


int8 Progress:  30%|██▉       | 150/508 [02:53<05:38,  1.06it/s]


Current metrics after 150 examples:
Average ANLS: 0.5527
Average inference time: 1.09s


int8 Progress:  31%|███▏      | 160/508 [03:03<05:57,  1.03s/it]


Current metrics after 160 examples:
Average ANLS: 0.5441
Average inference time: 1.08s


int8 Progress:  33%|███▎      | 170/508 [03:16<06:50,  1.22s/it]


Current metrics after 170 examples:
Average ANLS: 0.5492
Average inference time: 1.09s


int8 Progress:  35%|███▌      | 180/508 [03:28<06:26,  1.18s/it]


Current metrics after 180 examples:
Average ANLS: 0.5298
Average inference time: 1.09s


int8 Progress:  37%|███▋      | 190/508 [03:37<04:59,  1.06it/s]


Current metrics after 190 examples:
Average ANLS: 0.5254
Average inference time: 1.08s


int8 Progress:  39%|███▉      | 200/508 [03:50<08:05,  1.58s/it]


Current metrics after 200 examples:
Average ANLS: 0.5156
Average inference time: 1.09s


int8 Progress:  41%|████▏     | 210/508 [04:01<05:07,  1.03s/it]


Current metrics after 210 examples:
Average ANLS: 0.5276
Average inference time: 1.09s


int8 Progress:  43%|████▎     | 220/508 [04:12<04:45,  1.01it/s]


Current metrics after 220 examples:
Average ANLS: 0.5341
Average inference time: 1.09s


int8 Progress:  45%|████▌     | 230/508 [04:25<04:42,  1.02s/it]


Current metrics after 230 examples:
Average ANLS: 0.5272
Average inference time: 1.09s


int8 Progress:  47%|████▋     | 240/508 [04:37<05:46,  1.29s/it]


Current metrics after 240 examples:
Average ANLS: 0.5201
Average inference time: 1.09s


int8 Progress:  49%|████▉     | 250/508 [04:51<06:24,  1.49s/it]


Current metrics after 250 examples:
Average ANLS: 0.5214
Average inference time: 1.10s


int8 Progress:  51%|█████     | 260/508 [05:02<05:29,  1.33s/it]


Current metrics after 260 examples:
Average ANLS: 0.5146
Average inference time: 1.10s


int8 Progress:  53%|█████▎    | 270/508 [05:13<04:45,  1.20s/it]


Current metrics after 270 examples:
Average ANLS: 0.5116
Average inference time: 1.10s


int8 Progress:  55%|█████▌    | 280/508 [05:28<04:50,  1.27s/it]


Current metrics after 280 examples:
Average ANLS: 0.5062
Average inference time: 1.11s


int8 Progress:  57%|█████▋    | 290/508 [05:41<04:57,  1.36s/it]


Current metrics after 290 examples:
Average ANLS: 0.4991
Average inference time: 1.12s


int8 Progress:  59%|█████▉    | 300/508 [05:52<04:17,  1.24s/it]


Current metrics after 300 examples:
Average ANLS: 0.4891
Average inference time: 1.12s


int8 Progress:  61%|██████    | 310/508 [06:06<04:16,  1.29s/it]


Current metrics after 310 examples:
Average ANLS: 0.4871
Average inference time: 1.12s


int8 Progress:  63%|██████▎   | 320/508 [06:17<02:46,  1.13it/s]


Current metrics after 320 examples:
Average ANLS: 0.4903
Average inference time: 1.12s


int8 Progress:  65%|██████▍   | 330/508 [06:28<03:49,  1.29s/it]


Current metrics after 330 examples:
Average ANLS: 0.4973
Average inference time: 1.12s


int8 Progress:  67%|██████▋   | 340/508 [06:40<03:17,  1.18s/it]


Current metrics after 340 examples:
Average ANLS: 0.5046
Average inference time: 1.12s


int8 Progress:  69%|██████▉   | 350/508 [06:50<02:43,  1.03s/it]


Current metrics after 350 examples:
Average ANLS: 0.5035
Average inference time: 1.11s


int8 Progress:  71%|███████   | 360/508 [06:59<01:52,  1.32it/s]


Current metrics after 360 examples:
Average ANLS: 0.5020
Average inference time: 1.10s


int8 Progress:  73%|███████▎  | 370/508 [07:10<02:36,  1.13s/it]


Current metrics after 370 examples:
Average ANLS: 0.5029
Average inference time: 1.10s


int8 Progress:  75%|███████▍  | 380/508 [07:23<02:56,  1.38s/it]


Current metrics after 380 examples:
Average ANLS: 0.4937
Average inference time: 1.11s


int8 Progress:  77%|███████▋  | 390/508 [07:37<02:09,  1.09s/it]


Current metrics after 390 examples:
Average ANLS: 0.4907
Average inference time: 1.11s


int8 Progress:  79%|███████▊  | 400/508 [07:48<01:59,  1.11s/it]


Current metrics after 400 examples:
Average ANLS: 0.4906
Average inference time: 1.11s


int8 Progress:  81%|████████  | 410/508 [07:59<01:22,  1.19it/s]


Current metrics after 410 examples:
Average ANLS: 0.4956
Average inference time: 1.11s


int8 Progress:  83%|████████▎ | 420/508 [08:09<01:30,  1.03s/it]


Current metrics after 420 examples:
Average ANLS: 0.4950
Average inference time: 1.11s


int8 Progress:  85%|████████▍ | 430/508 [08:20<01:17,  1.01it/s]


Current metrics after 430 examples:
Average ANLS: 0.4929
Average inference time: 1.11s


int8 Progress:  87%|████████▋ | 440/508 [08:31<01:08,  1.01s/it]


Current metrics after 440 examples:
Average ANLS: 0.4970
Average inference time: 1.10s


int8 Progress:  89%|████████▊ | 450/508 [08:43<00:59,  1.03s/it]


Current metrics after 450 examples:
Average ANLS: 0.5011
Average inference time: 1.10s


int8 Progress:  91%|█████████ | 460/508 [08:52<00:45,  1.06it/s]


Current metrics after 460 examples:
Average ANLS: 0.4963
Average inference time: 1.10s


int8 Progress:  93%|█████████▎| 470/508 [09:05<00:49,  1.31s/it]


Current metrics after 470 examples:
Average ANLS: 0.4996
Average inference time: 1.10s


int8 Progress:  94%|█████████▍| 480/508 [09:18<00:44,  1.61s/it]


Current metrics after 480 examples:
Average ANLS: 0.5063
Average inference time: 1.11s


int8 Progress:  96%|█████████▋| 490/508 [09:30<00:21,  1.19s/it]


Current metrics after 490 examples:
Average ANLS: 0.5038
Average inference time: 1.11s


int8 Progress:  98%|█████████▊| 500/508 [09:40<00:09,  1.17s/it]


Current metrics after 500 examples:
Average ANLS: 0.4998
Average inference time: 1.10s


int8 Progress: 100%|██████████| 508/508 [09:48<00:00,  1.16s/it]



=== Benchmarking int4 model ===


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]


=== Model Configuration (int4) ===
Flash Attention 2: Enabled
Image Splitting: Disabled
Float Type: fp16
Quantization: int4

=== Memory Usage Analysis ===
Parameter count: 4,715,009,264
Parameters size (bfloat16): 8.78 GB
CUDA memory allocated: 21.14 GB
CUDA memory reserved: 21.23 GB

Module types that might contribute to memory usage:
  Linear4bit: 383
  MistralRMSNorm: 65
  LayerNorm: 55
  SiLU: 36
  MistralDecoderLayer: 32
  MistralAttention: 32
  MistralMLP: 32
  Linear: 28
  Idefics2EncoderLayer: 27
  Idefics2VisionAttention: 27
Processing 100 documents with 508 QA pairs


int4 Progress:   0%|          | 1/508 [00:00<06:28,  1.30it/s]


Question: On which date should the form be completed and send ?
Model answer: September 28
Ground truth: ['September 28']
ANLS: 1.0000


int4 Progress:   0%|          | 2/508 [00:01<07:13,  1.17it/s]


Question: What is the date mentioned in the form ?
Model answer: September 29, 1966
Ground truth: ['September 28, 1988']
ANLS: 0.8333


int4 Progress:   1%|          | 3/508 [00:02<06:13,  1.35it/s]


Question: What is the year of operating plan ?
Model answer: 1969
Ground truth: ['1989']
ANLS: 0.7500


int4 Progress:   2%|▏         | 10/508 [00:07<05:34,  1.49it/s]


Current metrics after 10 examples:
Average ANLS: 0.6928
Average inference time: 0.66s


int4 Progress:   4%|▍         | 20/508 [00:13<06:09,  1.32it/s]


Current metrics after 20 examples:
Average ANLS: 0.6823
Average inference time: 0.62s


int4 Progress:   6%|▌         | 30/508 [00:21<06:35,  1.21it/s]


Current metrics after 30 examples:
Average ANLS: 0.5802
Average inference time: 0.67s


int4 Progress:   8%|▊         | 40/508 [00:27<04:12,  1.85it/s]


Current metrics after 40 examples:
Average ANLS: 0.5435
Average inference time: 0.64s


int4 Progress:  10%|▉         | 50/508 [00:33<04:33,  1.68it/s]


Current metrics after 50 examples:
Average ANLS: 0.5787
Average inference time: 0.61s


int4 Progress:  12%|█▏        | 60/508 [00:41<05:14,  1.43it/s]


Current metrics after 60 examples:
Average ANLS: 0.5481
Average inference time: 0.61s


int4 Progress:  14%|█▍        | 70/508 [00:49<06:12,  1.17it/s]


Current metrics after 70 examples:
Average ANLS: 0.5294
Average inference time: 0.62s


int4 Progress:  16%|█▌        | 80/508 [00:55<04:16,  1.67it/s]


Current metrics after 80 examples:
Average ANLS: 0.5208
Average inference time: 0.61s


int4 Progress:  18%|█▊        | 90/508 [01:01<04:39,  1.49it/s]


Current metrics after 90 examples:
Average ANLS: 0.5426
Average inference time: 0.61s


int4 Progress:  20%|█▉        | 100/508 [01:08<04:10,  1.63it/s]


Current metrics after 100 examples:
Average ANLS: 0.5323
Average inference time: 0.62s


int4 Progress:  22%|██▏       | 110/508 [01:15<03:23,  1.96it/s]


Current metrics after 110 examples:
Average ANLS: 0.5021
Average inference time: 0.61s


int4 Progress:  24%|██▎       | 120/508 [01:21<04:41,  1.38it/s]


Current metrics after 120 examples:
Average ANLS: 0.5156
Average inference time: 0.61s


int4 Progress:  26%|██▌       | 130/508 [01:30<05:51,  1.08it/s]


Current metrics after 130 examples:
Average ANLS: 0.5071
Average inference time: 0.63s


int4 Progress:  28%|██▊       | 140/508 [01:36<03:20,  1.84it/s]


Current metrics after 140 examples:
Average ANLS: 0.5241
Average inference time: 0.62s


int4 Progress:  30%|██▉       | 150/508 [01:41<03:17,  1.81it/s]


Current metrics after 150 examples:
Average ANLS: 0.5455
Average inference time: 0.61s


int4 Progress:  31%|███▏      | 160/508 [01:47<03:32,  1.64it/s]


Current metrics after 160 examples:
Average ANLS: 0.5342
Average inference time: 0.60s


int4 Progress:  33%|███▎      | 170/508 [01:53<03:49,  1.47it/s]


Current metrics after 170 examples:
Average ANLS: 0.5368
Average inference time: 0.60s


int4 Progress:  35%|███▌      | 180/508 [02:01<03:54,  1.40it/s]


Current metrics after 180 examples:
Average ANLS: 0.5181
Average inference time: 0.61s


int4 Progress:  37%|███▋      | 190/508 [02:07<02:58,  1.78it/s]


Current metrics after 190 examples:
Average ANLS: 0.5119
Average inference time: 0.60s


int4 Progress:  39%|███▉      | 200/508 [02:14<03:55,  1.31it/s]


Current metrics after 200 examples:
Average ANLS: 0.5035
Average inference time: 0.61s


int4 Progress:  41%|████▏     | 210/508 [02:21<03:14,  1.53it/s]


Current metrics after 210 examples:
Average ANLS: 0.5102
Average inference time: 0.61s


int4 Progress:  43%|████▎     | 220/508 [02:28<02:51,  1.68it/s]


Current metrics after 220 examples:
Average ANLS: 0.5175
Average inference time: 0.61s


int4 Progress:  45%|████▌     | 230/508 [02:34<02:39,  1.75it/s]


Current metrics after 230 examples:
Average ANLS: 0.5123
Average inference time: 0.61s


int4 Progress:  47%|████▋     | 240/508 [02:40<02:25,  1.85it/s]


Current metrics after 240 examples:
Average ANLS: 0.5093
Average inference time: 0.60s


int4 Progress:  49%|████▉     | 250/508 [02:47<03:29,  1.23it/s]


Current metrics after 250 examples:
Average ANLS: 0.5054
Average inference time: 0.61s


int4 Progress:  51%|█████     | 260/508 [02:53<02:42,  1.52it/s]


Current metrics after 260 examples:
Average ANLS: 0.4991
Average inference time: 0.61s


int4 Progress:  53%|█████▎    | 270/508 [02:59<02:29,  1.60it/s]


Current metrics after 270 examples:
Average ANLS: 0.4973
Average inference time: 0.60s


int4 Progress:  55%|█████▌    | 280/508 [03:09<02:54,  1.31it/s]


Current metrics after 280 examples:
Average ANLS: 0.4901
Average inference time: 0.61s


int4 Progress:  57%|█████▋    | 290/508 [03:15<02:39,  1.37it/s]


Current metrics after 290 examples:
Average ANLS: 0.4797
Average inference time: 0.61s


int4 Progress:  59%|█████▉    | 300/508 [03:22<02:30,  1.38it/s]


Current metrics after 300 examples:
Average ANLS: 0.4704
Average inference time: 0.61s


int4 Progress:  61%|██████    | 310/508 [03:29<02:10,  1.52it/s]


Current metrics after 310 examples:
Average ANLS: 0.4659
Average inference time: 0.62s


int4 Progress:  63%|██████▎   | 320/508 [03:36<01:40,  1.86it/s]


Current metrics after 320 examples:
Average ANLS: 0.4702
Average inference time: 0.62s


int4 Progress:  65%|██████▍   | 330/508 [03:43<02:15,  1.32it/s]


Current metrics after 330 examples:
Average ANLS: 0.4747
Average inference time: 0.62s


int4 Progress:  67%|██████▋   | 340/508 [03:49<01:49,  1.53it/s]


Current metrics after 340 examples:
Average ANLS: 0.4810
Average inference time: 0.62s


int4 Progress:  69%|██████▉   | 350/508 [03:55<01:37,  1.63it/s]


Current metrics after 350 examples:
Average ANLS: 0.4840
Average inference time: 0.61s


int4 Progress:  71%|███████   | 360/508 [04:00<01:06,  2.23it/s]


Current metrics after 360 examples:
Average ANLS: 0.4817
Average inference time: 0.61s


int4 Progress:  73%|███████▎  | 370/508 [04:06<01:15,  1.83it/s]


Current metrics after 370 examples:
Average ANLS: 0.4830
Average inference time: 0.61s


int4 Progress:  75%|███████▍  | 380/508 [04:14<01:48,  1.18it/s]


Current metrics after 380 examples:
Average ANLS: 0.4810
Average inference time: 0.61s


int4 Progress:  77%|███████▋  | 390/508 [04:21<01:19,  1.49it/s]


Current metrics after 390 examples:
Average ANLS: 0.4799
Average inference time: 0.61s


int4 Progress:  79%|███████▊  | 400/508 [04:28<01:15,  1.42it/s]


Current metrics after 400 examples:
Average ANLS: 0.4829
Average inference time: 0.61s


int4 Progress:  81%|████████  | 410/508 [04:34<00:48,  2.01it/s]


Current metrics after 410 examples:
Average ANLS: 0.4882
Average inference time: 0.61s


int4 Progress:  83%|████████▎ | 420/508 [04:40<00:55,  1.57it/s]


Current metrics after 420 examples:
Average ANLS: 0.4883
Average inference time: 0.61s


int4 Progress:  85%|████████▍ | 430/508 [04:46<00:42,  1.84it/s]


Current metrics after 430 examples:
Average ANLS: 0.4874
Average inference time: 0.61s


int4 Progress:  87%|████████▋ | 440/508 [04:52<00:37,  1.81it/s]


Current metrics after 440 examples:
Average ANLS: 0.4920
Average inference time: 0.61s


int4 Progress:  89%|████████▊ | 450/508 [05:00<00:38,  1.52it/s]


Current metrics after 450 examples:
Average ANLS: 0.4935
Average inference time: 0.61s


int4 Progress:  91%|█████████ | 460/508 [05:07<00:33,  1.44it/s]


Current metrics after 460 examples:
Average ANLS: 0.4878
Average inference time: 0.61s


int4 Progress:  93%|█████████▎| 470/508 [05:13<00:27,  1.36it/s]


Current metrics after 470 examples:
Average ANLS: 0.4900
Average inference time: 0.61s


int4 Progress:  94%|█████████▍| 480/508 [05:21<00:25,  1.11it/s]


Current metrics after 480 examples:
Average ANLS: 0.4958
Average inference time: 0.61s


int4 Progress:  96%|█████████▋| 490/508 [05:28<00:13,  1.30it/s]


Current metrics after 490 examples:
Average ANLS: 0.4943
Average inference time: 0.61s


int4 Progress:  98%|█████████▊| 500/508 [05:34<00:05,  1.56it/s]


Current metrics after 500 examples:
Average ANLS: 0.4881
Average inference time: 0.61s


int4 Progress: 100%|██████████| 508/508 [05:39<00:00,  1.50it/s]



=== DocVQA Benchmark Results ===
Model Type | ANLS Score | Avg Inference Time (s) | Peak GPU Memory (GB)
----------|------------|------------------------|--------------------
original   | 0.5092 | 0.3738 | 31.71
int8       | 0.4974 | 1.1002 | 31.71
int4       | 0.4853 | 0.6103 | 31.71

=== Relative Performance ===
int8 vs original:
  - Accuracy: 0.98x (97.7%)
  - Speed: 0.34x faster
  - Memory: 1.00x (100.0% of original)
int4 vs original:
  - Accuracy: 0.95x (95.3%)
  - Speed: 0.61x faster
  - Memory: 1.00x (100.0% of original)

Detailed results saved to /content/idefics2_docvqa_results.json


{'original': {'model_type': 'original',
  'anls_score': 0.5091671280968654,
  'avg_inference_time': 0.37379346588465173,
  'num_examples': 508,
  'gpu_memory_peak_gb': 31.712890625,
  'sample_results': [{'question': 'On which date should the form be completed and send ?',
    'model_answer': 'September 28',
    'ground_truth': ['September 28'],
    'anls': 1.0,
    'inference_time': 1.1801178455352783},
   {'question': 'What is the date mentioned in the form ?',
    'model_answer': 'September 29, 1966',
    'ground_truth': ['September 28, 1988'],
    'anls': 0.8333333333333334,
    'inference_time': 0.5389549732208252},
   {'question': 'What is the year of operating plan ?',
    'model_answer': '1969',
    'ground_truth': ['1989'],
    'anls': 0.75,
    'inference_time': 0.330643892288208},
   {'question': 'What is the area of focus ?',
    'model_answer': 'Nutritional Science',
    'ground_truth': ['external influences', 'External Influences'],
    'anls': 0.0,
    'inference_time': 0

## TallyQA evals

In [8]:
import torch
from transformers import AutoProcessor, Idefics2ForConditionalGeneration, BitsAndBytesConfig
from datasets import load_dataset
from PIL import Image
import requests
import time
from tqdm import tqdm
import gc
import numpy as np
import traceback


def load_tallyqa_dataset():
    print("Loading TallyQA dataset from vikhyatk/tallyqa-test...")

    try:
        # Correctly load the dataset by properly accessing the test split
        dataset = load_dataset("vikhyatk/tallyqa-test")['test']
        print(f"Loaded dataset with {len(dataset)} items")

        # Create structured data for benchmarking
        structured_data = []

        # Use a smaller sample (300 items) for quicker benchmarking
        sample_size = min(300, len(dataset))

        for i in range(sample_size):
            item = dataset[i]
            if 'image' in item and 'qa' in item:
                image = item['image']
                qa_list = item['qa']

                for qa in qa_list:
                    # Take just one question per image to keep the benchmark faster
                    structured_data.append({
                        "image": image,
                        "question": qa['question'],
                        "answers": [qa['answer']],
                        "question_type": "counting",
                        "issimple": qa.get('is_simple', True),
                        "question_id": len(structured_data)
                    })
                    break  # Just use the first question per image

        print(f"Prepared {len(structured_data)} QA pairs for benchmarking")

        if len(structured_data) > 0:
            return structured_data
        else:
            print("No valid data found, using synthetic dataset")

    except Exception as e:
        print(f"Error loading TallyQA dataset: {e}")


def load_model(model_type):
    """
    Load IDEFICS-2 model with specified quantization
    """
    print(f"Loading IDEFICS-2 model with {model_type} precision...")

    # Clear GPU memory
    torch.cuda.empty_cache()
    gc.collect()

    # Common kwargs for processor
    processor_kwargs = {
        "do_image_splitting": False,  # Disable image splitting for memory efficiency
        "size": {"longest_edge": 448, "shortest_edge": 378}  # Lower resolution
    }

    # Common kwargs for all model variants
    common_kwargs = {
        "device_map": "auto",
        "torch_dtype": torch.float16,  # Using fp16 for better efficiency
        "low_cpu_mem_usage": True,
    }

    # Check if Flash Attention 2 is available
    try:
        import importlib
        flash_attn_spec = importlib.util.find_spec("flash_attn")
        has_flash_attn = flash_attn_spec is not None

        if has_flash_attn:
            common_kwargs["attn_implementation"] = "flash_attention_2"
            print("Flash Attention 2 is available and will be used.")
    except:
        print("Flash Attention 2 is not available, using standard attention.")

    # Load the appropriate model based on model_type
    if model_type == "fp16":
        processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b", **processor_kwargs)
        model = Idefics2ForConditionalGeneration.from_pretrained(
            "HuggingFaceM4/idefics2-8b",
            **common_kwargs
        )
    elif model_type == "int8":
        processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b", **processor_kwargs)
        model = Idefics2ForConditionalGeneration.from_pretrained(
            "/content/quantized_models/idefics2-8b-8bit-calibrated",
            **common_kwargs
        )
    elif model_type == "int4":
        processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b", **processor_kwargs)
        model = Idefics2ForConditionalGeneration.from_pretrained(
            "/content/quantized_models/idefics2-8b-4bit-nf4-calibrated",
            **common_kwargs
        )
    else:
        raise ValueError(f"Unknown model type: {model_type}")

    # Print model configuration details for verification
    print(f"\n=== Model Configuration ({model_type}) ===")
    print(f"Flash Attention 2: {'Enabled' if common_kwargs.get('attn_implementation') == 'flash_attention_2' else 'Disabled'}")
    print(f"Image Splitting: Disabled")
    print(f"Float Type: fp16")
    print(f"Quantization: {model_type}")

    model.eval()
    return model, processor


# Updated counting answer evaluation function
def evaluate_counting_answer(model_answer, ground_truth):
    """
    Evaluate counting question responses by extracting numbers
    """
    import re

    # Convert ground truth to integer
    gt_count = int(ground_truth[0])

    # Extract numbers from model answer
    numbers = re.findall(r'\b\d+\b', model_answer)

    if not numbers:
        return 0.0  # No number found in response

    # Use the first number found in the response
    try:
        predicted_count = int(numbers[0])

        # Calculate accuracy based on exact match or threshold
        if predicted_count == gt_count:
            return 1.0  # Exact match
        else:
            # Alternative: threshold-based score for close answers
            error = abs(predicted_count - gt_count)
            if gt_count > 0:
                relative_error = error / gt_count
                if relative_error <= 0.1:  # Within 10% error
                    return 0.5

            # For small counts, allow off-by-one
            if gt_count <= 10 and error == 1:
                return 0.5  # Close answer (off by 1)

            return 0.0  # Wrong answer
    except:
        return 0.0


def benchmark_model_tallyqa(model, processor, dataset, model_type):
    device = next(model.parameters()).device
    print(f"Running benchmark for {model_type} model on {device}")

    start_time = time.time()
    scores = []
    inference_times = []
    simple_scores = []
    complex_scores = []
    example_outputs = []

    # Make sure we have a dataset before proceeding
    if dataset is None or len(dataset) == 0:
        print("Empty dataset, skipping benchmark")
        return {
            "accuracy": 0.0,
            "simple_accuracy": 0.0,
            "complex_accuracy": 0.0,
            "avg_inference_time": 0.0,
            "total_time": 0.0,
            "memory_gb": 0.0,
            "examples": []
        }

    # Track peak memory usage
    torch.cuda.reset_peak_memory_stats()

    # Use tqdm function properly
    for i, example in enumerate(tqdm(dataset[:300])):  # Limit to 300 for faster results
        image = example["image"]
        question = example["question"]
        answers = example["answers"]
        issimple = example["issimple"]

        # Format prompt specifically for IDEFICS2 format for counting questions
        # Using the chat template
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": f"Look at the image and answer this counting question: {question} Give just a number as your answer."}
                ]
            }
        ]

        try:
            # Apply chat template and process inputs
            prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
            inputs = processor(
                text=prompt,
                images=[image],
                return_tensors="pt"
            )
            inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}

            # Measure inference time
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            inf_start = time.time()

            with torch.inference_mode():
                generated_ids = model.generate(
                    **inputs,
                    max_new_tokens=20,
                    do_sample=False
                )
                model_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
                model_answer = model_answer.split("Assistant: ")[-1].strip()

            torch.cuda.synchronize() if torch.cuda.is_available() else None
            inference_time = time.time() - inf_start
            inference_times.append(inference_time)

            # Evaluate counting accuracy
            score = evaluate_counting_answer(model_answer, answers)
            scores.append(score)

            # Track simple vs complex questions
            if issimple:
                simple_scores.append(score)
            else:
                complex_scores.append(score)

            # Save examples for later inspection
            if i < 5:
                example_outputs.append({
                    "question": question,
                    "model_answer": model_answer,
                    "ground_truth": answers[0],
                    "score": score,
                    "inference_time": inference_time
                })
        except Exception as e:
            print(f"Error processing example {i}: {e}")
            traceback.print_exc()
            continue

        # Clear cache after each iteration to prevent memory buildup
        if i % 10 == 0:
            torch.cuda.empty_cache()

    # Calculate metrics
    avg_score = sum(scores) / len(scores) if scores else 0
    avg_simple = sum(simple_scores) / len(simple_scores) if simple_scores else 0
    avg_complex = sum(complex_scores) / len(complex_scores) if complex_scores else 0
    avg_time = sum(inference_times) / len(inference_times) if inference_times else 0
    total_time = time.time() - start_time

    # Calculate peak memory usage
    peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)

    return {
        "accuracy": avg_score,
        "simple_accuracy": avg_simple,
        "complex_accuracy": avg_complex,
        "avg_inference_time": avg_time,
        "total_time": total_time,
        "peak_memory_gb": peak_memory_gb,
        "examples": example_outputs
    }


def run_tallyqa_benchmarks():
    # Load TallyQA dataset
    print("Loading TallyQA dataset...")
    structured_data = load_tallyqa_dataset()

    # Check if we actually got data
    if structured_data is None or len(structured_data) == 0:
        print("Failed to load dataset. Cannot proceed with benchmarking.")
        return {"error": "Failed to load dataset"}

    results = {}

    # Benchmark each model type separately
    for model_type in ["fp16", "int8", "int4"]:
        print(f"\n=== Benchmarking {model_type} model on TallyQA ===")
        try:
            # Clear memory
            gc.collect()
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

            # Load model
            model, processor = load_model(model_type)
            results[model_type] = benchmark_model_tallyqa(model, processor, structured_data, model_type)

            # Print immediate results
            print(f"\nResults for {model_type}:")
            print(f"Accuracy: {results[model_type]['accuracy']:.4f}")
            print(f"Simple Questions Accuracy: {results[model_type]['simple_accuracy']:.4f}")
            print(f"Complex Questions Accuracy: {results[model_type]['complex_accuracy']:.4f}")
            print(f"Average Inference Time: {results[model_type]['avg_inference_time']:.4f}s")
            print(f"Peak Memory Usage: {results[model_type]['peak_memory_gb']:.2f} GB")

            # Clean up
            del model, processor
            gc.collect()
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

        except Exception as e:
            print(f"Error benchmarking {model_type} model: {e}")
            traceback.print_exc()
            results[model_type] = {"error": str(e)}

    # Print comparison summary
    print("\n=== TallyQA Benchmark Summary for IDEFICS2 ===")
    print("Model | Overall Acc | Simple Acc | Complex Acc | Avg Time (s) | Peak Memory (GB)")
    print("------|-------------|------------|-------------|--------------|----------------")
    for model_type in ["fp16", "int8", "int4"]:
        if model_type in results and "error" not in results[model_type]:
            print(f"{model_type} | {results[model_type]['accuracy']:.4f} | "
                  f"{results[model_type]['simple_accuracy']:.4f} | "
                  f"{results[model_type]['complex_accuracy']:.4f} | "
                  f"{results[model_type]['avg_inference_time']:.4f} | "
                  f"{results[model_type]['peak_memory_gb']:.2f}")

    # Performance comparisons
    if all(model_type in results and "error" not in results[model_type] for model_type in ["fp16", "int8", "int4"]):
        print("\n=== Performance Comparison ===")

        # Avoid division by zero
        if results["fp16"]["accuracy"] > 0:
            print(f"INT8/FP16 accuracy ratio: {results['int8']['accuracy']/results['fp16']['accuracy']:.4f}x")
            print(f"INT4/FP16 accuracy ratio: {results['int4']['accuracy']/results['fp16']['accuracy']:.4f}x")

        if results["int8"]["avg_inference_time"] > 0 and results["int4"]["avg_inference_time"] > 0:
            print(f"INT8/FP16 speed ratio: {results['fp16']['avg_inference_time']/results['int8']['avg_inference_time']:.2f}x")
            print(f"INT4/FP16 speed ratio: {results['fp16']['avg_inference_time']/results['int4']['avg_inference_time']:.2f}x")

        print(f"INT8/FP16 memory ratio: {results['int8']['peak_memory_gb']/results['fp16']['peak_memory_gb']:.2f}x")
        print(f"INT4/FP16 memory ratio: {results['int4']['peak_memory_gb']/results['fp16']['peak_memory_gb']:.2f}x")

    return results

run_tallyqa_benchmarks()

Loading TallyQA dataset...
Loading TallyQA dataset from vikhyatk/tallyqa-test...


README.md:   0%|          | 0.00/476 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/19 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/19 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/19 [00:00<?, ?files/s]

test-00000-of-00019.parquet:   0%|          | 0.00/464M [00:00<?, ?B/s]

test-00001-of-00019.parquet:   0%|          | 0.00/452M [00:00<?, ?B/s]

test-00002-of-00019.parquet:   0%|          | 0.00/456M [00:00<?, ?B/s]

test-00003-of-00019.parquet:   0%|          | 0.00/455M [00:00<?, ?B/s]

test-00004-of-00019.parquet:   0%|          | 0.00/451M [00:00<?, ?B/s]

test-00005-of-00019.parquet:   0%|          | 0.00/449M [00:00<?, ?B/s]

test-00006-of-00019.parquet:   0%|          | 0.00/453M [00:00<?, ?B/s]

test-00007-of-00019.parquet:   0%|          | 0.00/454M [00:00<?, ?B/s]

test-00008-of-00019.parquet:   0%|          | 0.00/457M [00:00<?, ?B/s]

test-00009-of-00019.parquet:   0%|          | 0.00/452M [00:00<?, ?B/s]

test-00010-of-00019.parquet:   0%|          | 0.00/453M [00:00<?, ?B/s]

test-00011-of-00019.parquet:   0%|          | 0.00/449M [00:00<?, ?B/s]

test-00012-of-00019.parquet:   0%|          | 0.00/455M [00:00<?, ?B/s]

test-00013-of-00019.parquet:   0%|          | 0.00/515M [00:00<?, ?B/s]

test-00014-of-00019.parquet:   0%|          | 0.00/582M [00:00<?, ?B/s]

test-00015-of-00019.parquet:   0%|          | 0.00/490M [00:00<?, ?B/s]

test-00016-of-00019.parquet:   0%|          | 0.00/482M [00:00<?, ?B/s]

test-00017-of-00019.parquet:   0%|          | 0.00/540M [00:00<?, ?B/s]

test-00018-of-00019.parquet:   0%|          | 0.00/489M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/26451 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/18 [00:00<?, ?it/s]

Loaded dataset with 26451 items
Prepared 300 QA pairs for benchmarking

=== Benchmarking fp16 model on TallyQA ===
Loading IDEFICS-2 model with fp16 precision...


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]


=== Model Configuration (fp16) ===
Flash Attention 2: Disabled
Image Splitting: Disabled
Float Type: fp16
Quantization: fp16
Running benchmark for fp16 model on cuda:0


100%|██████████| 300/300 [00:56<00:00,  5.28it/s]



Results for fp16:
Accuracy: 0.8183
Simple Questions Accuracy: 0.8183
Complex Questions Accuracy: 0.0000
Average Inference Time: 0.1766s
Peak Memory Usage: 31.49 GB

=== Benchmarking int8 model on TallyQA ===
Loading IDEFICS-2 model with int8 precision...


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]


=== Model Configuration (int8) ===
Flash Attention 2: Disabled
Image Splitting: Disabled
Float Type: fp16
Quantization: int8
Running benchmark for int8 model on cuda:0


100%|██████████| 300/300 [02:56<00:00,  1.70it/s]



Results for int8:
Accuracy: 0.8183
Simple Questions Accuracy: 0.8183
Complex Questions Accuracy: 0.0000
Average Inference Time: 0.5751s
Peak Memory Usage: 24.64 GB

=== Benchmarking int4 model on TallyQA ===
Loading IDEFICS-2 model with int4 precision...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]


=== Model Configuration (int4) ===
Flash Attention 2: Disabled
Image Splitting: Disabled
Float Type: fp16
Quantization: int4
Running benchmark for int4 model on cuda:0


100%|██████████| 300/300 [01:39<00:00,  3.00it/s]



Results for int4:
Accuracy: 0.8333
Simple Questions Accuracy: 0.8333
Complex Questions Accuracy: 0.0000
Average Inference Time: 0.3202s
Peak Memory Usage: 21.39 GB

=== TallyQA Benchmark Summary for IDEFICS2 ===
Model | Overall Acc | Simple Acc | Complex Acc | Avg Time (s) | Peak Memory (GB)
------|-------------|------------|-------------|--------------|----------------
fp16 | 0.8183 | 0.8183 | 0.0000 | 0.1766 | 31.49
int8 | 0.8183 | 0.8183 | 0.0000 | 0.5751 | 24.64
int4 | 0.8333 | 0.8333 | 0.0000 | 0.3202 | 21.39

=== Performance Comparison ===
INT8/FP16 accuracy ratio: 1.0000x
INT4/FP16 accuracy ratio: 1.0183x
INT8/FP16 speed ratio: 0.31x
INT4/FP16 speed ratio: 0.55x
INT8/FP16 memory ratio: 0.78x
INT4/FP16 memory ratio: 0.68x


{'fp16': {'accuracy': 0.8183333333333334,
  'simple_accuracy': 0.8183333333333334,
  'complex_accuracy': 0,
  'avg_inference_time': 0.17658637603123983,
  'total_time': 56.778003215789795,
  'peak_memory_gb': 31.485440731048584,
  'examples': [{'question': 'How many people are there?',
    'model_answer': '2',
    'ground_truth': '2',
    'score': 1.0,
    'inference_time': 0.1619579792022705},
   {'question': 'How many cars are parked?',
    'model_answer': '2.',
    'ground_truth': '2',
    'score': 1.0,
    'inference_time': 0.17934894561767578},
   {'question': 'How many outlets are in the wall?',
    'model_answer': '2.',
    'ground_truth': '4',
    'score': 0.0,
    'inference_time': 0.17829036712646484},
   {'question': 'How many windows are there?',
    'model_answer': '1.',
    'ground_truth': '1',
    'score': 1.0,
    'inference_time': 0.18502259254455566},
   {'question': 'How many apples are on the table?',
    'model_answer': '1.',
    'ground_truth': '1',
    'score': 1