# Mosaic AI Agent Framework: Author and deploy a Stateful Agent with Thread-scoped memory using Databricks Lakebase and LangGraph
This notebook demonstrates how to build a stateful agent using the Mosaic AI Agent Framework and LangGraph, with Lakebase as the agent’s durable memory and checkpoint store. Threads allow you to store conversational state in Lakebase so you can pass in thread IDs into your agent instead of needing to send the full conversation history.
In this notebook, you will:
1. Author a Stateful Agent graph with Lakebase (the new Postgres database in Databricks) and Langgraph to manage state using thread ids in a Databricks Agent 
2. Wrap the LangGraph agent with `ResponsesAgent` interface to ensure compatibility with Databricks features
3. Test the agent's behavior locally
4. Register model to Unity Catalog, log and deploy the agent for use in apps and Playground

We are using [PostgresSaver in Langgraph](https://api.python.langchain.com/en/latest/checkpoint/langchain_postgres.checkpoint.PostgresSaver.html) to open a connection with our Lakebase Postgres database.

## Why use Lakebase?
Stateful agents need a place to persist, resume, and inspect their work. Lakebase provides a managed, UC-governed store for agent state:
- Durable, resumable state. Automatically capture threads, intermediate checkpoints, tool outputs, and node state after each graph step so you can resume, branch, or replay any point in time.
- Queryable & observable. Because state lands in the Lakehouse, you can use SQL (or notebooks) to audit conversations and build upon other Databricks functionality like dashboards
- Governed by Unity Catalog. Apply data permissions, lineage, and auditing to AI state, just like any other table.

## What are Stateful Agents?
Unlike stateless LLM calls, a stateful agent keeps and reuses context across steps and sessions. Each new conversation is tracked with a thread ID, which represents the logical task or dialogue stream. Pick up an existing thread at any time to continue the conversation without having to pass in the entire conversation history.

## Prerequisites
- Create a Lakebase instance, see Databricks documentation ([AWS](https://docs.databricks.com/aws/en/oltp/create/) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/oltp/create/)). 
- You can create a Lakebase instance by going to SQL Warehouses -> Lakebase Postgres -> Create database instance. You will need to retrieve values from the "Connection details" section of your Lakebase to fill out this notebook.
- Complete all the "TODO"s throughout this notebook

### Install dependencies

In [0]:
%pip install -U -qqqq databricks-langchain langgraph==0.5.3 uv databricks-agents mlflow-skinny[databricks] \
  langgraph-checkpoint-postgres==2.0.21 psycopg[binary,pool]
dbutils.library.restartPython()

## First time setup only: Set up checkpointer with your Lakebase instance

In [0]:
import os
import uuid
from databricks.sdk import WorkspaceClient
from psycopg_pool import ConnectionPool
import psycopg
from langgraph.checkpoint.postgres import PostgresSaver

# TODO: Fill in your lakebase instance details here. For the username, create a
# Service Principal and grant it databricks_superuser permissions onto the lakebase instance.
# See Service Principal Documentation for more information:
# https://docs.databricks.com/en/admin/users-groups/service-principals
# Use the Service Principal client id and secret as the SP_CLIENT_ID/SP_CLIENT_SECRET
# This will help initialize the checkpointers
DB_INSTANCE_NAME = "lakebase-erico-silva"  
DB_NAME          = "databricks_postgres"
SP_CLIENT_ID      = os.getenv("DATABRICKS_CLIENT_ID")
SP_CLIENT_SECRET      = os.getenv("DATABRICKS_CLIENT_SECRET")
SSL_MODE         = "require"
DB_HOST = "instance-9c1487ee-b26a-45c9-8ef1-55d6e9dd531b.database.cloud.databricks.com"
DB_PORT = 5432
WORKSPACE_HOST = "https://e2-demo-field-eng.cloud.databricks.com"

w = WorkspaceClient(
  host = WORKSPACE_HOST,
  client_id = SP_CLIENT_ID,
  client_secret = SP_CLIENT_SECRET
)

def db_password_provider() -> str:
    """
    Ask Databricks to mint a fresh DB credential for this instance.
    """
    cred = w.database.generate_database_credential(
        request_id=str(uuid.uuid4()),
        instance_names=[DB_INSTANCE_NAME],
    )
    return cred.token

class CustomConnection(psycopg.Connection):
    """
    A psycopg Connection subclass that injects a fresh password
    *at connection time* (only when the pool creates a new connection).
    """
    @classmethod
    def connect(cls, conninfo="", **kwargs):
        # Append the new password to kwargs
        kwargs["password"] = db_password_provider()
        # Call the superclass's connect method with updated kwargs
        return super().connect(conninfo, **kwargs)

pool = ConnectionPool(
    conninfo=f"dbname={DB_NAME} user={SP_CLIENT_ID} host={DB_HOST} port={DB_PORT} sslmode={SSL_MODE}",
    connection_class=CustomConnection,
    min_size=1,
    max_size=10,
    open=True,
)

# Use the pool to initialize your checkpoint tables
with pool.connection() as conn:
    conn.autocommit = True   # disable transaction wrapping
    checkpointer = PostgresSaver(conn)
    checkpointer.setup()
    conn.autocommit = False  # restore default if you want transactions later

    with conn.cursor() as cur:
        cur.execute("select 1")
    print("✅ Pool connected and checkpoint tables are ready.")

# Define the agent in code

## Write agent code to file agent.py
Define the agent code in a single cell below. This lets you write the agent code to a local Python file, using the `%%writefile` magic command, for subsequent logging and deployment.

## Wrap the LangGraph agent using the ResponsesAgent interface
For compatibility with Databricks AI features, the `LangGraphResponsesAgent` class implements the `ResponsesAgent` interface to wrap the LangGraph agent.

Databricks recommends using `ResponsesAgent` as it simplifies authoring multi-turn conversational agents using an open source standard. See MLflow's [ResponsesAgent documentation](https://www.mlflow.org/docs/latest/llms/responses-agent-intro/).

In [0]:
%%writefile agent.py
import json
import logging
import os
import time
import urllib.parse
import uuid
from threading import Lock
from typing import Annotated, Any, Generator, Optional, Sequence, TypedDict

import mlflow
from databricks_langchain import (
    ChatDatabricks,
    DatabricksFunctionClient,
    UCFunctionToolkit,
)
from databricks.sdk import WorkspaceClient
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
)
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langgraph.checkpoint.postgres import PostgresSaver
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
    ResponsesAgentRequest,
    ResponsesAgentResponse,
    ResponsesAgentStreamEvent,
)
import psycopg
from psycopg_pool import ConnectionPool
from psycopg.rows import dict_row
from contextlib import contextmanager

logger = logging.getLogger(__name__)
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))


