In [1]:
import logging
import re
from copy import deepcopy

from beartype import beartype
from beartype.typing import List, Union, Callable


LOGGER = logging.getLogger()


dot_decimal_not = r'^\.(\d+)'
replacement = r'0.\1'
#LOGGER.info("Extracting first number in answer.....................................")
def first_number(answer:str, equals:bool = True):
    # We want to extract the first number in each answer:
    # Warn if there is an equal sign in the answer or more than one number
    try:
        numbers = re.findall(r'[-+]?[0-9]*\.?[0-9]+', answer)
        if len(numbers) > 1:
            pass
            # LOGGER.debug(f"Warning: More than one number in answer: {answer}")
        first_number = numbers[0]

        first_number = re.sub(dot_decimal_not, replacement, first_number)
        # Equal sign in answer:
        if equals and "=" in answer:
            new_answer = answer.split("=")[1]
            # LOGGER.warn(f"Warning: Equal sign in answer: {answer}")
            # Get first number after equal sign:
            try:
                new_first_number = re.match(r'[-+]?[0-9]*\.?[0-9]+', new_answer)[0]
                return new_first_number
            except:
                # LOGGER.warn(f"Warning: No number after equal sign: {answer}")
                # print(f"Warning: No number after equal sign: {answer}")
                pass
        return first_number
    except Exception as e:
        return ""

def extract_answer(
        response: str,
):
    response = response.replace("Answer 2:", "Answer:")
    response = response.replace("Answer 1:", "Answer:")
    response = response.replace("Answer 3:", "Answer:")
    response = response.replace("Answer 4:", "Answer:")
    response = response.replace("Answer 5:", "Answer:")
    # Check that the model has generated a sentence with "Answer(<response>)"
    # Check with re if "answer " or "answer:" is in low_responses:
    if re.search(r"answer", response, flags=re.IGNORECASE) is None:
        #LOGGER.info(f"The model has not generated a sentence with 'Answer'  .......  {'FAIL':>10}   ||  Processed total: {(i+1)*100/len(responses):0.1f}%,  Correct extraction: {correctly_extracted_count/(i+1)}")#logging.info
        return response
    response = re.split(r"answer:?", response, flags=re.IGNORECASE)

    # Add all parts of the response after "Answer" to the answer:
    response = "answer".join(response[1:])
    # LOGGER.info(f"The model has generated a sentence with 'Answer'  .......  {'SUCCESS':>10}")#logging.info
    # LOGGER.debug(f"Response {i}: {response}")
    return response

    
def remove_tool_calls(answer:str, tool_name:str=None, start_token:str="<TOOL>", end_token:str="→"):
    if tool_name is not None:
        tool_use_pattern = rf" {re.escape(start_token)}(?={tool_name}).*?{re.escape(end_token)}"
    else:
        tool_use_pattern = rf" {re.escape(start_token)}.*?{re.escape(end_token)}"
    # We want to extract the first number in each answer:
    # Warn if there is an equal sign in the answer or more than one number
    answer = re.sub(tool_use_pattern, "", answer)
    return answer



