In [1]:
from agential.prompting.standard.prompting import Standard
import os
import json
import pickle

from agential.prompting.standard.prompts import (
    STANDARD_INSTRUCTION_AMBIGNQ, 
    STANDARD_INSTRUCTION_FEVER, 
    STANDARD_INSTRUCTION_GSM8K,  
    STANDARD_INSTRUCTION_HOTPOTQA, 
    STANDARD_INSTRUCTION_SVAMP, 
    STANDARD_INSTRUCTION_TRIVIAQA,
    STANDARD_INSTRUCTION_TABMWP,
    STANDARD_INSTRUCTION_HUMANEVAL,
    STANDARD_INSTRUCTION_MBPP,
)
from agential.core.fewshots.ambignq import AMBIGNQ_FEWSHOT_EXAMPLES_DIRECT
from agential.core.fewshots.fever import FEVER_FEWSHOT_EXAMPLES_DIRECT
from agential.core.fewshots.gsm8k import GSM8K_FEWSHOT_EXAMPLES_POT
from agential.core.fewshots.hotpotqa import HOTPOTQA_FEWSHOT_EXAMPLES_DIRECT
from agential.core.fewshots.svamp import SVAMP_FEWSHOT_EXAMPLES_POT
from agential.core.fewshots.triviaqa import TRIVIAQA_FEWSHOT_EXAMPLES_DIRECT
from agential.core.fewshots.tabmwp import TABMWP_FEWSHOT_EXAMPLES_POT
from agential.core.fewshots.humaneval import HUMANEVAL_FEWSHOT_EXAMPLES_POT
from agential.core.fewshots.mbpp import MBPP_FEWSHOT_EXAMPLES_POT

import warnings
warnings.filterwarnings('ignore')

from dotenv import load_dotenv
load_dotenv()

from agential.core.llm import LLM

import wandb
wandb.login()

with open('../data/hotpotqa/hotpot_dev_v1_simplified_s42_sample500.json', 'r') as file:
    data = json.load(file)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mvincenttu[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
from agential.eval.metrics.classification import EM, f1, precision, recall


seed = 42
root_dir = "output"
method_name = "standard"
benchmark_name = "hotpotqa"
num_retries = 1
warming = [1.0]

output_path = os.path.join(root_dir, method_name, benchmark_name)

if not os.path.exists(output_path):
    os.makedirs(output_path)

llm = LLM("gpt-3.5-turbo", organization=os.getenv("OPENAI_ORGANIZATION"), seed=seed)

method = Standard(
    llm=llm,
    benchmark=benchmark_name,
)

run = wandb.init(
    project=benchmark_name, 
    entity="agential",
    config={
        "seed": seed,
        "num_retries": num_retries,
        "warming": warming,
    },
    group=method_name,
    tags=[f"method={method_name}", f"seed={seed}", f"num_retries={num_retries}", f"warming={warming}", "base"],
)

eval_table_data = []
perf_table_data = []
em_scores = []
precision_scores = []
recall_scores = []
f1_scores = []
outputs = []

for instance in data:
    question = instance["question"]
    answer = instance["answer"]

    # Inference.
    out = method.generate(
        question=question,
        key=answer,
        num_retries=num_retries,
        warming=warming
    )

    # Calculate metrics.
    is_correct = int(EM(out.answer, answer))
    precision_score = precision(out.answer, answer)
    recall_score = recall(out.answer, answer)
    f1_score = f1(out.answer, answer)

    # Update scores.
    em_scores.append(is_correct)
    precision_scores.append(precision_score)
    recall_scores.append(recall_score)
    f1_scores.append(f1_score)

    # Update tables.
    eval_table_data.append([question, answer, out.answer, is_correct, precision_score, recall_score, f1_score])
    perf_table_data.append([
        out.total_prompt_tokens, 
        out.total_completion_tokens, 
        out.total_tokens, 
        out.total_prompt_cost,
        out.total_completion_cost,
        out.total_cost,
        out.total_prompt_time,
        out.total_time
    ])

    # Update outputs.
    outputs.append(out)

    # Log metrics.
    run.log({
        "em": is_correct,
        "precision": precision_score,
        "recall": recall_score,
        "f1": f1_score,
    })

total_em = sum(em_scores) / len(em_scores)
total_precision = sum(precision_scores) / len(precision_scores)
total_recall = sum(recall_scores) / len(recall_scores)
total_f1 = sum(f1_scores) / len(f1_scores)

eval_table = wandb.Table(data=eval_table_data, columns=["question", "answer", "predicted_answer", "EM", "precision", "recall", "f1"])
perf_table = wandb.Table(data=perf_table_data, columns=["total_prompt_tokens", "total_completion_tokens", "total_tokens", "total_prompt_cost", "total_completion_cost", "total_cost", "total_prompt_time", "total_time"])

outputs_save_path = os.path.join(output_path, f"{run.name}.pkl")
with open(outputs_save_path, 'wb') as f:
    pickle.dump(outputs, f)

artifact = wandb.Artifact(name=run.name, type="output")
artifact.add_file(local_path=outputs_save_path, name="outputs.pkl")
artifact.save()

run.log({
    f"{run.name}_eval": eval_table,
    f"{run.name}_perf": perf_table
})

run.log({
    "total_em": total_em,
    "total_precision": total_precision,
    "total_recall": total_recall,
    "total_f1": total_f1,
})

run.finish()

[34m[1mwandb[0m: Currently logged in as: [33mvincenttu[0m ([33magential[0m). Use [1m`wandb login --relogin`[0m to force relogin


0,1
em,▁▁▁█▁▁█▁▁▁█▁▁▁▁█▁▁▁█▁█▁█▁██▁▁▁▁█▁▁▁▁▁█▁█
f1,▁▁▁█▁▁█▃▁▁█▁▁▁▁█▅▁▄█▅█▁█▂██▂▁▁▅█▁▃▁▁▁█▅█
precision,▁▁▁█▁▁█▃▁▁█▁▁▁▁█▅▁▅███▁█▂██▁▁▁▅█▁▂▁▁▁█▅█
recall,▁▁▁█▁▁█▂▁▁█▁▁▁▁█▅▁▄█▃█▁█▅██▅▁▁▅█▁▃▁▁▁█▆█
total_em,▁
total_f1,▁
total_precision,▁
total_recall,▁

0,1
em,0.0
f1,0.25
precision,0.25
recall,0.25
total_em,0.282
total_f1,0.35906
total_precision,0.37254
total_recall,0.38965
