<a href="https://colab.research.google.com/github/Amirosimani/ReWOO-Gemini/blob/main/ReWOO_gemini.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
# !pip install --quiet datasets
# !pip install --quiet langchain langchain_community
# !pip install --quiet langchain_google_genai langchain_google_community
# !pip install --quiet wikipedia
# !pip install --quiet tiktoken

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━[0m [32m0.5/1.2 MB[0m [31m15.7 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.2/1.2 MB[0m [31m24.1 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m14.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [1]:
from google.colab import userdata

In [16]:
# --- Configuration ---
MODEL = "gemini-1.5-flash"
MAX_ITERATIONS = 10
MAX_EXECUTION_TIME = 90

GENERATION_CONFIG = {
    "temperature": 0.8,
    "top_p": 0.95,
    "top_k": 20,
    "candidate_count": 1,
    "max_output_tokens": 8192,
    "stop_sequences": ["STOP!"],
}

# SAFETY_SETTINGS = {
#     HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
#     HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
#     HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
#     HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
# }


GEMINI_API_KEY = userdata.get('GEMINI')
GOOGLE_API_KEY = userdata.get('GOOGLE-API')
CSE_ID = userdata.get("CSE-ID")

## Data

[**StrategyQA**](https://paperswithcode.com/dataset/strategyqa) is a question answering benchmark where the required reasoning steps are implicit in the question, and should be inferred using a strategy. It includes 2,780 examples, each consisting of a strategy question, its decomposition, and evidence paragraphs. Questions in StrategyQA are short, topic-diverse, and cover a wide range of strategies.



In [3]:
from datasets import load_dataset

ds = load_dataset("ChilleD/StrategyQA")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [4]:
# Access the training split
train_ds = ds["train"]

In [5]:
train_ds[0]

{'qid': '4fd64bb6ce5b78ab20b6',
 'term': 'Mixed martial arts',
 'description': 'full contact combat sport',
 'question': 'Is Mixed martial arts totally original from Roman Colosseum games?',
 'answer': False,
 'facts': 'Mixed Martial arts in the UFC takes place in an enclosed structure called The Octagon. The Roman Colosseum games were fought in enclosed arenas where combatants would fight until the last man was standing. Mixed martial arts contests are stopped when one of the combatants is incapacitated. The Roman Colosseum was performed in front of crowds that numbered in the tens of thousands. Over 56,000 people attended UFC 193.'}

# LLM

## Baseline

In [None]:
from google import genai
from google.genai.types import GenerateContentConfig

In [None]:
client = genai.Client(api_key=GEMINI_API_KEY)

In [None]:
def generate(prompt, model=MODEL, config=GENERATION_CONFIG):

    response = client.models.generate_content(
        model=model,
        contents=prompt,
        config=GenerateContentConfig(**config)
    )
    return response

In [None]:
generate(prompt=train_ds[0]['question'])

In [None]:
response.usage_metadata.total_token_count

In [None]:
# output_filename = "./output/combined_responses.json"  # Single output file
# all_responses = []  # List to store all responses

# # Create the output directory if it doesn't exist
# os.makedirs("./output", exist_ok=True)

# for i in range(train_ds.shape[0]):
#     response = generate(prompt=train_ds[i]['question'])

#     serializable_response = {}
#     try:
#         serializable_response = {"text": response.text,
#                                  "total_token": response.usage_metadata.total_token_count}
#     except AttributeError:
#         try:
#             serializable_response = response.to_dict()
#         except AttributeError:
#             serializable_response = str(response)  # fallback

#     all_responses.append(serializable_response)  # Add to the list of all responses

#     # Save ALL responses to the single file after EACH iteration
#     with open(output_filename, "w", encoding="utf-8") as f:
#         json.dump(all_responses, f, indent=4, ensure_ascii=False)

#     time.sleep(2)
#     print(f"Saved responses to {output_filename} (Iteration {i+1} of {train_ds.shape[0]})")

# print("Loop finished.")


## Native tool call

In [None]:
from google import genai
from google.genai.types import GenerateContentConfig

In [None]:
# def generate(prompt, model=MODEL, config=GENERATION_CONFIG):
#     response = client.models.generate_content(
#         model=model,
#         contents=prompt,
#         config=GenerateContentConfig(**config),
#         # tools='google_search_retrieval'
#     )
#     return response

In [None]:
generate(prompt=train_ds[0]['question'])

## ReAct Agent

In [8]:
import os
from typing import List, Dict, Union, Callable

import tiktoken
from langchain import hub
from langchain.agents import AgentExecutor, create_react_agent
from langchain_community.utilities import GoogleSearchAPIWrapper, WikipediaAPIWrapper
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, BaseMessage
from langchain_google_genai import ChatGoogleGenerativeAI, HarmBlockThreshold, HarmCategory
from langchain.tools import Tool

In [25]:
class ReActAgentExecutor:
    """
    A class to run the ReAct agent with specified configurations and tools.
    """

    def __init__(
        self,
        model: str = MODEL,
        generation_config: Dict = GENERATION_CONFIG,
        # safety_settings: Dict = SAFETY_SETTINGS,
        max_iterations: int = MAX_ITERATIONS,
        max_execution_time: int = MAX_EXECUTION_TIME,
        google_api_key: str = GOOGLE_API_KEY,
        cse_id: str = CSE_ID,
    ):
        self.model = model
        self.generation_config = generation_config
        # self.safety_settings = safety_settings
        self.max_iterations = max_iterations
        self.max_execution_time = max_execution_time
        self.google_api_key = google_api_key
        self.cse_id = cse_id
        self.llm = None
        self.tools = None
        self.agent = None
        self.agent_executor = None
        self.token_callback = None

        self._setup_llm()
        self._setup_tools()
        self._setup_agent()

    def _setup_llm(self):
      """Initializes the language model."""
      if not GEMINI_API_KEY or GEMINI_API_KEY == "your_gemini_api_key":
          raise ValueError("GEMINI_API_KEY must be set to a valid API key.")
      self.llm = ChatGoogleGenerativeAI(
          model=self.model,
          google_api_key=GEMINI_API_KEY,
          generation_config=self.generation_config,
          # safety_settings=self.safety_settings,
      )

    def _setup_tools(self):
        """Sets up the tools for the agent."""
        search = GoogleSearchAPIWrapper(
            google_api_key=self.google_api_key, google_cse_id=self.cse_id
        )
        wikipedia = WikipediaAPIWrapper(top_k_results=3, doc_content_chars_max=1000)

        self.tools = [
            Tool(
                name="Google Search",
                func=search.run,
                description="Useful for finding information on current events, comparisons, or diverse perspectives.",
            ),
            Tool(
                name="Wikipedia",
                func=wikipedia.run,
                description="Useful for getting definitions and summaries of topics from Wikipedia.",
            ),
        ]

    def _setup_agent(self):
        """Sets up the ReAct agent and executor."""
        prompt = hub.pull("hwchase17/react")
        self.agent = create_react_agent(self.llm, self.tools, prompt)

        self.token_callback = TokenCountingCallbackHandler(self.model)
        self.agent_executor = AgentExecutor(
            agent=self.agent,
            tools=self.tools,
            verbose=False,
            handle_parsing_errors=True,
            max_iterations=self.max_iterations,
            max_execution_time=self.max_execution_time,
            callbacks=[self.token_callback],
        )

    def run(self, input_data: Union[Dict, str]) -> Dict:
        """
        Runs the agent with the given input data.

        Args:
            input_data: Either a dictionary or a string representing the input for the agent.

        Returns:
            The output from the agent.
        """
        if isinstance(input_data, str):
            input_data = {"input": input_data}

        try:
            result = self.agent_executor.invoke(input_data)
            # Include token usage information in the result
            result["token_usage"] = {
                "total_tokens": self.token_callback.total_tokens,
                "prompt_tokens": self.token_callback.prompt_tokens,
                "completion_tokens": self.token_callback.completion_tokens,
            }
            self.token_callback.reset()  # Reset after each run
            return result
        except Exception as e:
            print(f"An error occurred: {e}")
            return {"error": str(e)}


class TokenCountingCallbackHandler(BaseCallbackHandler):
    """Callback handler for counting tokens used by the language model."""

    def __init__(self, model_name: str):
        self.model_name = model_name
        self.total_tokens = 0
        self.prompt_tokens = 0
        self.completion_tokens = 0
        self.encoding = tiktoken.get_encoding("cl100k_base")

    def on_llm_start(
        self, serialized: Dict[str, any], prompts: List[str], **kwargs
    ) -> None:
        """Collect prompt tokens when LLM starts."""
        for prompt in prompts:
            self.prompt_tokens += len(self.encoding.encode(prompt))

    def on_llm_end(self, response: LLMResult, **kwargs) -> None:
        """Collect completion tokens when LLM finishes generating."""
        if response.generations:
            for generation_list in response.generations:
                for generation in generation_list:
                    if generation.text:
                        self.completion_tokens += len(
                            self.encoding.encode(generation.text)
                        )

    def on_agent_action(self, action: AgentAction, **kwargs) -> None:
        """Increment token count on agent action."""
        if action.log:
            self.total_tokens += len(self.encoding.encode(action.log))

    def on_agent_finish(self, finish: AgentFinish, **kwargs) -> None:
        """Increment token count on agent finish."""
        if finish.log:
            self.total_tokens += len(self.encoding.encode(finish.log))

    def on_chain_end(self, outputs, **kwargs) -> None:
        """Print the total tokens used when the chain finishes."""
        self.total_tokens += self.completion_tokens + self.prompt_tokens
        print(f"Prompt tokens: {self.prompt_tokens}")
        print(f"Completion tokens: {self.completion_tokens}")
        print(f"Total tokens used in this chain: {self.total_tokens}")

    def reset(self):
        """Reset the counters for the next chain run."""
        self.total_tokens = 0
        self.prompt_tokens = 0
        self.completion_tokens = 0

In [26]:
# agent_executor = ReActAgentExecutor()
# result = agent_executor.run(train_ds[0]["question"])

In [None]:
import json

output_filename="agent_results.jsonl"
agent_executor = ReActAgentExecutor()

all_results = []

for i, question in enumerate(train_ds["question"]):
    print(f"Running agent for question {i+1}: {question}")
    result = agent_executor.run(question)
    print(f"Result for question {i+1}: {result}")

    all_results.append(
        {"question": question, "result": result}
    )

    # Save the updated list to the file after each iteration
    with open(output_filename, "w") as f:
        for item in all_results:
            f.write(json.dumps(item) + "\n")

    print(f"Results for question {i+1} saved to {output_filename}")



Running agent for question 1: Is Mixed martial arts totally original from Roman Colosseum games?
Prompt tokens: 0
Completion tokens: 0
Total tokens used in this chain: 504
Result for question 1: {'input': 'Is Mixed martial arts totally original from Roman Colosseum games?', 'output': "No, mixed martial arts is not totally original from Roman Colosseum games. While both involved combat, MMA's development is rooted in 20th-century martial arts traditions, not directly connected to the fighting styles of the Roman Colosseum.", 'token_usage': {'total_tokens': 504, 'prompt_tokens': 0, 'completion_tokens': 0}}
Results for question 1 saved to agent_results.jsonl
Running agent for question 2: Is the cuisine of Hawaii suitable for a vegan?




  lis = BeautifulSoup(html).find_all('li')


  lis = BeautifulSoup(html).find_all('li')


Prompt tokens: 0
Completion tokens: 0
Total tokens used in this chain: 297
Result for question 2: {'input': 'Is the cuisine of Hawaii suitable for a vegan?', 'output': 'While traditional Hawaiian cuisine is not inherently vegan, there are many vegan options available, both as adaptations of existing dishes and as entirely new creations.  Many restaurants offer vegan versions of popular Hawaiian foods, and numerous recipes are available online and in cookbooks.  Therefore, Hawaiian cuisine can be suitable for a vegan, with some adjustments or choices.', 'token_usage': {'total_tokens': 297, 'prompt_tokens': 0, 'completion_tokens': 0}}
Results for question 2 saved to agent_results.jsonl
Running agent for question 3: Is capturing giant squid in natural habitat impossible with no gear?


# TODO:

[x] add baseline

[] add calc (optional)

[x] callback/count tokens for react

[] ReWOO

[] add everything to a class, run queries at concurrently