<a href="https://colab.research.google.com/github/ShamsRupak/ai-doc-processing-suite/blob/main/Full_RAG_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [15]:
# Install necessary libraries
!pip install -q langchain
!pip install -q langchain-community
!pip install -q chromadb
!pip install -q sentence-transformers
!pip install -q pypdf
!pip install -q transformers
!pip install -q accelerate
!pip install -q bitsandbytes
!pip install -q tiktoken
!pip install -q rank-bm25

In [16]:
import os
import torch
from typing import List, Dict, Any
import numpy as np
from tqdm import tqdm

# Document processing
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document

# Embeddings and Vector Store
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma, FAISS
from langchain.retrievers import BM25Retriever, EnsembleRetriever

# LLM and Chain
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cuda
GPU: Tesla T4


In [17]:
# Define PDF paths
pdf_paths = [
    "/content/LenderFeesWorksheetNew.pdf",
    "/content/appraisal_report.pdf",
    "/content/payslip_sample_image.pdf",
    "/content/sample_bank_statement.pdf",
    "/content/sample_contract.pdf"
]

# Load all PDFs
all_documents = []
for pdf_path in pdf_paths:
    try:
        loader = PyPDFLoader(pdf_path)
        documents = loader.load()
        # Add metadata
        for doc in documents:
            doc.metadata['source'] = os.path.basename(pdf_path)
        all_documents.extend(documents)
        print(f"Loaded {len(documents)} pages from {os.path.basename(pdf_path)}")
    except Exception as e:
        print(f"Error loading {pdf_path}: {e}")

print(f"\nTotal documents loaded: {len(all_documents)}")

Loaded 1 pages from LenderFeesWorksheetNew.pdf
Loaded 10 pages from appraisal_report.pdf
Loaded 1 pages from payslip_sample_image.pdf
Loaded 1 pages from sample_bank_statement.pdf
Loaded 2 pages from sample_contract.pdf

Total documents loaded: 15


In [18]:
# Advanced text splitter with overlap for better context preservation
def create_smart_chunks(documents: List[Document],
                       chunk_size: int = 512,
                       chunk_overlap: int = 128) -> List[Document]:
    """
    Create optimized chunks with semantic awareness
    """
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
        separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""],
        is_separator_regex=False
    )

    chunks = []
    for doc in documents:
        doc_chunks = text_splitter.split_documents([doc])
        chunks.extend(doc_chunks)

    return chunks

# Create chunks
chunks = create_smart_chunks(all_documents, chunk_size=512, chunk_overlap=128)
print(f"Created {len(chunks)} chunks from {len(all_documents)} documents")

Created 128 chunks from 15 documents


In [19]:
# Initialize embeddings model (using a smaller, efficient model)
embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    model_kwargs={'device': 'cuda'},
    encode_kwargs={'normalize_embeddings': True}
)

# Create FAISS vector store for faster similarity search
print("Creating vector store...")
vectorstore = FAISS.from_documents(
    documents=chunks,
    embedding=embeddings
)
print("Vector store created successfully")

# Create BM25 retriever for keyword-based search
bm25_retriever = BM25Retriever.from_documents(chunks)
bm25_retriever.k = 3

Creating vector store...
Vector store created successfully


In [20]:
# Create hybrid retriever combining dense and sparse retrieval
faiss_retriever = vectorstore.as_retriever(
    search_type="similarity",
    search_kwargs={"k": 3}
)

# Ensemble retriever for better performance
ensemble_retriever = EnsembleRetriever(
    retrievers=[bm25_retriever, faiss_retriever],
    weights=[0.3, 0.7]  # Weighted combination
)

print("Hybrid retrieval system ready")

Hybrid retrieval system ready


In [21]:
# Load TinyLlama model (1.1B parameters, efficient for T4 GPU)
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

print("Loading LLM...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    load_in_8bit=True  # 8-bit quantization for memory efficiency
)

# Create pipeline
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,
    temperature=0.7,
    top_p=0.95,
    repetition_penalty=1.15
)

llm = HuggingFacePipeline(pipeline=pipe)
print("LLM loaded successfully")

Loading LLM...


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Device set to use cuda:0


LLM loaded successfully


In [22]:
# Custom prompt template for better responses
prompt_template = """<|system|>
You are a helpful assistant analyzing financial and legal documents. Use the following context to answer the question accurately and concisely.
</s>
<|user|>
Context: {context}

Question: {question}
</s>
<|assistant|>
Based on the provided context, """

PROMPT = PromptTemplate(
    template=prompt_template,
    input_variables=["context", "question"]
)

In [23]:
# Create RAG chain with hybrid retriever
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=ensemble_retriever,
    return_source_documents=True,
    chain_type_kwargs={"prompt": PROMPT}
)

print("RAG pipeline ready for queries")

RAG pipeline ready for queries


In [24]:
def expand_query(query: str, num_expansions: int = 2) -> List[str]:
    """
    Expand query with synonyms and related terms for better retrieval
    """
    # Simple query expansion (in production, use WordNet or LLM-based expansion)
    expansions = [query]

    # Add common variations
    if "fee" in query.lower():
        expansions.append(query.replace("fee", "charge"))
        expansions.append(query.replace("fee", "cost"))
    if "payment" in query.lower():
        expansions.append(query.replace("payment", "amount"))
    if "property" in query.lower():
        expansions.append(query.replace("property", "real estate"))

    return list(set(expansions))[:num_expansions + 1]

