## Medical Question answering with Retrieval Augmented Generation design pattern. 
Use Python 3 (Data Science 3.0) kernel image and `ml.m5.2xlarge` for this notebook.

This includes generating embeddings of all existing documents, indexing them in a vector store. Then for every user query, generate local embeddings and search based on embedding distance. The search responses act as context to the LLM model to generate a output. 

Challenges:
How to manage large document(s) that exceed the token limitHow to find the document(s) relevant to the question being asked

## Key components

LLM (Large Language Model): Mistral-7b-instruct available through Amazon SageMaker This model will be used to understand the document chunks and provide an answer in human friendly manner.

Embeddings Model: GPT-J 6B available through Amazon SageMaker. This model will be used to generate a numerical representation of the textual documents.

Vector Store: FAISS available through LangChainIn this notebook we are using this in-memory vector-store to store both the embeddings and the documents. In an enterprise context this could be replaced with a persistent store such as AWS OpenSearch, RDS Postgres with pgVector, ChromaDB, Pinecone or Weaviate.

Index: VectorIndex The index helps to compare the input embedding and the document embeddings to find relevant document

### Dataset
To explain this architecture pattern we are using the documents from MedQA. These documents include medical textbooks such as:
Pathology, Anatomy, Pharmacology and others. 

Download textbooks that are part of Q&A dataset MedQA released as part of Jin, Di, et al. "What Disease does this Patient Have? A Large-scale Open Domain Question Answering Dataset from Medical Exams." arXiv preprint arXiv:2009.13081 (2020). 

More details are available here https://github.com/jind11/MedQA

* Data source : @article{jin2020disease,
  title={What Disease does this Patient Have? A Large-scale Open Domain Question Answering Dataset from Medical Exams},
  author={Jin, Di and Pan, Eileen and Oufattole, Nassim and Weng, Wei-Hung and Fang, Hanyi and Szolovits, Peter},
  journal={arXiv preprint arXiv:2009.13081},
  year={2020} }
  
  

### Data preparation

> **NOTICE**: "This link leads to a Third-Party Dataset. AWS does not own, nor does it have any control over the Third-Party Dataset. You should perform your own independent assessment, and take measures to ensure that you comply with your own specific quality control practices and standards, and the local rules, laws, regulations, licenses and terms of use that apply to you, your content, and the Third-Party Dataset. AWS does not make any representations or warranties that the Third-Party Dataset is secure, virus-free, accurate, operational, or compatible with your own environment and standards. AWS does not make any representations, warranties or guarantees that any information in the Third-Party Dataset will result in a particular outcome or result."

1. The full dataset can be downloaded can be seen from https://drive.google.com/file/d/1ImYUSLk9JbgHXOemfvyiDiirluZHPeQw/view?usp=sharing. You can read more about the dataset in https://github.com/jind11/MedQA#data.
2. To speed up the uploading for this lab, a smaller version of dataset is already downloaded -  https://d2qrbbbqnxtln.cloudfront.net/Pathology_Robbins.txt

##### Prerequisites

