<a href="https://colab.research.google.com/github/Amirosimani/ReWOO-Gemini/blob/main/ReWOO_gemini.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


|||
|----------|-------------|
| Author(s)   | amirimani@ |
| Last updated | 24/01/2025 |
<br><br>


In [None]:
# !pip install --quiet datasets
# !pip install --quiet langchain langchain_community
# !pip install --quiet langchain_google_genai langchain_google_community
# !pip install --quiet wikipedia
# !pip install --quiet tiktoken

In [1]:
from google.colab import userdata

In [None]:
# --- Configuration ---
MODEL = "gemini-1.5-flash"
MAX_ITERATIONS = 10
MAX_EXECUTION_TIME = 90

GENERATION_CONFIG = {
    "temperature": 0.8,
    "top_p": 0.95,
    "top_k": 20,
    "candidate_count": 1,
    "max_output_tokens": 8192,
    "stop_sequences": ["STOP!"],
}

# SAFETY_SETTINGS = {
#     HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
#     HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
#     HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
#     HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
# }


GEMINI_API_KEY = userdata.get('GEMINI')
GOOGLE_API_KEY = userdata.get('GOOGLE-API')
CSE_ID = userdata.get("CSE-ID")

## Data

[**StrategyQA**](https://paperswithcode.com/dataset/strategyqa) is a question answering benchmark where the required reasoning steps are implicit in the question, and should be inferred using a strategy. It includes 2,780 examples, each consisting of a strategy question, its decomposition, and evidence paragraphs. Questions in StrategyQA are short, topic-diverse, and cover a wide range of strategies.



In [None]:
from datasets import load_dataset

ds = load_dataset("ChilleD/StrategyQA")

In [None]:
# Access the training split
train_ds = ds["train"]

In [None]:
train_ds[0]

# LLM

## Baseline

In [None]:
import time
from google import genai
from google.genai.types import GenerateContentConfig

In [None]:
client = genai.Client(api_key=GEMINI_API_KEY)

In [None]:
def generate(prompt, model=MODEL, config=GENERATION_CONFIG):

    client = genai.Client(api_key=GEMINI_API_KEY)
    config["system_instruction"] = "Once you are done finding the answer, only return Yes or No"

    start_time = time.time()

    response = client.models.generate_content(
        model=model,
        contents=prompt,
        config=GenerateContentConfig(**config)
    )

    end_time = time.time()
    wall_time = end_time - start_time

    new_response = {
        'gemini_response': response,
        'wall_time': wall_time
    }

    return new_response

In [None]:
import json

output_filename = "./gemini_responses.jsonl"
all_responses = []

for i in range(3):
    response = generate(prompt=train_ds[i]['question'])

    serializable_response = {}
    try:
        serializable_response = {
            "text": response['gemini_response'].text.strip(),
            "total_token": response['gemini_response'].usage_metadata.total_token_count,
            "wall_time": response['wall_time']
        }
    except (AttributeError, KeyError, TypeError):
        try:
            serializable_response = response.to_dict()
        except AttributeError:
            serializable_response = str(response)

    with open(output_filename, "a", encoding="utf-8") as f:
        json.dump(serializable_response, f, ensure_ascii=False)
        f.write("\n")

    time.sleep(1)
    print(f"(Iteration {i+1} of {train_ds.shape[0]})")

print("Loop finished.")

## Native tool call

## ReAct Agent

In [None]:
import os
from typing import List, Dict, Union, Callable

import tiktoken
from langchain import hub
from langchain.agents import AgentExecutor, create_react_agent
from langchain_community.utilities import GoogleSearchAPIWrapper, WikipediaAPIWrapper
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, BaseMessage
from langchain_google_genai import ChatGoogleGenerativeAI, HarmBlockThreshold, HarmCategory
from langchain.tools import Tool

In [None]:
class ReActAgentExecutor:
    """
    A class to run the ReAct agent with specified configurations and tools.
    """

    def __init__(
        self,
        model: str = MODEL,
        generation_config: Dict = GENERATION_CONFIG,
        # safety_settings: Dict = SAFETY_SETTINGS,
        max_iterations: int = MAX_ITERATIONS,
        max_execution_time: int = MAX_EXECUTION_TIME,
        google_api_key: str = GOOGLE_API_KEY,
        cse_id: str = CSE_ID,
    ):
        self.model = model
        self.generation_config = generation_config
        # self.safety_settings = safety_settings
        self.max_iterations = max_iterations
        self.max_execution_time = max_execution_time
        self.google_api_key = google_api_key
        self.cse_id = cse_id
        self.llm = None
        self.tools = None
        self.agent = None
        self.agent_executor = None
        self.token_callback = None

        self._setup_llm()
        self._setup_tools()
        self._setup_agent()

    def _setup_llm(self):
      """Initializes the language model."""
      if not GEMINI_API_KEY or GEMINI_API_KEY == "your_gemini_api_key":
          raise ValueError("GEMINI_API_KEY must be set to a valid API key.")
      self.llm = ChatGoogleGenerativeAI(
          model=self.model,
          google_api_key=GEMINI_API_KEY,
          generation_config=self.generation_config,
          # safety_settings=self.safety_settings,
      )

    def _setup_tools(self):
        """Sets up the tools for the agent."""
        search = GoogleSearchAPIWrapper(
            google_api_key=self.google_api_key, google_cse_id=self.cse_id
        )
        # wikipedia = WikipediaAPIWrapper(top_k_results=3, doc_content_chars_max=1000)

        self.tools = [
            Tool(
                name="Google Search",
                func=search.run,
                description="Useful for finding information on current events, comparisons, or diverse perspectives.",
            ),
            # Tool(
            #     name="Wikipedia",
            #     func=wikipedia.run,
            #     description="Useful for getting definitions and summaries of topics from Wikipedia.",
            # ),
        ]

    def _setup_agent(self):
        """Sets up the ReAct agent and executor."""
        prompt = hub.pull("hwchase17/react")
        self.agent = create_react_agent(self.llm, self.tools, prompt)

        self.token_callback = TokenCountingCallbackHandler(self.model)
        self.agent_executor = AgentExecutor(
            agent=self.agent,
            tools=self.tools,
            verbose=False,
            handle_parsing_errors=True,
            max_iterations=self.max_iterations,
            max_execution_time=self.max_execution_time,
            callbacks=[self.token_callback],
        )

    def run(self, input_data: Union[Dict, str]) -> Dict:
        """
        Runs the agent with the given input data.

        Args:
            input_data: Either a dictionary or a string representing the input for the agent.

        Returns:
            The output from the agent.
        """
        if isinstance(input_data, str):
            input_data = {"input": input_data}

        try:
            result = self.agent_executor.invoke(input_data)
            # Include token usage information in the result
            result["token_usage"] = {
                "total_tokens": self.token_callback.total_tokens,
                "prompt_tokens": self.token_callback.prompt_tokens,
                "completion_tokens": self.token_callback.completion_tokens,
            }
            self.token_callback.reset()  # Reset after each run
            return result
        except Exception as e:
            print(f"An error occurred: {e}")
            return {"error": str(e)}


class TokenCountingCallbackHandler(BaseCallbackHandler):
    """Callback handler for counting tokens used by the language model."""

    def __init__(self, model_name: str):
        self.model_name = model_name
        self.total_tokens = 0
        self.prompt_tokens = 0
        self.completion_tokens = 0
        self.encoding = tiktoken.get_encoding("cl100k_base")

    def on_llm_start(
        self, serialized: Dict[str, any], prompts: List[str], **kwargs
    ) -> None:
        """Collect prompt tokens when LLM starts."""
        for prompt in prompts:
            self.prompt_tokens += len(self.encoding.encode(prompt))

    def on_llm_end(self, response: LLMResult, **kwargs) -> None:
        """Collect completion tokens when LLM finishes generating."""
        if response.generations:
            for generation_list in response.generations:
                for generation in generation_list:
                    if generation.text:
                        self.completion_tokens += len(
                            self.encoding.encode(generation.text)
                        )

    def on_agent_action(self, action: AgentAction, **kwargs) -> None:
        """Increment token count on agent action."""
        if action.log:
            self.total_tokens += len(self.encoding.encode(action.log))

    def on_agent_finish(self, finish: AgentFinish, **kwargs) -> None:
        """Increment token count on agent finish."""
        if finish.log:
            self.total_tokens += len(self.encoding.encode(finish.log))

    def on_chain_end(self, outputs, **kwargs) -> None:
        """Print the total tokens used when the chain finishes."""
        self.total_tokens += self.completion_tokens + self.prompt_tokens
        print(f"Prompt tokens: {self.prompt_tokens}")
        print(f"Completion tokens: {self.completion_tokens}")
        print(f"Total tokens used in this chain: {self.total_tokens}")

    def reset(self):
        """Reset the counters for the next chain run."""
        self.total_tokens = 0
        self.prompt_tokens = 0
        self.completion_tokens = 0

In [None]:
# agent_executor = ReActAgentExecutor()
# result = agent_executor.run(train_ds[0]["question"])


In [None]:
import json

output_filename="agent_results.jsonl"
agent_executor = ReActAgentExecutor()

all_results = []

for i, question in enumerate(train_ds["question"]):
    print(f"Running agent for question {i+1}: {question}")
    result = agent_executor.run(question)
    print(f"Result for question {i+1}: {result}")

    all_results.append(
        {"question": question, "result": result}
    )

    # Save the updated list to the file after each iteration
    with open(output_filename, "w") as f:
        for item in all_results:
            f.write(json.dumps(item) + "\n")

    print(f"Results for question {i+1} saved to {output_filename}")

## ReWOO

ReWOO: Decoupling Reasoning from Observations
for Efficient Augmented Language Models [paper](https://arxiv.org/pdf/2305.18323)

based on the implementation [here](https://github.com/billxbf/ReWOO/tree/main)

In [None]:
import time
import re
import tiktoken
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain import LLMChain, PromptTemplate
from langchain_community.utilities import GoogleSearchAPIWrapper

In [None]:
# --- Simplified Worker ---
class GoogleSearchWorker:
    def __init__(self, name="Google"):
        self.name = name
        self.google_api_key = GOOGLE_API_KEY
        self.cse_id = CSE_ID
        self.description = "Worker that searches results from Google. Useful when you need to find short " \
                           "and succinct answers about a specific topic. Input should be a search query."

    def run(self, input):
        search = GoogleSearchAPIWrapper(
            google_api_key=self.google_api_key, google_cse_id=self.cse_id
        )
        # Get the results from the API
        results = search.results(input, 1)

        # Print the structure of the results for debugging
        print("Results Structure:", results)

        evidence = ""
        for result in results:
            # Check if 'snippet' exists, otherwise use 'title' or 'body'
            if "snippet" in result:
                evidence += result["snippet"]
            elif "title" in result:
                evidence += result["title"]
            elif "body" in result:  # Use "body" as a fallback
                evidence += result["body"]
            else:
                print("Warning: No relevant information found in result:", result)

        return evidence

# --- LLM Node (Simplified) ---
class LLMNode:
    def __init__(self, name, model_name, stop=None, input_type=str, output_type=str):
        self.name = name
        self.model_name = model_name
        self.model = model_name
        self.stop = stop
        self.input_type = input_type
        self.output_type = output_type
        self.generation_config = {
            "temperature": 0,
         }
        self.llm = ChatGoogleGenerativeAI(
            model=self.model,
            google_api_key=GEMINI_API_KEY,
            generation_config=self.generation_config,
            # safety_settings=self.safety_settings,
        )
        self.tokenizer = tiktoken.get_encoding("cl100k_base")

    def call_llm(self, prompt, stop):
        if isinstance(prompt, list):
            prompt_template = PromptTemplate(template=prompt[0], input_variables=["question"])
            prompt_text = prompt[0]
            llm_chain = LLMChain(prompt=prompt_template, llm=self.llm, verbose=False)
            response = llm_chain(prompt[1])
            output = response["text"].strip()
        else:
            prompt_template = PromptTemplate(template=prompt, input_variables=[])
            prompt_text = prompt
            llm_chain = LLMChain(prompt=prompt_template, llm=self.llm, verbose=False)
            response = llm_chain({})
            output = response["text"].strip()

        prompt_tokens = len(self.tokenizer.encode(prompt_text))
        completion_tokens = len(self.tokenizer.encode(output))
        return {
            "output": output,
            "prompt_tokens": prompt_tokens,
            "completion_tokens": completion_tokens
        }

# --- Planner ---
class Planner(LLMNode):
    def __init__(self, model_name="gemini-pro", fewshot=""):
        super().__init__("Planner", model_name, stop=None, input_type=str, output_type=str)
        self.worker_prompt = "Tools can be one of the following:\nGoogle[input]: Worker that searches results from Google. Useful when you need to find short and succinct answers about a specific topic. Input should be a search query.\n\n"
        self.prefix = "For the following tasks, make plans that can solve the problem step-by-step. For each plan, " \
                     "indicate which external tool together with tool input to retrieve evidence. You can store the " \
                     "evidence into a variable #E that can be called by later tools. (Plan, #E1, Plan, #E2, Plan, ...) \n\n"
        self.suffix = "Begin! Describe your plans with rich details. Each Plan should be followed by only one #E.\n\n"
        self.fewshot = fewshot

    def run(self, input, log=False):
        prompt = self.prefix + self.worker_prompt + self.fewshot + self.suffix + input + '\n'
        response = self.call_llm(prompt, self.stop)
        if log:
            return response
        return response["output"]

# --- Solver ---
class Solver(LLMNode):
    def __init__(self, model_name="gemini-pro"):
        super().__init__("Solver", model_name, stop=None, input_type=str, output_type=str)
        self.prefix = "Solve the following task or problem. To assist you, we provide some plans and corresponding evidences that might be helpful. Notice that some of these information contain noise so you should trust them with caution.\n\n"
        self.suffix = "\nNow begin to solve the task or problem. Respond with the answer directly with no extra words.\n\n"

    def run(self, input, worker_log, log=False):
        prompt = self.prefix + input + "\n" + worker_log + self.suffix + input + '\n'
        response = self.call_llm(prompt, self.stop)
        if log:
            return response
        return response["output"]

# --- Main PWS Class ---
class PWS:
    def __init__(self, planner_model="gemini-pro", solver_model="gemini-pro", fewshot=""):
        self.worker = GoogleSearchWorker()
        self.planner = Planner(model_name=planner_model, fewshot=fewshot)
        self.solver = Solver(model_name=solver_model)
        self.plans = []
        self.planner_evidences = {}
        self.worker_evidences = {}
        self.tokenizer = tiktoken.get_encoding("cl100k_base")
    def run(self, input):
        self._reinitialize()
        result = {}
        st = time.time()

        # Plan
        planner_response = self.planner.run(input, log=True)
        plan = planner_response["output"]

        # Store the planner input before calling call_llm
        planner_input = self.planner.prefix + self.planner.worker_prompt + self.planner.fewshot + self.planner.suffix + input + '\n'

        planner_log = planner_input + planner_response["output"]

        self.plans = self._parse_plans(plan)
        self.planner_evidences = self._parse_planner_evidences(plan)

        # --- Validation and Error Handling ---
        valid_plan = self._validate_plan()
        if not valid_plan:
            print("Warning: Invalid plan generated. Skipping worker and passing the question to solver.")
            worker_log = ""
            solver_input = self.solver.prefix + input + "\n" + self.solver.suffix + input + '\n'
        else:
            # Work
            self._get_worker_evidences()
            worker_log = ""
            total_worker_tokens = 0
            for i in range(len(self.plans)):
                e = f"#E{i + 1}"
                if e in self.worker_evidences:
                    worker_log += f"{self.plans[i]}\nEvidence:\n{self.worker_evidences[e]}\n"
                    total_worker_tokens += self._count_tokens(self.worker_evidences[e])
                else:
                    worker_log += f"{self.plans[i]}\nEvidence:\nNo evidence found for {e}\n"
                    print(f"Warning: No evidence found for {e} in self.worker_evidences")

        # Solve
        solver_response = self.solver.run(input, worker_log, log=True)

        # Similar fix for solver_log
        if valid_plan:
            solver_input = self.solver.prefix + input + "\n" + worker_log + self.solver.suffix + input + '\n'
        else:
            solver_input = self.solver.prefix + input + "\n" + self.solver.suffix + input + '\n'

        output = solver_response["output"]
        solver_log = solver_input + solver_response["output"]

        result["wall_time"] = time.time() - st
        result["input"] = input
        result["output"] = output
        result["planner_log"] = planner_log
        result["worker_log"] = worker_log
        result["solver_log"] = solver_log
        result["steps"] = len(self.plans) + 1
        result["prompt_tokens"] = planner_response["prompt_tokens"] + solver_response["prompt_tokens"]
        result["completion_tokens"] = planner_response["completion_tokens"] + solver_response["completion_tokens"]

        if valid_plan:
            result["total_tokens"] = result["prompt_tokens"] + result["completion_tokens"] + total_worker_tokens
        else:
            result["total_tokens"] = result["prompt_tokens"] + result["completion_tokens"]

        return result

    def _validate_plan(self):
        """
        Validates if the generated plan has the correct #E notation in sequence.
        """
        for i in range(len(self.plans)):
            expected_evidence_key = f"#E{i + 1}"
            if expected_evidence_key not in self.planner_evidences:
                return False
        return True

    def _parse_plans(self, response):
        plans = []
        for line in response.splitlines():
            if line.startswith("Plan:"):
                plans.append(line)
        return plans

    def _parse_planner_evidences(self, response):
        evidences = {}
        for line in response.splitlines():
            if line.startswith("#") and line[1] == "E" and line[2].isdigit():
                parts = line.split("=", 1)  # Split into at most 2 parts
                if len(parts) == 2:
                    e, tool_call = parts
                    e, tool_call = e.strip(), tool_call.strip()
                    evidences[e] = tool_call
                else:
                    # Handle cases where there's no '=' after #E
                    e = parts[0].strip()
                    evidences[e] = "No evidence found"  # Or some other default value
                    print(f"Warning: Invalid planner evidence format: {line}")
        return evidences

    def _get_worker_evidences(self):
        for e, tool_call in self.planner_evidences.items():
            if not tool_call.startswith("Google["):
                self.worker_evidences[e] = "No evidence found"
                continue
            tool_input = tool_call[7:-1]
            for var in re.findall(r"#E\d+", tool_input):
                if var in self.worker_evidences:
                    tool_input = tool_input.replace(var, "[" + self.worker_evidences[var] + "]")

            self.worker_evidences[e] = self.worker.run(tool_input)

    def _reinitialize(self):
        self.plans = []
        self.planner_evidences = {}
        self.worker_evidences = {}

    def _count_tokens(self, text):
        return len(self.tokenizer.encode(text))

In [None]:
import json

pws = PWS()

results_file = "./rewoo_results.json"
with open(results_file, "w") as f:
    json.dump([], f)

for question in train_ds["question"]:
    result = pws.run(question)
    print(result)

    with open(results_file, "r+") as f:
        data = json.load(f)
        data.append(result)
        f.seek(0)
        json.dump(data, f, indent=4)

# TODO:

[x] add baseline

[] add prompt, completion, total tokens

[] native google search retrieval

[x] callback/count tokens for react

[] ReWOO

[ ] add walltime to all other functions -> # token, time, accuracy

[] add everything to a class, run queries at concurrently