############################################
# Define your LLM endpoint and system prompt
############################################
# TODO: Replace with your model serving endpoint
LLM_ENDPOINT_NAME = "databricks-claude-3-7-sonnet"

# TODO: Update with your system prompt
SYSTEM_PROMPT = "You are a helpful assistant. Use the available tools to answer questions."

# TODO: Fill in Lakebase config values here
LAKEBASE_CONFIG = {
    "instance_name": "lakebase-erico-silva",
    "conn_host": "instance-9c1487ee-b26a-45c9-8ef1-55d6e9dd531b.database.cloud.databricks.com",
    "conn_db_name": "databricks_postgres",
    "conn_ssl_mode": "require",
}

###############################################################################
## Define tools for your agent,enabling it to retrieve data or take actions
## beyond text generation
## To create and see usage examples of more tools, see
## https://docs.databricks.com/en/generative-ai/agent-framework/agent-tool.html
###############################################################################

tools = []

# Example UC tools; add your own as needed
UC_TOOL_NAMES: list[str] = []
if UC_TOOL_NAMES:
    uc_toolkit = UCFunctionToolkit(function_names=UC_TOOL_NAMES)
    tools.extend(uc_toolkit.tools)

# Use Databricks vector search indexes as tools
# See https://docs.databricks.com/en/generative-ai/agent-framework/unstructured-retrieval-tools.html#locally-develop-vector-search-retriever-tools-with-ai-bridge
# List to store vector search tool instances for unstructured retrieval.
VECTOR_SEARCH_TOOLS = []

