In [0]:
!pip install requirements.txt

In [0]:
dbutils.library.restartPython()

In [0]:
from langchain_databricks.vectorstores import DatabricksVectorSearch

###Write the chain to a .py file to avoid serialization issues
When working with LangChain and complex LLM-based workflows, you might need to write the LLM chain to a companion file to avoid serialization issues for several reasons:
1. Serialization limitations: Some LLM chain components, particularly custom tools or complex objects, may not be directly JSON serializable. This can cause errors when trying to save or transmit the entire chain as a single object.
2. Preservation of functionality: Certain elements of an LLM chain, such as custom functions or methods, cannot be fully serialized without losing their functionality. Writing these components separately allows you to maintain their integrity.
3. Flexibility and modularity: By separating the LLM chain definition from its serialized state, you can more easily update or modify parts of the chain without affecting the entire workflow.
4. Compatibility across environments: Serialization issues can arise when moving LangChain objects between different environments or platforms. Writing the chain separately ensures better compatibility and portability.
5. Security considerations: Serializing entire LLM chains might inadvertently expose sensitive information or credentials. Separating the chain definition allows for better control over what is serialized and transmitted.
6. Performance optimization: For large or complex chains, serializing only the necessary components can improve performance and reduce memory usage.
7. Version control and maintenance: Keeping the LLM chain definition in a separate file makes it easier to track changes, manage versions, and maintain the code over time

In [0]:
%%writefile chain.py
import os
import mlflow
from operator import itemgetter
from databricks.vector_search.client import VectorSearchClient
from langchain_community.chat_models import ChatDatabricks
from langchain_community.vectorstores import DatabricksVectorSearch
from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

## Enable MLflow Tracing
mlflow.langchain.autolog()

# Return the string contents of the most recent message from the user
def extract_user_query_string(chat_messages_array):
    return chat_messages_array[-1]["content"]

def extract_previous_messages(chat_messages_array):
    messages = "\n"
    for msg in chat_messages_array[:-1]:
        messages += (msg["role"] + ": " + msg["content"] + "\n")
    return messages

def combine_all_messages_for_vector_search(chat_messages_array):
    return extract_previous_messages(chat_messages_array) + extract_user_query_string(chat_messages_array)

#Get the conf from the local conf file
model_config = mlflow.models.ModelConfig(development_config='rag_model_config.yaml')

databricks_resources = model_config.get("databricks_resources")
retriever_config = model_config.get("retriever_config")
llm_config = model_config.get("llm_config")

# Connect to the Vector Search Index
vs_client = VectorSearchClient(disable_notice=True)
vs_index = vs_client.get_index(
    endpoint_name=databricks_resources.get("vector_search_endpoint_name"),
    index_name=retriever_config.get("vector_search_index"),
)
vector_search_schema = retriever_config.get("schema")

# Turn the Vector Search index into a LangChain retriever
vector_search_as_retriever = DatabricksVectorSearch(
    vs_index,
    text_column=vector_search_schema.get("chunk_text"),
    columns=[
        vector_search_schema.get("primary_key"),
        vector_search_schema.get("chunk_text")#,
##        vector_search_schema.get("document_uri"),
    ],
).as_retriever(search_kwargs=retriever_config.get("parameters"))

# Required to:
# 1. Enable the RAG Studio Review App to properly display retrieved chunks
# 2. Enable evaluation suite to measure the retriever
mlflow.models.set_retriever_schema(
    primary_key=vector_search_schema.get("primary_key"),
    text_column=vector_search_schema.get("chunk_text") #,
#    doc_uri=vector_search_schema.get("document_uri")
)

# Method to format the docs returned by the retriever into the prompt
def format_context(docs):
    chunk_template = retriever_config.get("chunk_template")
    chunk_contents = [
        chunk_template.format(
            chunk_text=d.page_content,
        )
        for d in docs
    ]
    return "".join(chunk_contents)

# Prompt Template for generation
prompt = PromptTemplate(
    template=llm_config.get("llm_prompt_template"),
    input_variables=llm_config.get("llm_prompt_template_variables"),
)

# Foundation Models for generation
model = ChatDatabricks(
    endpoint=databricks_resources.get("llm_endpoint_name"),
    extra_params=llm_config.get("llm_parameters"),
)

# RAG Chain
chain = (
    {
        "question": itemgetter("messages") | RunnableLambda(extract_user_query_string),
        "context": itemgetter("messages")
        | RunnableLambda(combine_all_messages_for_vector_search)
        | vector_search_as_retriever
        | RunnableLambda(format_context),
        "chat_history": itemgetter("messages") | RunnableLambda(extract_previous_messages)
    }
    | prompt
    | model
    | StrOutputParser()
)

# Tell MLflow logging where to find your chain.
mlflow.models.set_model(model=chain)

In [0]:
# Log the model to MLflow
with mlflow.start_run(run_name=f"call_transcript_rag"):
    logged_chain_info = mlflow.langchain.log_model(
        lc_model=os.path.join(os.getcwd(), 'chain.py'),  # Chain code file e.g., /path/to/the/chain.py 
        model_config='rag_model_config.yaml',  # Chain configuration 
        artifact_path="chain",  # Required by MLflow
        input_example=model_config.get("input_example"),  # Save the chain's input schema.  MLflow will execute the chain before logging & capture it's output schema.
    )

In [0]:
# Test the chain locally
chain = mlflow.langchain.load_model(logged_chain_info.model_uri)
chain.invoke(model_config.get("input_example"))