def stats(
        question_result:List[bool],
        tool_history:List[List[int]],
        answer_type:str=""
):
    
    LOGGER.info("Calculating stats.....................................")
    TOOL_STATUS = ["", " good call", " bad call"]
    REALMS = ["global", "correct", "incorrect"]

    examples_in_realm_count = {
        "global":len(question_result),
        "correct":question_result.count(True),
        "incorrect":question_result.count(False),
    }

    seen_tools = []

    # Tool history is a list of the history of tools used for each answer.
    # Each tool history is a list of tool ids.

    # We want the following stats:
    # 1. Exact match accuracy
    # 2. Include match accuracy
    # For each of the correct/incorrect groups, for both exact and includes:
    # We want stats on tool use:
    # 1. Number of tools used
    # 2. Tool type distribution
    # In total:
    # 1. Average tool use
    # 2. Average tool type distribution

    tool_stats = {
        "max number of tools used":0,
        "min number of tools used":1000,
    }
    tool_stats = {key+tool_status:value for key, value in tool_stats.items() for tool_status in TOOL_STATUS}

    default_stat = {
            "per tool stats":{},
    } | tool_stats

    stats = {realm:deepcopy(default_stat) for realm in REALMS}

    # Tool use stats:
    # Will add same stats per tool, in dictionary with key as tool id
    default_call_count = {f"total {realm}" + tool_status:0 for realm in REALMS for tool_status in TOOL_STATUS}

    call_count = deepcopy(default_call_count)
    

    def compute_stats(tool_status, history_list, realm):
        nonlocal stats, call_count

        call_count[f"total {realm}" + tool_status] += len(history_list)
        stats[realm]["max number of tools used" + tool_status] = max(stats[realm]["max number of tools used" + tool_status], len(history_list))
        stats[realm]["min number of tools used" + tool_status] = min(stats[realm]["min number of tools used" + tool_status], len(history_list))
        for id in history_list:
            call_count[id][f"total {realm}" + tool_status] += 1
            stats[realm]["per tool stats"][id]["max number of tools used" + tool_status] = max(stats[realm]["per tool stats"][id]["max number of tools used" + tool_status], history_list.count(id))
            stats[realm]["per tool stats"][id]["min number of tools used" + tool_status] = min(stats[realm]["per tool stats"][id]["min number of tools used" + tool_status], history_list.count(id))

    for i, (t_history, result) in enumerate(zip(tool_history, question_result, strict=True)):

        call_list = []
        good_call_list = []
        bad_call_list = []
        
        for use in t_history:
            call_list.append(use["id"])
            if "status" not in use:
                LOGGER.warn(f"WARNING: No status for tool {use}")
                LOGGER.warn(f"data is id: {i}")
                bad_call_list.append(use["id"])
            elif use["status"] != 0:
                bad_call_list.append(use["id"])
            elif use["status"] == 0:
                good_call_list.append(use["id"])

        use_cases = [call_list, good_call_list, bad_call_list]

        for tool_status, use_case in zip(TOOL_STATUS, use_cases):

            # Initialize per tool stats for new tools:
            for id in use_case:
                if id not in call_count:
                    seen_tools.append(id)
                    call_count[id] = deepcopy(default_call_count)
                    for realm in REALMS:
                        stats[realm]["per tool stats"][id] = deepcopy(tool_stats)

            if tool_status == TOOL_STATUS[0]:
                examples_in_realm_count[REALMS[0]] += 1
            compute_stats(tool_status, use_case, REALMS[0])

            realm = "correct" if result else "incorrect"

            compute_stats(tool_status, use_case, realm)
                


    # Update averages in stats with counts:
    for realm in REALMS:
        realm_total = max(examples_in_realm_count[realm], 1)
        for tool_status in TOOL_STATUS:
            stats[realm]["average number of tools used" + tool_status] = call_count[f"total {realm}" + tool_status]/realm_total
            stats[realm][f"total tools used" + tool_status] = call_count[f"total {realm}" + tool_status]
            stats[realm][f"total examples" + tool_status] = examples_in_realm_count[realm]
            for id in seen_tools:
                stats[realm]["per tool stats"][id]["average number of tools used" + tool_status] = call_count[id][f"total {realm}" + tool_status]/realm_total
                stats[realm]["per tool stats"][id][f"total tools used" + tool_status] = call_count[id][f"total {realm}" + tool_status]
                stats[realm]["per tool stats"][id][f"total examples" + tool_status] = examples_in_realm_count[realm]

    # Accuracies:
    stats[REALMS[0]]["correct accuracy"] = examples_in_realm_count["correct"]/examples_in_realm_count[REALMS[0]]
    
    # Tree search through dictionary and print key branch and final leaf values:
    def print_dict(d, indent=0):
        for key, value in d.items():
            if isinstance(value, dict):
                LOGGER.info('  ' * indent + str(key))
                print_dict(value, indent+1)
            else:
                LOGGER.info('  ' * indent + str(key) + ": " + str(value))

    LOGGER.info(f"Stats for {answer_type}:")
    print_dict(stats)
    # Save stats to json file:
    # Count number of json files to give it an id:
    # Make stats dir in current directory if non existent:
    if not os.path.exists("./stats"):
        os.makedirs("./stats")
    file_id = len([f for f in os.listdir("./stats") if f.endswith(".json")])
    if answer_type != "":
        answer_type = "_" + answer_type
    with open(f"stats/stat_{file_id}"+answer_type+".json", "w") as f:
        json.dump(stats, f, indent=4)
    LOGGER.info(f"Stats saved to stats/stat_{file_id}"+answer_type+".json")

    return stats


