# Advanced Multi-Chain Routing with LangChain and Mistral Models

*This notebook should work well with the **`Data Science 3.0`**  or **`Python 3 (ipkernal)`** kernel in SageMaker Studio*

In the realm of generative AI application, there is often a need to leverage multiple data sources and capabilities to provide comprehensive and accurate responses to user queries. LangChain, a powerful framework for developing applications with large language models (LLMs), offers a flexible and modular approach to building such systems.

One of the key advantages of LangChain is its ability to combine multiple chains, each specialized in a specific task, into a single pipeline. This multi-chain routing approach allows for the seamless integration of different language models, data sources, and processing capabilities, enabling the creation of sophisticated and tailored solutions.

## Flexibility in Choosing LLMs with Multi-Chain Routing

By leveraging the multi-chain routing capability of LangChain, developers can incorporate multiple language models into a single pipeline, allowing each model to contribute its strengths and expertise to the overall solution. This approach enables the creation of more robust and accurate systems that can handle complex queries and tasks.

## Mistral AI in Amazon Bedrock

There are [four Mistral models available on Amazon Bedrock](https://aws.amazon.com/bedrock/mistral/) by the time of writing this notebook, offering flexibility to developers.

1. **Mistral Large**: Mistral AI’s most advanced large language model, Mistral Large is a cutting-edge text generation model with top-tier reasoning capabilities. Its precise instruction-following abilities enables application development and tech stack modernization at scale.
2. **Mistral 7B**: A 7B dense Transformer, fast-deployed and easily customizable. Small, yet powerful for a variety of use cases.
3. **Mixtral 8X7B**: A 7B sparse Mixture-of-Experts model with stronger capabilities than Mistral AI 7B. Uses 12B active parameters out of 45B total.
4. **Mistral Small**: Mistral Small is perfectly suited for straightforward tasks that can be performed in bulk, such as classification, customer support, or text generation.

## Use Case: Financial Services Industry (FSI)

To demonstrate the power of multi-chain routing with LangChain and Mistral models, we will explore a use case in the Financial Services Industry (FSI). In this scenario, a user wants to:

1. **Check Investment**: Determine if they have invested in a particular stock by querying a SQL database.
2. **Check Financial Reports**: Retrieve and analyze public financial reports and shareholder letters related to the stock using a Retrieval-Augmented Generation (RAG) chain.
3. **Check News**: Search for and retrieve relevant news articles about the stock or company using a search chain.

By combining these three capabilities into a single pipeline, the user can obtain a comprehensive overview of their investment, the company's performance, and the latest news and developments, all through a single query.

Throughout this notebook, we will walk through the process of setting up the individual chains, defining the routing logic, and integrating the Mistral models to power the multi-chain routing system.



## [Langchain Expression Language (LCEL)](https://python.langchain.com/v0.1/docs/expression_language/)

LangChain Expression Language, or LCEL, is a declarative way to easily compose chains together. In this notebook, we will use LCEL to implement the multiple chains and the orchestration workflow. To make it easy, we will implement our custom chains with [Runnable interface](https://python.langchain.com/v0.1/docs/expression_language/interface/)


---
## Setup

---

In [None]:
!pip install --upgrade --quiet langchain langchain-aws faiss-cpu duckduckgo-search --quiet

In [None]:
mistral_large_model_id = "mistral.mistral-large-2402-v1:0"
mistral_8x7b_model_id = "mistral.mixtral-8x7b-instruct-v0:1"
mistral_7b_model_id = "mistral.mistral-7b-instruct-v0:2"
mistral_small_model_id = "mistral.mistral-small-2402-v1:0"

# modify to the region of your choice
aws_region = "us-east-1"

In [None]:
import boto3
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_aws import BedrockLLM

bedrock_runtime = boto3.client(
    service_name="bedrock-runtime",
    region_name=aws_region,
)

In [None]:
mistral_large_llm = BedrockLLM(
    client=bedrock_runtime,
    model_id=mistral_large_model_id,
    model_kwargs={"temperature": 0.1},
    
)

mistral_8x7b_llm = BedrockLLM(
    client=bedrock_runtime,
    model_id=mistral_8x7b_model_id,
    model_kwargs={"temperature": 0.1},
)

mistral_7b_llm = BedrockLLM(
    client=bedrock_runtime,
    model_id=mistral_7b_model_id,
    model_kwargs={"temperature": 0.1},
)

mistral_small_llm = BedrockLLM(
    client=bedrock_runtime,
    model_id=mistral_small_model_id,
    model_kwargs={"temperature": 0.1},
)

In [None]:
import os

# create a local work directory to store data
workdir = "workspace"

try:
    os.makedirs(workdir, exist_ok=True)
except FileExistsError:
    print("The directory already exists.")

---
## SQL Chain: Text-to-SQL Translation with LLMs

One of the key components of our multi-chain routing system is the SQL Chain, which enables users to query structured data stored in databases using natural language queries. This capability is made possible through the text-to-SQL translation capabilities of large language models (LLMs) like Mistral.

LLMs have demonstrated remarkable performance in understanding and generating natural language, and this ability extends to translating natural language queries into structured SQL queries. By leveraging the language understanding and generation capabilities of LLMs, we can bridge the gap between human-friendly natural language and the structured query language used by databases.

Since we want the Text-to-SQL to be accurate and executable, Mistral Large is used for SQL chain.

---

Populate stock data

In [None]:
stock_data = [
    ["AAPL", "Apple Inc.", "Technology", 20, "NASDAQ", "USD"],
    ["MSFT", "Microsoft Corporation", "Technology", 18, "NASDAQ", "USD"],
    ["AMZN", "Amazon.com, Inc.", "Consumer Cyclical", 99, "NASDAQ", "USD"],
    ["NVDA", "NVIDIA Corporation", "Technology", 12, "NASDAQ", "USD"],
    ["TSLA", "Tesla, Inc.", "Consumer Cyclical", 10, "NASDAQ", "USD"],
    ["JPM", "JPMorgan Chase & Co.", "Finance", 20, "NYSE", "USD"],
    ["JNJ", "Johnson & Johnson", "Healthcare", 41, "NYSE", "USD"],
    ["XOM", "Exxon Mobil Corporation", "Energy", 33, "NYSE", "USD"],
    ["WMT", "Walmart Inc.", "Consumer Defensive",29, "NYSE", "USD"],
    ["PG", "Procter & Gamble Company", "Consumer Defensive", 35, "NYSE", "USD"]
]

In [None]:
import pandas as pd
df = pd.DataFrame(stock_data, columns=["ticker", "name", "sector", "shares", "exchange", "currency"])
df

Now import the data into a database table. SQLite is a lightweight, self-contained relational database that stores data in a single file. It allows applications to manage structured data without requiring a separate database server.

In [None]:
db_file = f"{workdir}/investment.db"
ddl = '''CREATE TABLE IF NOT EXISTS stocks
(ticker TEXT, name TEXT, sector TEXT, shares INTEGER, exchange TEXT, currency TEXT)'''

import sqlite3
conn = sqlite3.connect(db_file)
cur = conn.cursor()
cur.execute(ddl)

# Insert data from the DataFrame into the table
df.to_sql('stocks', conn, if_exists='replace', index=False)

# Commit the changes and close the connection
conn.commit()
conn.close()

In [None]:
def execute_sql(sql_query):
    conn = sqlite3.connect(db_file)
    cur = conn.cursor()
    cur.execute(sql_query)
    result = cur.fetchall()
    conn.close
    return result

SQL Chain can be broken down into two tasts:

1. Text-to-SQL
2. Generate answer

Take a look at [Langchain Runnable interface](https://python.langchain.com/v0.1/docs/expression_language/interface/), to understand the following components and their input and output types: Prompt, OutputParser, Retriever, Tool

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain.callbacks.tracers import ConsoleCallbackHandler

SQL_PROMPT_TEMPLATE = f'''<s>[INST]Based on the provided SQL table schema below, write a SQL query that would answer the question.
<schema> {ddl} </schema>
<question> {{question}} </question>
Just generate the SQL query without explanations
[/INST]
'''

text_to_sql_prompt = PromptTemplate.from_template(SQL_PROMPT_TEMPLATE)
text_to_sql = text_to_sql_prompt | mistral_large_llm | StrOutputParser() | RunnableLambda(func=execute_sql) 

In [None]:
FINAL_ANSWER_PROMPT_TEMPLATE = '''<s>[INST]Given the SQL query result: 
{result}
Produce a final response to the original question: 
{question}
[/INST]
'''

final_answer_prompt = PromptTemplate.from_template(FINAL_ANSWER_PROMPT_TEMPLATE)

sql_chain = (
    {
        "result": text_to_sql,
        "question": RunnablePassthrough()
    }
    | final_answer_prompt
    | mistral_large_llm
)

In [None]:
# ConsoleCallbackHandler() will print information about the chain's execution to the console, helping you see what's happening under the hood.
sql_chain.invoke(
    {"question": "Which stock I have the most shares?"}, 
    config={'callbacks': [ConsoleCallbackHandler()]}
)

---

## RAG Chain: Augmenting Language Models with Retrieval

In addition to the SQL Chain for querying structured data, our multi-chain routing system incorporates a Retrieval-Augmented Generation (RAG) chain for retrieving and processing unstructured data, such as financial reports and shareholder letters.

Mistral 8x7b is used for the RAG chain.

---

In [None]:
!pip install --upgrade pypdf --quiet

In [None]:
from urllib.request import urlretrieve
from tqdm import tqdm  # For progress bar

url_filename_map = {
    "https://s2.q4cdn.com/299287126/files/doc_financials/2024/ar/Amazon-com-Inc-2023-Shareholder-Letter.pdf": "Amazon-com-Inc-2023-Shareholder-Letter.pdf"
}

# Download files with progress bar
for url, filename in tqdm(url_filename_map.items(), unit="file"):
    urlretrieve(url, os.path.join(workdir, filename))

In [None]:
from langchain_community.document_loaders import PyPDFLoader

loader = PyPDFLoader(f"{workdir}/Amazon-com-Inc-2023-Shareholder-Letter.pdf")
pages = loader.load_and_split()

To set up the RAG chain, we need to use an embedding model to convert the text into embeddings, and store the embeddings in the vector database. [The Amazon Titan Text Embedding v2 model](https://docs.aws.amazon.com/bedrock/latest/userguide/titan-embedding-models.html) is used here. It can intake up to 8,192 tokens and outputs a vector of 1,024 dimensions. The model also works in 100+ different languages. The model is optimized for text retrieval tasks, but can also perform additional tasks, such as semantic similarity and clustering.

In [None]:
from langchain.embeddings import BedrockEmbeddings
from langchain_community.vectorstores import FAISS

bedrock_runtime = boto3.client(service_name="bedrock-runtime")
bedrock_embeddings = BedrockEmbeddings(
    model_id="amazon.titan-embed-text-v2:0", client=bedrock_runtime
)

# Use the recursive character splitter
vectorstore_faiss = FAISS.from_documents(
    pages,
    bedrock_embeddings,
)
vectorstore_faiss.save_local(f"{workdir}/faiss_index")

In [None]:
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)


def get_question(input):
    if not input:
        return None
    elif isinstance(input,str):
        return input
    elif isinstance(input,dict) and 'question' in input:
        return input['question']
    elif isinstance(input,BaseMessage):
        return input.content
    else:
        raise Exception("string or dict with 'question' key expected as RAG chain input.")


context_template = """
<s>[INST]Use the given context to answer the question. 
If you don't know the answer, respond "I don't know".
Keep your response as precise as possible and limit it to a few words. 

Here is the context:
{context}

Here is the question: 
{question}
[/INST]
"""

rag_prompt = PromptTemplate.from_template(context_template)

rag_chain = (
    {
        "context": RunnableLambda(get_question) | vectorstore_faiss.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | format_docs,
        "question": RunnablePassthrough()
    }
    | rag_prompt
    | mistral_8x7b_llm
)

In [None]:
rag_chain.invoke(
    {"question": "what's the biggest opportunity for Amazon?"},
    config={'callbacks': [ConsoleCallbackHandler()]}
)

---

## Web Search Chain: Retrieving News

In addition to querying structured data and analyzing unstructured documents, our multi-chain routing system incorporates a Web Search Chain to retrieve relevant news articles from the internet. This capability is essential in the financial services industry, where staying up-to-date with the latest news and developments can significantly impact investment decisions and financial analysis.

Langchain supports many [search tools](https://python.langchain.com/v0.1/docs/integrations/tools/search_tools/), including DuckDuckGo, Google Search, Bing Search etc. Here as an example, we will use DuckDuckGo.


Mistral 7B is used as an example.

---

In [None]:
from langchain_community.tools import DuckDuckGoSearchResults

search_prompt_template = '''
<s>[INST]Use the given search result to answer the question. 
If you don't know the answer, respond "I don't know".
Keep your response as precise as possible and limit it to a few words. 

Here is search result:
{search_result}

Here is the question: 
{question}
[/INST]
'''

search = DuckDuckGoSearchResults(backend="news")
search_prompt = PromptTemplate.from_template(search_prompt_template)

search_chain = (
    {
        "search_result": RunnableLambda(get_question) | search.run,
        "question": RunnablePassthrough()
    }
    | search_prompt
    | mistral_7b_llm
)

In [None]:
search_chain.invoke(
    {"question":"what's new for stock market today?"}
)

---

## Dynamic multi-chain routing


Now we wrap up the three chains into a single pipeline.

---

In [None]:
prompt_template='''
<s>[INST]Given the user question below, classify it as either being about `my-investment`, `company-financial-reports`, or `news`.

Do not respond with more than one word.

<question>
{question}
</question>
[/INST]
'''
chain_prompt = PromptTemplate.from_template(prompt_template)

chain = (
    chain_prompt
    | mistral_large_llm
    | StrOutputParser()
)

In [None]:
def route(info):
    if "my-investment" in info["topic"].lower():
        return sql_chain
    elif "company-financial-reports" in info["topic"].lower():
        return rag_chain
    else:
        return search_chain

In [None]:
full_chain = {
    "topic": chain, "question": lambda x: x["question"]
} | RunnableLambda(route)

In [None]:
full_chain.invoke({"question": "Which company do I have the most shares?"})

In [None]:
full_chain.invoke({"question": "What's amazon's biggest opportunity in its shareholder letter?"})

In [None]:
full_chain.invoke({"question": "What's new about stock market today?"})

## Concurrent Execution of Multiple Chains


In our Financial Services Industry use case, we often need to gather information from multiple sources simultaneously. RunnableParallel is a powerful feature in LangChain that allows us to execute multiple chains concurrently, improving the efficiency of our multi-chain routing system.

In [None]:
from langchain_core.runnables import RunnableParallel

In [None]:
# Define all prompt templates
SQL_QUERY_TEMPLATE = """<s>[INST]Write a SQL query to check all the information related to {company_name} from the stocks table:
Table schema: (ticker TEXT, name TEXT, sector TEXT, shares INTEGER, exchange TEXT, currency TEXT)
Table name: stocks
Just generate the SQL query without explanations.
[/INST]"""

INVESTMENT_TEMPLATE = """<s>[INST]Check if there are any investments in {company_name} in the portfolio:
{sql_result}

Table schema: (ticker TEXT, name TEXT, sector TEXT, shares INTEGER, exchange TEXT, currency TEXT)
Provide a concise summary of the investment position, including the number of shares.
[/INST]"""

FINANCIAL_TEMPLATE = """<s>[INST]Analyze the following financial report excerpt for {company_name}:
{report_content}

Provide key financial metrics and trends. Only use two sentances.
[/INST]"""

NEWS_TEMPLATE = """<s>[INST]Summarize the recent news about {company_name}:
{news_articles}

Highlight major developments and market sentiment. Only use two sentances.
[/INST]"""

In [None]:
# Create prompt templates
sql_prompt = PromptTemplate.from_template(SQL_QUERY_TEMPLATE)
investment_prompt = PromptTemplate.from_template(INVESTMENT_TEMPLATE)
financial_report_prompt = PromptTemplate.from_template(FINANCIAL_TEMPLATE)
news_prompt = PromptTemplate.from_template(NEWS_TEMPLATE)

In [None]:
# Create base chains for data retrieval
sql_query_chain = (
    {"company_name": RunnablePassthrough()} |
    sql_prompt | 
    mistral_7b_llm | 
    StrOutputParser() | 
    RunnableLambda(execute_sql)
)

In [None]:
# Create analysis chains
investment_chain = (
    {
        "company_name": RunnablePassthrough(),
        "sql_result": sql_query_chain
    } |
    investment_prompt |
    mistral_7b_llm |
    StrOutputParser()
)

financial_report_chain = (
    {
        "company_name": RunnablePassthrough(),
        "report_content": lambda x: rag_chain.invoke(x["company_name"])
    } |
    financial_report_prompt |
    mistral_large_llm |
    StrOutputParser()
)

news_chain = (
    {
        "company_name": RunnablePassthrough(),
        "news_articles": lambda x: search_chain.invoke(x["company_name"])
    } |
    news_prompt |
    mistral_8x7b_llm |
    StrOutputParser()
)

In [None]:
# Combine chains using RunnableParallel
stock_analysis_parallel = RunnableParallel(
    {
        "investment_status": investment_chain,
        "financial_analysis": financial_report_chain,
        "news_summary": news_chain
    }
)

def analyze_stock(company_name: str):
    #Analyze a stock by parallel processing of investment data, financial reports, and news articles.
    results = stock_analysis_parallel.invoke({"company_name": company_name})
    return results

In [None]:
results = analyze_stock("Amazon.com, Inc.")
    
print("\nInvestment Status:")
print(results["investment_status"])
print("\nFinancial Analysis:")
print(results["financial_analysis"])
print("\nNews Summary:")
print(results["news_summary"])