# Retrieval Augumented Generation (RAG) inference

***This notebook works best with the `conda_python3` on the `ml.t3.large` instance***.

---

At this point our slide deck data is ingested into Amazon OpenSearch Service Serverless collection. We are now ready to talk to our slide deck using a large multimodal model. We are using the [Anthropic’s Claude 3 Sonnet foundation model](https://aws.amazon.com/about-aws/whats-new/2024/03/anthropics-claude-3-sonnet-model-amazon-bedrock/) for this purpose.

## Step 1. Setup

Install the required Python packages and import the relevant files.

In [None]:
import sys
!{sys.executable} -m pip install -r requirements.txt

In [None]:
# import necessary libraries to run this notebook
import os
import io
import sys
import json
import yaml
import glob
import boto3
import base64
import logging
import requests
import botocore
import sagemaker
import opensearchpy
import numpy as np
import pandas as pd
import globals as g
from PIL import Image
from pathlib import Path
from IPython.display import Image
from urllib.parse import urlparse
from botocore.auth import SigV4Auth
from pandas.core.series import Series
from sagemaker import get_execution_role
from botocore.awsrequest import AWSRequest
from typing import List, Dict, Tuple, Optional
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
from utils import get_cfn_outputs, get_text_embedding, get_llm_response, get_question_entities

In [None]:
# set a logger
logging.basicConfig(format='[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
# global constants
CONFIG_FILE_PATH = "config.yaml"
# read the config yaml file
fpath = CONFIG_FILE_PATH
with open(fpath, 'r') as yaml_in:
    config = yaml.safe_load(yaml_in)
logger.info(f"config read from {fpath} -> {json.dumps(config, indent=2)}")

In [None]:
!pygmentize globals.py

## Step 2. Create two OpenSearch clients for images and texts separately

We create an OpenSearch client so that we can query the vector database for embeddings (pdf files) similar to the questions that we might want to ask of our `PDF file`

Get the name of the OpenSearch Service Serverless collection endpoint and index name from the CloudFormation stack outputs.

In [None]:
outputs = get_cfn_outputs(config['aws']['cfn_stack_name'])
host = outputs['MultimodalCollectionEndpoint'].split('//')[1]
text_index_name = outputs['OpenSearchTextIndexName']
img_index_name = outputs['OpenSearchImgIndexName']
logger.info(f"opensearchhost={host}, text index={text_index_name}, image index={img_index_name}")

In [None]:
session = boto3.Session()
credentials = session.get_credentials()
auth = AWSV4SignerAuth(credentials, g.AWS_REGION, g.OS_SERVICE)

# Represents the OSI client for images
img_os_client = OpenSearch(
    hosts = [{'host': host, 'port': 443}],
    http_auth = auth,
    use_ssl = True,
    verify_certs = True,
    connection_class = RequestsHttpConnection,
    pool_maxsize = 20
)

# Represents the OSI client for images
text_os_client = OpenSearch(
    hosts = [{'host': host, 'port': 443}],
    http_auth = auth,
    use_ssl = True,
    verify_certs = True,
    connection_class = RequestsHttpConnection,
    pool_maxsize = 20
)

## Step 3. Read for RAG

We now have all the pieces for RAG. Here is how we _talk to our slide deck_.

1. Convert the user question into embeddings using the Titan Text Embeddings model.

1. Find the most similar slide (image) corresponding to the the embeddings (for the user question) from the vector database (OpenSearch Serverless).

1. Now ask Claude3 to answer the user question using the retrieved image description for the most similar slide.

In [None]:
bedrock = boto3.client(service_name="bedrock-runtime", endpoint_url=g.TITAN_URL)

A handy function for similarity search in the vector db

In [None]:
def find_similar_data_with_string_filter(text_embedding, size, os_client, index_name, filter_string):
    """
    This function is used to prefilter the responses only with images/texts that have entities that match
    with the entities provided in the question. Once the documents are refiltered, the search from the index
    is returned.
    """
    logger.info(f"filter_string: {filter_string}")
    should_clauses = []
    # Convert filter_string to lowercase to ensure case-insensitive matching
    filter_string_lower = filter_string.lower()
    filter_string_upper = filter_string.upper()
    # entities can either be lower case or upper case. search for both cases is supported
    filter_string_variants = [filter_string_lower, filter_string_upper]
    logger.info(f"Entities extracted: {filter_string_variants}")
    for word in filter_string_variants:
        for entity in word.split(","):
            should_clauses.append({
                "wildcard": {
                    "metadata.entities": {
                        # wildcard queries are to search for terms that match a wildcard pattern
                        # in this case, we are searching for an entity within the text embedding
                        "value": f"*{entity.strip()}*",
                        "case_insensitive": True
                    }
                }
            })
    query = {
        "size": size,
        "query": {
            "bool": {
                "must": {
                    "knn": {
                        "vector_embedding": {
                            "vector": text_embedding,
                            "k": size
                        }
                    }
                },
                "filter": {
                    "bool": {
                        "should": should_clauses,
                        # should match at least a single entity to fetch a response. Increase if there
                        # is a lot being asked in the question
                        "minimum_should_match": 2
                    }
                }
            }
        }
    }

    try:
        content_based_search = os_client.search(body=query, index=index_name)
    except Exception as e:
        logger.error(f"error occured while querying OpenSearch index={index_name}, exception={e}")
        content_based_search = None
    return content_based_search


In [None]:
def get_nearest_img_search_response(nearest_image_path: str, prompt: str, modelId: str) -> str:
    """
    This function takes in the file path that is most similarto the text embeddings
    of the question, returns the image and checks for if the text description does not
    contain the answer, directly search for the answer in the selected image.
    """
    bedrock = boto3.client(service_name="bedrock-runtime", region_name=g.AWS_REGION, endpoint_url=g.TITAN_URL)
    # extract the file name from the nearest image path stored in s3
    filename: str = os.path.basename(nearest_image_path)
    local_directory: str = os.path.join(g.IMAGE_DIR, 'b64_images')
    local_image_path = os.path.join(local_directory, filename.replace('.jpg', '.b64'))
    print(local_image_path)
    # read the file, MAX image size supported is 2048 * 2048 pixels
    try:
        with open(local_image_path, "rb") as image_file:
            input_image_b64 = image_file.read().decode('utf-8')
    except Exception as e:
        logger.error(f"Error reading base64 image from local directory: {e}")
        return None

    body = json.dumps(
        {
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": 2000,
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "source": {
                                "type": "base64",
                                "media_type": "image/jpeg",
                                "data": input_image_b64
                            },
                        },
                        {"type": "text", "text": prompt},
                    ],
                }
            ],
        }
    )

    response = bedrock.invoke_model(
        modelId=modelId,
        body=body
    )

    resp_body = json.loads(response['body'].read().decode("utf-8"))
    resp_text = resp_body['content'][0]['text'].replace('"', "'")
    return resp_text

