# Super charge your LLMs with RAG at scale using AWS Glue for Apache Spark

This sample notebook demonstrates how to build RAG using AWS Glue for Apache Spark. In this sample, we leverage LangChain integrating with AWS Glue for Apache Spark, Amazon SageMaker JumpStart, and Amazon OpenSearch Serverless. To make this solution scalable and customizable, we leverage Apache Spark’s distributed fashion and PySpark’s flexible scripting capabilities. We use OpenSearch Serverless as a sample vector store, and use Llama 3.1 model provided in SageMaker JumpStart.

In [None]:
%session_id_prefix langchain-spark-rag-jumpstart-
%glue_version 4.0
%idle_timeout 2880
%worker_type G.1X
%number_of_workers 5

In [None]:
%additional_python_modules langchain==0.2.12, langchain-community==0.2.11, langchain-aws==0.1.16, typing-extensions, typing-inspect, beautifulsoup4, lxml, markdownify, pydantic, boto3, botocore, opensearch-py, sagemaker

## Vectorstore setup

In [None]:
import json
import os
os.environ["AWS_DEFAULT_REGION"] = "<region>"
iam_role_arn = "<your-iam-role-arn>"

In [None]:
import boto3
import time
vector_store_name = 'langchain-spark-rag'
index_name = "langchain-spark-rag-index"
encryption_policy_name = "langchain-spark-rag-sp"
network_policy_name = "langchain-spark-rag-np"
access_policy_name = 'langchain-spark-rag-ap'
identity = iam_role_arn

aoss_client = boto3.client('opensearchserverless')

security_policy = aoss_client.create_security_policy(
    name = encryption_policy_name,
    policy = json.dumps(
        {
            'Rules': [{'Resource': ['collection/' + vector_store_name],
            'ResourceType': 'collection'}],
            'AWSOwnedKey': True
        }),
    type = 'encryption'
)

network_policy = aoss_client.create_security_policy(
    name = network_policy_name,
    policy = json.dumps(
        [
            {'Rules': [{'Resource': ['collection/' + vector_store_name],
            'ResourceType': 'collection'}],
            'AllowFromPublic': True}
        ]),
    type = 'network'
)

collection = aoss_client.create_collection(name=vector_store_name,type='VECTORSEARCH')

while True:
    status = aoss_client.list_collections(collectionFilters={'name':vector_store_name})['collectionSummaries'][0]['status']
    if status in ('ACTIVE', 'FAILED'): break
    time.sleep(10)

access_policy = aoss_client.create_access_policy(
    name = access_policy_name,
    policy = json.dumps(
        [
            {
                'Rules': [
                    {
                        'Resource': ['collection/' + vector_store_name],
                        'Permission': [
                            'aoss:CreateCollectionItems',
                            'aoss:DeleteCollectionItems',
                            'aoss:UpdateCollectionItems',
                            'aoss:DescribeCollectionItems'],
                        'ResourceType': 'collection'
                    },
                    {
                        'Resource': ['index/' + vector_store_name + '/*'],
                        'Permission': [
                            'aoss:CreateIndex',
                            'aoss:DeleteIndex',
                            'aoss:UpdateIndex',
                            'aoss:DescribeIndex',
                            'aoss:ReadDocument',
                            'aoss:WriteDocument'],
                        'ResourceType': 'index'
                    }],
                'Principal': [identity],
                'Description': 'Easy data policy'}
        ]),
    type = 'data'
)

host = "https://" + collection['createCollectionDetail']['id'] + '.' + os.environ.get("AWS_DEFAULT_REGION", None) + '.aoss.amazonaws.com'


The above cell takes a few minutes (~3 minutes) to complete.

## Sample document download

In [None]:
import requests

content = requests.get("https://lilianweng.github.io/posts/2023-06-23-agent/")

s3_resource = boto3.resource("s3")

account_id = boto3.client('sts').get_caller_identity()['Account']
region = os.environ.get("AWS_DEFAULT_REGION", None)
bucket_name = f'langchain-spark-rag-jumpstart-{account_id}-{region}'

bucket = s3_resource.Bucket(bucket_name)
# bucket.create()
if region == 'us-east-1':
    bucket.create()
else:
    bucket.create(CreateBucketConfiguration={'LocationConstraint': region})

