# Retrieval-Augmented Generation: Question Answering based on Custom Dataset with Open-sourced [LangChain](https://python.langchain.com/en/latest/index.html) Library


---

This notebook has been tested in us-east-1 with **Data Science 2.0** kernel

---


Many use cases such as building a chatbot require text (text2text) generation models like **[BloomZ 7B1](https://huggingface.co/bigscience/bloomz-7b1)**, **[Flan T5 XXL](https://huggingface.co/google/flan-t5-xxl)**, and **[Flan T5 UL2](https://huggingface.co/google/flan-ul2)** to respond to user questions with insightful answers. The **BloomZ 7B1**, **Flan T5 XXL**, and **Flan T5 UL2** models have picked up a lot of general knowledge in training, but we often need to ingest and use a large library of more specific information.

In this notebook we will demonstrate how to use **BloomZ 7B1**, **Flan T5 XXL**, and **Flan T5 UL2** to answer questions using a library of documents as a reference, by using document embeddings and retrieval. The embeddings are generated from **GPT-J-6B** embedding model. 

**This notebook serves a template such that you can easily replace the example dataset by your own to build a custom question and asnwering application.**

## SageMaker Studio Notebook

- instance: ml.g4dn.xlarge
- kernel: python 3
- pyTorch 1.13 Python 3.9 GPU Optimised

## Model Deployment

Deploy from SageMaker JumpStart

- textembedding-gpt-j-6b (ml.g5.12xlarge)
- jumpstart-dft-falcon-40b-instruct-bf16 (ml.g5.48xlarge)

In [2]:
!pip install --upgrade pip
!pip install --upgrade langchain --quiet
!pip install transformers faiss-gpu --quiet
# !pip install bs4 --quiet

[0m

In [3]:
import time
import sagemaker, boto3, json
from sagemaker.session import Session
from sagemaker.model import Model
from sagemaker import image_uris, model_uris, script_uris, hyperparameters
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base
from typing import Any, Dict, List, Optional
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sm_client = boto3.client("sagemaker", aws_region)
sess = sagemaker.Session()
model_version = "*"

In [4]:
def query_endpoint_with_json_payload(encoded_json, endpoint_name, content_type="application/json"):
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType=content_type, Body=encoded_json
    )
    return response

def parse_response_model(query_response):
    model_predictions = json.loads(query_response["Body"].read())
    return [gen["generated_text"] for gen in model_predictions]

In [5]:
_MODEL_CONFIG_ = {
    
    "jumpstart-dft-hf-textembedding-gpt-j-6b": {
        "aws_region": "us-east-1",
        "endpoint_name": "jumpstart-dft-hf-textembedding-gpt-j-6b",
    },
    
    "jumpstart-dft-hf-llm-falcon-40b-instruct-bf16" : {
        "aws_region": "us-east-1",
        "endpoint_name": "jumpstart-dft-hf-llm-falcon-40b-instruct-bf16",
        "parse_function": parse_response_model,
        "prompt": """{context}\n\nGiven the above context, answer the following question:\n{question}\nAnswer: """,
    },
}


## Step 2. Ask a question to LLM without providing the context

To better illustrate why we need retrieval-augmented generation (RAG) based approach to solve the question and anwering problem. Let's directly ask the model a question and see how they respond.

In [6]:
question = "Which instances can I use with Managed Spot Training in SageMaker?"

In [7]:
# best_of: None, 
# temperature: None, 
# repetition_penalty: None, 
# top_k: None, 
# top_p: None, 
# typical_p: None, 
# do_sample: false, 
# max_new_tokens: 20, 
# return_full_text: Some(false), 
# stop: [], 
# truncate: None, 
# watermark: false, 
# details: false, 
# seed: None 

In [8]:
payload = {
    "inputs": question,
    "parameters":{
        "max_new_tokens": 100,
        "num_return_sequences": 1,
        "top_k": 50,
        "top_p": 0.95,
        "do_sample": False,
        "return_full_text": True,
        "temperature": 0.2
    }
}

list_of_LLMs = list(_MODEL_CONFIG_.keys())
list_of_LLMs.remove("jumpstart-dft-hf-textembedding-gpt-j-6b")  # remove the embedding model


for model_id in list_of_LLMs:
    endpoint_name = _MODEL_CONFIG_[model_id]["endpoint_name"]
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = _MODEL_CONFIG_[model_id]["parse_function"](query_response)
    print(f"For model: {model_id}, the generated output is: {generated_texts[0]}\n")

For model: jumpstart-dft-hf-llm-falcon-40b-instruct-bf16, the generated output is: Which instances can I use with Managed Spot Training in SageMaker?
You can use Managed Spot Training in SageMaker with the following instances:

1. ml.t3.medium
2. ml.t3.large
3. ml.t3.xlarge
4. ml.p3.2xlarge
5. ml.p3.8xlarge
6. ml.p3.16xlarge
7. ml.p3.24xlarge
8. ml.p3.2xlarge




You can see the generated answer is wrong or doesn't make much sense. 

## Step 3. Improve the answer to the same question using **prompt engineering** with insightful context


To better answer the question well, we provide extra contextual information, combine it with a prompt, and send it to model together with the question. Below is an example.

In [9]:
context = """Managed Spot Training can be used with all instances supported in Amazon SageMaker. 
Managed Spot Training is supported in all AWS Regions where Amazon SageMaker is currently available."""

In [10]:

parameters ={
        "max_new_tokens": 100,
        "num_return_sequences": 1,
        "top_k": 50,
        "top_p": 0.95,
        "do_sample": False,
        "return_full_text": True,
        "temperature": 0.2
    }


for model_id in list_of_LLMs:
    endpoint_name = _MODEL_CONFIG_[model_id]["endpoint_name"]
    prompt = _MODEL_CONFIG_[model_id]["prompt"]

    text_input = prompt.replace("{context}", context)
    text_input = text_input.replace("{question}", question)
    payload = {"inputs": text_input, "parameters":parameters}

    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = _MODEL_CONFIG_[model_id]["parse_function"](query_response)
    print(
        f"For model: {model_id}, the generated output is: {generated_texts[0]}"
    )

For model: jumpstart-dft-hf-llm-falcon-40b-instruct-bf16, the generated output is: Managed Spot Training can be used with all instances supported in Amazon SageMaker. 
Managed Spot Training is supported in all AWS Regions where Amazon SageMaker is currently available.

Given the above context, answer the following question:
Which instances can I use with Managed Spot Training in SageMaker?
Answer: 
All instances supported in Amazon SageMaker can be used with Managed Spot Training.


The output from step 3 tells us the chance to get the correct response significantly correlates with the insightful context you send into the LLM. 

**<span style="color:red">Now, the question becomes where can I find the insightful context based on the user query? The answer is to use a pre-stored knowledge data base with retrieval augmented generation, as shown in step 4 below</span>.**

## Step 4. Use RAG based approach with [LangChain](https://python.langchain.com/en/latest/index.html) and SageMaker endpoints to build a simplified question and answering application.


We plan to use document embeddings to fetch the most relevant documents in our document knowledge library and combine them with the prompt that we provide to LLM.

To achieve that, we will do following.

1. **Generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B embedding model.**
2. **Identify top K most relevant documents based on user query.**
    - 2.1 **For a query of your interest, generate the embedding of the query using the same embedding model.**
    - 2.2 **Search the indexes of top K most relevant documents in the embedding space using in-memory Faiss search.**
    - 2.3 **Use the indexes to retrieve the corresponded documents.**
3. **Combine the retrieved documents with prompt and question and send them into SageMaker LLM.**



Note: The retrieved document/text should be large enough to contain enough information to answer a question; but small enough to fit into the LLM prompt -- maximum sequence length of 1024 tokens. 

---
To build a simiplied QA application with LangChain, we need: 
1. Wrap up our SageMaker endpoints for embedding model and LLM into `langchain.embeddings.SagemakerEndpointEmbeddings` and `langchain.llms.sagemaker_endpoint.SagemakerEndpoint`. That requires a small overwritten of `SagemakerEndpointEmbeddings` class to make it compatible with SageMaker embedding mdoel.
2. Prepare the dataset to build the knowledge data base. 

---

Wrap up our SageMaker endpoints for embedding model into `langchain.embeddings.SagemakerEndpointEmbeddings`.  That requires a small overwritten of `SagemakerEndpointEmbeddings` class to make it compatible with SageMaker embedding mdoel.

### Embedding

In [11]:
import langchain
langchain.__version__

'0.0.202'

In [12]:
from typing import Dict, List
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
import json


# class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
    
#     def embed_documents(self, texts: List[str], chunk_size: int = 5) -> List[List[float]]:
#         """Compute doc embeddings using a SageMaker Inference Endpoint.

#         Args:
#             texts: The list of texts to embed.
#             chunk_size: The chunk size defines how many input texts will
#                 be grouped together as request. If None, will use the
#                 chunk size specified by the class.

#         Returns:
#             List of embeddings, one for each text.
#         """
#         results = []
#         _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size

#         for i in range(0, len(texts), _chunk_size):
#             response = self._embedding_func(texts[i : i + _chunk_size])
#             results.extend(response)
#         return results


class ContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, inputs: List[str], model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"text_inputs": inputs, **model_kwargs})
        return input_str.encode('utf-8')

    def transform_output(self, output: bytes) -> List[List[float]]:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["embedding"]

