# Conversational RAG - Step-by-Step Notebook

This notebook converts `src/document_chat/retrieval.py` into a procedural, debuggable format.

## What This Does
Implements a **Conversational Retrieval-Augmented Generation (RAG)** pipeline that:
1. Rewrites user questions using conversation context
2. Retrieves relevant document chunks from a FAISS index
3. Generates answers using an LLM

## Prerequisites
- FAISS index already created (via document ingestion)
- API keys set: `GOOGLE_API_KEY` and/or `GROQ_API_KEY`
- Config file at `config/config.yaml`

---
## Cell 1: Configuration Placeholders

**Purpose:** Define all configurable paths and parameters at the top for easy modification.

**External Dependencies:**
- `FAISS_INDEX_PATH`: Directory containing your FAISS index files (`index.faiss`, `index.pkl`)
- Environment variables: `GOOGLE_API_KEY`, `GROQ_API_KEY`, `LLM_PROVIDER`

In [None]:
# ============================================================
# CONFIGURATION PLACEHOLDERS - MODIFY THESE BEFORE RUNNING
# ============================================================

# Path to your FAISS index directory
FAISS_INDEX_PATH = "data/single_document_chat/your_session_id"  # <-- CHANGE THIS

# Name of the FAISS index files (without extension)
FAISS_INDEX_NAME = "index"

# Number of documents to retrieve per query
RETRIEVAL_K = 5

# Search type: "similarity" (cosine) or "mmr" (diversity-focused)
SEARCH_TYPE = "similarity"

# Session identifier for logging
SESSION_ID = "notebook_debug_session"

print(f"Configuration loaded:")
print(f"  FAISS Index: {FAISS_INDEX_PATH}")
print(f"  Retrieval K: {RETRIEVAL_K}")
print(f"  Search Type: {SEARCH_TYPE}")

---
## Cell 2: Imports

**Purpose:** Import all required libraries and modules.

**Key Dependencies:**
- `langchain_core`: Prompt templates, output parsers, message types
- `langchain_community.vectorstores.FAISS`: Vector similarity search
- `utils.model_loader.ModelLoader`: Loads LLM and embedding models
- `prompt.prompt_library.PROMPT_REGISTRY`: Central prompt storage

In [None]:
import sys
import os
from operator import itemgetter
from typing import List, Optional, Dict, Any

# LangChain imports
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.vectorstores import FAISS

# Project imports
from utils.model_loader import ModelLoader
from exception.custom_exception import DocumentPortalException
from logger import GLOBAL_LOGGER as log
from prompt.prompt_library import PROMPT_REGISTRY
from model.models import PromptType

print("All imports successful!")

---
## Cell 3: Load LLM Model

**Purpose:** Initialize the Language Model that will:
1. Rewrite questions with conversation context
2. Generate answers from retrieved documents

**What happens:**
- `ModelLoader` reads `config/config.yaml` for model settings
- Checks `LLM_PROVIDER` env var (default: "google")
- Returns either `ChatGoogleGenerativeAI` or `ChatGroq`

In [None]:
# Load the LLM
try:
    llm = ModelLoader().load_llm()
    if not llm:
        raise ValueError("LLM could not be loaded")
    log.info("LLM loaded successfully", session_id=SESSION_ID)
    print(f"LLM loaded: {type(llm).__name__}")
except Exception as e:
    print(f"ERROR loading LLM: {e}")
    raise DocumentPortalException("LLM loading error", sys)

---
## Cell 4: Load Prompt Templates

**Purpose:** Load the two prompt templates from the registry:

1. **contextualize_prompt**: Rewrites user questions to be standalone
   - Input: User question + chat history
   - Output: Standalone question (no context dependencies)

2. **qa_prompt**: Generates answers from context
   - Input: Retrieved documents + question + chat history
   - Output: Concise answer (max 3 sentences)

In [None]:
# Load prompt templates from registry
contextualize_prompt: ChatPromptTemplate = PROMPT_REGISTRY[
    PromptType.CONTEXTUALIZE_QUESTION.value
]
qa_prompt: ChatPromptTemplate = PROMPT_REGISTRY[
    PromptType.CONTEXT_QA.value
]