key = "data/langchain-spark-rag/content.html"
obj = s3_resource.Object(bucket_name, key)
obj.put(Body=content.text)

## Document preparation

#### Read HTML file

In [None]:
df_html = spark.read.text(f"s3://{bucket_name}/{key}", wholetext=True)
df_html.show()

#### Parse and clean up HTML

In [None]:
from bs4 import BeautifulSoup
from markdownify import MarkdownConverter

def parse_html(html):
    soup = BeautifulSoup(html, "lxml")
    soup.smooth()
    md = MarkdownConverter(heading_style="ATX").convert_soup(soup)
    return format_md(md)


In [None]:
import re

def format_md(md):
    # Remove extra preceding white spaces in a line
    md = re.sub(r"^ +", "", md, flags=re.MULTILINE)
    # Replace all newlines with spaces
    md = re.sub(r"\n", " ", md)
    # Remove redundant spaces
    md = re.sub(r"\s\s+", " ", md)
    # Convert "\_" to "_"
    md = re.sub(r"\\+_", "_", md)

    return md

#### Chunking HTML

In [None]:
from langchain.text_splitter import MarkdownTextSplitter

def chunk_md(md):
    splitter = MarkdownTextSplitter(chunk_size=1000, chunk_overlap=100)
    return splitter.split_text(md)

#### Embedding

In [None]:
_MODEL_CONFIG_ = {
    "huggingface-sentencesimilarity-all-MiniLM-L6-v2": {
        "instance type": "ml.g5.2xlarge",
        "env": {"SAGEMAKER_MODEL_SERVER_WORKERS": "1", "TS_DEFAULT_WORKERS_PER_MODEL": "1"},
    },
    "meta-textgeneration-llama-3-1-8b-instruct": {
        "instance type": "ml.g5.2xlarge",
        "env": {"SAGEMAKER_MODEL_SERVER_WORKERS": "1", "TS_DEFAULT_WORKERS_PER_MODEL": "1"},
    }
}

In [None]:
from sagemaker.jumpstart.model import JumpStartModel
from sagemaker.utils import name_from_base

newline, bold, unbold = "\n", "\033[1m", "\033[0m"

for model_id in _MODEL_CONFIG_:
    endpoint_name = name_from_base(f"jumpstart-example-raglc-{model_id}")
    inference_instance_type = _MODEL_CONFIG_[model_id]["instance type"]

    try:
        # Create and deploy the JumpStart model
        model = JumpStartModel(
            model_id=model_id,
            role=iam_role_arn,
            instance_type=inference_instance_type,
            predictor_cls=None,  # Use default predictor
            env=_MODEL_CONFIG_[model_id]["env"]
        )

        predictor = model.deploy(
            initial_instance_count=1,
            instance_type=inference_instance_type,
            endpoint_name=endpoint_name,
            accept_eula=True
        )

        print(f"{bold}Model {model_id} has been deployed successfully.{unbold}{newline}")
        _MODEL_CONFIG_[model_id]["endpoint_name"] = endpoint_name

    except Exception as e:
        print(f"Error deploying model {model_id}: {str(e)}")
        continue

In [None]:
from langchain_community.embeddings import SagemakerEndpointEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain_community.vectorstores import OpenSearchVectorSearch
from opensearchpy import RequestsHttpConnection, AWSV4SignerAuth

service = "aoss"
region = os.environ.get("AWS_DEFAULT_REGION", None)

def process_batch(chunks):
    from typing import List, Dict, Any
    credentials = boto3.Session().get_credentials()
    awsauth = AWSV4SignerAuth(credentials, region, service)

    class ContentHandler(EmbeddingsContentHandler):
        content_type = "application/json"
        accepts = "application/json"

        def transform_input(self, prompts: List[str], model_kwargs: Dict) -> bytes:
            input_str = json.dumps({"text_inputs": prompts, "mode": "embedding", **model_kwargs})
            return input_str.encode('utf-8')

        def transform_output(self, output: bytes) -> List[List[float]]:
            response_json = json.loads(output.read().decode("utf-8"))
            return response_json["embedding"]

    embeddings = SagemakerEndpointEmbeddings(
        endpoint_name=_MODEL_CONFIG_["huggingface-sentencesimilarity-all-MiniLM-L6-v2"]["endpoint_name"],
        region_name=region,
        content_handler=ContentHandler(),
    )

    vectorstore = OpenSearchVectorSearch.from_texts(
        chunks, 
        embeddings,
        index_name=index_name, 
        opensearch_url=host, 
        http_auth=awsauth,
        timeout=100,
        use_ssl=True,
        verify_certs=True,
        connection_class=RequestsHttpConnection
    )
    ret = vectorstore.add_texts(texts=chunks)
    return {"num_success": len(ret)}

