In [0]:
%pip install -U -q -r ../requirements.txt
dbutils.library.restartPython()

In [0]:
%run ../00-init

In [0]:
%run ../sql_query

In [0]:
%run ../Main_Config

In [0]:
import yaml
import os
import logging

rag_chain_config = {
    "databricks_resources": {
        "llm_endpoint_name": "databricks-meta-llama-3-3-70b-instruct",
        "vector_search_endpoint_name": VECTOR_SEARCH_ENDPOINT_NAME,
        "server_hostname": dbutils.secrets.get(scope = "DATABRICKS_CREDENTIALS", key = "server_hostname"),
        "http_path": dbutils.secrets.get(scope = "DATABRICKS_CREDENTIALS", key = "http_path"),
        # "access_token": dbutils.secrets.get(scope = "DATABRICKS_CREDENTIALS", key = "access_token"),
        "sp_client_id": dbutils.secrets.get(scope = "DATABRICKS_CREDENTIALS", key = "sp_client_id"),
        "sp_client_secret": dbutils.secrets.get(scope = "DATABRICKS_CREDENTIALS", key = "sp_client_secret")
    },
    "input_example": {
        "messages": [{"content": "What is spire?", "role": "user"}],
        "extra_params": {"user_id": "unknown.person@abc.com"}
    },
    "llm_config": {
        "llm_parameters": {"max_tokens": 1500, "temperature": 0.01},
        "llm_prompt_template": """You are an expert AI assistant. Your responses must follow these rules:
General Behavior
• As a AI assistant Rely only on the provided context and conversation history—do not hallucinate or introduce external information.
• If the answer is not known from context,just day "I couldn't find enough relevant context—could you please provide more details so I can give you an accurate response".
• You must handle both casual conversation and domain-specific queries gracefully.
• When the user's question is vague, ambiguous, or lacks enough detail to  generate a helpful response, ask a clear and concise follow-up question to clarify their intent. Your goal is to help the user refine their question so you can provide an accurate and relevant answer.

If the question is clear: answer it directly using the retrieved context.

If the question is unclear or incomplete: respond with a short clarification like:

"Could you clarify what you're referring to with [ambiguous term]?"

"Can you provide more context or specify what you're looking for?"

Conversation Handling
• For casual messages like “Hello” or “How are you?”, respond warmly and naturally, like a friendly human.
• If the user’s message is vague or ambiguous, ask a clear follow-up question rather than guessing.

• If a user gives a follow-up question Follow these strict instructions:
    1. The system will first attempt to retrieve context based only on the user's most recent question.
    2. If no relevant information is found, it will then attempt retrieval using the full conversation history.
    
– If it’s clear, answer it directly.
– If it’s unclear, ask a clarifying question before responding.
• If the user asks a factual or task-based question, answer it clearly and concisely, based only on the provided context.
• Never invent or assume information outside the context.

Response Style
Be Direct: Answer the question without restating it.

No Fillers: Avoid phrases like “Here’s the answer” or “According to the context.”

No Meta-Commentary: Don’t mention the prompt, context, or chat structure.

Be Concise: Prefer one paragraph; only use more if necessary.

Be Helpful: Always move the conversation forward with useful answers or clarifying questions.

**last question**
{last_question}

**Conversation History:**  
{chat_history}

**Context:**  
{context}

**Question:**  
{question}
""",
    "llm_prompt_template_variables": ["context", "chat_history", "question","last_question"],
    },
    "retriever_config": {
        "chunk_template": "Passage: {chunk_text}\n",
        "data_pipeline_tag": "poc",
        "parameters": {
            "k": 5,
            "query_type": "ann"
        },
        "schema": {
            "chunk_text": "content",
            "primary_key": "id",
            "file_id": "file_id"
        },
        "vector_search_index": "spire_catalog.spire_schema.source_data_index"
    },
    "vector_embd_columns": ["content", "id", "file_id"],
    "user_list_name": ["spire_demo_grp"],
    "model_name": "spire_rag_final",
    "catalog_name": "spire_catalog",
    "schema_name": "spire_schema",
    "source_data_table": "source_data_index",
    "user_permission_table": "sharepoint_permissions"
}

# Save configuration to file with error handling
try:
    with open('rag_chain_config.yaml', 'w') as f:
        yaml.dump(rag_chain_config, f)
    logger.info("Configuration saved successfully")
