In [None]:
import os
os.environ['LANGCHAIN_TRACING_V2'] = "true"
os.environ["LANGCHAIN_PROJECT"] = "ResearchPro2" 
os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
os.environ['LANGCHAIN_API_KEY'] = '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'
os.environ['OPENAI_API_KEY'] = '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'


In [None]:
# Install and import necessary packages
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_ollama import OllamaLLM
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
from sentence_transformers import SentenceTransformer, util
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain


from typing import List, Literal
import re
import os


In [None]:
PDF_DIR = "./arxiv_data"
DB_DIR = "./arxiv_vector1_db"

# Embedding & Vector DB
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_db = Chroma(persist_directory=DB_DIR, embedding_function=embedding_model)

# Load embedding model once for scoring
semantic_model = SentenceTransformer("all-MiniLM-L6-v2")

# LLM
llm = OllamaLLM(model="llama3.2")

In [None]:
import os
import re
from typing import Dict, List, Tuple, Optional
import torch
from sentence_transformers import SentenceTransformer, util
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser, BaseOutputParser
from langchain.schema import Document, HumanMessage, AIMessage
from langchain.vectorstores import VectorStore
from langchain.chains import LLMChain
from langchain.llms.base import BaseLLM
from langchain.schema.runnable import Runnable
from langchain.memory import ConversationBufferMemory

# Output Parser
class LineListOutputParser(BaseOutputParser[List[str]]):
    def parse(self, text: str) -> List[str]:
        lines = re.findall(r"^\d+\.\s+(.*)", text, re.MULTILINE)
        return [line.strip() for line in lines if line.strip()]