emb_content_handler = ContentHandler()


sm_llm_embeddings = SagemakerEndpointEmbeddings(
    endpoint_name=_MODEL_CONFIG_["jumpstart-dft-hf-textembedding-gpt-j-6b"]["endpoint_name"],
    region_name=_MODEL_CONFIG_["jumpstart-dft-hf-textembedding-gpt-j-6b"]["aws_region"],
    content_handler=emb_content_handler
)

In [13]:
sm_llm_embeddings.embed_documents(["Hello World"])[0][:5]

[0.0054481481201946735,
 -0.006502704694867134,
 0.003976280800998211,
 -0.019499264657497406,
 0.007071854546666145]

### LLM

Next, we wrap up our SageMaker endpoints for LLM into `langchain.llms.sagemaker_endpoint.SagemakerEndpoint`. 

In [14]:
import json
from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint

parameters ={
        "max_new_tokens": 100,
        "num_return_sequences": 1,
        "top_k": 50,
        "top_p": 0.95,
        "do_sample": False,
        "return_full_text": False,
        "temperature": 0.2
    }

class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]


content_handler = ContentHandler()

sm_llm_falcon_instruct = SagemakerEndpoint(
    endpoint_name=_MODEL_CONFIG_["jumpstart-dft-hf-llm-falcon-40b-instruct-bf16"]["endpoint_name"],
    region_name=_MODEL_CONFIG_["jumpstart-dft-hf-llm-falcon-40b-instruct-bf16"]["aws_region"],
    model_kwargs=parameters,
    content_handler=content_handler,
)

