# Healthcare RAG System Lab
## Overview

In this lab, you'll take on the role of a junior data scientist at a healthcare technology company that specializes in creating educational resources for patients. Your team has been tasked with developing a system that can automatically generate informative responses to common patient questions about medical conditions, treatments, and wellness practices.

The challenge is to ensure these responses are both accurate and grounded in authoritative medical information. Your specific assignment is to implement a Retrieval-Augmented Generation (RAG) system that can:
1. Understand patient questions about various health topics
2. Retrieve relevant information from a trusted knowledge base
3. Generate helpful, accurate responses based on that information
4. Avoid "hallucinated" content that could potentially misinform patients

This lab follows the generative AI implementation process we've studied, with particular focus on:
- Data Strategy and Knowledge Foundation
- Model Selection and Generation Control
- Evaluation Framework Development

## Setup

First, let's import the necessary libraries:

In [1]:
import torch
import pandas as pd
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cpu


## Part 1: Knowledge Base Setup

Let's create a sample medical knowledge base with information about common health conditions, treatments, and wellness practices:

In [2]:
# Create a sample medical knowledge base
knowledge_base = pd.DataFrame({
    'content': [
        "Diabetes is a chronic condition that affects how your body turns food into energy. There are three main types: Type 1, Type 2, and gestational diabetes. Type 2 diabetes is the most common form, accounting for about 90-95% of diabetes cases.",
        "Type 1 diabetes is an autoimmune reaction that stops your body from making insulin. Symptoms include increased thirst, frequent urination, hunger, fatigue, and blurred vision. It's usually diagnosed in children, teens, and young adults.",
        "Type 2 diabetes occurs when your body becomes resistant to insulin or doesn't make enough insulin. Risk factors include being overweight, being 45 years or older, having a parent or sibling with type 2 diabetes, and being physically active less than 3 times a week.",
        "Managing diabetes involves monitoring blood sugar levels, taking medications as prescribed, eating a healthy diet, maintaining a healthy weight, and getting regular physical activity. It's important to work with healthcare providers to develop a management plan.",
        "Hypertension, or high blood pressure, is when the force of blood pushing against the walls of your arteries is consistently too high. It's often called the 'silent killer' because it typically has no symptoms but significantly increases the risk of heart disease and stroke.",
        "Blood pressure is measured using two numbers: systolic (top number) and diastolic (bottom number). Normal blood pressure is less than 120/80 mm Hg. Hypertension is diagnosed when readings are consistently 130/80 mm Hg or higher.",
        "Lifestyle changes to manage hypertension include reducing sodium in your diet, getting regular physical activity, maintaining a healthy weight, limiting alcohol, quitting smoking, and managing stress. Medications may also be prescribed if lifestyle changes aren't enough.",
        "Regular physical activity offers numerous health benefits, including weight management, reduced risk of heart disease, strengthened bones and muscles, improved mental health, and enhanced ability to perform daily activities. Adults should aim for at least 150 minutes of moderate-intensity activity per week.",
        "A balanced diet should include a variety of fruits, vegetables, whole grains, lean proteins, and healthy fats. It's recommended to limit intake of added sugars, sodium, saturated fats, and processed foods. Proper nutrition helps prevent chronic diseases and supports overall health.",
        "Vaccination is one of the most effective ways to prevent infectious diseases. Vaccines work by helping the body recognize and fight specific pathogens. Common adult vaccines include influenza (flu), Tdap (tetanus, diphtheria, pertussis), shingles, and pneumococcal vaccines."
    ],
    'metadata': [
        {'topic': 'diabetes', 'subtopic': 'overview', 'source': 'medical_guidelines', 'last_updated': '2023-06-10'},
        {'topic': 'diabetes', 'subtopic': 'type1', 'source': 'medical_guidelines', 'last_updated': '2023-06-10'},
        {'topic': 'diabetes', 'subtopic': 'type2', 'source': 'medical_guidelines', 'last_updated': '2023-06-10'},
        {'topic': 'diabetes', 'subtopic': 'management', 'source': 'medical_guidelines', 'last_updated': '2023-06-10'},
        {'topic': 'hypertension', 'subtopic': 'overview', 'source': 'medical_guidelines', 'last_updated': '2023-07-22'},
        {'topic': 'hypertension', 'subtopic': 'diagnosis', 'source': 'medical_guidelines', 'last_updated': '2023-07-22'},
        {'topic': 'hypertension', 'subtopic': 'management', 'source': 'medical_guidelines', 'last_updated': '2023-07-22'},
        {'topic': 'wellness', 'subtopic': 'physical_activity', 'source': 'health_promotion', 'last_updated': '2023-05-15'},
        {'topic': 'wellness', 'subtopic': 'nutrition', 'source': 'health_promotion', 'last_updated': '2023-05-15'},
        {'topic': 'prevention', 'subtopic': 'vaccination', 'source': 'medical_guidelines', 'last_updated': '2023-08-05'}
    ]
})

