### Tutorial notebook from: https://docs.databricks.com/aws/en/generative-ai/tutorials/agent-framework-notebook#notebook

- Runs fine on Serverless

In [0]:
# Lock down some important packages
%pip install -U -qqqq mlflow langchain langgraph==0.3.4 databricks-langchain pydantic databricks-agents unitycatalog-langchain[databricks] 

#anyio==3.7.1

In [0]:
dbutils.library.restartPython()

In [0]:
from databricks_langchain import ChatDatabricks
import mlflow
mlflow.langchain.autolog()
import pandas as pd

In [0]:

LLM_ENDPOINT = "llama-instruct-3-3-70b"
llm = ChatDatabricks(endpoint=LLM_ENDPOINT)

In [0]:

databricks_docs_url = "https://raw.githubusercontent.com/databricks/genai-cookbook/refs/heads/main/quick_start_demo/chunked_databricks_docs_filtered.jsonl"
parsed_docs_df = pd.read_json(databricks_docs_url, lines=True)

In [0]:
from databricks_langchain.uc_ai import (
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)

uc_client = DatabricksFunctionClient()
set_uc_function_client(uc_client)

def tfidf_keywords(text: str) -> list[str]:
    """
    Extracts keywords from the provided text using TF-IDF.

    Args:
        text (string): Input text.
    Returns:
        list[str]: List of extracted keywords in ascending order of importance.
    """
    from sklearn.feature_extraction.text import TfidfVectorizer

    def extract_keywords(text, top_n=5):
        """Extracts top keywords from input text using trained TF-IDF vectorizer"""
        keyword_vectorizer = TfidfVectorizer(
            stop_words="english"
        )  # New vectorizer for query
        query_tfidf = keyword_vectorizer.fit_transform([text])  # Fit on query only
        scores = query_tfidf.toarray()[0]
        indices = scores.argsort()[-top_n:][::-1]  # Get top N keywords
        return [
            keyword_vectorizer.get_feature_names_out()[i]
            for i in indices
            if scores[i] > 0
        ]

    return extract_keywords(text)

# TODO fill in your catalog and schema name
catalog = "main"
schema = "default"

assert (catalog and schema)

# Create the function within the Unity Catalog catalog and schema specified
function_info = uc_client.create_python_function(
    func=tfidf_keywords,
    catalog=catalog,
    schema=schema,
    replace=True,  # Set to True to overwrite if the function already exists
)

uc_tool_names = [f"{catalog}.{schema}.tfidf_keywords"]
uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names)

In [0]:
print(uc_tool_names, uc_toolkit)

In [0]:
print(uc_toolkit.tools[0])
uc_toolkit.tools[0].invoke({"text": "The quick brown fox jumped over the lazy brown dog."})

In [0]:
from typing import Any

import mlflow
from langchain_core.tools import tool
from sklearn.feature_extraction.text import TfidfVectorizer

documents = parsed_docs_df
doc_vectorizer = TfidfVectorizer(stop_words="english")
tfidf_matrix = doc_vectorizer.fit_transform(documents["content"])

@tool
@mlflow.trace(name="LittleIndex", span_type=mlflow.entities.SpanType.RETRIEVER)
def find_relevant_documents(query: str, top_n: int = 5) -> list[dict[str, Any]]:
    """gets relevant documents for the query"""
    query_tfidf = doc_vectorizer.transform([query])
    similarities = (tfidf_matrix @ query_tfidf.T).toarray().flatten()
    ranked_docs = sorted(enumerate(similarities), key=lambda x: x[1], reverse=True)

    result = []
    for idx, score in ranked_docs[:top_n]:
        row = documents.iloc[idx]
        content = row["content"]
        doc_entry = {
            "page_content": content,
            "metadata": {
                "doc_uri": row["doc_uri"],
                "score": score,
            },
        }
        result.append(doc_entry)
    return result

In [0]:
from typing import Optional, Sequence, Union

from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode

