In [None]:
import torch
import os
import re
import json
import warnings
warnings.filterwarnings("ignore")
import numpy as np
from PIL import Image
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from peft import PeftModel, PeftConfig


import logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

from huggingface_hub import login
login(token='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')

# 
adapter_path = "../SFT/paligemma2-sat-sft/checkpoint-2000"
# adapter_path = "../GRPO/paligemma2-sat-grpo/checkpoint-1350"
peft_config = PeftConfig.from_pretrained(adapter_path)


base_model = PaliGemmaForConditionalGeneration.from_pretrained(
    peft_config.base_model_name_or_path,
    torch_dtype="auto",
    device_map="auto"
)


model = PeftModel.from_pretrained(base_model, adapter_path)
model.eval()
processor = AutoProcessor.from_pretrained(peft_config.base_model_name_or_path)

# model_id = "google/paligemma2-3b-mix-224"
# model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to(device).eval()
# processor = AutoProcessor.from_pretrained(model_id)


2025-05-05 15:43:42.577952: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-05 15:43:44.501896: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746474225.410635    7577 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746474225.636757    7577 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746474227.461736    7577 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

Using device: cuda


2025-05-05 15:44:05,336 - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [None]:
def extract_choice_answer(text):
    text = text.lower().strip()
    pattern1 = r'\(([a-z])\)'  # (A), (B), (C)
    pattern2 = r'\b([a-z])\b'  # A, B, C

    match = re.search(pattern1, text)
    if match:
        return match.group(1).lower()

    match = re.search(pattern2, text)
    if match:
        return match.group(1).lower()

    return text

def extract_gt_answer(gt_text):
    gt_text = gt_text.lower().strip()
    match = re.search(r"<answer>\s*(.*?)\s*</answer>", gt_text)
    if match:
        return match.group(1).strip().lower()
    return gt_text.strip().lower()

def normalize_answer(text, task_type=None):
    if not text:
        return ""

    text = text.lower().strip()


    if "\n" in text:
      text = text.split("\n")[-1].strip()
    if task_type in ['Count', 'Relation', 'Distance', 'Depth']:
        choice = extract_choice_answer(text)
        if choice and choice.lower() in "abcde":
            return choice.lower()

    return text

def evaluate_batch_with_sampling(model, processor, batch, batch_size, num_samples=4):
    images_raw = batch["image"]  # List of PIL images
    questions = batch["problem"] if "problem" in batch else batch["prompt"]
    gt_answers_raw = batch["solution"] if "solution" in batch else batch["answer"]
    gt_answers = [str(ans).lower() for ans in gt_answers_raw]


    inputs = processor(text=questions, images=images_raw, return_tensors="pt", padding=True).to(device)

    all_responses = []
    

    for _ in range(num_samples):

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                do_sample=True, 
                temperature=1.0, 
                # top_p=0.9, 
                eos_token_id=processor.tokenizer.eos_token_id,
                pad_token_id=processor.tokenizer.pad_token_id,
            )


            responses_batch = processor.batch_decode(outputs, skip_special_tokens=True)
            all_responses.append(responses_batch)
    
    # [[sample1_try1, sample1_try2, ...], [sample2_try1, sample2_try2, ...], ...]
    responses_by_sample = list(zip(*all_responses))
    
    return responses_by_sample, gt_answers, questions

def compute_accuracy_with_sampling(responses_by_sample, ground_truths, task_type=None):
    
    assert len(responses_by_sample) == len(ground_truths)
    correct = 0
    
    for sample_responses, gt in zip(responses_by_sample, ground_truths):
        gt_norm = normalize_answer(extract_gt_answer(gt), task_type)
        
        
        sample_correct = False
        for pred in sample_responses:
            pred_norm = normalize_answer(pred, task_type)
            if pred_norm == gt_norm:
                sample_correct = True
                break
        
        if sample_correct:
            correct += 1
        # else:
        #     print(f"❌ Wrong — Preds: {sample_responses} | GT: '{gt}' → {gt_norm}")

    return correct / len(responses_by_sample)