# To add vector search retriever tools,
# use VectorSearchRetrieverTool and create_tool_info,
# then append the result to TOOL_INFOS.
# Example:
# VECTOR_SEARCH_TOOLS.append(
#     VectorSearchRetrieverTool(
#         index_name="",
#         # filters="..."
#     )
# )

tools.extend(VECTOR_SEARCH_TOOLS)

#####################
## Define agent logic
#####################


class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]
    custom_inputs: Optional[dict[str, Any]]
    custom_outputs: Optional[dict[str, Any]]


class CredentialConnection(psycopg.Connection):
    """Custom connection class that generates fresh OAuth tokens with caching."""
    
    workspace_client = None
    instance_name = None
    
    # Cache attributes
    _cached_credential = None
    _cache_timestamp = None
    _cache_duration = 3000  # 50 minutes in seconds (50 * 60)
    _cache_lock = Lock()
    
    @classmethod
    def connect(cls, conninfo='', **kwargs):
        """Override connect to inject OAuth token with 50-minute caching"""
        if cls.workspace_client is None or cls.instance_name is None:
            raise ValueError("workspace_client and instance_name must be set on CredentialConnection class")
        
        # Get cached or fresh credential and append the new password to kwargs
        credential_token = cls._get_cached_credential()
        kwargs['password'] = credential_token
        
        # Call the superclass's connect method with updated kwargs
        return super().connect(conninfo, **kwargs)
    
    @classmethod
    def _get_cached_credential(cls):
        """Get credential from cache or generate a new one if cache is expired"""
        with cls._cache_lock:
            current_time = time.time()
            
            # Check if we have a valid cached credential
            if (cls._cached_credential is not None and 
                cls._cache_timestamp is not None and 
                current_time - cls._cache_timestamp < cls._cache_duration):
                return cls._cached_credential
            
            # Generate new credential
            credential = cls.workspace_client.database.generate_database_credential(
                request_id=str(uuid.uuid4()),
                instance_names=[cls.instance_name]
            )
            
            # Cache the new credential
            cls._cached_credential = credential.token
            cls._cache_timestamp = current_time
            
            return cls._cached_credential


