# Experiments: retrieval

**Original File:** `src/document_chat/retrieval.py`

## Purpose
This module implements a Conversational RAG (Retrieval-Augmented Generation) system using LangChain Expression Language (LCEL). It enables context-aware question answering over documents stored in a FAISS vector index, with support for maintaining chat history.

## Key Components
- **ConversationalRAG class**: Main RAG pipeline with lazy retriever initialization
  - `load_retriever_from_faiss()`: Load FAISS vectorstore and build retriever
  - `invoke()`: Execute the RAG pipeline with user input and chat history
  - `_build_lcel_chain()`: Build the LCEL processing chain

## Prerequisites
- `langchain`, `langchain-core`, `langchain-community` installed
- FAISS vector index created (see `data_ingestion.py`)
- Environment variables configured (API keys for LLM)
- Prompt templates in `prompt/prompt_library.py`

## Instructions & Setup Guide

### Execution Order
1. Run the imports cell
2. Review the ConversationalRAG class definition
3. Initialize the RAG system with a session ID
4. Load a FAISS index using `load_retriever_from_faiss()`
5. Invoke the RAG with questions and chat history

### Dependencies
```bash
pip install langchain langchain-core langchain-community faiss-cpu python-dotenv
```

### Configuration
- Ensure `.env` file contains `GROQ_API_KEY` and/or `GOOGLE_API_KEY`
- FAISS index must exist at the specified path
- Run from project root directory for proper imports

## 1. Imports and Dependencies

Import all required modules for the Conversational RAG system.

In [None]:
import sys
import os
from operator import itemgetter
from typing import List, Optional, Dict, Any

# LangChain imports
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.vectorstores import FAISS

# Project imports
from utils.model_loader import ModelLoader
from exception.custom_exception import DocumentPortalException
from logger import GLOBAL_LOGGER as log
from prompt.prompt_library import PROMPT_REGISTRY
from model.models import PromptType

print("All imports successful!")

## 2. ConversationalRAG Class Definition

The main class that orchestrates the RAG pipeline. Key features:
- **Lazy initialization**: Retriever and chain are built only when needed
- **LCEL-based**: Uses LangChain Expression Language for composable chains
- **Context-aware**: Rewrites questions based on chat history before retrieval

In [None]:
class ConversationalRAG:
    """
    LCEL-based Conversational RAG with lazy retriever initialization.

    Usage:
        rag = ConversationalRAG(session_id="abc")
        rag.load_retriever_from_faiss(index_path="faiss_index/abc", k=5, index_name="index")
        answer = rag.invoke("What is ...?", chat_history=[])
    """

    def __init__(self, session_id: Optional[str], retriever=None):
        try:
            self.session_id = session_id

            # Load LLM and prompts once
            self.llm = self._load_llm()
            self.contextualize_prompt: ChatPromptTemplate = PROMPT_REGISTRY[
                PromptType.CONTEXTUALIZE_QUESTION.value
            ]
            self.qa_prompt: ChatPromptTemplate = PROMPT_REGISTRY[
                PromptType.CONTEXT_QA.value
            ]

            # Lazy pieces
            self.retriever = retriever
            self.chain = None
            if self.retriever is not None:
                self._build_lcel_chain()

            log.info("ConversationalRAG initialized", session_id=self.session_id)
        except Exception as e:
            log.error("Failed to initialize ConversationalRAG", error=str(e))
            raise DocumentPortalException("Initialization error in ConversationalRAG", sys)

    # ---------- Private Methods ----------

    def _load_llm(self):
        """Load the LLM from ModelLoader."""
        try:
            llm = ModelLoader().load_llm()
            if not llm:
                raise ValueError("LLM could not be loaded")
            log.info("LLM loaded successfully", session_id=self.session_id)
            return llm
        except Exception as e:
            log.error("Failed to load LLM", error=str(e))
            raise DocumentPortalException("LLM loading error in ConversationalRAG", sys)

    @staticmethod
    def _format_docs(docs) -> str:
        """Format retrieved documents into a single string."""
        return "\n\n".join(getattr(d, "page_content", str(d)) for d in docs)

    def _build_lcel_chain(self):
        """Build the LCEL chain for conversational RAG."""
        try:
            if self.retriever is None:
                raise DocumentPortalException("No retriever set before building chain", sys)

            # Step 1: Rewrite user question with chat history context
            question_rewriter = (
                {"input": itemgetter("input"), "chat_history": itemgetter("chat_history")}
                | self.contextualize_prompt
                | self.llm
                | StrOutputParser()
            )

            # Step 2: Retrieve docs for rewritten question
            retrieve_docs = question_rewriter | self.retriever | self._format_docs

            # Step 3: Answer using retrieved context + original input + chat history
            self.chain = (
                {
                    "context": retrieve_docs,
                    "input": itemgetter("input"),
                    "chat_history": itemgetter("chat_history"),
                }
                | self.qa_prompt
                | self.llm
                | StrOutputParser()
            )

            log.info("LCEL graph built successfully", session_id=self.session_id)
        except Exception as e:
            log.error("Failed to build LCEL chain", error=str(e), session_id=self.session_id)
            raise DocumentPortalException("Failed to build LCEL chain", sys)

print("ConversationalRAG class defined successfully!")

## 3. Public API Methods

