In [0]:
pip install --upgrade "mlflow[databricks]>=3.1.0" openai "databricks-connect>=16.1"

In [0]:
%%writefile rag_agent.py

from openai import OpenAI
from mlflow.pyfunc import PythonModel
import mlflow

class RAGAgent(PythonModel):
    def __init__(self):
        mlflow_creds = mlflow.utils.databricks_utils.get_databricks_host_creds()
        self.client = client = OpenAI(
            api_key=mlflow_creds.token,
            base_url=f"{mlflow_creds.host}/serving-endpoints"
        )
    
    def predict(self, context, model_input: list[str]) -> list[str]:
      return [self._generate_response(question) for question in model_input]
    @mlflow.trace
    def _generate_response(self, user_question: str) -> str:
        context = self._retrieve_context(user_question)
        response = self.client.chat.completions.create(
            model="databricks-llama-4-maverick",
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": f"Context: {context}\n\nQuestion: {user_question}"}
            ],
            temperature=0.7,
            max_tokens=150
        )
        return response.choices[0].message.content
    @mlflow.trace
    def _retrieve_context(self, query: str) -> str:
        if "mlflow" in query.lower():
            return "MLflow is an open-source platform for managing ML workflows."
        return "General info about machine learning."


from mlflow.models import set_model

AGENT = RAGAgent()
set_model(AGENT)



In [0]:
AGENT.predict(None, ["What is MLflow?", "What is machine learning?"])

In [0]:
AGENT.predict(None, "what is mlflow?")

In [0]:
import mlflow
from mlflow.models.resources import DatabricksServingEndpoint
from mlflow.models.signature import infer_signature

resources = [DatabricksServingEndpoint(endpoint_name="databricks-llama-4-maverick")]


with mlflow.start_run(run_name="chatbot_v2"):
    logged_agent_info = mlflow.pyfunc.log_model(
        name="agent",
        python_model="rag_agent.py",
        pip_requirements=[
            "mlflow",
            "pydantic",
        ],
        resources=resources,
        signature=infer_signature(
            ["What is MLflow?"],  # Example input (batch)
            ["MLflow is an open-source platform..."]  # Example output
        ),
        input_example=["What is MLflow?"]  # Batch-friendly example
    )


In [0]:
loaded_model = mlflow.pyfunc.load_model(logged_agent_info.model_uri)
loaded_model.predict(["What is MLflow?", "What is machine learning?"])

In [0]:
mlflow.register_model(
    model_uri=logged_agent_info.model_uri,
    name="ml_demo.default.agent_x"
)

In [0]:
traces = mlflow.search_traces()

In [0]:
traces

In [0]:
import mlflow.genai.datasets

evaluation_dataset_table_name = "ml_demo.default.rag"

eval_dataset = mlflow.genai.datasets.create_dataset(
    uc_table_name=evaluation_dataset_table_name,
)
eval_dataset.merge_records(traces)

# eval_dataset = mlflow.genai.get_dataset(evaluation_dataset_table_name)

In [0]:
eval_dataset.to_df()

In [0]:
def predict_fn(messages: list) -> dict:
  return model.predict({"messages": messages})

def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,

In [0]:
# def predict(self, context, model_input: list[str]) -> list[str]:
#       return [self._generate_response(question) for question in model_input]

#predict funtion expects a list, but it has to be serializable.

def predict_fn(user_question: list) -> dict:
  print(user_question)
  return loaded_model.predict({"model_input": user_question})



In [0]:
import mlflow
from mlflow.genai.scorers import Correctness

# Evaluate your app against expert expectations
eval_results = mlflow.genai.evaluate(
    data=eval_dataset,
    predict_fn=predict_fn,  
    scorers=[Correctness()]  # Compares outputs to expected_response
)