class LangGraphResponsesAgent(ResponsesAgent):
    """Stateful agent using ResponsesAgent with Lakebase PostgreSQL checkpointing.
    
    Features:
    - Connection pooling with credential rotation and caching
    - Thread-based conversation state persistence
    - Tool support with UC functions
    """

    def __init__(self, lakebase_config: dict[str, Any]):
        self.lakebase_config = lakebase_config
        self.workspace_client = WorkspaceClient()
        
        # Model and tools
        self.model = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)
        self.system_prompt = SYSTEM_PROMPT
        self.model_with_tools = self.model.bind_tools(tools) if tools else self.model
        
        # Connection pool configuration
        self.pool_min_size = int(os.getenv("DB_POOL_MIN_SIZE", "1"))
        self.pool_max_size = int(os.getenv("DB_POOL_MAX_SIZE", "10"))
        self.pool_timeout = float(os.getenv("DB_POOL_TIMEOUT", "30.0"))
        
        # Token cache duration (in minutes, can be overridden via env var)
        cache_duration_minutes = int(os.getenv("DB_TOKEN_CACHE_MINUTES", "50"))
        CredentialConnection._cache_duration = cache_duration_minutes * 60
        
        # Initialize the connection pool with rotating credentials
        self._connection_pool = self._create_rotating_pool()
        
        mlflow.langchain.autolog()

    def _get_username(self) -> str:
        """Get the username for database connection"""
        try:
            sp = self.workspace_client.current_service_principal.me()
            return sp.application_id
        except Exception:
            user = self.workspace_client.current_user.me()
            return user.user_name

    def _create_rotating_pool(self) -> ConnectionPool:
        """Create a connection pool that automatically rotates credentials with caching"""
        # Set the workspace client and instance name on the custom connection class
        CredentialConnection.workspace_client = self.workspace_client
        CredentialConnection.instance_name = self.lakebase_config["instance_name"]
        
        username = self._get_username()
        host = self.lakebase_config["conn_host"]
        database = self.lakebase_config.get("conn_db_name", "databricks_postgres")
        
        # Create pool with custom connection class
        pool = ConnectionPool(
            conninfo=f"dbname={database} user={username} host={host} sslmode=require",
            connection_class=CredentialConnection,
            min_size=self.pool_min_size,
            max_size=self.pool_max_size,
            timeout=self.pool_timeout,
            open=True,
            kwargs={
                "autocommit": True, # Required for the .setup() method to properly commit the checkpoint tables to the database
                "row_factory": dict_row, # Required because the PostgresSaver implementation accesses database rows using dictionary-style syntax
                "keepalives": 1,
                "keepalives_idle": 30,
                "keepalives_interval": 10,
                "keepalives_count": 5,
            }
        )
        
        # Test the pool
        try:
            with pool.connection() as conn:
                with conn.cursor() as cursor:
                    cursor.execute("SELECT 1")
            logger.info(
                f"Connection pool with rotating credentials created successfully "
                f"(min={self.pool_min_size}, max={self.pool_max_size}, "
                f"token_cache={CredentialConnection._cache_duration / 60:.0f} minutes)"
            )
        except Exception as e:
            pool.close()
            raise ConnectionError(f"Failed to create connection pool: {e}")
        
        return pool
    
    @contextmanager
    def get_connection(self):
        """Context manager to get a connection from the pool"""
        with self._connection_pool.connection() as conn:
            yield conn
    
    def _langchain_to_responses(self, messages: list[BaseMessage]) -> list[dict[str, Any]]:
        """Convert from LangChain messages to Responses API format"""
        responses = []
        for message in messages:
            message_dict = message.model_dump()
            msg_type = message_dict["type"]
            
            if msg_type == "ai":
                if tool_calls := message_dict.get("tool_calls"):
                    for tool_call in tool_calls:
                        responses.append(
                            self.create_function_call_item(
                                id=message_dict.get("id") or str(uuid.uuid4()),
                                call_id=tool_call["id"],
                                name=tool_call["name"],
                                arguments=json.dumps(tool_call["args"]),
                            )
                        )
                else:
                    responses.append(
                        self.create_text_output_item(
                            text=message_dict.get("content", ""),
                            id=message_dict.get("id") or str(uuid.uuid4()),
                        )
                    )
            elif msg_type == "tool":
                responses.append(
                    self.create_function_call_output_item(
                        call_id=message_dict["tool_call_id"],
                        output=message_dict["content"],
                    )
                )
            elif msg_type == "human":
                responses.append({
                    "role": "user",
                    "content": message_dict.get("content", "")
                })
        
        return responses
    
    def _create_graph(self, checkpointer: PostgresSaver):
        """Create the LangGraph workflow"""
        def should_continue(state: AgentState):
            messages = state["messages"]
            last_message = messages[-1]
            if isinstance(last_message, AIMessage) and last_message.tool_calls:
                return "continue"
            return "end"
        
        if self.system_prompt:
            preprocessor = RunnableLambda(
                lambda state: [{"role": "system", "content": self.system_prompt}] + state["messages"]
            )
        else:
            preprocessor = RunnableLambda(lambda state: state["messages"])
        
        model_runnable = preprocessor | self.model_with_tools
        
        def call_model(state: AgentState, config: RunnableConfig):
            response = model_runnable.invoke(state, config)
            return {"messages": [response]}
        
        workflow = StateGraph(AgentState)
        workflow.add_node("agent", RunnableLambda(call_model))
        
        if tools:
            workflow.add_node("tools", ToolNode(tools))
            workflow.add_conditional_edges(
                "agent",
                should_continue,
                {"continue": "tools", "end": END}
            )
            workflow.add_edge("tools", "agent")
        else:
            workflow.add_edge("agent", END)
        
        workflow.set_entry_point("agent")
        
        return workflow.compile(checkpointer=checkpointer)
    
    def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
        """Non-streaming prediction"""
        # The same thread_id is used by BOTH predict() and predict_stream()
        ci = dict(request.custom_inputs or {})
        if "thread_id" not in ci:
            ci["thread_id"] = str(uuid.uuid4())
        request.custom_inputs = ci

        outputs = [
            event.item
            for event in self.predict_stream(request)
            if event.type == "response.output_item.done"
        ]
        return ResponsesAgentResponse(output=outputs, custom_outputs={"thread_id": ci["thread_id"]})
    
    def predict_stream(
        self,
        request: ResponsesAgentRequest,
    ) -> Generator[ResponsesAgentStreamEvent, None, None]:
        """Streaming prediction with PostgreSQL checkpointing"""
        # Get thread ID from custom inputs or generate new one
        thread_id = (request.custom_inputs or {}).get("thread_id", str(uuid.uuid4()))
        
        # Convert incoming Responses messages to ChatCompletions format
        # LangChain will automatically convert from ChatCompletions to LangChain format
        cc_msgs = self.prep_msgs_for_cc_llm([i.model_dump() for i in request.input])
        langchain_msgs = cc_msgs
        
        checkpoint_config = {"configurable": {"thread_id": thread_id}}
        
        # Use connection from pool
        with self.get_connection() as conn:            
            # Create checkpointer and graph
            checkpointer = PostgresSaver(conn)
            graph = self._create_graph(checkpointer)
            
            # Stream the graph execution
            for event in graph.stream(
                {"messages": langchain_msgs},
                checkpoint_config,
                stream_mode=["updates", "messages"]
            ):
                if event[0] == "updates":
                    for node_data in event[1].values():
                        for item in self._langchain_to_responses(node_data["messages"]):
                            yield ResponsesAgentStreamEvent(
                                type="response.output_item.done",
                                item=item
                            )
                # Stream message chunks for real-time text generation
                elif event[0] == "messages":
                    try:
                        chunk = event[1][0]
                        if isinstance(chunk, AIMessageChunk) and chunk.content:
                            yield ResponsesAgentStreamEvent(
                                **self.create_text_delta(
                                    delta=chunk.content,
                                    item_id=chunk.id
                                ),
                            )
                    except Exception as e:
                        logger.error(f"Error streaming chunk: {e}")