In [15]:
sm_llm_falcon_instruct("Which day comes after Friday?")

'\nSaturday'

## Data

Now, let's download the example data and prepare it for demonstration. We will use [Amazon SageMaker FAQs](https://aws.amazon.com/sagemaker/faqs/) as knowledge library. The data are formatted in a CSV file with two columns Question and Answer. We use the Answer column as the documents of knowledge library, from which relevant documents are retrieved based on a query. 

**For your purpose, you can replace the example dataset of your own to build a custom question and answering application.**

In [16]:
s3_path = f"s3://jumpstart-cache-prod-us-east-2/training-datasets/Amazon_SageMaker_FAQs/Amazon_SageMaker_FAQs.csv"

In [17]:
# Downloading the Database
!mkdir -p rag_data
!aws s3 cp $s3_path rag_data/Amazon_SageMaker_FAQs.csv

download: s3://jumpstart-cache-prod-us-east-2/training-datasets/Amazon_SageMaker_FAQs/Amazon_SageMaker_FAQs.csv to rag_data/Amazon_SageMaker_FAQs.csv


For the case when you have data saved in multiple subsets. The following code will read all files that end with `.csv` and concatenate them together. Please ensure each `csv` file has the same format.

In [18]:
import pandas as pd

df_knowledge = pd.read_csv("rag_data/Amazon_SageMaker_FAQs.csv", header=None, usecols=[1], names=["Answer"])
df_knowledge.head(6)

Unnamed: 0,Answer
0,Amazon SageMaker is a fully managed service to...
1,For a list of the supported Amazon SageMaker A...
2,Amazon SageMaker is designed for high availabi...
3,Amazon SageMaker stores code in ML storage vol...
4,Amazon SageMaker ensures that ML model artifac...
5,Amazon SageMaker does not use or share custome...


Drop the `Question` column as it is not used in this demonstration.

In [26]:
df_knowledge.to_csv("rag_data/processed.csv", header=False, index=False)

In [27]:
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.vectorstores import Chroma, AtlasDB, FAISS
from langchain.text_splitter import CharacterTextSplitter
from langchain import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain.document_loaders.csv_loader import CSVLoader

