# Retrieval Augmented Generation (RAG) Query to LLM

In this notebook, we will show you how to query the LLM with RAG techniques


### Fill in parameters
Replace placeholder parameters below with SageMaker Jumpstart model endpoints with OpenSearch domain endpoint and Prompt Firehose name retrieved from CDK deployment.

In [None]:
# Replace parameters with your own values, these are example values

text_model_endpoint = "jumpstart-dft-hf-llm-falcon-40b-bf16"
embed_model_endpoint = "jumpstart-dft-hf-textembedding-gpt-j-6b-fp16"
opensearch_domain_endpoint = "vpc-opensearchdomai-dtvvqhrhsqtc-avpib3sgtuvbynuwyqgwutrya4.us-east-1.es.amazonaws.com"
fh_prompt_name = "BackendStack-FirehosePrompts-m7cR7FRnyw3O"

### Install prerequisites

In [None]:
!pip3 install boto3==1.28.17
!pip3 install streamlit==1.27.2
!pip3 install langchain==0.0.317
!pip3 install opensearch-py==2.4.2

### Import all required libraries such as boto3, langchain etc. 

In [None]:
import boto3
from botocore.exceptions import ClientError
from boto3.dynamodb.conditions import Key
import json
import logging
from typing import Dict
from langchain.vectorstores import OpenSearchVectorSearch
from langchain import SagemakerEndpoint, PromptTemplate
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.chains.question_answering import load_qa_chain
from opensearchpy import RequestsHttpConnection, AWSV4SignerAuth

### Configure boto3 clients and logging

In [None]:
cwclient = boto3.client('cloudwatch')
fhclient = boto3.client('firehose')
credentials = boto3.Session().get_credentials()
region = boto3.Session().region_name

logging.getLogger().setLevel(logging.INFO)
logger = logging.getLogger()

### Define DynamoDB and CloudWatch functions

In [None]:
# def put_ddb_item(item):
#     try:
#         table.put_item(Item=item)
#     except ClientError as err:
#         logger.error(err.response['Error']['Code'], err.response['Error']['Message'])
#         raise

# def get_ddb_item(id):
#     try:
#         items = table.query(KeyConditionExpression=Key('id').eq(id))['Items'][0]
#         return items
#     except ClientError as err:
#         logger.error(err.response['Error']['Code'], err.response['Error']['Message'])
#         raise

def put_cw_metric(cwclient, score):
    try:
        cwclient.put_metric_data(
            Namespace='rag',
            MetricData=[
                {
                    'MetricName': 'similarity',
                    'Value': score,
                    'Unit': 'None',
                    'StorageResolution': 1
                },
            ]
        )
    except ClientError as err:
        logger.error(err.response['Error']['Code'], err.response['Error']['Message'])
        raise

### Define Langchain input and output handlers for Sagemaker Endpoint

In [None]:
class TextContentHandler(LLMContentHandler):
    """
    encode input string as utf-8 bytes, read the generated text
    from the output
    """
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs = {}) -> bytes:
        input_str = json.dumps({"inputs": prompt, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]['generated_text']

class EmbeddingsContentHandler(EmbeddingsContentHandler):
    """
    encode input string as utf-8 bytes, read the embeddings
    from the output
    """
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"text_inputs": inputs, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes):
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["embedding"]

### Define helper functions for embedding and text generation

In [None]:
def create_sagemaker_embeddings(endpoint_name):
    # create a content handler object which knows how to serialize
    # and deserialize communication with the model endpoint
    content_handler = EmbeddingsContentHandler()

    # read to create the Sagemaker embeddings, we are providing
    # the Sagemaker endpoint that will be used for generating the
    # embeddings to the class
    embeddings = SagemakerEndpointEmbeddings(
        endpoint_name=endpoint_name,
        region_name=region, 
        content_handler=content_handler
    )

    return embeddings

# Functiion to do vector search and get context from opensearch. Returns list of documents
def get_context_from_opensearch(query, endpoint_name, opensearch_domain_endpoint, opensearch_index):

    credentials = boto3.Session().get_credentials()
    region = boto3.Session().region_name
    auth = AWSV4SignerAuth(credentials, region, "es")
    opensearch_endpoint = f"https://{opensearch_domain_endpoint}"
    docsearch = OpenSearchVectorSearch(
        index_name=opensearch_index,
        embedding_function=create_sagemaker_embeddings(endpoint_name),
        opensearch_url=opensearch_endpoint,
        http_auth = auth,
        use_ssl = True,
        verify_certs = True,
        connection_class = RequestsHttpConnection,
        is_aoss=False
    )

    # docsearch = OpenSearchVectorSearch(
    #     index_name=opensearch_index,
    #     embedding_function=create_sagemaker_embeddings(endpoint_name),
    #     opensearch_url=opensearch_endpoint,
    #     is_aoss=False
    # )
    docs_with_scores = docsearch.similarity_search_with_score(query, k=3, vector_field="embedding", text_field="passage")
    for d in docs_with_scores:
        score = d[1]
        put_cw_metric(cwclient, score)
    docs = [doc[0] for doc in docs_with_scores]
    logger.info(f"docs received from opensearch:\n{docs}")
    return docs # return list of matching docs

# Function to combine the context from vector search, combine with question and query sage maker deployed model
def call_sm_text_generation_model(query, context, endpoint_name):

    # create a content handler object which knows how to serialize
    # and deserialize communication with the model endpoint
    content_handler = TextContentHandler()
    
    ## Query to sagemaker endpoint to generate a response from query and context
    llm = SagemakerEndpoint(
        endpoint_name=endpoint_name,
        region_name=region,
        content_handler=content_handler,
        endpoint_kwargs={'CustomAttributes': 'accept_eula=true'}
    )
    prompt_template = """Answer based on context:\n\n{context}\n\n{question}"""

    prompt = PromptTemplate(
        template=prompt_template, input_variables=["context", "question"]
    )
    logger.info(f"prompt sent to llm = \"{prompt}\"")
    chain = load_qa_chain(llm=llm, prompt=prompt)
    answer = chain({"input_documents": context, "question": query}, return_only_outputs=True)['output_text']
    logger.info(f"answer received from llm,\nquestion: \"{query}\"\nanswer: \"{answer}\"")
    
    return answer

### Enter your query

In [None]:
# ENTER YOUR QUERY AS A STRING BEFORE RUNNING THIS CELL
query = ""

In [None]:
context = get_context_from_opensearch(query, embed_model_endpoint, opensearch_domain_endpoint, "embeddings")
context_formatted =  [{"page_content": doc.page_content} for doc in context]
print(f"Found {str(len(context))} similar documents")

answer = call_sm_text_generation_model(query, context, text_model_endpoint)
print(answer)

### Send prompt data to Firehose for further analysis

In [None]:
fh_stream_records = []
embedding = create_sagemaker_embeddings(embed_model_endpoint).embed_query(query)
fh_stream_records.append({'Data': (str(embedding)+ "\n").encode('utf-8')})
fhclient.put_record_batch( DeliveryStreamName=fh_prompt_name, Records=fh_stream_records)