In [0]:
%pip install -U langchain langgraph databricks-sdk databricks-vectorsearch
%pip install databricks-sdk[openai]
%pip install grandalf
%pip install pyppeteer
%pip install -U databricks-agents>=0.16.0 mlflow>=2.20.2 databricks-langchain databricks-openai
%pip install --upgrade "mlflow-skinny[databricks]"
dbutils.library.restartPython()

In [0]:
%load_ext autoreload
%autoreload 2
# Enables autoreload; learn more at https://docs.databricks.com/en/files/workspace-modules.html#autoreload-for-python-modules
# To disable autoreload; run %autoreload 0

# Predictive Maintenance Agent

In [0]:
%%writefile chat_agent.py
from typing import Any, Generator, Optional, Sequence, Union
from langchain.chat_models import init_chat_model
from langchain.chat_models.base import BaseChatModel
from langchain_core.tools import tool
from typing import TypedDict, Annotated
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_community.chat_models.databricks import ChatDatabricks
from langgraph.checkpoint.memory import InMemorySaver
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from databricks.vector_search.client import VectorSearchClient
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)
from mlflow.langchain.chat_agent_langgraph import parse_message
from langgraph.graph.state import CompiledStateGraph

import mlflow

mlflow.langchain.autolog()

mlflow.set_registry_uri("databricks-uc")

# Configure
CATALOG = "workspace"
SCHEMA = "genai_demo"

# Anomaly Detection Model
MODEL_NAME = "isolation_forest_pm_model"
MODEL_NAME_FULL = f"{CATALOG}.{SCHEMA}.{MODEL_NAME  }"
MODEL_VERSION = 1
MODEL_URI = f'models:/{MODEL_NAME_FULL}/{MODEL_VERSION}'

# Vector Index
INDEX_NAME = "maintenance_docs_index"
INDEX_NAME_FULL = f"{CATALOG}.{SCHEMA}.{INDEX_NAME}"

# LLM
LLM_MODEL = "gpt-41"
TEMPERATURE = 0.1
# LLM_MODEL = "databricks-llama-4-maverick"


# Load resources: model, retriever, LLM
ad_model = mlflow.sklearn.load_model(MODEL_URI)

vsc = VectorSearchClient()
index = vsc.get_index(index_name=INDEX_NAME_FULL)

# ws = WorkspaceClient()
# chat_client = ws.serving_endpoints.get_open_ai_client()
llm = ChatDatabricks(
    target_uri="databricks",
    endpoint=LLM_MODEL,
    temperature=TEMPERATURE,
)


# Define tools
@tool
def anomaly_detector(vibration: float, pressure: float, temperature: float) -> str:
    """
    Detects anomalies in equipment behavior using vibration, pressure, and temperature.
    """
    try:
        prediction = ad_model.predict([[vibration, pressure, temperature]])
        result = "Anomalous" if prediction[0] == -1 else "Normal"
        return f"Anomaly Detection Result: {result}"
    except Exception as e:
        return f"Error: {str(e)}"
    

@tool
def vector_search(query: str) -> str:
    """
    Searches the vector index for machine manual documents."""
    try:
        # Search the index with the query string
        res = index.similarity_search(
            query_text=query,
            columns=["chunk_text"],
            num_results=1,
            query_type="hybrid"
            )
        context = "\n\n".join([r[0] for r in res["result"]["data_array"]])
        return context
    except Exception as e:
        return f"Vector search error: {str(e)}"


tools = [anomaly_detector, vector_search]

# Define Nodes
system_prompt = SystemMessage(
    content=(
        "You are a predictive maintenance engineer. Answer machine maintenance queries using the search index. "
        "If sensor data is provided, use the anomaly detection tool. "
        "If the machine is anomalous, ask user whether RCA and resolution is required if user does not suggest anything otherwise continue the task."
    )
)


# Add memory
checkpointer = InMemorySaver()