# ----- Export model -----
AGENT = LangGraphResponsesAgent(LAKEBASE_CONFIG)
mlflow.models.set_model(AGENT)

# Test the Agent locally

In [0]:
dbutils.library.restartPython()

In [0]:
from agent import AGENT
# Message 1, don't include thread_id (creates new thread)
result = AGENT.predict({
    "input": [{"role": "user", "content": "I am working on stateful agents"}]
})
print(result.model_dump(exclude_none=True))
thread_id = result.custom_outputs["thread_id"]

In [0]:
# Message 2, include thread ID and notice how agent remembers context from previous predict message
response2 = AGENT.predict({
    "input": [{"role": "user", "content": "What am I working on?"}],
    "custom_inputs": {"thread_id": thread_id}
})
print("Response 2:", response2.model_dump(exclude_none=True))

In [0]:
# Example calling agent without passing in thread id - notice it does not retain the memory
response3 = AGENT.predict({
    "input": [{"role": "user", "content": "What am I working on?"}],
})
print("Response 3 No thread id passed:", response3.model_dump(exclude_none=True))

In [0]:
# predict stream example
for chunk in AGENT.predict_stream({
    "input": [{"role": "user", "content": "What am I working on?"}],
    "custom_inputs": {"thread_id": thread_id}
}):
    print("Chunk:", chunk.model_dump(exclude_none=True))

# Log the agent as an MLflow model
Log the agent as code from the agent.py file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code).

## Enable automatic authentication for Databricks resources
For the most common Databricks resource types, Databricks supports and recommends declaring resource dependencies for the agent upfront during logging. This enables automatic authentication passthrough when you deploy the agent. With automatic authentication passthrough, Databricks automatically provisions, rotates, and manages short-lived credentials to securely access these resource dependencies from within the agent endpoint.

To enable automatic authentication, specify the dependent Databricks resources when calling `mlflow.pyfunc.log_model()`.

