# Retriever notebook :

In this notebook, we will 

In [0]:
%pip install --quiet -U databricks-sdk==0.64.0 "databricks-langchain>=0.4.0"  "mlflow[databricks]==3.4.0" langchain==0.3.25 langchain_core==0.3.59 databricks-vectorsearch==0.57 pydantic==2.10.1 zeroentropy
dbutils.library.restartPython()

source du dataset
https://github.com/mlflow/mlflow/blob/master/examples/llms/RAG/question_answer_source.csv

## 0- Init of the mlflow tracking

In this cell, the mlflow experiment is set to "rag_demo"

In [0]:
%run ../_config/config_rag

## 1- Create eval dataset

In  this cell, some questions are extracted from the mlflow repository on github  to create an evaluation dataset. 

"https://raw.githubusercontent.com/mlflow/mlflow/master/examples/llms/RAG/static_evaluation_dataset.csv"

In [0]:
import csv
import pandas as pd

QUESTIONS_SOURCE = "https://raw.githubusercontent.com/mlflow/mlflow/master/examples/llms/RAG/static_evaluation_dataset.csv"
eval_data = pd.read_csv(QUESTIONS_SOURCE)
question_serie = spark.createDataFrame(eval_data["question"]).withColumnRenamed("_1", "question")

In the work schema from your UC demo.demo, a delta table is created to log the questions.
 

In [0]:
# Catalog name
dbutils.widgets.text("catalog", "demo")
catalog = dbutils.widgets.get("catalog")
# Cellule PySpark
dbutils.widgets.text("schema", "demo")
schema = dbutils.widgets.get("schema")

# table name
dbutils.widgets.text("table_name", "databricks_mlflow_questions")
table_name = dbutils.widgets.get("table_name")

In [0]:
%sql
USE CATALOG IDENTIFIER(:catalog);
USE SCHEMA IDENTIFIER(:schema);

CREATE OR REPLACE TABLE  IDENTIFIER(:table_name) (
  id BIGINT GENERATED ALWAYS AS IDENTITY,
  question STRING,
  created_at TIMESTAMP DEFAULT current_timestamp
)
USING DELTA
TBLPROPERTIES (
  'delta.enableChangeDataFeed' = 'true',
  'delta.feature.allowColumnDefaults' = 'enabled'
);

--ALTER TABLE databricks_document_docling ALTER COLUMN created_at SET DEFAULT CURRENT_TIMESTAMP();

In [0]:
table_full_name = f"{catalog}.{schema}.{table_name}"

# Créer la table UC
question_serie.write \
    .format("delta") \
    .mode("append") \
    .saveAsTable(table_full_name)

In [0]:
%sql 
SELECT * FROM IDENTIFIER(:table_name) limit 2;

## 2- Create rag chain

No serving endpoints. The system is allready full of serving endpoints.

In [0]:
import yaml


rag_chain_config = {
    "databricks_resources": {
        "llm_endpoint_name": "chat_gpt_4o_mini",
        "vector_search_endpoint_name": "databricks_document_vs_endpoint_docling",
    },
    "input_example": {
        "messages": [{"content": "What are  the differences between MLflow, managed and open source, on databricks ?", "role": "user"}]
    },
    "llm_config": {
        "llm_parameters": {"max_tokens": 1500, "temperature": 0.01},
        "llm_prompt_template": "You are a trusted AI assistant that helps answer questions based only on the provided information. If you do not know the answer to a question, you truthfully say you do not know. Here is the history of the current conversation you are having with your user: {chat_history}. And here is some context which may or may not help you answer the following question: {context}.  Answer directly, do not repeat the question, do not start with something like: the answer to the question, do not add AI in front of your answer, do not say: here is the answer, do not mention the context or the question. Based on this context, answer this question: {question}",
        "llm_prompt_template_variables": ["context", "chat_history", "question"],
    },
    "retriever_config": {
        "chunk_template": "Passage: {chunk_text}\n",
        "data_pipeline_tag": "poc",
        "parameters": {"k": 3, "query_type": "ann"},
        "schema": {"raw_text": "content", "chunk_text": "contextualize_content", "document_uri": "url", "primary_key": "id"},
        "vector_search_index": f"{catalog}.{schema}.databricks_document_docling_vs_index",
        "source_table":f"{catalog}.{schema}.databricks_document_docling"
    },
}
try:
    with open('rag_chain_config.yaml', 'w') as f:
        yaml.dump(rag_chain_config, f)
except:
    print('pass to work on build job')
model_config = mlflow.models.ModelConfig(development_config='rag_chain_config.yaml')

In [0]:
%%writefile rag_chain.py
import os
import mlflow
from operator import itemgetter
from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

from databricks_langchain import DatabricksVectorSearch, ChatDatabricks
from datetime import datetime

## Enable MLflow Tracing
mlflow.langchain.autolog()

