In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='2'
import sys  
import json
import torch
import numpy as np
from PIL import Image 
from tqdm import tqdm
import datetime
from collections import defaultdict


from datasets import load_dataset
from chartmoe import ChartMoE_Robot

mme_data = load_dataset("lmms-lab/MME")['test']

robot = ChartMoE_Robot()

  from .autonotebook import tqdm as notebook_tqdm


Set max length to 4096


Loading checkpoint shards: 100%|██████████| 2/2 [00:10<00:00,  5.42s/it]


In [2]:
eval_type_dict = {
    "Perception": [
        "existence",
        "count",
        "position",
        "color",
        "posters",
        "celebrity",
        "scene",
        "landmark",
        "artwork",
        "OCR",
    ],
    "Cognition": [
        "commonsense_reasoning",
        "numerical_calculation",
        "text_translation",
        "code_reasoning",
    ],
}

In [3]:
def parse_pred_ans(pred_ans):
    """Brought from Otter Eval"""
    pred_ans = pred_ans.lower().strip().replace(".", "")
    pred_label = None
    if pred_ans in ["yes", "no"]:
        pred_label = pred_ans
    elif len(pred_ans) == 1:
        if pred_ans == "y":
            pred_label = "yes"
        elif pred_ans == "n":
            pred_label = "no"
        else:
            pred_label = "other"
    else:
        prefix_pred_ans = pred_ans[:4]
        if "yes" in prefix_pred_ans:
            pred_label = "yes"
        elif "no" in prefix_pred_ans:
            pred_label = "no"
        else:
            pred_label = "other"
    return pred_label

In [4]:
results = []
for d in tqdm(mme_data):
    image = d['image'].convert("RGB")
    question = d['question']
    category = d['category']
    gt_ans = d["answer"].lower().strip().replace(".", "")

    with torch.cuda.amp.autocast():
        pred, _ = robot.chat(
            image=image,
            question=question,
            temperature=1.0,
            max_new_tokens=500,
            num_beams=3,
            do_sample=False,
            repetition_penalty=1.0
        )

    pred_ans = parse_pred_ans(pred)
    assert gt_ans in ["yes", "no"]
    # assert pred_ans in ["yes", "no", "other"]

    score = 1.0 if pred_ans == gt_ans else 0.0
    key_name = "mme_percetion_score" if category in eval_type_dict["Perception"] else "mme_cognition_score"

    results.append({key_name: {"question_id": d["question_id"], "category": category, "score": score}})

with open("mme_results.jsonl", 'w') as f:
    for res in results:
        f.write(f"{json.dumps(res)}\n")

100%|██████████| 2374/2374 [55:47<00:00,  1.41s/it] 


In [14]:
category2score = defaultdict(dict)
results = [list(res.values())[0] for res in results]
for result in results:
    question_id = result["question_id"]
    score = result["score"]
    category = result["category"]
    if question_id not in category2score[category]:
        category2score[category][question_id] = []
    category2score[category][question_id].append(score)
category2avg_score = {}
for category, question2scores in category2score.items():
    total_score = 0
    for question_id, scores in question2scores.items():
        assert len(scores) == 2, "MME only supports pairwise evaluation"
        acc = sum(scores) / len(scores) * 100.0
        acc_plus = (sum(scores) == 2) * 100.0
        score = acc_plus + acc
        total_score += score
    avg_score = total_score / len(question2scores)
    category2avg_score[category] = avg_score
total_score = sum(category2avg_score.values())
print(total_score)

2214.1502601040415


In [16]:
category2avg_score

{'code_reasoning': 117.5,
 'artwork': 186.25,
 'celebrity': 163.23529411764707,
 'numerical_calculation': 147.5,
 'text_translation': 155.0,
 'count': 170.0,
 'color': 170.0,
 'commonsense_reasoning': 138.57142857142858,
 'position': 158.33333333333334,
 'OCR': 125.0,
 'landmark': 170.25,
 'scene': 157.0,
 'existence': 180.0,
 'posters': 175.51020408163265}

In [21]:
scores = defaultdict(int)
for eval_type in eval_type_dict:
    for category_type in eval_type_dict[eval_type]:
        scores[eval_type] += category2avg_score[category_type]
scores

defaultdict(int,
            {'Perception': 1655.578831532613, 'Cognition': 558.5714285714286})

In [22]:
1712.0 + 530.7

2242.7

In [23]:
1655.6 + 558.6

2214.2