In [None]:
def response_from_text_extracted(bedrock: botocore.client, 
                     prompt: str) -> str:
    """
    This function takes in the prompt that checks whether the text file has a response and if not, 
    returns a "not found"
    """
    modelId=config['bedrock_model_info']['claude_sonnet_model_id']
    body = json.dumps(
    {
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 1000,
        "messages": [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                ],
            }
        ],
    })

    try:
        response = bedrock.invoke_model(
        modelId=modelId,
        body=body)

        response_body = json.loads(response['body'].read().decode("utf-8"))
        llm_response = response_body['content'][0]['text'].replace('"', "'")

    except Exception as e:
        logger.error(f"exception={e}")
        llm_response = None
    return llm_response

In [None]:
def sanitize_llm_response(llm_response: str) -> str:
    """
    This function sanitizes the LLM response generated. If the LLM response contains a sentence that has 
    "not found" or "Not found" within it, it returns only "not found". Case sensitivity does not matter.
    """
    string_to_match: str = "not found"
    sanitized_response: Optional[str] = None
    try:
        # Normalize the case for comparison
        normalized_response = llm_response.lower()
        if llm_response is None:
            sanitized_response: Optional[str] = None
        # Check if the normalized response contains the string "not found"
        if string_to_match in normalized_response:
            sanitized_response = string_to_match
        else:
            sanitized_response = normalized_response
    except Exception as e:
        logger.error(f"The LLM response cannot be sanitized: {e}")
        sanitized_response: Optional[str] = None
    return sanitized_response

## Function to get response from indexes

In [None]:
direct_image_answer_retrieval_prompt: str = """
Human: Your role is to give the answer to the question in the <question></question> tags. If the image description in the <img_text_desc></img_text_desc> tags does not contain the answer to the question, then search the "data" for it. 
If the img_text_desc or the actual image in the "data", both do not contain the answer to the question, then respond with two words only - "not found".

Refer to the img_text_desc and question below:

<img_text_desc>
{context}
</img_text_desc>

<question>
{question}
</question>

If the answer to the question is not in the image description, search the actual image provided in "data" and retrieve the answer directly from the image provided. Search for the question directly in the image in this case and retrieve the most accurate answer. Do not make up an answer.

Assistant: Based on the question and context, if there is an answer, here is my response in 1 sentence. I will ONLY respond with the two words "not found" in lower case if no answer to the question is found:
"""

