# Unit 1: 6 - Creating a Basic ReAct Agent with SmolLM2 from Scratch

**Collaborators**:
* Roberto Rodriguez ([@Cyb3rWard0g](https://x.com/Cyb3rWard0g))

## Introduction to ReAct Agents
A ReAct agent follows a structured `Thought-Action-Observation` loop, allowing it to reason about its next steps, take actions, and adjust its approach based on observations. Unlike a simple ToolCallingAgent, which directly invokes tools based on a query, a ReAct agent thinks step by step before acting and refines its response through multiple iterations.

In the previous notebooks, we saw how easy it is to build a `ToolCallingAgent`, where the model directly produces tool calls. Now, let's explore how we can replicate the ReAct pattern with the `SmolLM2` model, as described in [this notebook shared in the course](https://github.com/huggingface/agents-course/blob/main/notebooks/unit1/dummy_agent_library.ipynb).

### Install Required Libraries

In [None]:
# !pip install transformers torch

## Define LM Client

In [1]:
from typing import List, Dict
from transformers import StoppingCriteria, StoppingCriteriaList

class LMClient:
    """
    Handles communication with SmolLM2 for generating responses.
    """

    def __init__(self, model, tokenizer, device, max_new_tokens=512):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.max_new_tokens = max_new_tokens

    def generate(self, messages: List[Dict], stop_sequences: List[str] = None) -> str:
        """
        Generates a response from SmolLM2 given a conversation history, stopping at 'Observation:' if provided.

        Args:
            messages (List[Dict]): The list of messages in the chat.
            stop_sequences (List[str], optional): Sequences that signal generation should stop.

        Returns:
            str: The generated response.
        """
        input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        encoded_input = self.tokenizer(input_text, return_tensors="pt").to(self.device)
        input_ids = encoded_input["input_ids"]
        attention_mask = encoded_input["attention_mask"]

        stopping_criteria = None
        if stop_sequences:
            stopping_criteria = self.make_stopping_criteria(stop_sequences)

        outputs = self.model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_new_tokens=self.max_new_tokens,
            eos_token_id=self.tokenizer.eos_token_id,
            stopping_criteria=stopping_criteria
        )

        generated_tokens = outputs[0][input_ids.shape[1]:]
        return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)

    def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
        """
        Creates a stopping criterion that halts generation when any of the stop sequences are reached.

        Args:
            stop_sequences (List[str]): A list of stop sequences.

        Returns:
            StoppingCriteriaList: Custom stopping criteria.
        """

        class StopOnStrings(StoppingCriteria):
            def __init__(self, stop_strings, tokenizer):
                self.stop_strings = stop_strings
                self.tokenizer = tokenizer
                self.generated_text = ""  # Store generated text stream

            def __call__(self, input_ids, scores, **kwargs):
                """Stop generation when any stop sequence is found in the accumulated output."""
                self.generated_text += self.tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)

                # Stop if any stop sequence appears in the generated text
                return any(stop_seq in self.generated_text for stop_seq in self.stop_strings)

        return StoppingCriteriaList([StopOnStrings(stop_sequences, self.tokenizer)])


## Define Tool Class and Decorator

In [33]:
import inspect

class Tool:
    """
    Represents an AI-registered tool.
    """

    def __init__(self, name: str, description: str, func: callable):
        self.name = name
        self.description = description
        self.func = func
        self.arguments = inspect.signature(func).parameters
        self.outputs = inspect.signature(func).return_annotation
    
    def to_string(self) -> str:
        """
        Returns a structured representation of the tool.
        """
        args_str = ", ".join([f"{arg}: {param.annotation}" for arg, param in self.arguments.items()])
        return f"Tool Name: {self.name}, Description: {self.description}, Arguments: {args_str}, Outputs: {self.outputs}"

    def __call__(self, *args, **kwargs):
        """Invoke the tool."""
        return self.func(*args, **kwargs)

def tool(func):
    """
    Decorator to register a function as a tool.
    """
    return Tool(func.__name__, func.__doc__, func)


## Define Agent Class

### Define Logging

In [3]:
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

### Define System Prompt Template

In [21]:
from jinja2 import Template