print(f"Knowledge base loaded with {len(knowledge_base)} entries")
knowledge_base.head(2)

Knowledge base loaded with 10 entries


Unnamed: 0,content,metadata
0,Diabetes is a chronic condition that affects h...,"{'topic': 'diabetes', 'subtopic': 'overview', ..."
1,Type 1 diabetes is an autoimmune reaction that...,"{'topic': 'diabetes', 'subtopic': 'type1', 'so..."


### Task 1: Create Document Embeddings

Complete the function below to create embeddings for each document in the knowledge base. These embeddings will be used to find relevant documents based on patient queries.

In [3]:
def create_document_embeddings(documents):
    """
    Create embeddings for a list of documents.
    
    Args:
        documents: List of text documents to embed
        
    Returns:
        Numpy array of document embeddings
    """
    # Initialize a sentence transformer model
    # Recommended: 'sentence-transformers/all-mpnet-base-v2' or similar
    embedding_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')  # Replace with your code
    
    # Generate embeddings for all documents
    # Hint: Use the model.encode() method
    document_embeddings = embedding_model.encode(
        documents,
        show_progress_bar=True,
        normalize_embeddings=True
    ) 
    
    return document_embeddings

# Extract document content
documents = knowledge_base['content'].tolist()

# Create document embeddings
document_embeddings = create_document_embeddings(documents)

# Verify the shape of embeddings
if document_embeddings is not None:
    print(f"Generated embeddings with shape: {document_embeddings.shape}")
else:
    print("Embeddings not created yet.")

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Generated embeddings with shape: (10, 768)


## Part 2: Implementing the Retrieval Component

Now, let's implement the function to retrieve relevant documents based on a patient query.

In [5]:
def retrieve_documents(query, embeddings, contents, metadata, top_k=3, threshold=0.3):
    """
    Retrieve the most relevant documents for a given query.
    
    Args:
        query: The patient's question
        embeddings: The precomputed document embeddings
        contents: The text content of the documents
        metadata: The metadata for each document
        top_k: Maximum number of documents to retrieve
        threshold: Minimum similarity score to include a document
        
    Returns:
        List of (content, metadata, similarity_score) tuples
    """
    # Get or initialize the embedding model (same as in create_document_embeddings)
    embedding_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
    
    # Embed the query
    query_embedding = embedding_model.encode(
        [query],
        normalize_embeddings=True
    )
    
    # Calculate similarity scores between query and all documents
    # Hint: Use cosine_similarity
    similarities = cosine_similarity(query_embedding, embeddings).flatten() 
    
    # Filter by threshold and get top k results
    # Hint: Use list comprehension, sorting, and slicing
    candidates = [
        (contents[i], metadata[i], float(similarities[i]))
        for i in range(len(contents))
        if similarities[i] >= threshold
    ]

    # Sort by similarity score (descending) and take top_k
    candidates.sort(key=lambda x: x[2], reverse=True)
    
    # Return the top documents with their metadata and scores
    results = candidates[:top_k]  # Replace with your code
    
    return results