def exact_acc(
        list_a:List,
        list_b:List,
        convert:Callable=str,
):
    count = 0
    for a, b in zip(list_a, list_b, strict=True):
        try:
            if convert(a) == convert(b):
                count += 1
        except ValueError:
            pass
        
    
    return count/len(list_a)

def incl_acc(
        list_a:List,
        list_b:List,
        convert:Callable=str,
        det: bool=False
):
    count = 0
    result = []
    for a, b in zip(list_a, list_b, strict=True):
        try:
            if convert(a) in convert(b):
                count += 1
                result.append(True)
            else:
                result.append(False)
        except ValueError:
            result.append(False)
            pass

    if det:
        return count/len(list_a), result

    return count/len(list_a)


def eval_asdiv(
    responses:List[str],
    solutions:List[str],
):
    results = []
    correct_count = 0
    for resp, sol in zip(responses, solutions, strict=True):
        sol = sol.strip()
        try:
            sol = float(sol)
            # sol is a number so we get first number in response:
            resp = first_number(resp)
            if resp == "":
                results.append(False)
            elif str(float(resp)) == str(sol):
                correct_count += 1
                results.append(True)
            else:
                results.append(False)
        except:
            # sol is not a number. It can be either a "yes" or the name of a person, or a time

            # Check if sol is in first 20 words of response:
            resp = " ".join(resp.split()[:20])
            if str(sol) in resp:
                correct_count += 1
                results.append(True)
            else:
                results.append(False)

    return correct_count/len(solutions), results

In [17]:
import pandas as pd
import os
import sys

# Load csv file:                                                                                              First                 
# "/vol/bitbucket/jg2619/augmenting_llms/benchmarks/results/asdiv-responses-PLEASE_AY_24.csv"
# "/vol/bitbucket/jg2619/augmenting_llms/benchmarks/results/gms8k-easy-responses-PLEASE_AY_25.csv"
# /vol/bitbucket/jg2619/augmenting_llms/benchmarks/old_results/asdiv-responses-1-shot_basic_asdiv_A_basic_1-shot_27.csv
# /vol/bitbucket/jg2619/augmenting_llms/benchmarks/old_results/asdiv-responses-DX-2_asdiv_DX-2_25.csv 
# /vol/bitbucket/jg2619/augmenting_llms/benchmarks/old_results/asdiv-responses-dx2_ASDIV_16_8_DX-2_28.csv           close to 0       
# /vol/bitbucket/jg2619/augmenting_llms/benchmarks/old_results/asdiv-responses-No_duplicates_no_add_DX-3_29.csv    close to 0
# /vol/bitbucket/jg2619/augmenting_llms/benchmarks/good_results/asdiv-responses-Sunny_DX-5_0.csv                0.037815126050420166
# /vol/bitbucket/jg2619/augmenting_llms/benchmarks/results/asdiv-responses-Nighty_DX-5-bare_0.csv               51 
# /vol/bitbucket/jg2619/augmenting_llms/benchmarks/results/gms8k-easy-responses-Nighty_DX-5-bare_0.csv           0 everything except 0.15 for incl bare
# /vol/bitbucket/jg2619/augmenting_llms/benchmarks/results/asdiv-responses-OBI2_DZ-5.2-bare_0.csv                   0.4495798319327731 
# asdiv-responses-task_explan_Nighty_DZ-5.2-COT_0.csv    0.37         
# asdiv-responses-task_explan_Nighty_DZ-5.2-COT_2.csv    0.3781
# asdiv-responses-task_explan_Nighty_DZ-5.2-COT-Shuffle_0.csv    0.48739
# asdiv-responses-ARG-single-tool_arg_training_0.csv     0.008403
# asdiv-responses-please22_mickey_0.csv   0.48
# asdiv-responses-mickypad_mickey_1.csv


# BASELINE:
#/vol/bitbucket/jg2619/augmenting_llms/benchmarks/old_results/asdiv-responses-basic_asdiv_A_basic_0-shot_26.csv
# /vol/bitbucket/jg2619/augmenting_llms/benchmarks/results/triviaQA-small-responses-BARE-BASELINE_GPTJ_baseline_0.csv
# /vol/bitbucket/jg2619/augmenting_llms/benchmarks/results/asdiv-responses-mickey-task-all-tools_mickey-task_0.csv
# /vol/bitbucket/jg2619/augmenting_llms/benchmarks/results/triviaqa-responses-full_GPTJ_baseline_2.csv
# "/vol/bitbucket/jg2619/augmenting_llms/benchmarks/results/asdiv-full-responses-basebaseB_GPTJ_Master_0.csv"
# "/vol/bitbucket/jg2619/augmenting_llms/benchmarks/results/asdiv-full-responses-goodmorning_med_0.csv"