def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[ToolNode, Sequence[BaseTool]],
    agent_prompt: Optional[str] = None,
) -> CompiledGraph:
    model = model.bind_tools(tools)

    def routing_logic(state: ChatAgentState):
        last_message = state["messages"][-1]
        if last_message.get("tool_calls"):
            return "continue"
        else:
            return "end"

    if agent_prompt:
        system_message = {"role": "system", "content": agent_prompt}
        preprocessor = RunnableLambda(
            lambda state: [system_message] + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    def call_model(
        state: ChatAgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)

        return {"messages": [response]}
    
    workflow = StateGraph(ChatAgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ChatAgentToolNode(tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        routing_logic,
        {
            "continue": "tools",
            "end": END,
        },
    )
    workflow.add_edge("tools", "agent")

    return workflow.compile()

In [0]:
import mlflow

mlflow.langchain.autolog()

agent = create_tool_calling_agent(llm, tools=[*uc_toolkit.tools, find_relevant_documents])
agent.invoke({"messages": [{"role": "user", "content":"What are the keywords for the sentence: 'the quick brown fox jumped over the lazy brown dog'?"}]})

In [0]:
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)
from typing import Any, Optional

class DocsAgent(ChatAgent):
  def __init__(self, agent):
    self.agent = agent

  def predict(
      self,
      messages: list[ChatAgentMessage],
      context: Optional[ChatContext] = None,
      custom_inputs: Optional[dict[str, Any]] = None,
  ) -> ChatAgentResponse:
      # ChatAgent has a built-in helper method to help convert framework-specific messages, like langchain BaseMessage to a python dictionary
      request = {"messages": self._convert_messages_to_dict(messages)}

      output = agent.invoke(request)
      # Here 'output' is already a ChatAgentResponse, but to make the ChatAgent signature explicit for this demonstration we are returning a new instance
      return ChatAgentResponse(**output)

In [0]:
AGENT = DocsAgent(agent=agent)
AGENT.predict({"messages": [{"role": "user", "content": "What is DLT in Databricks?"}]})

In [0]:
from mlflow.models import ModelConfig

baseline_config = {
   "endpoint_name": "databricks-meta-llama-3-3-70b-instruct",
   "temperature": 0.01,
   "max_tokens": 1000,
   "system_prompt": """You are a helpful assistant that answers questions about Databricks. Questions unrelated to Databricks are irrelevant.

    You answer questions using a set of tools. If needed, you ask the user follow-up questions to clarify their request.
    """,
   "tool_list": [f"{catalog}.{schema}.*"],
}

class DocsAgent(ChatAgent):
    def __init__(self):
        self.config = ModelConfig(development_config=baseline_config)
        self.agent = self._build_agent_from_config()

    def _build_agent_from_config(self):
        temperature = self.config.get("temperature")
        max_tokens = self.config.get("max_tokens")
        system_prompt = self.config.get("system_prompt")
        llm_endpoint_name = self.config.get("endpoint_name")
        tool_list = self.config.get("tool_list")

        llm = ChatDatabricks(endpoint=llm_endpoint_name, temperature=temperature, max_tokens=max_tokens)
        toolkit = UCFunctionToolkit(function_names=tool_list)
        agent = create_tool_calling_agent(llm, tools=[*toolkit.tools, find_relevant_documents], agent_prompt=system_prompt)

        return agent
    
    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        # ChatAgent has a built-in helper method to help convert framework-specific messages, like langchain BaseMessage to a python dictionary
        request = {"messages": self._convert_messages_to_dict(messages)}

        output = self.agent.invoke(request)
        # Here 'output' is already a ChatAgentResponse, but to make the ChatAgent signature explicit for this demonstration we are returning a new instance
        return ChatAgentResponse(**output)
    
agent = DocsAgent()
agent.predict({"messages": [{"role": "user", "content": "What is DLT"}]})

In [0]:
%%writefile getting_started_agent.py
from typing import Any, Optional, Sequence, Union

import mlflow
import pandas as pd
from databricks_langchain import ChatDatabricks
from databricks_langchain.uc_ai import (
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool, tool
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from mlflow.models import ModelConfig
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)
from sklearn.feature_extraction.text import TfidfVectorizer

databricks_docs_url = "https://raw.githubusercontent.com/databricks/genai-cookbook/refs/heads/main/quick_start_demo/chunked_databricks_docs_filtered.jsonl"
parsed_docs_df = pd.read_json(databricks_docs_url, lines=True)

documents = parsed_docs_df
doc_vectorizer = TfidfVectorizer(stop_words="english")
tfidf_matrix = doc_vectorizer.fit_transform(documents["content"])

@tool
@mlflow.trace(name="LittleIndex", span_type=mlflow.entities.SpanType.RETRIEVER)
def find_relevant_documents(query: str, top_n: int = 5) -> list[dict[str, Any]]:
    """gets relevant documents for the query"""
    query_tfidf = doc_vectorizer.transform([query])
    similarities = (tfidf_matrix @ query_tfidf.T).toarray().flatten()
    ranked_docs = sorted(enumerate(similarities), key=lambda x: x[1], reverse=True)

    result = []
    for idx, score in ranked_docs[:top_n]:
        row = documents.iloc[idx]
        content = row["content"]
        doc_entry = {
            "page_content": content,
            "metadata": {
                "doc_uri": row["doc_uri"],
                "score": score,
            },
        }
        result.append(doc_entry)
    return result

def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[ToolNode, Sequence[BaseTool]],
    agent_prompt: Optional[str] = None,
) -> CompiledGraph:
    model = model.bind_tools(tools)

    def routing_logic(state: ChatAgentState):
        last_message = state["messages"][-1]
        if last_message.get("tool_calls"):
            return "continue"
        else:
            return "end"

    if agent_prompt:
        system_message = {"role": "system", "content": agent_prompt}
        preprocessor = RunnableLambda(
            lambda state: [system_message] + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    def call_model(
        state: ChatAgentState,
        config: RunnableConfig,
    ):
                response = model_runnable.invoke(state, config)

        return {"messages": [response]}

    workflow = StateGraph(ChatAgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ChatAgentToolNode(tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        routing_logic,
        {
            "continue": "tools",
            "end": END,
        },
    )
    workflow.add_edge("tools", "agent")

    return workflow.compile()

class DocsAgent(ChatAgent):
    def __init__(self, config, tools):
        # Load config
        # When this agent is deployed to Model Serving, the configuration loaded here is replaced with the config passed to mlflow.pyfunc.log_model(model_config=...)
        self.config = ModelConfig(development_config=config)
        self.tools = tools
        self.agent = self._build_agent_from_config()

    def _build_agent_from_config(self):
        llm = ChatDatabricks(
            endpoint=self.config.get("endpoint_name"),
            temperature=self.config.get("temperature"),
            max_tokens=self.config.get("max_tokens"),
        )
        agent = create_tool_calling_agent(
            llm,
            tools=self.tools,
            agent_prompt=self.config.get("system_prompt"),
        )
        return agent
    
    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        # ChatAgent has a built-in helper method to help convert framework-specific messages, like langchain BaseMessage to a python dictionary
        request = {"messages": self._convert_messages_to_dict(messages)}

        output = self.agent.invoke(request)
        # Here 'output' is already a ChatAgentResponse, but to make the ChatAgent signature explicit for this demonstration we are returning a new instance
        return ChatAgentResponse(**output)
    
catalog = "main"
schema = "default"

LLM_ENDPOINT = LLM_ENDPOINT

baseline_config = {
    "endpoint_name": LLM_ENDPOINT,
    "temperature": 0.01,
    "max_tokens": 1000,
    "system_prompt": """You are a helpful assistant that answers questions about Databricks. Questions unrelated to Databricks are irrelevant.

    You answer questions using a set of tools. If needed, you ask the user follow-up questions to clarify their request.
    """,
}

tools = [find_relevant_documents]
uc_client = DatabricksFunctionClient()
set_uc_function_client(uc_client)
uc_toolkit = UCFunctionToolkit(function_names=[f"{catalog}.{schema}.*"])
tools.extend(uc_toolkit.tools)


AGENT = DocsAgent(baseline_config, tools)
mlflow.models.set_model(AGENT)