In [None]:
!pip install langchain --quiet
!pip install unstructured --quiet
!pip install "unstructured[md]" --quiet
!pip install chromadb --quiet
!pip install -U boto3 --quiet 

In [None]:
import boto3
import os

from langchain.document_loaders import UnstructuredMarkdownLoader, DirectoryLoader, S3DirectoryLoader
from langchain.text_splitter import MarkdownTextSplitter
from langchain.chains import RetrievalQA
from langchain.vectorstores import Chroma

In [None]:
STACK = 'LLMStack'

In [None]:
def get_cf_config(stackname: str):
    cf_client = boto3.client('cloudformation')

    response = cf_client.describe_stacks(StackName=stackname)
    outputs = response["Stacks"][0]["Outputs"]

    cf_outputs = {}
    for i in outputs:
        cf_outputs[i['OutputKey']] = i['OutputValue']
    return cf_outputs

In [None]:
stack_config = get_cf_config(STACK)

In [None]:
def download_files_from_s3(bucket_name, dir_name: str = 'input_data'):
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
        
    s3_client = boto3.client('s3')
    s3 = boto3.resource('s3')
    bucket = s3.Bucket(bucket_name)
    objs = bucket.objects.filter(Prefix=dir_name)
    for obj in objs:
        s3_client.download_file(bucket_name, obj.key, obj.key)

In [None]:
download_files_from_s3(stack_config['S3BucketName'])

In [None]:
loader = DirectoryLoader('input_data',
                         loader_cls=UnstructuredMarkdownLoader,
                         show_progress=True,
                         use_multithreading=True)

docs = loader.load()
print(f'Lazy loaded {len(docs)} docs.')

In [None]:
md_splitter = MarkdownTextSplitter()
texts = md_splitter.split_documents(docs)
print(f'There are {len(texts)} texts after splitting.')

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter

chunk_size = 500
chunk_overlap = 100
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=chunk_size, chunk_overlap=chunk_overlap
)

# Split
splits = text_splitter.split_documents(texts)
print(f'There are {len(splits)} texts after splitting.')

In [None]:
import json
from typing import Dict, List
from langchain.embeddings import SagemakerEndpointEmbeddings

from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler

In [None]:
class ContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"text_inputs": inputs, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> List[List[float]]:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["embedding"]

In [None]:
content_handler = ContentHandler()

embeddings = SagemakerEndpointEmbeddings( 
    endpoint_name=stack_config['SageMakerEndpointEmbeddings'],
    region_name=stack_config['AWSRegion'],
    content_handler=content_handler
)

In [None]:
import chromadb

os.makedirs('chroma')
chroma_client = chromadb.PersistentClient(path="chroma")
collection = chroma_client.create_collection(name="sagemaker_docs")

In [None]:
# Loading is slow due to inferencing one doc at a time.
# Required some inference in chunks to speed up

from tqdm import tqdm

collection_input_embeddings = []
collection_input_documents = []
collection_input_metadatas = []
collection_input_ids = []
errors = []

for idx, text in tqdm(enumerate(splits), total=len(splits)):
    try:
        page_content = text.page_content
        metadata = text.metadata
        idx_str = f'id{idx}'
        emb = embeddings.embed_query(page_content)

        collection_input_embeddings.append(emb)
        collection_input_documents.append(page_content)
        collection_input_metadatas.append(metadata)
        collection_input_ids.append(idx_str)
    except Exception as e:
        error = {
            'metadata': metadata,
            'page_content': page_content,
            'error': e
        }
        errors.append(error)

In [None]:
collection.add(
    embeddings=collection_input_embeddings,
    documents=collection_input_documents,
    metadatas=collection_input_metadatas,
    ids=collection_input_ids)

In [None]:
bucket_name = 'llmstacks3-s3bucket-1b39zw19b74me'

In [None]:
s3 = boto3.client('s3')

def upload_to_s3(local_file, s3_bucket, s3_object_name):
    s3.upload_file(local_file, s3_bucket, s3_object_name)
    print(f"Uploaded {local_file} to {s3_bucket}/{s3_object_name}")

for root, dirs, files in os.walk('chroma'):
    for file in files:
        local_path = os.path.join(root, file)
        upload_to_s3(local_path, bucket_name, local_path)