In [0]:
%pip install -U langchain langgraph databricks-sdk databricks-vectorsearch
%pip install databricks-sdk[openai]
%pip install grandalf
%pip install pyppeteer
%pip install -U databricks-agents>=0.16.0 mlflow>=2.20.2 databricks-langchain databricks-openai
dbutils.library.restartPython()

In [0]:
%load_ext autoreload
%autoreload 2
# Enables autoreload; learn more at https://docs.databricks.com/en/files/workspace-modules.html#autoreload-for-python-modules
# To disable autoreload; run %autoreload 0

# Predictive Maintenance Agent

In [0]:
# Import dependencies
from langgraph.graph import StateGraph, START, END
from pydantic import BaseModel
from datetime import datetime
import mlflow
from databricks.vector_search.client import VectorSearchClient
from databricks_langchain import ChatDatabricks

mlflow.langchain.autolog()

# Configure
CATALOG = "workspace"
SCHEMA = "genai_demo"
MODEL_NAME = "isolation_forest_pm_model"
MODEL_VERSION = 1
MODEL_NAME_FULL = f"models:/{CATALOG}.{SCHEMA}.{MODEL_NAME}/{MODEL_VERSION}"
INDEX_NAME = "maintenance_docs_index"
INDEX_NAME_FULL = f"{CATALOG}.{SCHEMA}.{INDEX_NAME}"
# LLM_MODEL = "databricks-llama-4-maverick"
LLM_MODEL = "gpt-41"
TEMPERATURE = 0.1


# Define State Schema
class AgentState(BaseModel):
    timestamp: datetime
    machine_id: int
    temperature: float
    vibration: float
    pressure: float
    # Normal operating ranges
    normal_temp: tuple[float, float] = (20, 36)
    normal_vibration: tuple[float, float] = (1, 2.2)
    normal_pressure: tuple[float, float] = (2, 4.5)
    # RCA logs
    is_anomaly: bool = False
    query: str = ""
    context: str = ""
    suggestion: str = ""


# Load trained model from MLflow
ad_model = mlflow.sklearn.load_model(MODEL_NAME_FULL)

# Initialize vector retriever
vsc = VectorSearchClient()
index = vsc.get_index(index_name=INDEX_NAME_FULL)

# Initialize LLM (Maverick)
llm = ChatDatabricks(
    target_uri="databricks",
    endpoint=LLM_MODEL,
    temperature=TEMPERATURE,
)


# Define Nodes
def detect_anomaly(state: AgentState) -> dict:
    X = [[state.temperature, state.vibration, state.pressure]]
    state.is_anomaly = ad_model.predict(X)[0] == -1
    return {"is_anomaly": state.is_anomaly}

def rca_with_query_optimization(state: AgentState):
    # 1. Optimize query
    raw_query = (f"Machine {state.machine_id} anomaly: T={state.temperature} [normal {state.normal_temp[0] - state.normal_temp[1]}], "
                 f"V={state.vibration} [normal {state.normal_vibration[0]} - {state.normal_vibration[1]}], "
                 f"P={state.pressure} [normal {state.normal_pressure[0]} - {state.normal_pressure[1]}]")
    q_opt_msg = [
        {"role":"system","content":"Rewrite the following to a concise, technical search query focusing on deviation from normal operation."},
        {"role":"user","content":raw_query}
    ]
    q_opt = llm.invoke(q_opt_msg).content
    state.query = q_opt

    # 2. Retrieve relevant documents
    hits = index.similarity_search(query_text=q_opt, columns=["chunk_text"], num_results=2, query_type="hybrid")
    context = "\n\n".join(hit[0] for hit in hits["result"]["data_array"])
    state.context = context
    # 3. Generate root cause & action
    prompt = [
        {"role":"system","content":"You're an engineer analyzing machinery anomalies."},
        {"role":"user","content":
         f"Anomaly details:\n{raw_query}\n\nContext:\n{context}\n\nProvide root cause and maintenance actions."}
    ]
    response = llm.invoke(prompt)
    suggestion = response.content
    state.suggestion = suggestion
    return {"suggestion": suggestion, "context": context, "query": q_opt}


def normal(state: AgentState) -> dict:
    suggestion = "✅ Machine is operating properly."
    state.suggestion = suggestion
    return {"suggestion": "✅ Machine is operating properly."}


workflow = StateGraph(AgentState)
workflow.add_node("detect_anomaly", detect_anomaly)
workflow.add_node("rca", rca_with_query_optimization)
workflow.add_edge(START, "detect_anomaly")
workflow.add_conditional_edges("detect_anomaly",
    lambda s: "rca" if s.is_anomaly else "normal",
    {"rca":"rca", "normal": "normal"}
)
workflow.add_node("normal", normal)
workflow.add_edge("normal", END)
workflow.add_edge("rca", END)

agent = workflow.compile()

In [0]:
# Plot the agent graph
print(agent.get_graph().draw_ascii())

In [0]:
# Normal Case

sensor_data = {
    "timestamp": datetime.now(), 
    "machine_id": 1, 
    "temperature": 20, 
    "vibration": 1.5, 
    "pressure": 3}
response = agent.invoke(sensor_data)
print(response["suggestion"])


In [0]:
# Anomalous Case

sensor_data = {
    "timestamp": datetime.now(), 
    "machine_id": 1, 
    "temperature": 65, 
    "vibration": 3.5, 
    "pressure": 1}

response = agent.invoke(sensor_data)

print(response["suggestion"])
print('\n----------------------------\n')
print(response["context"])