In [1]:
import boto3
import json
import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import TextLoader
from langchain.chains import RetrievalQA
from langchain.llms import SagemakerEndpoint
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.vectorstores import FAISS
from typing import Any, Dict, List, Optional
import os


In [2]:
TEXT_EMBEDDING_MODEL_ENDPOINT_NAME = 'jumpstart-dft-hf-sentencesimilarity-20250325-230336' #INSERT EMBEDDING ENDPOINT NAME IF DIFFERENT
TEXT_GENERATION_MODEL_ENDPOINT_NAME = 'jumpstart-dft-hf-llm-mistral-7b-ins-20250325-230300' #INSERT TEXT GENERATION ENDPOINT NAME IF DIFFERENT

REGION_NAME = boto3.session.Session().region_name

In [4]:
import boto3
import json
import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import TextLoader
from langchain.chains import RetrievalQA
from langchain.llms import SagemakerEndpoint
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.vectorstores import FAISS
from typing import Any, Dict, List, Optional
import os

# Setup
TEXT_EMBEDDING_MODEL_ENDPOINT_NAME = 'jumpstart-dft-hf-sentencesimilarity-20250325-230336' #INSERT EMBEDDING ENDPOINT NAME IF DIFFERENT
TEXT_GENERATION_MODEL_ENDPOINT_NAME = 'jumpstart-dft-hf-llm-mistral-7b-ins-20250325-230300' #INSERT TEXT GENERATION ENDPOINT NAME IF DIFFERENT
REGION_NAME = boto3.session.Session().region_name

# Ensure correct naming
os.environ["AWS_DEFAULT_REGION"] = REGION_NAME


# Embedding Setup
class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
    def embed_documents(self, texts: List[str], chunk_size: int = 5) -> List[List[float]]:
        """Compute doc embeddings using a SageMaker Inference Endpoint.

        Args:
            texts: The list of texts to embed.
            chunk_size: The chunk size defines how many input texts will
                be grouped together as request. If None, will use the
                chunk size specified by the class.

        Returns:
            List of embeddings, one for each text.
        """
        results = []
        _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size

        for i in range(0, len(texts), _chunk_size):
            response = self._embedding_func(texts[i : i + _chunk_size])
            print
            results.extend(response)
        return results


