# Lab 03. Build Retrival Augmented Generation System using Amazon EMR Spark Distributed Processing and OpeSearch Vector Database 

In this notebook we demonstrate how you can build a Retrival Augmented Generation System using the following components,
1. Embedding Model: `all-MiniLM-L6-v2`
2. Text Generation Model: `meta-/llama2-7b-chat`
3. Vector Database: OpenSearch as Vector Database to store embeddings
4. StreamLit UI: A Chat Interface to talk to your documents

<div style="background-color: #FFDDDD; border-left: 5px solid red; padding: 10px; color: black;">
    <strong>Kernel:</strong> Spark Analytics 2.0 [SparkMagic PySpark] <strong>Instance Type:</strong> ml.t3.medium
</div>

## Connect to an Existing EMR Cluster

In [None]:
%load_ext sagemaker_studio_analytics_extension.magics

In [None]:
%%help

In [None]:
%%local
!echo "Your EMR Cluster ID ---> $(aws emr list-clusters | jq '.Clusters[0].Id')"

In [None]:
%sm_analytics emr connect --verify-certificate False --cluster-id j-3FIXE21RQG8VM --auth-type None --language python  

## Upload Files from Local to S3

In [None]:
%%local
!python3 -m pip install setuptools

In [None]:
%%local
!python3 -m pip install sagemaker==2.192.0

In [None]:
%%local
import os
import glob
import boto3
import sagemaker
from tqdm import tqdm

In [None]:
%%local
sess = sagemaker.Session()
default_bucket = sess.default_bucket()
s3_client = boto3.client("s3")
print(f"Using default bucket ---> {default_bucket}")

A few sample files are available in directory under ./AWSGuides/, these are sample documents we'll be using to build our RAG application.

In [None]:
%%local
def upload_raw_pdf_files_to_bucket(destination_bucket, destination_prefix, raw_pdf_files):
    
    print(f"Uploading ---> {len(raw_pdf_files)} files!")
    
    uploaded_file_s3uris = []
    for pdf_file in tqdm(raw_pdf_files, total=len(raw_pdf_files)):
        pdf_fname = os.path.basename(pdf_file).replace(",", "").replace(" ", "-")
        
        pdf_dest_prefix = os.path.join(destination_prefix, pdf_fname)
        
        s3_client.upload_file(
            pdf_file, 
            destination_bucket, 
            pdf_dest_prefix
        )
        uploaded_file_s3uris.append(f"s3://{destination_bucket}/{pdf_dest_prefix}")
    
    return uploaded_file_s3uris

pdf_files_to_upload = glob.glob("./AWSGuides/*.pdf")

destination_prefix = "Lab03/raw-pdfs"

files_paths_in_s3 = upload_raw_pdf_files_to_bucket(
    destination_bucket=default_bucket, 
    destination_prefix=destination_prefix,
    raw_pdf_files=pdf_files_to_upload
)

print(f"Uploaded files to ---> {files_paths_in_s3}")

Let's send these variables from our local instance to Pyspark Primary node using a simple 

`%%send_to_spark` command

In [None]:
%%send_to_spark -i destination_prefix -t str -n SRC_FILE_PREFIX

In [None]:
%%send_to_spark -i default_bucket -t str -n SRC_BUCKET_NAME

## Lets Convert PDF into Text

In [None]:
import os
import boto3
import json
from PyPDF2 import PdfReader
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader, PyPDFDirectoryLoader
import io

In [None]:
print(f"Source bucket and prefix to read pdf files ---> {SRC_BUCKET_NAME} {SRC_FILE_PREFIX}")

In [None]:
def list_files_in_s3_bucket_prefix(bucket_name, prefix):
    
    s3 = boto3.client('s3')

    # Paginate through the objects in the specified bucket and prefix, and collect all keys (file paths)
    paginator = s3.get_paginator('list_objects_v2')
    page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)

    file_paths = []
    for page in page_iterator:
        if "Contents" in page:
            for obj in page["Contents"]:
                if os.path.basename(obj["Key"]):
                    file_paths.append(obj["Key"])

    return file_paths

all_pdf_files = list_files_in_s3_bucket_prefix(
    bucket_name=SRC_BUCKET_NAME, 
    prefix=SRC_FILE_PREFIX
)
print(f"Found {len(all_pdf_files)} files ---> {all_pdf_files}")

