In [4]:
import openai


In [20]:
import os
import openai
import requests
import re
import json

client = openai.OpenAI(api_key= os.environ.get("OPENAI_API_KEY"))
_quad_fence = re.compile(r'`+(?:json)?', re.IGNORECASE)


def remove_json_fence(text: str) -> str:
    """
    Remove the JSON fence from the text.
    """
    # Remove the JSON fence
    text = _quad_fence.sub('', text)
    return text


def llm(history, model_name, stop=["\n"]):
    response = client.chat.completions.create(
                model=model_name,
                messages=history,
            )
   # print(response)
    response = response.choices[0].message.content.strip()

    #response = json.loads(remove_json_fence(response))
    return response


In [22]:

import toml
import random
import time

from rich.console import Console
from rich.markup import escape

from hotpotqa.hotpotqa_env import create_hotpot_env,create_fever_env, step



def run_hotpotqa(
                 question_index: int,
                 split: str = 'train',
                 ):

    env = create_fever_env(split=split)
    question = env.reset(idx=question_index)

    steps = 0
    info = None

    prompt = "Determine if there is Observation that SUPPORTS or REFUTES a Claim, or if there is NOT ENOUGH INFORMATION. \nClaim: Nikolaj Coster-Waldau worked with the Fox Broadcasting Company.\nAnswer: SUPPORTS\n\nClaim: Stranger Things is set in Bloomington, Indiana.\nAnswer:REFUTES\n\nClaim: Beautiful reached number two on the Billboard Hot 100 in 2003.?\nAnswer: NOT ENOUGH INFO\n"
    prompt += f"Claim: {question}\nAnswer:"
    try:
        ans = llm([{"role": "user", "content": prompt}], model_name="gpt-4o")


        obs, r, done, info = step(env, f"finish[{ans}]")
    except Exception as e:
        print(f"Error: {e}")
        return False, steps, {"answer": 'Error', "em": False, "gt_answer": ' '}

    if not done or not info or not info['answer']:
        action = "finish[]"
        obs, r, d, info = step(env, "finish[]")



    return done, steps, info


def main():
    config_path = './conf/config.toml'
    config = toml.load(config_path)


    seed = config.get("seed", 42)
    test_count = config.get("test_count", 10)
    max_retry = config.get("max_retry", 5)
    split = config.get("split", "train")


    idxs = list(range(7405))
    random.Random(seed).shuffle(idxs)

    rs = []
    infos = []
    old_time = time.time()
    correct_count = 0

    for i in idxs[0:test_count]:
        done, steps, info = run_hotpotqa(
            question_index=i,
            split=split,
        )

        agent_ans = info['answer'].strip()
        exact_match = info['em']
        correct_answer = info['gt_answer'].strip()

        # console = Console()
        # console.print(f'[red]GT: {escape(correct_answer)}[/red]')
        # console.print(f'[green]Agent: {escape(agent_ans)}[/green]')


        #check if corect, ignore case

        if exact_match:
            correct = True
            correct_count += 1
        else:
            correct = False

        rs.append(1 if correct else 0)
        infos.append(info)


        print(sum(rs), len(rs), sum(rs) / len(rs), (time.time() - old_time) / len(rs))
        print('-----------')
        print()


    print(f'Final result: {sum(rs)}/{len(rs)} = {sum(rs) / len(rs)}')


if __name__ == "__main__":
    main()


0 1 0.0 0.4970104694366455
-----------

1 2 0.5 0.48537957668304443
-----------

1 3 0.3333333333333333 0.46449923515319824
-----------

2 4 0.5 0.4644436836242676
-----------

2 5 0.4 0.4636528491973877
-----------

3 6 0.5 0.4553142786026001
-----------

4 7 0.5714285714285714 0.4450115476335798
-----------

4 8 0.5 0.4384791851043701
-----------

4 9 0.4444444444444444 0.44088493453131783
-----------

5 10 0.5 0.45329985618591306
-----------

6 11 0.5454545454545454 0.4504156546159224
-----------

7 12 0.5833333333333334 0.44589346647262573
-----------

7 13 0.5384615384615384 0.45751023292541504
-----------

8 14 0.5714285714285714 0.4597775936126709
-----------

8 15 0.5333333333333333 0.459812863667806
-----------

8 16 0.5 0.519484668970108
-----------

9 17 0.5294117647058824 0.51766178187202
-----------

10 18 0.5555555555555556 0.5115523073408339
-----------

11 19 0.5789473684210527 0.5081454703682348
-----------

12 20 0.6 0.5031385064125061
-----------

12 21 0.57142857142