# Find Medical Specialties - RAG Interface

This notebook creates a user interface for searching and finding medical specialties using RAG (Retrieval Augmented Generation) with MongoDB as the vector store.

Based on day5.ipynb structure, but using MongoDB from LoadDB.ipynb as the data source.


In [None]:
import os
import numpy as np
from typing import List, Dict, Any, Optional
from pymongo import MongoClient
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from pydantic import ConfigDict
import gradio as gr
from dotenv import load_dotenv

load_dotenv()


True

## Database Connection


In [49]:
def getDBConnection() -> MongoClient:
    """
    Creates and returns a MongoDB client using the connection string stored
    in the MONGO_connectionString environment variable.
    """
    conn_str = os.getenv("MONGO_connectionString")
    
    if not conn_str:
        raise EnvironmentError(
            "Environment variable 'MONGO_connectionString' is not set."
        )
    
    try:
        client = MongoClient(conn_str)
        client.admin.command("ping")
        print("DB client successfully created")
        return client
    except Exception as e:
        raise ConnectionError(f"Failed to connect to MongoDB: {e}") from e

# Initialize database connection
db_client = getDBConnection()


DB client successfully created


## MongoDB Vector Retriever

Create a custom retriever that works with MongoDB vector store


In [50]:
class MongoDBVectorRetriever(BaseRetriever):
    """Custom retriever for MongoDB vector store."""
    
    model_config = ConfigDict(extra='allow', arbitrary_types_allowed=True)
    
    client: MongoClient
    db_name: str
    collection_name: str
    embeddings: OpenAIEmbeddings
    k: int = 5
    
    def __init__(
        self,
        client: MongoClient,
        db_name: str,
        collection_name: str,
        embeddings: OpenAIEmbeddings,
        k: int = 5
    ):
        # Initialize with proper field assignment
        super().__init__(client=client, db_name=db_name, collection_name=collection_name, embeddings=embeddings, k=k)
        # Set non-field attributes after initialization
        self.db = client[db_name]
        self.collection = self.db[collection_name]
    
    def cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
        """Calculate cosine similarity between two vectors."""
        dot_product = np.dot(vec1, vec2)
        norm1 = np.linalg.norm(vec1)
        norm2 = np.linalg.norm(vec2)
        if norm1 == 0 or norm2 == 0:
            return 0.0
        return dot_product / (norm1 * norm2)
    
    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        """Retrieve relevant documents from MongoDB."""
        # Embed the query
        query_embedding = self.embeddings.embed_query(query)
        query_vec = np.array(query_embedding)
        
        # Get all documents from the collection
        all_docs = list(self.collection.find({}))
        
        if not all_docs:
            return []
        
        # Calculate similarities
        scored_docs = []
        for doc in all_docs:
            doc_embedding = np.array(doc.get("embedding", []))
            if len(doc_embedding) == 0:
                continue
            
            similarity = self.cosine_similarity(query_vec, doc_embedding)
            scored_docs.append((similarity, doc))
        
        # Sort by similarity and get top k
        scored_docs.sort(key=lambda x: x[0], reverse=True)
        top_docs = scored_docs[:self.k]
        
        # Convert to LangChain Documents
        documents = []
        for similarity, doc in top_docs:
            metadata = doc.get("metadata", {})
            # Add Code to metadata for easy access
            code = doc.get("Code", "") or metadata.get("Code", "")
            if code:
                metadata["Code"] = code
            
            documents.append(
                Document(
                    page_content=doc.get("text", ""),
                    metadata=metadata
                )
            )
        
        return documents


In [51]:
# Use the same embedding model as LoadDB.ipynb
embeddings = OpenAIEmbeddings(model="text-embedding-3-large")

# Create MongoDB retriever
retriever = MongoDBVectorRetriever(
    client=db_client,
    db_name="PublicHealthData",
    collection_name="specialtyMetaDataVectors",
    embeddings=embeddings,
    k=5  # Retrieve top 5 most relevant documents
)

# Test that retriever works
try:
    test_docs = retriever.get_relevant_documents("test")
    print(f"Retriever initialized successfully. Found {len(test_docs)} test documents.")
except Exception as e:
    print(f"Warning: Retriever test failed: {e}")
    print("Continuing anyway...")


  test_docs = retriever.get_relevant_documents("test")


Retriever initialized successfully. Found 0 test documents.


In [52]:
# Create LLM
MODEL = "gpt-4o-mini"
llm = ChatOpenAI(temperature=0.7, model=MODEL)

# Set up conversation memory
memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)

# Create conversational retrieval chain
conversation_chain = ConversationalRetrievalChain.from_llm(
    llm=llm,
    retriever=retriever,
    memory=memory
)

print("Conversation chain initialized")


Conversation chain initialized


  memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)


In [53]:
def chat(message, history):
    """Chat function for Gradio interface."""
    result = conversation_chain.invoke({"question": message})
    return result["answer"]


In [54]:
# Create and launch Gradio interface
view = gr.ChatInterface(
    chat,
    title="Find Medical Specialties",
    description="Ask questions about medical specialties. The system will search the specialty database and provide answers based on the retrieved information.",
    examples=[
        "What specialties are available for heart disease?",
        "Tell me about pediatric care specialties",
        "What specialties deal with mental health?",
        "Find specialties related to cancer treatment"
    ],
    type="messages"
).launch(inbrowser=True)


* Running on local URL:  http://127.0.0.1:7868
* To create a public link, set `share=True` in `launch()`.