file_dir = "/vol/bitbucket/jg2619/augmenting_llms/benchmarks/results/ASDiv-full-responses-wewe_med-no-token-mono-low-k_0.csv"

df = pd.read_csv(file_dir)


In [17]:
count = 0
count_correct = 0
for i, response in enumerate(df["responses"]):
    first_num = first_number(response, equals=False)

    if first_num != "":
        response = response[:response.index(first_num)]

    if "[Calculator" in response:
        count += 1
        if results_asdiv_exa[i]:
            count_correct += 1

print(count)
print(count_correct)

679
156


In [3]:
import json
import random

def load_ASDiv(subset = False):

    from bs4 import BeautifulSoup
    import re

    with open('/vol/bitbucket/jg2619/augmenting_llms/benchmarks/ASDiv/ASDiv.xml', 'r') as f:
        data = f.read()
    Bs_data = BeautifulSoup(data, "lxml")

    problem_ids = []

    with open('/vol/bitbucket/jg2619/augmenting_llms/benchmarks/ASDiv/fold0.txt', 'r') as f:
        for line in f.readlines():
            problem_ids.append(line.strip())

    if not subset: problem_ids = [f"nluds-{id:04d}" for id in range(1,2306)]
    data = []

    for id in problem_ids:
        problem = Bs_data.find("problem", id=id)
        # Remove units, provided as " (unit)"
        answer = re.sub(' \(.+\)', '', problem.answer.text)

        question = str(problem.find(string=True, recursive=False)[3:-3] + " " + problem.question.text)

        data.append({"question":question, "answer":answer})

    return data

def load_triviaQA(small=False):
    with open('/vol/bitbucket/jg2619/augmenting_llms/benchmarks/TriviaQA/triviaqa-unfiltered/short-unfiltered-web-dev-types.json') as f:
        data = json.load(f)["Data"]
        # Get a random sample of 300 examples, with a fixed seed:
        if small: data = random.Random(42).sample(data, 900)
    return data

trivia_data = load_triviaQA(small = False)
answer_aliases = [d["answer_aliases"] for d in trivia_data]

asdiv_data = load_ASDiv(subset= False)
asdiv_corr = [d["answer"] for d in asdiv_data]


def trivia_qa_accuracy(
        model_answers:List[str],
        solutions:List[List[str]],
):
    correct_includes = 0

    for correct, model_ans in zip(solutions, model_answers):
        # Get first 20 words of model_ans:
        model_ans = " ".join(model_ans.split()[:20])
        incl_done = False
        for c in correct:
            if not incl_done and str(c) in str(model_ans):
                correct_includes += 1
                incl_done = True

    return correct_includes/len(solutions)



len(trivia_data)

# Boolean list of data with answer_type ""WikipediaEntity"" 
# (i.e. the answer is a wikipedia entity)

wikipedia_entities = [d for d in trivia_data if d["answer_type"] == "WikipediaEntity"]
wiki_bool = [d["answer_type"] == "WikipediaEntity" for d in trivia_data]

wiki_trivia = [d["answer"] for d in wikipedia_entities]
answer_aliases = [d["answer_aliases"] for d in trivia_data]

print(len(wiki_trivia))



9768


In [6]:
df.columns

Index(['questions', 'correct_answers', 'responses', 'extracted',
       'tool_histories', 'first_numbers_arr', 'first_numbers_end'],
      dtype='object')

In [5]:
import ast

stats(
    question_result = results_asdiv_exa,
    tool_history = [ast.literal_eval(h.replace("TypeError(\"", "\"Error(").replace("\")}", ")\"}")) for h in df["tool_histories"].tolist()],
    answer_type = "ASDiv",
)

  LOGGER.warn(f"data is id: {i}")
data is id: 29
data is id: 29


data is id: 30
data is id: 30
data is id: 32
data is id: 32
data is id: 35
data is id: 35
data is id: 39
data is id: 39
data is id: 44