SYSTEM_PROMPT = Template("""
Answer the following questions as best you can. You have access to the following tools:

{{ tools }}

The way you use the tools is by specifying a json blob.
Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).

### ReAct Thought-Action-Observation Cycle

1. **Thought**: First, analyze the query and determine what action to take.
2. **Action**: Call a tool if needed, using a structured JSON format.
3. **Observation**: Wait for the tool's response before making the next decision.

#### Example Format:
Question: the input question you must answer
Thought: reason about what to do next
Action:

{
  "action": "tool_name",
  "action_input": { "param1": "value1" }
}

Observation: result from the tool execution
... (this Thought/Action/Observation sequence can repeat N times)

You must always end your output with:

Thought: I now know the final answer
Final Answer: <your answer here>

Now begin! Reminder to ALWAYS use the exact characters `Final Answer:` when you provide a definitive answer.
""")

### Defining Tool Calling Workflow

In [22]:
from typing import List, Dict, Tuple, Optional
import json
import re
import logging

logger = logging.getLogger(__name__)

class ReActAgent:
    """
    Implements a ReAct-style agent that follows the Thought-Action-Observation cycle.
    """

    def __init__(self, model: LMClient):
        self.model = model
        self.system_prompt = SYSTEM_PROMPT
        self.tools = {}

    def register_tool(self, tool: Tool):
        """
        Registers a tool for function calling.

        Args:
            tool (Tool): The tool instance.
        """
        if not isinstance(tool, Tool):
            raise TypeError(f"Expected Tool instance, got {type(tool)}")
        self.tools[tool.name] = tool
        logger.info(f"Registered tool: {tool.name}")

    def prepare_messages(self, query: str) -> List[Dict[str, str]]:
        """
        Prepares structured messages including system instructions.

        Args:
            query (str): The user query.

        Returns:
            List[Dict[str, str]]: Formatted conversation messages.
        """
        tool_descriptions = "\n".join([t.to_string() for t in self.tools.values()])
        rendered_prompt = self.system_prompt.render(tools=tool_descriptions)

        system_message = {"role": "system", "content": rendered_prompt}
        user_message = {"role": "user", "content": f"Question: {query}\n"}  # Start accumulating

        return [system_message, user_message]

    def parse_response(self, text: str) -> Tuple[Optional[str], Optional[dict], Optional[str]]:
        """
        Extracts the thought, action, and final answer (if present) from the model response.

        Args:
            text (str): The model-generated response.

        Returns:
            Tuple: (thought content, action dictionary if present, final answer if present)
        """
        logger.info(f"Received response from model: {text}")

        # Define regex patterns
        action_regex = re.compile(r'Action:\s*({.*?})', re.DOTALL)  # Extract JSON action
        final_answer_regex = re.compile(r'Final Answer:\s*(.*)', re.DOTALL)  # Extract final answer
        thought_regex = re.compile(r'Thought:\s*(.*?)$', re.DOTALL | re.MULTILINE)  # Extract thought

        thought, action, final_answer = None, None, None

        # Extract thought
        thought_match = thought_regex.search(text)
        if thought_match:
            thought = thought_match.group(1).strip()

        # Extract action JSON
        action_match = action_regex.search(text)
        if action_match:
            try:
                action = json.loads(action_match.group(1))
            except json.JSONDecodeError:
                logger.error(f"Invalid action JSON: {action_match.group(1)}")
                raise ValueError(f"Failed to parse action JSON: {action_match.group(1)}")

        # Extract final answer
        final_answer_match = final_answer_regex.search(text)
        if final_answer_match:
            final_answer = final_answer_match.group(1).strip()

        return thought, action, final_answer

    def _execute_tool_calls(self, action: dict) -> str:
        """
        Executes the requested tool function and returns the observation.

        Args:
            action (dict): The tool call JSON.

        Returns:
            str: Observation string containing the tool execution result.
        """
        tool_name = action.get("action")
        tool_args = action.get("action_input", {})

        if tool_name not in self.tools:
            observation = f"Error: Unknown tool {tool_name}"
            logger.error(observation)
            return observation

        logger.info(f"Executing tool: {tool_name} with arguments {tool_args}")
        tool_result = self.tools[tool_name](**tool_args)

        return f"Observation: {tool_result}\n"

    def run(self, query: str) -> str:
        """
        Processes a user query using the Thought-Action-Observation cycle.

        Args:
            query (str): User query.

        Returns:
            str: The final natural response from the assistant.
        """
        logger.info(f"User query received: {query}")
        messages = self.prepare_messages(query)

        while True:  # Loop until "Final Answer" is found
            response_text = self.model.generate(messages, stop_sequences=["Observation:"])

            thought, action, final_answer = self.parse_response(response_text)

            # Log current cycle
            logger.info(f"Thought: {thought}")
            logger.info(f"Action: {action}")

            if final_answer:
                logger.info(f"Final response from assistant: {final_answer}")
                return final_answer

            if action:
                observation = self._execute_tool_calls(action)

                # Update user message content instead of adding new messages
                messages[1]["content"] += f"Thought: {thought}\n"
                messages[1]["content"] += f"Action:\n{json.dumps(action, indent=2)}\n"
                messages[1]["content"] += f"{observation}\n"

                logger.info(f"Continuing cycle with observation: {observation}")

            else:
                logger.warning("No action detected, breaking loop.")
                return response_text

