# RAG Pipeline - Semantic Search Interface

**Notebook ID:** `05_rag_pipeline_v1`  
**Description:** Interactive semantic search interface using Gradio for querying ChromaDB vector store

This notebook provides:
- Simple semantic search interface using Gradio
- Query ChromaDB instance created in `04_vector_store_v1.ipynb`
- Display results with content preview (500 characters) and similarity scores
- Configurable number of results (3-10 chunks)

---


In [8]:
# CELL_ID: 05_rag_pipeline_v1_imports
# ============================================================================
# IMPORT DEPENDENCIES
# ============================================================================

# %pip install gradio chromadb requests --quiet

import json
import chromadb
from pathlib import Path
from typing import List, Dict, Any, Optional
from chromadb.config import Settings


In [9]:
# CELL_ID: 05_rag_pipeline_v1_jina_embedding
# ============================================================================
# JINA EMBEDDING FUNCTION (REUSED FROM 04_vector_store_v1)
# ============================================================================

import requests
import time
from typing import List

class JinaEmbeddingFunction:
    """
    Custom embedding function for ChromaDB using Jina API.
    Implements the interface expected by ChromaDB's embedding_function parameter.
    """
    
    def __init__(
        self,
        api_key: str = "jina_dc47aa711aa944799688c5c6f82215595xhVihTxOWOBR7-ZK0LXOM3g3oxY",
        model: str = "jina-embeddings-v4",
        task: str = "text-matching",
        api_url: str = "https://api.jina.ai/v1/embeddings",
        batch_size: int = 10,
        max_retries: int = 3
    ):
        self.api_key = api_key
        self.model = model
        self.task = task
        self.api_url = api_url
        self.batch_size = batch_size
        self.max_retries = max_retries
        self.headers = {
            'Content-Type': 'application/json',
            'Authorization': f'Bearer {api_key}'
        }
    
    def name(self) -> str:
        """
        Return the name of the embedding function.
        Required by ChromaDB for embedding function validation.
        """
        return "jina-embeddings-v4"
    
    def __call__(self, input):
        """Generate embeddings for input text(s)."""
        if isinstance(input, str):
            texts = [input]
        else:
            texts = input
        
        if not texts:
            return []
        
        all_embeddings = []
        for i in range(0, len(texts), self.batch_size):
            batch = texts[i:i + self.batch_size]
            batch_embeddings = self._embed_batch(batch)
            all_embeddings.extend(batch_embeddings)
        
        return all_embeddings
    
    def _embed_batch(self, texts: List[str]) -> List[List[float]]:
        """Embed a batch of texts using Jina API."""
        data = {
            "model": self.model,
            "task": self.task,
            "input": [{"text": text} for text in texts]
        }
        
        for attempt in range(self.max_retries):
            try:
                response = requests.post(
                    self.api_url,
                    headers=self.headers,
                    json=data,
                    timeout=60
                )
                response.raise_for_status()
                
                result = response.json()
                embeddings = []
                if 'data' in result:
                    for item in result['data']:
                        if 'embedding' in item:
                            embeddings.append(item['embedding'])
                    return embeddings
                else:
                    raise ValueError(f"Unexpected API response format: {result}")
                    
            except requests.exceptions.RequestException as e:
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(f"‚ö† API request failed (attempt {attempt + 1}/{self.max_retries}), retrying in {wait_time}s...")
                    time.sleep(wait_time)
                else:
                    raise Exception(f"Failed to get embeddings after {self.max_retries} attempts: {e}")
        
        return []

print("‚úì JinaEmbeddingFunction class loaded")


‚úì JinaEmbeddingFunction class loaded


In [10]:
# CELL_ID: 05_rag_pipeline_v1_chromadb_reader
# ============================================================================
# CHROMADB READER CLASS (REUSED FROM 04_vector_store_v1)
# ============================================================================

