Copyright 2024 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
#@title Default title text
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [2]:
!pip install vertexai
!pip install datasets

Collecting vertexai
  Using cached vertexai-1.71.1-py3-none-any.whl.metadata (10 kB)
Collecting google-cloud-aiplatform==1.71.1 (from google-cloud-aiplatform[all]==1.71.1->vertexai)
  Using cached google_cloud_aiplatform-1.71.1-py2.py3-none-any.whl.metadata (32 kB)
Collecting google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,<3.0.0dev,>=1.34.1 (from google-api-core[grpc]!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,<3.0.0dev,>=1.34.1->google-cloud-aiplatform==1.71.1->google-cloud-aiplatform[all]==1.71.1->vertexai)
  Using cached google_api_core-2.24.2-py3-none-any.whl.metadata (3.0 kB)
Collecting google-auth<3.0.0dev,>=2.14.1 (from google-cloud-aiplatform==1.71.1->google-cloud-aiplatform[all]==1.71.1->vertexai)
  Using cached google_auth-2.38.0-py2.py3-none-any.whl.metadata (4.8 kB)
Collecting proto-plus<2.0.0dev,>=1.22.3 (from google-cloud-aiplatform==1.71.1->google-cloud-aiplatform[all]==1.71.1->vertexai)
  Using cached proto_plus-1.26.1

In [1]:
#@title Imports

import io
import json
import vertexai
from vertexai.preview.generative_models import GenerativeModel, Part, Image
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#@title Utils

def pil_image_to_bytes(img, format='PNG'):
  if not img:
    return None
  img_byte_arr = io.BytesIO()
  img.save(img_byte_arr, format=format)
  return img_byte_arr.getvalue()


def get_multi_image_question_parts(record, question_appendix, max_num_images=6):
  """Converts the question into multiple parts corresponding to texts/images.

  Args:
    record: This is a dictionary with a key "question" containing the text of
            the question, and multiple keys named "image_i" corresponding to
            the i-th image.
    question_appendix: A text to be appended at the end of the question.
    max_num_images: Maximum number of images in a problem (6 for ReMI).

  Returns:
    A list of vertexai parts where each part is either a piece of text or an
    image. For example, a question such "Find the difference between <image_1>
    and <image_2> in terms of minutes." will become an array of the form:
    ["Find the difference between ", image1, " and ", image2, " in terms of
    minutes"].
    Note: The parts for other APIs can be prepared similarly, but using the
    specific parts class from those APIs.
  """

  parts, images = [], []
  for i in range(max_num_images):
    img_bytes = pil_image_to_bytes(record[f'image_{i+1}'])
    if img_bytes:
      images.append(Part.from_data(img_bytes, mime_type='image/png'))

  question = 'Question: ' + record['question'] + question_appendix
  for j, img in enumerate(images):
    question_parts = question.split(f'<image{j+1}>')
    parts += [question_parts[0], img] if question_parts[0] else [img]
    question = question_parts[1]
  return parts + [question] if question else []

In [4]:
#@title Evaluate


def strip_json(s):
  """Takes a text possibly containing a json, and extrcts the json part."""
  return s[s.index('{') if '{' in s else 0: s.rindex('}') + 1 if '}' in s else 0]


def is_float(x):
  try: float(x); return True
  except (ValueError, TypeError): return False


def exact_match(label, pred, eps=0.01):
  return (
      str(label) == str(pred) or
      (is_float(label) and is_float(pred) and float(label) == float(pred)) or
      ((pred == 'f(x)' and label == 'f') or (pred == 'g(x)' and label == 'g') or (pred == 'h(x)' and label == 'h')) or
      (f'({pred})' == label or f'({label})' == pred) or
      (str(pred).replace(' ', '') == str(label) or str(label).replace(' ', '') == str(pred)) or
      relaxed_accuracy(label, pred, eps=eps))

def relaxed_accuracy(label, pred, eps=0.03):
  if not is_float(label) or not is_float(pred): return False
  return (1-eps) * float(label) <= float(pred) <= (1+eps) * float(label)

def accuracy_with_tolerance(label, pred, tolerance=10):
  if not is_float(label) or not is_float(pred): return False
  return float(label) - tolerance <= float(pred) <= float(label) + tolerance

def get_pred(model_response):
  model_response_json = strip_json(model_response).replace('\\"', '').replace('\\', '')
  try:
    pred = str(json.loads(model_response_json)['answer']).lower()
    return pred.split('%')[0].strip()
  except (KeyError, json.JSONDecodeError):
    return 'BAD_JSON'

def prep_label(label):
  return label.lower().replace('\\', '').split('%')[0].strip()

