In [1]:
import openai

In [5]:
import os
import openai
import requests
client = openai.OpenAI(api_key= os.environ.get("OPENAI_API_KEY"))
import re
import json
_quad_fence = re.compile(r'`+(?:json)?', re.IGNORECASE)


def remove_json_fence(text: str) -> str:
    """
    Remove the JSON fence from the text.
    """
    # Remove the JSON fence
    text = _quad_fence.sub('', text)
    return text


def llm(history, model_name, stop=["\n"]):
    response = client.chat.completions.create(
                model=model_name,
                messages=history,
            )
   # print(response)
    response = response.choices[0].message.content.strip()

    #response = json.loads(remove_json_fence(response))
    return response


In [8]:

import toml
import random
import time

from rich.console import Console
from rich.markup import escape

from hotpotqa.hotpotqa_env import create_hotpot_env,create_fever_env, step



def run_hotpotqa(
                 question_index: int,
                 split: str = 'train',
                 ):

    env = create_fever_env(split=split)
    question = env.reset(idx=question_index)

    steps = 0
    info = None

    prompt = f""""Determine if there is Observation that SUPPORTS or REFUTES a Claim, or if there is NOT ENOUGH INFORMATION. \nClaim: Nikolaj Coster-Waldau worked with the Fox Broadcasting Company.\nThought: Nikolaj William Coster-Waldau appeared in the 2009 Fox television film Virtuality, so he has worked with the Fox Broadcasting Company.\nAnswer: SUPPORTS\n\nClaim: Stranger Things is set in Bloomington, Indiana.\nThought: Stranger Things is in the fictional town of Hawkins, Indiana, not in Bloomington, Indiana.\nAnswer:REFUTES\n\nClaim: Beautiful reached number two on the Billboard Hot 100 in 2003.?\nThought: The song peaked at number two on the Billboard Hot 100 in the United States, but not sure if it was in 2003.\nAnswer: NOT ENOUGH INFO\n"
"""
    prompt += f"\n\nClaim: {question}\nThought:"
    try:
        response = llm([{"role": "user", "content": prompt}], model_name="gpt-4o")
        #extract the answer
        ans = response.split("Answer:")[-1].strip()

        obs, r, done, info = step(env, f"finish[{ans}]")
    except Exception as e:
        print(f"Error: {e}")
        return False, steps, {"answer": 'Error', "em": False, "gt_answer": ' '}

    if not done or not info or not info['answer']:
        action = "finish[]"
        obs, r, d, info = step(env, "finish[]")



    return done, steps, info


def main():
    config_path = './conf/config.toml'
    config = toml.load(config_path)


    seed = config.get("seed", 42)
    test_count = config.get("test_count", 10)
    max_retry = config.get("max_retry", 5)
    split = config.get("split", "train")


    idxs = list(range(7405))
    random.Random(seed).shuffle(idxs)

    rs = []
    infos = []
    old_time = time.time()
    correct_count = 0

    for i in idxs[0:test_count]:
        done, steps, info = run_hotpotqa(
            question_index=i,
            split=split,
        )

        agent_ans = info['answer'].strip()
        exact_match = info['em']
        correct_answer = info['gt_answer'].strip()

        # console = Console()
        # console.print(f'[red]GT: {escape(correct_answer)}[/red]')
        # console.print(f'[green]Agent: {escape(agent_ans)}[/green]')


        #check if corect, ignore case

        if exact_match:
            correct = True
            correct_count += 1
        else:
            correct = False

        rs.append(1 if correct else 0)
        infos.append(info)


        print(sum(rs), len(rs), sum(rs) / len(rs), (time.time() - old_time) / len(rs))
        print('-----------')
        print()


    print(f'Final result: {sum(rs)}/{len(rs)} = {sum(rs) / len(rs)}')


if __name__ == "__main__":
    main()


0 1 0.0 1.5507700443267822
-----------

1 2 0.5 1.4352731704711914
-----------

1 3 0.3333333333333333 1.305526336034139
-----------

2 4 0.5 1.243513584136963
-----------

2 5 0.4 1.2729425430297852
-----------

3 6 0.5 1.2816389401753743
-----------

4 7 0.5714285714285714 1.2536962372916085
-----------

4 8 0.5 1.304094135761261
-----------

4 9 0.4444444444444444 1.3811787499321833
-----------

5 10 0.5 1.3662223100662232
-----------

6 11 0.5454545454545454 1.3685469410636208
-----------

7 12 0.5833333333333334 1.3647592465082805
-----------

7 13 0.5384615384615384 1.442833056816688
-----------

8 14 0.5714285714285714 1.4409760917936052
-----------

9 15 0.6 1.4603159745534262
-----------

10 16 0.625 1.5319738388061523
-----------

11 17 0.6470588235294118 1.502484447815839
-----------

12 18 0.6666666666666666 1.486719184451633
-----------

13 19 0.6842105263157895 1.4682654330604954
-----------

14 20 0.7 1.4296919465065003
-----------

15 21 0.7142857142857143 1.43297078495