# Simplified Research Agent

This notebook implements a simplified version of the research agent that leverages LangChain/LangGraph's built-in tool calling logic while still tracking which tools are used and collecting evidence.

In [9]:
from pprint import pprint
import json
from IPython.display import Image, display
from dotenv import load_dotenv
import importlib
import os
import sys
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage, SystemMessage
from langchain_ollama import ChatOllama
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from typing import Annotated, TypedDict, List, Dict, Any

notebook_dir = os.getcwd()

project_root = os.path.abspath(
    os.path.join(notebook_dir, "../.."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import core.tools.registry

load_dotenv(os.path.join(os.path.dirname(notebook_dir), '.env'), override=True)

False

## Configuration and Tool Setup

In [10]:
MODEL = "mistral-nemo"
TEMPERATURE = 0

TOOL_REGISTRY = {
    'core.tools.builtins.calculator': ['multiply', 'add', 'divide'],
    'core.tools.builtins.wikipedia': ['query']
}

def import_function(module_name, function_name):
    """Dynamically imports a function from a module."""
    try:
        module = importlib.import_module(module_name)
        function = getattr(module, function_name)
        return function
    except (ImportError, AttributeError) as e:
        print(f"Error: Could not import function '{function_name}' from module '{module_name}'.")
        print(f"Exception: {e}")
        return None

all_tools = [import_function(module, function) for module,
         functions in TOOL_REGISTRY.items() for function in functions]
print(f"Tools: {[tool.__name__ for tool in all_tools]}")

Tools: ['multiply', 'add', 'divide', 'query']


## State Definition and System Prompt

In [11]:
class State(TypedDict):
    messages: Annotated[list[BaseMessage], add_messages]
    claim: str
    used_tools: List[str]
    # {'name': tool name, 'args': {kwargs}, 'result': str}
    evidence: list[dict]

with open(os.path.join(notebook_dir, 'prompts/research_agent_system_prompt.txt'), 'r') as f:
    sys_msg = SystemMessage(content=f.read())

## Node Functions

In [12]:
def preprocessing(state: State):
    """
    Preprocesses state before sending to the assistant for tool routing.
    """
    state['messages'] = [sys_msg, HumanMessage(content=state['claim'])]
    state['used_tools'] = []
    state['evidence'] = []
    return state

def assistant(state: State) -> State:
    """
    The main assistant node that processes the claim and decides which tools to use.
    This leverages LangChain/LangGraph's built-in tool calling logic.
    """
    llm_with_tools = ChatOllama(
        model=MODEL,
        temperature=TEMPERATURE,
    ).bind_tools(all_tools)
    
    response = llm_with_tools.invoke(state['messages'])

    if hasattr(response, 'tool_calls'):
        for tool_call in response.tool_calls:
            tool_name = tool_call['name']
            if tool_name not in state['used_tools']:
                state['used_tools'].append(tool_name)
    
    return {"messages": response}

def postprocessing(state: State) -> State:
    """
    Extract evidence from the message history.
    """
    evidence = []
    for i in range(len(state['messages'])):
        message = state['messages'][i]
        if isinstance(message, AIMessage) and hasattr(message, 'tool_calls'):
            for tool_call in message.tool_calls:
                for j in range(i + 1, len(state['messages'])):
                    next_message = state['messages'][j]
                    if isinstance(next_message, ToolMessage) and next_message.tool_call_id == tool_call['id']:
                        evidence.append({
                            'name': tool_call['name'],
                            'args': tool_call['args'],
                            'result': next_message.content
                        })
                        break
    
    state['evidence'] = evidence
    return state

## Graph Construction

In [13]:
builder = StateGraph(State)

builder.add_node("preprocessing", preprocessing)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(all_tools))
builder.add_node("postprocessing", postprocessing)

builder.add_edge(START, "preprocessing")
builder.add_edge("preprocessing", "assistant")
builder.add_conditional_edges(
    "assistant",
    tools_condition,
    {
        "tools": "tools",
        "__end__": "postprocessing"
    }
)
builder.add_edge("tools", "assistant")
builder.add_edge("postprocessing", END)

agent = builder.compile()

try:
    display(Image(agent.get_graph().draw_mermaid_png()))
except Exception as e:
    print(f"Could not visualize graph: {e}")

Could not visualize graph: HTTPSConnectionPool(host='mermaid.ink', port=443): Read timed out. (read timeout=10)


## Testing with Example Claims

In [14]:
factual_claim = "Albert Einstein developed the theory of relativity"
factual_result = agent.invoke({"claim": factual_claim})

print("\nFactual Claim Test:")
print(f"Claim: {factual_claim}")
print("Tools Used:")
for tool_name in factual_result.get('used_tools', []):
    print(f"  {tool_name}")

print("Evidence:")
for evidence in factual_result.get('evidence', []):
    print(f"  Tool: {evidence['name']}")
    print(f"  Args: {evidence['args']}")
    print(f"  Result: {evidence['result'][:100]}..." if len(evidence['result']) > 100 else f"  Result: {evidence['result']}")
    print()


Factual Claim Test:
Claim: Albert Einstein developed the theory of relativity
Tools Used:
  query
Evidence:
  Tool: query
  Args: {'query': 'Albert Einstein theory of relativity'}
  Result: No Wikipedia page found for 'Albert Einstein theory of relativity'.



In [15]:
math_claim = "12 multiplied by 10 equals 120"
math_result = agent.invoke({"claim": math_claim})

print("\nMathematical Claim Test:")
print(f"Claim: {math_claim}")
print("Tools Used:")
for tool_name in math_result.get('used_tools', []):
    print(f"  {tool_name}")

print("Evidence:")
for evidence in math_result.get('evidence', []):
    print(f"  Tool: {evidence['name']}")
    print(f"  Args: {evidence['args']}")
    print(f"  Result: {evidence['result']}")
    print()


Mathematical Claim Test:
Claim: 12 multiplied by 10 equals 120
Tools Used:
  multiply
Evidence:
  Tool: multiply
  Args: {'a': 12, 'b': 10}
  Result: 120



In [16]:
mixed_claim = "If you multiply the distance from Earth to the Sun (93 million miles) by 2, you get 186 million miles"
mixed_result = agent.invoke({"claim": mixed_claim})

print("\nMixed Claim Test:")
print(f"Claim: {mixed_claim}")
print("Tools Used:")
for tool_name in mixed_result.get('used_tools', []):
    print(f"  {tool_name}")

print("Evidence:")
for evidence in mixed_result.get('evidence', []):
    print(f"  Tool: {evidence['name']}")
    print(f"  Args: {evidence['args']}")
    print(f"  Result: {evidence['result'][:100]}..." if len(evidence['result']) > 100 else f"  Result: {evidence['result']}")
    print()


Mixed Claim Test:
Claim: If you multiply the distance from Earth to the Sun (93 million miles) by 2, you get 186 million miles
Tools Used:
  query
  multiply
Evidence:
  Tool: query
  Args: {'query': 'distance from Earth to Sun'}
  Result: No Wikipedia page found for 'distance from Earth to Sun'.

  Tool: multiply
  Args: {'a': 93000000, 'b': 2}
  Result: 186000000