# Test the retrieval function with a sample query
if document_embeddings is not None:
    sample_query = "What are the symptoms of Type 1 diabetes?"
    retrieved_docs = retrieve_documents(
        query=sample_query,
        embeddings=document_embeddings,
        contents=documents,
        metadata=knowledge_base['metadata'].tolist(),
        top_k=2
    )
    
    print(f"Query: {sample_query}")
    print("\nRetrieved Documents:")
    for i, (content, meta, score) in enumerate(retrieved_docs):
        print(f"{i+1}. [{score:.4f}] {content[:100]}...")
        print(f"   Topic: {meta['topic']}, Subtopic: {meta['subtopic']}")
else:
    print("Cannot test retrieval without document embeddings.")

Query: What are the symptoms of Type 1 diabetes?

Retrieved Documents:
1. [0.7585] Type 1 diabetes is an autoimmune reaction that stops your body from making insulin. Symptoms include...
   Topic: diabetes, Subtopic: type1
2. [0.4625] Diabetes is a chronic condition that affects how your body turns food into energy. There are three m...
   Topic: diabetes, Subtopic: overview


## Part 3: Building the Generation Component

Now, let's implement the generation component that will use the retrieved documents to create informative responses.

In [6]:
# Initialize the generative model
def initialize_generator(model_name="gpt2"):
    """
    Initialize the generative model and tokenizer.
    
    Args:
        model_name: Name of the pretrained model to use
        
    Returns:
        Tokenizer and model objects
    """
    # Load the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    
    # Set padding token if needed
    # Check if pad_token exists, if not set it to eos_token

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = tokenizer.eos_token_id
            
    
    return tokenizer, model

# Initialize the generator
tokenizer, model = initialize_generator()
if tokenizer and model:
    print(f"Initialized {model.config._name_or_path} with {model.num_parameters()} parameters")

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Initialized gpt2 with 124439808 parameters


In [9]:
def generate_rag_response(query, contents, metadata, document_embeddings, tokenizer, model, max_length=100):
    """
    Generate a response using Retrieval-Augmented Generation.
    
    Args:
        query: The patient's question
        contents: List of document contents
        metadata: List of document metadata
        document_embeddings: Precomputed embeddings for the documents
        tokenizer: The tokenizer for the language model
        model: The language model for generation
        max_length: Maximum response length
        
    Returns:
        Dictionary with the generated response and the retrieved documents
    """
    # Retrieve relevant documents for the query
    retrieved_docs = retrieve_documents(
        query=query,
        embeddings=document_embeddings,
        contents=contents,
        metadata=metadata,
        top_k=2
    )
  
    
    # Format prompt with retrieved context
    # Hint: If no relevant documents are found, generate without context
    # Otherwise, include retrieved documents as context
    if retrieved_docs:
        context = "\n".join(
            [f"- {doc}" for doc, _, _ in retrieved_docs]
        )
        prompt = (
            "You are a healthcare assistant providing accurate, patient-friendly information.\n\n"
            "Use the following trusted medical information to answer the question.\n\n"
            f"Context:\n{context}\n\n"
            f"Question: {query}\n\n"
            "Answer:"
        )
    else: prompt = (
        "You are a healthcare assistant providing general information.\n\n"
        f"Question: {query}\n\n"
        "Answer:"
    )


    
    # Tokenize the prompt
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=512
    )
    
    # TGenerate the response
    # Consider including temperature, top_k, and do_sample parameters (will otherwise use greedy method)
    output_sequences = model.generate(
        **inputs,
        max_new_tokens=max_length,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        no_repeat_ngram_size=3,
        repetition_penalty=1.15,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id
    )
    generated_text = tokenizer.decode(
        output_sequences[0],
        skip_special_tokens=True
    )
    
    
    # Decode the response and extract the generated text
    response = generated_text.split("Answer:")[-1].strip()
    
    # Return the results
    return {
        "query": query,
        "response": response,
        "retrieved_documents": retrieved_docs
    }

