In [3]:
# Install dependencies (if using Colab/Jupyter)
# %pip install --quiet --upgrade langchain-text-splitters langchain-community langgraph langchain-openai faiss-cpu pypdf sentence-transformers torch rich

import os
import getpass
import logging
import time
from typing import List, Dict, Any
from typing_extensions import TypedDict
from dataclasses import dataclass
from pydantic import BaseModel, Field
from concurrent.futures import ThreadPoolExecutor

from langchain.chat_models import init_chat_model
from langchain_openai import OpenAIEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain import hub
from langgraph.graph import StateGraph, END, START
from rich import print as rprint  # For better console formatting

# Prompt for keys
if not os.environ.get("OPENAI_API_KEY"):
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API Key: ")
if not os.environ.get("LANGCHAIN_API_KEY"):
    os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("Enter your LangChain API Key: ")
os.environ["LANGSMITH_TRACING"] = "true"

# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class PerformanceMetrics:
    query_time: float = 0.0
    retrieval_time: float = 0.0
    generation_time: float = 0.0
    total_time: float = 0.0
    chunks_retrieved: int = 0

class EnhancedSearch(BaseModel):
    query: str = Field(description="The actual user query")
    section: str = Field(default="", description="Section reference")
    clause_type: str = Field(default="", description="Clause type")
    keywords: List[str] = Field(default_factory=list, description="Key terms")
    priority: str = Field(default="standard", description="Priority")

class EnhancedState(TypedDict):
    question: str
    query: EnhancedSearch
    context: List[Document]
    logic_result: str
    answer: str
    metrics: PerformanceMetrics
    intermediate_results: Dict[str, Any]

class OptimizedRAGSystem:
    def __init__(self, model_name: str = "gpt-4o-mini"):
        self.llm = init_chat_model(model_name, model_provider="openai")
        self.embeddings = OpenAIEmbeddings(model="text-embedding-3-large", chunk_size=1000)
        self.vector_store = None
        self.executor = ThreadPoolExecutor(max_workers=4)
        self.query_cache = {}

    def load_and_process_document(self, file_path: str) -> None:
        loader = PyPDFLoader(file_path)
        pages = loader.load()

        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=800, chunk_overlap=150, add_start_index=True,
            separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
        )
        all_splits = text_splitter.split_documents(pages)

        self.vector_store = FAISS.from_documents(
            all_splits, self.embeddings, normalize_L2=True
        )
        logger.info(f"Processed {len(all_splits)} chunks.")

    def enhanced_query_analysis(self, state: EnhancedState):
        question = state["question"]
        if question in self.query_cache:
            return {"query": self.query_cache[question]}

        prompt = f"""
        Analyze and extract structured info:
        Question: {question}
        Return clause_type, keywords, and section if mentioned.
        """
        structured_llm = self.llm.with_structured_output(EnhancedSearch)
        result = structured_llm.invoke(prompt)
        self.query_cache[question] = result
        return {"query": result}

    def hybrid_retrieval(self, state: EnhancedState):
        query = state["query"]
        semantic_results = self.vector_store.similarity_search_with_score(query.query, k=15, fetch_k=30)
        filtered = [(doc, score) for doc, score in semantic_results if score < 0.8]

        if query.clause_type:
            filtered = [
                (doc, score) for doc, score in filtered
                if query.clause_type.lower() in doc.page_content.lower()
            ] or filtered

        filtered.sort(key=lambda x: x[1])
        final_docs = [doc for doc, _ in filtered[:6]]
        return {"context": final_docs}

    def parallel_logic_evaluation(self, state: EnhancedState):
        docs = state["context"]
        clause = state["query"].clause_type or "relevant info"
        if not docs:
            return {"logic_result": "No relevant context found"}

        def evaluate_chunk(docs_chunk):
            context = "\n\n".join(doc.page_content for doc in docs_chunk)
            prompt = f"Context:\n{context}\n\nEvaluate relevance to: '{clause}'"
            return self.llm.invoke(prompt).content

        if len(docs) > 3:
            mid = len(docs)//2
            result1 = self.executor.submit(evaluate_chunk, docs[:mid]).result()
            result2 = self.executor.submit(evaluate_chunk, docs[mid:]).result()
            return {"logic_result": f"{result1}\n\n{result2}"}

        return {"logic_result": evaluate_chunk(docs)}

    def optimized_generation(self, state: EnhancedState):
        context = "\n\n".join(doc.page_content for doc in state["context"])
        prompt = f"""
        You are an expert document QA agent. Answer clearly.

        Question: {state['question']}

        Context:
        {context}

        Prior Analysis:
        {state.get('logic_result', '')}

        Instructions:
        - Answer factually using only context.
        - Include specific amounts, exclusions, or rules.
        - Say what's missing if info is incomplete.
        """
        response = self.llm.invoke(prompt)
        return {"answer": response.content}

# Pipeline setup
def create_optimized_rag_system(file_path: str) -> OptimizedRAGSystem:
    system = OptimizedRAGSystem()
    system.load_and_process_document(file_path)

    graph = (
        StateGraph(EnhancedState)
        .add_node("analyze_query", system.enhanced_query_analysis)
        .add_node("retrieve", system.hybrid_retrieval)
        .add_node("evaluate", system.parallel_logic_evaluation)
        .add_node("generate", system.optimized_generation)
        .add_edge(START, "analyze_query")
        .add_edge("analyze_query", "retrieve")
        .add_edge("retrieve", "evaluate")
        .add_edge("evaluate", "generate")
        .add_edge("generate", END)
        .compile()
    )
    system.graph = graph
    return system

def ask_question_optimized(system: OptimizedRAGSystem, question: str, verbose: bool = True):
    result = system.graph.invoke({"question": question, "metrics": PerformanceMetrics()})

    if verbose:
        rprint("\n[bold cyan]🧠 Final Answer[/bold cyan]\n")
        rprint(result["answer"])
    return result




In [5]:
file_path = "/content/drive/MyDrive/BAJHLIP23020V012223.pdf"
rag_system = create_optimized_rag_system(file_path)
ask_question_optimized(rag_system, "Does the policy cover hospitalization for mental illness?")


{'question': 'Does the policy cover hospitalization for mental illness?',
 'query': EnhancedSearch(query='Does the policy cover hospitalization for mental illness?', section='', clause_type='Policy Coverage', keywords=['hospitalization', 'mental illness', 'policy cover'], priority='high'),
 'context': [Document(id='c2d7c053-adeb-47a7-8754-7fc72f99c14b', metadata={'producer': 'Microsoft® Word 2016', 'creator': 'Microsoft® Word 2016', 'creationdate': '2022-06-16T20:06:13+05:30', 'author': 'Vinay Dhanokar/Head Office Pune/Corporate Communication/General', 'moddate': '2022-06-16T20:06:13+05:30', 'source': '/content/drive/MyDrive/BAJHLIP23020V012223.pdf', 'total_pages': 49, 'page': 11, 'page_label': '12', 'start_index': 1272}, page_content='specified in the Policy Schedule. \nThe above coverage is subject to fulfilment of following conditions:  \na. Mental Illness treatment is only covered where patient is diagnosed and treated by a psychiatrist, clinical \npsychologist or licensed psychoth