In [1]:
import joblib
import numpy as np

from langchain.chat_models import ChatOpenAI
from discussion_agents.cog.agent.reflexion import ReflexionCoTAgent

In [2]:
hotpot = joblib.load('hotpot-qa-distractor-sample.joblib').reset_index(drop = True)

hotpot['supporting_paragraphs'] = None
for ind, row in hotpot.iterrows():
    supporting_articles = row['supporting_facts']['title']
    articles = row['context']['title']
    sentences = row['context']['sentences'] 
    supporting_paragraphs = []
    for article in supporting_articles:
        supporting_paragraph = ''.join(sentences[np.where(articles == article)][0])
        supporting_paragraphs.append(supporting_paragraph)
    supporting_paragraphs = '\n\n'.join(supporting_paragraphs)
    hotpot.at[ind, 'supporting_paragraphs'] = supporting_paragraphs

In [3]:
row = hotpot.iloc[0]

import dotenv
import os

dotenv.load_dotenv("../../.env")
openai_api_key = os.getenv("OPENAI_API_KEY")

llm = ChatOpenAI(
    temperature=0,
    max_tokens=250,
    model_name="gpt-3.5-turbo",
    model_kwargs={"stop": "\n"},
    openai_api_key=openai_api_key
)

agent = ReflexionCoTAgent(
    self_reflect_llm=llm,
    action_llm=llm,
    question=row['question'],
    context=row['supporting_paragraphs'],
    key=row['answer']
)

In [4]:
row["question"]

"VIVA Media AG changed it's name in 2004. What does their new acronym stand for?"

In [5]:
row["answer"]

'Gesellschaft mit beschränkter Haftung'

In [6]:
agents = [agent]
n = 1
for i in range(n):
    agent.generate(context=row["context"], question=row["question"], key=row["answer"], strategy=None)
    print(f'Answer: {agent.answer}')

Thought: The question is asking for the acronym that VIVA Media AG changed its name to in 2004. To find the answer, I need to look for information about VIVA Media AG's name change in the provided context. Let me search for that information.
Action: Finish[VIVA Media GmbH]
Answer: VIVA Media GmbH


In [8]:
row["answer"]

'Gesellschaft mit beschränkter Haftung'