## Initializing SmolLM2 Agent

### Loading SmolLM2 Efficiently

To avoid downloading the model every time (**~3.42 GB**), we first check if it exists locally before loading:

In [6]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os

MODEL_NAME = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
MODEL_DIR = "data/smollm2"

def load_model():
    if os.path.exists(MODEL_DIR):
        print("Loading model from local directory.")
        model = AutoModelForCausalLM.from_pretrained(MODEL_DIR)
    else:
        print("Downloading model...")
        model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
        model.save_pretrained(MODEL_DIR)
    return model

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = load_model().to(device)

Loading model from local directory.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Initializing Language Model Client

In [7]:
model = LMClient(model, tokenizer, device)

### Defining Tools

Tools allow the model to execute external functions when needed. We define them as Python functions and convert them into structured format for SmolLM2 to understand their purpose.

In [34]:
import datetime
import random

@tool
def get_current_time() -> str:
    """Returns the current time in HH:MM:SS format."""
    return datetime.datetime.now().strftime("%H:%M:%S")
@tool
def get_random_number(min: int, max: int) -> int:
    """Returns a random number between min and max."""
    return random.randint(min, max)

In [35]:
get_current_time.to_string()

"Tool Name: get_current_time, Description: Returns the current time in HH:MM:SS format., Arguments: , Outputs: <class 'str'>"

In [36]:
get_random_number.to_string()

"Tool Name: get_random_number, Description: Returns a random number between min and max., Arguments: min: <class 'int'>, max: <class 'int'>, Outputs: <class 'int'>"

### Initializing Agent

In [26]:
agent = ReActAgent(model=model)

### Registering Tools

In [27]:
agent.register_tool(get_current_time)
agent.register_tool(get_random_number)

2025-02-17 20:37:14,561 - INFO - Registered tool: get_current_time
2025-02-17 20:37:14,561 - INFO - Registered tool: get_random_number


In [28]:
agent.tools

{'get_current_time': <__main__.Tool at 0x754a75d10>,
 'get_random_number': <__main__.Tool at 0x14fcaf820>}

### Basic One-Step Examples

In [29]:
response = agent.run("What is the current time?")
response


2025-02-17 20:37:15,460 - INFO - User query received: What is the current time?
2025-02-17 20:37:30,306 - INFO - Received response from model: Thought: I will use the get_current_time tool to get the current time.

Action:
{
  "action": "get_current_time"
}

Observation:
2025-02-17 20:37:30,309 - INFO - Thought: I will use the get_current_time tool to get the current time.
2025-02-17 20:37:30,309 - INFO - Action: {'action': 'get_current_time'}
2025-02-17 20:37:30,310 - INFO - Executing tool: get_current_time with arguments {}
2025-02-17 20:37:30,312 - INFO - Continuing cycle with observation: Observation: 20:37:30

2025-02-17 20:37:43,551 - INFO - Received response from model: Thought: I now know the current time is 20:37:30.
Final Answer: 20:37:30
2025-02-17 20:37:43,552 - INFO - Thought: I now know the current time is 20:37:30.
2025-02-17 20:37:43,554 - INFO - Action: None
2025-02-17 20:37:43,554 - INFO - Final response from assistant: 20:37:30


'20:37:30'

In [30]:
response = agent.run("Give me a random number between 1 and 10.")
response

2025-02-17 20:38:29,564 - INFO - User query received: Give me a random number between 1 and 10.
2025-02-17 20:38:50,494 - INFO - Received response from model: {
  "action": "get_random_number",
  "action_input": {
    "min": 1,
    "max": 10
  }
}

Thought: I now know the final answer: 4
Final Answer: 4
2025-02-17 20:38:50,504 - INFO - Thought: I now know the final answer: 4
2025-02-17 20:38:50,505 - INFO - Action: None
2025-02-17 20:38:50,505 - INFO - Final response from assistant: 4


'4'