class ContentHandler(EmbeddingsContentHandler):  # Inherit from EmbeddingsContentHandler
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        # Use "embedding" mode for both documents and queries.
        input_str = json.dumps({"text_inputs": prompt, "mode": "embedding", **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> List[float]: #Expects to return a list of floats
        response_json = json.loads(output.read().decode("utf-8"))
        embeddings = response_json["embedding"]
        return embeddings


content_handler = ContentHandler()

sagemakerEndpointEmbeddingsJumpStart = SagemakerEndpointEmbeddingsJumpStart(
    endpoint_name=TEXT_EMBEDDING_MODEL_ENDPOINT_NAME,
    region_name=REGION_NAME,
    content_handler=content_handler
)


# Load Data and Split (important to call after sagemakerEndpointEmbeddingsJumpStart is initiated)
loader = TextLoader("Pathology_Robbins.txt")  # Replace with your data file
documents = loader.load()

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
docs = text_splitter.split_documents(documents)

print(f"Number of chunks: {len(docs)}")

sample_embedding = np.array(sagemakerEndpointEmbeddingsJumpStart.embed_query(docs[0].page_content))
print("Sample embedding of a document chunk: ", sample_embedding)
print("Size of the embedding: ", sample_embedding.shape)


Number of chunks: 5171
Sample embedding of a document chunk:  [-2.68548373e-02 -5.14841871e-03  8.22765604e-02  2.51330547e-02
 -5.26014669e-03 -5.87520450e-02  7.25993514e-02  5.83153265e-03
  1.96829159e-03 -1.03373043e-02  2.18124315e-02 -7.02288449e-02
  3.04869432e-02 -2.42926478e-02  1.07806809e-02  1.18084755e-02
 -4.04602120e-04  4.70750406e-02 -4.19485047e-02  2.68056672e-02
 -1.20089045e-02 -4.66680191e-02  7.61874812e-03 -7.36146793e-02
  3.83745246e-02  5.43740168e-02 -7.34584220e-03 -4.35480615e-03
 -3.49703655e-02 -1.43278882e-01  5.94927871e-04 -7.77929882e-03
  3.37314117e-03  3.28561850e-02  1.15412083e-02 -2.63297465e-02
 -1.76403560e-02 -4.29618685e-03 -3.88793573e-02  5.76247927e-03
  5.42547517e-02  8.73222575e-03 -4.60612774e-02  6.76298980e-03
  3.03594731e-02 -4.08621617e-02  3.95240774e-03 -2.54773833e-02
  4.18353826e-02 -4.78082784e-02 -4.71962839e-02 -1.01800738e-02
  1.61627885e-02  4.23433036e-02  4.70918603e-03 -1.95064899e-02
  5.22197708e-02  3.16964127

In [5]:
#FAISS Indexing
db = FAISS.from_documents(docs, sagemakerEndpointEmbeddingsJumpStart) #embeddings


In [7]:
import boto3
import json
from langchain.llms import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from typing import Any, Dict, List, Optional
import os


# Setup
TEXT_EMBEDDING_MODEL_ENDPOINT_NAME = 'jumpstart-dft-hf-sentencesimilarity-20250325-230336' #INSERT EMBEDDING ENDPOINT NAME IF DIFFERENT
TEXT_GENERATION_MODEL_ENDPOINT_NAME = 'jumpstart-dft-hf-llm-mistral-7b-ins-20250325-230300' #INSERT TEXT GENERATION ENDPOINT NAME IF DIFFERENT
REGION_NAME = boto3.session.Session().region_name

# Ensure correct naming
os.environ["AWS_DEFAULT_REGION"] = REGION_NAME
#Text Generation with Falcon
class ContentHandlerFalcon(LLMContentHandler):
    content_type = "application/json"
    accepts = "text/plain"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        # Format the prompt according to Falcon-Instruct expectations
        formatted_prompt = f"<|prompter|>{prompt}<|endoftext|><|assistant|>"
        input_str = json.dumps({
            "inputs": formatted_prompt,
            "parameters": {
                "max_new_tokens": 500,
                "temperature": 0.7,
                "top_p": 0.95,
                "do_sample": True
            }
        })
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        try:
            raw_response = output.read().decode("utf-8")
            print(f"Raw response from model: {raw_response}")  # Debug print
            response_json = json.loads(raw_response)
            
            if isinstance(response_json, list) and len(response_json) > 0:
                if "generated_text" in response_json[0]:
                    return response_json[0]["generated_text"]
            elif isinstance(response_json, dict):
                return response_json.get("generated_text", "")
            
            print(f"Unexpected response format: {response_json}")  # Debug print
            return ""
        except Exception as e:
            print(f"Error processing output: {str(e)}")
            return ""

content_handler_falcon = ContentHandlerFalcon()

llm = SagemakerEndpoint(
    endpoint_name=TEXT_GENERATION_MODEL_ENDPOINT_NAME,
    region_name=REGION_NAME,
    content_handler=content_handler_falcon,
    model_kwargs={"max_new_tokens": 500}  # Simplify parameters
)

    
# RAG QA Chain
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=db.as_retriever(search_kwargs={"k": 3}),
    return_source_documents=True,
    verbose=True  # Add this to see the chain's operation
)

# Example Question
try:
    query = "What are the main causes of heart failure?"
    result = qa_chain({"query": query})
    print("Question:", query)
    print("Answer:", result["result"])
    print("Source Documents:", result["source_documents"])
except Exception as e:
    print(f"Error during query: {str(e)}")




[1m> Entering new RetrievalQA chain...[0m
Raw response from model: [{"generated_text":"<|prompter|>Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\nHeart failure may result from systolic or diastolic dysfunction. Systolic dysfunction results from inadequate myocardial contractile function, usually as a consequence of ischemic heart disease or hypertension. Diastolic dysfunction refers to an inability of the heart to adequately relax and fill, which may be a consequence of massive left ventricular hypertrophy, myocardial fibrosis, amyloid deposition, or constrictive pericarditis. Approximately one half of CHF cases are attributable to diastolic dysfunction, with a greater frequency seen in older adults, diabetic patients, and women. Heart failure may also be caused by valve dysfunction (e.g., due to endocarditis), or may occur following rapid increases in blood volume