# Building a RAG chatbot with LangChain, Hugging Face, Amazon SageMaker and Amazon OpenSearch Serverless

In [None]:
%%sh
pip install sagemaker langchain opensearch-py -qU

In [None]:
import boto3, json, sagemaker

from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
from transformers import AutoConfig
from typing import Dict

from opensearchpy import RequestsHttpConnection, AWSV4SignerAuth

from langchain import LLMChain
from langchain.chains import RetrievalQA
from langchain.document_loaders import HuggingFaceDatasetLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import OpenSearchVectorSearch

## Deploy our LLM on a SageMaker Endpoint

In [None]:
role = sagemaker.get_execution_role()

hub = {
	'HF_MODEL_ID':'mistralai/Mistral-7B-Instruct-v0.1',
	'SM_NUM_GPUS': '1'
}

huggingface_model = HuggingFaceModel(
	image_uri=get_huggingface_llm_image_uri("huggingface",version="1.1.0"),
	env=hub,
	role=role 
)

predictor = huggingface_model.deploy(
	initial_instance_count=1,
	instance_type="ml.g5.2xlarge",
	container_startup_health_check_timeout=300,
    wait=False,
  )

## Configure the LangChain input and output handlers for our LLM

In [None]:
model_kwargs = {"max_new_tokens": 512, "top_p": 0.8, "temperature": 0.8}

class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps(
            # Mistral prompt, see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
            {"inputs": f"<s>[INST] {prompt} [/INST]", "parameters": {**model_kwargs}}
        )
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        splits = response_json[0]["generated_text"].split("[/INST] ")
        return splits[1]

content_handler = ContentHandler()

In [None]:
sm_client = boto3.client('sagemaker') # needed later to check that endpoint is up
smrt_client = boto3.client("sagemaker-runtime") # needed for AWS credentials

llm = SagemakerEndpoint(
    endpoint_name=predictor.endpoint_name,
    model_kwargs=model_kwargs,
    content_handler=content_handler,
    client=smrt_client,
)

## Load the Reuters news dataset from the Hugging Face hub

In [None]:
loader = HuggingFaceDatasetLoader("reuters21578", 
                                  page_content_column="text", 
                                  name="ModLewis")
data = loader.load()
len(data)

In [None]:
data[0]

## Split the news articles into chunks

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter

splitter = RecursiveCharacterTextSplitter(chunk_size=128, chunk_overlap=0)

In [None]:
%%time
docs = splitter.split_documents(data)
len(docs)

In [None]:
docs[0]

## Configure our embedding model

In [None]:
# See https://huggingface.co/spaces/mteb/leaderboard

embedding_model_id = "BAAI/bge-small-en-v1.5"

config = AutoConfig.from_pretrained(embedding_model_id)
embedding_dimensions = config.hidden_size

embeddings = HuggingFaceEmbeddings(
    model_name=embedding_model_id,
)

## Define credentials for Amazon OpenSearch Serverless

In [None]:
host = 's35vkyhnago99udpmal0.us-east-1.aoss.amazonaws.com'
index_name = 'julsimon-index-demo'
region = 'us-east-1'

credentials = boto3.Session().get_credentials()
auth = AWSV4SignerAuth(credentials, region, "aoss")

## Embed and index chunks

In [None]:
docs_100 = [docs[x:x+100] for x in range(0, len(docs), 100)]

In [None]:
%%time

for docs in docs_100:
    oss = OpenSearchVectorSearch.from_documents(
        docs,
        embeddings,
        opensearch_url=f'https://{host}:443',
        http_auth=auth,
        use_ssl=True,
        verify_certs=True,
        connection_class=RequestsHttpConnection,
        index_name=index_name,
        timeout=60,
    )
    print(".", end="")

## Configure RAG chain

In [None]:
retriever = oss.as_retriever(search_kwargs={"k": 10})

In [None]:
# Define prompt template

prompt_template = """
As a helpful news agent, please answer the question using only the context below.
If you don't know, say you don't know.
Cite the title of the articles you used to build your answer.

question: {question}

context: {context}
"""

prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

In [None]:
chain = RetrievalQA.from_chain_type(
    llm=llm, 
    chain_type="stuff",
    retriever=retriever, 
    chain_type_kwargs = {"prompt": prompt})

In [None]:
# Make sure that our LLM has been deployed

waiter = sm_client.get_waiter('endpoint_in_service')
waiter.wait(EndpointName=predictor.endpoint_name)

## Ask a question

In [None]:
question = "What are the worst storms in recent news?"
answer = chain.run({"query": question})
print(answer)

In [None]:
predictor.delete_model()
predictor.delete_endpoint()