# Build a RAG agent with LangChain

Derived from [Langchain docs](https://docs.langchain.com/oss/python/langchain/rag)

### Objective

You will start a **vLLM engine** to serve a local LLM and implement a RAG (Retrieval-Augmented Generation) agent that queries a blog post. 

The documents will be vectorized using sentence embeddings and stored in **Milvus**, a vector database, so the agent can retrieve relevant context for answering user queries.

### Instructions to Start the VLLM Inference Engine

NOTE: This notebook should be run on a GPU node (A40 or above). Also `cd` into your respective project folder since this creates logs files and database. 

1. **SSH into the compute node** that is running this Jupyter server.
2. **Start the VLLM inference engine** using the following command:
    ```bash
    apptainer exec /mimer/NOBACKUP/groups/llm-workshop/containers/rag/rag.sif vllm serve /mimer/NOBACKUP/Datasets/LLM/huggingface/hub/models--HuggingFaceTB--SmolLM3-3B/snapshots/a07cc9a04f16550a088caea529712d1d335b0ac1 --port=$(find_ports) --gpu-memory-utilization 0.6 --enable_auto_tool_choice --tool_call_parser=hermes > vllm.out 2> vllm.err &
    ```
3. **Monitor the logs**:
    - `vllm.out` for standard output.
    - `vllm.err` for error logs.

## Setup

In [None]:
from langchain_huggingface import HuggingFaceEmbeddings

# Initialize local sentence-transformers embeddings model
embeddings = HuggingFaceEmbeddings(model_name="/mimer/NOBACKUP/Datasets/LLM/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/c9745ed1d9f207416be6d2e6f8de32d1f16199bf")

In [None]:
from langchain_milvus import Milvus

URI = "./milvus_example.db"

# Initialize Milvus vector store with local embeddings
vector_store = Milvus(
    embedding_function=embeddings,
    connection_args={"uri": URI},
    index_params={"index_type": "FLAT", "metric_type": "L2"},
)

In [None]:
from langchain.chat_models import init_chat_model

VLLM_PORT=46219 # REPLACE WITH YOUR VLLM PORT
LLM_MODEL="/mimer/NOBACKUP/Datasets/LLM/huggingface/hub/models--HuggingFaceTB--SmolLM3-3B/snapshots/a07cc9a04f16550a088caea529712d1d335b0ac1"

# Initialize local vLLM chat model
model = init_chat_model(
    model=LLM_MODEL,
    model_provider="openai",
    base_url=f"http://localhost:{VLLM_PORT}/v1",
    api_key="none",
)

## Indexing

### Loading documents

In [None]:
import bs4
from langchain_community.document_loaders import WebBaseLoader

# Only keep post title, headers, and content from the full HTML.
bs4_strainer = bs4.SoupStrainer(class_=("post-title", "post-header", "post-content"))

# Load a blog post from Lilian Weng's blog
loader = WebBaseLoader(
    web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),
    bs_kwargs={"parse_only": bs4_strainer},
)
docs = loader.load()

assert len(docs) == 1
print(f"Total characters: {len(docs[0].page_content)}")

In [None]:
# Print the first 500 characters of the loaded document
print(docs[0].page_content[:500])

### Splitting documents

In [None]:
from langchain_text_splitters import RecursiveCharacterTextSplitter

# Split document into smaller chunks for indexing
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,  # chunk size (characters)
    chunk_overlap=200,  # chunk overlap (characters)
    add_start_index=True,  # track index in original document
)
all_splits = text_splitter.split_documents(docs)

print(f"Split blog post into {len(all_splits)} sub-documents.")

### Storing documents

In [None]:
# Add document chunks to Milvus vector store
document_ids = vector_store.add_documents(documents=all_splits)

print(document_ids[:3])

## Retrieval and Generation

Even retrieving from a vector database can be considered a tool. This is just to demo tool usage.

### Tool usage:

In [None]:
from langchain.tools import tool

# Define a retrieval tool that fetches relevant documents from the vector store
@tool(response_format="content_and_artifact")
def retrieve_context(query: str):
    """Retrieve information to help answer a query."""
    retrieved_docs = vector_store.similarity_search(query, k=2)
    serialized = "\n\n".join(
        (f"Source: {doc.metadata}\nContent: {doc.page_content}")
        for doc in retrieved_docs
    )
    return serialized, retrieved_docs

In [None]:
from langchain.agents import create_agent

# Create an agent with the retrieval tool
tools = [retrieve_context]
# NOTE: we are disabling reasoning/thinking to force it to use tools
prompt = (
    "/no_think You have access to a tool that retrieves context from a blog post. "
    "Use the tool to help answer user queries."
)
agent = create_agent(model, tools, system_prompt=prompt)

In [None]:
# Example query to the agent
# NOTE: We construct a multi-part query to demonstrate retrieval and follow-up
query = (
    "What is the standard method for Task Decomposition?\n\n"
    "Once you get the answer, look up common extensions of that method."
)

for event in agent.stream(
    {"messages": [
        {"role": "system", "content": "/no_think"},
        {"role": "user", "content": query}
    ]},
    stream_mode="values",
):
    event["messages"][-1].pretty_print()

### No tool usage:

Instead of using a tool, we can directly send the retrived results based on the original query to the LLM. The agent in this case will have no tool but we inject the extra retrived context as a middleware

In [None]:
from langchain.agents.middleware import dynamic_prompt, ModelRequest

# Define a dynamic prompt middleware to inject context
@dynamic_prompt
def prompt_with_context(request: ModelRequest) -> str:
    """Inject context into state messages."""
    last_query = request.state["messages"][-1].text
    retrieved_docs = vector_store.similarity_search(last_query)

    docs_content = "\n\n".join(doc.page_content for doc in retrieved_docs)

    system_message = (
        "/no_think You are a helpful assistant. Use the following context in your response:"
        f"\n\n{docs_content}"
    )
    print(system_message)
    return system_message


agent = create_agent(model, tools=[], middleware=[prompt_with_context])

In [None]:
query = "What is task decomposition?"
for step in agent.stream(
    {"messages": [{"role": "user", "content": query}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()