In [25]:
def enhanced_query(question: str, use_expansion: bool = True) -> Dict[str, Any]:
    """
    Enhanced query with optional query expansion
    """
    if use_expansion:
        expanded_queries = expand_query(question)
        all_docs = []

        for query in expanded_queries:
            docs = ensemble_retriever.get_relevant_documents(query)
            all_docs.extend(docs)

        # Remove duplicates based on content
        unique_docs = []
        seen_content = set()
        for doc in all_docs:
            if doc.page_content not in seen_content:
                unique_docs.append(doc)
                seen_content.add(doc.page_content)

        # Use unique docs for context
        context = "\n\n".join([doc.page_content for doc in unique_docs[:5]])

        # Generate response
        response = llm(prompt_template.format(context=context, question=question))

        return {
            "question": question,
            "answer": response,
            "source_documents": unique_docs[:5],
            "expanded_queries": expanded_queries
        }
    else:
        result = qa_chain({"query": question})
        return result

In [26]:
# Define test queries covering different document types
test_queries = [
    # Financial queries
    "What are the total fees listed in the lender fees worksheet?",
    "What is the interest rate mentioned in the loan documents?",

    # Property queries
    "What is the appraised value of the property?",
    "Describe the property features mentioned in the appraisal report",

    # Employment/Income queries
    "What is the employee's net pay from the payslip?",
    "What deductions are shown on the payslip?",

    # Banking queries
    "What is the ending balance in the bank statement?",
    "List the major deposits shown in the bank statement",

    # Contract queries
    "What is the refund policy mentioned in the service agreement?",
    "What are the payment terms in the contract?"
]

# Test each query
results = []
for query in test_queries:
    print(f"\n{'='*80}")
    print(f"Query: {query}")
    print(f"{'='*80}")

    try:
        result = enhanced_query(query, use_expansion=True)

        print(f"\nAnswer: {result['answer']}")
        print(f"\nExpanded queries: {result.get('expanded_queries', [query])}")
        print(f"\nSource documents:")
        for i, doc in enumerate(result.get('source_documents', [])[:3]):
            print(f"\n{i+1}. From {doc.metadata.get('source', 'Unknown')}:")
            print(f"   {doc.page_content[:200]}...")

        results.append({
            "query": query,
            "answer": result['answer'],
            "sources": [doc.metadata.get('source', 'Unknown') for doc in result.get('source_documents', [])]
        })
    except Exception as e:
        print(f"Error processing query: {e}")
        results.append({
            "query": query,
            "answer": f"Error: {str(e)}",
            "sources": []
        })


Query: What are the total fees listed in the lender fees worksheet?

Answer: <|system|>
You are a helpful assistant analyzing financial and legal documents. Use the following context to answer the question accurately and concisely.
</s>
<|user|>
Context: ** B S Br L TP C= Borrower = Seller = Broker = Lender = Third Party = Correspondent
Calyx Form - feews.frm (09/2015)
FEES WORKSHEET
John Q. Smith / Mary A. Smith samplesmith
10/05/2015
30 YEAR FIXED -Purchase
XYZ Lender
$ 380,000 4.250 % 360 / 360 mths
475,000.00
1,121.53
4,520.00
380,000.00
Cash Deposit 5,000.00
needed to close 95,641.53
1,869.37
39.58
400.00
2,308.95
ORIGINATION CHARGES
Underwriting Fee XYZ Lender Borrower $ 550.00
Wire Transfer Fee XYZ Lender Borrower $ 75.00

400.00
2,308.95
ORIGINATION CHARGES
Underwriting Fee XYZ Lender Borrower $ 550.00
Wire Transfer Fee XYZ Lender Borrower $ 75.00
Administration Fee XYZ Lender Borrower $ 445.00
OTHER CHARGES
Appraisal Fee XYZ Lender Borrower $ 525.00
Credit Report Fee XYZ Lend

In [27]:
# Evaluate retrieval performance
def evaluate_retrieval_quality(results: List[Dict[str, Any]]) -> None:
    """
    Basic evaluation of retrieval quality
    """
    print("\n" + "="*80)
    print("RETRIEVAL PERFORMANCE SUMMARY")
    print("="*80)

    total_queries = len(results)
    successful_queries = sum(1 for r in results if "Error" not in r['answer'])

    print(f"\nTotal queries: {total_queries}")
    print(f"Successful queries: {successful_queries}")
    print(f"Success rate: {successful_queries/total_queries*100:.1f}%")

    # Source distribution
    all_sources = []
    for r in results:
        all_sources.extend(r['sources'])

    from collections import Counter
    source_counts = Counter(all_sources)

    print("\nDocument usage distribution:")
    for source, count in source_counts.most_common():
        print(f"  - {source}: {count} times")

# Run evaluation
evaluate_retrieval_quality(results)


RETRIEVAL PERFORMANCE SUMMARY

Total queries: 10
Successful queries: 10
Success rate: 100.0%

Document usage distribution:
  - appraisal_report.pdf: 24 times
  - LenderFeesWorksheetNew.pdf: 10 times
  - sample_contract.pdf: 6 times
  - sample_bank_statement.pdf: 6 times
  - payslip_sample_image.pdf: 4 times


In [28]:
# Save results for analysis
import json
from datetime import datetime

# Prepare results for saving
save_data = {
    "timestamp": datetime.now().isoformat(),
    "model": model_id,
    "total_chunks": len(chunks),
    "embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
    "results": results
}

# Save to JSON
with open("/content/rag_results.json", "w") as f:
    json.dump(save_data, f, indent=2)

print("\nResults saved to /content/rag_results.json")


Results saved to /content/rag_results.json