def evaluate(task, labels, model_responses):
  correct = 0
  for orig_label, model_response in zip(labels, model_responses):
    pred, label = get_pred(model_response), prep_label(orig_label)
    if task == 'RefCoco':
      correct += 1 if str(pred) in label.split(',') else 0  # whether pred is in label
    elif task in ['GeomShape', 'GeomCost']:
      correct += 1 if relaxed_accuracy(label, pred, eps=0.03) else 0
    elif task == 'Clocks':
      correct += 1 if accuracy_with_tolerance(label, pred, tolerance=10) else 0
    else:
      correct += 1 if exact_match(label, pred) else 0
  return correct / len(labels)

In [5]:
#@title Prepare Data for Model Call

QUESTION_APPENDICES = {
    'Collisions': ' Output only a valid JSON string with two fields: "explanation" and "answer". Do not output anything else. The explanation field contains your reasoning. The answer field contains the numeric value corresponding to your final answer. If it is a yes or no question, the answer field must be 0 for no and 1 for yes.',
    'Clocks': ' Output only a valid JSON string with two fields: "explanation" and "answer". Do not output anything else. The explanation field contains your reasoning. The answer field contains the numeric value corresponding to your final answer.',
    'Schedule': ' Output only a valid JSON string with two fields: "explanation" and "answer". Do not output anything else. The explanation field contains your reasoning. The answer field contains a string corresponding to your final answer.',
    'EmojiAlgebra': ' Output only a valid JSON string with two fields: "explanation" and "answer". Do not output anything else. The explanation field contains your reasoning. The answer field contains the numeric value corresponding to your final answer.',
    'Charts': ' Output only a valid JSON string with two fields: "explanation" and "answer". Do not output anything else. The explanation field contains your reasoning. The answer field contains a string or numerical value corresponding to your final answer.',
    'CodeEdit': ' Output only a valid JSON string with two fields: "explanation" and "answer". Do not output anything else. The explanation field contains your reasoning. The answer field contains the line of code corresponding to your final answer.',
    'GeomShape': ' Output only a valid JSON string with two fields: "explanation" and "answer". Do not output anything else. The explanation field contains your reasoning. The answer field contains the numeric value corresponding to your final answer.',
    'GeomCost': ' Output only a valid JSON string with two fields: "explanation" and "answer". Do not output anything else. The explanation field contains your reasoning. The answer field contains the numeric value corresponding to your final answer.',
    'FuncRead': ' Output only a valid JSON string with two fields: "explanation" and "answer". Do not output anything else. The explanation field contains your reasoning. The answer field contains a string or numeric value corresponding to your final answer.',
    'RefCoco': ' Output only a valid JSON string with two fields: "explanation" and "answer". Do not output anything else. The explanation field contains your reasoning. The answer field contains the numeric value corresponding to your final answer.',
    'IQ': ' Output only a valid JSON string with two fields: "explanation" and "answer". Do not output anything else. The explanation field contains your reasoning. The answer field contains the a string corresponding to your final choice.',
    'Isomorphism': ' Output only a valid JSON string with two fields: "explanation" and "answer". Do not output anything else. The explanation field contains your reasoning. The answer field must be 1 if the two graphs are isomorphic and 0 otherwise.',
    'Maps': ' Output only a valid JSON string with two fields: "explanation" and "answer". Do not output anything else. The explanation field contains your reasoning. The answer field contains a string corresponding to your final answer.'}

prompts, labels = {}, {}
from tqdm import tqdm

for example in tqdm(load_dataset("mehrankazemi/ReMI")['test']):
  task = example['task']
  prompts[task] = prompts.get(task, []) + [get_multi_image_question_parts(example, QUESTION_APPENDICES[task])]
  # print(prompts[task])
  # print(labels.get(task, []))
  # print(example['label'])
  # 修复这里的错误：将字符串放入列表中再连接
  labels[task] = labels.get(task, []) + [example['label']]

100%|██████████| 2600/2600 [05:46<00:00,  7.50it/s]


In [6]:
print(prompts['Maps'][2])