def create_agent(llm, tools, system_prompt):

    class AgentState(TypedDict):
        messages: Annotated[list[AnyMessage], add_messages]

    llm_with_tools = llm.bind_tools(tools)

    # assistant node
    def assistant_node(state: AgentState) -> AgentState:
        msgs = state["messages"]
        # Prepend system prompt if first turn
        if not any(isinstance(m, SystemMessage) for m in msgs):
            msgs = [system_prompt] + msgs

        response = llm_with_tools.invoke(msgs)
        return {"messages": [response]}
    
    # Tools node
    tools_node = ToolNode(tools)

    # Build graph
    builder = StateGraph(AgentState)
    builder.add_node("assistant", assistant_node)
    builder.add_node("tools", tools_node)

    builder.add_edge(START, "assistant")
    builder.add_conditional_edges(
        "assistant",
        tools_condition
    )
    builder.add_edge("tools", "assistant")

    # agent = builder.compile(checkpointer=checkpointer)
    agent = builder.compile()

    return agent


class LangGraphChatAgent(ChatAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def predict(
        self, 
        messages: list[ChatAgentMessage], 
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        request = {"messages": self._convert_messages_to_dict(messages)}
        res = self.agent.invoke(request)
        response = [ChatAgentMessage(**parse_message(r)) for r in res["messages"]]
        return ChatAgentResponse(messages=response)
    
    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:
        request = {"messages": self._convert_messages_to_dict(messages)}
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                for m in node_data.get("messages", []):
                    msg = parse_message(m)
                    yield ChatAgentChunk(delta=ChatAgentMessage(**msg))


pm_agent = create_agent(llm, tools, system_prompt)
AGENT = LangGraphChatAgent(pm_agent)
mlflow.models.set_model(AGENT)

In [0]:
from mlflow.langchain.chat_agent_langgraph import parse_message
parse_message(chunk)

In [0]:
parse_message(node_data["messages"][-1])

In [0]:
len(hist_lc)

In [0]:
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage,AIMessage
from chat_agent import pm_agent
messages = [
    # HumanMessage(content="hello"),
    # AIMessage(content="hello, how can I assist you?"),
    # HumanMessage(content="My machine is running very hot, I am wondering what the cuase is and how I can solve it.")
    {"role": "user", "content": "hello"},
    {"role": "assistant", "content": "hello, how can I assist you?"},
    {"role": "user", "content": "My machine is running very hot, could you help check the manual what could be cause and how I can fix it?"}
]
request = {"messages": messages}
# response = pm_agent.invoke(request)
# for m in response['messages']:
    # m.pretty_print()

hist_lc = []
for chunk in pm_agent.stream(request):
    # print(chunk)
    hist_lc.append(chunk)

In [0]:
config = {"configurable": {"thread_id": "1"}}
# messages = [HumanMessage(content="The machine's vibration is 3.2, pressure is 45, temperature is 78. If the machine is anomalous, please do a detail RCA and suggest a resolution.")]
# messages = [HumanMessage(content="The machine's vibration is 3.2, pressure is 45, temperature is 78.")]
messages = [HumanMessage(content="My machine is running very hot, I am wondering what the cuase is and how I can solve it.")]
response = pm_agent.invoke({"messages": messages,}, config)
for m in response['messages']:
    m.pretty_print()

In [0]:
# messages = response['messages'] + [HumanMessage(content="yes, pleasse")]
messages = [HumanMessage(content="yes, please")]
response = pm_agent.invoke({"messages": messages}, config)

for m in response['messages']:
    m.pretty_print()

## Log and Register Agent

In [0]:
import os
import mlflow
from pkg_resources import get_distribution
from mlflow.models.signature import infer_signature
from chat_agent import pm_agent

input_example = {"messages": [{"role": "user", "content": "how can you assist?"}]}
output_example = pm_agent.invoke(input_example)

# signature = infer_signature(input_example, output_example)

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="pm_agent",
        python_model="chat_agent.py",
        pip_requirements=[
            f"databricks-connect=={get_distribution('databricks-connect').version}",
            f"mlflow=={get_distribution('mlflow').version}",
            f"databricks-langchain=={get_distribution('databricks-langchain').version}",
            f"langgraph=={get_distribution('langgraph').version}",
        ]
    )



In [0]:
# Register the agent
CATALOG = "workspace"
SCHEMA = "genai_demo"
AGENT_NAME = 'pm_chat_agent'
AGENT_NAME_FULL = f"{CATALOG}.{SCHEMA}.{AGENT_NAME}"
registered_model_info = mlflow.register_model(model_uri=logged_agent_info.model_uri, name=AGENT_NAME_FULL)

## Deploy the agent

In [0]:
from databricks import agents
AGENT_VERSION = 1
agents.deploy(AGENT_NAME_FULL, AGENT_VERSION)