def evaluate_cvbench(model, processor, num_samples=4):
    print(f"\n=== Evaluating CVBench dataset with {num_samples} samples per question ===")

    try:
        ds = load_dataset("nyu-visionx/CV-Bench")
        print(f"Loaded CVBench dataset with {len(ds['test'])} test samples")
    except Exception as e:
        print(f"Error loading CVBench dataset: {e}")
        return {}

    
    tasks = ['Count', 'Relation', 'Distance', 'Depth']
    # tasks = ['Count'] 
    
    task_data = {}
    for task in tasks:
        task_data[task] = [item for item in ds['test'] if item['task'] == task]
        print(f"Found {len(task_data[task])} samples for {task} task")

    
    results = {}
    batch_size = 50

    for task in tasks:
        print(f"\n--- Evaluating {task} task with {num_samples} samples per question ---")
        task_samples = task_data[task]

        
        num_samples_total = (len(task_samples) // batch_size) * batch_size
        if num_samples_total == 0:
            print(f"Skipping {task} task: not enough samples")
            continue

        task_samples = task_samples[:num_samples_total]

        responses_by_sample = []
        ground_truths = []


        for i in tqdm(range(0, len(task_samples), batch_size)):
            batch = task_samples[i:i+batch_size]

            batch_dict = {
                "image": [item["image"] for item in batch],
                "prompt": ['<image>'+ item["prompt"] for item in batch],
                "answer": [item["answer"] for item in batch]
            }

          
            batch_responses, batch_gt, questions = evaluate_batch_with_sampling(
                model, processor, batch_dict, batch_size, num_samples=num_samples
            )

            
            responses_by_sample.extend(batch_responses)
            ground_truths.extend(batch_gt)

          
            if i % 100 == 0:
                accuracy = compute_accuracy_with_sampling(responses_by_sample, ground_truths, task)
                print(f"{task} Accuracy with {num_samples} samples: {accuracy:.2%}")

    
        accuracy = compute_accuracy_with_sampling(responses_by_sample, ground_truths, task)
        print(f"{task} Accuracy with {num_samples} samples: {accuracy:.2%}")
        results[task] = accuracy


    if results:
        overall_accuracy = sum(results.values()) / len(results)
        print(f"\nCVBench Overall Accuracy with {num_samples} samples: {overall_accuracy:.2%}")
        results["overall"] = overall_accuracy
        results["sampling_info"] = f"{num_samples} samples per question, correct if any sample is correct"

    return results

def save_results(results, output_file="cvbench_results_sampling.json"):

    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    print(f"Results saved to {output_file}")

In [None]:
results = evaluate_cvbench(model, processor, num_samples=4)
save_results(results)

print("\n=== Summary of Results ===")
for dataset, result in results.items():
    if isinstance(result, dict):
        print(f"{dataset} results:")
        for task, acc in result.items():
            print(f"  {task}: {acc:.2%}")
    else:
        print(f"{dataset}: {result:.2%}")


=== Evaluating CVBench dataset with 4 samples per question ===
Loaded CVBench dataset with 2638 test samples
Found 788 samples for Count task
Found 650 samples for Relation task
Found 600 samples for Distance task
Found 600 samples for Depth task

--- Evaluating Count task with 4 samples per question ---


  7%|▋         | 1/15 [00:11<02:42, 11.57s/it]

Count Accuracy with 4 samples: 78.00%


 20%|██        | 3/15 [00:21<01:18,  6.54s/it]

Count Accuracy with 4 samples: 68.00%


 33%|███▎      | 5/15 [00:32<00:59,  5.92s/it]

Count Accuracy with 4 samples: 64.40%


 47%|████▋     | 7/15 [00:42<00:43,  5.44s/it]

Count Accuracy with 4 samples: 66.29%


 60%|██████    | 9/15 [00:51<00:29,  4.98s/it]

Count Accuracy with 4 samples: 67.11%


 73%|███████▎  | 11/15 [01:01<00:19,  4.77s/it]

Count Accuracy with 4 samples: 69.27%


 87%|████████▋ | 13/15 [01:10<00:09,  4.66s/it]

Count Accuracy with 4 samples: 71.38%


100%|██████████| 15/15 [01:19<00:00,  5.29s/it]


Count Accuracy with 4 samples: 72.80%
Count Accuracy with 4 samples: 72.80%

--- Evaluating Relation task with 4 samples per question ---


  8%|▊         | 1/13 [00:04<00:56,  4.75s/it]

Relation Accuracy with 4 samples: 84.00%


 23%|██▎       | 3/13 [00:14<00:47,  4.76s/it]

Relation Accuracy with 4 samples: 82.00%


 38%|███▊      | 5/13 [00:23<00:38,  4.76s/it]

Relation Accuracy with 4 samples: 81.20%


 54%|█████▍    | 7/13 [00:33<00:28,  4.73s/it]

Relation Accuracy with 4 samples: 80.57%


 69%|██████▉   | 9/13 [00:42<00:18,  4.65s/it]

Relation Accuracy with 4 samples: 81.11%


 85%|████████▍ | 11/13 [00:51<00:09,  4.63s/it]

Relation Accuracy with 4 samples: 82.18%


100%|██████████| 13/13 [01:00<00:00,  4.68s/it]


Relation Accuracy with 4 samples: 82.92%
Relation Accuracy with 4 samples: 82.92%

--- Evaluating Distance task with 4 samples per question ---


  8%|▊         | 1/12 [00:10<01:59, 10.90s/it]

Distance Accuracy with 4 samples: 90.00%


 25%|██▌       | 3/12 [00:32<01:38, 10.93s/it]

Distance Accuracy with 4 samples: 93.33%


 42%|████▏     | 5/12 [00:54<01:15, 10.83s/it]

Distance Accuracy with 4 samples: 92.00%


 58%|█████▊    | 7/12 [01:15<00:53, 10.76s/it]

Distance Accuracy with 4 samples: 92.29%


 75%|███████▌  | 9/12 [01:37<00:32, 10.90s/it]

Distance Accuracy with 4 samples: 91.78%


 92%|█████████▏| 11/12 [02:00<00:11, 11.09s/it]

Distance Accuracy with 4 samples: 90.55%


100%|██████████| 12/12 [02:11<00:00, 10.96s/it]


Distance Accuracy with 4 samples: 90.00%

--- Evaluating Depth task with 4 samples per question ---


  8%|▊         | 1/12 [00:10<01:58, 10.75s/it]

Depth Accuracy with 4 samples: 92.00%


 25%|██▌       | 3/12 [00:32<01:36, 10.74s/it]

Depth Accuracy with 4 samples: 88.00%


 42%|████▏     | 5/12 [00:53<01:14, 10.65s/it]

Depth Accuracy with 4 samples: 90.00%


 58%|█████▊    | 7/12 [01:14<00:52, 10.57s/it]

Depth Accuracy with 4 samples: 90.86%


 75%|███████▌  | 9/12 [01:35<00:31, 10.63s/it]

Depth Accuracy with 4 samples: 91.33%


 92%|█████████▏| 11/12 [01:57<00:10, 10.64s/it]

Depth Accuracy with 4 samples: 91.09%


100%|██████████| 12/12 [02:07<00:00, 10.63s/it]


Depth Accuracy with 4 samples: 91.17%

CVBench Overall Accuracy with 4 samples: 84.22%
Results saved to cvbench_results_sampling.json

=== Summary of Results ===
Count: 72.80%
Relation: 82.92%
Distance: 90.00%
Depth: 91.17%
overall: 84.22%


ValueError: Unknown format code '%' for object of type 'str'