['Question: Here are two images. The first image is image A ', inline_data {
  mime_type: "image/png"
  data: "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\001U\000\000\001U\010\002\000\000\000\320n\353y\000\0008\350IDATx\234\355\335y|T\325\331\007\360\347n\263'\231\354\tYH\200\020\366-\004\022\020P\020\025\005*b@E\321\266\026\264\266U\264n\255\365\325\276V_Z\005\367*\265*Z\\022\327\342\006\002*\010a\227\035\302\022B2YIB2\231\365.\347\274\177L\014Y&!\231\314\222\311<\337\017\237vrg\356\275O\314\375\3159w;\227\241\224\002\n\r\222$UWW\207\205\205EDD\004\272\026\004\000P]]\255(\312\200\001\003\002U\000\353\213\205fggG\2652d\310\220e\313\226\235>}\332\027\353\362\235C\207\016\351t\272\307\036{\254\365\304\305\213\027\353t:\273\335\036\250\252<c6\233\357\271\347\236\230\230\230\224\224\024\243\321\230\221\221\221\237\237\037\350\242z\254\335v5|\370\360\033n\270a\353\326\255\201\256\313\023\257\276\372\352\300\201\003\023\022\022\222\222\222\302\303\303\037y\344\021\247\323\351\3772x_,\324l6\

In [7]:
# 使用transformers加载Qwen模型
def init_qwen_model():
    from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
    import torch
    from PIL import Image as PILImage
    
    model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
    processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True,max_pixels=2048*28*28)
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_id,
        device_map="auto",
        trust_remote_code=True,
        torch_dtype=torch.float16
    )
    return model, processor

qwen_model, qwen_processor = init_qwen_model()

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Loading checkpoint shards: 100%|██████████| 5/5 [00:05<00:00,  1.19s/it]


In [9]:
# ... existing code ...

import torch
from tqdm import tqdm  # 确保导入tqdm

def qwen_multimodal_call(prompt_parts):
    """对Qwen模型进行调用"""
    from PIL import Image as PILImage
    
    # 解析prompt_parts为Qwen接受的格式
    messages = [{"role": "user", "content": []}]
    
    for part in prompt_parts:
        if isinstance(part, str):
            messages[0]["content"].append({"type": "text", "text": part})
        else:  # 这是一个图像Part
            img_bytes = part.inline_data.data
            img = PILImage.open(io.BytesIO(img_bytes))
            messages[0]["content"].append({"type": "image", "image": img})
    
    # 构建Qwen模型的输入
    text = qwen_processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    
    image_inputs = [content["image"] for content in messages[0]["content"] 
                   if content.get("type") == "image"]
    
    inputs = qwen_processor(
        text=[text],
        images=image_inputs,
        padding=True,
        return_tensors="pt"
    )
    
    inputs = inputs.to(qwen_model.device)
    
    # 生成回复
    with torch.no_grad():  # 禁用梯度计算以节省内存
        generated_ids = qwen_model.generate(**inputs, max_new_tokens=512)  # 减少max_new_tokens
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
    
    output_text = qwen_processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]
    
    torch.cuda.empty_cache()  # 清理缓存
    return output_text

model_responses = {}
for task in prompts:
    model_responses[task] = []
    # if task == 'FuncRead':
    #     continue
    for prompt in tqdm(prompts[task], desc=f"Processing {task}"):
        model_response = qwen_multimodal_call(prompt)
        # print("model_response:" + model_response)
        model_responses[task].append(model_response)

# ... existing code ...

Processing FuncRead: 100%|██████████| 200/200 [20:45<00:00,  6.23s/it]
Processing Maps: 100%|██████████| 200/200 [23:21<00:00,  7.01s/it]
Processing RefCoco: 100%|██████████| 200/200 [08:08<00:00,  2.44s/it]
Processing GeomCost: 100%|██████████| 200/200 [25:03<00:00,  7.52s/it]
Processing Collisions: 100%|██████████| 200/200 [23:22<00:00,  7.01s/it]
Processing Isomorphism: 100%|██████████| 200/200 [10:31<00:00,  3.16s/it]
Processing Schedule: 100%|██████████| 200/200 [09:13<00:00,  2.77s/it]
Processing GeomShape:  20%|██        | 40/200 [11:14<44:57, 16.86s/it]


KeyboardInterrupt: 

In [10]:
#@title Run the evaluator
task_scores = {}
for task in model_responses:
    score = evaluate(task, labels[task], model_responses[task])
    task_scores[task] = score
    print(f"{task}: {score:.2f}")

# 计算总体平均分
average_score = sum(task_scores.values()) / len(task_scores)
print("\n" + "="*50)
print(f"ReMI 总分: {average_score:.2f}")
print("="*50)

# 以表格形式展示所有分数
print("\n详细任务分数:")
for task, score in task_scores.items():
    print(f"{task:<15}: {score:.2f}")

FuncRead: 0.31
Maps: 0.30
RefCoco: 0.38
GeomCost: 0.36
Collisions: 0.42
Isomorphism: 0.68
Schedule: 0.30
GeomShape: 0.04

ReMI 总分: 0.35

详细任务分数:
FuncRead       : 0.31
Maps           : 0.30
RefCoco        : 0.38
GeomCost       : 0.36
Collisions     : 0.42
Isomorphism    : 0.68
Schedule       : 0.30
GeomShape      : 0.04
