# Building a Persistent Memory System for AI Agents with Databricks Lakebase

In this tutorial, we'll build a persistent memory system for AI agents using **Databricks Lakebase Provisioned** - a fully managed PostgreSQL database integrated with the Databricks platform.

## What We'll Build

A conversation memory system that stores:
* User messages and agent responses
* Session tracking
* Timestamps
* Metadata (intents, entities, etc.)

## Why Lakebase?

* **Fully Managed PostgreSQL**: No infrastructure management
* **Databricks Integration**: Seamless authentication and governance
* **ACID Transactions**: Perfect for conversational state management
* **Scalable**: Start small (1 CU) and scale as needed

## Prerequisites

* Databricks workspace (AWS)
* Python environment with databricks-sdk
* Basic familiarity with PostgreSQL

Let's get started!

## Step 0: Install Dependencies

First, we'll install the required packages:
* `databricks-sdk`: For Lakebase API access
* `psycopg2-binary`: PostgreSQL driver for Python
* Additional packages for building AI agents (optional)

In [0]:
%pip install --upgrade databricks-sdk 
%pip install psycopg2-binary
%pip install mlflow langgraph langchain-openai databricks-langchain langgraph-checkpoint-postgres psycopg[binary] pydantic
dbutils.library.restartPython()

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m
[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m
Collecting mlflow
  Downloading mlflow-3.8.1-py3-none-any.whl.metadata (31 kB)
Collecting langgraph
  Downloading langgraph-1.0.6-py3-none-any.whl.metadata (7.4 kB)
Collecting langchain-openai
  Downloading langchain_openai-1.1.7-py3-none-any.whl.metadata (2.6 kB)
Collecting databricks-langchain
  Downloading databricks_langchain-0.13.0-py3-none-any.whl.metadata (2.6 kB)
Collecting langgraph-checkpoint-postgres
  Downloading langgraph_checkpoint_postgres-3.0.3-py3-none-any.whl.metadata (4.8 kB)
Collecting psycopg[binary]
  Downloading psycopg-3.3.2-py3-none-any.whl.metadata (4.3 kB)
Collecting mlflow-skinny==3.8.1 (from mlflow)
  Downloading mlflow_skinny-3.8.1-py3-none-any.whl.metadata (31 kB)
Collecting mlflow-tracing==3.8.1

## Step 1: Create a Lakebase Instance

A **Lakebase instance** is a managed PostgreSQL server. We'll:
1. Initialize the Databricks SDK
2. Create an instance with 1 Compute Unit (CU)
3. Handle the case where the instance already exists

The instance takes a few minutes to provision and will be in `STARTING` or `AVAILABLE` state.

## Step 2: Create a Custom Database

Now we'll:
1. Generate OAuth credentials for the instance
2. Connect using psycopg2
3. Create a custom database called `travel_assistant_db`

**Note**: We use OAuth token authentication with your Databricks identity.

## Step 3: Create the Conversation Memory Table

We'll create a table with:
* `id`: Auto-incrementing primary key
* `session_id`: Track conversation sessions
* `user_message` & `agent_response`: Store the conversation
* `timestamp`: When the interaction occurred
* `metadata`: JSONB field for flexible data (intents, entities, etc.)

## Step 4: Query the Conversation History

Let's verify everything works by querying the data we just inserted.

## Step 5: Create a Reusable Connection Helper

For production use, we'll create a helper function that:
* Generates fresh credentials automatically
* Returns a ready-to-use connection
* Can be called from anywhere in your code

This is useful for AI agent applications that need to persist state across conversations.

In [0]:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.database import DatabaseInstance
from databricks.sdk.errors import InvalidParameterValue

# Initialize the workspace client
w = WorkspaceClient()

# Configuration
LAKEBASE_INSTANCE_NAME = "travel-agent-memory"

try:
    # Try to create the Lakebase instance
    instance = w.database.create_database_instance(
        DatabaseInstance(
            name=LAKEBASE_INSTANCE_NAME,
            capacity="CU_1"  # Start with 1 Compute Unit
        )
    )
    print(f"✓ Created Lakebase instance: {instance.name}")
except InvalidParameterValue as e:
    if "not unique" in str(e):
        # Instance already exists, retrieve it
        instance = w.database.get_database_instance(name=LAKEBASE_INSTANCE_NAME)
        print(f"✓ Lakebase instance already exists: {instance.name}")
    else:
        raise

# Display instance details
print(f"\nInstance Details:")
print(f"  Name: {instance.name}")
print(f"  State: {instance.state}")
print(f"  Endpoint: {instance.read_write_dns}")

✓ Lakebase instance already exists: travel-agent-memory
Connection endpoint: instance-756cf329-0f27-4995-9dff-2b6bb01d3a29.database.cloud.databricks.com


In [0]:
import psycopg2

# Helper function to get connection details
def get_connection_details():
    """Get Lakebase connection details with fresh credentials."""
    credential = w.database.generate_database_credential(
        instance_names=[LAKEBASE_INSTANCE_NAME]
    )
    instance = w.database.get_database_instance(name=LAKEBASE_INSTANCE_NAME)
    current_user = w.current_user.me()
    
    return {
        "host": instance.read_write_dns,
        "user": current_user.user_name,
        "password": credential.token
    }

conn_details = get_connection_details()
print(f"✓ Generated credentials for: {conn_details['user']}")

# Connect to the default postgres database
conn = psycopg2.connect(
    host=conn_details['host'],
    port=5432,
    database="postgres",
    user=conn_details['user'],
    password=conn_details['password'],
    sslmode="require"
)
conn.autocommit = True
cursor = conn.cursor()

print(f"✓ Connected to Lakebase instance")

# Create a custom database
try:
    cursor.execute("CREATE DATABASE travel_assistant_db")
    print(f"✓ Created database: travel_assistant_db")
except psycopg2.errors.DuplicateDatabase:
    print(f"✓ Database already exists: travel_assistant_db")

cursor.close()
conn.close()

# Connect to the new database and create table
conn = psycopg2.connect(
    host=conn_details['host'],
    port=5432,
    database="travel_assistant_db",
    user=conn_details['user'],
    password=conn_details['password'],
    sslmode="require"
)
cursor = conn.cursor()

# Create table for conversation memory
cursor.execute("""
    CREATE TABLE IF NOT EXISTS conversation_memory (
        id SERIAL PRIMARY KEY,
        session_id VARCHAR(255) NOT NULL,
        user_message TEXT,
        agent_response TEXT,
        timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
        metadata JSONB
    )
""")
conn.commit()
print(f"✓ Created table: conversation_memory")

# Insert sample conversation
cursor.execute("""
    INSERT INTO conversation_memory (session_id, user_message, agent_response, metadata)
    VALUES (%s, %s, %s, %s)
""", (
    "demo-session-001",
    "I want to book a flight to Paris",
    "I can help you with that. When would you like to travel?",
    '{"intent": "flight_booking", "destination": "Paris"}'
))
conn.commit()
print(f"✓ Inserted sample conversation")

cursor.close()
conn.close()

✓ Generated credentials for: travel-agent-memory
  Hostname: instance-756cf329-0f27-4995-9dff-2b6bb01d3a29.database.cloud.databricks.com
  Username: ankit.yadav@databricks.com

✓ Connected to Lakebase instance
✓ Database already exists: travel_assistant_db


In [0]:
import psycopg2

# Get fresh connection details
conn_details = get_connection_details()

conn = psycopg2.connect(
    host=conn_details['host'],
    port=5432,
    database="travel_assistant_db",
    user=conn_details['user'],
    password=conn_details['password'],
    sslmode="require"
)
cursor = conn.cursor()

# Query conversation history
cursor.execute("SELECT * FROM conversation_memory ORDER BY timestamp DESC LIMIT 10")
rows = cursor.fetchall()

print(f"Recent Conversations ({len(rows)} records):")
print("=" * 80)
for row in rows:
    print(f"Session: {row[1]}")
    print(f"User: {row[2]}")
    print(f"Agent: {row[3]}")
    print(f"Time: {row[4]}")
    print("-" * 80)

cursor.close()
conn.close()

Recent conversations:
--------------------------------------------------------------------------------
Session: sample-session-001
User: I want to book a flight to Paris
Agent: I can help you with that. When would you like to travel?
Time: 2026-01-16 11:01:53.401011
--------------------------------------------------------------------------------
Session: sample-session-001
User: I want to book a flight to Paris
Agent: I can help you with that. When would you like to travel?
Time: 2026-01-16 10:19:40.395614
--------------------------------------------------------------------------------

✓ Successfully queried 2 records directly from Lakebase


In [0]:
# Test the connection helper function we created
conn = get_lakebase_connection()
print("✓ Connection helper function works successfully")

# Quick test query
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM conversation_memory")
count = cursor.fetchone()[0]
print(f"✓ Can access conversation_memory table ({count} rows)")

cursor.close()
conn.close()

print(f"\n✓ Ready to integrate with LangGraph!")

✓ Connection function created successfully

Usage:
  conn = get_lakebase_connection()
  cursor = conn.cursor()
  cursor.execute('SELECT * FROM conversation_memory')
  rows = cursor.fetchall()


## Next Steps

You now have a fully functional persistent memory system for AI agents! Here's what you can do next:

### 1. Integrate with LangGraph
```python
from langgraph.checkpoint.postgres import PostgresSaver

conn = get_lakebase_connection()
checkpointer = PostgresSaver(conn)
# Use with your LangGraph agent
```

### 2. Add More Tables
Create additional tables for:
* User preferences
* Conversation summaries
* Tool call history
* Agent performance metrics

### 3. Scale Your Instance
As your application grows, scale up:
```python
w.database.update_database_instance(
    name=LAKEBASE_INSTANCE_NAME,
    database_instance=DatabaseInstance(
        name=LAKEBASE_INSTANCE_NAME,
        capacity="CU_2"  # Scale to 2 CUs
    ),
    update_mask="capacity"
)
```

### 4. Register in Unity Catalog (Optional)
For data governance and SQL access:
```python
from databricks.sdk.service.database import DatabaseCatalog

w.database.create_database_catalog(
    DatabaseCatalog(
        name="travel_assistant_catalog",
        database_instance_name=LAKEBASE_INSTANCE_NAME,
        database_name="travel_assistant_db"
    )
)
```

## Key Takeaways

* **Lakebase Provisioned** provides managed PostgreSQL with Databricks integration
* **OAuth authentication** uses your Databricks identity - no password management
* **psycopg2** gives you full PostgreSQL capabilities for complex queries
* **Perfect for AI agents** that need persistent, transactional state management

## Resources

* [Lakebase Documentation](https://docs.databricks.com/aws/en/oltp/instances/about/)
* [Databricks SDK for Python](https://databricks-sdk-py.readthedocs.io/)
* [LangGraph Checkpointing](https://langchain-ai.github.io/langgraph/how-tos/persistence/)

Happy building! 🚀

---

# Part 2: Building a Stateful AI Agent with LangGraph

Now that we have our Lakebase database set up, let's build a **stateful conversational agent** using LangGraph. This agent will:

* **Remember conversations** across multiple turns
* **Track structured information** (destination, dates, budget)
* **Persist state** in Lakebase using PostgreSQL checkpoints
* **Support time-travel** - branch from any point in conversation history

## What is LangGraph?

LangGraph is a framework for building stateful, multi-actor applications with LLMs. Key features:

* **State Management**: Automatically persists conversation state
* **Checkpointing**: Save and restore from any point in the conversation
* **Graph-based**: Define agent logic as a directed graph of nodes
* **Framework Agnostic**: Works with any LLM provider

## Architecture Overview

Our agent consists of:
1. **State Schema**: Defines what information to track
2. **Connection Manager**: Handles Lakebase connections with psycopg3
3. **Agent Graph**: The LLM logic with checkpointing enabled
4. **Conversation Runner**: Executes multi-turn conversations

Let's build each component!

## Step 6: Define the Agent State Schema

The state schema defines what information our agent tracks across conversations. We use TypedDict with LangGraph's `add_messages` annotation to automatically accumulate conversation history.

**Key fields:**
* `messages`: Conversation history (auto-accumulated by LangGraph)
* `destination`, `travel_dates`, `budget`, `trip_type`: Extracted travel preferences
* `preferences`: List of user preferences collected during conversation

In [0]:
from typing import Annotated, Any
from typing_extensions import TypedDict
from langgraph.graph.message import add_messages

class TravelPlannerState(TypedDict):
    """State for the Travel Planning Assistant.
    
    This state is checkpointed after each interaction, allowing the agent
    to maintain context across multiple turns in the conversation.
    """
    # Conversation history - LangGraph handles message accumulation
    messages: Annotated[list, add_messages]
    
    # Structured travel preferences extracted from conversation
    destination: str | None
    travel_dates: str | None
    budget: str | None
    trip_type: str | None  # e.g., "relaxation", "adventure", "cultural"
    
    # Collected preferences as a list
    preferences: list[str]

## Step 7: Create a Connection Manager for LangGraph

The `LakebaseConnectionManager` provides a clean interface for LangGraph's `PostgresSaver` to connect to our Lakebase database.

**Why psycopg3?** LangGraph's PostgreSQL checkpointer requires psycopg (version 3), not psycopg2. Key differences:
* `autocommit=True` can be passed as a connection parameter
* Better async support
* Modern API design

**Important**: This uses OAuth tokens that expire after ~1 hour. For production, you'll need to refresh credentials or use native Postgres roles with passwords.

## Step 8: Build the LangGraph Agent

This creates our conversational agent with:

1. **System Prompt**: Guides the agent's behavior and shows current state
2. **Agent Node**: Processes user input and generates responses using Claude
3. **Information Extraction**: Parses user messages to extract travel preferences
4. **Graph Compilation**: Connects everything with checkpointing enabled

The agent automatically:
* Loads previous conversation state from Lakebase
* Updates structured information (destination, dates, etc.)
* Saves checkpoints after each turn

In [0]:
import os
from contextlib import contextmanager
from typing import Generator
import psycopg

class LakebaseConnectionManager:
    """Manages connections to Lakebase for checkpoint storage."""
    
    def __init__(
        self,
        host: str,
        database: str,
        port: int = 5432,
        user: str = None,
        password: str = None
    ):
        self.host = host
        self.database = database
        self.port = port
        self.user = user
        self.password = password
    
    @contextmanager
    def get_connection(self) -> Generator[psycopg.Connection, None, None]:
        """Get a database connection as a context manager."""
        conn = psycopg.connect(
            host=self.host,
            port=self.port,
            dbname=self.database,
            user=self.user,
            password=self.password,
            autocommit=True,
            sslmode="require"
        )
        try:
            yield conn
        finally:
            conn.close()

# Initialize with your Lakebase instance details
# Generate fresh credentials
credential = w.database.generate_database_credential(
    instance_names=[LAKEBASE_INSTANCE_NAME]
)

instance = w.database.get_database_instance(name=LAKEBASE_INSTANCE_NAME)
hostname = instance.read_write_dns
current_user = w.current_user.me()
username = current_user.user_name

lakebase_manager = LakebaseConnectionManager(
    host=hostname,
    database="travel_assistant_db",
    user=username,
    password=credential.token
)

In [0]:
from databricks_langchain import (
    ChatDatabricks,
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.postgres import PostgresSaver
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage

# System prompt that guides the agent's behavior
SYSTEM_PROMPT = """You are a helpful Travel Planning Assistant. Your goal is to help users 
plan their perfect trip by gathering information about their preferences and providing 
personalized recommendations.

When interacting with users:
1. If they haven't specified a destination, ask about it
2. Once you know the destination, ask about travel dates if not mentioned
3. Gather information about their budget and trip preferences (adventure, relaxation, cultural, etc.)
4. Provide helpful suggestions based on what you've learned

Always be conversational and remember what the user has already told you. Don't ask for 
information they've already provided.

Current collected information:
- Destination: {destination}
- Dates: {travel_dates}
- Budget: {budget}
- Trip type: {trip_type}
- Preferences: {preferences}
"""

def create_travel_agent_graph(checkpointer: PostgresSaver):
    """Create the LangGraph agent with checkpointing."""
    
    # Initialize the LLM
    LLM_ENDPOINT_NAME = "databricks-claude-sonnet-4-5"
    llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)
    
    def agent_node(state: TravelPlannerState) -> dict:
        """Main agent node that processes user input and generates responses."""
        
        # Format system prompt with current state
        system_content = SYSTEM_PROMPT.format(
            destination=state.get("destination", "Not specified"),
            travel_dates=state.get("travel_dates", "Not specified"),
            budget=state.get("budget", "Not specified"),
            trip_type=state.get("trip_type", "Not specified"),
            preferences=", ".join(state.get("preferences", [])) or "None collected"
        )
        
        # Build messages for the LLM
        messages = [SystemMessage(content=system_content)] + state["messages"]
        
        # Get response from LLM
        response = llm.invoke(messages)
        
        # Extract any travel information from the conversation
        # In a production agent, you might use a separate extraction step
        updated_state = extract_travel_info(state, response.content)
        updated_state["messages"] = [response]
        
        return updated_state
    
    def extract_travel_info(state: TravelPlannerState, response: str) -> dict:
        """Extract structured travel information from the conversation.
        
        In a production system, you might use an LLM call for this extraction.
        For simplicity, we'll do basic pattern matching here.
        """
        updates = {}
        
        # Get the last user message for context
        user_messages = [m for m in state["messages"] if isinstance(m, HumanMessage)]
        if user_messages:
            last_user_msg = user_messages[-1].content.lower()
            
            # Simple extraction logic (production would use LLM)
            if not state.get("destination"):
                destinations = ["rome", "paris", "tokyo", "new york", "london", "barcelona"]
                for dest in destinations:
                    if dest in last_user_msg:
                        updates["destination"] = dest.title()
                        break
            
            if not state.get("budget"):
                if "budget" in last_user_msg or "$" in last_user_msg:
                    # Extract budget mentions
                    if "luxury" in last_user_msg:
                        updates["budget"] = "Luxury ($500+/day)"
                    elif "moderate" in last_user_msg or "mid" in last_user_msg:
                        updates["budget"] = "Moderate ($150-300/day)"
                    elif "budget" in last_user_msg:
                        updates["budget"] = "Budget ($50-150/day)"
        
        return updates
    
    # Build the graph
    graph = StateGraph(TravelPlannerState)
    graph.add_node("agent", agent_node)
    graph.add_edge(START, "agent")
    graph.add_edge("agent", END)
    
    # Compile with checkpointer for state persistence
    return graph.compile(checkpointer=checkpointer)

## Step 9: Run a Multi-Turn Conversation

Let's test our agent with a realistic conversation flow. The agent will:

1. **Remember context** from previous messages
2. **Extract information** (destination, dates, preferences)
3. **Avoid repetition** - won't ask for info already provided
4. **Persist state** in Lakebase after each turn

Each conversation has a unique `thread_id` that groups related messages together. The checkpointer automatically:
* Saves state after each turn
* Loads previous state before processing new messages
* Enables conversation resumption across sessions

## Step 10: Inspect Checkpoint History

**Checkpoints** are snapshots of conversation state at specific points in time. They enable:

* **Debugging**: See exactly what the agent knew at each step
* **Time-travel**: Resume from any previous point
* **Branching**: Explore alternative conversation paths

Each checkpoint contains:
* `checkpoint_id`: Unique identifier
* `timestamp`: When it was created
* `message_count`: Number of messages so far
* `destination`, `travel_dates`: Extracted information

Let's view the checkpoint history for our conversation:

## Step 11: Branch from a Checkpoint (Time-Travel)

One of LangGraph's most powerful features is **checkpoint branching** - the ability to go back to any point in a conversation and explore a different path.

**Use cases:**
* **A/B testing**: Compare different agent responses
* **Error recovery**: Retry from before a mistake
* **What-if scenarios**: Explore alternative conversation flows

In this example, we'll:
1. Find the checkpoint before the user mentioned "Rome"
2. Branch from that point with "Paris" instead
3. See how the conversation diverges

The original "Rome" conversation remains intact in the database!

In [0]:
import uuid

def run_conversation_turn(
    graph,
    user_message: str,
    thread_id: str,
    lakebase_manager: LakebaseConnectionManager
) -> str:
    """Execute a single turn in the conversation."""
    
    config = {"configurable": {"thread_id": thread_id}}
    
    # Input for this turn
    input_state = {
        "messages": [HumanMessage(content=user_message)],
        "destination": None,
        "travel_dates": None,
        "budget": None,
        "trip_type": None,
        "preferences": []
    }
    
    # Run the graph - checkpointer automatically loads previous state
    with lakebase_manager.get_connection() as conn:
        checkpointer = PostgresSaver(conn)
        
        # Ensure tables exist (only needed once)
        checkpointer.setup()
        
        graph = create_travel_agent_graph(checkpointer)
        result = graph.invoke(input_state, config)
    
    # Return the assistant's response
    return result["messages"][-1].content


# Demonstrate a multi-turn conversation
thread_id = str(uuid.uuid4())
print(f"Starting conversation with thread_id: {thread_id}\n")

# Turn 1
response1 = run_conversation_turn(
    graph=None,  # Created inside function
    user_message="Hi! I'm thinking about planning a trip.",
    thread_id=thread_id,
    lakebase_manager=lakebase_manager
)
print(f"User: Hi! I'm thinking about planning a trip.")
print(f"Agent: {response1}\n")

# Turn 2
response2 = run_conversation_turn(
    graph=None,
    user_message="I've always wanted to visit Rome!",
    thread_id=thread_id,
    lakebase_manager=lakebase_manager
)
print(f"User: I've always wanted to visit Rome!")
print(f"Agent: {response2}\n")

# Turn 3
response3 = run_conversation_turn(
    graph=None,
    user_message="I'm thinking early March, maybe the first week. What's the weather like?",
    thread_id=thread_id,
    lakebase_manager=lakebase_manager
)
print(f"User: I'm thinking early March, maybe the first week. What's the weather like?")
print(f"Agent: {response3}\n")

Starting conversation with thread_id: 0deb3870-d2a4-46cf-b62d-5c4f4801d491

User: Hi! I'm thinking about planning a trip.
Agent: Hello! How exciting that you're planning a trip! 🌍 

I'd love to help you plan something amazing. To get started, do you have a destination in mind, or are you still exploring options? If you're open to suggestions, I can help you find the perfect spot based on what you're looking for!

User: I've always wanted to visit Rome!
Agent: Rome is an absolutely fantastic choice! 🇮🇹 The Eternal City has so much to offer - incredible history, stunning architecture, world-class art, and of course, amazing food!

To help me plan the perfect Roman adventure for you, I have a few questions:

**When are you thinking of going?** Do you have specific dates in mind, or are you flexible? (Different seasons offer different experiences - spring and fall are particularly lovely with mild weather and fewer crowds!)

Also, once we nail down the timing, I'd love to know more about:


In [0]:
from typing import List, Dict

def get_checkpoint_history(
    thread_id: str,
    lakebase_manager: LakebaseConnectionManager,
    limit: int = 10
) -> List[Dict[str, Any]]:
    """Retrieve checkpoint history for a thread.
    
    Returns a list of checkpoints with metadata, ordered from most recent to oldest.
    """
    with lakebase_manager.get_connection() as conn:
        checkpointer = PostgresSaver(conn)
        
        # Create a minimal graph just to access state history
        graph = StateGraph(TravelPlannerState)
        graph.add_node("agent", lambda x: x)
        graph.add_edge(START, "agent")
        graph.add_edge("agent", END)
        compiled = graph.compile(checkpointer=checkpointer)
        
        config = {"configurable": {"thread_id": thread_id}}
        
        history = []
        for state in compiled.get_state_history(config):
            if len(history) >= limit:
                break
            
            # Extract useful information from each checkpoint
            messages = state.values.get("messages", [])
            history.append({
                "checkpoint_id": state.config["configurable"]["checkpoint_id"],
                "thread_id": thread_id,
                "timestamp": state.created_at,
                "next_nodes": state.next,
                "message_count": len(messages),
                "last_message": _get_message_preview(messages),
                "destination": state.values.get("destination"),
                "travel_dates": state.values.get("travel_dates")
            })
        
        return history

def _get_message_preview(messages: list, max_length: int = 100) -> str:
    """Get a preview of the last message for checkpoint identification."""
    if not messages:
        return None
    last_msg = messages[-1]
    content = getattr(last_msg, "content", str(last_msg))
    return content[:max_length] + "..." if len(content) > max_length else content


# View the checkpoint history for our conversation
history = get_checkpoint_history(thread_id, lakebase_manager, limit=20)

print("Checkpoint History:")
print("-" * 80)
for i, checkpoint in enumerate(history):
    print(f"\n[{i}] Checkpoint: {checkpoint['checkpoint_id'][:16]}...")
    print(f"    Timestamp: {checkpoint['timestamp']}")
    print(f"    Messages: {checkpoint['message_count']}")
    print(f"    Destination: {checkpoint['destination']}")
    print(f"    Preview: {checkpoint['last_message']}")

Checkpoint History:
--------------------------------------------------------------------------------

[0] Checkpoint: 1f0f2d63-3499-6c...
    Timestamp: 2026-01-16T12:23:44.587043+00:00
    Messages: 6
    Destination: None
    Preview: Early March in Rome is a lovely time to visit! 🌸

**Weather-wise, you can expect:**
- Temperatures a...

[1] Checkpoint: 1f0f2d62-f5ec-6b...
    Timestamp: 2026-01-16T12:23:38.014984+00:00
    Messages: 5
    Destination: None
    Preview: I'm thinking early March, maybe the first week. What's the weather like?

[2] Checkpoint: 1f0f2d62-f5e8-67...
    Timestamp: 2026-01-16T12:23:38.013256+00:00
    Messages: 4
    Destination: Rome
    Preview: Rome is an absolutely fantastic choice! 🇮🇹 The Eternal City has so much to offer - incredible histor...

[3] Checkpoint: 1f0f2d62-f405-67...
    Timestamp: 2026-01-16T12:23:37.815409+00:00
    Messages: 4
    Destination: Rome
    Preview: Rome is an absolutely fantastic choice! 🇮🇹 The Eternal City has so much to

In [0]:
from typing import Dict, Any

def branch_from_checkpoint(
    thread_id: str,
    checkpoint_id: str,
    new_message: str,
    lakebase_manager: LakebaseConnectionManager
) -> Dict[str, Any]:
    """Branch from a specific checkpoint with a different message.
    
    This creates a new fork in the conversation history, preserving the original.
    """
    config = {
        "configurable": {
            "thread_id": thread_id,
            "checkpoint_id": checkpoint_id  # Resume from this specific point
        }
    }
    
    with lakebase_manager.get_connection() as conn:
        checkpointer = PostgresSaver(conn)
        graph = create_travel_agent_graph(checkpointer)
        
        # Run with the new message, branching from the checkpoint
        input_state = {
            "messages": [HumanMessage(content=new_message)],
            "destination": None,
            "travel_dates": None,
            "budget": None,
            "trip_type": None,
            "preferences": []
        }
        
        result = graph.invoke(input_state, config)
        
        # Get the new checkpoint ID
        new_state = graph.get_state(config)
        
        return {
            "response": result["messages"][-1].content,
            "new_checkpoint_id": new_state.config["configurable"]["checkpoint_id"],
            "parent_checkpoint_id": checkpoint_id,
            "destination": result.get("destination")
        }


# Find the checkpoint right after the user's greeting (before they mentioned Rome)
# This is typically the checkpoint at index 1 or 2 in the history
history = get_checkpoint_history(thread_id, lakebase_manager)

# Find checkpoint before destination was mentioned
branch_point = None
for checkpoint in reversed(history):  # Start from oldest
    if checkpoint["destination"] is None:
        branch_point = checkpoint["checkpoint_id"]
        break

if branch_point:
    print(f"Branching from checkpoint: {branch_point[:16]}...")
    print("Original: User said 'Rome'")
    print("Alternative: User says 'Paris'\n")
    
    result = branch_from_checkpoint(
        thread_id=thread_id,
        checkpoint_id=branch_point,
        new_message="I've always wanted to visit Paris!",
        lakebase_manager=lakebase_manager
    )
    
    print(f"User: I've always wanted to visit Paris!")
    print(f"Agent: {result['response']}")
    print(f"\nNew checkpoint created: {result['new_checkpoint_id'][:16]}...")
    print(f"Branched from: {result['parent_checkpoint_id'][:16]}...")

Branching from checkpoint: 1f0f2d62-a71f-67...
Original: User said 'Rome'
Alternative: User says 'Paris'

User: I've always wanted to visit Paris!
Agent: How wonderful! Paris is an absolutely magical city! 🗼 There's so much to see and do - from the iconic Eiffel Tower and the Louvre to charming cafés in Montmartre and strolls along the Seine.

To help me plan the perfect Parisian adventure for you, when are you thinking of visiting? Do you have specific dates in mind, or are you flexible with timing? 

Also, it would be helpful to know:
- How long are you planning to stay?
- Is this your first time in Paris, or have you been before?

This will help me tailor recommendations that make the most of your time in the City of Light! ✨

New checkpoint created: 1f0f2d62-a71f-67...
Branched from: 1f0f2d62-a71f-67...


---

# Part 3: Production Deployment with MLflow

Now let's package our agent for production deployment using **MLflow's ResponsesAgent** framework. This enables:

* **Model Serving**: Deploy as a REST API endpoint
* **Version Control**: Track agent versions in Unity Catalog
* **Resource Management**: Automatic authentication to Databricks resources
* **Standardized Interface**: Compatible with Databricks AI Gateway

## What is ResponsesAgent?

`ResponsesAgent` is MLflow's framework for deploying conversational AI agents. It provides:

* **Standard API**: Request/response format compatible with OpenAI's API
* **Streaming Support**: Real-time token streaming
* **Custom Outputs**: Return metadata like thread_id, checkpoint_id
* **Framework Agnostic**: Works with LangGraph, LangChain, or custom code

## Deployment Architecture

1. **agent.py**: Production-ready agent class with ResponsesAgent interface
2. **MLflow Logging**: Package agent with dependencies and resources
3. **Unity Catalog**: Register model for governance and versioning
4. **Model Serving**: Deploy as a scalable REST endpoint (optional)

Let's create the production agent!

## Step 12: Create the Production Agent File

We'll create `agent.py` - a self-contained file that:

**Key Components:**

1. **TravelPlanningAgent Class**: Extends `ResponsesAgent` with:
   * `predict()`: Handles synchronous requests
   * `predict_stream()`: Handles streaming requests
   * Connection management with Lakebase
   * Checkpoint history retrieval

2. **Environment Variables**: Connection details set at deployment:
   * `LAKEBASE_HOST`: Database endpoint
   * `LAKEBASE_USER`: OAuth username
   * `LAKEBASE_PASSWORD`: OAuth token
   * `LAKEBASE_DATABASE`: Database name

3. **MLflow Integration**:
   * `mlflow.models.set_model(AGENT)`: Registers the agent instance
   * Enables MLflow to discover and validate the model

**Important**: This file uses `psycopg` (v3), not `psycopg2`, for LangGraph compatibility.

## Step 13: Log and Register the Agent in Unity Catalog

This final step:

1. **Sets environment variables**: So agent.py can connect during validation
2. **Defines resources**: Declares the Claude endpoint the agent needs
3. **Logs the model**: Packages agent.py with dependencies
4. **Registers in Unity Catalog**: Creates version 1 of the model

**Key Parameters:**

* `name="agent"`: Artifact name within the MLflow run
* `python_model="agent.py"`: Path to the agent file
* `resources`: Databricks resources for automatic auth passthrough
* `pip_requirements`: Pinned versions for reproducibility
* `registered_model_name`: Three-level Unity Catalog name

**After registration**, you can:
* Deploy to Model Serving for REST API access
* Query from notebooks using `mlflow.pyfunc.load_model()`
* Track versions and lineage in Unity Catalog

Let's log the model!

In [0]:
%%writefile agent.py
# agent.py - Save this as a separate file for MLflow logging

import os
import uuid
import logging
from typing import Any, Dict, Generator, List, Optional
from contextlib import contextmanager
from databricks_langchain import (
    ChatDatabricks,
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
import mlflow
import psycopg
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.checkpoint.postgres import PostgresSaver
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
    ResponsesAgentRequest,
    ResponsesAgentResponse,
    ResponsesAgentStreamEvent,
)
from typing import Annotated
from typing_extensions import TypedDict

logger = logging.getLogger(__name__)

# Configuration - these will be set via environment variables at deployment time
# For local development, you can set these in your notebook before importing
LAKEBASE_HOST = os.environ.get("LAKEBASE_HOST")
LAKEBASE_DATABASE = os.environ.get("LAKEBASE_DATABASE", "travel_assistant_db")
LAKEBASE_USER = os.environ.get("LAKEBASE_USER")
LAKEBASE_PASSWORD = os.environ.get("LAKEBASE_PASSWORD")


class TravelPlannerState(TypedDict):
    """State schema for the Travel Planning Assistant."""
    messages: Annotated[list, add_messages]
    destination: Optional[str]
    travel_dates: Optional[str]
    budget: Optional[str]
    trip_type: Optional[str]
    preferences: List[str]


SYSTEM_PROMPT = """You are a helpful Travel Planning Assistant. Your goal is to help users 
plan their perfect trip by gathering information about their preferences and providing 
personalized recommendations.

When interacting with users:
1. If they haven't specified a destination, ask about it
2. Once you know the destination, ask about travel dates if not mentioned
3. Gather information about their budget and trip preferences
4. Provide helpful suggestions based on what you've learned

Always be conversational and remember what the user has already told you.

Current collected information:
- Destination: {destination}
- Dates: {travel_dates}
- Budget: {budget}
- Trip type: {trip_type}
- Preferences: {preferences}
"""


class TravelPlanningAgent(ResponsesAgent):
    """Production-ready Travel Planning Agent with Lakebase-backed memory."""
    
    def __init__(self):
        """Initialize the agent.
        
        Note: Heavy initialization should be deferred to first predict call
        to avoid issues in distributed serving environment.
        """
        self._llm = None
    
    @property
    def llm(self):
        """Lazy initialization of the LLM client."""
        if self._llm is None:
            LLM_ENDPOINT_NAME = "databricks-claude-sonnet-4-5"
            self._llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)
        return self._llm
    
    @contextmanager
    def get_connection(self):
        """Get a Lakebase connection as a context manager."""
        # Validate that connection parameters are set
        if not all([LAKEBASE_HOST, LAKEBASE_USER, LAKEBASE_PASSWORD]):
            raise ValueError(
                "Lakebase connection parameters not set. "
                "Please set LAKEBASE_HOST, LAKEBASE_USER, and LAKEBASE_PASSWORD environment variables."
            )
        
        conn = psycopg.connect(
            host=LAKEBASE_HOST,
            port=5432,
            dbname=LAKEBASE_DATABASE,
            user=LAKEBASE_USER,
            password=LAKEBASE_PASSWORD,
            sslmode="require",
            autocommit=True
        )
        try:
            yield conn
        finally:
            conn.close()
    
    def _create_graph(self, checkpointer: PostgresSaver):
        """Create the LangGraph agent graph."""
        
        def agent_node(state: TravelPlannerState) -> dict:
            """Process user input and generate response."""
            system_content = SYSTEM_PROMPT.format(
                destination=state.get("destination") or "Not specified",
                travel_dates=state.get("travel_dates") or "Not specified",
                budget=state.get("budget") or "Not specified",
                trip_type=state.get("trip_type") or "Not specified",
                preferences=", ".join(state.get("preferences") or []) or "None"
            )
            
            messages = [SystemMessage(content=system_content)] + state["messages"]
            response = self.llm.invoke(messages)
            
            return {"messages": [response]}
        
        graph = StateGraph(TravelPlannerState)
        graph.add_node("agent", agent_node)
        graph.add_edge(START, "agent")
        graph.add_edge("agent", END)
        
        return graph.compile(checkpointer=checkpointer)
    
    def _convert_to_langchain_messages(
        self, 
        messages: List[Dict[str, Any]]
    ) -> List[BaseMessage]:
        """Convert ResponsesAgent messages to LangChain format."""
        lc_messages = []
        for msg in messages:
            role = msg.get("role", "user")
            content = msg.get("content", "")
            
            if role == "user":
                lc_messages.append(HumanMessage(content=content))
            elif role == "assistant":
                lc_messages.append(AIMessage(content=content))
            # Add handling for other roles as needed
        
        return lc_messages
    
    def _convert_to_response_format(
        self,
        messages: List[BaseMessage]
    ) -> List[Dict[str, Any]]:
        """Convert LangChain messages to ResponsesAgent output format."""
        output = []
        for msg in messages:
            if isinstance(msg, AIMessage):
                output.append(
                    self.create_text_output_item(
                        text=msg.content,
                        id=str(uuid.uuid4())
                    )
                )
        return output
    
    def get_checkpoint_history(
        self, 
        thread_id: str, 
        limit: int = 10
    ) -> List[Dict[str, Any]]:
        """Retrieve checkpoint history for debugging and time-travel."""
        config = {"configurable": {"thread_id": thread_id}}
        
        with self.get_connection() as conn:
            checkpointer = PostgresSaver(conn)
            graph = self._create_graph(checkpointer)
            
            history = []
            for state in graph.get_state_history(config):
                if len(history) >= limit:
                    break
                
                messages = state.values.get("messages", [])
                history.append({
                    "checkpoint_id": state.config["configurable"]["checkpoint_id"],
                    "thread_id": thread_id,
                    "timestamp": state.created_at,
                    "next_nodes": state.next,
                    "message_count": len(messages),
                    "last_message": self._get_message_preview(messages)
                })
            
            return history
    
    def _get_message_preview(self, messages: list, max_length: int = 100) -> str:
        """Get a preview of the last message."""
        if not messages:
            return None
        content = getattr(messages[-1], "content", "")
        return content[:max_length] + "..." if len(content) > max_length else content
    
    def predict(
        self, 
        request: ResponsesAgentRequest
    ) -> ResponsesAgentResponse:
        """Handle a prediction request.
        
        This is the main entry point for non-streaming requests.
        """
        # Extract thread_id from custom_inputs, generate new one if not provided
        custom_inputs = dict(request.custom_inputs or {})
        if "thread_id" not in custom_inputs:
            custom_inputs["thread_id"] = str(uuid.uuid4())
        
        thread_id = custom_inputs["thread_id"]
        checkpoint_id = custom_inputs.get("checkpoint_id")  # Optional for branching
        
        # Build checkpoint configuration
        checkpoint_config = {"configurable": {"thread_id": thread_id}}
        if checkpoint_id:
            checkpoint_config["configurable"]["checkpoint_id"] = checkpoint_id
            logger.info(f"Branching from checkpoint: {checkpoint_id}")
        
        # Convert input messages to LangChain format
        lc_messages = self._convert_to_langchain_messages(
            [msg.model_dump() for msg in request.input]
        )
        
        # Prepare input state
        input_state = {
            "messages": lc_messages,
            "destination": None,
            "travel_dates": None,
            "budget": None,
            "trip_type": None,
            "preferences": []
        }
        
        # Execute the graph
        with self.get_connection() as conn:
            checkpointer = PostgresSaver(conn)
            graph = self._create_graph(checkpointer)
            
            result = graph.invoke(input_state, checkpoint_config)
        
        # Convert output to ResponsesAgent format
        output = self._convert_to_response_format(result["messages"])
        
        # Include thread_id and checkpoint info in custom outputs
        custom_outputs = {"thread_id": thread_id}
        
        try:
            history = self.get_checkpoint_history(thread_id, limit=1)
            if history:
                custom_outputs["checkpoint_id"] = history[0]["checkpoint_id"]
        except Exception as e:
            logger.warning(f"Could not retrieve checkpoint_id: {e}")
        
        if checkpoint_id:
            custom_outputs["parent_checkpoint_id"] = checkpoint_id
        
        return ResponsesAgentResponse(
            output=output,
            custom_outputs=custom_outputs
        )
    
    def predict_stream(
        self,
        request: ResponsesAgentRequest
    ) -> Generator[ResponsesAgentStreamEvent, None, None]:
        """Handle a streaming prediction request.
        
        For simplicity, this implementation collects the full response
        and streams it. A production implementation would stream tokens.
        """
        # Get the full response
        response = self.predict(request)
        
        # Stream the output items
        for item in response.output:
            yield ResponsesAgentStreamEvent(
                type="response.output_item.done",
                item=item
            )


# Create the agent instance for MLflow
AGENT = TravelPlanningAgent()

# Register the model with MLflow (required for code-based logging)
mlflow.models.set_model(AGENT)

# For local testing
if __name__ == "__main__":
    from mlflow.types.responses import ResponsesAgentRequest
    
    # Test the agent
    test_request = ResponsesAgentRequest(
        input=[{"role": "user", "content": "I want to plan a trip to Rome!"}]
    )
    
    response = AGENT.predict(test_request)
    print(f"Response: {response.output}")
    print(f"Thread ID: {response.custom_outputs.get('thread_id')}")

Overwriting agent.py


In [0]:
import mlflow
import os
from mlflow.models.resources import DatabricksServingEndpoint
from pkg_resources import get_distribution

# Set environment variables for agent.py to use
# Generate fresh credentials
credential = w.database.generate_database_credential(
    instance_names=[LAKEBASE_INSTANCE_NAME]
)
instance = w.database.get_database_instance(name=LAKEBASE_INSTANCE_NAME)
current_user = w.current_user.me()

os.environ["LAKEBASE_HOST"] = instance.read_write_dns
os.environ["LAKEBASE_DATABASE"] = "travel_assistant_db"
os.environ["LAKEBASE_USER"] = current_user.user_name
os.environ["LAKEBASE_PASSWORD"] = credential.token

# Set the experiment
mlflow.set_experiment("/Users/ankit.yadav@databricks.com/travel-planning-agent")

# Define resources the agent needs access to
resources = [
    DatabricksServingEndpoint(endpoint_name="databricks-claude-sonnet-4-5"),
]

# Log the agent
with mlflow.start_run():
    logged_agent = mlflow.pyfunc.log_model(
        name="agent",
        python_model="agent.py",
        resources=resources,
        pip_requirements=[
            f"mlflow=={get_distribution('mlflow').version}",
            f"databricks-langchain=={get_distribution('databricks-langchain').version}",
            f"langgraph=={get_distribution('langgraph').version}",
            f"langgraph-checkpoint-postgres=={get_distribution('langgraph-checkpoint-postgres').version}",
            "psycopg[binary]",
            f"langchain=={get_distribution('langchain').version}",
            f"langchain-core=={get_distribution('langchain-core').version}",
            "pydantic>=2.0.0",
        ],
        registered_model_name="ankit_yadav.default.travel_planning_agent"
    )

print(f"Model logged to: {logged_agent.model_uri}")

🔗 View Logged Model at: https://fevm-ay-demo-workspace.cloud.databricks.com/ml/experiments/3859159789242077/models/m-5fb17f88ed744d149f9a1ff2b091fe5a?o=7474653873260502
2026/01/16 13:52:31 INFO mlflow.pyfunc: Predicting on input example to validate output
Registered model 'ankit_yadav.default.travel_planning_agent' already exists. Creating a new version of this model...


Uploading artifacts:   0%|          | 0/12 [00:00<?, ?it/s]

🔗 Created version '3' of model 'ankit_yadav.default.travel_planning_agent': https://fevm-ay-demo-workspace.cloud.databricks.com/explore/data/models/ankit_yadav/default/travel_planning_agent/version/3?o=7474653873260502


Model logged to: models:/m-5fb17f88ed744d149f9a1ff2b091fe5a


---

# Part 3: Production Deployment

Now that we have our agent working locally, let's deploy it to production. We'll skip the temporary OAuth token approach and go **straight to production-grade credential management** using:

* **Databricks Secrets**: Secure credential storage
* **Native PostgreSQL Roles**: Long-lived passwords (no expiration)
* **MLflow Model Serving**: Scalable REST API deployment

This approach ensures your agent can run indefinitely without token expiration issues.

---

## Production-Ready Credential Management

Before deploying to production, let's set up **long-lived credentials** that won't expire:

### What We'll Do:

1. **Create a Databricks Secret Scope** to store sensitive credentials
2. **Create a Native PostgreSQL Role** with password authentication (no expiration)
3. **Grant Permissions** to the role for the travel_assistant_db
4. **Store Credentials** in Databricks Secrets
5. **Update Model Registration** to reference secrets
6. **Deploy Endpoint** with secret references

### Why This Approach?

* ✅ **No token expiration** - uses password authentication
* ✅ **Secure storage** - credentials encrypted in Databricks Secrets
* ✅ **Easy rotation** - update secret without redeploying
* ✅ **Audit trail** - track who accesses secrets
* ✅ **Production-grade** - follows security best practices

Let's set it up!

In [0]:
# Create a secret scope for Lakebase credentials
SECRET_SCOPE = "lakebase-prod"

try:
    w.secrets.create_scope(
        scope=SECRET_SCOPE
    )
    print(f"✓ Created secret scope: {SECRET_SCOPE}")
except Exception as e:
    if "already exists" in str(e).lower():
        print(f"✓ Secret scope already exists: {SECRET_SCOPE}")
    else:
        raise

print(f"\nSecret scope '{SECRET_SCOPE}' is ready for storing credentials")

✓ Secret scope already exists: lakebase-prod

Secret scope 'lakebase-prod' is ready for storing credentials


In [0]:
from databricks.sdk.service.database import DatabaseInstance, DatabaseInstanceState
import time

print(f"Enabling native PostgreSQL login on instance: {LAKEBASE_INSTANCE_NAME}\n")

# Wait for instance to be AVAILABLE before updating
print("Checking instance state...")
max_wait = 600  # 10 minutes
start_time = time.time()

while time.time() - start_time < max_wait:
    instance = w.database.get_database_instance(name=LAKEBASE_INSTANCE_NAME)
    state = instance.state
    
    print(f"  Current state: {state}")
    
    if state == DatabaseInstanceState.AVAILABLE:
        print(f"\n✓ Instance is AVAILABLE\n")
        break
    elif state == DatabaseInstanceState.FAILED:
        print(f"\n✗ Instance provisioning failed")
        raise Exception(f"Instance {LAKEBASE_INSTANCE_NAME} is in FAILED state")
    
    time.sleep(10)  # Check every 10 seconds
else:
    print(f"\n⚠ Timeout waiting for instance to be AVAILABLE")
    raise Exception(f"Instance still in {state} state after {max_wait} seconds")

# Update the instance to enable native password authentication
w.database.update_database_instance(
    name=LAKEBASE_INSTANCE_NAME,
    database_instance=DatabaseInstance(
        name=LAKEBASE_INSTANCE_NAME,
        enable_pg_native_login=True
    ),
    update_mask="enable_pg_native_login"
)

print(f"✓ Native PostgreSQL login enabled")
print(f"\nThis allows the instance to accept password-based authentication")
print(f"in addition to OAuth tokens.")

# Verify the setting
time.sleep(2)  # Wait for update to propagate

instance = w.database.get_database_instance(name=LAKEBASE_INSTANCE_NAME)
print(f"\nVerification:")
print(f"  enable_pg_native_login: {instance.effective_enable_pg_native_login}")

if instance.effective_enable_pg_native_login:
    print(f"\n✓ Native login is now enabled!")
else:
    print(f"\n⚠ Setting may take a moment to propagate. Check again in 30 seconds.")

Enabling native PostgreSQL login on instance: travel-agent-memory

Checking instance state...
  Current state: DatabaseInstanceState.AVAILABLE

✓ Instance is AVAILABLE

✓ Native PostgreSQL login enabled

This allows the instance to accept password-based authentication
in addition to OAuth tokens.

Verification:
  enable_pg_native_login: True

✓ Native login is now enabled!


In [0]:
import psycopg2
import secrets
import string

# Generate a strong password for the service account
def generate_password(length=32):
    """Generate a cryptographically strong password."""
    alphabet = string.ascii_letters + string.digits + "!@#$%^&*"
    return ''.join(secrets.choice(alphabet) for _ in range(length))

AGENT_USER = "travel_agent_service"
AGENT_PASSWORD = generate_password()

print(f"Creating PostgreSQL role: {AGENT_USER}\n")

# Connect as admin using OAuth
credential = w.database.generate_database_credential(
    instance_names=[LAKEBASE_INSTANCE_NAME]
)
instance = w.database.get_database_instance(name=LAKEBASE_INSTANCE_NAME)
current_user = w.current_user.me()

conn = psycopg2.connect(
    host=instance.read_write_dns,
    port=5432,
    database="postgres",
    user=current_user.user_name,
    password=credential.token,
    sslmode="require"
)
conn.autocommit = True
cursor = conn.cursor()

# Create the service role
try:
    cursor.execute(f"""
        CREATE ROLE {AGENT_USER} WITH LOGIN PASSWORD %s
    """, (AGENT_PASSWORD,))
    print(f"✓ Created role: {AGENT_USER}")
except psycopg2.errors.DuplicateObject:
    print(f"✓ Role already exists: {AGENT_USER}")
    # Update password for existing role
    cursor.execute(f"""
        ALTER ROLE {AGENT_USER} WITH PASSWORD %s
    """, (AGENT_PASSWORD,))
    print(f"✓ Updated password for: {AGENT_USER}")

cursor.close()
conn.close()

print(f"\n✓ PostgreSQL role created successfully")
print(f"\nNext: Store credentials in Databricks Secrets")

Creating PostgreSQL role: travel_agent_service

✓ Role already exists: travel_agent_service
✓ Updated password for: travel_agent_service

✓ PostgreSQL role created successfully

Next: Store credentials in Databricks Secrets


In [0]:
import psycopg2

print(f"Granting permissions to {AGENT_USER}...\n")

# Connect to postgres database as admin
credential = w.database.generate_database_credential(
    instance_names=[LAKEBASE_INSTANCE_NAME]
)
instance = w.database.get_database_instance(name=LAKEBASE_INSTANCE_NAME)
current_user = w.current_user.me()

conn = psycopg2.connect(
    host=instance.read_write_dns,
    port=5432,
    database="postgres",
    user=current_user.user_name,
    password=credential.token,
    sslmode="require"
)
conn.autocommit = True
cursor = conn.cursor()

# Grant database connection permission
cursor.execute(f"""
    GRANT CONNECT ON DATABASE travel_assistant_db TO {AGENT_USER}
""")
print(f"✓ Granted CONNECT on travel_assistant_db")

cursor.close()
conn.close()

# Connect to travel_assistant_db to grant table permissions
conn = psycopg2.connect(
    host=instance.read_write_dns,
    port=5432,
    database="travel_assistant_db",
    user=current_user.user_name,
    password=credential.token,
    sslmode="require"
)
conn.autocommit = True
cursor = conn.cursor()

# Grant schema usage
cursor.execute(f"""
    GRANT USAGE ON SCHEMA public TO {AGENT_USER}
""")
print(f"✓ Granted USAGE on schema public")

# Grant table permissions (for LangGraph checkpoint tables)
cursor.execute(f"""
    GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public TO {AGENT_USER}
""")
print(f"✓ Granted table permissions")

# Grant sequence permissions (for auto-increment IDs)
cursor.execute(f"""
    GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO {AGENT_USER}
""")
print(f"✓ Granted sequence permissions")

# Grant default privileges for future tables
cursor.execute(f"""
    ALTER DEFAULT PRIVILEGES IN SCHEMA public 
    GRANT SELECT, INSERT, UPDATE, DELETE ON TABLES TO {AGENT_USER}
""")
print(f"✓ Granted default privileges for future tables")

cursor.close()
conn.close()

print(f"\n✓ All permissions granted to {AGENT_USER}")

Granting permissions to travel_agent_service...

✓ Granted CONNECT on travel_assistant_db
✓ Granted USAGE on schema public
✓ Granted table permissions
✓ Granted sequence permissions
✓ Granted default privileges for future tables

✓ All permissions granted to travel_agent_service


In [0]:
# Store the credentials in Databricks Secrets
SECRET_SCOPE = "lakebase-prod"

print(f"Storing credentials in secret scope: {SECRET_SCOPE}\n")

# Store each credential
w.secrets.put_secret(
    scope=SECRET_SCOPE,
    key="lakebase-host",
    string_value=instance.read_write_dns
)
print(f"✓ Stored: lakebase-host")

w.secrets.put_secret(
    scope=SECRET_SCOPE,
    key="lakebase-database",
    string_value="travel_assistant_db"
)
print(f"✓ Stored: lakebase-database")

w.secrets.put_secret(
    scope=SECRET_SCOPE,
    key="lakebase-user",
    string_value=AGENT_USER
)
print(f"✓ Stored: lakebase-user")

w.secrets.put_secret(
    scope=SECRET_SCOPE,
    key="lakebase-password",
    string_value=AGENT_PASSWORD
)
print(f"✓ Stored: lakebase-password")

print(f"\n✓ All credentials stored securely in Databricks Secrets")
print(f"\nTo reference in endpoint configuration, use:")
print(f"  {{{{secrets/{SECRET_SCOPE}/lakebase-host}}}}")
print(f"  {{{{secrets/{SECRET_SCOPE}/lakebase-database}}}}")
print(f"  {{{{secrets/{SECRET_SCOPE}/lakebase-user}}}}")
print(f"  {{{{secrets/{SECRET_SCOPE}/lakebase-password}}}}")

Storing credentials in secret scope: lakebase-prod

✓ Stored: lakebase-host
✓ Stored: lakebase-database
✓ Stored: lakebase-user
✓ Stored: lakebase-password

✓ All credentials stored securely in Databricks Secrets

To reference in endpoint configuration, use:
  {{secrets/lakebase-prod/lakebase-host}}
  {{secrets/lakebase-prod/lakebase-database}}
  {{secrets/lakebase-prod/lakebase-user}}
  {{secrets/lakebase-prod/lakebase-password}}


In [0]:
import psycopg2

print(f"Testing connection with service account: {AGENT_USER}\n")

# Test connection using the service account credentials
try:
    test_conn = psycopg2.connect(
        host=instance.read_write_dns,
        port=5432,
        database="travel_assistant_db",
        user=AGENT_USER,
        password=AGENT_PASSWORD,
        sslmode="require"
    )
    
    cursor = test_conn.cursor()
    cursor.execute("SELECT current_user, current_database()")
    result = cursor.fetchone()
    
    print(f"✓ Successfully connected as: {result[0]}")
    print(f"✓ Connected to database: {result[1]}")
    
    # Test table access
    cursor.execute("SELECT COUNT(*) FROM conversation_memory")
    count = cursor.fetchone()[0]
    print(f"✓ Can access conversation_memory table ({count} rows)")
    
    cursor.close()
    test_conn.close()
    
    print(f"\n✓ Service account is working correctly!")
    print(f"\nReady to deploy with production credentials")
    
except Exception as e:
    print(f"✗ Connection failed: {e}")
    print(f"\nTroubleshooting:")
    print(f"  1. Verify native login is enabled (run cell 37)")
    print(f"  2. Wait 30 seconds after enabling native login")
    print(f"  3. Check that role was created successfully")
    print(f"  4. Verify permissions were granted")

Testing connection with service account: travel_agent_service

✓ Successfully connected as: travel_agent_service
✓ Connected to database: travel_assistant_db
✓ Can access conversation_memory table (2 rows)

✓ Service account is working correctly!

Ready to deploy with production credentials


## Step F: Deploy Endpoint with Databricks Secrets

Now we'll create/update the serving endpoint to use the secrets we just created. The endpoint will:

* Reference secrets using `{{secrets/scope/key}}` syntax
* Automatically decrypt secrets at runtime
* Never expose credentials in logs or UI
* Work indefinitely without token expiration

**Note:** If the endpoint already exists, this will update it with the new configuration.

In [0]:
from databricks.sdk.service.serving import (
    EndpointCoreConfigInput,
    ServedEntityInput,
    AiGatewayConfig,
    AiGatewayInferenceTableConfig
)
from databricks.sdk.errors import ResourceAlreadyExists

endpoint_name = "travel-planning-agent-prod"
SECRET_SCOPE = "lakebase-prod"

print(f"Creating production endpoint: {endpoint_name}...")
print("This may take 5-10 minutes...\n")

try:
    endpoint = w.serving_endpoints.create_and_wait(
        name=endpoint_name,
        config=EndpointCoreConfigInput(
            name=endpoint_name,
            served_entities=[
                ServedEntityInput(
                    entity_name="ankit_yadav.default.travel_planning_agent",
                    entity_version="2",
                    scale_to_zero_enabled=True,
                    workload_size="Small",
                    environment_vars={
                        "LAKEBASE_HOST": f"{{{{secrets/{SECRET_SCOPE}/lakebase-host}}}}",
                        "LAKEBASE_DATABASE": f"{{{{secrets/{SECRET_SCOPE}/lakebase-database}}}}",
                        "LAKEBASE_USER": f"{{{{secrets/{SECRET_SCOPE}/lakebase-user}}}}",
                        "LAKEBASE_PASSWORD": f"{{{{secrets/{SECRET_SCOPE}/lakebase-password}}}}"
                    }
                )
            ]
        ),
        ai_gateway=AiGatewayConfig(
            inference_table_config=AiGatewayInferenceTableConfig(
                catalog_name="ankit_yadav",
                schema_name="default",
                table_name_prefix="travel_agent_prod",
                enabled=True
            )
        )
    )
    print(f"✓ Production endpoint created: {endpoint_name}")
    print(f"  State: {endpoint.state.config_update}")
except ResourceAlreadyExists:
    print(f"✓ Endpoint already exists, updating configuration...")
    
    # Update existing endpoint
    w.serving_endpoints.update_config(
        name=endpoint_name,
        served_entities=[
            ServedEntityInput(
                entity_name="ankit_yadav.default.travel_planning_agent",
                entity_version="2",
                scale_to_zero_enabled=True,
                workload_size="Small",
                environment_vars={
                    "LAKEBASE_HOST": f"{{{{secrets/{SECRET_SCOPE}/lakebase-host}}}}",
                    "LAKEBASE_DATABASE": f"{{{{secrets/{SECRET_SCOPE}/lakebase-database}}}}",
                    "LAKEBASE_USER": f"{{{{secrets/{SECRET_SCOPE}/lakebase-user}}}}",
                    "LAKEBASE_PASSWORD": f"{{{{secrets/{SECRET_SCOPE}/lakebase-password}}}}"
                }
            )
        ]
    )
    print(f"✓ Endpoint configuration updated")

print(f"\n✓ Production endpoint ready with secure credentials!")
print(f"\nBenefits:")
print(f"  - No token expiration")
print(f"  - Credentials stored securely")
print(f"  - Easy to rotate (update secret, restart endpoint)")
print(f"  - Audit trail for secret access")

Creating production endpoint: travel-planning-agent-prod...
This may take 5-10 minutes...

✓ Production endpoint created: travel-planning-agent-prod
  State: EndpointStateConfigUpdate.NOT_UPDATING

✓ Production endpoint ready with secure credentials!

Benefits:
  - No token expiration
  - Credentials stored securely
  - Easy to rotate (update secret, restart endpoint)
  - Audit trail for secret access


In [0]:
import os
import requests
import json
from databricks.sdk.service.serving import EndpointStateReady

endpoint_name = "travel-planning-agent-prod"

# Get endpoint status
endpoint = w.serving_endpoints.get(name=endpoint_name)
print(f"Endpoint Status: {endpoint.state.ready}\n")

if endpoint.state.ready != EndpointStateReady.READY:
    print("⚠ Endpoint is not ready yet. Wait a few minutes and try again.")
else:
    print("Testing production endpoint with persistent memory...\n")
    
    # Get authentication token from notebook context
    token = os.environ.get("DATABRICKS_TOKEN")
    if not token:
        # Get token from notebook context
        try:
            token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
            print("✓ Using notebook authentication token\n")
        except:
            print("✗ Could not get authentication token")
            print("Please set DATABRICKS_TOKEN environment variable or create a PAT\n")
            token = None
    
    if token:
        # Build request URL and headers
        url = f'https://fevm-ay-demo-workspace.cloud.databricks.com/serving-endpoints/{endpoint_name}/invocations'
        headers = {
            'Authorization': f'Bearer {token}',
            'Content-Type': 'application/json'
        }
        
        # Make request with correct schema (use 'input' not 'inputs')
        request_data = {
            "input": [{"role": "user", "content": "Hi! I want to plan a trip to Japan."}]
        }
        
        response = requests.post(url, headers=headers, data=json.dumps(request_data))
        
        if response.status_code != 200:
            print(f"Error: Request failed with status {response.status_code}")
            print(f"Response: {response.text}")
            print(f"\nTroubleshooting:")
            print(f"  1. Check endpoint logs in the Serving UI")
            print(f"  2. Verify secrets are configured correctly")
            print(f"  3. Ensure native login is enabled on Lakebase instance")
        else:
            response_data = response.json()
            
            # Extract data from response
            thread_id = response_data.get("custom_outputs", {}).get("thread_id")
            checkpoint_id = response_data.get("custom_outputs", {}).get("checkpoint_id")
            agent_message = response_data['output'][0]['content'][0]['text']
            
            print(f"✓ Started conversation")
            print(f"  Thread ID: {thread_id}")
            print(f"  Checkpoint ID: {checkpoint_id}\n")
            print(f"Agent: {agent_message}")
            
            print(f"\n✓ Production endpoint is working!")
            print(f"\nKey features:")
            print(f"  - Credentials stored in Databricks Secrets")
            print(f"  - No token expiration issues")
            print(f"  - Persistent memory in Lakebase (thread_id: {thread_id})")
            print(f"  - Checkpoint system enabled (checkpoint_id: {checkpoint_id[:16]}...)")
            print(f"  - Ready for production traffic")

Endpoint Status: EndpointStateReady.READY

Testing production endpoint with persistent memory...

✓ Using notebook authentication token

✓ Started conversation
  Thread ID: 5fce1f80-bc2d-4ad7-8b79-eea7a1b49624
  Checkpoint ID: 1f0f3090-de81-6df8-8001-a84238814e79

Agent: Hello! How exciting that you're planning a trip to Japan! 🇯🇵 It's such an amazing destination with incredible culture, food, and experiences.

To help me create the perfect itinerary for you, I'd love to know a bit more:

**When are you thinking of visiting Japan?** The timing can really shape your experience:
- Spring (March-May) is famous for cherry blossoms
- Summer (June-August) has festivals but can be hot and humid
- Fall (September-November) offers beautiful autumn colors
- Winter (December-February) is great for skiing and winter illuminations

Do you have specific dates in mind, or are you flexible with timing?

✓ Production endpoint is working!

Key features:
  - Credentials stored in Databricks Secrets
  - N

---

## 🎉 Congratulations!

You've built a **production-ready AI agent with persistent memory** using Databricks Lakebase and LangGraph!

## What We Built

### Part 1: Lakebase Database Setup
✅ **Managed PostgreSQL Instance**: Zero-ops database with 1 CU  
✅ **Custom Database**: `travel_assistant_db` for agent memory  
✅ **Conversation Table**: Schema for storing chat history  
✅ **Connection Helper**: Reusable function for database access  

### Part 2: Stateful Agent with LangGraph
✅ **State Schema**: Tracks messages, destination, dates, budget, preferences  
✅ **Connection Manager**: psycopg3 integration for checkpointing  
✅ **Agent Graph**: Claude-powered conversational agent  
✅ **Multi-Turn Conversations**: Context-aware dialogue  
✅ **Checkpoint History**: View conversation snapshots  
✅ **Time-Travel Branching**: Explore alternative conversation paths  

### Part 3: Production Deployment
✅ **Native PostgreSQL Role**: `travel_agent_service` with no expiration  
✅ **Databricks Secrets**: Secure credential storage in `lakebase-prod` scope  
✅ **MLflow Registration**: Version 2 in Unity Catalog  
✅ **Production Endpoint**: `travel-planning-agent-prod` with secret references  
✅ **Inference Tables**: Automatic logging for monitoring  

## Key Technical Learnings

* **psycopg3 vs psycopg2**: LangGraph requires psycopg (v3) for PostgresSaver
* **Native Login**: Must enable `enable_pg_native_login=True` on Lakebase instance
* **Secret Syntax**: Use `{{secrets/scope/key}}` in environment variables
* **ResponsesAgent API**: Use `input` (singular), not `inputs` (plural)
* **Response Format**: Access `response['output']` (singular), not `outputs`
* **Authentication**: Use `dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()` for notebook tokens

## Production Architecture

```
[👤 User] 
    ↓ HTTPS Request
[🔒 Model Serving Endpoint]
    ↓ Loads credentials from Databricks Secrets
[🤖 TravelPlanningAgent (ResponsesAgent)]
    ↓ Connects with native Postgres role
[💾 Lakebase PostgreSQL]
    ↓ Stores checkpoints
[LangGraph PostgresSaver]
```

## Deployment Workflow Summary

1. **Create Lakebase instance** with native login enabled
2. **Create native Postgres role** with strong password
3. **Grant permissions** to the role
4. **Store credentials** in Databricks Secrets
5. **Register model** in Unity Catalog with MLflow
6. **Deploy endpoint** referencing secrets
7. **Query endpoint** with persistent memory

## Next Steps for Production

### 1. Add Multi-Turn Conversation Support

The current test uses a single message. For full conversations:

```python
# Continue conversation with thread_id
request_data = {
    "input": [{"role": "user", "content": "I'm thinking March."}],
    "custom_inputs": {"thread_id": "<thread_id_from_previous_response>"}
}
```

### 2. Implement Credential Rotation

```python
# Update secret with new password
w.secrets.put_secret(
    scope="lakebase-prod",
    key="lakebase-password",
    string_value="<new_password>"
)

# Restart endpoint to pick up new credentials
w.serving_endpoints.update_config(name=endpoint_name, ...)
```

### 3. Monitor with Inference Tables

Query the auto-generated inference table:

```sql
SELECT * FROM ankit_yadav.default.travel_agent_prod_payload
ORDER BY request_time DESC
LIMIT 100
```

### 4. Scale the Lakebase Instance

```python
w.database.update_database_instance(
    name=LAKEBASE_INSTANCE_NAME,
    database_instance=DatabaseInstance(
        name=LAKEBASE_INSTANCE_NAME,
        capacity="CU_2"  # Scale to 2 CUs
    ),
    update_mask="capacity"
)
```

## Resources

* **Lakebase Docs**: [docs.databricks.com/oltp](https://docs.databricks.com/aws/en/oltp/instances/about/)
* **LangGraph Docs**: [langchain-ai.github.io/langgraph](https://langchain-ai.github.io/langgraph/)
* **MLflow ResponsesAgent**: [mlflow.org/docs/latest/genai/serving/responses-agent](https://mlflow.org/docs/latest/genai/serving/responses-agent/)
* **Databricks Secrets**: [docs.databricks.com/security/secrets](https://docs.databricks.com/security/secrets/)
* **Model Serving**: [docs.databricks.com/machine-learning/model-serving](https://docs.databricks.com/machine-learning/model-serving/)

## Questions?

Reach out on LinkedIn or check the Databricks Community forums!

Happy building! 🚀