In [1]:
# dependencies 
import google.generativeai as genai
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
from glob import glob
import os
import json

In [2]:
genai.configure(api_key= '')

In [4]:
generation_config = {
  "temperature": 0,
  "top_p": 1,
  "top_k": 1,
  "max_output_tokens": 2048,
}
safety_settings = [
  {
    "category": "HARM_CATEGORY_HARASSMENT",
    "threshold": "BLOCK_NONE"
  },
  {
    "category": "HARM_CATEGORY_HATE_SPEECH",
    "threshold": "BLOCK_NONE"
  },
  {
    "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
    "threshold": "BLOCK_NONE"
  },
  {
    "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
    "threshold": "BLOCK_NONE"
  }
]
model = genai.GenerativeModel('gemini-1.5-flash-latest', safety_settings=safety_settings, generation_config=generation_config)

In [None]:
model.generate_content("hi") #sanity check

In [52]:
chart_type = "simple"
question_type = "complex"

In [53]:
prompt = """You are an expert in getting the answers from a given long answer with steps. These questions were asked about a chart.
Task: Extract the final answer based on the given long sequence of reasoning with answer, given the question.
 
Instructions:
Append to your response and reasoning: 'The answer is: <final_answer>'. 

If a question asks about a column name, give the full and exact name for the column as it is written in answer. 
If a question required multiple outputs and the output contains multiple outputs as well, give it in the form: [<output1>, <output2> ..] where outputs are in sorted order. For example, if the output is 'Australia and India' give the answer as [Australia, India]. 
Ignore percentage signs.
Remove the units from the answer. For example, if the answer is '10 million', give the answer as '10'. 

A few examples:

Question: What is the value of the blue column?
Given Answer: The blue column has the name 'XXX' and the value is 10.
Your Answer: <reasoning>. The answer is: 10

Question: What is the share of people above 65+ years in the small business category?
Given Answer: To find the share of SME owners in small business over 65 years, we need to add the percentages for the '65-69 years' and '70-74 years' age groups. The calculation is as follows: 26.1% (65-69 years) + 11.8% (70-74 years) = 37.9%. So, the share of SME owners in small business over 65 years is 37.9%.
Your Answer: <reasoning>. The answer is: 37.9

Where <reasoning>. is your reasoning and your chain of thought to get to the answer.

You need to carefully look at the question and the given answer. Think step by step.

Question: {question}
Given Answer: {answer}"""

In [54]:
from typing import Optional
import re
  
def modified_relaxed_accuracy(question:str,
                        target: str,
                        prediction: str,
                        max_relative_change: float = 0.05) -> bool:
  """Calculates relaxed correctness.

  The correctness tolerates certain error ratio defined by max_relative_change.
  See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
  “Following Methani et al. (2020), we use a relaxed accuracy measure for the
  numeric answers to allow a minor inaccuracy that may result from the automatic
  data extraction process. We consider an answer to be correct if it is within
  5% of the gold answer. For non-numeric answers, we still need an exact match
  to consider an answer to be correct. 
  This is now updated to take in account a lot more cases”

  Args:
    target: Target string.
    prediction: Predicted string.
    max_relative_change: Maximum relative change.

  Returns:
    Whether the prediction was correct given the specified tolerance.
  """
  def _to_float(text: str) -> Optional[float]:
    try:
      if text.endswith("%"):
        return float(text.rstrip("%")) / 100.0
      else:
        return float(text)
    except ValueError:
      return None
    
  def _remove_commas_from_numbers(text: str) -> str:
    text = re.sub(r'(\d*),(\d+)', r'\1\2', text)
    return text
  
  def _remove_spaces(text: str) -> str:
    return text.replace(" ", "")
  
  def _check_for_years(question: str) -> bool:
    return "year" in question.lower()
  
  def _check_list(text: str) -> bool:
    return text.startswith("[") and text.endswith("]")
  
  def _list_of_answers(target: str, prediction: str) -> bool:
    target = target.split(",")
    prediction = prediction.split(",")
    target = sorted(target)
    prediction = sorted(prediction)
    return target == prediction
  
  prediction = _remove_commas_from_numbers(prediction)
  target = _remove_commas_from_numbers(target)
  
  prediction_float = _to_float(prediction)
  target_float = _to_float(target)
  
  if not _check_for_years(question) and (prediction_float is not None and target_float is not None):
    try :
      relative_change = abs(prediction_float - target_float) / abs(target_float)
      return relative_change <= max_relative_change
    except :
      return False
  else:
    prediction = _remove_spaces(prediction)
    target = _remove_spaces(target)
    
    if _check_list(target) and _check_list(prediction):
      return _list_of_answers(target[1:-1], prediction[1:-1])
    elif _check_list(target) or _check_list(prediction):
        return _list_of_answers(target[1:-1], prediction) or _list_of_answers(target, prediction[1:-1])
    else:
      return target.lower() in prediction.lower() or (prediction.lower() in target.lower() and len(prediction) > 0)

In [55]:
def get_results(queries, max_workers=10):
    with ThreadPoolExecutor() as executor:
        executor._max_workers = max_workers
        results = list(executor.map(generate_content, queries))
    return results

def generate_content(query):
    try:
        resp = model.generate_content(query)
        print(".", end="")
        return resp.text
    except Exception as e:
        print(e)
        return 'Error by gemini'

In [63]:
chart_type = "complex"
ques_type = "complex"

categories = os.listdir('../Results/cog_agent/{chart_type}_{ques_type}/Initial_Run/')
categories = [c.split('.')[0] for c in categories]

In [None]:
for category in categories:
    df = pd.read_json('../perturb_jsons/{}_{}/{}.json'.format(chart_type, ques_type, category))
    questions = df['query'].tolist()
    gold = df['label'].tolist()
    pred = json.load(open(f'../Results/InternLM_XComposer2VL/{chart_type}_{ques_type}/Initial_Run/{category}.json','r'))
    assert(len(pred) == len(questions))
    queries = [prompt.format(question=question, answer=answer) for question, answer in zip(questions, pred)]
    with ThreadPoolExecutor() as executor:
        executor._max_workers = 16
        model_responses = list(executor.map(generate_content, queries))
    copy = model_responses.copy()
    for i, resp in enumerate(copy):
        resp = resp.strip()
        if(resp[-1] == '.'):
            resp = resp[:-1]
        if 'The answer is: ' in resp:
            x = resp.split('The answer is: ')
            model_responses[i] = x[1]
        elif 'the answer is: ' in resp:
            x = resp.split('the answer is: ')
            model_responses[i] = x[1]
        elif 'The answer is ' in resp:
            x = resp.split('The answer is ')
            model_responses[i] = x[1]
        else:
            print(i, "error by gemini")

    results = list(zip(questions, model_responses))
    final_responses = []
    for result in results:
        question, response = result
        final_responses.append(response.strip())
        
    final_responses = [response.split('=')[-1] for response in final_responses]
    final_responses = [response.split('%')[0] for response in final_responses]

    model_performance = []
    results = list(zip(questions, model_responses))
    for i, ans in enumerate(final_responses):
        model_score = modified_relaxed_accuracy(questions[i],gold[i], ans)
        model_performance.append(model_score)
    
    print(f"For Category: {category}")
    print("Model performance: ", sum(model_performance),"out of", len(model_performance),">>", sum(model_performance)/len(model_performance))
    print("-------------------------------------------------")

    json.dump()
