# Retrieval Augmented Generation (RAG) Pipeline

In [None]:
import warnings
warnings.filterwarnings('ignore')

import wikipedia
from langchain_community.document_loaders import WebBaseLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter  
from langchain_core.documents import Document  
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from openai import OpenAI

In [None]:
# OpenAI API key - replace with your key
OPENAI_API_KEY = 'your-openai-api-key-here'
client = OpenAI(api_key=OPENAI_API_KEY)

In [None]:
# Function to load Wikipedia documents based on a list of topics
def load_wikipedia_docs(topics):
    documents = []
    
    for topic in topics:
        try:
            page = wikipedia.page(topic, auto_suggest=False)
            doc = Document(
                page_content=page.content,
                metadata={"source": page.url, "title": page.title}
            )
            documents.append(doc)
            print(f"Loaded: {page.title}")
        except wikipedia.exceptions.DisambiguationError as e:
            page = wikipedia.page(e.options[0], auto_suggest=False)
            doc = Document(
                page_content=page.content,
                metadata={"source": page.url, "title": page.title}
            )
            documents.append(doc)
            print(f"Loaded: {page.title}")
        except Exception as e:
            print(f"Could not load {topic}: {e}")
    
    return documents

In [None]:
# Wikipedia topics on health and medicine
wiki_topics = [
    "Vaccination",
    "Antibiotic",
    "Vitamin D",
    "Heart disease",
    "Immune system",
    "Covid-19",
    "Cancer",
    "Tuberculosis",
    "Obesity",
    "Cholesterol",
    "Stroke",
    "Arthritis"
]

print("Loading Wikipedia articles...\n")
wiki_docs = load_wikipedia_docs(wiki_topics)
print(f"\nLoaded {len(wiki_docs)} Wikipedia articles")

In [None]:
# Function to load documents from web URLs using LangChain's WebBaseLoader
def load_web_docs(urls):
    """Load documents from web URLs using LangChain's WebBaseLoader"""
    documents = []
    
    for url in urls:
        try:
            loader = WebBaseLoader(url)
            docs = loader.load()
            documents.extend(docs)
            print(f"Loaded: {url}")
        except Exception as e:
            print(f"Could not load {url}: {e}")
    
    return documents

In [None]:
# Health articles from the web
web_urls = [
    "https://www.who.int/news-room/fact-sheets/detail/malaria",
    "https://www.who.int/news-room/fact-sheets/detail/diabetes",
    "https://www.who.int/news-room/fact-sheets/detail/hypertension",
    "https://www.who.int/health-topics/hepatitis#tab=tab_1",
    "https://www.who.int/health-topics/ebola#tab=tab_1",
    "https://www.who.int/health-topics/nutrition#tab=tab_1",
    "https://www.who.int/health-topics/physical-activity#tab=tab_1",
    "https://www.who.int/health-topics/self-care#tab=tab_1"
    
]

print("Loading web articles...\n")
web_docs = load_web_docs(web_urls)
print(f"\nLoaded {len(web_docs)} web articles")

# Combine all documents
all_documents = wiki_docs + web_docs
print(f"Total documents loaded: {len(all_documents)}")

## Chunking

In [None]:
# Split documents into smaller chunks of 500 characters for better embedding and retrieval 
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,       
    chunk_overlap=100,    
    separators=["\n\n", "\n", ". ", " ", ""]  # Split on paragraphs first, then sentences to preserve context
)

chunks = text_splitter.split_documents(all_documents)
print(f"Split {len(all_documents)} documents into {len(chunks)} chunks")

## Create Embeddings and Building the FAISS Index


In [None]:
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

# Extracting text content from chunks
chunk_texts = [chunk.page_content for chunk in chunks]

# Creating embeddings 
embeddings = embedding_model.encode(chunk_texts, show_progress_bar=False)
print(f"Created {len(embeddings)} embeddings of dimension {embeddings.shape[1]}")

In [None]:
# Building the FAISS index
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings.astype('float32'))

print(f"FAISS index built with {index.ntotal} vectors")

## Building the RAG Pipeline

In [None]:
# function to retrieve relevant chunks based on a query
def retrieve(query, k=3):
    # Embed the query
    query_embedding = embedding_model.encode([query])
    
    # Search the FAISS index
    distances, indices = index.search(query_embedding.astype('float32'), k)
    
    # Get the relevant chunks
    retrieved_chunks = []
    for i, idx in enumerate(indices[0]):
        retrieved_chunks.append({
            "content": chunks[idx].page_content,
            "source": chunks[idx].metadata.get("source", "Unknown"),
            "distance": distances[0][i]
        })
    
    return retrieved_chunks


# Function to enhance the prompt with the query and retrieved context
def enhance_prompt(query, retrieved_chunks):
    context = "\n\n".join([chunk["content"] for chunk in retrieved_chunks])
    prompt = f"""Use the following context to answer the question and if the context doesn't contain enough information to answer, say so.

CONTEXT:
{context}
QUESTION: {query}
ANSWER:"""
    
    return prompt

# Function to generate an answer using OpenAI
def generate(prompt):
    response = client.chat.completions.create(
        model="gpt-4.1",
        messages=[
            {"role": "system", "content": "You are a helpful medical information assistant. Answer questions based only on the provided context."},
            {"role": "user", "content": prompt}],
        temperature=0.3)
    return response.choices[0].message.content


# Main RAG query function
def rag_query(question, k=3, show_sources=True):
    print(f"Question: {question}\n")
    
    # Retrieve the relevant chunks
    retrieved = retrieve(question, k=k)
    
    if show_sources:
        print("Retrieved Sources:")
        for i, chunk in enumerate(retrieved, 1):
            print(f"  [{i}] {chunk['source'][:80]}... (distance: {chunk['distance']:.3f})")
        print()
    
    # Augment the prompt with context
    enhanced_prompt = enhance_prompt(question, retrieved)
    
    # Step 3: Generate answer
    answer = generate(enhanced_prompt)
    
    print(f"Answer: {answer}\n")
    
    return answer

## Testing the RAG System


In [None]:
# Questions from the wiki topics and web link topics for testing the RAG system
test_questions = [
    "Why is vaccination important for public health?",
    "What are the common side effects of antibiotics?",
    "How does vitamin D deficiency affect the body?",
    "What are the major risk factors associated with heart disease?",
    "How does Covidâ€‘19 primarily spread between people?",
    "What are the main symptoms of malaria?",
    "What is the recommended diet for preventing type 2 diabetes?",
    "How can high blood pressure be controlled through lifestyle?",
    "What are the typical signs of hepatitis infection?",
    "Why is regular physical activity beneficial for overall health?"
]

In [None]:
print("TESTING THE RAG SYSTEM - HEALTH & MEDICINE Q&A")
print()

for question in test_questions:
    rag_query(question, k=3)