In [None]:
import os
os.environ['LANGCHAIN_TRACING_V2'] = "true"
os.environ["LANGCHAIN_PROJECT"] = "ResearchPro2" 
os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
os.environ['LANGCHAIN_API_KEY'] = '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'
os.environ['OPENAI_API_KEY'] = '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'


In [None]:
# Install and import necessary packages
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_ollama import OllamaLLM
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
from sentence_transformers import SentenceTransformer, util
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain


from typing import List, Literal
import re
import os


In [None]:
PDF_DIR = "./arxiv_data"
DB_DIR = "./arxiv_vector1_db"

# Embedding & Vector DB
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_db = Chroma(persist_directory=DB_DIR, embedding_function=embedding_model)

# Load embedding model once for scoring
semantic_model = SentenceTransformer("all-MiniLM-L6-v2")

# LLM
llm = OllamaLLM(model="llama3.2")

In [None]:
import os
import re
from typing import Dict, List, Tuple, Optional, Any
import torch
from dataclasses import dataclass, field
from sentence_transformers import SentenceTransformer, util
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser, BaseOutputParser
from langchain.schema import Document, HumanMessage, AIMessage
from langchain.vectorstores import VectorStore
from langchain.chains import LLMChain
from langchain.llms.base import BaseLLM
from langchain.schema.runnable import Runnable
from langchain.memory import ConversationBufferMemory
from pydantic import BaseModel, Field
import logging
from rich.console import Console
from rich.panel import Panel
from rich.markdown import Markdown

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("SageRAG")

# Rich console for prettier output
console = Console()

# Enhanced output parsers with better error handling
class LineListOutputParser(BaseOutputParser[List[str]]):
    """Parse output that contains a numbered list and return as a list of strings."""
    
    def parse(self, text: str) -> List[str]:
        """Parse text into a list of strings."""
        # Updated regex to be more robust with various numbering styles
        lines = re.findall(r"^\s*\d+\.?\s+(.*?)$", text, re.MULTILINE)
        return [line.strip() for line in lines if line.strip()]
    
    @property
    def _type(self) -> str:
        return "line_list"

class QueryRating(BaseModel):
    """Schema for query rating output."""
    rating: int = Field(description="Rating from 1-5")
    explanation: str = Field(description="Explanation for the rating")

@dataclass
class RetrievedDocument:
    """Dataclass for tracking retrieved document info."""
    document: Document
    score: float
    query: str = ""
    rank: int = 0

