# RAG Vaccine Information System

This notebook implements a Retrieval-Augmented Generation (RAG) system for vaccine information using Google's Gemini AI. The system:
- Ingests documents from PDFs and websites
- Creates a searchable vector database using embeddings
- Provides an AI agent that answers questions using retrieved context

## 1. Setup and Configuration

Import required libraries and configure API access.

In [None]:
from pathlib import Path
import sys
import numpy as np

from google.adk.agents import Agent
from google.adk.models.google_llm import Gemini
from google.adk.tools import FunctionTool
from google.genai import types
from google.adk.apps.app import App
from google.adk.sessions import InMemorySessionService
from google.adk.runners import Runner

print("✅ Imports loaded")

### Project Path Configuration

Add the project root to Python's import path to access custom modules.

In [None]:
# Ensure the project root (the parent of the "src" directory) is on sys.path
# so that "import src.model" finds the src package under the project root.
project_root = Path.cwd().parent
src_dir = project_root / "src"

project_root_path = str(project_root.resolve())
if project_root_path not in sys.path:
    sys.path.insert(0, project_root_path)

from src.model import Intensity, SentimentOutput
from src.model.rag_output import RagOutput
from src.config import load_env_variables, get_env_variable

In [None]:
load_env_variables()

GOOGLE_API_KEY = get_env_variable("GOOGLE_API_KEY")
print(f"✅ API key loaded")

### API Configuration

Load environment variables and configure retry policy for API calls.

In [None]:
retry_config = types.HttpRetryOptions(
    attempts=3,
    initial_delay=1,
    http_status_codes=[429, 500, 503, 504]
)

## 2. Data Structures

Define the core data structures used throughout the RAG pipeline.

In [None]:
import os
import time
import pickle
import heapq
import numpy as np
import requests
from urllib.parse import urljoin, urlparse
from bs4 import BeautifulSoup
from pypdf import PdfReader
from dataclasses import dataclass
from pathlib import Path as FilePath
from google import genai

@dataclass
class DocumentChunk:
    """Represents a chunk of text from a document with metadata."""
    id: int
    content: str
    source: str      # URL or file path
    doc_type: str    # "web" or "pdf"

## 3. Document Processing

Functions to load, parse, and chunk documents from various sources.

In [None]:
def chunk_text(text: str,
               source: str,
               doc_type: str,
               chunk_size: int = 800,
               chunk_overlap: int = 200,
               start_id: int = 0) -> list[DocumentChunk]:
    """
    Splits text into overlapping chunks for RAG.
    
    Args:
        text: The text to chunk
        source: Source identifier (URL or file path)
        doc_type: Type of document ("web" or "pdf")
        chunk_size: Number of words per chunk
        chunk_overlap: Number of overlapping words between chunks
        start_id: Starting ID for chunks
    
    Returns:
        List of DocumentChunk objects
    """
    words = text.split()
    chunks = []
    i = 0
    current_id = start_id

    while i < len(words):
        chunk_words = words[i:i + chunk_size]
        chunk_text = " ".join(chunk_words).strip()
        if chunk_text:
            chunks.append(
                DocumentChunk(
                    id=current_id,
                    content=chunk_text,
                    source=source,
                    doc_type=doc_type
                )
            )
            current_id += 1
        i += chunk_size - chunk_overlap

    return chunks

### PDF Processing

Load and chunk PDF documents from a folder.

In [None]:
def load_pdfs_from_folder(folder_path: str) -> list[DocumentChunk]:
    """
    Reads all PDFs in a folder and converts them to chunks.
    
    Args:
        folder_path: Path to folder containing PDF files
        
    Returns:
        List of DocumentChunk objects from all PDFs
    """
    all_chunks: list[DocumentChunk] = []
    current_id = 0
    
    if not os.path.exists(folder_path):
        print(f"[PDF] Folder not found: {folder_path}")
        return []

    for filename in os.listdir(folder_path):
        if not filename.lower().endswith(".pdf"):
            continue

        full_path = os.path.join(folder_path, filename)
        print(f"[PDF] Loading: {full_path}")

        try:
            reader = PdfReader(full_path)
            text_pages = []
            for page in reader.pages:
                page_text = page.extract_text() or ""
                text_pages.append(page_text)

            full_text = "\n".join(text_pages)
            chunks = chunk_text(
                text=full_text,
                source=full_path,
                doc_type="pdf",
                start_id=current_id
            )
            all_chunks.extend(chunks)
            current_id = all_chunks[-1].id + 1 if all_chunks else current_id

        except Exception as e:
            print(f"[PDF] ERROR on {full_path}: {e}")

    print(f"[PDF] Total chunks from PDFs: {len(all_chunks)}")
    return all_chunks