model_config = mlflow.models.ModelConfig(development_config='rag_chain_config.yaml')
#Create specific config
databricks_resources = model_config.get("databricks_resources")

llm_config = model_config.get("llm_config")

retriever_config = model_config.get("retriever_config")

vector_search_schema = retriever_config.get("schema")
# Return the string contents of the most recent message from the user
def extract_user_query_string(chat_messages_array):
    return chat_messages_array[-1]["content"]

#@mlflow.trace(name="extract_previous_messages")
def extract_previous_messages(chat_messages_array):
    messages = "\n"
    for msg in chat_messages_array[:-1]:
        messages += (msg["role"] + ": " + msg["content"] + "\n")
    return messages

#@mlflow.trace(name="combine_all_messages_for_vector_search")
def combine_all_messages_for_vector_search(chat_messages_array):
    return extract_previous_messages(chat_messages_array) + extract_user_query_string(chat_messages_array)


# Turn the Vector Search index into a LangChain retriever
vector_search_as_retriever = DatabricksVectorSearch(
    endpoint=databricks_resources.get("vector_search_endpoint_name"),
    index_name=retriever_config.get("vector_search_index"),
    columns=[
        vector_search_schema.get("primary_key"),
        vector_search_schema.get("chunk_text"),
        vector_search_schema.get("document_uri"),
    ],
).as_retriever(search_kwargs=retriever_config.get("parameters"))

@mlflow.trace(name="vector_search_retrieval", span_type="RETRIEVER", attributes={
    "index_name": retriever_config.get("vector_search_index"),
    "endpoint": databricks_resources.get("vector_search_endpoint_name")
})
def retrieve_documents(query: str):
    """Wrapper pour tracks the retriever calls"""
    results = vector_search_as_retriever.invoke(query)
    # Access current span to add attrbutes
    current_span = mlflow.get_current_active_span()
    
    results = vector_search_as_retriever.invoke(query)
    
    # Add num results attirbutes
    if current_span:
        current_span.set_attribute("num_results", len(results))
    return results

# Required to:
# 1. Enable the RAG Studio Review App to properly display retrieved chunks
# 2. Enable evaluation suite to measure the retriever
mlflow.models.set_retriever_schema(
    primary_key=vector_search_schema.get("primary_key"),
    text_column=vector_search_schema.get("chunk_text"),
    doc_uri=vector_search_schema.get("document_uri")
)


# Method to format the docs returned by the retriever into the prompt
#@mlflow.trace(name="format_context")
def format_context(docs):
    #chunk_template = retriever_config.get("chunk_template")
    chunk_contents = [
        f"Document : {d.page_content}"
        for d in docs
    ]
    #print(f"chunk_contents : {chunk_contents}")
    return "".join(chunk_contents)


# Prompt Template for generation
prompt = PromptTemplate(
    template=llm_config.get("llm_prompt_template"),
    input_variables=llm_config.get("llm_prompt_template_variables"),
)

# FM for generation
model = ChatDatabricks(
    endpoint=databricks_resources.get("llm_endpoint_name"),
    extra_params=llm_config.get("llm_parameters"),
)

# RAG Chain
chain = (
    { # Inputs prompt variables
        "question": # user question extracted for the list of messages  
            itemgetter("messages") |  #gathers the list of messages
            RunnableLambda(extract_user_query_string), # extracts the last message user query
        "context": # pipe the retriever
            itemgetter("messages") | 
            RunnableLambda(combine_all_messages_for_vector_search)| #Prepare the input for the vector search 
            retrieve_documents | # Retrieve the pertinent context
            RunnableLambda(format_context), 
        "chat_history": # message history
            itemgetter("messages") | 
            RunnableLambda(extract_previous_messages) # Extracts the messages history
    }
    | prompt
    | model
    | StrOutputParser()
)

# Tell MLflow logging where to find your chain.
mlflow.models.set_model(model=chain)

# COMMAND ----------

#print(chain.invoke(model_config.get("input_example")))

In [0]:
from mlflow.models.resources import DatabricksVectorSearchIndex, DatabricksServingEndpoint

# Log the model to MLflow
with mlflow.start_run(run_name="demo_rag_chain"):
    logged_chain_info = mlflow.langchain.log_model(
        lc_model="rag_chain.py",  # Chain code file e.g., /path/to/the/chain.py 
        model_config=rag_chain_config,  # Chain configuration 
        artifact_path="rag_chain",  # Required by MLflow
        input_example=model_config.get("input_example"),  # Save the chain's input schema
        resources=[
            DatabricksVectorSearchIndex(index_name=model_config.get("retriever_config").get("vector_search_index")),
            DatabricksServingEndpoint(endpoint_name=model_config.get("databricks_resources").get("llm_endpoint_name"))
        ],
        extra_pip_requirements=["databricks-connect"]
    )

# Test the chain locally
rag_chain = mlflow.langchain.load_model(logged_chain_info.model_uri)
rag_chain.invoke(model_config.get("input_example"))