In [None]:
import os
import re
import sys  
import json
import torch 
import argparse
import numpy as np
from PIL import Image  
from tqdm import tqdm
from utils import model_gen
from transformers import AutoModelForCausalLM, AutoTokenizer  

ckpt_path = 'internlm/internlm-xcomposer2-vl-7b'
tokenizer = AutoTokenizer.from_pretrained(ckpt_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map="cuda", trust_remote_code=True).eval().cuda().to(torch.bfloat16)
model.tokenizer = tokenizer

In [None]:
lang = "en" # en | zh
split = "dev" # dev | test
## define split/language here ##

if lang == "en":
    with open(f"data/qbench/llvisionqa_{split}.json") as f:
        llvqa_data = json.load(f)
elif lang == "zh":
    zh_split = "验证集" if split == "dev" else "测试集"
    with open(f"data/qbench/质衡-问答-{zh_split}.json") as f:
        llvqa_data = json.load(f)
else:
    raise NotImplementedError("Q-Bench does not support languages other than English (en) and Chinese (zh) yet. Contact us (https://github.com/VQAssessment/Q-Bench/) to convert  Q-Bench into more languages.")

correct = np.zeros((3,4))
all_ = np.zeros((3,4))
pattern = re.compile(r'[A-D]')
answers = {}
for llddata in tqdm((llvqa_data)):
    t, c = llddata["type"], llddata["concern"]
    if lang == "en":
        message = llddata["question"] 
    elif lang == "zh":
        message = llddata["question"] 
        
    options_prompt = ''
    for choice, ans in zip(["A.", "B.", "C.", "D."], llddata["candidates"]):
        options_prompt += f"{choice} {ans} "
        if "correct_ans" in llddata and ans == llddata["correct_ans"]:
            correct_choice = choice[0]
            
    text = '[UNUSED_TOKEN_146]user\nQuestion: {}\nContext: N/A\nOptions: {}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\nThe answer is'.format(
                llddata["question"], options_prompt)
    
    img_path = f"data/qbench/llv_dev/{split}/" + llddata["img_path"]
    # 1st dialogue turn
    with torch.cuda.amp.autocast(): 
        response = model_gen(model, text, img_path)
        res = pattern.findall(response)
        if len(res) == 0:
            print('Error:', output_text); res = 'E'
        else:
            res = res[0]
        
     
    llddata["response"] = res
    answers[llddata["img_path"]] = res[0]
    #print("[Response]: {}, [Correct Ans]: {}".format(response, correct_choice))
    all_[t][c] += 1
    if res[0] not in ['A', 'B', 'C', 'D']:
        print("[Response]: {}, [Correct Ans]: {}".format(res, correct_choice))
    if split == 'dev' and res[0] == correct_choice:
        correct[t][c] += 1
        
print (correct.sum(1)/all_.sum(1))
print (correct.sum(0)/all_.sum(0))
print (correct.sum()/all_.sum())
torch.save(answers, 'Output/QBench_dev_en_InternLM_XComposer_VL.json.pth')