# Medical RAG Pipeline - Healthcare Q&A System

This notebook demonstrates how to build a Retrieval-Augmented Generation (RAG) system specifically designed for medical and healthcare questions. We'll use medical knowledge bases and implement specialized retrieval strategies for accurate medical information.

## Features:
- Medical document preprocessing
- Specialized medical embeddings
- Context-aware retrieval
- Safe medical response generation
- Citation and source tracking

## 1. Setup and Imports

In [None]:
!pip install -q transformers langchain faiss-cpu sentence-transformers datasets pandas numpy

In [None]:
import os
import pandas as pd
import numpy as np
from typing import List, Dict, Any
import warnings
warnings.filterwarnings('ignore')

# LangChain imports
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate

# Transformers imports
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

print("Libraries imported successfully!")

## 2. Medical Knowledge Base Setup

We'll create a sample medical knowledge base. In practice, you would use real medical databases like PubMed, medical textbooks, or clinical guidelines.

In [None]:
# Sample medical documents (in practice, load from real medical databases)
medical_documents = [
    {
        "title": "Diabetes Mellitus Type 2 - Overview",
        "content": """Type 2 diabetes mellitus is a chronic metabolic disorder characterized by insulin resistance and relative insulin deficiency. 
        It affects how the body metabolizes glucose, leading to elevated blood sugar levels. Risk factors include obesity, sedentary lifestyle, 
        genetic predisposition, and age. Management involves lifestyle modifications, blood glucose monitoring, and often medication including 
        metformin, sulfonylureas, or insulin therapy. Complications can include cardiovascular disease, nephropathy, retinopathy, and neuropathy.""",
        "source": "Medical Textbook - Endocrinology",
        "category": "Endocrinology"
    },
    {
        "title": "Hypertension - Diagnosis and Treatment",
        "content": """Hypertension, or high blood pressure, is defined as systolic blood pressure ≥140 mmHg or diastolic blood pressure ≥90 mmHg. 
        It's often called the 'silent killer' because it typically has no symptoms. Primary hypertension has no identifiable cause, 
        while secondary hypertension results from underlying conditions. Treatment includes lifestyle modifications (diet, exercise, 
        weight loss, sodium restriction) and antihypertensive medications (ACE inhibitors, ARBs, diuretics, calcium channel blockers).""",
        "source": "Cardiology Guidelines 2023",
        "category": "Cardiology"
    },
    {
        "title": "Pneumonia - Clinical Presentation and Management",
        "content": """Pneumonia is an infection that inflames air sacs in one or both lungs, which may fill with fluid or pus. 
        Symptoms include cough with phlegm, fever, chills, and difficulty breathing. Common causes include bacteria (Streptococcus pneumoniae), 
        viruses, and fungi. Diagnosis involves chest X-ray, blood tests, and sputum culture. Treatment depends on the causative organism: 
        bacterial pneumonia is treated with antibiotics, while viral pneumonia may require supportive care or antiviral medications.""",
        "source": "Infectious Disease Handbook",
        "category": "Pulmonology"
    },
    {
        "title": "Migraine Headaches - Pathophysiology and Treatment",
        "content": """Migraine is a primary headache disorder characterized by recurrent attacks of moderate to severe headache, often unilateral 
        and pulsating. Associated symptoms may include nausea, vomiting, photophobia, and phonophobia. Triggers can include stress, 
        certain foods, hormonal changes, and sleep disturbances. Treatment involves acute therapy (triptans, NSAIDs) for active attacks 
        and preventive therapy (beta-blockers, anticonvulsants, CGRP antagonists) for frequent migraines.""",
        "source": "Neurology Clinical Practice",
        "category": "Neurology"
    },
    {
        "title": "Myocardial Infarction - Emergency Management",
        "content": """Myocardial infarction (heart attack) occurs when blood flow to part of the heart muscle is blocked, usually by a blood clot. 
        Symptoms include chest pain, shortness of breath, nausea, and sweating. Immediate treatment is crucial and includes aspirin, 
        oxygen therapy, nitroglycerin, and emergency reperfusion therapy (PCI or thrombolytics). Long-term management involves 
        antiplatelet therapy, beta-blockers, ACE inhibitors, and statins to prevent future cardiac events.""",
        "source": "Emergency Medicine Protocols",
        "category": "Emergency Medicine"
    }
]