print("Prompts loaded successfully!")
print(f"\n--- Contextualize Prompt ---")
print(contextualize_prompt.messages[0].prompt.template[:200] + "...")
print(f"\n--- QA Prompt ---")
print(qa_prompt.messages[0].prompt.template[:200] + "...")

---
## Cell 5: Load FAISS Retriever

**Purpose:** Load the pre-built FAISS vector index from disk and create a retriever.

**What happens:**
1. Validates that `FAISS_INDEX_PATH` exists
2. Loads embedding model (must match what was used during indexing)
3. Loads FAISS index files (`index.faiss`, `index.pkl`)
4. Creates retriever configured to return top-K similar documents

**Note:** `allow_dangerous_deserialization=True` is needed because FAISS indexes contain pickled objects. Only use with trusted indexes.

In [None]:
# Load FAISS retriever
try:
    # Validate path exists
    if not os.path.isdir(FAISS_INDEX_PATH):
        raise FileNotFoundError(f"FAISS index directory not found: {FAISS_INDEX_PATH}")
    
    # Load embeddings (must match what was used during indexing)
    embeddings = ModelLoader().load_embeddings()
    print(f"Embeddings loaded: {type(embeddings).__name__}")
    
    # Load FAISS vectorstore from disk
    vectorstore = FAISS.load_local(
        FAISS_INDEX_PATH,
        embeddings,
        index_name=FAISS_INDEX_NAME,
        allow_dangerous_deserialization=True,  # Trust our own indexes
    )
    print(f"FAISS vectorstore loaded from: {FAISS_INDEX_PATH}")
    
    # Create retriever
    search_kwargs = {"k": RETRIEVAL_K}
    retriever = vectorstore.as_retriever(
        search_type=SEARCH_TYPE, 
        search_kwargs=search_kwargs
    )
    
    log.info(
        "FAISS retriever loaded successfully",
        index_path=FAISS_INDEX_PATH,
        index_name=FAISS_INDEX_NAME,
        k=RETRIEVAL_K,
        session_id=SESSION_ID,
    )
    print(f"Retriever ready! Will return top {RETRIEVAL_K} documents per query.")

except FileNotFoundError as e:
    print(f"ERROR: {e}")
    print("\nMake sure to:")
    print("1. Update FAISS_INDEX_PATH in Cell 1")
    print("2. Run document ingestion first to create the index")
    raise
except Exception as e:
    print(f"ERROR loading retriever: {e}")
    raise DocumentPortalException("Loading error in retriever", sys)

---
## Cell 6: Helper Function - Format Documents

**Purpose:** Convert retrieved document objects into a single text string for the LLM context.

**Input:** List of Document objects with `page_content` attribute

**Output:** All documents joined with double newlines

In [None]:
def format_docs(docs) -> str:
    """
    Format retrieved documents into a single context string.
    
    Args:
        docs: List of Document objects from retriever
        
    Returns:
        str: All document contents joined with '\n\n'
    """
    return "\n\n".join(getattr(d, "page_content", str(d)) for d in docs)

print("format_docs function defined.")

---
## Cell 7: Build LCEL Chain - Question Rewriter

**Purpose:** Create the first stage of the pipeline that rewrites user questions.

**Flow:**
```
{"input": question, "chat_history": [...]} 
    → contextualize_prompt (template)
    → llm (rewrites question)
    → StrOutputParser (extracts text)
    → "Standalone question"
```

**Example:**
- Input: "What happens if they're late?" + history about payments
- Output: "What happens if payments are late?"

In [None]:
# Stage 1: Question Rewriter Chain
question_rewriter = (
    {"input": itemgetter("input"), "chat_history": itemgetter("chat_history")}
    | contextualize_prompt
    | llm
    | StrOutputParser()
)

print("Question rewriter chain built.")
print("\nThis chain takes:")
print("  - 'input': The user's question")
print("  - 'chat_history': List of previous messages")
print("And returns: A standalone question string")

---
## Cell 8: Build LCEL Chain - Document Retriever

**Purpose:** Chain the question rewriter to the retriever and document formatter.

