In [1]:
# dependencies 
import torch 
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
from glob import glob
import os
import json

In [None]:
from modelscope import snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer

model_dir = snapshot_download('qwen/Qwen-VL')

tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    device_map="cuda:0",
    trust_remote_code=True
).eval()

tokenizer.pad_token = '<|endoftext|>'
tokenizer.padding_side = 'left'

In [16]:
chart_type = "complex"
question_type = "complex"

os.makedirs(f"../Results/Qwen-VL/{chart_type}_{question_type}/", exist_ok=True)

In [17]:
from typing import Optional
import re
  
def modified_relaxed_accuracy(question:str,
                        target: str,
                        prediction: str,
                        max_relative_change: float = 0.05) -> bool:
  """Calculates relaxed correctness.

  The correctness tolerates certain error ratio defined by max_relative_change.
  See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
  “Following Methani et al. (2020), we use a relaxed accuracy measure for the
  numeric answers to allow a minor inaccuracy that may result from the automatic
  data extraction process. We consider an answer to be correct if it is within
  5% of the gold answer. For non-numeric answers, we still need an exact match
  to consider an answer to be correct. 
  This is now updated to take in account a lot more cases”

  Args:
    target: Target string.
    prediction: Predicted string.
    max_relative_change: Maximum relative change.

  Returns:
    Whether the prediction was correct given the specified tolerance.
  """
  def _to_float(text: str) -> Optional[float]:
    try:
      if text.endswith("%"):
        return float(text.rstrip("%")) / 100.0
      else:
        return float(text)
    except ValueError:
      return None
    
  def _remove_commas_from_numbers(text: str) -> str:
    text = re.sub(r'(\d*),(\d+)', r'\1\2', text)
    return text
  
  def _remove_spaces(text: str) -> str:
    return text.replace(" ", "")
  
  def _check_for_years(question: str) -> bool:
    return "year" in question.lower()
  
  def _check_list(text: str) -> bool:
    return text.startswith("[") and text.endswith("]")
  
  def _list_of_answers(target: str, prediction: str) -> bool:
    target = target.split(",")
    prediction = prediction.split(",")
    target = sorted(target)
    prediction = sorted(prediction)
    return target == prediction
  
  prediction = _remove_commas_from_numbers(prediction)
  target = _remove_commas_from_numbers(target)
  
  prediction_float = _to_float(prediction)
  target_float = _to_float(target)
  
  if not _check_for_years(question) and (prediction_float is not None and target_float is not None):
    try :
      relative_change = abs(prediction_float - target_float) / abs(target_float)
      return relative_change <= max_relative_change
    except :
      return False
  else:
    prediction = _remove_spaces(prediction)
    target = _remove_spaces(target)
    
    if _check_list(target) and _check_list(prediction):
      return _list_of_answers(target[1:-1], prediction[1:-1])
    elif _check_list(target) or _check_list(prediction):
        return _list_of_answers(target[1:-1], prediction) or _list_of_answers(target, prediction[1:-1])
    else:
      return target.lower() in prediction.lower() or (prediction.lower() in target.lower() and len(prediction) > 0)

In [18]:
categories = os.listdir("../perturb_jsons/{}_{}".format(chart_type, question_type))
categories = [os.path.basename(category).split(".")[0] for category in categories]
categories = sorted(categories)

global_answers = {}
category_wise_scores = {}

In [None]:
for category in categories:
    print("running for category:", category)
    print()
    df = pd.read_json('../perturb_jsons/{}_{}/{}.json'.format(chart_type, question_type, category))
    questions = df['query'].tolist()
    gold_labels = df['label'].tolist()
    imagenames = df['imgname'].tolist()
    perturbations = df['perturbation'].tolist()
    imagenames = [f'../final_data/{chart_type}_{question_type}/plots/{pert}/{img}.png' for img, pert in zip(imagenames, perturbations)]
    
    queries = []
    for i, question in enumerate(questions):
        text = question + ' Answer:'
        queries.append(tokenizer.from_list_format([
                {'image': imagenames[i]},
                {'text': text},
        ]))  

    batch_size = 8
    batches = [queries[i:i+batch_size] for i in range(0, len(queries), batch_size)] 
       
    model_responses = []
    for batch in batches:
        inputs = tokenizer(batch, return_tensors='pt', padding=True, truncation=True)
        inputs.to(model.device)
        with torch.no_grad():
            try:
                outputs = model.generate(**inputs)
            except:
                print("bad output")
        generated = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        model_responses.extend(generated)
        print("." * batch_size, end='')
    print()
        
    global_answers[category] = model_responses
    print("answers generated!")

    results = list(zip(questions, model_responses))
    final_responses = []
    for result in results:
        question, response = result
        new_response = response.split('Answer:')[-1].strip()
        new_response = new_response.split('%')[0].strip()
        final_responses.append(new_response)
        
    final_responses = [response.split('=')[-1] for response in final_responses]
    final_responses = [response.split('%')[0] for response in final_responses]

    model_performance = []
    results = list(zip(questions, model_responses))
    for i, ans in enumerate(final_responses):
        model_score = modified_relaxed_accuracy(questions[i],gold_labels[i], ans)
        model_performance.append(model_score)

    category_wise_scores[category] = sum(model_performance)
    print('Model accuracy:', sum(model_performance) / len(model_performance))
    print()
    
    # store resutls after each category
    with open(f"../Results/Qwen-VL/{chart_type}_{question_type}/results.json", "w") as f:
        json.dump(category_wise_scores, f)
    # store answers after each category
    with open(f"../Results/Qwen-VL/{chart_type}_{question_type}/answers.json", "w") as f:
        json.dump(global_answers, f)