Use langchain to read the `csv` data. There are multiple built-in functions in LangChain to read different format of files such as `txt`, `html`, and `pdf`. For details, see [LangChain document loaders](https://python.langchain.com/en/latest/modules/indexes/document_loaders.html).

In [28]:
loader = CSVLoader(file_path="rag_data/processed.csv")
documents = loader.load()
# text_splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0)
# texts = text_splitter.split_documents(documents) ### if you use langchain.document_loaders.TextLoader to load text file. You can uncomment the code
## to split the text.
documents[:3]

[Document(page_content='Amazon SageMaker is a fully managed service to prepare data and build, train, and deploy machine learning (ML) models for any use case with fully managed infrastructure, tools, and workflows.: For a list of the supported Amazon SageMaker AWS Regions, please visit the\xa0AWS Regional Services page. Also, for more information, see\xa0Regional endpoints\xa0in the AWS general reference guide.', metadata={'source': 'rag_data/processed.csv', 'row': 0}),
 Document(page_content='Amazon SageMaker is a fully managed service to prepare data and build, train, and deploy machine learning (ML) models for any use case with fully managed infrastructure, tools, and workflows.: Amazon SageMaker is designed for high availability. There are no maintenance windows or scheduled downtimes. SageMaker APIs run in Amazon’s proven, high-availability data centers, with service stack replication configured across three facilities in each AWS Region to provide fault tolerance in the event of

**Now, we can build an QA application. <span style="color:red">LangChain makes it extremly simple with following few lines of code</span>.**

Based on the question below, we can achieven the points in Step 4 with just a few lines of code as shown below.

In [29]:
question

'Which instances can I use with Managed Spot Training in SageMaker?'

In [30]:
index_creator = VectorstoreIndexCreator(
    vectorstore_cls=FAISS,
    embedding=sm_llm_embeddings,
    # text_splitter=CharacterTextSplitter(chunk_size=300, chunk_overlap=0),
)

In [31]:
index = index_creator.from_loaders([loader])

In [34]:
sm_llm_falcon_instruct.model_kwargs = {
        "max_new_tokens": 50,
        "num_return_sequences": 1,
        "top_k": 50,
        "top_p": 0.95,
        "do_sample": False,
        "return_full_text": False,
        "temperature": 0.1
}
index.query(question=question, llm=sm_llm_falcon_instruct)

' You can use any instance type that is available in the AWS Spot Market. Spot instances are available at a discount compared to On-Demand instances, and you can use them to train your models at a lower cost.'

## Step 5. Customize the QA application above with different prompt.

Now, we see how simple it is to use LangChain to achieve question and answering application with just few lines of code. Let's break down the above `VectorstoreIndexCreator` and see what's happening under the hood. Furthermore, we will see how to incorporate a customize prompt rather than using a default prompt with `VectorstoreIndexCreator`.

Firstly, we **generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B-FP16 embedding model.**

In [61]:
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Get your splitter ready
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=5)

# Split your docs into texts
texts = text_splitter.split_documents(documents)

# Get embedding engine ready
sm_llm_embeddings

# Embedd your texts
docsearch = FAISS.from_documents(texts, sm_llm_embeddings)

In [62]:
question

'Which instances can I use with Managed Spot Training in SageMaker?'

Based on the question above, we then **identify top K most relevant documents based on user query, where K = 3 in this setup**.

In [95]:
docs = docsearch.similarity_search(question, k=3)
docs

[Document(page_content='Managed Spot Training can be used with all instances supported in Amazon SageMaker.', metadata={'source': 'rag_data/processed.csv', 'row': 89}),
 Document(page_content='You enable the Managed Spot Training option when submitting your training jobs and you also specify', metadata={'source': 'rag_data/processed.csv', 'row': 84}),
 Document(page_content='have more granular control over model features can use Amazon SageMaker Edge Manager. SageMaker', metadata={'source': 'rag_data/processed.csv', 'row': 139})]

Finally, we **combine the retrieved documents with prompt and question and send them into SageMaker LLM.** 

We define a customized prompt as below.

In [100]:
prompt_template = """{context}\n\nGiven the above context, answer the following question:\n{question}\n\nAnswer:"""

PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

In [104]:
sm_llm_falcon_instruct.model_kwargs = {
        "max_new_tokens": 50,
        "num_return_sequences": 1,
        "top_k": 50,
        "top_p": 0.95,
        "do_sample": False,
        "return_full_text": True,
        "temperature": 0.1,
}
chain = load_qa_chain(llm=sm_llm_falcon_instruct, prompt=PROMPT)

Send the top 3 most relevant docuemnts and question into LLM to get a answer.

In [105]:
result = chain({"input_documents": docs, "question": question}, return_only_outputs=True)[
    "output_text"
]

Print the final answer from LLM as below, which is accurate.

In [106]:
print(result)

Managed Spot Training can be used with all instances supported in Amazon SageMaker.

You enable the Managed Spot Training option when submitting your training jobs and you also specify

have more granular control over model features can use Amazon SageMaker Edge Manager. SageMaker

Given the above context, answer the following question:
Which instances can I use with Managed Spot Training in SageMaker?

Answer:

You can use all instances supported in Amazon SageMaker with Managed Spot Training.


### Clean up

In [None]:
# delete the endpoints created for testing
for model_id in _MODEL_CONFIG_:
    endpoint_name = _MODEL_CONFIG_[model_id]["endpoint_name"]
    sagemaker_session.delete_endpoint(endpoint_name)