class ResearchAssistant:
    """AI Research Assistant using RAG for scientific paper queries."""
    
    def __init__(
        self, 
        llm: BaseLLM, 
        vector_db: VectorStore, 
        embeddings_model: Optional[SentenceTransformer] = None,
        relevance_threshold: float = 0.6,
        max_docs_per_query: int = 3,
        always_use_fallback: bool = True  # New parameter to control fallback behavior
    ):
        """Initialize the Research Assistant with LLM and vector database.
        
        Args:
            llm: The language model to use for generation
            vector_db: Vector database for document retrieval
            embeddings_model: Optional SentenceTransformer model for semantic similarity
            relevance_threshold: Minimum relevance score for documents
            max_docs_per_query: Maximum documents to use per query
            always_use_fallback: Whether to always use LLM fallback for general knowledge questions
        """
        self.llm = llm
        self.vector_db = vector_db
        self.relevance_threshold = relevance_threshold
        self.max_docs_per_query = max_docs_per_query
        self.always_use_fallback = always_use_fallback
        
        # Initialize sentence transformer model for semantic similarity if not provided
        if embeddings_model is None:
            logger.info("Initializing default embedding model")
            self.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
        else:
            self.semantic_model = embeddings_model
            
        # Initialize memory
        self.memory = ConversationBufferMemory(
            memory_key="chat_history",
            return_messages=True
        )
        
        # Create a retriever for direct use
        self.retriever = vector_db.as_retriever()
        
        # Define parsers
        self.json_parser = JsonOutputParser(pydantic_object=QueryRating)
        self.list_parser = LineListOutputParser()
        
        logger.info("Research Assistant initialized successfully")
    
    def rate_query(self, query: str) -> Dict[str, Any]:
        """Rates the query and gives an explanation.
        
        Args:
            query: The user's query to be evaluated
            
        Returns:
            Dict containing rating and explanation
        """
        prompt = PromptTemplate.from_template("""
        You are an intelligent assistant trained to evaluate search queries for a scientific research database.
        Given the following query: "{query}"
        
        1. Rate the query on a scale of 1 (very poor) to 5 (excellent) based on:
           - Clarity: Is the query clear and unambiguous?
           - Specificity: Does it contain specific technical terms or concepts?
           - Relevance: Is it focused on retrieving scientific content?
           - Retrievability: Will it work well with vector search?
        
        2. Provide a short explanation for the rating (what makes it effective or ineffective).
        
        Respond in JSON format:
        {{
          "rating": <number between 1-5>,
          "explanation": "<your explanation>"
        }}
        """)
        
        chain: Runnable = prompt | self.llm | self.json_parser
        
        try:
            return chain.invoke({"query": query})
        except Exception as e:
            logger.error(f"Error rating query: {e}")
            # Fallback response if parsing fails
            return {
                "rating": 3, 
                "explanation": "Unable to rate query properly. Consider adding more specific terms."
            }
    
    def is_general_knowledge_question(self, query: str) -> bool:
        """Determines if a question is likely a general knowledge question rather than research-specific.
        
        Args:
            query: The user's query to evaluate
            
        Returns:
            Boolean indicating if this is likely a general knowledge question
        """
        prompt = PromptTemplate.from_template("""
        You are an intelligent assistant analyzing a question. Determine if this is a general knowledge question
        that doesn't require specific scientific papers to answer properly.
        
        Question: "{query}"
        
        Examples of general knowledge questions:
        - What is an operating system?
        - Who invented the telephone?
        - How does gravity work?
        - What is the difference between RAM and ROM?
        
        Examples of research-specific questions:
        - What are the latest developments in CRISPR gene editing?
        - How does the transformer architecture improve machine translation accuracy?
        - What methodologies are used to measure quantum entanglement?
        - What were the findings of the 2023 paper on climate change impact on coral reefs?
        
        Is this a general knowledge question that could be answered without specific research papers?
        Answer only YES or NO.
        """)
        
        try:
            result = self.llm.invoke(prompt.format(query=query))
            return "YES" in result.upper()
        except Exception as e:
            logger.error(f"Error determining if general knowledge question: {e}")
            return False  # Default to assuming it's not general knowledge
    
    def suggest_rewrites(self, query: str, chat_history: Optional[List] = None) -> List[str]:
        """Returns rephrased versions of the query optimized for retrieval.
        
        Args:
            query: Original query to rewrite
            chat_history: Optional conversation history for context
            
        Returns:
            List of rewritten queries
        """
        history_context = ""
        if chat_history:
            history_context = "Consider this conversation context when rewriting the query:\n"
            for message in chat_history[-3:]:  # Use only last 3 messages for brevity
                if hasattr(message, "content"):
                    role = "Human" if isinstance(message, HumanMessage) else "Assistant"
                    history_context += f"{role}: {message.content}\n"
        
        prompt = PromptTemplate(
            input_variables=["question", "history_context"],
            template="""You are an AI assistant specializing in scientific research queries.
            
            {history_context}
            
            Rephrase the question in 5 different ways to improve retrieval from a scientific paper database. 
            Focus on:
            1. Using technical terminology for better vector matching
            2. Breaking down complex queries into clearer formulations
            3. Adding relevant synonyms or related concepts
            4. Varying syntax while preserving semantic meaning
            5. Including key entities and relationships from the original query
            
            Number each version starting with 1.
            
            Question: {question}"""
        )
        
        try:
            result = (prompt | self.llm | self.list_parser).invoke({
                "question": query,
                "history_context": history_context
            })
            return result
        except Exception as e:
            logger.error(f"Error suggesting rewrites: {e}")
            # Return a minimal set of rewrites if parsing fails
            return [
                query,  # Original query
                f"research about {query}",
                f"papers discussing {query}"
            ]
    
    def compute_confidence(self, original: str, rewritten: str) -> float:
        """Computes semantic similarity between original and rewritten queries.
        
        Args:
            original: Original query
            rewritten: Rewritten version
            
        Returns:
            Similarity score between 0-1
        """
        try:
            vec_orig = self.semantic_model.encode(original, convert_to_tensor=True)
            vec_rewrite = self.semantic_model.encode(rewritten, convert_to_tensor=True)
            return float(util.pytorch_cos_sim(vec_orig, vec_rewrite).item())
        except Exception as e:
            logger.error(f"Error computing similarity: {e}")
            return 0.7  # Default reasonable value
    
    def score_queries(self, original: str, queries: List[str]) -> List[Tuple[str, float]]:
        """Scores and sorts rephrased queries by confidence.
        
        Args:
            original: Original query
            queries: List of rewritten queries
            
        Returns:
            List of (query, score) tuples sorted by score
        """
        scored = [(q, self.compute_confidence(original, q)) for q in queries]
        return sorted(scored, key=lambda x: x[1], reverse=True)
    
    def present_query_options(self, original: str, queries: List[str]) -> List[Tuple[str, float]]:
        """Presents the original and rewritten queries with confidence scores.
        
        Args:
            original: Original query
            queries: List of rewritten queries
            
        Returns:
            List of (query, score) tuples including original query
        """
        # Add original query at the top
        scored_queries = [("Original: " + original, 1.0)]
        
        # Add scored rewritten queries
        rewritten_scores = self.score_queries(original, queries)
        for i, (query, score) in enumerate(rewritten_scores, 1):
            scored_queries.append((f"Rewrite {i}: {query}", score))
            
        # Display options in a nice format
        console.print("\n[bold cyan]Available search queries:[/bold cyan]")
        for i, (query, score) in enumerate(scored_queries):
            confidence_color = "green" if score > 0.8 else "yellow" if score > 0.6 else "red"
            console.print(f"[bold]{i}.[/bold] {query}")
            console.print(f"   [bold {confidence_color}]Confidence: {score:.2f}[/bold {confidence_color}]")
            
        return scored_queries
    
    def retrieve_documents(
        self, 
        queries: List[str], 
        k: int = 5
    ) -> Tuple[List[RetrievedDocument], Dict[str, List[RetrievedDocument]]]:
        """Retrieves documents using multiple queries, preserving query information.
        
        Args:
            queries: List of queries to retrieve documents for
            k: Number of documents to retrieve per query
            
        Returns:
            Tuple of (all unique documents, query->documents mapping)
        """
        all_docs: List[RetrievedDocument] = []
        unique_content = set()
        query_docs_map: Dict[str, List[RetrievedDocument]] = {}
        
        # Retrieve docs for each query
        for query in queries:
            console.print(f"\n[bold blue]Query:[/bold blue] '{query}'")
            
            try:
                results = self.vector_db.similarity_search_with_relevance_scores(query, k=k)
                docs_for_query: List[RetrievedDocument] = []
                
                for i, (doc, score) in enumerate(results, 1):
                    # Create retrieved document object
                    retrieved_doc = RetrievedDocument(
                        document=doc,
                        score=score,
                        query=query,
                        rank=i
                    )
                    
                    # Display result info
                    score_color = "green" if score > 0.8 else "yellow" if score > 0.6 else "red"
                    console.print(f"[bold]--- Result {i} ---[/bold]")
                    console.print(f"[bold {score_color}]Score: {score:.4f}[/bold {score_color}]")
                    
                    # Show document preview
                    preview = doc.page_content[:150] + "..." if len(doc.page_content) > 150 else doc.page_content
                    console.print(Panel(preview, title="Content Preview", width=100))
                    
                    # Show metadata if available
                    if hasattr(doc, 'metadata') and doc.metadata:
                        source = doc.metadata.get('source', 'Unknown')
                        console.print(f"[dim]Source: {source}[/dim]")
                    
                    # Add to results for this query
                    docs_for_query.append(retrieved_doc)
                    
                    # Only add unique documents to overall collection
                    if doc.page_content not in unique_content:
                        unique_content.add(doc.page_content)
                        all_docs.append(retrieved_doc)
                
                query_docs_map[query] = docs_for_query
                
            except Exception as e:
                logger.error(f"Error retrieving documents for query '{query}': {e}")
                console.print(f"[bold red]Error retrieving documents for query: {query}[/bold red]")
        
        console.print(f"\n[bold green]Total unique documents:[/bold green] {len(all_docs)}")
        return all_docs, query_docs_map

    def rerank_docs(self, query: str, docs: List[RetrievedDocument]) -> List[RetrievedDocument]:
        """Reranks documents based on semantic similarity to the query.
        
        Args:
            query: Query to compare documents against
            docs: List of retrieved documents
            
        Returns:
            Reranked list of documents with updated scores
        """
        try:
            query_embedding = self.semantic_model.encode(query, convert_to_tensor=True)
            reranked = []
            
            for doc in docs:
                doc_embedding = self.semantic_model.encode(doc.document.page_content, convert_to_tensor=True)
                new_score = float(util.pytorch_cos_sim(query_embedding, doc_embedding).item())
                
                # Create new RetrievedDocument with updated score
                reranked_doc = RetrievedDocument(
                    document=doc.document,
                    score=new_score,
                    query=doc.query,
                    rank=0  # Will be updated after sorting
                )
                reranked.append(reranked_doc)
            
            # Sort by score and update ranks
            reranked.sort(key=lambda x: x.score, reverse=True)
            for i, doc in enumerate(reranked, 1):
                doc.rank = i
                
            return reranked
            
        except Exception as e:
            logger.error(f"Error reranking documents: {e}")
            return docs  # Return original documents if reranking fails
    
    def can_answer_without_retrieval(self, question: str) -> Tuple[bool, Optional[str]]:
        """Determines if a question can be answered directly from memory without retrieval.
        
        Args:
            question: User's question
            
        Returns:
            Tuple of (can_answer, answer_if_available)
        """
        # Get chat history from memory if available
        chat_history = self.memory.load_memory_variables({}).get("chat_history", [])
        
        if not chat_history:
            return False, None
            
        # Create the prompt to check if we can answer directly from the conversation
        prompt = PromptTemplate.from_template("""
        You are an AI assistant helping with a scientific research conversation.
        Given the following conversation history and a new question, determine if the question:
        1. Can be answered directly based ONLY on the conversation history (like "what was my previous question?")
        2. Does NOT require retrieving new information from scientific papers
        
        If both conditions are true, provide the answer. Otherwise, respond with "NEEDS_RETRIEVAL".
        
        Conversation History:
        {chat_history}
        
        New Question: {question}
        
        Your assessment (answer directly or respond with "NEEDS_RETRIEVAL"):
        """)
        
        # Format the chat history for context
        history_str = ""
        for message in chat_history[-5:]:  # Use only last 5 messages for brevity
            if hasattr(message, "content"):
                role = "Human" if isinstance(message, HumanMessage) else "Assistant"
                history_str += f"{role}: {message.content}\n"
        
        try:
            # Ask the LLM if this can be answered without retrieval
            response = self.llm.invoke(
                prompt.format(chat_history=history_str, question=question)
            )
            
            # If the response is "NEEDS_RETRIEVAL", we need to use retrieval
            if "NEEDS_RETRIEVAL" in response:
                return False, None
            else:
                logger.info("Question can be answered from conversation history")
                return True, response
        except Exception as e:
            logger.error(f"Error determining if retrieval needed: {e}")
            return False, None  # Default to retrieval if there's an error
    
    def generate_final_answer(
        self, 
        question: str, 
        docs: List[RetrievedDocument], 
        max_docs: int = 5,
        fallback_to_llm: bool = True
    ) -> str:
        """Generates the final answer from retrieved documents with fallback to LLM.
        
        Args:
            question: User's question
            docs: List of retrieved documents
            max_docs: Maximum number of documents to include
            fallback_to_llm: Whether to fall back to the LLM when docs are insufficient
            
        Returns:
            Generated answer
        """
        # Get chat history from memory if available
        chat_history = self.memory.load_memory_variables({}).get("chat_history", [])
        
        # Create chat history string for context if available
        chat_context = ""
        if chat_history:
            chat_context = "Previous conversation:\n"
            for message in chat_history[-3:]:  # Use last 3 messages for context
                if hasattr(message, "content"):
                    role = "Human" if isinstance(message, HumanMessage) else "Assistant"
                    chat_context += f"{role}: {message.content}\n"
        
        # First, check if this might be a general knowledge question
        is_general = self.is_general_knowledge_question(question)
        
        # First, check if we have any relevant documents
        has_relevant_docs = any(doc.score >= self.relevance_threshold for doc in docs)
        
        # If no relevant docs were found and fallback is enabled, use LLM directly
        if (not has_relevant_docs and fallback_to_llm) or (is_general and self.always_use_fallback):
            if is_general:
                console.print("[yellow]This appears to be a general knowledge question. Using LLM knowledge.[/yellow]")
            else:
                console.print("[yellow]No relevant documents found. Falling back to LLM knowledge.[/yellow]")
            
            # Create a direct LLM response prompt
            fallback_prompt = PromptTemplate(
                input_variables=["question", "chat_history"],
                template="""You are a helpful scientific research assistant.
                
                The user has asked a question that either:
                1. Could not be answered using our research paper database, or
                2. Is a general knowledge question that doesn't require specific papers
                
                Please answer based on your general knowledge.
                
                {chat_history}
                
                Question: {question}
                
                Provide a helpful, informative answer. If this is a scientific or technical question,
                explain concepts clearly and accurately. Don't apologize for not using specific papers 
                - just provide your best answer directly.
                
                Answer:"""
            )
            
            try:
                answer = (fallback_prompt | self.llm).invoke({
                    "question": question,
                    "chat_history": chat_context
                })
                
                # Save the QA pair to memory
                self.memory.chat_memory.add_messages([
                    HumanMessage(content=question),
                    AIMessage(content=answer)
                ])
                
                return answer
            except Exception as e:
                logger.error(f"Error generating fallback answer: {e}")
                return "I'm sorry, I don't have enough information to answer your question accurately."
        
        # Format document content with metadata
        context_pieces = []
        for i, doc in enumerate(docs[:max_docs]):
            chunk = f"Document {i+1} [Relevance: {doc.score:.2f}]:\n{doc.document.page_content}"
            
            # Add metadata if available
            if hasattr(doc.document, 'metadata') and doc.document.metadata:
                source = doc.document.metadata.get('source', 'Unknown')
                chunk += f"\nSource: {source}"
                
            context_pieces.append(chunk)
            
        context = "\n\n---\n\n".join(context_pieces)
        
        # If we have context but it might not be sufficient, use a prompt that can fall back to general knowledge
        prompt = PromptTemplate(
            input_variables=["question", "context", "chat_history"],
            template="""You are a helpful scientific research assistant analyzing scientific papers. 
            Use the following retrieved context chunks to answer the user's question thoroughly.
            
            Guidelines:
            - Focus on providing accurate information from the papers
            - Synthesize information across documents when appropriate
            - Cite the sources of information in your answer (e.g., "According to Document 1...")
            - If the information isn't sufficient in the context, supplement with your general knowledge,
              clearly indicating which parts of your answer come from the papers and which parts are from your general knowledge
            - If the context contains conflicting information, highlight the disagreement and possible reasons
            - Maintain scientific accuracy above all else
            - Don't apologize for using general knowledge - just be clear about what comes from papers vs. your knowledge
            
            {chat_history}
            
            Context:
            {context}
            
            Question: {question}
            
            Answer:"""
        )
        
        try:
            # Generate answer
            answer = (prompt | self.llm).invoke({
                "question": question, 
                "context": context,
                "chat_history": chat_context
            })
            
            # Save the QA pair to memory
            self.memory.chat_memory.add_messages([
                HumanMessage(content=question),
                AIMessage(content=answer)
            ])
            
            return answer
        except Exception as e:
            logger.error(f"Error generating answer: {e}")
            return "I'm sorry, I encountered an error while generating your answer. Please try rephrasing your question."
    
    def generate_combined_answer(
        self, 
        question: str, 
        query_results: Dict[str, List[RetrievedDocument]],
        fallback_to_llm: bool = True
    ) -> str:
        """Generates a combined answer from multiple query results with fallback.
        
        Args:
            question: User's question
            query_results: Dictionary mapping queries to retrieved documents
            fallback_to_llm: Whether to fall back to the LLM when docs are insufficient
            
        Returns:
            Generated answer
        """
        # Get chat history from memory if available
        chat_history = self.memory.load_memory_variables({}).get("chat_history", [])
        
        # Create chat history string for context if available
        chat_context = ""
        if chat_history:
            chat_context = "Previous conversation:\n"
            for message in chat_history[-3:]:  # Use last 3 messages for context
                if hasattr(message, "content"):
                    role = "Human" if isinstance(message, HumanMessage) else "Assistant"
                    chat_context += f"{role}: {message.content}\n"
        
        # Check if this might be a general knowledge question
        is_general = self.is_general_knowledge_question(question)
        
        # Check if we have any relevant documents at all
        all_docs = []
        for docs_list in query_results.values():
            all_docs.extend(docs_list)
            
        has_relevant_docs = any(doc.score >= self.relevance_threshold for doc in all_docs)
        
        # If no relevant documents and fallback enabled, use LLM directly
        if (not has_relevant_docs and fallback_to_llm) or (is_general and self.always_use_fallback):
            if is_general:
                console.print("[yellow]This appears to be a general knowledge question. Using LLM knowledge.[/yellow]")
            else:
                console.print("[yellow]No relevant documents found across all queries. Falling back to LLM knowledge.[/yellow]")
            
            # Create a direct LLM response prompt
            fallback_prompt = PromptTemplate(
                input_variables=["question", "chat_history"],
                template="""You are a helpful scientific research assistant.
                
                The user has asked a question that either:
                1. Could not be answered using our research paper database, or
                2. Is a general knowledge question that doesn't require specific papers
                
                Please answer based on your general knowledge.
                
                {chat_history}
                
                Question: {question}
                
                Provide a helpful, informative answer. If this is a scientific or technical question,
                explain concepts clearly and accurately. Don't apologize for not using specific papers 
                - just provide your best answer directly.
                
                Answer:"""
            )
            
            try:
                answer = (fallback_prompt | self.llm).invoke({
                    "question": question,
                    "chat_history": chat_context
                })
                
                # Save the QA pair to memory
                self.memory.chat_memory.add_messages([
                    HumanMessage(content=question),
                    AIMessage(content=answer)
                ])
                
                return answer
            except Exception as e:
                logger.error(f"Error generating fallback answer: {e}")
                return "I'm sorry, I don't have enough information to answer your question accurately."
        
        # Combine all relevant documents from different queries
        all_context_sections = []
        
        for query, docs in query_results.items():
            # Use only the top N documents for each query
            docs_for_query = docs[:self.max_docs_per_query]
            if docs_for_query:
                all_context_sections.append(f"Results for query: '{query}'")
                for i, doc in enumerate(docs_for_query, 1):
                    section = f"Document {i} [Relevance: {doc.score:.2f}]:\n{doc.document.page_content}"
                    
                    # Add metadata if available
                    if hasattr(doc.document, 'metadata') and doc.document.metadata:
                        source = doc.document.metadata.get('source', 'Unknown')
                        section += f"\nSource: {source}"
                        
                    all_context_sections.append(section)
        
        # Join all context sections
        combined_context = "\n\n---\n\n".join(all_context_sections)
        
        # Create a prompt that emphasizes synthesizing information and can fall back to general knowledge
        prompt = PromptTemplate(
            input_variables=["question", "context", "chat_history"],
            template="""You are a helpful scientific research assistant analyzing scientific papers. 
            The user's question has been reformulated in several ways, and each formulation returned different documents.
            
            Guidelines:
            - Synthesize information across ALL retrieved documents to provide a comprehensive answer
            - Compare and contrast findings from different sources
            - Highlight the most relevant information from each source
            - Cite the specific documents you're referencing (e.g., "According to the paper in Document 3...")
            - If information conflicts across sources, explain the different perspectives
            - If the information is insufficient, supplement with your general knowledge,
              clearly indicating which parts of your answer come from the papers and which are from your general knowledge
            - Maintain scientific accuracy above all else
            - Don't apologize for using general knowledge - just be clear about what information comes from papers vs. your knowledge
            
            {chat_history}
            
            Context from multiple query formulations:
            {context}
            
            Original Question: {question}
            
            Answer:"""
        )
        
        try:
            # Generate answer
            answer = (prompt | self.llm).invoke({
                "question": question, 
                "context": combined_context,
                "chat_history": chat_context
            })
            
            # Save the QA pair to memory
            self.memory.chat_memory.add_messages([
                HumanMessage(content=question),
                AIMessage(content=answer)
            ])
            
            return answer
        except Exception as e:
            logger.error(f"Error generating combined answer: {e}")
            return "I'm sorry, I encountered an error while generating your answer. Please try rephrasing your question."
    
    def answer_query(self, query: str, use_retrieval: bool = True) -> str:
        """Direct answer method that either does retrieval or uses LLM fallback.
        
        Args:
            query: User's question
            use_retrieval: Whether to use retrieval process or direct LLM
            
        Returns:
            Generated answer
        """
        # First check if we can answer directly from conversation history
        can_answer, direct_answer = self.can_answer_without_retrieval(query)
        if can_answer:
            console.print("[green]Answering from conversation history[/green]")
            return direct_answer
        
        # Check if this is a general knowledge question
        is_general = self.is_general_knowledge_question(query)
        
        # If it's general knowledge and we're set to always use fallback for those
        if is_general and self.always_use_fallback and not use_retrieval:
            console.print("[yellow]General knowledge question detected. Using LLM directly.[/yellow]")
            # Create a direct LLM response prompt
            chat_history = self.memory.load_memory_variables({}).get("chat_history", [])
            chat_context = ""
            if chat_history:
                chat_context = "Previous conversation:\n"
                for message in chat_history[-3:]:  # Use last 3 messages for context
                    if hasattr(message, "content"):
                        role = "Human" if isinstance(message, HumanMessage) else "Assistant"
                        chat_context += f"{role}: {message.content}\n"
            
            fallback_prompt = PromptTemplate(
                input_variables=["question", "chat_history"],
                template="""You are a helpful scientific research assistant.
                
                The user has asked a general knowledge question that doesn't require specific papers.
                Please answer based on your general knowledge.
                
                {chat_history}
                
                Question: {question}
                
                Provide a helpful, informative answer. If this is a scientific or technical question,
                explain concepts clearly and accurately.
                
                Answer:"""
            )
            
            try:
                answer = (fallback_prompt | self.llm).invoke({
                    "question": query,
                    "chat_history": chat_context
                })
                
                # Save the QA pair to memory
                self.memory.chat_memory.add_messages([
                    HumanMessage(content=query),
                    AIMessage(content=answer)
                ])
                
                return answer
            except Exception as e:
                logger.error(f"Error generating direct answer: {e}")
                return "I'm sorry, I encountered an error while generating your answer."
        
        # Otherwise, proceed with full retrieval pipeline
        return self.search_with_query_feedback(query)
    
    def search_with_query_feedback(self, query: str, num_results: int = 5) -> str:
        """Main pipeline that processes a query and returns an answer.
        
        Args:
            query: User's question
            num_results: Number of results to retrieve
            
        Returns:
            Generated answer
        """
        try:
            # Step 1: Rate the original query
            console.print("\n[bold cyan]Evaluating your query...[/bold cyan]")
            rating_result = self.rate_query(query)
            
            rating_color = "green" if rating_result['rating'] >= 4 else "yellow" if rating_result['rating'] >= 3 else "red"
            console.print(f"[bold {rating_color}]Query Rating: {rating_result['rating']}/5[/bold {rating_color}]")
            console.print(f"Explanation: {rating_result['explanation']}")
            console.print("\n" + "-"*50 + "\n")
            
            # Check if this is a general knowledge question
            is_general = self.is_general_knowledge_question(query)
            if is_general and self.always_use_fallback:
                console.print("[yellow]This appears to be a general knowledge question. Proceeding with simplified search.[/yellow]")
                # For general knowledge questions, we might still want to check the database briefly
                # but we'll rely more on the LLM fallback
            
            # Step 2: Suggest rewrites if the rating is less than perfect
            rewritten_queries = []
            if rating_result["rating"] < 5:
                console.print("[bold cyan]Generating improved query variations...[/bold cyan]")
                # Get chat history for context-aware rewrites
                chat_history = self.memory.load_memory_variables({}).get("chat_history", [])
                rewritten_queries = self.suggest_rewrites(query, chat_history)
                console.print("\n" + "-"*50 + "\n")
            
            # Step 3: Present options to the user
            query_options = self.present_query_options(query, rewritten_queries)
            console.print("\n" + "-"*50 + "\n")
            
            # Step 4: Get user selection
            console.print("[bold cyan]Select queries to use:[/bold cyan]")
            console.print("Enter the numbers of the queries you want to use (comma-separated, e.g., '0,2,3')")
            console.print("Or press Enter to use all queries")
            selected_indices_input = input("> ")
            
            if selected_indices_input.strip() == "":
                # Use all queries if no selection
                selected_indices = list(range(len(query_options)))
            else:
                try:
                    selected_indices = [int(idx.strip()) for idx in selected_indices_input.split(",")]
                except ValueError:
                    console.print("[bold red]Invalid input. Using original query only.[/bold red]")
                    selected_indices = [0]  # Default to original query
            
            # Get the selected queries (without the prefix and score)
            selected_queries = []
            for idx in selected_indices:
                if 0 <= idx < len(query_options):
                    query_text = query_options[idx][0]
                    # Remove the "Original: " or "Rewrite N: " prefix
                    if "Original: " in query_text:
                        query_text = query_text.replace("Original: ", "")
                    elif "Rewrite " in query_text:
                        query_text = query_text.split(": ", 1)[1] if ": " in query_text else query_text
                    selected_queries.append(query_text)
            
            if not selected_queries:
                console.print("[bold red]No valid queries selected. Using original query.[/bold red]")
                selected_queries = [query]
                
            console.print(f"\n[bold green]Selected {len(selected_queries)} queries for retrieval.[/bold green]")
            console.print("\n" + "-"*50 + "\n")
            
            # Step 5: Retrieve documents
            console.print("[bold cyan]Retrieving relevant documents...[/bold cyan]")
            docs, query_docs_map = self.retrieve_documents(selected_queries, k=num_results)
            
            # Step 6: Generate answer
            if len(selected_queries) > 1:
                console.print("\n[bold cyan]Generating combined answer based on multiple query results...[/bold cyan]")
                final_answer = self.generate_combined_answer(
                    query, 
                    query_docs_map,
                    fallback_to_llm=True  # Always enable fallback
                )
            else:
                # For single query, rerank and use the traditional approach
                console.print("\n[bold cyan]Reranking documents based on relevance to the original query...[/bold cyan]")
                reranked_docs = self.rerank_docs(query, docs)
                
                # Apply relevance threshold
                filtered_docs = [
                    doc for doc in reranked_docs if doc.score >= self.relevance_threshold
                ]
                
                # Display relevance information
                has_relevant_docs = len(filtered_docs) > 0
                if not has_relevant_docs:
                    console.print("[yellow]No documents met the relevance threshold. Using all retrieved documents but may fall back to LLM knowledge.[/yellow]")
                    filtered_docs = reranked_docs
                
                # Generate the final answer with fallback enabled
                console.print("\n[bold cyan]Generating answer...[/bold cyan]")
                final_answer = self.generate_final_answer(
                    query, 
                    filtered_docs, 
                    fallback_to_llm=True  # Always enable fallback
                )
            
            return final_answer
            
        except Exception as e:
            logger.error(f"Error in search pipeline: {e}")
            # Fall back to direct LLM answer if the pipeline fails
            console.print("[bold red]Error in search pipeline. Falling back to LLM.[/bold red]")
            
            chat_history = self.memory.load_memory_variables({}).get("chat_history", [])
            chat_context = ""
            if chat_history:
                chat_context = "Previous conversation:\n"
                for message in chat_history[-3:]:
                    if hasattr(message, "content"):
                        role = "Human" if isinstance(message, HumanMessage) else "Assistant"
                        chat_context += f"{role}: {message.content}\n"
            
            # Create a direct LLM response as emergency fallback
            emergency_prompt = PromptTemplate(
                input_variables=["question", "chat_history"],
                template="""You are a helpful scientific research assistant.
                
                There was an error retrieving information from our research database.
                Please answer the question based on your general knowledge.
                
                {chat_history}
                
                Question: {question}
                
                Answer:"""
            )
            
            try:
                answer = (emergency_prompt | self.llm).invoke({
                    "question": query,
                    "chat_history": chat_context
                })
                return answer
            except:
                return "I apologize, but I'm having technical difficulties. Please try again with a different question."

