# LangGraph with Amazon Bedrock Knowledge Bases 

This notebook will show you know to make use of [Langgraph](https://python.langchain.com/docs/langgraph),and [Amazon Bedrock Knowledge bases](https://aws.amazon.com/bedrock/knowledge-bases/) as a RAG source to retrieve relevant documents from


In [71]:
import boto3
from botocore.config import Config
from langchain_community.retrievers import AmazonKnowledgeBasesRetriever
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_community.embeddings import  BedrockEmbeddings
from langchain_aws import ChatBedrock
from langgraph.graph import END, StateGraph
from typing import Dict, TypedDict
import os


# setup boto3 config to allow for retrying
my_region = "us-west-2"
my_config = Config(
    region_name = my_region,
    signature_version = 'v4',
    retries = {
        'max_attempts': 3,
        'mode': 'standard'
    }
)

# setup bedrock runtime client 
bedrock_rt = boto3.client("bedrock-runtime", config = my_config)
# setup bedrock agent runtime client
bedrock_agent_rt = boto3.client("bedrock-agent-runtime", config = my_config)
# setup S3 client
s3 = boto3.client("s3", config = my_config)

# Initialize Bedrock  
In this notebook, we will be making use of Anthropic's Calude 3 Sonnet model and the Amazon titan embeddings model. If you would like to use a different model, all the model IDs are available [here](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html)


In [72]:
sonnet_model_id = "anthropic.claude-3-sonnet-20240229-v1:0"
model_kwargs =  { 
    "max_tokens": 2048,
    "temperature": 0.0,
    "top_k": 250,
    "top_p": 1,
    "stop_sequences": ["Human"],
}

sonnet_llm = ChatBedrock(
    client=bedrock_rt,
    model_id=sonnet_model_id,
    model_kwargs=model_kwargs,
)

embeddings_model_id = "amazon.titan-embed-text-v1"
embedding_llm = BedrockEmbeddings(client = bedrock_rt, model_id = embeddings_model_id)

### Set your bedrock knowledge base ID here

In [70]:
bedrock_retriever = AmazonKnowledgeBasesRetriever(
    knowledge_base_id="<Bedrock KB ID>",
    region_name = my_region,
    retrieval_config={"vectorSearchConfiguration": {"numberOfResults": 4}},
)

## Create router
We will now define the router that will make decision on which RAG or functions to use. In this example we will only include 1 RAG, but feel free to extend this section to add your own tools

In [84]:
data_sources = """rag: Get data about amazon and amazon web services
fixed: get data about the IRS or taxes """


system_msg = f"""
<instructions>
You are an assistant that helps users pull data from the correct data source. You pick the correct data source based on the question asked
<tools>
{data_sources}
</tools>

You only output the name of the tool
</instructions>
"""

input_qn_template = """
<input question>
{question}
</input question>"""

route_template = [
    ("system", system_msg),
    ("user", input_qn_template),
]


route_prompt_template = ChatPromptTemplate.from_messages(route_template)

router_chain = route_prompt_template | sonnet_llm | StrOutputParser()



In [74]:
class GraphState(TypedDict):
    keys: Dict[str, any]

In [75]:
def router_invoke(state):
    print("Router invoked. ")
    # Invoke the router to pick the correct data source 
    state_dict = state["keys"]
    question = state_dict["question"]
    # run retriever object 
    route_out = router_chain.invoke({"question":question})
    return {"keys": {"route": route_out, "question": question}}


In [76]:

def generate(state):
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]

    print("Generator invoked")

    # run retriever object 
    template = [
        ("system", "You are a knowledgeable and helpful QnA bot. Your task is to provide accurate and relevant answers to questions asked by users. Use the provided context to answer the questions, and if the context does not contain enough information to answer a question, politely indicate that you do not have enough information."),
        ("user", """Use the following pieces of retrieved context to answer the question. 
         <context> {context} </context>
         <question> {question} </question> 
         Keep the answer concise """),
    ]

    prompt = ChatPromptTemplate.from_messages(template)
    rag_chain = prompt | sonnet_llm | StrOutputParser()

    generation = rag_chain.invoke({"context": documents, "question": question})


    return {"keys": {"generation": generation, "question": question}}

In [77]:

#  retrieve only
def retrieve(state):
    print("Retriever invoked")

    state_dict = state["keys"]
    question = state_dict["question"]
    docs = bedrock_retriever.invoke(question)
    documents = []
    for doc in docs:
        print(doc.page_content)
        documents.append(doc.page_content)

    print(documents)
    return {"keys": {"documents": documents , "question": question}}

In [78]:
#Mock fixed data source

#  retrieve only
def fixed_data(state):
    print("Fixed data source invoked")

    state_dict = state["keys"]
    question = state_dict["question"]

    documents = "The IRS, or Internal Revenue Service, is the federal agency responsible for administering and enforcing tax laws in the United States. It is a bureau of the Department of the Treasury and is responsible for collecting taxes and processing tax returns. The IRS plays a crucial role in the federal government's revenue collection efforts, ensuring that individuals and businesses pay their fair share of taxes to fund various public services and programs."

    return {"keys": {"documents": documents , "question": question}}

In [85]:
def rag_router(state):
    state_dict = state["keys"]
    route = state_dict["route"]
    if route.replace("\n","") == "fixed":
        return "fixed"
    else:
        return "rag"

## Graph Assembly

This graph will consist of routing between 2 different data sources and generation. It will pick the correct data source based on the inputted query. We have queries that will be able to read the original knowledge base containing information about amazon and the fixed output string talking about the IRS

In [87]:

workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("data_source_router", router_invoke)  
workflow.add_node("retrieve", retrieve)  
workflow.add_node("generate", generate)  
workflow.add_node("fixed_data", fixed_data)  

# Build graph
workflow.set_entry_point("data_source_router")
workflow.add_conditional_edges( ## Once the router is called, the output is read by the rag_router function and based on the final route output, it will trigger the correct node
    "data_source_router",
    rag_router,
    {
        "rag": "retrieve",
        "fixed": "fixed_data",
    },
)
workflow.add_edge("retrieve", "generate") #Adding fixed steps to generate the retrieved content
workflow.add_edge("fixed_data", "generate")
workflow.add_edge("generate", END)

# Compile
app = workflow.compile()

In [None]:
# Test case to invoke the KB
question = "Tell me about Amazon's work in Generative AI"
inputs = { "keys":{"question":question}}
graph_out = app.invoke(inputs)
print(graph_out["keys"]["generation"])

In [None]:
# Test case to invoke the fixed response
question = "Tell me about taxes"
inputs = { "keys":{"question":question}}
graph_out = app.invoke(inputs)
print(graph_out["keys"]["generation"])