**Flow:**
```
question_rewriter output (standalone question)
    → retriever (FAISS similarity search)
    → format_docs (join into single string)
    → "Doc1 content\n\nDoc2 content\n\n..."
```

In [None]:
# Stage 2: Document Retrieval Chain
retrieve_docs = question_rewriter | retriever | format_docs

print("Document retrieval chain built.")
print("\nThis chain:")
print("  1. Takes rewritten question from Stage 1")
print("  2. Searches FAISS for similar documents")
print(f"  3. Returns top {RETRIEVAL_K} documents as a single string")

---
## Cell 9: Build LCEL Chain - Full RAG Pipeline

**Purpose:** Create the complete end-to-end RAG chain.

**Flow:**
```
{"input": question, "chat_history": [...]}
    ↓
{
  "context": retrieve_docs output,
  "input": original question,
  "chat_history": original history
}
    → qa_prompt (answer template)
    → llm (generates answer)
    → StrOutputParser (extracts text)
    → "The answer is..."
```

In [None]:
# Stage 3: Full RAG Chain
rag_chain = (
    {
        "context": retrieve_docs,
        "input": itemgetter("input"),
        "chat_history": itemgetter("chat_history"),
    }
    | qa_prompt
    | llm
    | StrOutputParser()
)

log.info("LCEL graph built successfully", session_id=SESSION_ID)
print("Full RAG chain built successfully!")
print("\nThe complete pipeline:")
print("  User Question → Rewrite → Retrieve → Generate Answer")

---
## Cell 10: Invoke Function

**Purpose:** Create a wrapper function to invoke the RAG chain with proper error handling.

**Parameters:**
- `user_input`: The user's question
- `chat_history`: List of previous messages (optional)

**Returns:** The generated answer string

In [None]:
def invoke_rag(user_input: str, chat_history: Optional[List[BaseMessage]] = None) -> str:
    """
    Invoke the RAG pipeline to answer a question.
    
    Args:
        user_input: The user's question
        chat_history: Previous conversation messages (optional)
        
    Returns:
        str: The generated answer
    """
    try:
        chat_history = chat_history or []
        payload = {"input": user_input, "chat_history": chat_history}
        
        answer = rag_chain.invoke(payload)
        
        if not answer:
            log.warning(
                "No answer generated", 
                user_input=user_input, 
                session_id=SESSION_ID
            )
            return "No answer generated."
        
        log.info(
            "Chain invoked successfully",
            session_id=SESSION_ID,
            user_input=user_input,
            answer_preview=str(answer)[:150],
        )
        return answer
        
    except Exception as e:
        print(f"ERROR during invocation: {e}")
        log.error("Failed to invoke RAG", error=str(e))
        raise DocumentPortalException("Invocation error in RAG", sys)

print("invoke_rag function defined.")

---
## Cell 11: Test - Single Question

**Purpose:** Test the RAG pipeline with a single question (no history).

**Modify the question below to match your indexed documents.**

In [None]:
# ============================================================
# TEST: Single Question (No History)
# ============================================================

test_question = "What is the main topic of this document?"  # <-- CHANGE THIS

print(f"Question: {test_question}")
print("="*60)

answer = invoke_rag(test_question, chat_history=[])

print(f"\nAnswer:\n{answer}")

---
## Cell 12: Test - Multi-Turn Conversation

**Purpose:** Test the conversational aspect - asking follow-up questions that reference previous exchanges.

**This demonstrates:**
- Question rewriting with context
- Chat history management

In [None]:
# ============================================================
# TEST: Multi-Turn Conversation
# ============================================================

# Initialize conversation history
conversation_history = []

# --- Turn 1 ---
q1 = "What are the key points discussed?"  # <-- CHANGE THIS
print(f"Q1: {q1}")

a1 = invoke_rag(q1, chat_history=conversation_history)
print(f"A1: {a1}")

# Add to history
conversation_history.extend([
    HumanMessage(content=q1),
    AIMessage(content=a1)
])

print("\n" + "="*60 + "\n")

# --- Turn 2 (Follow-up) ---
q2 = "Can you tell me more about that?"  # <-- CHANGE THIS (refers to previous answer)
print(f"Q2: {q2}")