except Exception as e:
    logger.warning(f"Failed to save configuration to file: {str(e)}")
    logger.info("Continuing with in-memory configuration for build job")

model_config = mlflow.models.ModelConfig(development_config='rag_chain_config.yaml')

In [0]:
%%writefile chain.py
import os
import time
import logging
import requests
import pandas as pd
import mlflow
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional
from operator import itemgetter

from databricks.vector_search.client import VectorSearchClient
from databricks import sql, sdk
from databricks.sdk.core import Config, oauth_service_principal

from langchain_community.chat_models import ChatDatabricks
from langchain_community.vectorstores import DatabricksVectorSearch
from langchain_core.runnables import RunnableLambda
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

from mlflow.models.rag_signatures import ChatCompletionRequest

# ----------------------------
#  Logging
# ----------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
logger = logging.getLogger(__name__)

mlflow.langchain.autolog()

model_config         = mlflow.models.ModelConfig(development_config="rag_chain_config.yaml")
databricks_resources = model_config.get("databricks_resources")
retriever_config     = model_config.get("retriever_config")
llm_config           = model_config.get("llm_config")
server_host_name    = databricks_resources["server_hostname"]
http_path           = databricks_resources["http_path"]
access_token_value  = databricks_resources["access_token"]
llm_endpoint = databricks_resources["llm_endpoint_name"]

input_example     = model_config.get("input_example")

# ----------------------------
#  Custom Chat Request Schema

# ----------------------------
def connect_to_databricks(server_hostname: str, http_path: str, sp_client_id: str, sp_client_secret:str):

    def credential_provider():
        config = Config(
            host          = f"https://{server_hostname}",
            client_id     = sp_client_id,
            client_secret =  sp_client_secret)
        return oauth_service_principal(config)

    return sql.connect(
        server_hostname=server_hostname,
        http_path=http_path,
        credentials_provider=credential_provider,
    )

def extract_file_id(user_id: str) -> List[str]:
    server = databricks_resources["server_hostname"]
    path   = databricks_resources["http_path"]
    sp_client_id = databricks_resources["sp_client_id"]
    sp_client_secret = databricks_resources["sp_client_secret"]
    conn   = connect_to_databricks(server, path, sp_client_id, sp_client_secret)
    query  = (f"""SELECT file_id FROM spire_catalog.spire_schema.sharepoint_permissions WHERE user_id = '{user_id}'""")

    df = pd.read_sql(query, conn)
    conn.close()
    ids = df["file_id"].drop_duplicates().tolist()
    if not ids:
        logger.warning(f"No file permissions for user {user_id}")
    return ids

# ----------------------------
def combine_all_messages(chat_messages: List[Dict[str, str]]) -> str:
    return extract_user_query_string(chat_messages)

def extract_user_query_string(chat_messages: List[Dict[str, str]]) -> str:
    return chat_messages[-1]["content"] 


def extract_previous_messages(chat_messages_array, assistant_response=None):
    messages = ""
    for msg in chat_messages_array:
        messages += f"\n{msg['role']}: {msg['content']}"
    if assistant_response:
        messages += f"  assistant: {assistant_response}\n"
    return messages


def store_response_with_history(inputs_and_response: dict):
    chat_messages = inputs_and_response["inputs"]["messages"]
    model_response = inputs_and_response["output"]
    
    updated_history = extract_previous_messages(chat_messages, assistant_response=model_response)
    
    return model_response 

def format_context(docs: List[Any]) -> str:
    if not docs:
        return "No relevant documents found."
    template = retriever_config["chunk_template"]
    return "".join(template.format(chunk_text=d.page_content) for d in docs)

def most_recent_message(chat_messages_array):
    if len(chat_messages_array) > 1:
        return chat_messages_array[-2]["content"]
    return chat_messages_array[-1]["content"]

def get_vector_search_index(retries: int = 3, delay: float = 2.0):
    client = VectorSearchClient(
        workspace_url=f"https://{databricks_resources['server_hostname']}",
        service_principal_client_id=databricks_resources['sp_client_id'], 
        service_principal_client_secret=databricks_resources['sp_client_secret'],
        disable_notice=True,
    )
    for i in range(retries):
        try:
            return client.get_index(
                endpoint_name=databricks_resources["vector_search_endpoint_name"],
                index_name=retriever_config["vector_search_index"],
            )
        except Exception as e:
            logger.warning(f"Index fetch failed (attempt {i+1}): {e}")
            time.sleep(delay)
    raise RuntimeError("Could not retrieve vector search index")

