In [None]:
from agential.agents.lats.agent import LATS
from agential.utils.docstore import DocstoreExplorer
from agential.core.llm import LLM

from langchain_community.docstore.wikipedia import Wikipedia

from agential.core.fewshots.hotpotqa import (
    HOTPOTQA_FEWSHOT_EXAMPLES_REACT,
)
from agential.core.fewshots.fever import (
    FEVER_FEWSHOT_EXAMPLES_REACT,
)
from agential.core.fewshots.triviaqa import (
    TRIVIAQA_FEWSHOT_EXAMPLES_REACT,
) 
from agential.core.fewshots.ambignq import (
    AMBIGNQ_FEWSHOT_EXAMPLES_REACT,
)
from agential.core.fewshots.gsm8k import (
    GSM8K_FEWSHOT_EXAMPLES_REACT
)
from agential.core.fewshots.svamp import (
    SVAMP_FEWSHOT_EXAMPLES_REACT
)
from agential.core.fewshots.tabmwp import (
    TABMWP_FEWSHOT_EXAMPLES_REACT
)
from agential.core.fewshots.humaneval import (
    HUMANEVAL_FEWSHOT_EXAMPLES_REACT
)
from agential.core.fewshots.mbpp import (
    MBPP_FEWSHOT_EXAMPLES_REACT
)
from agential.agents.lats.prompts import (
    HOTPOTQA_FEWSHOT_EXAMPLES_LATS_REFLECT,
    LATS_INSTRUCTION_HOTPOTQA,
    LATS_REFLECT_INSTRUCTION_HOTPOTQA,
    HOTPOTQA_FEWSHOT_EXAMPLES_LATS_VALUE, 
    LATS_VALUE_INSTRUCTION_HOTPOTQA,

    AMBIGNQ_FEWSHOT_EXAMPLES_LATS_REFLECT,
    LATS_INSTRUCTION_AMBIGNQ,
    LATS_REFLECT_INSTRUCTION_AMBIGNQ,
    AMBIGNQ_FEWSHOT_EXAMPLES_LATS_VALUE, 
    LATS_VALUE_INSTRUCTION_AMBIGNQ,

    FEVER_FEWSHOT_EXAMPLES_LATS_REFLECT,
    LATS_INSTRUCTION_FEVER,
    LATS_REFLECT_INSTRUCTION_FEVER,
    FEVER_FEWSHOT_EXAMPLES_LATS_VALUE, 
    LATS_VALUE_INSTRUCTION_FEVER,

    TRIVIAQA_FEWSHOT_EXAMPLES_LATS_REFLECT,
    LATS_INSTRUCTION_TRIVIAQA,
    LATS_REFLECT_INSTRUCTION_TRIVIAQA,
    TRIVIAQA_FEWSHOT_EXAMPLES_LATS_VALUE, 
    LATS_VALUE_INSTRUCTION_TRIVIAQA,

    GSM8K_FEWSHOT_EXAMPLES_LATS_REFLECT,
    LATS_INSTRUCTION_GSM8K,
    LATS_REFLECT_INSTRUCTION_GSM8K,
    GSM8K_FEWSHOT_EXAMPLES_LATS_VALUE,
    LATS_VALUE_INSTRUCTION_GSM8K,

    SVAMP_FEWSHOT_EXAMPLES_LATS_REFLECT,
    LATS_INSTRUCTION_SVAMP,
    LATS_REFLECT_INSTRUCTION_SVAMP,
    SVAMP_FEWSHOT_EXAMPLES_LATS_VALUE,
    LATS_VALUE_INSTRUCTION_SVAMP,

    TABMWP_FEWSHOT_EXAMPLES_LATS_REFLECT,
    LATS_INSTRUCTION_TABMWP,
    LATS_REFLECT_INSTRUCTION_TABMWP,
    TABMWP_FEWSHOT_EXAMPLES_LATS_VALUE,
    LATS_VALUE_INSTRUCTION_TABMWP,

    MBPP_FEWSHOT_EXAMPLES_LATS_REFLECT,
    LATS_INSTRUCTION_MBPP,
    LATS_REFLECT_INSTRUCTION_MBPP,
    MBPP_FEWSHOT_EXAMPLES_LATS_VALUE,
    LATS_VALUE_INSTRUCTION_MBPP,

    HUMANEVAL_FEWSHOT_EXAMPLES_LATS_REFLECT,
    LATS_INSTRUCTION_HUMANEVAL,
    LATS_REFLECT_INSTRUCTION_HUMANEVAL,
    HUMANEVAL_FEWSHOT_EXAMPLES_LATS_VALUE,
    LATS_VALUE_INSTRUCTION_HUMANEVAL,
)


import dotenv
dotenv.load_dotenv()

import warnings
warnings.filterwarnings("ignore")

llm = LLM("gpt-3.5-turbo")

In [None]:
import json

# Open and read the JSON file
with open('../data/hotpot_dev_v1_simplified.json', 'r') as file:
    data = json.load(file)

# Print the data
print(data[:5])

In [None]:
from agential.eval.metrics.classification import EM, precision, recall, f1


agent = LATS(
    llm=llm, 
    benchmark="hotpotqa", 
    docstore=DocstoreExplorer(Wikipedia()),
    n_samples=2, 
    max_reflections=4, 
    depth_limit=5,
    max_unique=5,
    cache_values=True,
)

num_correct = 0 
samples = 20

for i in data[:samples]:

    question = i["question"]
    answer = i["answer"]


    out = agent.generate(
        question=question,
        key=answer,
        examples=HOTPOTQA_FEWSHOT_EXAMPLES_REACT,
        reflect_examples=HOTPOTQA_FEWSHOT_EXAMPLES_LATS_REFLECT,
        value_examples=HOTPOTQA_FEWSHOT_EXAMPLES_LATS_VALUE,
        prompt=LATS_INSTRUCTION_HOTPOTQA,
        reflect_prompt=LATS_REFLECT_INSTRUCTION_HOTPOTQA,
        value_prompt=LATS_VALUE_INSTRUCTION_HOTPOTQA,
        additional_keys={},
        reflect_additional_keys={},
        value_additional_keys={},
        max_iterations=1,
        reset=True
    )
    is_correct = EM(out.answer, answer)
    print(question)
    print(answer, "\t\t", out.answer, "\t\t", is_correct, end="\n\n")

    num_correct += int(is_correct)

    precision_out = precision(out.answer , answer , True)

    recall_out = recall(out.answer , answer , True)

    f1_out = f1(out.answer , answer , True)
    
print(f"{num_correct}/{samples}")