print(f"Loaded {len(medical_documents)} medical documents")

## 3. Document Processing and Embedding Creation

In [None]:
# Initialize medical-domain embeddings
# Using a model fine-tuned on biomedical texts for better medical context understanding
embedding_model = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2",  # You can use "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract" for better medical performance
    model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}
)

print("Embedding model initialized")

In [None]:
# Text splitter for processing documents
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,
    chunk_overlap=50,
    separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
)

# Process documents
processed_docs = []
for doc in medical_documents:
    # Split the content into chunks
    chunks = text_splitter.split_text(doc["content"])
    
    for i, chunk in enumerate(chunks):
        processed_docs.append({
            "content": chunk,
            "metadata": {
                "title": doc["title"],
                "source": doc["source"],
                "category": doc["category"],
                "chunk_id": i
            }
        })

print(f"Processed {len(processed_docs)} document chunks")

In [None]:
# Create vector database
texts = [doc["content"] for doc in processed_docs]
metadatas = [doc["metadata"] for doc in processed_docs]

# Create FAISS vector store
vectorstore = FAISS.from_texts(
    texts=texts,
    embedding=embedding_model,
    metadatas=metadatas
)

print("Vector database created successfully")

## 4. Medical LLM Setup

We'll use a language model optimized for medical conversations. In practice, you might use models like BioBERT, ClinicalBERT, or fine-tuned medical models.

In [None]:
# Initialize language model (using a smaller model for demo - replace with medical-specific models in production)
model_name = "microsoft/DialoGPT-medium"  # Replace with medical-specific models like "microsoft/BioGPT"

# Create text generation pipeline
text_generator = pipeline(
    "text-generation",
    model=model_name,
    tokenizer=model_name,
    max_length=512,
    temperature=0.1,  # Lower temperature for more factual responses
    do_sample=True,
    device=0 if torch.cuda.is_available() else -1
)

# Wrap in LangChain
llm = HuggingFacePipeline(pipeline=text_generator)

print("Language model initialized")

## 5. Medical RAG Pipeline Creation

In [None]:
# Create medical-specific prompt template
medical_prompt_template = """
You are a knowledgeable medical AI assistant. Based on the provided medical context, answer the question accurately and professionally.

IMPORTANT GUIDELINES:
1. Only provide information based on the given context
2. If the context doesn't contain enough information, state this clearly
3. Always recommend consulting healthcare professionals for medical decisions
4. Include relevant citations from the provided sources
5. Be precise and avoid speculation

Context:
{context}

Question: {question}

Medical Response:
"""

medical_prompt = PromptTemplate(
    template=medical_prompt_template,
    input_variables=["context", "question"]
)

print("Medical prompt template created")

In [None]:
# Create retrieval QA chain
medical_qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=vectorstore.as_retriever(
        search_type="similarity",
        search_kwargs={"k": 3}  # Retrieve top 3 most relevant chunks
    ),
    chain_type_kwargs={"prompt": medical_prompt},
    return_source_documents=True
)

print("Medical RAG pipeline created successfully")

## 6. Medical Question Answering Interface