vs_index = get_vector_search_index()
schema   = retriever_config["schema"]

# ----------------------------
#  Retriever with embedded user_id
# ----------------------------
def dynamic_retriever_with_user(inputs: Dict[str, Any]) -> List[Any]:
    """
    inputs is a dict with keys:
      - "messages": List[{"role":.., "content":..}]
      - "custom_inputs": {"user_id":.., "filters":{..}}
    """
    chat_messages = inputs["messages"]
    user_id       = inputs["custom_inputs"]["user_id"]

    # lookup file_ids
    file_ids = extract_file_id(user_id)

    # build search kwargs
    params = retriever_config["parameters"].copy()
    params["filters"] = (
        {"file_id": file_ids}
        if file_ids
        else {"file_id": ["__no_permission__"]}
    )

    retr = DatabricksVectorSearch(
        vs_index,
        text_column=schema["chunk_text"],
        columns=[
            schema["primary_key"],
            schema["chunk_text"],
            schema["file_id"],
        ],
    ).as_retriever(search_kwargs=params)

    query = combine_all_messages(chat_messages)
    return retr.get_relevant_documents(query)

runnable_retriever = RunnableLambda(dynamic_retriever_with_user)

prompt = PromptTemplate(
    template=llm_config["llm_prompt_template"],
    input_variables=["chat_history", "context", "question"],
)

model = ChatDatabricks(
    endpoint=databricks_resources["llm_endpoint_name"],
    extra_params=llm_config["llm_parameters"],
)

core_chain = (
    {
        "context": (
            {"messages": itemgetter("messages"), "custom_inputs": itemgetter("extra_params")}
            | runnable_retriever
            | RunnableLambda(format_context)
        ),
        # Extract the question string
        "question": itemgetter("messages") | RunnableLambda(extract_user_query_string),
        # Extract chat history (everything except last message)
        "chat_history": itemgetter("messages") | RunnableLambda(extract_previous_messages),
        "last_question": itemgetter("messages") | RunnableLambda(most_recent_message),
    }
    | prompt
    | model
    | StrOutputParser()
)




chain = RunnableLambda(lambda inputs: {"inputs": inputs, "output": core_chain.invoke(inputs)}) | RunnableLambda(store_response_with_history)


mlflow.models.set_retriever_schema(
    primary_key=schema["primary_key"],
    text_column=schema["chunk_text"],
)
mlflow.models.set_model(model=chain)




In [0]:
input_example = {
        "messages": [{"content": "What is spire?", "role": "user"}],
        "extra_params":{"user_id": "aanchal.gupta@lirik.io"}
}
with mlflow.start_run(run_name=f"demo_rag_quickstart"):
    logged_chain_info = mlflow.langchain.log_model(
        lc_model=os.path.join(os.getcwd(), 'chain.py'),
        model_config='rag_chain_config.yaml',  # Chain configuration 
        artifact_path="chain",  # Required by MLflow
        input_example=input_example,  
    )


In [0]:
chain = mlflow.langchain.load_model(logged_chain_info.model_uri)
chain.invoke({
        "messages": [{"content": "What is spire?", "role": "user"}],
        "extra_params":{"user_id": "aanchal.gupta@lirik.io"}
    })

In [0]:
from databricks import agents
MODEL_NAME = model_config.get("model_name")
MODEL_NAME_FQN = f"{catalog}.{db}.{MODEL_NAME}"

In [0]:
# Register the chain to UC
uc_registered_model_info = mlflow.register_model(model_uri=logged_chain_info.model_uri, name=MODEL_NAME_FQN)

# Deploy to enable the Review APP and create an API endpoint
deployment_info = agents.deploy(model_name=MODEL_NAME_FQN, model_version=uc_registered_model_info.version, scale_to_zero=True)

wait_for_model_serving_endpoint_to_be_ready(deployment_info.endpoint_name)

In [0]:
user_list = user_list_name

agents.set_permissions(model_name=MODEL_NAME_FQN, users=user_list, permission_level=agents.PermissionLevel.CAN_QUERY)