class ChromaDBReader:
    """
    Handles reading/searching from Chroma DB with Jina embedding function.
    Adapted from ChromaDBWriter in 04_vector_store_v1.ipynb.
    """
    
    def __init__(
        self,
        chroma_db_path: str = "./chroma_db",
        collection_name: str = "diabetes_guidelines_v1",
        embedding_function = None
    ):
        self.chroma_db_path = Path(chroma_db_path)
        self.collection_name = collection_name
        self.embedding_function = embedding_function
        self.client = None
        self.collection = None
    
    def initialize(self):
        """Initialize ChromaDB client and collection."""
        if self.client is None:
            self.client = chromadb.PersistentClient(
                path=str(self.chroma_db_path),
                settings=Settings(
                    anonymized_telemetry=False,
                    allow_reset=True
                )
            )
            print(f"‚úì ChromaDB client initialized: {self.chroma_db_path}")
        
        # Get collection
        # For existing collections, ChromaDB uses the stored embedding function
        # We only need to pass it if the collection doesn't exist yet
        try:
            # First, try to get the collection without embedding function
            # (ChromaDB will use the stored configuration)
            self.collection = self.client.get_collection(name=self.collection_name)
            print(f"‚úì Loaded collection: {self.collection_name}")
            print(f"  ‚Ä¢ Total chunks: {self.collection.count()}")
        except Exception as e:
            # If collection doesn't exist, try with embedding function
            # (though this shouldn't happen if 04_vector_store_v1 was run)
            try:
                if self.embedding_function:
                    self.collection = self.client.get_collection(
                        name=self.collection_name,
                        embedding_function=self.embedding_function
                    )
                    print(f"‚úì Loaded collection: {self.collection_name}")
                    print(f"  ‚Ä¢ Total chunks: {self.collection.count()}")
                else:
                    raise Exception(f"Collection '{self.collection_name}' not found and no embedding function provided.")
            except Exception as e2:
                raise Exception(f"Failed to load collection '{self.collection_name}': {e2}. Make sure you've run 04_vector_store_v1.ipynb first.")
    
    def _unflatten_metadata(self, flat_metadata: Dict) -> Dict:
        """Unflatten metadata (parse JSON strings back to objects)."""
        unflattened = {}
        for key, value in flat_metadata.items():
            try:
                if isinstance(value, str) and (value.startswith('[') or value.startswith('{')):
                    unflattened[key] = json.loads(value)
                else:
                    unflattened[key] = value
            except:
                unflattened[key] = value
        return unflattened
    
    def search(self, query: str, n_results: int = 5, where: Dict = None) -> List[Dict]:
        """
        Search the collection with semantic search.
        
        Args:
            query: Search query text
            n_results: Number of results to return
            where: Optional metadata filter
            
        Returns:
            List of result dictionaries with content, metadata, and relevance score
        """
        if not self.collection:
            self.initialize()
        
        results = self.collection.query(
            query_texts=[query],
            n_results=n_results,
            where=where,
            include=['documents', 'metadatas', 'distances']
        )
        
        # Format results
        formatted_results = []
        seen_chunk_ids = set()
        
        for i in range(len(results['ids'][0])):
            chunk_id = results['ids'][0][i]
            
            # Deduplicate
            if chunk_id in seen_chunk_ids:
                continue
            
            chunk_data = {
                'chunk_id': chunk_id,
                'content': results['documents'][0][i],
                'metadata': self._unflatten_metadata(results['metadatas'][0][i]),
                'relevance_score': 1 - results['distances'][0][i],  # Convert distance to similarity
                'distance': results['distances'][0][i]
            }
            formatted_results.append(chunk_data)
            seen_chunk_ids.add(chunk_id)
        
        return formatted_results

print("‚úì ChromaDBReader class loaded")


‚úì ChromaDBReader class loaded


In [11]:
# CELL_ID: 05_rag_pipeline_v1_initialize
# ============================================================================
# INITIALIZE CHROMADB READER
# ============================================================================

print("=" * 60)
print("INITIALIZING CHROMADB READER")
print("=" * 60)

# Initialize Jina embedding function
jina_embedding_fn = JinaEmbeddingFunction()
print("‚úì Jina embedding function ready")

# Initialize ChromaDB reader
chroma_reader = ChromaDBReader(
    chroma_db_path="./chroma_db",
    collection_name="diabetes_guidelines_v1",
    embedding_function=jina_embedding_fn
)
chroma_reader.initialize()

print("=" * 60)


INITIALIZING CHROMADB READER
‚úì Jina embedding function ready
‚úì ChromaDB client initialized: chroma_db
‚úì Loaded collection: diabetes_guidelines_v1
  ‚Ä¢ Total chunks: 78


In [12]:
# CELL_ID: 05_rag_pipeline_v1_search_function
# ============================================================================
# SEMANTIC SEARCH FUNCTION
# ============================================================================

def format_search_results(results: List[Dict], content_length: int = 500) -> str:
    """
    Format search results for display.
    
    Args:
        results: List of search result dictionaries
        content_length: Maximum length of content preview in characters
        
    Returns:
        Formatted string with results
    """
    if not results:
        return "No results found."
    
    output = []
    output.append(f"Found {len(results)} result(s):\n")
    output.append("=" * 60)
    
    for i, result in enumerate(results, 1):
        metadata = result['metadata']
        content = result['content']
        relevance_score = result['relevance_score']
        
        # Format content preview
        content_preview = content[:content_length]
        if len(content) > content_length:
            content_preview += "..."
        
        # Get title and hierarchy info
        title = metadata.get('title', 'N/A')
        level = metadata.get('level', 'N/A')
        url = metadata.get('url', 'N/A')
        parent_title = metadata.get('parent_title', '')
        
        # Build output
        output.append(f"\n[{i}] {title}")
        output.append(f"    Similarity Score: {relevance_score:.3f} ({relevance_score*100:.1f}%)")
        output.append(f"    Level: {level}")
        if parent_title:
            output.append(f"    Parent: {parent_title[:60]}")
        output.append(f"    URL: {url}")
        output.append(f"    Content Preview ({len(content_preview)} chars):")
        output.append(f"    {'-' * 56}")
        # Indent content
        content_lines = content_preview.split('\n')
        for line in content_lines[:10]:  # Limit to 10 lines
            output.append(f"    {line}")
        if len(content_lines) > 10:
            output.append(f"    ... ({len(content_lines) - 10} more lines)")
        output.append("")
    
    return "\n".join(output)