Let's prep a list to process files along with bucket names 

In [None]:
all_pdf_files = [(SRC_BUCKET_NAME, fpath) for fpath in all_pdf_files]
type(all_pdf_files)

Let's convert our list to a spark RDD for parallelization of our list

In [None]:
pdfs_rdd = spark.sparkContext.parallelize(all_pdf_files)
type(pdfs_rdd)

Each code node reaches out a pdf file from our list, downloads the pdf file into memory and returns a PyPDF2 class reference for downstream workloads

![EMR Read PDFs into Memory](media/EMR-Doc-Read.jpg)

In [None]:
def load_pdf_from_s3_into_memory(row):
    """
    Load a PDF file from an S3 bucket directly into memory.
    """
    try:
        src_bucket_name, src_file_key = row 
        s3 = boto3.client('s3')
        pdf_file = io.BytesIO()
        s3.download_fileobj(src_bucket_name, src_file_key, pdf_file)
        pdf_file.seek(0)
        pdf_reader = PdfReader(pdf_file)
        return (src_file_key, pdf_reader, len(pdf_reader.pages))
    
    except Exception as e:    
        return (os.path.basename(src_file_key), str(e))

Let's concurrently load pdf files into memory using rdd map and collect

In [None]:
pdfs_in_memory = pdfs_rdd.map(load_pdf_from_s3_into_memory).collect()

In [None]:
print(f"all pdfs combined there are ---> {sum([pg_num for _, _, pg_num in pdfs_in_memory])} pages to process!")

In [None]:
class CustomDocument:
    def __init__(self, text, path, number):
        self.page_content = text
        self.metadata = {
            'source': path, 
            'page': number  
        }

    def __repr__(self):
        # This method is for representing the object in a way that’s clear to a human (also can be used for debugging)
        return f"Document(page_content='{self.page_content}', metadata={self.metadata})"

    # Optionally, if you need a string representation of the instance that is more user-friendly, 
    # you can implement the __str__ method
    def __str__(self):
        return f"Page Content: {self.page_content}\nSource: {self.metadata['source']}\nPage Number: {self.metadata['page']}"
    
def extract_text_from_pdf_reader(row):
    """ 
    Extract text from a page of the document 
    """
    try:
        doc_path, page_num = row
        page_text = global_pdfs_in_mem_dict[doc_path].pages[page_num].extract_text()
        return page_text, doc_path, page_num
    except Exception as e:
        return str(e), doc_path, page_num

In [None]:
global_pdfs_in_mem_dict = {_key: pdf_reader for _key, pdf_reader, _ in pdfs_in_memory}

In [None]:
docs_instances = []
for (file_src, _, page_count) in pdfs_in_memory:
    for pg_num in range(page_count):
        docs_instances.append((file_src, pg_num))
print(f"Created {len(docs_instances)} parallel instances to process!")

In [None]:
docs_instances_rdd = spark.sparkContext.parallelize(docs_instances)

Every PDF document has 'n' pages to process, this task can be executed in a parallel fashion using Spark Processing. 

Each Document is split page by page, each page from a global reference of in memory pdfs.

![PageLevelProcessingEMRPDFtoTxt](media/PageLevelProcessingEMRPDFtoTxt.jpg)

In [None]:
documents = docs_instances_rdd.map(extract_text_from_pdf_reader).collect()
documents_custom = [
    CustomDocument(text=text, path=doc_source, number=page_num) 
    for text, doc_source, page_num in documents
]

In [None]:
documents_custom[121]

We split pages using a reference chunk size, chunk size is an experimental value. To learn more about chunk size and how RecursiveCharacterTextSplitter, see: https://python.langchain.com/docs/modules/data_connection/document_transformers/text_splitters/recursive_text_splitter

In [None]:
global_text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=750,
    chunk_overlap=10
)

In [None]:
docs = global_text_splitter.split_documents(documents_custom)
print(f"Total number of docs after split {len(docs)}")

In [None]:
print(docs[2695])

