# RAG Agent Demo

This notebook demonstrates the RAG (Retrieval Augmented Generation) agent that answers questions about Stock Market Performance using a PDF document as knowledge base.

In [None]:
# Import required libraries
from typing import Annotated, Sequence, TypedDict
from langchain_core.messages import (
    BaseMessage,
    ToolMessage,
    SystemMessage,
    HumanMessage,
)
from langchain_ollama import ChatOllama, OllamaEmbeddings
from langgraph.graph import StateGraph, END
from langchain_core.tools import tool
from langgraph.graph.message import add_messages
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma

In [None]:
# Initialize LLM and embeddings
llm = ChatOllama(
    model="llama3.2:latest", temperature=0
)  # Low temperature for consistency
embeddings = OllamaEmbeddings(model="nomic-embed-text")

# Verify models are loaded
assert llm is not None, "LLM failed to initialize"
assert embeddings is not None, "Embeddings model failed to initialize"

In [None]:
# Load and process PDF
pdf_path = os.path.join(
    os.path.dirname(os.getcwd()), "static", "Stock_Market_Performance_2024.pdf"
)
print(pdf_path)
assert os.path.exists(pdf_path), f"PDF file not found: {pdf_path}"

e:\Code\AI_Agents_with_Python\static\Stock_Market_Performance_2024.pdf


In [None]:
pdf_loader = PyPDFLoader(pdf_path)
pages = pdf_loader.load()
print(f"PDF loaded successfully with {len(pages)} pages")

# Split text into chunks
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,
    chunk_overlap=200,
)
pages_split = text_splitter.split_documents(pages)
assert len(pages_split) > 0, "Text splitting produced no chunks"

PDF loaded successfully with 9 pages


In [None]:
# Set up ChromaDB vector store
persist_directory = os.path.join(os.path.dirname(os.getcwd()), "Agents", "Chroma")
collection_name = "stock_market"

if not os.path.exists(persist_directory):
    os.makedirs(persist_directory)

vectorstore = Chroma.from_documents(
    documents=pages_split,
    embedding=embeddings,
    persist_directory=persist_directory,
    collection_name=collection_name,
)
print("ChromaDB vector store created successfully")

# Set up retriever
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5})
assert retriever is not None, "Retriever failed to initialize"

ChromaDB vector store created successfully


In [None]:
# Define retriever tool and agent state
@tool
def retriever_tool(query: str) -> str:
    """Searches the Stock Market Performance 2024 document."""
    docs = retriever.invoke(query)
    if not docs:
        return "No relevant information found."

    results = [f"Document {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs)]
    return "\n\n".join(results)


tools = [retriever_tool]
llm = llm.bind_tools(tools)


class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]

In [13]:
# Define agent functions
def should_continue(state: AgentState):
    """Check if more tool calls are needed."""
    result = state["messages"][-1]
    return hasattr(result, "tool_calls") and len(result.tool_calls) > 0


system_prompt = """You are an AI assistant who answers questions about Stock Market Performance in 2024.
Use the retriever tool to access the document. Cite specific parts in your answers."""

tools_dict = {tool.name: tool for tool in tools}


def call_llm(state: AgentState) -> AgentState:
    """Call LLM with current state."""
    messages = [SystemMessage(content=system_prompt)] + list(state["messages"])
    message = llm.invoke(messages)
    return {"messages": [message]}


def take_action(state: AgentState) -> AgentState:
    """Execute tool calls from the LLM's response."""

    tool_calls = state["messages"][-1].tool_calls
    results = []

    for t in tool_calls:
        print(
            f"Calling Toll: {t['name']} with query: {t['args'].get('query', "No query provided.")}"
        )

        if not t["name"] in tools_dict:
            print(f"\nTool: {t['name']} does not exist.")
            result = "Incorrect Tool Name, Please Retry and Select tool from list of Available tools."

        else:
            result = tools_dict[t["name"]].invoke(t["args"].get("query", ""))
            print(f"Result length: {len(str(result))}")

        results.append(
            ToolMessage(tool_call_id=t["id"], name=t["name"], content=str(result))
        )

    print("Tools Execution Complete. Back to the model.")
    return {"messages": results}

In [None]:
# Set up and compile the graph
graph = StateGraph(AgentState)

# Add nodes
graph.add_node("llm", call_llm)
graph.add_node("retriever_agent", take_action)

# Add edges
graph.add_conditional_edges(
    "llm", should_continue, {True: "retriever_agent", False: END}
)
graph.add_edge("retriever_agent", "llm")
graph.set_entry_point("llm")

app = graph.compile()

In [18]:
# Test the agent with sample questions and validate responses
test_examples = [
    {
        "question": "What were the top performing sectors in 2024?",
        "exp_keywords": ["Technology", "Nasdaq", "Magnificent 7", "tech", "NVIDIA"],
    },
    {
        "question": "What was the overall market trend in Q1 2024?",
        "exp_keywords": ["S&P 500", "25%", "total return", "gains"],
    },
]


def validate_response(response: str, keywords: list) -> tuple[bool, list]:
    """Validate that response contains expected keywords"""
    found_keywords = []
    for keyword in keywords:
        if keyword.lower() in response.lower():
            found_keywords.append(keyword)
    return len(found_keywords) >= 2, found_keywords


for question in test_examples:
    print(f"\nQuestion: {question['question']}")
    messages = [HumanMessage(content=question["question"])]
    result = app.invoke({"messages": messages})
    response = result["messages"][-1].content
    print(f"Answer:\n{response}\n")

    # Basic response validation
    assert isinstance(result["messages"][-1], BaseMessage), "Invalid response format"
    assert len(response) > 0, "Empty response received"

    # Keyword validation
    is_valid, found = validate_response(response, question["exp_keywords"])
    assert is_valid, (
        f"Response missing expected keywords. Found only: {found}. "
        f"Expected at least 2 from: {question['exp_keywords']}"
    )
    print(f"Found keywords: {found}")

    # Tool usage validation
    tool_messages = [msg for msg in result["messages"] if isinstance(msg, ToolMessage)]
    assert len(tool_messages) > 0, "Response should use retriever tool"


Question: What were the top performing sectors in 2024?
Calling Toll: retriever_tool with query: Top performing sectors in 2024
Result length: 4818
Tools Execution Complete. Back to the model.
Answer:
The top performing sectors in 2024 were:

1. Technology: The tech-heavy Nasdaq Composite outpaced the broader market, jumping nearly 29% for the year.
2. Mega-cap technology stocks: A group of seven powerhouse companies often dubbed the "Magnificent 7" - Apple, Microsoft, Alphabet (Google), Amazon, Meta, Facebook, and NVIDIA.

These sectors were among the top performers in 2024, driven by strong gains from mega-cap technology stocks.

Found keywords: ['Technology', 'Nasdaq', 'Magnificent 7', 'tech', 'NVIDIA']

Question: What was the overall market trend in Q1 2024?
Calling Toll: retriever_tool with query: Q1 2024 market trend
Result length: 4818
Tools Execution Complete. Back to the model.
Answer:
The overall market trend in Q1 2024 was a strong one for equities, with the U.S. stock mark