# 1. Configurar Ambiente

# 1.1 Variables

In [0]:
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
llm = dbutils.widgets.get("llm")
vs_index = dbutils.widgets.get("index")
vs_endpoint = dbutils.widgets.get("vs_endpoint")
volume = dbutils.widgets.get("volume")
exp_name = dbutils.widgets.get("exp_name")
model_name = dbutils.widgets.get("model_name")

# 1.2 DSPy

In [0]:
import dspy

# Configurar LLM que usa DSPy
lm = dspy.LM(llm)
dspy.configure(lm=lm)

# 1.3 MLflow

In [0]:
import mlflow

# Configurar MLflow con Databricks
mlflow.set_tracking_uri("databricks")
mlflow.set_experiment("/Workspace/Shared/MLflow Experiments/DSPy RAG")
mlflow.dspy.autolog(log_compiles=True, log_evals=True)

import mlflow # Storing artifacts in a volume requires MLflow 2.15.0 or above

ARTIFACT_PATH = f"dbfs:/Volumes/{catalog}/{schema}/{volume}/DSPy RAG Artifacts/" # can be a managed or external volume

mlflow.set_tracking_uri("databricks")
mlflow.set_registry_uri("databricks-uc")

if mlflow.get_experiment_by_name(exp_name) is None:
    mlflow.create_experiment(name=exp_name, artifact_location=ARTIFACT_PATH)
mlflow.set_experiment(exp_name)

# 2. Configurar RAG

# 2.1 Definir Signature

In [0]:
class RAG_SIGNATURE(dspy.Signature):
  """Receive user's request, search for context in index, and recommend a cooking recipe."""
  question: str = dspy.InputField(desc="Customer question on cooking recipe.")
  context: str = dspy.InputField(desc="Vector Search Index of cooking recipes.")
  recommendation: str = dspy.OutputField(desc="Recommended cooking recipe.")

# 2.2 Definir Lógica de RAG

In [0]:
class Recommendation(dspy.Module):
  def __init__(self, index = None, endpoint = None, for_mosaic_agent=True):
    from databricks.vector_search.client import VectorSearchClient
    from databricks.vector_search.utils import CredentialStrategy
    self.response_generator = dspy.Predict(RAG_SIGNATURE)
    self.vs_index = index if index is not None else f"{catalog}.{schema}.{vs_index}"
    self.vs_endpoint = endpoint if endpoint is not None else vs_endpoint

    # Definir cliente de Vector Search
    client = VectorSearchClient(
      credential_strategy=CredentialStrategy.MODEL_SERVING_USER_CREDENTIALS,
      disable_notice=True
    )
    self.retriever = client.get_index(
        endpoint_name=self.vs_endpoint, 
        index_name=self.vs_index
    )

  # Definir función de búsqueda
  def documentation_vector_search(self, text):
      # Obtener contexto usando búsqueda por similitud
      retrieved_context = self.retriever.similarity_search(
          query_text=text,
          columns=["title", "ingredients", "directions"],
          num_results=3,
          disable_notice=True
      )
      return(retrieved_context)
  
  # Generar y devolver respuesta
  def forward(self, question):
    context = self.documentation_vector_search(question)
    response = self.response_generator(context=context, question=question)
    return response

# 3. Probar agente

In [0]:
with mlflow.start_run(run_name="recommender") as run:
    recommender = Recommendation()
    print(recommender("What recipes can I prepare with chicken and rice?").recommendation)

    mlflow.log_param(run.info.run_id, "generated_prompt", recommender.inspect_history())

# 4. Guardar agente

In [0]:
from mlflow.models import ModelSignature
from mlflow.types.schema import ColSpec, Schema
from mlflow.models.resources import DatabricksVectorSearchIndex

input_schema = Schema([ColSpec("string")])
output_schema = Schema([ColSpec("string")])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
input_example = "What recipes can I prepare with chicken and rice?"

resources = [
    DatabricksVectorSearchIndex(index_name=f"{catalog}.{schema}.{vs_index}")
]

with mlflow.start_run(run_name="Log Model"):
    model_info = mlflow.dspy.log_model(
        recommender,
        registered_model_name=f"{catalog}.{schema}.{model_name}",
        model_config={"index": vs_index, "endpoint": vs_endpoint},
        input_example=[input_example],
        pip_requirements="requirements.txt",
        signature=signature,
        resources=resources
    )

    mlflow.models.predict(
        model_uri=model_info.model_uri,
        input_data=input_example,
        env_manager="uv",
    )

In [0]:
from mlflow.deployments import get_deploy_client

client = get_deploy_client("databricks")

endpoint = client.create_endpoint(
    name=model_name,
    config={
        "served_entities": [
            {
                "name": f"{model_name}-{model_info.registered_model_version}",
                "entity_name": f"{catalog}.{schema}.{model_name}",
                "entity_version": model_info.registered_model_version,
                "workload_size": "Small",
                "scale_to_zero_enabled": True,
            }
        ],
        "traffic_config": {
            "routes": [
                {
                    "served_model_name": f"{model_name}-{model_info.registered_model_version}",
                    "traffic_percentage": 100,
                }
            ]
        },
    },
)