In [None]:
def generate_embeddings(input_text_sample):
    
    assert isinstance(input_text_sample, str), f"Input must be a single string but found " 
    
    lambda_client = boto3.client('lambda', region_name='us-west-2') 

    # Prepare the data to send to the Lambda function
    data = {
        "input": input_text_sample
    }

    # Invoke the Lambda function
    response = lambda_client.invoke(
        FunctionName="invokeEmbeddingEndpoint",
        InvocationType="RequestResponse",
        Payload=json.dumps(data)
    )

    # Decode and load the response payload
    response_payload = json.loads(response['Payload'].read().decode("utf-8"))

    # Extract status and embeddings from the response
    status_code, embeddings = int(response_payload['statusCode']), json.loads(response_payload['body'])

    return status_code, embeddings
    
class EmbeddingsGenerator:
    
    @staticmethod
    def embed_documents(input_text, normalize=True):
        """
        Generate embeddings for the provided text, invoking a Lambda function.
        """
        assert isinstance(input_text, list), "Input type must me list to embed_documents function"
        
        input_text_rdd = spark.sparkContext.parallelize(input_text)
        
        embeddings_generated = input_text_rdd.map(generate_embeddings).collect()
        
        embedding_response = []
        for s_code, embeddings in embeddings_generated:
            if s_code == 200:
                embedding_response.append(embeddings)
            else:
                pass
        
        return embedding_response
    
    @staticmethod
    def embed_query(input_text):
        status_code, embedding = generate_embeddings(input_text)
        if status_code == 200:
            return embedding
        else: 
            None

In [None]:
response_code, sample_sentence_embedding = generate_embeddings(docs[1000].page_content)
print(f"Status {response_code}, Embedding size of the document --->", len(sample_sentence_embedding))

In [None]:
%%local
INDEX_NAME_OSE = "amz-guides-index"

In [None]:
%%send_to_spark -i INDEX_NAME_OSE -t str -n INDEX_NAME_OSE

In [None]:
def get_secret(secret_name, region_name="us-west-2"):
    # Create a Secrets Manager client
    session = boto3.session.Session()
    client = session.client(
        service_name='secretsmanager',
        region_name=region_name
    )
    get_secret_value_response = client.get_secret_value(
        SecretId=secret_name
    )
    secrets = json.loads(get_secret_value_response['SecretString'])
    user = secrets['username']
    pwd = secrets['password']
    return user, pwd

# Use the function
my_secret_name = "OpenSearchSecret-workshop-studio-cfn"  # Replace with your secret name
my_region_name = "us-west-2"     # Replace with your AWS region
user, pwd = get_secret(my_secret_name, my_region_name)
print(f"Session user and pwd ---> ", user, pwd)

In [None]:
resp = EmbeddingsGenerator.embed_documents([d.page_content for d in docs[:100]])

In [None]:
len(resp)

In [None]:
import time
from langchain.vectorstores import OpenSearchVectorSearch


start = time.time()
docsearch = OpenSearchVectorSearch.from_documents(
    docs[:100], 
    EmbeddingsGenerator, 
    opensearch_url="https://search-opensearchservi-ol9kboy2sp4p-o62szzg7yeeufr3mi26nanxfke.us-west-2.es.amazonaws.com",
    bulk_size=len(docs),
    http_auth=("admin", "Admin123-"),
    index_name=INDEX_NAME_OSE,
    engine="faiss"
)

end = time.time()
print(f"Total Time for ingestion: {round(end - start, 2)} secs")

In [None]:
query = "What is a SageMaker"
sample_responses = docsearch.similarity_search(
    query, 
    k=5, 
    space_type="cosineSimilarity", 
    search_type="painless_scripting"
)

In [None]:
sample_responses[-1].page_content

## Putting it All Together

In [None]:
%%local
!python3 -m pip install -q opensearch-py==2.3.2 langchain==0.0.310 typing_extensions==4.7.1

In [None]:
%%local
import boto3
from langchain.vectorstores import OpenSearchVectorSearch


class EmbeddingGenerator:
    def __init__(self):
        self.lambda_client = boto3.client('lambda', region_name='us-west-2')
    
    def embed_query(self, input_text_sample):
        """Generate embeddings for the input text."""
        
        # Prepare the data to send to the Lambda function.
        data = {"input": [input_text_sample]}

        # Invoke the Lambda function.
        response = self.lambda_client.invoke(
            FunctionName="InvokeEndpoint",
            InvocationType="RequestResponse",
            Payload=json.dumps(data)
        )

        # Decode and load the response payload.
        response_payload = json.loads(response['Payload'].read().decode("utf-8"))

        # Extract status and embeddings from the response.
        status_code, embeddings = int(response_payload['statusCode']), json.loads(response_payload['body'])

        return embeddings


