In [None]:
!python --version

In [None]:
import sagemaker, boto3, json
from sagemaker.session import Session

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

In [None]:
from typing import Dict, List
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.llms.sagemaker_endpoint import ContentHandlerBase
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
import json


class ContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"text_inputs": inputs, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> List[List[float]]:
        response_json = json.loads(output.read().decode("utf-8"))        
        embeddings = response_json['embedding']
        return embeddings


content_handler = ContentHandler()


embeddings = SagemakerEndpointEmbeddings(    
    endpoint_name="hf-textembedding-all-minilm-l6-v2",
    region_name=aws_region,
    content_handler=content_handler,
)

In [None]:
from langchain.vectorstores import FAISS

VECTOR_DB_DIR = "vector-db"
vector_db = FAISS.load_local(VECTOR_DB_DIR, embeddings)

In [None]:
query = "What is this document about"
docs = vector_db.similarity_search(query)

print(len(docs))
for doc in docs:
    print(doc.page_content)
    print('\n')