### Web Scraping

Crawl websites and extract text content.

In [None]:
def is_same_domain(url: str, root_netloc: str) -> bool:
    """Check if a URL belongs to the same domain."""
    try:
        return urlparse(url).netloc == root_netloc
    except Exception:
        return False


def extract_text_from_html(html: str) -> str:
    """Extract clean text from HTML, removing scripts and styles."""
    soup = BeautifulSoup(html, "html.parser")

    # Remove script and style elements
    for tag in soup(["script", "style", "noscript"]):
        tag.decompose()

    text = soup.get_text(separator=" ")
    # Normalize whitespace
    return " ".join(text.split())


def crawl_website(root_url: str,
                  max_pages: int = 100,
                  max_depth: int = 3,
                  start_id: int = 0) -> list[DocumentChunk]:
    """
    Crawl a website and extract text chunks from all pages.
    
    Uses BFS to follow internal links up to a maximum depth.
    Only crawls pages from the same domain as the root URL.
    
    Args:
        root_url: Starting URL for the crawl
        max_pages: Maximum number of pages to visit
        max_depth: Maximum link depth from root
        start_id: Starting ID for chunks
        
    Returns:
        List of DocumentChunk objects from all crawled pages
    """
    parsed_root = urlparse(root_url)
    root_netloc = parsed_root.netloc

    visited = set()
    to_visit: list[tuple[str, int]] = [(root_url, 0)]
    all_chunks: list[DocumentChunk] = []
    current_id = start_id

    session = requests.Session()
    session.headers.update({"User-Agent": "rag-bot/1.0"})

    while to_visit and len(visited) < max_pages:
        url, depth = to_visit.pop(0)

        if url in visited:
            continue
        visited.add(url)

        if depth > max_depth:
            continue

        try:
            print(f"[WEB] Downloading ({depth}): {url}")
            resp = session.get(url, timeout=10)
            if "text/html" not in resp.headers.get("Content-Type", ""):
                continue

            text = extract_text_from_html(resp.text)
            if text.strip():
                chunks = chunk_text(
                    text=text,
                    source=url,
                    doc_type="web",
                    start_id=current_id
                )
                all_chunks.extend(chunks)
                current_id = all_chunks[-1].id + 1 if all_chunks else current_id

            # Extract links and add to queue
            soup = BeautifulSoup(resp.text, "html.parser")
            for a in soup.find_all("a", href=True):
                href = a["href"]
                full_url = urljoin(url, href)
                parsed = urlparse(full_url)

                # Keep only HTTP/HTTPS links from same domain
                if parsed.scheme not in ("http", "https"):
                    continue
                if not is_same_domain(full_url, root_netloc):
                    continue
                if full_url not in visited:
                    to_visit.append((full_url, depth + 1))

            # Polite crawling delay
            time.sleep(0.2)

        except Exception as e:
            print(f"[WEB] ERROR on {url}: {e}")

    print(f"[WEB] Total pages visited: {len(visited)}")
    print(f"[WEB] Total chunks from site: {len(all_chunks)}")
    return all_chunks

## 4. Embedding and Retrieval

Functions for creating embeddings and retrieving similar chunks.

In [None]:
# Initialize Gemini Client for embeddings
client = genai.Client(api_key=GOOGLE_API_KEY)

# Cache configuration
CACHE_DIR = FilePath("../cache")
EMBEDDINGS_CACHE_FILE = CACHE_DIR / "embeddings.pkl"
CHUNKS_CACHE_FILE = CACHE_DIR / "chunks.pkl"

