# LangGraph RAG Agent with Databricks Vector Search

This notebook demonstrates:
1. Creating a LangGraph agent with retrieval capabilities
2. Testing the agent with sample queries
3. Logging the agent as an MLflow model using Model As Code


## 1. Setup and Imports


In [None]:
%pip install langgraph langchain databricks_agents databricks_langchain databricks-vectorsearch "mlflow>=3.6" pandas matplotlib
%restart_python


In [None]:
import mlflow
import uuid
from typing import TypedDict, Annotated, Sequence
import operator
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, END
from databricks_langchain import ChatDatabricks
from databricks.vector_search.client import VectorSearchClient

print(f"MLflow version: {mlflow.__version__}")
assert mlflow.__version__ >= "3.6.0", "MLflow 3.6+ required for session/user tracking"


## 2. Configuration


In [None]:
# Vector Search configuration (from pdf_to_vector_search notebook)
CATALOG_NAME = "brian_ml_dev"
SCHEMA_NAME = "eval_testing"
VECTOR_INDEX_NAME = "annual_report_index"
VECTOR_SEARCH_ENDPOINT = "one-env-shared-endpoint-13"

# LLM configuration
LLM_ENDPOINT = "databricks-gpt-oss-120b"

# Agent configuration
TOP_K_RESULTS = 3
MAX_ITERATIONS = 5

# MLflow configuration
EXPERIMENT_NAME = "/Users/brian.law@databricks.com/langgraph_rag_agent"
MODEL_NAME = "langgraph_rag_agent"


## 3. Initialize Components


In [None]:
# Initialize Vector Search client
vsc = VectorSearchClient()
full_index_name = f"{CATALOG_NAME}.{SCHEMA_NAME}.{VECTOR_INDEX_NAME}"

# Get vector search index
vector_index = vsc.get_index(
    endpoint_name=VECTOR_SEARCH_ENDPOINT,
    index_name=full_index_name
)

print(f"✓ Vector Search Index loaded: {full_index_name}")


In [None]:
# Initialize LLM
llm = ChatDatabricks(
    endpoint=LLM_ENDPOINT,
    temperature=0.1,
    max_tokens=500
)

print(f"✓ LLM initialized: {LLM_ENDPOINT}")


## 4. Define Agent State and Tools


In [None]:
class AgentState(TypedDict):
    """State of the agent graph"""
    messages: Annotated[Sequence[BaseMessage], operator.add]
    question: str
    context: str
    answer: str
    iterations: int


In [None]:
def retrieve_documents(state: AgentState) -> AgentState:
    """Retrieve relevant documents from vector search"""
    question = state["question"]
    
    # Perform similarity search
    results = vector_index.similarity_search(
        query_text=question,
        columns=["text", "page"],
        num_results=TOP_K_RESULTS
    )
    
    # Extract text from results
    if isinstance(results, dict):
        data_array = results.get('result', {}).get('data_array', [])
    else:
        data_array = getattr(results, 'data_array', [])
    
    # Format context from retrieved documents
    context_parts = []
    for i, result in enumerate(data_array, 1):
        text = result[0] if isinstance(result, (list, tuple)) else result
        context_parts.append(f"Document {i}:\n{text}")
    
    context = "\n\n".join(context_parts)
    
    return {
        **state,
        "context": context,
        "messages": state["messages"] + [SystemMessage(content=f"Retrieved {len(data_array)} relevant documents")]
    }


In [None]:
def generate_answer(state: AgentState) -> AgentState:
    """Generate answer using LLM with retrieved context"""
    question = state["question"]
    context = state["context"]
    
    # Create RAG prompt
    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are a helpful assistant that answers questions based on the provided context. "
                   "If the answer cannot be found in the context, say so."),
        ("human", "Context:\n{context}\n\nQuestion: {question}\n\nAnswer:")
    ])
    
    # Generate response
    chain = prompt | llm
    response = chain.invoke({"context": context, "question": question})
    
    answer = response.content if hasattr(response, 'content') else str(response)
    
    return {
        **state,
        "answer": answer,
        "messages": state["messages"] + [AIMessage(content=answer)],
        "iterations": state.get("iterations", 0) + 1
    }


## 5. Build LangGraph Agent


In [None]:
# Create the graph
workflow = StateGraph(AgentState)

# Add nodes
workflow.add_node("retrieve", retrieve_documents)
workflow.add_node("generate", generate_answer)

# Set entry point
workflow.set_entry_point("retrieve")

# Add edges
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)

# Compile the graph
agent = workflow.compile()

print("✓ LangGraph agent created successfully")


## 6. Test the Agent


In [None]:
# Session and user tracking for chat/conversation functionality (MLflow 3.6+)
# Generate session and user IDs for tracking conversations
session_id = str(uuid.uuid4())
user_id = "test_user_001"  # In production, use actual user identifier

# Test questions
test_questions = [
    "What are the key financial highlights from the annual report?",
    "What are the main risks mentioned in the report?",
    "What is the company's growth strategy?"
]

print("Testing the agent with session/user tracking...\n" + "="*80)
print(f"Session ID: {session_id}")
print(f"User ID: {user_id}")
print("="*80)

# Enable MLflow autologging for LangChain
mlflow.langchain.autolog()

for i, question in enumerate(test_questions, 1):
    print(f"\nTest {i}: {question}")
    print("-" * 80)
    
    # Run the agent
    result = agent.invoke({
        "messages": [HumanMessage(content=question)],
        "question": question,
        "context": "",
        "answer": "",
        "iterations": 0
    })
    
    # Add session and user metadata to the current trace (MLflow 3.6+)
    mlflow.update_current_trace(
        metadata={
            "mlflow.trace.session": session_id,
            "mlflow.trace.user": user_id,
            "question_index": i,
        }
    )
    
    print(f"\nAnswer: {result['answer']}")
    print("\n" + "="*80)


## 7. Create MLflow Model Wrapper


In [None]:
mlflow.models.set_model(agent)

## 8. Summary


In [None]:
print("="*80)
print("SUMMARY")
print("="*80)
print(f"\n✓ LangGraph RAG agent created")
print(f"✓ Agent tested with sample queries")
print(f"✓ Session/user tracking enabled (MLflow 3.6+)")
print(f"✓ Model set with mlflow.models.set_model()")
print(f"\nAgent Configuration:")
print(f"  - Vector Index: {full_index_name}")
print(f"  - LLM Endpoint: {LLM_ENDPOINT}")
print(f"  - Top K Results: {TOP_K_RESULTS}")
print(f"\nSession Tracking:")
print(f"  - Session ID: {session_id}")
print(f"  - User ID: {user_id}")
print(f"\nNext step: Use this notebook code within an MLflow run to log the model")
print(f"          Then use model_evaluation.ipynb to evaluate it")
print("="*80)
