In [0]:
USE_YAML = False

CATALOG = 'perdomo_demos'
SCHEMA = "rag_chatbot"
TABLE_NAME = f"{CATALOG}.{SCHEMA}.databricks_documentation"
MODEL_NAME = f"{CATALOG}.{SCHEMA}.databricks_docs_rag_chatbot"

VECTOR_SEARCH_ENDPOINT_NAME = "rag_chatbot_endpoint"
VS_INDEX_FULLNAME = f"{TABLE_NAME}_vs_index"

NAMES = {
    "table_name": TABLE_NAME,
    "model_name": MODEL_NAME,
    "vector_search_endpoint_name": VECTOR_SEARCH_ENDPOINT_NAME,
    "vs_index_fullname": VS_INDEX_FULLNAME
}

CHAIN_CONFIG = {
    "databricks_resources": {
        "llm_endpoint_name": "databricks-meta-llama-3-3-70b-instruct",
        "vector_search_endpoint_name": VECTOR_SEARCH_ENDPOINT_NAME,
    },
    "input_example": {
        "messages": [{"content": "How can I disable serverless?", "role": "user"}]
    },
    "llm_config": {
        "llm_parameters": {"max_tokens": 1500, "temperature": 0.01},
        "llm_prompt_template": "You are a trusted AI assistant that helps answer questions based only on the provided information. If you do not know the answer to a question, you truthfully say you do not know. Here is the history of the current conversation you are having with your user: {chat_history}. And here is some context which may or may not help you answer the following question: {context}.  Answer directly, do not repeat the question, do not start with something like: the answer to the question, do not add AI in front of your answer, do not say: here is the answer, do not mention the context or the question. Based on this context, answer this question: {question}",
        "llm_prompt_template_variables": ["context", "chat_history", "question"],
    },
    "retriever_config": {
        "chunk_template": "Passage: {chunk_text}\n",
        "data_pipeline_tag": "poc",
        "parameters": {"k": 5}, # , "query_type": "ann"
        "schema": {"chunk_text": "content", "document_uri": "url", "primary_key": "id"},
        "vector_search_index": VS_INDEX_FULLNAME,
    },
}

In [0]:
#%pip install -r requirements.txt # Original had -U
#dbutils.library.restartPython()

import mlflow
import yaml

from operator import itemgetter

from langchain.schema.runnable import RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate

from databricks_langchain.vectorstores import DatabricksVectorSearch
from databricks_langchain.chat_models import ChatDatabricks

if USE_YAML:
    with open("names.yaml", "r") as file:
        names = yaml.safe_load(file)

    TABLE_NAME = names.get("table_name")
    VECTOR_SEARCH_ENDPOINT_NAME = names.get("vector_search_endpoint_name")
    VS_INDEX_FULLNAME = names.get("vs_index_fullname")
    CHAIN_CONFIG = mlflow.models.ModelConfig(development_config='rag_chain_config.yaml')

# Methods to format the docs returned by the retriever into the prompt (keep only the text from chunks)
def format_context(docs):
    chunk_contents = [f"Passage: {d.page_content}\n" for d in docs]
    return "".join(chunk_contents)
  
def extract_user_query_string(chat_messages_array: str):
    return chat_messages_array[-1]["content"]

def extract_previous_messages(chat_messages_array):
    messages = "\n"
    for msg in chat_messages_array[:-1]:
        messages += (msg["role"] + ": " + msg["content"] + "\n")
    return messages

def combine_all_messages_for_vector_search(chat_messages_array):
    return extract_previous_messages(chat_messages_array) + extract_user_query_string(chat_messages_array)

# Enable MLflow Tracing
mlflow.langchain.autolog()

databricks_resources = CHAIN_CONFIG.get("databricks_resources")
llm_config = CHAIN_CONFIG.get("llm_config")
retriever_config = CHAIN_CONFIG.get("retriever_config")

# Turn the Vector Search index into a LangChain retriever
vector_search_as_retriever = DatabricksVectorSearch(
    endpoint=databricks_resources.get("vector_search_endpoint_name"),
    index_name=retriever_config.get("vector_search_index"),
    columns=[
        retriever_config.get("schema").get("primary_key"),
        retriever_config.get("schema").get("chunk_text"),
        retriever_config.get("schema").get("document_uri"),
    ],
).as_retriever(k=5, query_type="ann")

mlflow.models.set_retriever_schema(
    primary_key=retriever_config.get("schema").get("primary_key"),
    text_column=retriever_config.get("schema").get("chunk_text"),
    doc_uri=retriever_config.get("schema").get("document_uri")
)

# Model and prompt configuration
prompt = PromptTemplate(
    template=llm_config.get("llm_prompt_template"),
    input_variables=llm_config.get("llm_prompt_template_variables"),
)

model = ChatDatabricks(
    endpoint=databricks_resources.get("llm_endpoint_name"),
    extra_params=llm_config.get("llm_parameters")
)

# RAG Chain
chain = (
    {
        "question": itemgetter("messages") | RunnableLambda(extract_user_query_string),
        "context": itemgetter("messages")
        | RunnableLambda(combine_all_messages_for_vector_search)
        | vector_search_as_retriever
        | RunnableLambda(format_context),
        "chat_history": itemgetter("messages") | RunnableLambda(extract_previous_messages)
    }
    | prompt
    | model
    | StrOutputParser()
)

# Tell MLflow logging where to find your chain.
mlflow.models.set_model(model=chain)