In [0]:
%pip install -U langchain langgraph databricks-sdk databricks-vectorsearch
%pip install databricks-sdk[openai]
%pip install grandalf
%pip install pyppeteer
%pip install -U databricks-agents>=0.16.0 mlflow>=2.20.2 databricks-langchain databricks-openai
dbutils.library.restartPython()

In [0]:
%load_ext autoreload
%autoreload 2
# Enables autoreload; learn more at https://docs.databricks.com/en/files/workspace-modules.html#autoreload-for-python-modules
# To disable autoreload; run %autoreload 0

# Predictive Maintenance Agent

In [0]:
%%writefile flow_agent.py
# Import dependencies
from typing import Any, Optional
import uuid
import json
from datetime import datetime

from pydantic import BaseModel
import mlflow
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)

from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from databricks.vector_search.client import VectorSearchClient
from databricks_langchain import ChatDatabricks

mlflow.langchain.autolog()

# Configure
CATALOG = "workspace"
SCHEMA = "genai_demo"
MODEL_NAME = "isolation_forest_pm_model"
MODEL_VERSION = 1
MODEL_NAME_FULL = f"models:/{CATALOG}.{SCHEMA}.{MODEL_NAME}/{MODEL_VERSION}"
INDEX_NAME = "maintenance_docs_index"
INDEX_NAME_FULL = f"{CATALOG}.{SCHEMA}.{INDEX_NAME}"
# LLM_MODEL = "databricks-llama-4-maverick"
LLM_MODEL = "gpt-41"
TEMPERATURE = 0.1


# Define State Schema
class AgentState(BaseModel):
    timestamp: datetime
    machine_id: int
    temperature: float
    vibration: float
    pressure: float
    # Normal operating ranges
    normal_temp: tuple[float, float] = (20, 36)
    normal_vibration: tuple[float, float] = (1, 2.2)
    normal_pressure: tuple[float, float] = (2, 4.5)
    # RCA logs
    is_anomaly: bool = False
    raw_query: str = ""
    query: str = ""
    context: str = ""
    suggestion: str = ""


# Load trained model from MLflow
ad_model = mlflow.sklearn.load_model(MODEL_NAME_FULL)

# Initialize vector retriever
vsc = VectorSearchClient()
index = vsc.get_index(index_name=INDEX_NAME_FULL)

# Initialize LLM (Maverick)
llm = ChatDatabricks(
    target_uri="databricks",
    endpoint=LLM_MODEL,
    temperature=TEMPERATURE,
)


# Define Nodes
def detect_anomaly(state: AgentState) -> dict:
    X = [[state.temperature, state.vibration, state.pressure]]
    state.is_anomaly = ad_model.predict(X)[0] == -1
    return {"is_anomaly": state.is_anomaly}


def query_otimization(state: AgentState):
    raw_query = (f"Machine {state.machine_id} anomaly: T={state.temperature} [normal {state.normal_temp[0] - state.normal_temp[1]}], "
                 f"V={state.vibration} [normal {state.normal_vibration[0]} - {state.normal_vibration[1]}], "
                 f"P={state.pressure} [normal {state.normal_pressure[0]} - {state.normal_pressure[1]}]")
    
    q_opt_msg = [
        {"role":"system","content":"Rewrite the following to a concise, technical search query focusing on deviation from normal operation."},
        {"role":"user","content":raw_query}
    ]

    q_opt = llm.invoke(q_opt_msg).content
    return {"raw_query": raw_query, "query": q_opt}

def vector_search(state: AgentState):
    hits = index.similarity_search(query_text=state.query, columns=["chunk_text"], num_results=2, query_type="hybrid")
    context = "\n\n".join(hit[0] for hit in hits["result"]["data_array"])
    return {"context": context}

def rca(state: AgentState):
    prompt = [
        {"role":"system","content":"You're an engineer analyzing machinery anomalies."},
        {"role":"user","content":
         f"Anomaly details:\n{state.raw_query}\n\nContext:\n{state.context}\n\nProvide root cause and maintenance actions."}
    ]
    response = llm.invoke(prompt)
    suggestion = response.content
    return {"suggestion": suggestion}



def normal(state: AgentState) -> dict:
    suggestion = "✅ Machine is operating properly."
    state.suggestion = suggestion
    return {"suggestion": "✅ Machine is operating properly."}