direct_text_answer_retrieval_prompt: str = """
Human: Your role is to give the answer to the question in the <question></question> tags. If the text description in the <text_desc></text_desc> tags does not contain the answer to the question, then respond with two words only - "not found".

Refer to the text_desc and question below:

<text_desc>
{context}
</text_desc>
<question>
{question}
</question>

If the answer to the question is not in the text desctipion, then respond with two words only: "not found"

If the answer to the question is mentioned in the text_desc, then give the answer.

Assistant: Based on the question and context, if there is an answer, here is my response in 1 sentence. I will ONLY respond with the two words "not found" in lower case if no answer to the question is found:
"""

In [None]:
def get_index_response(question: str,
                       size: int,
                       index_clients: List[Tuple[opensearchpy.client.OpenSearch, str]]) -> Dict:
    """
    Get LLM responses from retrieved data on questions asked from image, text, or both indexes combined
    :param question: Question that a user asks on the content
    :param size: 'k' size
    :param index_clients: List of tuples containing OpenSearch clients and index names
    :Dict: Dictionary with the context used to answer the question and the final response
    """
    try:
        index_llm_response_and_context: Dict = {'source': ''}
        model_id: str = config['bedrock_model_info']['claude_sonnet_model_id']
        # Represents the list of extracted text and paths from the most similar hits
        all_hits = []
        for os_client, index_name in index_clients:
            question_entities = get_question_entities(bedrock, question)
            print(f"question entities: {question_entities}")
            # Get the text embedding for the given question
            text_embedding = get_text_embedding(bedrock, question)
            vector_db_response = find_similar_data_with_string_filter(text_embedding, size, os_client, index_name, question_entities)
            if vector_db_response:
                hits = vector_db_response.get('hits', {}).get('hits', [])
                for hit in hits:
                    content_path = hit.get('_source').get('file_path')
                    extracted_text = hit.get('_source').get('file_text')
                    all_hits.append((content_path, extracted_text, index_name))
        for content_path, extracted_text, index_name in all_hits:
            file_text: str = ""
            logger.info(f"Iterating through all relevant hits to search for an answer....")

            if index_name == outputs['OpenSearchImgIndexName']:
                # If the response is from the image index, append to file_text
                !aws s3 cp {content_path} .
                local_img_path = os.path.basename(content_path)
                display(Image(filename=local_img_path))

            logger.info(f"Going to answer the question: \"{question}\" using the context: \"{content_path}\"")

            # Now getting a response from the text or image index
            if index_name == outputs['OpenSearchImgIndexName']:
                # If the response is not given in the extracted text, search the image directly
                search_in_img_prompt: str = direct_image_answer_retrieval_prompt.format(context=extracted_text, question=question)
                direct_response = get_nearest_img_search_response(content_path, search_in_img_prompt, config['bedrock_model_info']['claude_sonnet_model_id'])
                sanitized_response = sanitize_llm_response(direct_response)
                logger.info(f"sanitized response: {sanitized_response}")
                logger.info(f"response from the image index: {direct_response}")
                # If the answer is not contained in the image description, then 
                # add the llm response to that specific chosen image to the file_text/context.
                if sanitized_response != "not found":
                    logger.info(f"Response found from the image index")
                    # Update the context sources if the answer is given directly from the image as context to give the final answer
                    file_text += sanitized_response
                    index_llm_response_and_context['source'] += file_text
            elif index_name == outputs['OpenSearchTextIndexName']:
                search_in_txt_prompt: str = direct_text_answer_retrieval_prompt.format(context=extracted_text, question=question)
                direct_response = response_from_text_extracted(bedrock, search_in_txt_prompt)
                sanitized_response = sanitize_llm_response(direct_response)
                logger.info(f"sanitized response: {sanitized_response}")
                logger.info(f"response from the text index: {direct_response}")
                # If the answer is not contained in the text description, then 
                # add the llm response to that specific chosen text to the file_text/context.
                if sanitized_response != "not found":
                    logger.info(f"Response found from the text index")
                    # Update the context sources if the answer is given directly from the text as context to give the final answer
                    file_text += sanitized_response
                    index_llm_response_and_context['source'] += file_text

        logger.info(f"Summary provided to the llm: {index_llm_response_and_context['source']}")
        logger.info(f"question provided to the llm: {question}")
        index_llm_response = get_llm_response(bedrock, question, index_llm_response_and_context['source'], model_id)
        index_llm_response_and_context['response'] = index_llm_response
        logger.info(f"response from the llm: {index_llm_response}")
    except Exception as e:
        logger.error(f"Could not get a response: {e}")
        index_llm_response_and_context['response'] = None
        index_llm_response_and_context['source'] = None
    return index_llm_response_and_context