#### Pre-process HTML document

In [None]:
from pyspark.sql.functions import col, udf

def process_html(html):
    cleaned_md = parse_html(html)
    chunks = chunk_md(cleaned_md)
    process_batch(chunks)
    return cleaned_md

process_html_udf = udf(lambda z: process_html(z))

In [None]:
df_html_processed = df_html.select(process_html_udf(col("value")).alias("text"))
df_html_processed.show()

## Question answering

In [None]:
import logging
from typing import List, Dict, Any
from langchain_community.llms import SagemakerEndpoint
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler

credentials = boto3.Session().get_credentials()
awsauth = AWSV4SignerAuth(credentials, region, service)

sagemaker_client = boto3.client('sagemaker-runtime')


class EmbeddingsContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompts: List[str], model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"text_inputs": prompts, "mode": "embedding", **model_kwargs})
        return input_str.encode('utf-8')

    def transform_output(self, output: bytes) -> List[List[float]]:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["embedding"]

class LLMContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({
            "inputs": prompt,
            "parameters": {
                "max_new_tokens": 256,
                "top_p": 0.9,
                "temperature": model_kwargs.get("temperature", 0.6),
            }
        })
        return input_str.encode('utf-8')

    def transform_output(self, output: bytes) -> str:
        try:
            response_json = json.loads(output.read().decode("utf-8"))
            logging.info(f"Raw output: {response_json}")
            
            if isinstance(response_json, list):
                if response_json and isinstance(response_json[0], dict):
                    return response_json[0].get("generated_text", str(response_json[0]))
                else:
                    return str(response_json[0]) if response_json else ""
            elif isinstance(response_json, dict):
                return response_json.get("generated_text", str(response_json))
            else:
                return str(response_json)
        except Exception as e:
            logging.error(f"Error in transform_output: {str(e)}")
            logging.error(f"Raw output: {output}")
            raise

embeddings = SagemakerEndpointEmbeddings(
    endpoint_name=_MODEL_CONFIG_["huggingface-sentencesimilarity-all-MiniLM-L6-v2"]["endpoint_name"],
    region_name=region,
    content_handler=EmbeddingsContentHandler(),
)

opensearch_vector_search_client = OpenSearchVectorSearch(
    index_name=index_name,
    embedding_function=embeddings,
    opensearch_url=host,
    http_auth=awsauth,
    timeout=100,
    use_ssl=True,
    verify_certs=True,
    connection_class=RequestsHttpConnection,
    vector_field="vector_field",
    engine="faiss"
)

llm = SagemakerEndpoint(
    endpoint_name=_MODEL_CONFIG_["meta-textgeneration-llama-3-1-8b-instruct"]["endpoint_name"],
    client=sagemaker_client,
    model_kwargs={"temperature": 1e-10},
    content_handler=LLMContentHandler(),
)

In [None]:
from langchain.chains import RetrievalQA
qa_aoss = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=opensearch_vector_search_client.as_retriever(),
    return_source_documents=True,
)

In [None]:
from langchain.load.dump import dumps
query = "What is Task Decomposition?"
results = opensearch_vector_search_client.similarity_search(query)
print(dumps(results, pretty=True))

In [None]:
qa_aoss.invoke(query)

## Clean up

In [None]:
## S3
bucket.object_versions.delete()
bucket.delete()

## OpenSearch Serverless
aoss_client.delete_collection(id=collection['createCollectionDetail']['id'])
aoss_client.delete_access_policy(name=access_policy_name, type='data')
aoss_client.delete_security_policy(name=encryption_policy_name, type='encryption')
aoss_client.delete_security_policy(name=network_policy_name, type='network')

## SageMaker
sagemaker_client = boto3.client('sagemaker')
for model_id in _MODEL_CONFIG_:
    sagemaker_client.delete_endpoint(EndpointName=_MODEL_CONFIG_[model_id]["endpoint_name"])