def create_agent():
    workflow = StateGraph(AgentState)
    workflow.add_node("detect_anomaly", detect_anomaly)
    workflow.add_node("query_optimization", query_otimization)
    workflow.add_node("vector_search", vector_search)
    workflow.add_node("rca", rca)

    workflow.add_edge(START, "detect_anomaly")
    workflow.add_conditional_edges("detect_anomaly",
        lambda s: "query_optimization" if s.is_anomaly else "normal",
        {"query_optimization":"query_optimization", "normal": "normal"}
    )
    workflow.add_node("normal", normal)
    workflow.add_edge("normal", END)
    workflow.add_edge("query_optimization", "vector_search")
    workflow.add_edge("vector_search", "rca")
    workflow.add_edge("rca", END)

    agent = workflow.compile()
    return agent


class LangGraphChatAgent(ChatAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent
    def predict(
        self, 
        messages: list[ChatAgentMessage], 
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        messages = self._convert_messages_to_dict(messages)
        request = json.loads(messages[-1]['content'])
        response = self.agent.invoke(request)
        return ChatAgentResponse(messages=[ChatAgentMessage(role="assistant", content=response["suggestion"], id=str((uuid.uuid4())))])
        # return [ChatAgentMessage(role="assistant", content=response["suggestion"], id=str((uuid.uuid4())))]

flow_agent = create_agent()
FLOW_AGENT = LangGraphChatAgent(flow_agent)
mlflow.models.set_model(FLOW_AGENT)

In [0]:
from flow_agent import flow_agent
print(flow_agent.get_graph().draw_ascii())

In [0]:
import json
from datetime import datetime

from flow_agent import FLOW_AGENT

sensor_data_normal = {
    "timestamp": str(datetime.now()), 
    "machine_id": 1, 
    "temperature": 20, 
    "vibration": 1.5, 
    "pressure": 30
    }

request = {"messages": [{"role": "user", "content": json.dumps(sensor_data_normal)}]}
response = FLOW_AGENT.predict(request)
print(response.messages[0].content)

In [0]:
sensor_data_abnormal = {
    "timestamp": str(datetime.now()), 
    "machine_id": 1, 
    "temperature": 45, 
    "vibration": 3.5, 
    "pressure": 25
    }

request = {"messages": [{"role": "user", "content": json.dumps(sensor_data_abnormal)}]}
response = FLOW_AGENT.predict(request)
print(response.messages[0].content)

## Log and Register Agent to Unity Catalog

In [0]:
import os
import mlflow
from pkg_resources import get_distribution


sensor_data = {
    "timestamp": str(datetime.now()), 
    "machine_id": 1, 
    "temperature": 20, 
    "vibration": 1.5, 
    "pressure": 30
    }

input_example = {"messages": [{"role": "user", "content": json.dumps(sensor_data)}]}

# Log the agent
with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="flow_pm_agent",
        python_model="flow_agent.py",
        input_example=input_example,
        pip_requirements=[
            f"databricks-connect=={get_distribution('databricks-connect').version}",
            f"mlflow=={get_distribution('mlflow').version}",
            f"databricks-langchain=={get_distribution('databricks-langchain').version}",
            f"langgraph=={get_distribution('langgraph').version}",
        ]
    )

In [0]:
# Register the agent
CATALOG = "workspace"
SCHEMA = "genai_demo"
AGENT_NAME = 'flow_pm_agent'
AGENT_NAME_FULL = f"{CATALOG}.{SCHEMA}.{AGENT_NAME}"
registered_model_info = mlflow.register_model(model_uri=logged_agent_info.model_uri, name=AGENT_NAME_FULL)

## Test the Registered Agent

In [0]:
import mlflow
CATALOG = "workspace"
SCHEMA = "genai_demo"
AGENT_NAME = 'flow_pm_agent'
AGENT_VERSION = 1
AGENT_NAME_FULL = f"models:/{CATALOG}.{SCHEMA}.{AGENT_NAME}/{AGENT_VERSION}"
flow_agent_loaded = mlflow.pyfunc.load_model(AGENT_NAME_FULL)

In [0]:
sensor_data_abnormal = {
    "timestamp": str(datetime.now()), 
    "machine_id": 1, 
    "temperature": 45, 
    "vibration": 3.5, 
    "pressure": 25
    }

request = {"messages": [{"role": "user", "content": json.dumps(sensor_data_abnormal)}]}
response = flow_agent_loaded.predict(request)
print(response['messages'][0]['content'])