In [1]:
import torch
import os
import re
import json
import warnings
warnings.filterwarnings("ignore")
import numpy as np
from PIL import Image
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from peft import PeftModel, PeftConfig


import logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

from huggingface_hub import login
login(token='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')


# # adapter_path = "../SFT/paligemma2-sat-sft/checkpoint-2000"
# adapter_path = "../GRPO/paligemma2-sat-grpo/checkpoint-2000"
# peft_config = PeftConfig.from_pretrained(adapter_path)


# base_model = PaliGemmaForConditionalGeneration.from_pretrained(
#     peft_config.base_model_name_or_path,
#     torch_dtype="auto",
#     device_map="auto"
# )


# model = PeftModel.from_pretrained(base_model, adapter_path).to(device)
# model.eval()
# processor = AutoProcessor.from_pretrained(peft_config.base_model_name_or_path)

model_id = "google/paligemma2-3b-mix-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to(device).eval()
processor = AutoProcessor.from_pretrained(model_id)


2025-05-14 18:42:21.469784: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747262542.468856    7573 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747262542.713843    7573 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1747262544.832974    7573 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747262544.833023    7573 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1747262544.833026    7573 computation_placer.cc:177] computation placer alr

Using device: cuda


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [2]:
def extract_choice_answer(text):
    """从文本中提取选择题答案"""
    text = text.lower().strip()
    pattern1 = r'\(([a-z])\)'  # (A), (B), (C)
    pattern2 = r'\b([a-z])\b'  # A, B, C

    match = re.search(pattern1, text)
    if match:
        return match.group(1).lower()

    match = re.search(pattern2, text)
    if match:
        return match.group(1).lower()

    return text

def extract_gt_answer(gt_text):
    """提取答案中的标准答案"""
    gt_text = gt_text.lower().strip()
    match = re.search(r"<answer>\s*(.*?)\s*</answer>", gt_text)
    if match:
        return match.group(1).strip().lower()
    return gt_text.strip().lower()

def change_prompt(prompt):
  # print(prompt)
  new_prompt = prompt.replace("Which object is closer to","Which object is far away from")
  # print(new_prompt)
  return new_prompt

def change_answer(answer):
    # print("old", answer)
    if "(A)" in answer:
        new_answer = answer.replace("(A)", "(B)")
    elif "(B)" in answer:
        new_answer = answer.replace("(B)", "(A)")
    else:
        new_answer = answer
        print(new_answer)
    return new_answer
    
def normalize_answer(text, task_type=None):
    """规范化答案格式"""
    if not text:
        return ""

    text = text.lower().strip()

    if "\n" in text:
      text = text.split("\n")[-1].strip()
    if task_type in ['Count', 'Relation', 'Distance', 'Depth']:
    # if task_type in ['Depth']:
        choice = extract_choice_answer(text)
        if choice and choice.lower() in "abcde":
            return choice.lower()

    # numbers = re.findall(r"\d+", text)
    # if numbers:
    #     return numbers[0]

    # if "yes" in text:
    #     return "yes"
    # elif "no" in text:
    #     return "no"

    return text

def evaluate_batch_with_sampling(model, processor, batch, batch_size, num_samples=4):
    """评估一批样本，每个样本采样多次"""
    images_raw = batch["image"]  # List of PIL images
    questions = batch["problem"] if "problem" in batch else batch["prompt"]
    gt_answers_raw = batch["solution"] if "solution" in batch else batch["answer"]
    gt_answers = [str(ans).lower() for ans in gt_answers_raw]

    # 预处理输入
    inputs = processor(text=questions, images=images_raw, return_tensors="pt", padding=True).to(device)

    all_responses = []

    if num_samples==1:
        for _ in range(num_samples):
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    do_sample=False, 
                    eos_token_id=processor.tokenizer.eos_token_id,
                    pad_token_id=processor.tokenizer.pad_token_id,
                )
                # 解码输出
                responses_batch = processor.batch_decode(outputs, skip_special_tokens=True)
                all_responses.append(responses_batch)
    else:
        for _ in range(num_samples):
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    do_sample=True, 
                    temperature=1.0,
                    # top_p=0.9,
                    eos_token_id=processor.tokenizer.eos_token_id,
                    pad_token_id=processor.tokenizer.pad_token_id,
                )
    
                # 解码输出
                responses_batch = processor.batch_decode(outputs, skip_special_tokens=True)
                all_responses.append(responses_batch)
    
    # 转置结果，以便按样本组织
    # [[sample1_try1, sample1_try2, ...], [sample2_try1, sample2_try2, ...], ...]
    responses_by_sample = list(zip(*all_responses))
    
    return responses_by_sample, gt_answers, questions