{'global': {'per tool stats': {0: {'max number of tools used': 3,
    'max number of tools used good call': 1,
    'max number of tools used bad call': 2,
    'min number of tools used': 1,
    'min number of tools used good call': 1,
    'min number of tools used bad call': 1,
    'average number of tools used': 0.5021691973969631,
    'total tools used': 2315,
    'total examples': 4610,
    'average number of tools used good call': 0.4945770065075922,
    'total tools used good call': 2280,
    'total examples good call': 4610,
    'average number of tools used bad call': 0.007592190889370932,
    'total tools used bad call': 35,
    'total examples bad call': 4610}},
  'max number of tools used': 3,
  'max number of tools used good call': 1,
  'max number of tools used bad call': 2,
  'min number of tools used': 0,
  'min number of tools used good call': 0,
  'min number of tools used bad call': 0,
  'average number of tools used': 0.5021691973969631,
  'total tools used': 2315,
  

In [8]:
len(df)

2305

In [15]:
from functools import partial

DEBUG = False
responses = list(df['responses'].values)
correct_answers = asdiv_corr

generous_acc, result = incl_acc(correct_answers, responses, det = True)
print(f"Generous accuracy: {generous_acc}")


remove_calls_arr = partial(remove_tool_calls, start_token="[", end_token="→")
remove_calls_end = partial(remove_tool_calls, start_token="[", end_token="]")
extracted_answers = list(map(extract_answer, responses))
call_less_arr_responses = list(map(remove_calls_arr, responses))
call_less_end_responses = list(map(remove_calls_end, responses))
extracted_cla_responses = list(map(extract_answer, call_less_arr_responses))
extracted_cle_responses = list(map(extract_answer, call_less_end_responses))

first_numbers_arr = list(map(first_number, call_less_arr_responses))
first_numbers_end = list(map(first_number, call_less_end_responses))
first_numbers_cla = list(map(first_number, extracted_cla_responses))
first_numbers_cle = list(map(first_number, extracted_cle_responses))

fn_arr_acc = exact_acc(first_numbers_arr, correct_answers, float)
fn_end_acc = exact_acc(first_numbers_end, correct_answers, float)
fn_cla_acc = exact_acc(first_numbers_cla, correct_answers, float)
fn_cle_acc = exact_acc(first_numbers_cle, correct_answers, float)

asdiv_acc_a, results_asdiv_a = eval_asdiv(call_less_arr_responses, correct_answers)
asdiv_acc_e, results_asdiv_e = eval_asdiv(call_less_end_responses, correct_answers)
asdiv_acc_exa, results_asdiv_exa = eval_asdiv(extracted_cla_responses, correct_answers)
asdiv_acc_exe, results_asdiv_exe = eval_asdiv(extracted_cle_responses, correct_answers)

print(f"First number in answer (arr): {fn_arr_acc}")
print(f"First number in answer (end): {fn_end_acc}")
print(f"First number in answer (extracted arr): {fn_cla_acc}")
print(f"First number in answer (extracted end): {fn_cle_acc}")

print(f"ASDiv accuracy (arr): {asdiv_acc_a}")
print(f"ASDiv accuracy (end): {asdiv_acc_e}")
print(f"ASDiv accuracy (extracted arr): {asdiv_acc_exa}")
print(f"ASDiv accuracy (extracted end): {asdiv_acc_exe}")

print(f"Sum of asdiv results ex arr: {sum(results_asdiv_exa)}")
print(f"Len of asdiv results ex arr: {len(results_asdiv_exa)}")
print(f"Acc of asdiv results ex arr: {sum(results_asdiv_exa)/len(results_asdiv_exa)}")

assert len(results_asdiv_a) == len(results_asdiv_e) == len(results_asdiv_exa) == len(results_asdiv_exe) == len(correct_answers)



Generous accuracy: 0.39739696312364425
First number in answer (arr): 0.19956616052060738
First number in answer (end): 0.14620390455531454
First number in answer (extracted arr): 0.22125813449023862
First number in answer (extracted end): 0.17483731019522777
ASDiv accuracy (arr): 0.2086767895878525
ASDiv accuracy (end): 0.15531453362255965
ASDiv accuracy (extracted arr): 0.227765726681128
ASDiv accuracy (extracted end): 0.18134490238611714
Sum of asdiv results ex arr: 525
Len of asdiv results ex arr: 2305
Acc of asdiv results ex arr: 0.227765726681128


In [11]:

dot_decimal_not = r'^\.(\d+)'
replacement = r'0.\1'
#LOGGER.info("Extracting first number in answer.....................................")

def first_number(answer:str, equals:bool = True) -> str:
    # We want to extract the first number in each answer:
    # Warn if there is an equal sign in the answer or more than one number
    try:
        numbers = re.findall(r'[-+]?[0-9]*\.?[0-9]+', answer)
        if len(numbers) > 1:
            pass
            #LOGGER.debug(f"Warning: More than one number in answer: {answer}")
        first_number = numbers[0]

        first_number = re.sub(dot_decimal_not, replacement, first_number)
        # Equal sign in answer:
        if equals and "=" in answer:
            new_answer = answer.split("=")[1]
            # Get first number after equal sign:
            try:
                new_first_number = re.match(r'[-+]?[0-9]*\.?[0-9]+', new_answer)[0]
                return new_first_number
            except:
                LOGGER.warn(f"Warning: No number after equal sign: {answer}")
                print(f"Warning: No number after equal sign: {answer}")

        return first_number
    except Exception as e:
        return ""



def exact_acc(
        list_a:List,
        list_b:List,
        convert=str,
):
    count = 0
    for a, b in zip(list_a, list_b, strict=True):
        try:
            if convert(a) == convert(b):
                count += 1
        except ValueError:
            pass

    return count/len(list_a)

def incl_acc(
        list_a:List,
        list_b:List,
        convert=str,
        det: bool=False
):
    count = 0
    result = []
    for a, b in zip(list_a, list_b, strict=True):
        try:
            if convert(a) in convert(b):
                count += 1
                result.append(True)
            else:
                result.append(False)
        except ValueError:
            result.append(False)
            pass

    if det:
        return count/len(list_a), result

    return count/len(list_a)

In [18]:
from functools import partial

print(df.columns)
full_responses = list(df['responses'].values)

DEBUG = False
end_token = "</TOOL>"

correct_answers = list(df['correct_answers'].values)

remove_tool = partial(remove_tool_calls, start_token = "<TOOL>", end_token = end_token)

eval_answers = list(map(remove_tool, full_responses))

print(len(df))
print(len(wiki_bool))

wiki_eval_answers = [d for i, d in enumerate(eval_answers) if wiki_bool[i]]
wiki_answer_aliases = [d for i, d in enumerate(answer_aliases) if wiki_bool[i]]

print(len(df), flush=True)

new_line = '\n'
correct_includes = 0
includes_in_bare = 0
first_num_correct = 0
extr_first_num_correct = 0
print("Correct Ans :  Extracted Ans  : Response")
extracted = eval_answers # df['model_answers'][i]
for i in range(len(df)):
    if not wiki_bool[i]:
        continue
    if DEBUG:
        print(f"Correct answer:{correct_answers[i]:<12}  -  Extracted_answers: {extracted[i]:<12}  - Responses: {df['responses'][i].replace(new_line,'')}")
        print(f"First number in answer:{first_numbers[i]:<12}  -  Extracted_answers: {first_numbers_extracted[i]:<12}")
    # Correct answer:<5  -  Extracted_answers: <5  - Responses: 
    #print(f"Correct answer:{df['correct_answers'][i]:<12}  -  Extracted_answers: {df['extracted_answers'][i]:<12}  - Responses: {df['responses'][i].replace(new_line,'')}")
    if str(correct_answers[i]) in str(extracted[i]):
        correct_includes += 1
    if str(correct_answers[i]) in df['responses'][i]:
        includes_in_bare += 1

exact_extr_acc, incl_extr_acc = calculate_accuracy(extracted, correct_answers, stat_name="Accuracy (extracted)")
_, flex_acc = calculate_accuracy(full_responses, correct_answers, stat_name="Accuracy (full)")
print(len(eval_answers))
if "trivia" in file_dir:
    list_acc, list_acc_incl = calculate_accuracy(eval_answers, answer_aliases, stat_name="Accuracy (list)")
triviaacc = trivia_qa_accuracy(eval_answers, answer_aliases)
wiki_triviaacc = trivia_qa_accuracy(wiki_eval_answers, wiki_answer_aliases)


print(f"Included accuracy (extracted): {correct_includes/len(df)}")
print(f"Included accuracy (extracted): {incl_extr_acc}")
print(f"Included accuracy (full): {includes_in_bare/len(df)}")
print(f"Included accuracy (full): {flex_acc}")
print(f"Trivia accuracy: {triviaacc}")
print(f"Wiki Trivia accuracy: {wiki_triviaacc}")

assert len(df) == len(extracted)
assert len(df) == len(correct_answers)


print("Eval preds:")
print(f"List accuracy: {list_acc}")
print(f"List accuracy (included): {list_acc_incl}")

Index(['questions', 'correct_answers', 'responses', 'extracted',
       'tool_histories', 'first_numbers_arr', 'first_numbers_end'],
      dtype='object')
2305
11313
2305


Correct Ans :  Extracted Ans  : Response


NameError: name 'calculate_accuracy' is not defined

In [45]:
for i in range(len(df)):
    assert extracted[i] == df['model_answers'][i]
    print("Extracted  / Extracted in Eval ")
    print(f"{extracted[i]}  / {df['model_answers'][i]}")
    print

KeyError: 'model_answers'

In [19]:

print(len(df))

calc_calls = 0
wiki_calls = 0
calend_calls = 0

bad_calls = 0

print(df.columns)
for i, row in df.iterrows():
    if "[Calculator" in row.responses:
        calc_calls += 1
    if "[WikiSearch" in row.responses:
        wiki_calls += 1
    if "[Calendar" in row.responses:
        calend_calls += 1
    if "[Calculator|]" in row.responses or "[Calculator|→]" in row.responses or "→ None" in row.responses:
        bad_calls += 1
        print(row.responses)
        print("LOL")
        
    print(f"QUESTION: {row.questions}")
    print(f"ANSWER {i}: {row.correct_answers}")
    print("FIRST NUM arr: ", row.first_numbers_arr)
    print("FIRST NUM end: ", row.first_numbers_end)
    print(f"RESULT: {results_asdiv_a[i]}")
    print("RESPONSE: ",row.responses.replace("\n", "\n--"))
    print("_________________________________________")

print(f"Calculator calls: {calc_calls}")
print(f"Wiki calls: {wiki_calls}")
print(f"Calendar calls: {calend_calls}")
print(f"Bad calls: {bad_calls}")

2305
Index(['questions', 'correct_answers', 'responses', 'extracted',
       'tool_histories', 'first_numbers_arr', 'first_numbers_end'],
      dtype='object')
QUESTION: Seven red apples and two green apples are in the basket. How many apples are in the basket?
ANSWER 0: 9
FIRST NUM arr:  9.0
FIRST NUM end:  9.0
RESULT: True
RESPONSE:  The answer is 7+2=9 apples. [WikiSearch|7+2→ As I was going to St Ives. A similar problem is found in the Rhind Mathematical Papyrus (Problem 79), dated to around 1650 BC. The papyrus is translated as follows: The problem appears to be an illustration of an algorithm for multiplying numbers. The sequence 7, 7, 7, 7, 7 appears in the right-hand column, and the terms 2,801, 2×2,801, 4×2,801 appear in the left; the sum]
_________________________________________
QUESTION: Ellen has six more balls than Marin. Marin has nine balls. How many balls does Ellen have?
ANSWER 1: 15
FIRST NUM arr:  15.0
FIRST NUM end:  15.0
RESULT: True
RESPONSE:  The balls are in a 

In [21]:
# Open the csv file: /vol/bitbucket/jg2619/augmenting_llms/augmented_data_pipeline/data/train/definite_horizon/GPTJ/train3_tagged.csv
from transformers import AutoTokenizer
import pandas as pd

import ast


cache_dir = "/vol/bitbucket/jg2619/augmenting_llms/augmented_data_pipeline/toolformer/cache"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", cache_dir=cache_dir)

tokenizer.add_tokens([" <TOOL>", "</TOOL>", "[PAD]"])

df2 = pd.read_csv('/vol/bitbucket/jg2619/augmenting_llms/augmented_data_pipeline/data/train/definite_horizon/GPTJ/train3_tagged.csv')


for index, row in df2.iterrows():
    if row["tool_name"] == "Calculator":
        print(tokenizer.decode(ast.literal_eval(row["tokenized_text"])))


  from .autonotebook import tqdm as notebook_tqdm


KeyboardInterrupt: 