# RAG Chain - MLflow Logging and load

In this notebook, we will log and deploy our RAG (Retrieval Augmented Generation) chain to MLflow. 

The RAG chain combines vector search retrieval with LLM generation to answer questions based on your organization's documents.

## Install Required Packages

Install the necessary libraries for:
- MLflow integration (`mlflow[databricks]`)
- LangChain framework (`langchain`, `langchain_core`)
- Databricks integrations (`databricks-sdk`, `databricks-langchain`, `databricks-vectorsearch`)

In [0]:
%pip install --quiet -U databricks-sdk==0.64.0 "databricks-langchain>=0.4.0"  "mlflow[databricks]==3.4.0" langchain==0.3.25 langchain_core==0.3.59 databricks-vectorsearch==0.57


In [0]:
# Restart Python kernel to load newly installed packages
dbutils.library.restartPython()

## 0- Initialization

Before logging our RAG chain, we need to initialize the MLflow experiment and Unity Catalog configuration.

### 0-1 Initialize MLflow Tracking

This configuration script sets up the MLflow experiment where all runs, metrics, and models will be tracked.
The experiment is named "rag_chain_demo" and will be used to organize all related model versions and evaluations.

In [0]:
%run ../_config/config_rag

### 0-2 Initialize Unity Catalog

Load the Unity Catalog configuration (catalog, schema) that defines where our vector search indexes and models are stored.
This ensures consistent namespace usage across all notebooks in the project.

In [0]:
%run "../_config/config_unity_catalog"

## 1- Log and Load RAG Chain with MLflow

In this section, we:
1. Load the RAG chain configuration from YAML
2. Log the chain to MLflow with all dependencies
3. Test the logged model locally

**Note:** We are not creating serving endpoints in this notebook to avoid overwhelming the system with endpoints.

### 1-1 Load the YAML Configuration

The YAML configuration file (`rag_chain_config.yaml`) contains:
- Databricks resource names (LLM endpoint, vector search index)
- LLM parameters (temperature, max_tokens, prompt template)
- Retriever configuration (number of results, search parameters)
- Input examples for testing

Using YAML configuration allows us to change model behavior without modifying code.

In [0]:
import yaml

# Load YAML config as dict
with open('rag_chain_config.yaml', 'r') as f:
    rag_chain_config_dict = yaml.safe_load(f) #The LLM endpoint has evovled to become claude_45 since the first version

### 1-2 Log Model Using MLflow LangChain Flavor

This cell logs the RAG chain to MLflow with:
- **lc_model**: Path to the chain code file (`rag_chain.py`)
- **model_config**: Configuration loaded from YAML
- **input_example**: Sample input for schema inference and testing
- **resources**: Databricks resources the model depends on (Vector Search index, LLM endpoint)

MLflow's LangChain flavor handles:
- Dependency tracking
- Model serialization
- Schema inference
- Resource declaration for governance

In [0]:
from mlflow.models.resources import DatabricksVectorSearchIndex, DatabricksServingEndpoint
import mlflow
# Log the RAG chain to MLflow
with mlflow.start_run(run_name="demo_rag_chain"):
    logged_chain_info = mlflow.langchain.log_model(
        lc_model="rag_chain.py",  # Chain code file path
        model_config=rag_chain_config_dict,  # Pass dict config
        artifact_path="rag_chain",  # Artifact path within the run (required by MLflow)
        input_example=rag_chain_config_dict.get("input_example"),  # Sample input for schema inference
        resources=[
            # Declare Vector Search index dependency
            DatabricksVectorSearchIndex(index_name=rag_chain_config_dict.get("retriever_config").get("vector_search_index")),
            # Declare LLM endpoint dependency
            DatabricksServingEndpoint(endpoint_name=rag_chain_config_dict.get("databricks_resources").get("llm_endpoint_name"))
        ],
        extra_pip_requirements=["databricks-connect"]  # Additional dependencies
    )


### 1-3 Load and Test the Model

After logging the model, we load it back from MLflow and test it with an example input.

This step verifies:
- The model was logged correctly
- All dependencies are properly serialized
- The model can be loaded and invoked successfully
- The output format matches expectations

This is a critical validation step before deploying to production.

In [0]:
# Get the input example from configuration
input_example = rag_chain_config_dict.get("input_example")

# Load the logged chain from MLflow
rag_chain = mlflow.langchain.load_model(logged_chain_info.model_uri)

# Test the chain with the example input
response = rag_chain.invoke(input_example)

# Display results
print(f"### Input Example: {input_example}")
print(f"### Model Response: {response}")