In [None]:
class MedicalRAGSystem:
    def __init__(self, qa_chain):
        self.qa_chain = qa_chain
    
    def ask_medical_question(self, question: str) -> Dict[str, Any]:
        """
        Process a medical question and return a comprehensive answer with sources.
        """
        try:
            # Get response from RAG chain
            result = self.qa_chain({"query": question})
            
            # Extract answer and sources
            answer = result["result"]
            source_docs = result["source_documents"]
            
            # Format sources
            sources = []
            for doc in source_docs:
                sources.append({
                    "title": doc.metadata.get("title", "Unknown"),
                    "source": doc.metadata.get("source", "Unknown"),
                    "category": doc.metadata.get("category", "Unknown"),
                    "content_preview": doc.page_content[:150] + "..."
                })
            
            return {
                "question": question,
                "answer": answer,
                "sources": sources,
                "disclaimer": "This information is for educational purposes only. Always consult with healthcare professionals for medical advice."
            }
            
        except Exception as e:
            return {
                "question": question,
                "answer": f"Error processing question: {str(e)}",
                "sources": [],
                "disclaimer": "Please consult healthcare professionals for medical advice."
            }
    
    def display_response(self, response: Dict[str, Any]):
        """
        Display the medical response in a formatted manner.
        """
        print("=" * 80)
        print(f"QUESTION: {response['question']}")
        print("=" * 80)
        print(f"ANSWER: {response['answer']}")
        print("\n" + "-" * 40 + " SOURCES " + "-" * 40)
        
        for i, source in enumerate(response['sources'], 1):
            print(f"\n[{i}] {source['title']}")
            print(f"    Source: {source['source']}")
            print(f"    Category: {source['category']}")
            print(f"    Preview: {source['content_preview']}")
        
        print("\n" + "!" * 80)
        print(f"DISCLAIMER: {response['disclaimer']}")
        print("!" * 80)

# Initialize the medical RAG system
medical_rag = MedicalRAGSystem(medical_qa_chain)
print("Medical RAG System initialized and ready!")

## 7. Testing the Medical RAG System

In [None]:
# Test questions
test_questions = [
    "What are the symptoms of diabetes type 2?",
    "How is hypertension diagnosed and treated?",
    "What should I do if someone is having a heart attack?",
    "What are the common triggers for migraine headaches?",
    "How is pneumonia treated?"
]

# Test each question
for question in test_questions:
    response = medical_rag.ask_medical_question(question)
    medical_rag.display_response(response)
    print("\n" + "=" * 100 + "\n")

## 8. Interactive Medical Q&A Session

In [None]:
# Interactive session
def interactive_medical_qa():
    print("Welcome to the Medical RAG Q&A System!")
    print("Ask medical questions and get evidence-based answers.")
    print("Type 'quit' to exit.\n")
    
    while True:
        question = input("\nEnter your medical question: ")
        
        if question.lower() in ['quit', 'exit', 'q']:
            print("Thank you for using the Medical RAG System. Stay healthy!")
            break
        
        if question.strip():
            response = medical_rag.ask_medical_question(question)
            medical_rag.display_response(response)
        else:
            print("Please enter a valid question.")

# Uncomment the line below to start an interactive session
# interactive_medical_qa()

## 9. Performance Evaluation

In [None]:
def evaluate_retrieval_performance():
    """
    Evaluate the retrieval performance of the medical RAG system.
    """
    evaluation_questions = [
        {
            "question": "What are diabetes symptoms?",
            "expected_category": "Endocrinology"
        },
        {
            "question": "How to treat high blood pressure?",
            "expected_category": "Cardiology"
        },
        {
            "question": "What causes headaches?",
            "expected_category": "Neurology"
        }
    ]
    
    correct_retrievals = 0
    total_questions = len(evaluation_questions)
    
    for eval_item in evaluation_questions:
        response = medical_rag.ask_medical_question(eval_item["question"])
        
        # Check if any retrieved source matches expected category
        categories_found = [source["category"] for source in response["sources"]]
        
        if eval_item["expected_category"] in categories_found:
            correct_retrievals += 1
            print(f"✅ Correct retrieval for: {eval_item['question']}")
        else:
            print(f"❌ Incorrect retrieval for: {eval_item['question']}")
            print(f"   Expected: {eval_item['expected_category']}, Found: {categories_found}")
    
    accuracy = correct_retrievals / total_questions
    print(f"\nRetrieval Accuracy: {accuracy:.2%} ({correct_retrievals}/{total_questions})")

