In [4]:
! pip install num2words

Collecting num2words
  Downloading num2words-0.5.14-py3-none-any.whl.metadata (13 kB)
Collecting docopt>=0.6.2 (from num2words)
  Downloading docopt-0.6.2.tar.gz (25 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hDownloading num2words-0.5.14-py3-none-any.whl (163 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.5/163.5 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hBuilding wheels for collected packages: docopt
  Building wheel for docopt (pyproject.toml) ... [?25ldone
[?25h  Created wheel for docopt: filename=docopt-0.6.2-py2.py3-none-any.whl size=13781 sha256=ed70be4043316e44e524d92ef4e12405730c1670576965922996e3b16d6fac59
  Stored in directory: /home/divyansh/.cache/pip/wheels/1a/bf/a1/4cee4f7678c68c5875ca89eaccf460593539805c3906722228
Successfully built docopt
Installing collected packages: docopt, num2words
Su

In [6]:
import os
import json
import numpy as np
from PIL import Image
from typing import Dict, List, Union, Any
import re
import argparse
from tqdm import tqdm
from itertools import islice
from num2words import num2words

In [None]:
def load_dataset(dataset_path: str) -> List[Dict]:
    
    with open(dataset_path, 'r') as f:
        dataset = json.load(f)
    return dataset

def is_numeric(text: str) -> bool:
   
    try:
        cleaned_text = text.replace(',', '')
        float(cleaned_text)
        return True
    except ValueError:
        return False

def normalize_answer(answer: str) -> str:
    answer = answer.lower()
    answer = re.sub(r'\b(a|an|the)\b', ' ', answer)
    answer = re.sub(r'[^\w\s.]', '', answer)
    answer = re.sub(r'(?<!\d)\.|\.(?!\d)', '', answer)
    answer = re.sub(r'\s+', ' ', answer).strip()

    return answer

def check_answer_correctness(pred: str, label: str) -> bool:
    pred_normalized = normalize_answer(pred)
    label_normalized = normalize_answer(label)

    if label_normalized in pred_normalized:
        return True

    is_list = (label.startswith('[') and label.endswith(']'))
    elements = []
    if is_list:
        list_content = label[1:-1].strip()
        if list_content:
            elements = [e.strip() for e in list_content.split(',')]
        normalized_elements = [normalize_answer(e) for e in elements]
        all_elements_present = all(elem in pred_normalized for elem in normalized_elements)
        if all_elements_present:
            return True

    label_is_numeric = is_numeric(label_normalized)
    if label_is_numeric:
        pred_tokens = pred.split()
        for token in pred_tokens:
            token = normalize_answer(token)
            if is_numeric(token):
                try:
                    pred_val = float(token.replace(',', ''))
                    label_val = float(label_normalized.replace(',', ''))
                except ValueError:
                    continue
                if abs(label_val) < 1e-10:
                    if abs(pred_val - label_val) < 1e-10:
                        return True
                    else:
                        continue
                rel_error = abs(pred_val - label_val) / abs(label_val)
                if rel_error <= 0.05:
                    return True
            else:
              num_as_word = num2words(label).lower()
              if token == num_as_word:
                  return True

              try:
                  text_as_num = float(token)
                  if label < 1e-10:
                      return abs(label - text_as_num) < 1e-10

                  rel_error = abs(label - text_as_num) / abs(label)
                  return rel_error <= 0.05  # 5% tolerance
              except (ValueError, TypeError):
                  continue
        return False

    label_words = label.split()
    pred_words = pred.split()
    for i,words in pred_words:
        pred_words[i] = normalize_answer(words)

    current_pos = 0
    for word in label_words:
        word = normalize_answer(word)
        try:
            current_pos = pred_words.index(word, current_pos) + 1
        except ValueError:
            return False
    return True

class ChartQAEvaluator:

    def __init__(self):
        """
        Initialize the evaluator.

        Args:
            model: The VLM model to evaluate
            image_dir: Directory containing chart images
        """

    def process_single_example(self, example: Dict) -> Dict[str, Any]:

        try:
            prediction = example['output']

            is_correct = check_answer_correctness(prediction, example["label"])
            print("\n``````````````````````````````````````````````````````````")
            print(f"Query: {example['query']}")
            print(f"label: {example['label']}")
            print(f"Prediction: {prediction}")
            print(f"Truth: {is_correct}")
            print("``````````````````````````````````````````````````````````")


            return {
                "example": example,
                "prediction": prediction,
                "is_correct": is_correct
            }
        except Exception as e:
            print(f"Error processing : {e}")
            return {
                "example": example,
                "prediction": "",
                "is_correct": False,
                "error": str(e)
            }

    def evaluate_dataset(self, dataset: List[Dict]) -> Dict[str, Any]:
        results = []
        numeric_results = []
        non_numeric_results = []

        for example in tqdm(dataset, desc="Evaluating"):
            result = self.process_single_example(example)
            results.append(result)

            if is_numeric(example["label"]):
                numeric_results.append(result["is_correct"])
            else:
                non_numeric_results.append(result["is_correct"])


        overall_accuracy = np.mean([r["is_correct"] for r in results])

        numeric_accuracy = np.mean(numeric_results) if numeric_results else 0
        non_numeric_accuracy = np.mean(non_numeric_results) if non_numeric_results else 0

        return {
            "overall_accuracy": float(overall_accuracy),
            "numeric_accuracy": float(numeric_accuracy),
            "non_numeric_accuracy": float(non_numeric_accuracy),
            "num_examples": len(dataset),
            "num_numeric": len(numeric_results),
            "num_non_numeric": len(non_numeric_results),
            "detailed_results": results
        }



### Assign the paths in the cell below
- `input_dir`: Directory containing model responses in following format:
```json
    {
        "imgname": "multi_col_100294.png",
        "query": "What is the average of all the dark blue bars?",
        "label": "22.33",
        "output": "23.67."
    }
```

In [None]:
def main():

    input_dir = "results"
    output_dir = "eval_results"
    for files in os.listdir(input_dir):
        output_path = output_dir + "/"+files + "_eval_results.json"
        model_name = input

        dataset = load_dataset(input_dir + "/"+files)

        evaluator = ChartQAEvaluator()

        results = evaluator.evaluate_dataset(dataset)

        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        with open(output_path, 'w') as f:
            json.dump(results, f, indent=2)

        print("\nEvaluation Results:")
        print(f"Model Name: {model_name}")
        print(f"Overall Accuracy: {results['overall_accuracy']:.4f}")
        print(f"Numeric Accuracy: {results['numeric_accuracy']:.4f} ({results['num_numeric']} examples)")
        print(f"Non-Numeric Accuracy: {results['non_numeric_accuracy']:.4f} ({results['num_non_numeric']} examples)")


if __name__ == "__main__":
    main()