**TODO:** 
- Add lakebase as a resource type
- If your Unity Catalog tool queries a [vector search index](https://docs.databricks.com/docs%20link) or leverages [external functions](https://docs.databricks.com/docs%20link), you need to include the dependent vector search index and UC connection objects, respectively, as resources. See docs ([AWS](https://docs.databricks.com/generative-ai/agent-framework/log-agent.html#specify-resources-for-automatic-authentication-passthrough) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/log-agent#resources)).

In [0]:
# Determine Databricks resources to specify for automatic auth passthrough at deployment time
import mlflow
from agent import tools, LLM_ENDPOINT_NAME, LAKEBASE_CONFIG
from databricks_langchain import VectorSearchRetrieverTool
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint, DatabricksLakebase
from unitycatalog.ai.langchain.toolkit import UnityCatalogTool
from pkg_resources import get_distribution

resources = [DatabricksServingEndpoint(LLM_ENDPOINT_NAME), DatabricksLakebase(database_instance_name=LAKEBASE_CONFIG["instance_name"])]

for tool in tools:
    if isinstance(tool, VectorSearchRetrieverTool):
        resources.extend(tool.resources)
    elif isinstance(tool, UnityCatalogTool):
        resources.append(DatabricksFunction(function_name=tool.uc_function_name))

input_example = {
    "input": [
        {
            "role": "user",
            "content": "What is an LLM agent?"
        }
    ],
    "custom_inputs": {"thread_id": "example-thread-123"},
}

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="agent",
        python_model="agent.py",
        input_example=input_example,
        resources=resources,
        pip_requirements=[
            "databricks-langchain",
            f"databricks-connect=={get_distribution('databricks-connect').version}",
            f"langgraph=={get_distribution('langgraph').version}",
            f"langgraph-checkpoint-postgres=={get_distribution('langgraph-checkpoint-postgres').version}",
            f"psycopg[binary,pool]",
            f"pydantic=={get_distribution('pydantic').version}",
        ]
    )

# Evaluate the agent with Agent Evaluation
Use Mosaic AI Agent Evaluation to evalaute the agent's responses based on expected responses and other evaluation criteria. Use the evaluation criteria you specify to guide iterations, using MLflow to track the computed quality metrics. See Databricks documentation ([AWS](https://docs.databricks.com/(https://docs.databricks.com/aws/generative-ai/agent-evaluation) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-evaluation/)).

To evaluate your tool calls, add custom metrics. See Databricks documentation ([AWS](https://docs.databricks.com/en/generative-ai/agent-evaluation/custom-metrics.html#evaluating-tool-calls) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/agent-evaluation/custom-metrics#evaluating-tool-calls)).

In [0]:
import mlflow
from mlflow.genai.scorers import RelevanceToQuery, RetrievalGroundedness, RetrievalRelevance, Safety

eval_dataset = [
    {
        "inputs": {"input": [{"role": "user", "content": "Calculate the 15th Fibonacci number"}]},
        "expected_response": "The 15th Fibonacci number is 610.",
    }
]

eval_results = mlflow.genai.evaluate(
    data=eval_dataset,
    predict_fn=lambda input: AGENT.predict({"input": input}),
    scorers=[RelevanceToQuery(), Safety()],  # add more scorers here if they're applicable
)

# Review the evaluation results in the MLfLow UI (see console output)

# Pre-deployment agent validation
Before registering and deploying the agent, perform pre-deployment checks using the mlflow.models.predict() API.

In [0]:
mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/agent",
    input_data={"input": [{"role": "user", "content": "I am working on stateful agents"}]},
    env_manager="uv",
)

# Register the model to Unity Catalog
Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog.

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
catalog = "catalog"
schema = "schema"
model_name = "stateful-agent-threads-example"

UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

Deploy the agent

In [0]:
from databricks import agents
agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version, tags = {"endpointSource": "docs"})

# Next steps
It will take around 15 minutes for you to finish deploying your agent. After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. 

Now, with your stateful agent, you can pick up past threads and continue the conversation.

You can query your Lakebase instance to see a record of your conversation at various threads/checkpoints. Here is a basic query to see 10 checkpoints:
```
-- See all conversation threads with their metadata
SELECT 
    *
FROM checkpoints
LIMIT 10;
```

Check most recently logged checkpoints:
```
SELECT
    c.*,
    (c.checkpoint::json->>'ts')::timestamptz AS ts
FROM checkpoints c
ORDER BY ts DESC
LIMIT 10;
```