# Test the RAG system with several queries
if document_embeddings is not None and tokenizer and model:
    test_queries = [
        "What are the different types of diabetes?",
        "How can I manage my high blood pressure through lifestyle changes?",
        "Why is regular physical activity important for health?",
        "What vaccines should adults consider getting?"
    ]
    
    for query in test_queries:
        print(f"\nQuery: {query}")
        result = generate_rag_response(
            query=query,
            contents=documents,
            metadata=knowledge_base['metadata'].tolist(),
            document_embeddings=document_embeddings,
            tokenizer=tokenizer,
            model=model
        )
        
        print("\nRetrieved Documents:")
        for i, (doc, meta, score) in enumerate(result["retrieved_documents"]):
            print(f"{i+1}. [{score:.4f}] Topic: {meta['topic']}, Subtopic: {meta['subtopic']}")
        
        print(f"\nGenerated Response:\n{result['response']}")
        print("-" * 80)
else:
    print("Cannot test generation without embeddings or model.")


Query: What are the different types of diabetes?

Retrieved Documents:
1. [0.7130] Topic: diabetes, Subtopic: overview
2. [0.6430] Topic: diabetes, Subtopic: type1

Generated Response:
The type I diabetic was born with at birth has two major causes (Type II) – A genetic predisposition or "dysphoria," which can cause you not having enough blood sugar control during pregnancy; also called malignancy syndrome because it results when glucose levels fall below normal but still keep rising as well as overfeeding may affect fetal development.[1] People who have this disease typically develop multiple heart problems including high cholesterol/HDL ratio,[2][3], kidney failure, liver damage
--------------------------------------------------------------------------------

Query: How can I manage my high blood pressure through lifestyle changes?

Retrieved Documents:
1. [0.7775] Topic: hypertension, Subtopic: management
2. [0.4690] Topic: hypertension, Subtopic: overview

Generated Response:
You 

## Part 4: Evaluation and Analysis

Let's implement a basic evaluation function to assess the quality of our generated responses.

In [11]:
import re

def evaluate_response(response_data):
    """
    Evaluate the quality of a generated response based on various criteria.
    
    Args:
        response_data: Dictionary containing the query, response, and retrieved docs
        
    Returns:
        Evaluation metrics
    """
    # Implement at least two evaluation metrics
    # Suggestions:
    # 1. Content relevance - Check if response mentions terms from retrieved docs
    # 2. Response length appropriateness - Check if length is suitable for query
    # 3. Medical terminology usage - Check if appropriate medical terms are included

    query = response_data.get("query", "")
    response = response_data.get("response", "") or ""
    retrieved_docs = response_data.get("retrieved_documents", []) or []

    retrieved_text = " ".join([doc for doc, _, _ in retrieved_docs]).lower()
    response_text = response.lower()

    # Basic tokenization 
    def tokenize(text):
        return re.findall(r"[a-zA-Z']+", text.lower())

    response_tokens = set(tokenize(response))
    retrieved_tokens = set(tokenize(retrieved_text))

    # Content relevance / grounding: token overlap ratio
    overlap = response_tokens.intersection(retrieved_tokens)
    grounding_overlap_ratio = (len(overlap) / max(len(response_tokens), 1))

    # Response length appropriateness
    word_count = len(response.split())
    length_ok = 30 <= word_count <= 185
                               

    
    medical_terms = [
        "diabetes", "insulin", "glucose", "hypertension", "blood pressure",
        "systolic", "diastolic", "cardiovascular", "cholesterol", "nutrition",
        "obesity", "physical activity", "vaccination", "immune", "prevention"
    ]

    med_terms_found = [t for t in medical_terms if t in response_text]
    medical_terms_count = len(med_terms_found)
    medical_terms_ok = medical_terms_count >= 1

    metrics = {
        "word_count": word_count,
        "length_ok": length_ok,
        "grounding_overlap_ratio": round(grounding_overlap_ratio, 3),
        "medical_terms_count": medical_terms_count,
        "medical_terms_found": med_terms_found[:6],
        "used_retrieval": len(retrieved_docs) > 0        
    }
    
    return metrics