def save_index_to_cache(embeddings: np.ndarray, chunks: list[DocumentChunk]):
    """
    Save embeddings and chunks to disk cache for faster loading.
    
    Args:
        embeddings: Numpy array of embedding vectors
        chunks: List of DocumentChunk objects
    """
    CACHE_DIR.mkdir(exist_ok=True)
    
    with open(EMBEDDINGS_CACHE_FILE, 'wb') as f:
        pickle.dump(embeddings, f)
    
    with open(CHUNKS_CACHE_FILE, 'wb') as f:
        pickle.dump(chunks, f)
    
    print(f"✅ Index cached to {CACHE_DIR}")

def load_index_from_cache() -> tuple[np.ndarray, list[DocumentChunk]]:
    """
    Load embeddings and chunks from disk cache.
    
    Returns:
        Tuple of (embeddings array, chunks list) or (empty array, empty list) if not cached
    """
    if not EMBEDDINGS_CACHE_FILE.exists() or not CHUNKS_CACHE_FILE.exists():
        return np.array([]), []
    
    try:
        with open(EMBEDDINGS_CACHE_FILE, 'rb') as f:
            embeddings = pickle.load(f)
        
        with open(CHUNKS_CACHE_FILE, 'rb') as f:
            chunks = pickle.load(f)
        
        print(f"✅ Index loaded from cache: {len(chunks)} chunks")
        return embeddings, chunks
    except Exception as e:
        print(f"⚠️ Error loading cache: {e}")
        return np.array([]), []

def embed_texts(texts: list[str],
                batch_size: int = 64) -> np.ndarray:
    """
    Generate embeddings for a list of texts using Google Gemini API.
    
    Args:
        texts: List of text strings to embed
        batch_size: Number of texts to embed per API call
        
    Returns:
        Numpy array of shape (n, d) where n is number of texts and d is embedding dimension
    """
    if not texts:
        return np.zeros((0, 0), dtype=np.float32)

    all_vectors = []

    try:
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            print(f"[EMB] Batch {i}–{i+len(batch)-1} of {len(texts)}")

            # Gemini embedding call
            response = client.models.embed_content(
                model="text-embedding-004",
                contents=batch
            )
            # Extract embedding vectors from response
            vectors = [item.values for item in response.embeddings]
            all_vectors.extend(vectors)

        return np.array(all_vectors, dtype=np.float32)
    
    except Exception as e:
        print(f"❌ Error during embedding: {e}")
        raise

def build_vector_index(chunks: list[DocumentChunk]) -> tuple[np.ndarray, list[DocumentChunk]]:
    """
    Create embedding vectors for all document chunks.
    
    Args:
        chunks: List of DocumentChunk objects
        
    Returns:
        Tuple of (embeddings array, chunks list)
    """
    print("[INDEX] Calculating embeddings for all chunks...")
    texts = [c.content for c in chunks]
    embeddings = embed_texts(texts)
    print(f"[INDEX] Embeddings shape: {embeddings.shape}")
    return embeddings, chunks

In [None]:
def cosine_similarity_matrix(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    """
    Calculate cosine similarity between multiple vectors and a single query vector.
    
    Args:
        a: Matrix of shape (n, d) containing n vectors of dimension d
        b: Single vector of shape (d,)
        
    Returns:
        Array of shape (n,) with similarity scores
    """
    if a.size == 0:
        return np.array([])

    a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-10)
    b_norm = b / (np.linalg.norm(b) + 1e-10)
    return np.dot(a_norm, b_norm)


def retrieve_top_k(query: str,
                   embeddings: np.ndarray,
                   chunks: list[DocumentChunk],
                   k: int = 5) -> list[DocumentChunk]:
    """
    Retrieve the top-k most similar chunks to a query.
    
    Embeds the query and finds the k chunks with highest cosine similarity.
    
    Args:
        query: Query string
        embeddings: Matrix of chunk embeddings (n, d)
        chunks: List of DocumentChunk objects corresponding to embeddings
        k: Number of top results to return
        
    Returns:
        List of top-k DocumentChunk objects, ordered by similarity (highest first)
    """
    if embeddings.size == 0 or not chunks:
        return []

    query_emb = embed_texts([query])[0]  # (d,)
    sims = cosine_similarity_matrix(embeddings, query_emb)  # (n,)

    # Get indices of top k similarities using heapq
    k = min(k, len(chunks))
    top_k_items = heapq.nlargest(k, enumerate(sims), key=lambda x: x[1])
    
    # Extract chunks in order of similarity (highest first)
    top_chunks = [chunks[idx] for idx, _ in top_k_items]
    return top_chunks