def compute_accuracy_with_sampling(responses_by_sample, ground_truths, task_type=None):
    """计算采样准确率 - 只要有一次答对就算对"""
    assert len(responses_by_sample) == len(ground_truths)
    correct = 0
    
    for sample_responses, gt in zip(responses_by_sample, ground_truths):
        gt_norm = normalize_answer(extract_gt_answer(gt), task_type)
        
        # 检查该样本的所有尝试是否有一次答对
        sample_correct = False
        for pred in sample_responses:
            pred_norm = normalize_answer(pred, task_type)
            print(pred_norm,gt_norm)
            if pred_norm == gt_norm:
                sample_correct = True
                break
        
        if sample_correct:
            correct += 1
        # else:
        #     # 可选：输出所有错误预测
        #     print(f"❌ Wrong — Preds: {sample_responses} | GT: '{gt}' → {gt_norm}")

    return correct / len(responses_by_sample)

def evaluate_cvbench(model, processor, num_samples=4):
    """评估CV-Bench数据集，使用多次采样策略"""
    print(f"\n=== Evaluating CVBench dataset with {num_samples} samples per question ===")

    try:
        ds = load_dataset("nyu-visionx/CV-Bench")
        print(f"Loaded CVBench dataset with {len(ds['test'])} test samples")
    except Exception as e:
        print(f"Error loading CVBench dataset: {e}")
        return {}

    # 定义要评估的任务
    # tasks = ['Count', 'Relation', 'Distance', 'Depth']
    tasks = ['Distance']  # 如果只想测试一个任务，可以取消注释此行

    # 按任务分类样本
    task_data = {}
    for task in tasks:
        task_data[task] = [item for item in ds['test'] if item['task'] == task]
        print(f"Found {len(task_data[task])} samples for {task} task")

    # 评估每个任务
    results = {}
    batch_size = 10  # 可根据GPU内存调整

    for task in tasks:
        print(f"\n--- Evaluating {task} task with {num_samples} samples per question ---")
        task_samples = task_data[task]

        # 确保样本数量是batch_size的整数倍
        num_samples_total = (len(task_samples) // batch_size) * batch_size
        if num_samples_total == 0:
            print(f"Skipping {task} task: not enough samples")
            continue

        task_samples = task_samples[:num_samples_total]

        responses_by_sample = []
        ground_truths = []

        # 批量处理样本
        for i in tqdm(range(0, len(task_samples), batch_size)):
            
            
            batch = task_samples[i:i+batch_size]
            new_prompt=[change_prompt(item["prompt"])for item in batch]
            new_answer=[change_answer(item["answer"])for item in batch]
            
            batch_dict = {
                "image": [item["image"] for item in batch],
                "prompt": ['<image>'+ new_p for new_p in new_prompt],
                "answer": [new_a for new_a in new_answer]
            }

            # 评估批次，使用多次采样
            batch_responses, batch_gt, questions = evaluate_batch_with_sampling(
                model, processor, batch_dict, batch_size, num_samples=num_samples
            )

            # 记录结果
            responses_by_sample.extend(batch_responses)
            ground_truths.extend(batch_gt)

            # 定期输出中间结果
            if i % 100 == 0:
                accuracy = compute_accuracy_with_sampling(responses_by_sample, ground_truths, task)
                print(f"{task} Accuracy with {num_samples} samples: {accuracy:.2%}")

        # 计算任务准确率
        accuracy = compute_accuracy_with_sampling(responses_by_sample, ground_truths, task)
        print(f"{task} Accuracy with {num_samples} samples: {accuracy:.2%}")
        results[task] = accuracy

    # 计算总体准确率
    if results:
        overall_accuracy = sum(results.values()) / len(results)
        print(f"\nCVBench Overall Accuracy with {num_samples} samples: {overall_accuracy:.2%}")
        results["overall"] = overall_accuracy
        results["sampling_info"] = f"{num_samples} samples per question, correct if any sample is correct"

    return results

def save_results(results, output_file="cvbench_results_sampling.json"):
    """保存评估结果"""
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    print(f"Results saved to {output_file}")

In [3]:
results = evaluate_cvbench(model, processor, num_samples=1)
save_results(results)

    # 输出摘要
print("\n=== Summary of Results ===")
for dataset, result in results.items():
    if isinstance(result, dict):
        print(f"{dataset} results:")
        for task, acc in result.items():
            print(f"  {task}: {acc:.2%}")


=== Evaluating CVBench dataset with 1 samples per question ===
Loaded CVBench dataset with 2638 test samples
Found 600 samples for Distance task

--- Evaluating Distance task with 1 samples per question ---


  2%|▏         | 1/60 [00:07<07:39,  7.79s/it]

c b
c b
c b
c b
c a
c a
c b
b b
c a
c b
Distance Accuracy with 1 samples: 10.00%


 18%|█▊        | 11/60 [00:23<01:19,  1.61s/it]

c b
c b
c b
c b
c a
c a
c b
b b
c a
c b
c a
c b
c a
c a
c a
c a
c b
c a
c a
c b
c a
c a
c a
c b
c a
c a
c a
c a
a a
c a
c a
c a
c b
b a
c a
c a
b b
c b
c b
c a
c a
c b
c b
b b
c b
c b
c a
c b
c a
c b
c b
b b
c b
c a
c a
c a
b b
c a
c a
b a
c b
c a
b a
c b
c a
c a
c b
c a
c b
c b
c b
c b
c a
c a
c b
c b
c a
c b
b b
c b
c a
c a
c b
c b
c a
c a
b b
c a
c b
c a
b b
c b
c b
c b
c b
c a
c a
c b
c b
c b
c a
c a
c a
c b
c b
a b
c b
c a
c a
c a
Distance Accuracy with 1 samples: 8.18%


 35%|███▌      | 21/60 [00:38<01:00,  1.55s/it]

c b
c b
c b
c b
c a
c a
c b
b b
c a
c b
c a
c b
c a
c a
c a
c a
c b
c a
c a
c b
c a
c a
c a
c b
c a
c a
c a
c a
a a
c a
c a
c a
c b
b a
c a
c a
b b
c b
c b
c a
c a
c b
c b
b b
c b
c b
c a
c b
c a
c b
c b
b b
c b
c a
c a
c a
b b
c a
c a
b a
c b
c a
b a
c b
c a
c a
c b
c a
c b
c b
c b
c b
c a
c a
c b
c b
c a
c b
b b
c b
c a
c a
c b
c b
c a
c a
b b
c a
c b
c a
b b
c b
c b
c b
c b
c a
c a
c b
c b
c b
c a
c a
c a
c b
c b
a b
c b
c a
c a
c a
c b
c b
c a
c a
b a
c a
c a
c b
c a
c a
c a
c b
b b
c b
c b
c a
c b
c b
c b
c b
c b
c b
c a
c b
c a
c a
c b
c b
c b
c a
b a
c b
c b
c a
c b
c a
c a
c b
c b
c b
c b
c a
c a
c b
c b
c b
c b
c b
b b
c b
b b
c a
c b
a a
c a
c a
c a
b a
c a
c b
c a
c b
c a
c a
c a
c a
c a
c b
c b
a b
c a
c a
c b
c b
c b
c a
c a
c a
c a
b a
c b
c b
c b
c a
b b
c a
c a
c b
c a
c a
c b
c b
c b
c b
c b
c b
c b
c a
c a
b a
Distance Accuracy with 1 samples: 6.67%


 52%|█████▏    | 31/60 [00:54<00:43,  1.51s/it]

c b
c b
c b
c b
c a
c a
c b
b b
c a
c b
c a
c b
c a
c a
c a
c a
c b
c a
c a
c b
c a
c a
c a
c b
c a
c a
c a
c a
a a
c a
c a
c a
c b
b a
c a
c a
b b
c b
c b
c a
c a
c b
c b
b b
c b
c b
c a
c b
c a
c b
c b
b b
c b
c a
c a
c a
b b
c a
c a
b a
c b
c a
b a
c b
c a
c a
c b
c a
c b
c b
c b
c b
c a
c a
c b
c b
c a
c b
b b
c b
c a
c a
c b
c b
c a
c a
b b
c a
c b
c a
b b
c b
c b
c b
c b
c a
c a
c b
c b
c b
c a
c a
c a
c b
c b
a b
c b
c a
c a
c a
c b
c b
c a
c a
b a
c a
c a
c b
c a
c a
c a
c b
b b
c b
c b
c a
c b
c b
c b
c b
c b
c b
c a
c b
c a
c a
c b
c b
c b
c a
b a
c b
c b
c a
c b
c a
c a
c b
c b
c b
c b
c a
c a
c b
c b
c b
c b
c b
b b
c b
b b
c a
c b
a a
c a
c a
c a
b a
c a
c b
c a
c b
c a
c a
c a
c a
c a
c b
c b
a b
c a
c a
c b
c b
c b
c a
c a
c a
c a
b a
c b
c b
c b
c a
b b
c a
c a
c b
c a
c a
c b
c b
c b
c b
c b
c b
c b
c a
c a
b a
c a
c a
c a
c a
c b
c a
c a
c b
c a
c a
c b
c b
c b
c b
c b
c a
c b
c b
c b
c a
c a
c b
c b
c b
c b
c a
c a
c a
c a
c b
c a
c a
c b
c b
c b
c b
c a
c a
c b
c b


 68%|██████▊   | 41/60 [01:09<00:29,  1.53s/it]

c b
c b
c b
c b
c a
c a
c b
b b
c a
c b
c a
c b
c a
c a
c a
c a
c b
c a
c a
c b
c a
c a
c a
c b
c a
c a
c a
c a
a a
c a
c a
c a
c b
b a
c a
c a
b b
c b
c b
c a
c a
c b
c b
b b
c b
c b
c a
c b
c a
c b
c b
b b
c b
c a
c a
c a
b b
c a
c a
b a
c b
c a
b a
c b
c a
c a
c b
c a
c b
c b
c b
c b
c a
c a
c b
c b
c a
c b
b b
c b
c a
c a
c b
c b
c a
c a
b b
c a
c b
c a
b b
c b
c b
c b
c b
c a
c a
c b
c b
c b
c a
c a
c a
c b
c b
a b
c b
c a
c a
c a
c b
c b
c a
c a
b a
c a
c a
c b
c a
c a
c a
c b
b b
c b
c b
c a
c b
c b
c b
c b
c b
c b
c a
c b
c a
c a
c b
c b
c b
c a
b a
c b
c b
c a
c b
c a
c a
c b
c b
c b
c b
c a
c a
c b
c b
c b
c b
c b
b b
c b
b b
c a
c b
a a
c a
c a
c a
b a
c a
c b
c a
c b
c a
c a
c a
c a
c a
c b
c b
a b
c a
c a
c b
c b
c b
c a
c a
c a
c a
b a
c b
c b
c b
c a
b b
c a
c a
c b
c a
c a
c b
c b
c b
c b
c b
c b
c b
c a
c a
b a
c a
c a
c a
c a
c b
c a
c a
c b
c a
c a
c b
c b
c b
c b
c b
c a
c b
c b
c b
c a
c a
c b
c b
c b
c b
c a
c a
c a
c a
c b
c a
c a
c b
c b
c b
c b
c a
c a
c b
c b


 85%|████████▌ | 51/60 [01:25<00:14,  1.60s/it]

c b
c b
c b
c b
c a
c a
c b
b b
c a
c b
c a
c b
c a
c a
c a
c a
c b
c a
c a
c b
c a
c a
c a
c b
c a
c a
c a
c a
a a
c a
c a
c a
c b
b a
c a
c a
b b
c b
c b
c a
c a
c b
c b
b b
c b
c b
c a
c b
c a
c b
c b
b b
c b
c a
c a
c a
b b
c a
c a
b a
c b
c a
b a
c b
c a
c a
c b
c a
c b
c b
c b
c b
c a
c a
c b
c b
c a
c b
b b
c b
c a
c a
c b
c b
c a
c a
b b
c a
c b
c a
b b
c b
c b
c b
c b
c a
c a
c b
c b
c b
c a
c a
c a
c b
c b
a b
c b
c a
c a
c a
c b
c b
c a
c a
b a
c a
c a
c b
c a
c a
c a
c b
b b
c b
c b
c a
c b
c b
c b
c b
c b
c b
c a
c b
c a
c a
c b
c b
c b
c a
b a
c b
c b
c a
c b
c a
c a
c b
c b
c b
c b
c a
c a
c b
c b
c b
c b
c b
b b
c b
b b
c a
c b
a a
c a
c a
c a
b a
c a
c b
c a
c b
c a
c a
c a
c a
c a
c b
c b
a b
c a
c a
c b
c b
c b
c a
c a
c a
c a
b a
c b
c b
c b
c a
b b
c a
c a
c b
c a
c a
c b
c b
c b
c b
c b
c b
c b
c a
c a
b a
c a
c a
c a
c a
c b
c a
c a
c b
c a
c a
c b
c b
c b
c b
c b
c a
c b
c b
c b
c a
c a
c b
c b
c b
c b
c a
c a
c a
c a
c b
c a
c a
c b
c b
c b
c b
c a
c a
c b
c b


100%|██████████| 60/60 [01:40<00:00,  1.67s/it]

c b
c b
c b
c b
c a
c a
c b
b b
c a
c b
c a
c b
c a
c a
c a
c a
c b
c a
c a
c b
c a
c a
c a
c b
c a
c a
c a
c a
a a
c a
c a
c a
c b
b a
c a
c a
b b
c b
c b
c a
c a
c b
c b
b b
c b
c b
c a
c b
c a
c b
c b
b b
c b
c a
c a
c a
b b
c a
c a
b a
c b
c a
b a
c b
c a
c a
c b
c a
c b
c b
c b
c b
c a
c a
c b
c b
c a
c b
b b
c b
c a
c a
c b
c b
c a
c a
b b
c a
c b
c a
b b
c b
c b
c b
c b
c a
c a
c b
c b
c b
c a
c a
c a
c b
c b
a b
c b
c a
c a
c a
c b
c b
c a
c a
b a
c a
c a
c b
c a
c a
c a
c b
b b
c b
c b
c a
c b
c b
c b
c b
c b
c b
c a
c b
c a
c a
c b
c b
c b
c a
b a
c b
c b
c a
c b
c a
c a
c b
c b
c b
c b
c a
c a
c b
c b
c b
c b
c b
b b
c b
b b
c a
c b
a a
c a
c a
c a
b a
c a
c b
c a
c b
c a
c a
c a
c a
c a
c b
c b
a b
c a
c a
c b
c b
c b
c a
c a
c a
c a
b a
c b
c b
c b
c a
b b
c a
c a
c b
c a
c a
c b
c b
c b
c b
c b
c b
c b
c a
c a
b a
c a
c a
c a
c a
c b
c a
c a
c b
c a
c a
c b
c b
c b
c b
c b
c a
c b
c b
c b
c a
c a
c b
c b
c b
c b
c a
c a
c a
c a
c b
c a
c a
c b
c b
c b
c b
c a
c a
c b
c b