# Example usage implementation 
def create_research_assistant(llm, vector_db, embeddings_model=None):
    """Factory function to create a research assistant with the specified components."""
    return ResearchAssistant(
        llm=llm,
        vector_db=vector_db,
        embeddings_model=embeddings_model,
        relevance_threshold=0.5,  # Lower threshold to be more lenient with document relevance
        max_docs_per_query=3,
        always_use_fallback=True  # Always use LLM fallback for general knowledge
    )

# Interactive command-line interface
def run_cli(assistant):
    """Run an interactive CLI for the research assistant."""
    console.print("[bold green]Research Assistant CLI[/bold green]")
    console.print("Type 'exit' to quit\n")
    
    while True:
        question = input("\n[bold blue]Ask a question:[/bold blue] ")
        if question.lower() in ('exit', 'quit', 'q'):
            break
            
        try:
            answer = assistant.answer_query(question)
            console.print("\n[bold green]Answer:[/bold green]")
            console.print(Panel(Markdown(answer), width=100))
        except Exception as e:
            console.print(f"[bold red]Error:[/bold red] {e}")

# Usage example:
if __name__ == "__main__":
    # You would need to define these before using:
    # from langchain.llms import ChatOpenAI
    # from langchain.vectorstores import Chroma
    # 
    # llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
    # vector_db = Chroma(embedding_function=...)
    # 
    # assistant = create_research_assistant(llm, vector_db)
    # run_cli(assistant)
    print("Import this module to use the ResearchAssistant class")