# Run evaluation
evaluate_retrieval_performance()

## 10. Advanced Features and Improvements

This section outlines potential improvements for a production medical RAG system.

In [None]:
# Advanced features that could be implemented:

class AdvancedMedicalRAG:
    """
    Advanced medical RAG system with additional features for production use.
    """
    
    def __init__(self):
        self.confidence_threshold = 0.7
        self.safety_keywords = ["emergency", "urgent", "severe", "critical"]
    
    def detect_emergency(self, question: str) -> bool:
        """
        Detect if a question indicates a medical emergency.
        """
        question_lower = question.lower()
        emergency_indicators = [
            "chest pain", "heart attack", "stroke", "difficulty breathing",
            "severe bleeding", "unconscious", "poisoning", "overdose"
        ]
        
        return any(indicator in question_lower for indicator in emergency_indicators)
    
    def calculate_confidence(self, retrieval_scores: List[float]) -> float:
        """
        Calculate confidence score based on retrieval similarities.
        """
        if not retrieval_scores:
            return 0.0
        
        # Simple confidence calculation - can be improved
        return np.mean(retrieval_scores)
    
    def add_safety_warnings(self, question: str, answer: str) -> str:
        """
        Add appropriate safety warnings to medical responses.
        """
        warnings = []
        
        if self.detect_emergency(question):
            warnings.append("🚨 EMERGENCY: If this is a medical emergency, call emergency services immediately (911 in the US).")
        
        if any(keyword in question.lower() for keyword in ["medication", "drug", "dosage"]):
            warnings.append("💊 MEDICATION WARNING: Never change medication dosages without consulting your healthcare provider.")
        
        if warnings:
            return "\n".join(warnings) + "\n\n" + answer
        
        return answer
    
    def cite_sources_properly(self, sources: List[Dict]) -> str:
        """
        Generate proper medical citations.
        """
        citations = []
        for i, source in enumerate(sources, 1):
            citation = f"[{i}] {source['title']} - {source['source']}"
            citations.append(citation)
        
        return "\n".join(citations)

print("Advanced Medical RAG features outlined")
print("\nProduction improvements would include:")
print("- Medical emergency detection")
print("- Confidence scoring")
print("- Safety warnings")
print("- Proper medical citations")
print("- Integration with medical databases (PubMed, UpToDate)")
print("- Multi-language support")
print("- User authentication and logging")
print("- Regular knowledge base updates")

## 11. Saving and Loading the System

In [None]:
# Save the vector database for later use
vectorstore.save_local("medical_vectorstore")
print("Vector database saved successfully")

# To load the vector database later:
# loaded_vectorstore = FAISS.load_local("medical_vectorstore", embedding_model)
# print("Vector database loaded successfully")

## Conclusion

This medical RAG pipeline demonstrates:

1. **Specialized medical knowledge processing** - Handling medical documents with appropriate chunking
2. **Medical-domain embeddings** - Using embeddings optimized for medical text
3. **Safe response generation** - Including disclaimers and safety warnings
4. **Source attribution** - Providing clear citations for medical information
5. **Evaluation framework** - Testing retrieval accuracy

### Next Steps for Production:

1. **Use medical-specific models** like BioBERT, ClinicalBERT, or BioGPT
2. **Integrate real medical databases** such as PubMed, UpToDate, or medical textbooks
3. **Implement robust safety measures** including emergency detection and appropriate warnings
4. **Add human oversight** for sensitive medical queries
5. **Ensure compliance** with medical regulations and privacy laws (HIPAA)
6. **Regular updates** of the knowledge base with latest medical research

**Important Disclaimer:** This system is for educational and research purposes only. Always consult qualified healthcare professionals for medical advice, diagnosis, or treatment.