embedding_generator = EmbeddingGenerator()

docsearch = OpenSearchVectorSearch(
    index_name=INDEX_NAME_OSE,
    embedding_function=embedding_generator,
    opensearch_url="https://search-opensearchservi-ol9kboy2sp4p-o62szzg7yeeufr3mi26nanxfke.us-west-2.es.amazonaws.com",
    http_auth=("admin", "Admin123-"),
    engine="faiss"
)

In [None]:
%%local
import re
import json
from typing import Dict
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.llms import SagemakerEndpoint
from langchain.chains.question_answering import load_qa_chain
from langchain.chains import RetrievalQA


class ContentHandler(LLMContentHandler):
    
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        
        pattern = r"(QUESTION:\n)(.*?)(\n\n)"
        
        match = re.search(pattern, prompt, re.DOTALL)
        
        # question_block = match.group(0)
        query_only = match.group(2)
        
        modified_prompt = re.sub(pattern, '', prompt, flags=re.DOTALL)
        
        body = {
            "inputs": [
                [
                     {
                         "role": "system", 
                         "content": modified_prompt
                     },
                    {
                        "role": "user", 
                        "content": query_only
                    },
                ]   
            ], 
            "parameters": model_kwargs
        }
        input_str = json.dumps(body)
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        system_response = response_json[0]['generation']['content']
        return system_response.strip()

content_handler = ContentHandler()


llm_sm_ep = SagemakerEndpoint(
    endpoint_name="jumpstart-dft-meta-textgeneration-llama-2-7b-f", 
    region_name="us-east-1",
    model_kwargs={
        "max_new_tokens": 512, 
        "top_p": 1.0, 
        "temperature": 0.1, 
        "return_full_text": False
    },
    content_handler=content_handler,
    endpoint_kwargs={"CustomAttributes": 'accept_eula=true'}
)

In [None]:
%%local
from langchain import PromptTemplate

template = """
Answer the following QUESTION based on the CONTEXT
given. If you do not know the answer and the CONTEXT doesn't
contain the answer truthfully say "I don't know".

CONTEXT:
{context}

QUESTION:
{question}

ANSWER:
"""
prompt_template = PromptTemplate(
    template=template, 
    input_variables=['context', 'question']
)

llm_qa_smep_chain = RetrievalQA.from_chain_type(
    llm=llm_sm_ep,
    chain_type='stuff',
    retriever=docsearch.as_retriever(search_kwargs={"k": 10, "space_type": "cosineSimilarity", "space_type": "painless_scripting"}),
    return_source_documents=True,
    chain_type_kwargs={"prompt": prompt_template}
)


def pretty_print(chain_op):
    question = chain_op['query']
    
    response = chain_op['result']
    
    sources = "\n".join([f"-{src.metadata['source'].split('/')[-1]} (page: {src.metadata['page']})" for src in chain_op['source_documents']])
    
    stdout = f"""Question:\n> {question}\n\n================\nSystem:\n> {response}\n\n================\nSources:\n{sources}
    """
    print(stdout)

In [None]:
%%local
pretty_print(llm_qa_smep_chain("What is a SageMaker Training job and how do you run it?"))

In [None]:
%%local
pretty_print(llm_qa_smep_chain("What types of instances are supported for Training Job?"))

In [None]:
%%local
pretty_print(llm_qa_smep_chain("How to install packages on EC2 instances using Command line?"))

In [None]:
%%local
pretty_print(llm_qa_smep_chain("How to Create a Training Job using Boto3 SDK?"))

In [None]:
%%local
pretty_print(llm_qa_smep_chain("How can I deploy a model to SageMaker Hosting service?"))

In [None]:
%%local
pretty_print(llm_qa_smep_chain("How do I validate a model using boto3 sdk and visualize results using matplotlib library?"))

In [None]:
%%local
pretty_print(llm_qa_smep_chain("How can I use the console to add a git repository to my SageMaker account?"))