Skip to content

Commit

Permalink
Merge pull request #8 from aws-samples/may-updates
Browse files Browse the repository at this point in the history
May updates
  • Loading branch information
omerh committed May 5, 2024
2 parents a3e26bb + 3fd7a33 commit e3d7e1d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 32 deletions.
59 changes: 27 additions & 32 deletions ask-bedrock-with-rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from utils import opensearch, secret
from langchain_community.embeddings import BedrockEmbeddings
from langchain_community.vectorstores import OpenSearchVectorSearch
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.llms.bedrock import Bedrock
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.chat_models import BedrockChat
import boto3
from loguru import logger
import sys
Expand All @@ -18,10 +19,10 @@

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ask", type=str, default="What is <3?")
parser.add_argument("--ask", type=str, default="What is the meaning of <3?")
parser.add_argument("--index", type=str, default="rag")
parser.add_argument("--region", type=str, default="us-east-1")
parser.add_argument("--bedrock-model-id", type=str, default="anthropic.claude-v2")
parser.add_argument("--bedrock-model-id", type=str, default="anthropic.claude-3-sonnet-20240229-v1:0")
parser.add_argument("--bedrock-embedding-model-id", type=str, default="amazon.titan-embed-text-v1")

return parser.parse_known_args()
Expand Down Expand Up @@ -51,7 +52,7 @@ def create_opensearch_vector_search_client(index_name, opensearch_password, bedr


def create_bedrock_llm(bedrock_client, model_version_id):
bedrock_llm = Bedrock(
bedrock_llm = BedrockChat(
model_id=model_version_id,
client=bedrock_client,
model_kwargs={'temperature': 0}
Expand All @@ -60,12 +61,14 @@ def create_bedrock_llm(bedrock_client, model_version_id):


def main():
logger.info("Starting")
logger.info("Starting...")
args, _ = parse_args()
region = args.region
index_name = args.index
bedrock_model_id = args.bedrock_model_id
bedrock_embedding_model_id = args.bedrock_embedding_model_id
question = args.ask
logger.info(f"Question provided: {question}")

# Creating all clients for chain
bedrock_client = get_bedrock_client(region)
Expand All @@ -76,39 +79,31 @@ def main():
opensearch_vector_search_client = create_opensearch_vector_search_client(index_name, opensearch_password, bedrock_embeddings_client, opensearch_endpoint)

# LangChain prompt template
if len(args.ask) > 0:
question = args.ask
else:
question = "what is the meaning of <3?"
logger.info(f"No question provided, using default question {question}")

prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. don't include harmful content
prompt = ChatPromptTemplate.from_template("""Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. don't include harmful content
{context}
Question: {question}
Answer:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
Question: {input}
Answer:""")

logger.info(f"Starting the chain with KNN similarity using OpenSearch, Bedrock FM {bedrock_model_id}, and Bedrock embeddings with {bedrock_embedding_model_id}")
qa = RetrievalQA.from_chain_type(llm=bedrock_llm,
chain_type="stuff",
retriever=opensearch_vector_search_client.as_retriever(),
return_source_documents=True,
chain_type_kwargs={"prompt": PROMPT, "verbose": True},
verbose=True)
docs_chain = create_stuff_documents_chain(bedrock_llm, prompt)
retrieval_chain = create_retrieval_chain(
retriever=opensearch_vector_search_client.as_retriever(),
combine_docs_chain = docs_chain
)

response = qa.invoke(question, return_only_outputs=False)
logger.info(f"Invoking the chain with KNN similarity using OpenSearch, Bedrock FM {bedrock_model_id}, and Bedrock embeddings with {bedrock_embedding_model_id}")
response = retrieval_chain.invoke({"input": question})

logger.info("This are the similar documents from OpenSearch based on the provided query")
source_documents = response.get('source_documents')
print("")
logger.info("These are the similar documents from OpenSearch based on the provided query:")
source_documents = response.get('context')
for d in source_documents:
logger.info(f"With the following similar content from OpenSearch:\n{d.page_content}\n")
logger.info(f"Text: {d.metadata['text']}")
print("")
logger.info(f"Text: {d.page_content}")

logger.info(f"The answer from Bedrock {bedrock_model_id} is: {response.get('result')}")
print("")
logger.info(f"The answer from Bedrock {bedrock_model_id} is: {response.get('answer')}")


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
boto3>=1.34.79
langchain==0.1.14
langchain-community==0.0.31
langchain-core==0.1.50
coloredlogs>=15.0.1
jq==1.7.0
opensearch-py==2.5.0
Expand Down

0 comments on commit e3d7e1d

Please sign in to comment.