In [0]:
%load_ext autoreload
%autoreload 2
# Enables autoreload; learn more at https://docs.databricks.com/en/files/workspace-modules.html#autoreload-for-python-modules
# To disable autoreload; run %autoreload 0

# Predictive Maintenance Agent

In [0]:
%pip install -U langchain langgraph databricks-sdk databricks-vectorsearch
%pip install databricks-sdk[openai]
%pip install grandalf
%pip install pyppeteer
%pip install -U databricks-agents>=0.16.0 mlflow>=2.20.2 databricks-langchain databricks-openai
dbutils.library.restartPython()

In [0]:
import mlflow, mlflow.langchain
mlflow.langchain.autolog()

In [0]:
# Import dependencies
from langgraph.graph import StateGraph, START, END
from typing import Literal
from pydantic import BaseModel
from databricks.vector_search.client import VectorSearchClient
from databricks.sdk import WorkspaceClient
import pickle
from sklearn.ensemble import IsolationForest
import mlflow
from datetime import datetime

In [0]:
# Configure
catalog = "workspace"
schema = "genai_demo"
model_name = "isolation_forest_pm_model"
model_version = 4
AD_MODEL = f"models:/{catalog}.{schema}.{model_name}/{model_version}"
VECTOR_INDEX = "workspace.genai_demo.maintenance_docs_index"
EMBEDDING_MODEL = "databricks-gte-large-en"
LLM_MODEL = "databricks-llama-4-maverick"


In [0]:
# Define State Schema
class AgentState(BaseModel):
    timestamp: datetime
    machine_id: int
    temperature: float
    vibration: float
    pressure: float
    # Normal operating ranges
    normal_temp: tuple[float, float] = (20, 36)
    normal_vibration: tuple[float, float] = (1, 2.2)
    normal_pressure: tuple[float, float] = (2, 4.5)
    # RCA logs
    is_anomaly: bool = False
    query: str = ""
    context: str = ""
    suggestion: str = ""

In [0]:
# Load resources for model, retriever, LLM
# Load trained model from MLflow
ad_model = mlflow.sklearn.load_model(AD_MODEL)

# Initialize vector retriever
vsc = VectorSearchClient()
index = vsc.get_index(index_name=VECTOR_INDEX)  # adjust catalog/schema

# Initialize LLM (Maverick)
ws = WorkspaceClient()
chat_client = ws.serving_endpoints.get_open_ai_client()

In [0]:
# Define Nodes
def detect_anomaly(state: AgentState) -> dict:
    X = [[state.temperature, state.vibration, state.pressure]]
    state.is_anomaly = ad_model.predict(X)[0] == -1
    return {"is_anomaly": state.is_anomaly}

def rca_with_query_optimization(state: AgentState):
    # 1. Optimize query
    raw_query = (f"Machine {state.machine_id} anomaly: T={state.temperature} [normal {state.normal_temp[0] - state.normal_temp[1]}], "
                 f"V={state.vibration} [normal {state.normal_vibration[0]} - {state.normal_vibration[1]}], "
                 f"P={state.pressure} [normal {state.normal_pressure[0]} - {state.normal_pressure[1]}]")
    q_opt_msg = [
        {"role":"system","content":"Rewrite the following to a concise, technical search query focusing on deviation from normal operation."},
        {"role":"user","content":raw_query}
    ]
    q_opt = chat_client.chat.completions.create(model=LLM_MODEL, messages=q_opt_msg).choices[0].message.content
    state.query = q_opt

    # 2. Retrieve relevant documents
    hits = index.similarity_search(query_text=q_opt, columns=["chunk_text"], num_results=2, query_type="hybrid")
    context = "\n\n".join(hit[0] for hit in hits["result"]["data_array"])
    state.context = context
    # 3. Generate root cause & action
    prompt = [
        {"role":"system","content":"You're an engineer analyzing machinery anomalies."},
        {"role":"user","content":
         f"Anomaly details:\n{raw_query}\n\nContext:\n{context}\n\nProvide root cause and maintenance actions."}
    ]
    response = chat_client.chat.completions.create(model=LLM_MODEL, messages=prompt)
    suggestion = response.choices[0].message.content
    state.suggestion = suggestion
    return {"suggestion": suggestion}