Add the public methods to the ConversationalRAG class:
- `load_retriever_from_faiss()`: Load FAISS index and create retriever
- `invoke()`: Execute the RAG chain with user input

In [None]:
# Add these methods to ConversationalRAG class

def load_retriever_from_faiss(
    self,
    index_path: str,
    k: int = 5,
    index_name: str = "index",
    search_type: str = "similarity",
    search_kwargs: Optional[Dict[str, Any]] = None,
):
    """
    Load FAISS vectorstore from disk and build retriever + LCEL chain.
    
    Args:
        index_path: Path to FAISS index directory
        k: Number of documents to retrieve (default: 5)
        index_name: Name of the index files (default: "index")
        search_type: Type of search - "similarity" or "mmr" (default: "similarity")
        search_kwargs: Additional search parameters
    
    Returns:
        The configured retriever
    """
    try:
        if not os.path.isdir(index_path):
            raise FileNotFoundError(f"FAISS index directory not found: {index_path}")

        embeddings = ModelLoader().load_embeddings()
        vectorstore = FAISS.load_local(
            index_path,
            embeddings,
            index_name=index_name,
            allow_dangerous_deserialization=True,  # ok if you trust the index
        )

        if search_kwargs is None:
            search_kwargs = {"k": k}

        self.retriever = vectorstore.as_retriever(
            search_type=search_type, search_kwargs=search_kwargs
        )
        self._build_lcel_chain()

        log.info(
            "FAISS retriever loaded successfully",
            index_path=index_path,
            index_name=index_name,
            k=k,
            session_id=self.session_id,
        )
        return self.retriever

    except Exception as e:
        log.error("Failed to load retriever from FAISS", error=str(e))
        raise DocumentPortalException("Loading error in ConversationalRAG", sys)

# Attach to class
ConversationalRAG.load_retriever_from_faiss = load_retriever_from_faiss
print("load_retriever_from_faiss method added!")

In [None]:
def invoke(self, user_input: str, chat_history: Optional[List[BaseMessage]] = None) -> str:
    """
    Invoke the LCEL pipeline with user input and chat history.
    
    Args:
        user_input: The user's question
        chat_history: List of previous messages (HumanMessage/AIMessage)
    
    Returns:
        The generated answer string
    """
    try:
        if self.chain is None:
            raise DocumentPortalException(
                "RAG chain not initialized. Call load_retriever_from_faiss() before invoke().", sys
            )
        chat_history = chat_history or []
        payload = {"input": user_input, "chat_history": chat_history}
        answer = self.chain.invoke(payload)
        if not answer:
            log.warning(
                "No answer generated", user_input=user_input, session_id=self.session_id
            )
            return "no answer generated."
        log.info(
            "Chain invoked successfully",
            session_id=self.session_id,
            user_input=user_input,
            answer_preview=str(answer)[:150],
        )
        return answer
    except Exception as e:
        log.error("Failed to invoke ConversationalRAG", error=str(e))
        raise DocumentPortalException("Invocation error in ConversationalRAG", sys)

# Attach to class
ConversationalRAG.invoke = invoke
print("invoke method added!")

## 4. Usage Example

Demonstrate how to use the ConversationalRAG class with a FAISS index.

**Note:** You need to have a FAISS index created first. Use `data_ingestion.py` to create one.

In [None]:
# Example usage - adjust paths as needed

# Initialize the RAG system
session_id = "demo_session"
rag = ConversationalRAG(session_id=session_id)
print(f"RAG initialized with session: {session_id}")

In [None]:
# Load FAISS index (update path to your actual index)
INDEX_PATH = "faiss_index/demo_session"  # Update this path

# Check if index exists before loading
if os.path.exists(INDEX_PATH):
    retriever = rag.load_retriever_from_faiss(
        index_path=INDEX_PATH,
        k=5,  # Number of documents to retrieve
        search_type="similarity"
    )
    print(f"Retriever loaded from: {INDEX_PATH}")
else:
    print(f"Index not found at: {INDEX_PATH}")
    print("Please create an index first using data_ingestion.py")

In [None]:
# Ask a question (only if retriever was loaded)
if rag.retriever is not None:
    question = "What is the main topic of the document?"
    chat_history = []  # Empty for first question
    
    answer = rag.invoke(question, chat_history=chat_history)
    print(f"Question: {question}")
    print(f"Answer: {answer}")

In [None]:
# Follow-up question with chat history
if rag.retriever is not None:
    # Build chat history from previous interaction
    chat_history = [
        HumanMessage(content="What is the main topic of the document?"),
        AIMessage(content=answer)  # Previous answer
    ]
    
    follow_up = "Can you elaborate on that?"
    follow_up_answer = rag.invoke(follow_up, chat_history=chat_history)
    
    print(f"Follow-up: {follow_up}")
    print(f"Answer: {follow_up_answer}")

## Summary & Next Steps

### Key Takeaways
1. **ConversationalRAG** provides a complete RAG pipeline with chat history support
2. The LCEL chain consists of three steps:
   - Question rewriting (contextualizing with chat history)
   - Document retrieval from FAISS
   - Answer generation with retrieved context
3. Lazy initialization allows flexible setup of the retriever

### Possible Extensions
- Add streaming support for real-time responses
- Implement different retrieval strategies (MMR, hybrid search)
- Add source document citations to answers
- Implement conversation memory persistence
- Add re-ranking of retrieved documents