def semantic_search(query: str, num_results: int = 5) -> str:
    """
    Perform semantic search and return formatted results.
    
    Args:
        query: Search query text
        num_results: Number of results to return (3-10)
        
    Returns:
        Formatted string with search results
    """
    if not query or not query.strip():
        return "Please enter a search query."
    
    # Clamp num_results between 3 and 10
    num_results = max(3, min(10, num_results))
    
    try:
        results = chroma_reader.search(query=query, n_results=num_results)
        return format_search_results(results, content_length=500)
    except Exception as e:
        return f"Error during search: {str(e)}"

print("‚úì Search function defined")


‚úì Search function defined


In [13]:
# CELL_ID: 05_rag_pipeline_v1_gradio_interface
# ============================================================================
# GRADIO INTERACTIVE INTERFACE
# ============================================================================

import gradio as gr

# Create Gradio interface
def create_search_interface():
    """
    Create and launch Gradio interface for semantic search.
    """
    with gr.Blocks(title="Diabetes Guidelines Semantic Search", theme=gr.themes.Soft()) as demo:
        gr.Markdown("""
        # üîç Diabetes Guidelines Semantic Search
        
        Search through the Kenya National Clinical Guidelines for the Management of Diabetes using semantic search.
        
        **How to use:**
        1. Enter your search query in the text box
        2. Adjust the number of results (3-10 chunks)
        3. Click "Search" or press Enter
        4. View results with similarity scores and content previews
        
        **Similarity Score:** Higher values (closer to 1.0) indicate better matches. Scores are calculated as 1 - cosine distance.
        """)
        
        with gr.Row():
            with gr.Column(scale=3):
                query_input = gr.Textbox(
                    label="Search Query",
                    placeholder="e.g., 'insulin treatment for type 1 diabetes'",
                    lines=2
                )
            with gr.Column(scale=1):
                num_results = gr.Slider(
                    label="Number of Results",
                    minimum=3,
                    maximum=10,
                    value=5,
                    step=1,
                    info="Select 3-10 chunks"
                )
        
        search_btn = gr.Button("üîç Search", variant="primary", size="lg")
        
        gr.Markdown("### Results")
        results_output = gr.Textbox(
            label="Search Results",
            lines=20,
            max_lines=30,
            show_copy_button=True
        )
        
        # Examples
        gr.Markdown("### Example Queries")
        examples = gr.Examples(
            examples=[
                ["insulin treatment for type 1 diabetes", 5],
                ["diabetes management during pregnancy", 5],
                ["hypoglycemia symptoms and treatment", 5],
                ["blood glucose monitoring guidelines", 5],
                ["diabetic ketoacidosis management", 5],
                ["nutritional management for diabetes", 5],
            ],
            inputs=[query_input, num_results]
        )
        
        # Connect inputs to search function
        search_btn.click(
            fn=semantic_search,
            inputs=[query_input, num_results],
            outputs=results_output
        )
        
        query_input.submit(
            fn=semantic_search,
            inputs=[query_input, num_results],
            outputs=results_output
        )
        
        # Add footer
        gr.Markdown("""
        ---
        **Note:** This search uses semantic embeddings (Jina AI) to find relevant content based on meaning, not just keywords.
        """)
    
    return demo

# Create and launch the interface
print("=" * 60)
print("CREATING GRADIO INTERFACE")
print("=" * 60)

demo = create_search_interface()

print("‚úì Gradio interface created")
print("\n" + "=" * 60)
print("LAUNCHING INTERFACE")
print("=" * 60)
print("\nThe interface will open in your browser.")
print("You can also access it via the local URL shown below.")
print("=" * 60)

# Launch the interface
demo.launch(share=False, server_name="127.0.0.1", server_port=7860)


  from .autonotebook import tqdm as notebook_tqdm


CREATING GRADIO INTERFACE
‚úì Gradio interface created

LAUNCHING INTERFACE

The interface will open in your browser.
You can also access it via the local URL shown below.
* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.




In [14]:
# CELL_ID: 05_rag_pipeline_v1_test_search
# ============================================================================
# TEST SEARCH FUNCTION (OPTIONAL - FOR TESTING WITHOUT GRADIO)
# ============================================================================

# Uncomment to test search function directly
"""
print("=" * 60)
print("TESTING SEARCH FUNCTION")
print("=" * 60)

test_query = "insulin treatment for type 1 diabetes"
num_results = 5

print(f"Query: '{test_query}'")
print(f"Number of results: {num_results}")
print("\n" + "=" * 60)

results = semantic_search(test_query, num_results)
print(results)
"""


'\nprint("=" * 60)\nprint("TESTING SEARCH FUNCTION")\nprint("=" * 60)\n\ntest_query = "insulin treatment for type 1 diabetes"\nnum_results = 5\n\nprint(f"Query: \'{test_query}\'")\nprint(f"Number of results: {num_results}")\nprint("\n" + "=" * 60)\n\nresults = semantic_search(test_query, num_results)\nprint(results)\n'