In [1]:
%pip install faiss-cpu==1.7.4 --quiet

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
jupyter-ai 2.29.1 requires faiss-cpu!=1.8.0.post0,<2.0.0,>=1.8.0, but you have faiss-cpu 1.7.4 which is incompatible.[0m[31m
Note: you may need to restart the kernel to use updated packages.


In [2]:
%pip install langchain==0.0.222 --quiet

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
jupyter-ai 2.29.1 requires faiss-cpu!=1.8.0.post0,<2.0.0,>=1.8.0, but you have faiss-cpu 1.7.4 which is incompatible.
jupyter-ai 2.29.1 requires pydantic~=2.0, but you have pydantic 1.10.21 which is incompatible.
jupyter-ai-magics 2.29.1 requires langchain<0.4.0,>=0.3.0, but you have langchain 0.0.222 which is incompatible.
jupyter-ai-magics 2.29.1 requires pydantic~=2.0, but you have pydantic 1.10.21 which is incompatible.
langchain-aws 0.2.10 requires pydantic<3,>=2, but you have pydantic 1.10.21 which is incompatible.
langchain-community 0.3.19 requires langchain<1.0.0,>=0.3.20, but you have langchain 0.0.222 which is incompatible.
langchain-core 0.3.41 requires pydantic<3.0.0,>=2.5.2; python_full_version < "3.12.4", but you have pydantic 1.10.21 which is incompatible.
pydantic-settings 2.8.1 requires pyda

In [3]:
%%capture 

!pip install PyYAML

In [4]:
!pip install --upgrade pydantic langchain


Collecting pydantic
  Downloading pydantic-2.10.6-py3-none-any.whl.metadata (30 kB)
Collecting langchain
  Downloading langchain-0.3.21-py3-none-any.whl.metadata (7.8 kB)
Collecting langchain-core<1.0.0,>=0.3.45 (from langchain)
  Downloading langchain_core-0.3.48-py3-none-any.whl.metadata (5.9 kB)
Collecting langchain-text-splitters<1.0.0,>=0.3.7 (from langchain)
  Downloading langchain_text_splitters-0.3.7-py3-none-any.whl.metadata (1.9 kB)
Downloading pydantic-2.10.6-py3-none-any.whl (431 kB)
Downloading langchain-0.3.21-py3-none-any.whl (1.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m50.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading langchain_core-0.3.48-py3-none-any.whl (418 kB)
Downloading langchain_text_splitters-0.3.7-py3-none-any.whl (32 kB)
Installing collected packages: pydantic, langchain-core, langchain-text-splitters, langchain
  Attempting uninstall: pydantic
    Found existing installation: pydantic 1.10.21
    Uninstalling 

#### Imports

In [5]:
import requests
import logging 
import boto3
import yaml
import json

##### Setup logging

In [6]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies 

In [7]:
logger.info(f'Using requests=={requests.__version__}')
logger.info(f'Using pyyaml=={yaml.__version__}')

Using requests==2.32.3
Using pyyaml==6.0.2


#### Setup essentials

In [8]:
TEXT_EMBEDDING_MODEL_ENDPOINT_NAME = 'jumpstart-dft-hf-sentencesimilarity-20250325-230336' #INSERT EMBEDDING ENDPOINT NAME IF DIFFERENT
TEXT_GENERATION_MODEL_ENDPOINT_NAME = 'jumpstart-dft-hf-llm-mistral-7b-ins-20250325-230300' #INSERT TEXT GENERATION ENDPOINT NAME IF DIFFERENT

REGION_NAME = boto3.session.Session().region_name

#### Encode passages (chunks) using JumpStart's GPT-J text embedding model . We are specifically using only 1 of 20 textbooks from the dataset. It takes about 6 minutes to generate embeddings for one textbook (for example, Pathology). You can increase the number of textbooks indexed by adding sufficient time buffer for execution. 

In order to follow the RAG approach this notebook is using the LangChain framework where it has integrations with different services and tools that allow efficient building of patterns such as RAG. 

In [9]:
import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import DirectoryLoader, TextLoader

loader = DirectoryLoader("./", glob="**/Pathology*.txt", loader_cls=TextLoader)

documents = loader.load()
# - in our testing Character split works better with this PDF data set
text_splitter = RecursiveCharacterTextSplitter(
    # Set a really small chunk size, just to show.
    chunk_size = 1000,
    chunk_overlap  = 100,
)
docs = text_splitter.split_documents(documents)

In [10]:
print(docs[0])

page_content='Plasma Membrane: Protection and Nutrient Acquisition

Biosynthetic Machinery: Endoplasmic Reticulum and Golgi Apparatus

Waste Disposal: Lysosomes and Proteasomes

Modular Signaling Proteins, Hubs, and

Components of the Extracellular Matrix

Proliferation and the Cell Cycle

Pathology literally translates to the study of suffering (Greek pathos = suffering, logos = study); as applied to modern medicine, it is the study of disease. Virchow was certainly correct in asserting that disease originates at the cellular level, but we now realize that cellular disturbances arise from alterations in molecules (genes, proteins, and others) that influence the survival and behavior of cells. Thus, the foundation of modern pathology is understanding the cellular and molecular abnormalities that give rise to diseases. It is helpful to consider these abnormalities in the context of normal cellular structure and function, which is the theme of this introductory chapter.' metadata={'sourc

In [11]:
avg_doc_length = lambda documents: sum([len(doc.page_content) for doc in documents])//len(documents)
avg_char_count_pre = avg_doc_length(documents)
avg_char_count_post = avg_doc_length(docs)
print(f'Average length among {len(documents)} documents loaded is {avg_char_count_pre} characters.')
print(f'After the split we have {len(docs)} documents more than the original {len(documents)}.')
print(f'Average length among {len(docs)} documents (after split) is {avg_char_count_post} characters.')

Average length among 1 documents loaded is 3784898 characters.
After the split we have 5171 documents more than the original 1.
Average length among 5171 documents (after split) is 744 characters.


In [12]:
# Embedding Setup
class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
    def embed_documents(self, texts: List[str], chunk_size: int = 5) -> List[List[float]]:
        """Compute doc embeddings using a SageMaker Inference Endpoint.

        Args:
            texts: The list of texts to embed.
            chunk_size: The chunk size defines how many input texts will
                be grouped together as request. If None, will use the
                chunk size specified by the class.

        Returns:
            List of embeddings, one for each text.
        """
        results = []
        _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size

        for i in range(0, len(texts), _chunk_size):
            response = self._embedding_func(texts[i : i + _chunk_size])
            print
            results.extend(response)
        return results


class ContentHandler(EmbeddingsContentHandler):  # Inherit from EmbeddingsContentHandler
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        # Use "embedding" mode for both documents and queries.
        input_str = json.dumps({"text_inputs": prompt, "mode": "embedding", **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> List[float]: #Expects to return a list of floats
        response_json = json.loads(output.read().decode("utf-8"))
        embeddings = response_json["embedding"]
        return embeddings


content_handler = ContentHandler()

sagemakerEndpointEmbeddingsJumpStart = SagemakerEndpointEmbeddingsJumpStart(
    endpoint_name=TEXT_EMBEDDING_MODEL_ENDPOINT_NAME,
    region_name=REGION_NAME,
    content_handler=content_handler
)


# Load Data and Split (important to call after sagemakerEndpointEmbeddingsJumpStart is initiated)
loader = TextLoader("Pathology_Robbins.txt")  # Replace with your data file
documents = loader.load()

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
docs = text_splitter.split_documents(documents)

print(f"Number of chunks: {len(docs)}")



NameError: name 'SagemakerEndpointEmbeddings' is not defined

In [20]:
print(docs[0].page_content)

Plasma Membrane: Protection and Nutrient Acquisition

Biosynthetic Machinery: Endoplasmic Reticulum and Golgi Apparatus

Waste Disposal: Lysosomes and Proteasomes

Modular Signaling Proteins, Hubs, and

Components of the Extracellular Matrix

Proliferation and the Cell Cycle

Pathology literally translates to the study of suffering (Greek pathos = suffering, logos = study); as applied to modern medicine, it is the study of disease. Virchow was certainly correct in asserting that disease originates at the cellular level, but we now realize that cellular disturbances arise from alterations in molecules (genes, proteins, and others) that influence the survival and behavior of cells. Thus, the foundation of modern pathology is understanding the cellular and molecular abnormalities that give rise to diseases. It is helpful to consider these abnormalities in the context of normal cellular structure and function, which is the theme of this introductory chapter.


---
## Semantic Similarity with Amazon Jumpstart Embedding Models

Semantic search refers to searching for information based on the meaning and concepts of words and phrases, rather than just matching keywords. Embedding models like Amazon Titan Embeddings allow semantic search by representing words and sentences as dense vectors that encode their semantic meaning.

Semantic matching is extremely helpful for RAG because it returns results that are conceptually related to the user's query, even if they don't contain the exact keywords. This leads to more relevant and useful search results which can be injected into our LLM's prompts.

First, let's take a look below to illustrate the sample of an embedding

In [21]:
sample_embedding = np.array(sagemakerEndpointEmbeddingsJumpStart.embed_query(docs[0].page_content))
print("Sample embedding of a document chunk: ", sample_embedding)
print("Size of the embedding: ", sample_embedding.shape)

ValueError: Error raised by inference endpoint: An error occurred (ModelError) when calling the InvokeEndpoint operation: Received client error (400) from primary with message "{
  "code": 400,
  "type": "InternalServerException",
  "message": "The input payload must contain mode"
}
". See https://us-west-2.console.aws.amazon.com/cloudwatch/home?region=us-west-2#logEventViewer:group=/aws/sagemaker/Endpoints/jumpstart-dft-hf-sentencesimilarity-20250324-203000 in account 941545645943 for more information.

Now create embeddings for the entire document set. Note for a single medical textbook, it takes about 6 minutes.

In [None]:
#FAISS Indexing
db = FAISS.from_documents(docs, sagemakerEndpointEmbeddingsJumpStart) #embeddings
faiss.add_embeddings(data)
faiss.save_local("faiss_index")

In [None]:
from tqdm.contrib.concurrent import process_map
from multiprocessing import cpu_count

def generate_embeddings(x):
    return (x, sagemakerEndpointEmbeddingsJumpStart.embed_query(x))
    
workers = 1 * cpu_count()

texts = [i.page_content for i in docs]

In [None]:
workers

In [None]:
data = process_map(generate_embeddings, texts, max_workers=workers, chunksize=100)

Next, we insert the embeddings to the FAISS vector store

In [None]:
from langchain.vectorstores import FAISS
faiss = FAISS.from_documents(docs[0:2], sagemakerEndpointEmbeddingsJumpStart)
faiss.add_embeddings(data)
faiss.save_local("faiss_index")

Next we create user query to retrieve a response from vector search and LLM combined

In [None]:
query = "What is acute kidney injury?"

In [None]:
query_embedding = faiss.embedding_function(query)
np.array(query_embedding)

In [None]:
relevant_documents = faiss.similarity_search_by_vector(query_embedding)
context = ""
print(f'{len(relevant_documents)} documents are fetched which are relevant to the query.')
print('----')
for i, rel_doc in enumerate(relevant_documents):
    print(f'## Document {i+1}: {rel_doc.page_content}.......')
    print('---')
    context += rel_doc.page_content
context = context.replace("\n", " ")

Now create a prompt template to trigger the model with above context from vector search. We specifically inform the model to answer only using the context provied.

In [None]:
template = """
        You are a helpful, polite, fact-based agent.
        If you don't know the answer, just say that you don't know.
        Please answer the following question using the context provided. 

        CONTEXT: 
        {context}
        =========
        QUESTION: {question} 
        ANSWER: """


In [None]:
prompt = template.format(context=context, question=query)
print(prompt)

Invoke the endpoint to generate a response from the LLM

In [None]:
smr_client = boto3.client("sagemaker-runtime")

In [None]:
response_model = smr_client.invoke_endpoint(
    EndpointName=TEXT_GENERATION_MODEL_ENDPOINT_NAME,
    Body=json.dumps(
        {"inputs": prompt, "parameters": {"max_new_tokens": 500}}
    ),
    ContentType="application/json",
)
response = json.loads(response_model["Body"].read())


In [None]:
print(response[0]["generated_text"])

---
## Create RAG with Langchain and LLM hosted on SageMaker


In [None]:
import boto3
import json
from langchain.llms import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from typing import Any, Dict, List, Optional
import os

# Setup
TEXT_EMBEDDING_MODEL_ENDPOINT_NAME = 'jumpstart-dft-hf-sentencesimilarity-20250324-203000'  # YOUR BGE ENDPOINT
TEXT_GENERATION_MODEL_ENDPOINT_NAME = 'jumpstart-dft-hf-sentencesimilarity-20250324-203000'  # YOUR FALCON ENDPOINT
REGION_NAME = boto3.session.Session().region_name

# Ensure correct naming
os.environ["AWS_DEFAULT_REGION"] = REGION_NAME

#Text Generation with Falcon
class ContentHandlerFalcon(LLMContentHandler):  # Inherit from LLMContentHandler
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        #Adding the mode
        input_str = json.dumps({"text_inputs": prompt, "mode":"embedding", **model_kwargs})  # Falcon expects "text_inputs" and mode
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        # Check if 'generated_text' is a list or a string
        if isinstance(response_json, list) and len(response_json) > 0 and "generated_text" in response_json[0]:
             generated_text = response_json[0]["generated_text"]  # Adjust based on Falcon's output
        else:
            generated_text = response_json.get("generated_text", "") #handles if "generated_text" at top level
        return generated_text

content_handler_falcon = ContentHandlerFalcon()

llm = SagemakerEndpoint(
    endpoint_name=TEXT_GENERATION_MODEL_ENDPOINT_NAME,
    region_name=REGION_NAME,
    content_handler=content_handler_falcon,
    #model_kwargs={"max_new_tokens": 500}  # Remove invalid model_kwargs
)


In [None]:
# RAG QA Chain
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",  # "stuff" is a simple way to pass all context
    retriever=db.as_retriever(search_kwargs={"k": 3}),  # Number of chunks to retrieve
    return_source_documents=True
)


In [None]:
# Example Question
query = "What are the main causes of heart failure?"
result = qa_chain({"query": query})
print("Question:", query)
print("Answer:", result["result"])
#print("Source Documents:", result["source_documents"]) #uncomment for more info

