
# Building Multi-stage AI System

In this, you will construct a multi-stage reasoning system using Databricks' features and LangChain.

You will start by building the first chain, which performs a search using a dataset containing product descriptions from Etsy. Following that, you will create the second chain, which creates an image for the proposed product. Finally, you will integrate these chains to form a complete multi-stage AI system.

In [0]:
%pip install -U -qq databricks-sdk databricks-vectorsearch langchain-databricks langchain==0.3.7 langchain-community==0.3.7

dbutils.library.restartPython()

In [0]:
%run ../Includes/Classroom-Setup-02LAB

In [0]:
print(f"Username:          {DA.username}")
print(f"Catalog Name:      {DA.catalog_name}")
print(f"Schema Name:       {DA.schema_name}")
print(f"Working Directory: {DA.paths.working_dir}")
print(f"Dataset Location:  {DA.paths.datasets}")

## Load Dataset

Before you start building the AI chain, you need to load and prepare the dataset and save it as a Delta table.  
For this demo, we will use the Databricks Documentation Dataset available from the Databricks Marketplace.

This dataset contains documentation pages with associated `id`, `url`, and `content`.  
We will format the data to create a single unified `document` field combining the URL and content, which will then be used to build a Vector Store.

The table will be created for you in the next code block.

In [0]:
## Load the docs table from Unity Catalog
vs_source_table_fullname = f"{DA.catalog_name}.{DA.schema_name}.docs"
create_docs_table(vs_source_table_fullname)
## Display a sample of the data
display(spark.sql(f"SELECT * FROM {vs_source_table_fullname}"))

%md 
## Create a Vector Store

In this step, you will compute embeddings for the dataset containing information about the products and store them in a Vector Search index using Databricks Vector Search.

**🚨IMPORTANT: Vector Search endpoints must be created before running the rest of the demo. These are already created for you in Databricks Lab environment.**


In [0]:
## Assign Vector Search endpoint by username
vs_endpoint_prefix = "vs_endpoint_"
vs_endpoint_name = vs_endpoint_prefix + str(get_fixed_integer(DA.unique_name("_")))
print(f"Assigned Vector Search endpoint name: {vs_endpoint_name}.")

In [0]:
## Index table name
vs_index_table_fullname = f"{DA.catalog_name}.{DA.schema_name}.doc_embeddings"

## Store embeddings in vector store
## NOTE: we're using 'content' as the embedding column
create_vs_index(vs_endpoint_name, vs_index_table_fullname, vs_source_table_fullname, "document" )

## Step 1: Build the First Chain (Vector Store Search)

In this task, you will create first chain that will search for product details from the Vector Store using a dataset containing product descriptions.

**Instructions:**
   - Configure components for the first chain to perform a search using the Vector Store.
   - Utilize the loaded dataset to generate prompts for Vector Store search queries.
   - Set up retrieval to extract relevant product details based on the generated prompts and search results.


In [0]:
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.prompts import PromptTemplate
from langchain_databricks import ChatDatabricks, DatabricksVectorSearch
from langchain_core.output_parsers import StrOutputParser

## Define the Databricks Chat model: llama-3
llm_llama = ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct", max_tokens=1000)

## Define the prompt template for generating search queries
prompt_template_vs = PromptTemplate.from_template(
    """
    You are a documentation assistant. Based on the following context from a technical document, generate a concise summary or relevant content snippet for answering the user’s question.

    Write a response that is aligned with the tone and format of technical documentation and helps the user understand or resolve their query.

    Maximum 300 words.

    Use the following document snippet and context as example;

    <context>
    {context}
    </context>

    Question: {input}
    """
)

## Construct the RetrievalQA chain for Vector Store search
def get_retriever(persist_dir=None):
    vsc = VectorSearchClient(disable_notice=True)
    vs_index = vsc.get_index(vs_endpoint_name, vs_index_table_fullname)
    vectorstore = DatabricksVectorSearch(vs_index_table_fullname)
    return vectorstore.as_retriever(search_kwargs={"k": 3})

## Construct the chain for question-answering
question_answer_chain = create_stuff_documents_chain(llm_llama, prompt_template_vs)
chain1 = create_retrieval_chain(get_retriever(), question_answer_chain)