# Evaluate the responses for our test queries
if 'test_queries' in locals() and document_embeddings is not None and tokenizer and model:
    for query in test_queries:
        result = generate_rag_response(
            query=query,
            contents=documents,
            metadata=knowledge_base['metadata'].tolist(),
            document_embeddings=document_embeddings,
            tokenizer=tokenizer,
            model=model
        )
        
        metrics = evaluate_response(result)
        print(f"Query: {query}")
        print(f"Evaluation Metrics: {metrics}")
        print("-" * 80)
else:
    print("Cannot evaluate without test queries or necessary components.")

Query: What are the different types of diabetes?
Evaluation Metrics: {'word_count': 84, 'length_ok': True, 'grounding_overlap_ratio': 0.06, 'medical_terms_count': 0, 'medical_terms_found': [], 'used_retrieval': True}
--------------------------------------------------------------------------------
Query: How can I manage my high blood pressure through lifestyle changes?
Evaluation Metrics: {'word_count': 81, 'length_ok': True, 'grounding_overlap_ratio': 0.025, 'medical_terms_count': 1, 'medical_terms_found': ['diabetes'], 'used_retrieval': True}
--------------------------------------------------------------------------------
Query: Why is regular physical activity important for health?
Evaluation Metrics: {'word_count': 82, 'length_ok': True, 'grounding_overlap_ratio': 0.012, 'medical_terms_count': 0, 'medical_terms_found': [], 'used_retrieval': True}
--------------------------------------------------------------------------------
Query: What vaccines should adults consider getting?
Eva

## Reflection Questions

Answer the following questions about your RAG implementation and its potential applications in healthcare:

### How does the RAG approach improve factual accuracy compared to regular generation?

The RAG approach improves factual accuracy by grounding responses generated by the model to a stable, verified, and trusted external knowledge base as opposed to relying completely on the model's training data. It retrieves relevant documents prior to formulating a reponse and conditions its response based on information from the retrieved documents. This reduces the liklihood of hallucinations.

### What are potential challenges or limitations of your current implementation?

In my current implementation, I could not get the model to stay reliably grounded in the retrieved documents, even when matching relevant info was provided. This is seen in the very low grounding overlap scores and the fact that hallucinations exist in the generated responses that can be easily detected by a quick human review. Another challenge is that the knowlegde base I am working with here is small and simple, which means that the amount of authoritative context is limited.

### How might you enhance this system for a production healthcare environment?

For a production usage, we could enhance the system by giving it a more capable, instruction-specific language model. We could also expand and curate the model's medical knowledge base with more forms of authoritatice context. Another improvement could be enforcing a series of stricter prompt constraints to limit responses to retrieved content only. Lastly, some potential additional forms of improvement could be stronger screening for hallucinations, implementing human-in-the-loop reviews of responses, and developing/implementing more stringent metrics to measure medical correctness and safety in responses.

### What ethical considerations are particularly important for healthcare content generation?

Ethical considerations are critical in the context of healthcare content generation because it can have a direct and significant affect on patients relying on responses from the model. If left unchecked, hallucinations and other limitations with the model can lead to underdiagnosing, misdiagnosis, or even medical advice which might harm the patient. Responses from this model in this context must be safeguarded in as impeccable a way as possible because medical misinformation can have deadly consequences.