In [None]:
### Requirements
# pip install "sycamore-ai[opensearch,local-inference]"

In [None]:
import os

os.environ["ARYN_API_KEY"]='YOUR-ARYN-API-KEY'

In [None]:
import sycamore
from sycamore.context import ExecMode
from sycamore.transforms.partition import ArynPartitioner
from sycamore.transforms.extract_schema import LLMPropertyExtractor
from sycamore.transforms.summarize_images import SummarizeImages, LLMImageSummarizer
from sycamore.transforms.standardizer import USStateStandardizer, DateTimeStandardizer, ignore_errors
from sycamore.transforms.merge_elements import GreedySectionMerger
from sycamore.functions.tokenizer import HuggingFaceTokenizer
from sycamore.transforms.embed import SentenceTransformerEmbedder
from sycamore.llms import Bedrock, BedrockModels

import pyarrow.fs

llm = Bedrock(BedrockModels.CLAUDE_3_SONNET)

paths = ["s3://aryn-public/ntsb/"]

context = sycamore.init()
# Add exec_mode=ExecMode.LOCAL to .init to run without Ray
docset = context.read.binary(paths=paths, binary_format="pdf")
docset = docset.materialize(path="./opensearch-tutorial/downloaded-docset", source_mode=sycamore.MATERIALIZE_USE_STORED)
# Make sure your Aryn API key is accessible in the environment variable ARYN_API_KEY
partitioned_docset = (docset.partition(partitioner=ArynPartitioner(extract_table_structure=True, extract_images=True))
        .materialize(path="./opensearch-tutorial/partitioned-docset", source_mode=sycamore.MATERIALIZE_USE_STORED)
        )
partitioned_docset.execute()

In [None]:
schema = {
            'type': 'object',
            'properties': {'accidentNumber': {'type': 'string'},
                           'dateAndTime': {'type': 'date'},
                           'location': {'type': 'string', 'description': 'US State where the incident occured'},
                           'aircraft': {'type': 'string'},
                           'aircraftDamage': {'type': 'string'},
                           'injuries': {'type': 'string'},
                           'definingEvent': {'type': 'string'}},
            'required': ['accidentNumber',
                         'dateAndTime',
                         'location',
                         'aircraft']
    }

schema_name = 'FlightAccidentReport'
property_extractor=LLMPropertyExtractor(llm=llm, num_of_elements=20, schema_name=schema_name, schema=schema)

enriched_docset = (
    partitioned_docset
        # Extracts the properties based on the schema defined  
        .extract_properties(property_extractor=property_extractor)

        # Summarizes images that were extracted using an LLM
        .transform(SummarizeImages, summarizer=LLMImageSummarizer(llm=llm))
)

formatted_docset = (
    enriched_docset
    
        # Converts state abbreviations to their full names.
        .map( lambda doc: ignore_errors(doc, USStateStandardizer, ["properties","entity","location"]))

        # Converts datetime into a common format
        .map( lambda doc: ignore_errors(doc, DateTimeStandardizer, ["properties","entity","dateAndTime"]))
)


merger = GreedySectionMerger(tokenizer=HuggingFaceTokenizer("sentence-transformers/all-MiniLM-L6-v2"), max_tokens=512)
chunked_docset = formatted_docset.merge(merger=merger)

model_name = "thenlper/gte-small"

embedded_docset = chunked_docset.spread_properties(["entity", "path"]).explode().embed(embedder=SentenceTransformerEmbedder(batch_size=10_000, model_name=model_name))

embedded_docset = embedded_docset.materialize(path="./opensearch-tutorial/embedded-docset", source_mode=sycamore.MATERIALIZE_USE_STORED)
embedded_docset.execute()

In [None]:
#Set OpenSearch Service configuration for connector.

# If you're running against localhost, or directly to a cluster,
# you should use port 9200.
openSearch_client_args = {
    "hosts": [{"host": "YOUR-DOMAIN-ENDPOINT", "port": 443}],
    "http_compress": True,
    "http_auth": ("YOUR-OPENSEARCH-USERNAME", "YOUR-OPENSEARCH-PASSWORD"),
    "use_ssl": True,
    "verify_certs": False,
    "ssl_assert_hostname": False,
    "ssl_show_warn": False,
    "timeout": 120,
}

index_settings = {
    "body": {
        "settings": {
            "index.knn": True,
            "number_of_shards": 2,
            "number_of_replicas": 1,
        },
        "mappings": {
            "properties": {
                "embedding": {
                    "type": "knn_vector",
                    "dimension": 384,
                    "method": {"name": "hnsw", "engine": "faiss"},
                },
                "text": {"type": "text"},
            }
        },
    }
}

In [None]:
embedded_docset.write.opensearch(
    os_client_args=openSearch_client_args,
    index_name="aryn-rag-demo",
    index_settings=index_settings,
)

In [None]:
from sycamore.query.execution.operations import summarize_data
from sycamore.connectors.opensearch.utils import get_knn_query

#Run the first RAG pipeline

context_size = 20
question = "What was common with incidents in Texas, and how does that differ from incidents in California?"

text_embedder = SentenceTransformerEmbedder(batch_size=10_000, model_name=model_name)
os_client_args = openSearch_client_args

os_query = get_knn_query(query_phrase=question, context=context, text_embedder=text_embedder)

docset = context.read.opensearch(index_name="aryn-rag-demo14", query=os_query, os_client_args=os_client_args)
docset = docset.limit(context_size)

answer = summarize_data(
            question=question,
            result_description="Documents returned that can answer a quesiton about flight incidents",
            result_data=[docset],
            context=context,
            llm=llm
        )
print(answer)

In [None]:
#Run the second RAG pipeline using a filter in the retrieval step

question = "What incidents occured in California?"

os_query = get_knn_query(query_phrase=question, context=context, text_embedder=text_embedder)
os_query["query"]["knn"]["embedding"]["filter"] = {"match_phrase": {"properties.entity.location": "California"}}

docset = context.read.opensearch(index_name="aryn-rag-demo14", query=os_query, os_client_args=os_client_args)
docset = docset.limit(context_size)

answer = summarize_data(
            question=question,
            result_description="Documents returned that can answer a quesiton about flight incidents",
            result_data=[docset],
            context=context,
            llm=llm
        )
print(answer)