### Question 3 - Combined Response (Both Image and Text Indexes)

#### First, get the response from the text index using the text index

In [None]:
# question: str = "<enter your example question here>"
# # index_clients: List[Tuple] = [(img_os_client, img_index_name)]
# index_clients: List[Tuple] = [(text_os_client, text_index_name), (img_os_client, img_index_name)]
# get_index_response(question, 1, index_clients)

In [None]:
question: str = "What is the Cisco recommendation based on the ratings?"
# index_clients: List[Tuple] = [(img_os_client, img_index_name)]
index_clients: List[Tuple] = [(text_os_client, text_index_name), (img_os_client, img_index_name)]
get_index_response(question, config['k_count_retrieval'], index_clients)

In [None]:
question: str = "What was the total Debt/Equity ratio for Amazon in 2022 ?"
# index_clients: List[Tuple] = [(img_os_client, img_index_name)]
index_clients: List[Tuple] = [(text_os_client, text_index_name), (img_os_client, img_index_name)]
get_index_response(question, config['k_count_retrieval'], index_clients)

In [None]:
question: str = "What is the Total Debt/Capital Ratio for Amazon?"
# index_clients: List[Tuple] = [(img_os_client, img_index_name)]
index_clients: List[Tuple] = [(text_os_client, text_index_name), (img_os_client, img_index_name)]
get_index_response(question, config['k_count_retrieval'], index_clients)

## Eval Dataset Comparison
---

In this section of the notebook we do as follows:

1. Iterate through each question provided in the dataset, and call the `get_index_response` function

1. Record responses from the text response, the image response and combined responses

1. Update the df and store the result in the `eval directory`

In [None]:
eval_files: Optional[List[str]] = None
eval_content = []
# if the evaluation dataset is given, get those files and create a dataframe to work with
if config['eval_qna_dataset_info']['is_given'] is True:
    eval_dir: str = config['eval_qna_dataset_info']['dir_name']
    fpath = os.path.join(eval_dir, "*.csv")
    eval_files = glob.glob(fpath, recursive=True)
    for eval_file in eval_files:
        logger.info(f"eval files: {eval_file}")
        eval_df = pd.read_csv(eval_file)
        eval_df = eval_df.drop(columns=['Unnamed: 3'])
eval_df.head(10)

In [None]:
if eval_df is not None:
    index_clients_both = [(text_os_client, text_index_name), (img_os_client, img_index_name)]
    text_index_client = [(text_os_client, text_index_name)]
    img_index_client = [(img_os_client, img_index_name)]
    for i, question in enumerate(eval_df[config['eval_qna_dataset_info']['question_key']]):
        combined_response = get_index_response(question, config['k_count_retrieval'], index_clients_both)
        eval_df.at[i, 'combined_response'] = combined_response['response']
        logger.info(f"combined_response['response']: {combined_response['response']}")
        eval_df.at[i, 'image_and_text_source'] = combined_response['source']    
        text_response = get_index_response(question, config['k_count_retrieval'], text_index_client)
        eval_df.at[i, 'text_response'] = text_response['response']
        eval_df.at[i, 'text_source'] = text_response['source']
        image_response = get_index_response(question, config['k_count_retrieval'], img_index_client)
        eval_df.at[i, 'img_response'] = image_response['response']
        eval_df.at[i, 'img_source'] = image_response['source']
    print(eval_df.head(10))
    metrics_dir = config['metrics_dir']['dir_name']
    os.makedirs(metrics_dir, exist_ok=True)
    side_view_eval_file = os.path.join(metrics_dir, config['eval_qna_dataset_info']['updated_eval_file'])
    eval_df.to_csv(side_view_eval_file, index=True)
else:
    logger.info(f"Evaluation dataset not provided. Provide a data set in the eval directory and try again.")

In [None]:
eval_df.head()

## Clean Up