def normal(state: AgentState) -> dict:
    suggestion = "✅ Machine is operating properly."
    state.suggestion = suggestion
    return {"suggestion": "✅ Machine is operating properly."}

In [0]:
# Build the LangGraph
workflow = StateGraph(AgentState)
workflow.add_node("detect_anomaly", detect_anomaly)
workflow.add_node("rca", rca_with_query_optimization)
workflow.add_edge(START, "detect_anomaly")
workflow.add_conditional_edges("detect_anomaly",
    lambda s: "rca" if s.is_anomaly else "normal",
    {"rca":"rca", "normal": "normal"}
)
workflow.add_node("normal", normal)
workflow.add_edge("normal", END)
workflow.add_edge("rca", END)

agent = workflow.compile()

In [0]:
print(agent.get_graph().draw_ascii())

In [0]:
# Test
# test = {
#     "timestamp": datetime.now(),
#     "machine_id": 2,
#     "temperature": 60.5,
#     "vibration": 3.7,
#     "pressure": 27.2,
# }

test = {
    "timestamp": datetime.now(),
    "machine_id": 2,
    "temperature": 25,
    "vibration": 1.5,
    "pressure": 3,
}
print(agent.invoke(test)["suggestion"])

## Wrap the agent

In [0]:
# from mlflow.deployments import get_deploy_client
# import mlflow
# from langchain.schema.runnable import Runnable
# from langgraph.graph import StateGraph

from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import ChatAgentMessage, ChatAgentResponse
from langgraph.graph.state import CompiledStateGraph
import uuid

In [0]:
import json
json.dumps({"time_stamp": "2015-07", "machine_id": 1, "temperature": 20, "vibration": 1.5, "pressure": 3})

In [0]:
class LangGraphChatAgent(ChatAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def predict(self, messages: list[ChatAgentMessage], **kwargs) -> ChatAgentResponse:
        messages = self._convert_messages_to_dict(messages)
        input = json.loads(messages[0]["content"])
        result = self.agent.invoke(input)["suggestion"]
        outputs = [ChatAgentMessage(id=str(uuid.uuid4()), role="assistant", content=result)]
        return ChatAgentResponse(messages=outputs)

In [0]:
chat_agent = LangGraphChatAgent(agent)

test = {
    "timestamp": datetime.now(),
    "machine_id": 2,
    "temperature": 25,
    "vibration": 1.5,
    "pressure": 3,
}

messages = [ChatAgentMessage(role="user", content=json.dumps(test, default=str))]
output = chat_agent.predict(messages)

In [0]:
output.messages[0].content

In [0]:
from agent import AGENT

In [0]:
from datetime import datetime
from mlflow.types.agent import ChatAgentMessage
import json

test = {
    "timestamp": datetime.now(),
    "machine_id": 2,
    "temperature": 65,
    "vibration": 5,
    "pressure": 1,
}

messages = [ChatAgentMessage(role="user", content=json.dumps(test, default=str))]
output = AGENT.predict(messages)
output.messages[0].content

## Log agent

In [0]:
from pkg_resources import get_distribution
import mlflow

test = {
    "timestamp": datetime.now(),
    "machine_id": 2,
    "temperature": 65,
    "vibration": 5,
    "pressure": 1,
}

input_example = {"messages": messages}

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="pm_agent",
        python_model="agent.py",
        # resources=resources,
        # input_example=[input_example],
        pip_requirements=[
            f"databricks-connect=={get_distribution('databricks-connect').version}",
            f"mlflow=={get_distribution('mlflow').version}",
            f"databricks-langchain=={get_distribution('databricks-langchain').version}",
            f"langgraph=={get_distribution('langgraph').version}",
        ],
    )

## Register the model to Unity Catalog

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
catalog = "workspace"
schema = "genai_demo"
model_name = "pm_agent"
UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

In [0]:
version = 1
agent = mlflow.pyfunc.load_model(f"models:/{UC_MODEL_NAME}/{version}")

In [0]:
test = {
    "timestamp": datetime.now(),
    "machine_id": 2,
    "temperature": 65,
    "vibration": 5,
    "pressure": 1,
}

input_data = {"messages": [{"role": "user", "content": json.dumps(test, default=str)}]}

agent.predict(input_data)