a2 = invoke_rag(q2, chat_history=conversation_history)
print(f"A2: {a2}")

# Add to history
conversation_history.extend([
    HumanMessage(content=q2),
    AIMessage(content=a2)
])

print("\n" + "="*60 + "\n")

# --- Turn 3 (Another follow-up) ---
q3 = "Are there any exceptions?"  # <-- CHANGE THIS
print(f"Q3: {q3}")

a3 = invoke_rag(q3, chat_history=conversation_history)
print(f"A3: {a3}")

---
## Cell 13: Debug - Inspect Intermediate Steps

**Purpose:** Step-by-step debugging to see what happens at each stage.

Useful for:
- Verifying question rewriting works correctly
- Checking what documents are being retrieved
- Understanding why an answer might be wrong

In [None]:
# ============================================================
# DEBUG: Inspect Each Stage
# ============================================================

debug_question = "What happens if they're late?"  # <-- CHANGE THIS
debug_history = [
    HumanMessage(content="What are the payment terms?"),
    AIMessage(content="Payment is due within 30 days.")
]

print("Original Question:", debug_question)
print("Chat History:", [(m.type, m.content[:50]) for m in debug_history])
print("\n" + "="*60)

# Step 1: See how question gets rewritten
print("\n[STAGE 1] Question Rewriting")
rewritten = question_rewriter.invoke({
    "input": debug_question,
    "chat_history": debug_history
})
print(f"Rewritten Question: {rewritten}")

# Step 2: See what documents are retrieved
print("\n" + "="*60)
print("\n[STAGE 2] Document Retrieval")
retrieved_docs = retriever.invoke(rewritten)
print(f"Retrieved {len(retrieved_docs)} documents:")
for i, doc in enumerate(retrieved_docs, 1):
    preview = doc.page_content[:150].replace('\n', ' ')
    print(f"\n  Doc {i}: {preview}...")

# Step 3: See the formatted context
print("\n" + "="*60)
print("\n[STAGE 3] Formatted Context (first 500 chars)")
context = format_docs(retrieved_docs)
print(context[:500] + "...")

# Step 4: Final answer
print("\n" + "="*60)
print("\n[STAGE 4] Final Answer")
final_answer = invoke_rag(debug_question, debug_history)
print(final_answer)

---
## Cell 14: Interactive Chat Loop

**Purpose:** Interactive testing - keep asking questions until you type 'quit'.

**Note:** Run this cell and interact in the output area.

In [None]:
# ============================================================
# INTERACTIVE: Chat Loop
# ============================================================

print("Interactive RAG Chat")
print("Type 'quit' to exit, 'clear' to reset history")
print("="*60 + "\n")

interactive_history = []

while True:
    user_input = input("You: ").strip()
    
    if user_input.lower() == 'quit':
        print("Goodbye!")
        break
    
    if user_input.lower() == 'clear':
        interactive_history = []
        print("History cleared.\n")
        continue
    
    if not user_input:
        continue
    
    try:
        response = invoke_rag(user_input, interactive_history)
        print(f"\nAssistant: {response}\n")
        
        # Update history
        interactive_history.extend([
            HumanMessage(content=user_input),
            AIMessage(content=response)
        ])
    except Exception as e:
        print(f"Error: {e}\n")

---
## Summary

### Variables Persisting Between Cells

| Variable | Type | Description |
|----------|------|-------------|
| `llm` | LLM | Language model instance |
| `embeddings` | Embeddings | Embedding model for FAISS |
| `retriever` | Retriever | FAISS-based document retriever |
| `contextualize_prompt` | ChatPromptTemplate | Question rewriting prompt |
| `qa_prompt` | ChatPromptTemplate | Answer generation prompt |
| `question_rewriter` | Chain | Stage 1 chain |
| `retrieve_docs` | Chain | Stage 2 chain |
| `rag_chain` | Chain | Complete RAG pipeline |

### Functions

| Function | Purpose |
|----------|--------|
| `format_docs(docs)` | Convert document list to string |
| `invoke_rag(question, history)` | Execute the RAG pipeline |