## 5. Knowledge Base Initialization

Build or load the vector index from document sources.

In [None]:
# Configuration
CONFIG = {
    "pdf_folder": str((project_root / "src" / "Doc_vaccini").resolve()),
    "root_url": "https://www.serviziterritoriali-asstmilano.it/servizi/vaccinazioni/",
    "max_pages": 10,
    "max_depth": 2,
    "use_cache": True
}

# Initialize global index
global_embeddings = np.array([])
global_chunks = []

# Load from cache or build fresh
if CONFIG["use_cache"]:
    global_embeddings, global_chunks = load_index_from_cache()

if global_embeddings.size == 0:
    print("Building knowledge base...")
    
    # Load and process documents
    pdf_chunks = load_pdfs_from_folder(CONFIG["pdf_folder"])
    web_start_id = pdf_chunks[-1].id + 1 if pdf_chunks else 0
    web_chunks = crawl_website(
        CONFIG["root_url"], 
        max_pages=CONFIG["max_pages"], 
        max_depth=CONFIG["max_depth"], 
        start_id=web_start_id
    )
    
    # Build index
    all_chunks = pdf_chunks + web_chunks
    if all_chunks:
        global_embeddings, global_chunks = build_vector_index(all_chunks)
        save_index_to_cache(global_embeddings, global_chunks)
        print(f"✅ Index built: {len(global_chunks)} chunks")
    else:
        print("⚠️ No content found")
else:
    print(f"✅ Loaded from cache: {len(global_chunks)} chunks")

In [None]:
def retrieve_vaccine_info(query: str) -> str:
    """
    Retrieves vaccine information from the knowledge base.
    
    This function is exposed to the AI agent as a tool.
    
    Args:
        query: User's question about vaccines
        
    Returns:
        Formatted string with relevant information and source citations
    """
    if not global_chunks:
        return "Error: Knowledge base not initialized. Run ingestion cell first."
    
    try:
        top_chunks = retrieve_top_k(query, global_embeddings, global_chunks, k=5)
        if not top_chunks:
            return "No relevant information found."
        
        # Format results with source citations
        results = []
        for c in top_chunks:
            source = FilePath(c.source).name if c.doc_type == "pdf" else c.source
            results.append(f"[SOURCE: {source}]\n{c.content}")
        
        return "\n\n---\n\n".join(results)
    except Exception as e:
        return f"Error: {str(e)}"

# Create tool wrapper for the agent
rag_tool = FunctionTool(retrieve_vaccine_info)
print(f"✅ Tool ready" if global_chunks else "⚠️ Run ingestion first")

## 6. RAG Tool

Create the retrieval function that the agent will use to access the knowledge base.

## 7. Agent Configuration

Configure the RAG agent with the retrieval tool and structured output schema.

In [None]:
prompt = """You are a helpful assistant for vaccine information.
You have access to a knowledge base containing official documents and web pages about vaccinations.
    
When the user asks a question:
1. Use the `retrieve_vaccine_info` tool to find relevant information.
2. Answer the question based ONLY on the information returned by the tool.
3. If the tool returns no information, or the information is not pertinent, return an error in the format specified below
4. Always cite the sources provided in the tool output.
5. Be concise but thorough in your responses.
"""

rag_agent = Agent(
    name="RAG_Vaccine_Informer",
    model=Gemini(
        model="gemini-2.5-flash-lite", 
        retry_options=retry_config
    ),
    instruction=prompt,
    tools=[rag_tool],
    output_key="rag_output",
    #output_schema=RagOutput,
)

print("✅ RAG Agent configured")

## 8. Application Setup

Create the application and runner instances.

In [None]:
# Create session service for managing conversation state
session_service = InMemorySessionService()

# Create application with RAG agent as root
application = App(
    name="VaccineInfoRAG",
    root_agent=rag_agent
)

# Create runner to execute queries
runner = Runner(
    app=application, 
    session_service=session_service
)

print("✅ Application ready")

## 9. Testing

Test the RAG system with sample queries.

### Single Query Test

In [None]:
response = await runner.run_debug("Tell me the policy for pregnant woment vaccination")