class ResearchAssistant:
    def __init__(self, llm: BaseLLM, vector_db: VectorStore, embeddings_model: Optional[SentenceTransformer] = None):
        """Initialize the Research Assistant with LLM and vector database."""
        self.llm = llm
        self.vector_db = vector_db
        
        # Initialize sentence transformer model for semantic similarity if not provided
        if embeddings_model is None:
            self.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
        else:
            self.semantic_model = embeddings_model
            
        # Initialize memory - we'll manage it manually instead of using ConversationalRetrievalChain
        self.memory = ConversationBufferMemory(
            memory_key="chat_history",
            return_messages=True
        )
        
        # Create a retriever for direct use
        self.retriever = vector_db.as_retriever()
    
    def rate_query(self, query: str) -> Dict[str, str | int]:
        """Rates the query and gives an explanation."""
        prompt = PromptTemplate.from_template("""
        You are an intelligent assistant trained to evaluate search queries.
        Given the following query: "{query}"
        1. Rate the query on a scale of 1 (very poor) to 5 (excellent) based on:
           - Clarity
           - Specificity
           - Relevance
           - Retrievability
        2. Provide a short explanation for the rating (if it is vague, incomplete, or inefficient).
        Respond in JSON format:
        {{
          "rating": 3,
          "explanation": "Too vague, lacks keywords."
        }}
        """)
        parser = JsonOutputParser()
        chain: Runnable = prompt | self.llm | parser
        return chain.invoke({"query": query})
    
    def suggest_rewrites(self, query: str) -> List[str]:
        """Returns 5 rephrased versions of the query."""
        prompt = PromptTemplate(
            input_variables=["question"],
            template="""You are an AI assistant. Rephrase the question in 5 different ways to improve retrieval from a document store. Number each version starting with 1.\nQuestion: {question}"""
        )
        return (prompt | self.llm | LineListOutputParser()).invoke({"question": query})
    
    def compute_confidence(self, original: str, rewritten: str) -> float:
        """Computes semantic similarity between original and rewritten queries."""
        vec_orig = self.semantic_model.encode(original, convert_to_tensor=True)
        vec_rewrite = self.semantic_model.encode(rewritten, convert_to_tensor=True)
        return util.pytorch_cos_sim(vec_orig, vec_rewrite).item()
    
    def score_queries(self, original: str, queries: List[str]) -> List[Tuple[str, float]]:
        """Scores and sorts rephrased queries by confidence."""
        scored = [(q, self.compute_confidence(original, q)) for q in queries]
        scored.sort(key=lambda x: x[1], reverse=True)
        return scored
    
    def present_query_options(self, original: str, queries: List[str]) -> List[Tuple[str, float]]:
        """Presents the original and rewritten queries with confidence scores."""
        # Add original query at the top
        scored_queries = [("Original: " + original, 1.0)]
        
        # Add scored rewritten queries
        rewritten_scores = self.score_queries(original, queries)
        for i, (query, score) in enumerate(rewritten_scores, 1):
            scored_queries.append((f"Rewrite {i}: {query}", score))
            
        # Print options for user
        print("Available search queries (with confidence scores):")
        for i, (query, score) in enumerate(scored_queries):
            print(f"{i}. {query} [Confidence: {score:.2f}]")
            
        return scored_queries
    
    def retrieve_documents(self, queries: List[str], k: int = 5) -> Tuple[List[Document], Dict[str, List[Document]]]:
        """
        Retrieves documents using multiple queries, preserving which query returned which documents.
        Shows relevance scores for better transparency.
        """
        all_docs = []
        unique_docs = {}
        query_docs_map = {}
        
        # Retrieve docs for each query
        for query in queries:
            print(f"\nQuery: '{query}'")
            results = self.vector_db.similarity_search_with_relevance_scores(query, k=k)
            docs = []
            
            for i, (doc, score) in enumerate(results, 1):
                print(f"--- Result {i} ---")
                print(f"Score: {score:.4f}")
                print(f"Chunk:\n{doc.page_content[:150]}...")
                if hasattr(doc, 'metadata') and doc.metadata:
                    source = doc.metadata.get('source', 'Unknown')
                    print(f"Source: {source}")
                print()
                
                docs.append(doc)
                
                # Add to our overall collection with deduplication
                if doc.page_content not in unique_docs:
                    unique_docs[doc.page_content] = doc
                    all_docs.append(doc)
            
            query_docs_map[query] = docs
        
        print(f"\nTotal unique documents: {len(all_docs)}")
        return all_docs, query_docs_map

    def rerank_docs(self, query: str, docs: List[Document]) -> List[Tuple[Document, float]]:
        """Reranks documents based on semantic similarity to the query."""
        query_embedding = self.semantic_model.encode(query, convert_to_tensor=True)
        reranked = []
        for doc in docs:
            doc_embedding = self.semantic_model.encode(doc.page_content, convert_to_tensor=True)
            score = util.pytorch_cos_sim(query_embedding, doc_embedding).item()
            reranked.append((doc, score))
        return sorted(reranked, key=lambda x: x[1], reverse=True)
    
    def can_answer_without_retrieval(self, question: str) -> Tuple[bool, Optional[str]]:
        """
        Determines if a question can be answered directly from memory without retrieval.
        Returns a tuple: (can_answer, answer_if_available)
        """
        # Get chat history from memory if available
        chat_history = self.memory.load_memory_variables({}).get("chat_history", [])
        
        if not chat_history:
            return False, None
            
        # Create the prompt to check if we can answer directly from the conversation
        prompt = PromptTemplate.from_template("""
        You are an AI assistant helping with a conversation.
        Given the following conversation history and a new question, determine if the question:
        1. Can be answered directly based ONLY on the conversation history (like "what was my previous question?")
        2. Does NOT require retrieving new information from documents
        
        If both conditions are true, provide the answer. Otherwise, respond with "NEEDS_RETRIEVAL".
        
        Conversation History:
        {chat_history}
        
        New Question: {question}
        
        Your assessment (answer directly or respond with "NEEDS_RETRIEVAL"):
        """)
        
        # Format the chat history for context
        history_str = ""
        for message in chat_history:
            if hasattr(message, "content"):
                role = "Human" if isinstance(message, HumanMessage) else "Assistant"
                history_str += f"{role}: {message.content}\n"
        
        # Ask the LLM if this can be answered without retrieval
        response = self.llm.invoke(
            prompt.format(chat_history=history_str, question=question)
        )
        
        
        # If the response is "NEEDS_RETRIEVAL", we need to use retrieval
        if "NEEDS_RETRIEVAL" in response:
            return False, None
        else:
            return True, response
    
    def generate_final_answer(self, question: str, docs: List[Document], max_docs: int = 5) -> str:
        """Generates the final answer from retrieved documents."""
        # Get chat history from memory if available
        chat_history = self.memory.load_memory_variables({}).get("chat_history", [])
        
        # Create chat history string for context if available
        chat_context = ""
        if chat_history:
            chat_context = "Previous conversation:\n"
            for message in chat_history:
                if hasattr(message, "content"):
                    role = "Human" if isinstance(message, HumanMessage) else "Assistant"
                    chat_context += f"{role}: {message.content}\n"
            
        prompt = PromptTemplate(
            input_variables=["question", "context", "chat_history"],
            template="""You are a helpful research assistant. 
            Use the following retrieved context chunks to answer the user's question thoroughly.
            If the information isn't in the context, indicate what's missing rather than making up information.
            If the context contains conflicting or uncertain information, highlight the disagreement. 
            Do not fabricate any facts not grounded in the provided context.
            
            {chat_history}
            
            Context:
            {context}
            
            Question: {question}
            
            Answer:"""
        )
        
        context = "\n\n---\n\n".join([f"Document {i+1}:\n{doc.page_content}" 
                                     for i, doc in enumerate(docs[:max_docs])])
        
        answer = (prompt | self.llm).invoke({
            "question": question, 
            "context": context,
            "chat_history": chat_context
        })
        
        # Save the QA pair to memory manually
        self.memory.chat_memory.add_messages([
            HumanMessage(content=question),
            AIMessage(content=answer)
        ])
        
        return answer
    
    def generate_combined_answer(self, question: str, query_results: Dict[str, List[Document]], max_docs_per_query: int = 3) -> str:
        """Generates a combined answer from multiple query results."""
        # Get chat history from memory if available
        chat_history = self.memory.load_memory_variables({}).get("chat_history", [])
        
        # Create chat history string for context if available
        chat_context = ""
        if chat_history:
            chat_context = "Previous conversation:\n"
            for message in chat_history:
                if hasattr(message, "content"):
                    role = "Human" if isinstance(message, HumanMessage) else "Assistant"
                    chat_context += f"{role}: {message.content}\n"
        
        # Combine all relevant documents from different queries
        all_context_sections = []
        for query, docs in query_results.items():
            # Use only the top N documents for each query
            docs_for_query = docs[:max_docs_per_query]
            if docs_for_query:
                all_context_sections.append(f"Results for query: '{query}'")
                for i, doc in enumerate(docs_for_query, 1):
                    all_context_sections.append(f"Document {i}:\n{doc.page_content}")
        
        # Join all context sections
        combined_context = "\n\n---\n\n".join(all_context_sections)
        
        # Create a prompt that emphasizes synthesizing information across different query results
        prompt = PromptTemplate(
            input_variables=["question", "context", "chat_history"],
            template="""You are a helpful research assistant. 
            The user's question has been reformulated in several ways, and each formulation returned different documents.
            Use the following retrieved context chunks from ALL query variations to comprehensively answer the user's question.
            Synthesize information across all retrieved documents to provide the most complete answer possible.
            If the information isn't in the context, indicate what's missing rather than making up information.
            If the context contains conflicting information from different query results, highlight the disagreement.
            Do not fabricate any facts not grounded in the provided context.
            
            {chat_history}
            
            Context from multiple query formulations:
            {context}
            
            Original Question: {question}
            
            Answer:"""
        )
        
        answer = (prompt | self.llm).invoke({
            "question": question, 
            "context": combined_context,
            "chat_history": chat_context
        })
        
        # Save the QA pair to memory manually
        self.memory.chat_memory.add_messages([
            HumanMessage(content=question),
            AIMessage(content=answer)
        ])
        
        return answer
    
    def search_with_query_feedback(self, query: str, num_results: int = 5) -> str:
        """Main pipeline that processes a query and returns an answer."""
        # Step 1: Rate the original query
        print("Evaluating your query...")
        rating_result = self.rate_query(query)
        print(f"Query Rating: {rating_result['rating']}/5")
        print(f"Explanation: {rating_result['explanation']}")
        print("\n" + "-"*50 + "\n")
        
        # Step 2: Suggest rewrites if the rating is less than perfect
        rewritten_queries = []
        if rating_result["rating"] < 5:
            print("Generating improved query variations...")
            rewritten_queries = self.suggest_rewrites(query)
            print("\n" + "-"*50 + "\n")
        
        # Step 3: Present options to the user
        query_options = self.present_query_options(query, rewritten_queries)
        print("\n" + "-"*50 + "\n")
        
        # Step 4: Get user selection
        selected_indices = input("Enter the numbers of the queries you want to use (comma-separated, e.g., '0,2,3'): ")
        try:
            selected_indices = [int(idx.strip()) for idx in selected_indices.split(",")]
        except ValueError:
            print("Invalid input. Please enter numbers separated by commas.")
            return "Invalid query selection. Please try again."
        
        # Get the selected queries (without the prefix and score)
        selected_queries = []
        for idx in selected_indices:
            if idx == 0:  # Original query
                selected_queries.append(query)
            else:  # Rewritten query
                # Check if the index is valid
                if idx <= len(rewritten_queries):
                    query_text = rewritten_queries[idx-1]
                    selected_queries.append(query_text)
        
        if not selected_queries:
            print("No valid queries selected.")
            return "No valid queries were selected. Please try again."
            
        print(f"\nSelected {len(selected_queries)} queries for retrieval.")
        print("\n" + "-"*50 + "\n")
        
        # Step 5: Retrieve documents
        print("Retrieving relevant documents...")
        docs, query_docs_map = self.retrieve_documents(selected_queries, k=num_results)
        
        # Step 6: Generate answer
        # If using multiple queries, use the combined answer approach
        if len(selected_queries) > 1:
            print("\nGenerating combined answer based on multiple query results...")
            final_answer = self.generate_combined_answer(query, query_docs_map, max_docs_per_query=3)
        else:
            # For single query, rerank and use the traditional approach
            print("\nReranking documents based on relevance to the original query...")
            reranked_docs_with_scores = self.rerank_docs(query, docs)
    
            # Apply relevance threshold
            relevance_threshold = 0.6  # adjust as needed
            filtered_docs_with_scores = [
                (doc, score) for doc, score in reranked_docs_with_scores if score >= relevance_threshold
            ]
            
            # If no documents meet the threshold, use all documents
            if not filtered_docs_with_scores:
                print("No documents met the relevance threshold. Using all retrieved documents.")
                filtered_docs_with_scores = reranked_docs_with_scores
            
            # Sort top N
            filtered_docs_with_scores = sorted(filtered_docs_with_scores, key=lambda x: x[1], reverse=True)
            filtered_docs = [doc for doc, _ in filtered_docs_with_scores]
            
            # Display top reranked documents
            print("\nTop reranked documents:")
            for i, (doc, score) in enumerate(filtered_docs_with_scores[:num_results], 1):
                print(f"{i}. Score: {score:.4f}")
                print(f"   Preview: {doc.page_content[:150]}...")
                if hasattr(doc, 'metadata') and doc.metadata:
                    print(f"   Source: {doc.metadata.get('source', 'Unknown')}")
                    
            # Generate the final answer
            print("\nGenerating your answer based on retrieved documents...")
            final_answer = self.generate_final_answer(query, filtered_docs[:num_results])
        
        return final_answer
    
    def ask_with_memory(self, question: str, num_results: int = 5) -> str:
        """
        Uses conversation memory to process follow-up questions.
        Now includes query improvement and document retrieval.
        """
        # First, check if this is a question we can answer directly from memory
        can_direct_answer, direct_answer = self.can_answer_without_retrieval(question)
        if can_direct_answer:
            print("Question can be answered directly from conversation history...")
            # Save the QA pair to memory manually
            self.memory.chat_memory.add_messages([
                HumanMessage(content=question),
                AIMessage(content=direct_answer)
            ])
            return direct_answer
        
        # Otherwise, use the full query improvement pipeline
        # Step 1: Rate the query
        print("Evaluating your query (with memory context)...")
        rating_result = self.rate_query(question)
        print(f"Query Rating: {rating_result['rating']}/5")
        print(f"Explanation: {rating_result['explanation']}")
        print("\n" + "-"*50 + "\n")
        
        # Step 2: Suggest rewrites if the rating is less than perfect
        rewritten_queries = []
        if rating_result["rating"] < 5:
            print("Generating improved query variations...")
            rewritten_queries = self.suggest_rewrites(question)
            print("\n" + "-"*50 + "\n")
        
        # Step 3: Present options to the user
        query_options = self.present_query_options(question, rewritten_queries)
        print("\n" + "-"*50 + "\n")
        
        # Step 4: Get user selection
        selected_indices = input("Enter the numbers of the queries you want to use (comma-separated, e.g., '0,2,3'): ")
        try:
            selected_indices = [int(idx.strip()) for idx in selected_indices.split(",")]
        except ValueError:
            print("Invalid input. Please enter numbers separated by commas.")
            return "Invalid query selection. Please try again."
        
        # Get the selected queries
        selected_queries = []
        for idx in selected_indices:
            if idx == 0:  # Original query
                selected_queries.append(question)
            else:  # Rewritten query
                # Check if the index is valid
                if idx <= len(rewritten_queries):
                    query_text = rewritten_queries[idx-1]
                    selected_queries.append(query_text)
        
        if not selected_queries:
            print("No valid queries selected.")
            return "No valid queries were selected. Please try again."
            
        print(f"\nSelected {len(selected_queries)} queries for retrieval.")
        print("\n" + "-"*50 + "\n")
        
        # Step 5: Retrieve documents
        print("Retrieving relevant documents...")
        if len(selected_queries) > 1:
            # If multiple queries were selected, use the multi-query approach
            docs, query_docs_map = self.retrieve_documents(selected_queries, k=num_results)
            
            # Generate a combined answer
            print("\nGenerating combined answer with conversation context...")
            answer = self.generate_combined_answer(question, query_docs_map)
        else:
            # If only one query was selected, use the standard approach
            docs = self.retriever.get_relevant_documents(selected_queries[0])
            
            print(f"Retrieved {len(docs)} documents")
            
            # Show previews of the retrieved documents
            print("\nTop retrieved documents:")
            for i, doc in enumerate(docs[:5], 1):
                print(f"{i}. Preview: {doc.page_content[:150]}...")
                if hasattr(doc, 'metadata') and doc.metadata:
                    print(f"   Source: {doc.metadata.get('source', 'Unknown')}")
            
            # Generate answer with memory context
            print("\nGenerating answer with conversation context...")
            answer = self.generate_final_answer(question, docs)
        
        # The answer is returned but not printed again
        return answer
    
    def process_query(self, query: str, use_memory: bool = False, num_results: int = 5) -> str:
        """Main entry point that decides whether to use memory or the full pipeline."""
        # Check if this might be a follow-up question that should use memory
        chat_history = self.memory.load_memory_variables({}).get("chat_history", [])
        
        if use_memory or (chat_history and self._is_likely_followup(query)):
            print("Using conversation memory to process this query...")
            return self.ask_with_memory(query, num_results)
        else:
            print("Using full query improvement pipeline...")
            return self.search_with_query_feedback(query, num_results)
    
    def _is_likely_followup(self, query: str) -> bool:
        """Heuristically determines if a query is likely a follow-up question."""
        # Look for pronouns, references, and questions that seem to refer to previous context
        followup_indicators = [
            "it", "this", "that", "they", "them", "these", "those",
            "previous", "earlier", "above", "mentioned",
            "what about", "how about", "tell me more", "elaborate",
            "why", "how does", "can you explain"
        ]
        
        query_lower = query.lower()
        return any(indicator in query_lower for indicator in followup_indicators)

# Example usage
def main():    
    # Initialize the research assistant
    assistant = ResearchAssistant(llm=llm, vector_db=vector_db)
    
    # Main interaction loop
    while True:
        query = input("\nEnter your query (or 'exit' to quit): ")
        if query.lower() == 'exit':
            break
            
        # Ask the user whether to use memory mode or full pipeline
        use_memory = input("Use conversation memory? (y/n): ").lower() == 'y'
        
        # Process the query
        answer = assistant.process_query(query, use_memory=use_memory)
        # Only print the final answer once
        print("\nFinal Answer:", answer)

if __name__ == "__main__":
    main()