## Invoke the chain with an example query   
response = chain1.invoke({"input": "How do I create a Delta table?"})
print(response['answer'])

## Step 2: Build the Second Chain (Optimization)

In this step, you will create a second chain to enhance the product details generated by the first chain. This optimization process aims to make the descriptions more compelling and SEO-friendly. In a real-world scenario, this model could be trained on your internal data or fine-tuned to align with your specific business objectives.

**Instructions:**

- Define a second chain using `llama-3-70b-instruct`.  

- Create a prompt to optimize the generated product description. For example:  
  *"You are a marketing expert. Revise the product title and description to be SEO-friendly and more appealing to Databricks users."*

- Use `product_details` as the parameter to be passed into the prompt.  

- Implement the chain and test it with a sample input.  


In [0]:
## Define the Databricks Chat model using llama-3-3-70b-instruct
llm_llama3 = ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct", max_tokens=1000)

## Define the prompt template for refining documentation output
doc_optimization_prompt = PromptTemplate.from_template(
    """
    You are a technical writer. Improve the following documentation snippet to make it clearer, concise, and aligned with the tone used in Databricks documentation.

    Documentation snippet: {doc_snippet}

    Return only the revised documentation content.
    """
)

## Define chain 2
chain2 = doc_optimization_prompt | llm_llama3 | StrOutputParser()

## Test the chain
chain2.invoke({"doc_snippet": "Query testing product with mobile app control"})

## Step 3: Integrate Chains into a Multi-chain System

In this task, you will link the individual chains created in Task 2 and Task 3 together to form a multi-chain system that can handle multi-stage reasoning.

**Instructions:**

- Use Databricks **`Llama Chat model`** for processing text inputs, which is defined above in the first task.

- Create a prompt template to generate an **`HTML page`** for displaying generated product details.

- Construct the **`Multi-Chain System`**  by combining the outputs of the previous chains. **Important**: You will need to rename the out of the first chain and second chain while passing them to the next stage. This sequential chain should be as; **chain3 = chain1 > (`product_details`) > chain2 > `(optimized_product_details)` > prompt3**.  

- Invoke the multi-chain system with the input data to generate the HTML page for the specified product.


In [0]:
from langchain.schema.runnable import RunnablePassthrough, RunnableMap
from langchain_core.output_parsers import StrOutputParser
from IPython.display import display, HTML

## Define the prompt template for generating the HTML page
prompt_template_3 = PromptTemplate.from_template(
    """Create an HTML section for the following technical documentation snippet:
    
    Content: {optimized_doc}

    Return valid HTML (no head/body tags).
    """
)

## Construct multi-stage chain
chain3 = (
    chain1
    | RunnableMap({"doc_snippet": lambda x: x["answer"]})
    | chain2
    | RunnableMap({"optimized_doc": lambda x: x})
    | prompt_template_3
    | llm_llama
    | StrOutputParser()
)

## Sample query
query = {
    "input": "How do I create a Delta table in Databricks?"
}

output_html = chain3.invoke(query)

## Display the generated HTML output
display(HTML(output_html))

## Step 4: Save the Chain to Model Registry in UC

In this task, you will save the multi-stage chain system within our Unity Catalog.

**Instructions:**

- Set the model registry to UC and use the model name defined.

- Log and register the final multi-chain system.

- To test the registered model, load the model back from model registry and query it using a sample query. 

After registering the chain, you can view the chain and models in the **Catalog Explorer**.

In [0]:
from mlflow.models import infer_signature
import mlflow

## Set model registry to UC
mlflow.set_registry_uri("databricks-uc")
model_name = f"{DA.catalog_name}.{DA.schema_name}.multi_stage_doc_chain"

## Log the model
with mlflow.start_run(run_name="multi_stage_doc_chain") as run:
    signature = infer_signature(query, output_html)
    model_info = mlflow.langchain.log_model(
        chain3,
        loader_fn=get_retriever, 
        artifact_path="chain",
        registered_model_name=model_name,
        input_example=query,
        signature=signature
    )

## Load and test the model
model_uri = f"models:/{model_name}/{model_info.registered_model_version}"
model = mlflow.langchain.load_model(model_uri)

output_html = model.invoke(query)
display(HTML(output_html))