# LegalAI RAG Pipeline

This notebook implements the RAG pipeline using:
- **Pinecone Vector DB** (Scalable storage)
- **NER Redaction** (Privacy protection)
- **Legal-Specific Chunking** (Regex-based)
- **GPU Acceleration** (Google Colab optimized)
- **Comprehensive Metrics** (M1-M24 including BERTScore, FCD, Bias)

In [None]:
# --- 1. SETUP & DEPENDENCIES ---
import sys
import os

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running in Google Colab. Installing dependencies...")
    !pip install -q PyMuPDF==1.23.26 sentence-transformers pinecone psutil nltk rouge-score bert-score tiktoken transformers torch scikit-learn
    !pip install -q -U google-generativeai
else:
    print("Running Locally. Ensure 'requirements.txt' are installed.")

In [None]:
# --- 2. IMPORTS & CONFIG ---
import re
import time
import psutil
import warnings
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from pathlib import Path

import torch
from transformers import logging as hf_logging
hf_logging.set_verbosity_error()
warnings.filterwarnings("ignore", message="Some weights of the model")

# NLTK
import nltk
for res in ['wordnet', 'omw-1.4', 'punkt']:
    try: nltk.data.find(f'corpora/{res}')
    except LookupError: nltk.download(res, quiet=True)

from sentence_transformers import SentenceTransformer
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
from bert_score import score as bert_score
from sklearn.metrics.pairwise import cosine_similarity
import tiktoken

# --- API KEY INPUTS (SECURE) ---
import google.generativeai as genai
if IN_COLAB:
    from google.colab import userdata
else:
    userdata = None

def get_key(name):
    try:
        return userdata.get(name)
    except:
        return os.getenv(name)

print("\nüîê API CONFIGuration")
GOOGLE_API_KEY = get_key("GOOGLE_API_KEY")
if not GOOGLE_API_KEY:
    GOOGLE_API_KEY = input("Enter GOOGLE_API_KEY: ").strip()
genai.configure(api_key=GOOGLE_API_KEY)

PINECONE_API_KEY = get_key("PINECONE_API_KEY")
if not PINECONE_API_KEY:
    PINECONE_API_KEY = input("Enter PINECONE_API_KEY: ").strip()

print("‚úÖ Keys configured")

In [None]:
# --- 3. PINECONE SETUP ---
try:
    from pinecone import Pinecone, ServerlessSpec
    pc = Pinecone(api_key=PINECONE_API_KEY)
    PINECONE_VERSION = "v3"
except ImportError:
    import pinecone
    pinecone.init(api_key=PINECONE_API_KEY, environment="us-east-1")
    PINECONE_VERSION = "v2"
    
INDEX_NAME = "mpnet-index-colab"

# Auto-create index
existing_indexes = [i.name for i in pc.list_indexes()] if PINECONE_VERSION == "v3" else pinecone.list_indexes()

if INDEX_NAME not in existing_indexes:
    print(f"Creating index '{INDEX_NAME}'...")
    if PINECONE_VERSION == "v3":
        pc.create_index(
            name=INDEX_NAME,
            dimension=768,
            metric="cosine",
            spec=ServerlessSpec(cloud="aws", region="us-east-1")
        )
    else:
        pinecone.create_index(name=INDEX_NAME, dimension=768, metric="cosine")

if PINECONE_VERSION == "v3":
    index = pc.Index(INDEX_NAME)
else:
    index = pinecone.Index(INDEX_NAME)

print(f"‚úÖ Connected to Pinecone Index: {INDEX_NAME}")

In [None]:
# --- 4. NER & PDF UTILS ---
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline

# Optional Redaction
USE_NER_REDACTION = input("Use NER Redaction to hide names? (y/n) [n]: ").lower() == 'y'

ner_pipeline = None
if USE_NER_REDACTION:
    print("Loading NER model...")
    ner_model_name = "dslim/bert-base-NER"
    ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
    ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name)
    ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
    print("‚úÖ NER Ready")

try:
    import fitz
    PDF_LIBRARY = "pymupdf"
except ImportError:
    from pypdf import PdfReader
    PDF_LIBRARY = "pypdf"

def extract_chunks(pdf_path, chunk_size=200):
    chunks = []
    # Advanced Regex for legal citations
    citation_pattern = re.compile(r"(Section\s\d+[A-Za-z]|Sec\.\s\d+[A-Za-z]|\d+\s?Cr\.?\s?\d+)", re.IGNORECASE)

    try:
        if PDF_LIBRARY == "pymupdf":
            doc = fitz.open(pdf_path)
            text = "\n".join([page.get_text("text") for page in doc])
            doc.close()
        else:
            reader = PdfReader(pdf_path)
            if reader.is_encrypted:
                try: reader.decrypt('')
                except: pass
            text = "\n".join([page.extract_text() or "" for page in reader.pages])

        text = re.sub(r"[*_]", "", text)
        text = re.sub(r"\s+", " ", text).strip()
        sentences = re.split(r'(?<=[.?!]) +|\n+', text)

        buf = ""
        for sent in sentences:
            parts = citation_pattern.split(sent)
            for part in parts:
                seg = part.strip()
                if not seg: continue
                if len(buf.split()) + len(seg.split()) < chunk_size:
                    buf += " " + seg
                else:
                    if buf.strip():
                        chunks.append({'text': buf.strip(), 'source': str(pdf_path)})
                    buf = seg
        if buf.strip():
            chunks.append({'text': buf.strip(), 'source': str(pdf_path)})

        return chunks
    except Exception as e:
        print(f"‚ö† Error processing {pdf_path}: {e}")
        return []

def run_ner_redact(text):
    if not ner_pipeline: return text
    try:
        ents = ner_pipeline(text)
        for e in ents:
            if e['entity_group'] == 'PER':
                text = re.sub(r'\b{}\b'.format(re.escape(e['word'].replace("##",""))), '[REDACTED]', text)
        return text
    except:
        return text

In [None]:
# --- 5. GATHER & EMBED (WITH CACHING, ROBUST IDS & M1 TRACKING) ---
import hashlib

CACHE_DIR = Path("emb_cache_v2")
CACHE_DIR.mkdir(exist_ok=True, parents=True)

# Global metric holder
M1_AVG_EMBED_TIME = 0.0

# --- 1. FILE FINDING ---
print("--- PDF SELECTION ---")

# Mount Drive Logic
if IN_COLAB:
    from google.colab import drive
    mount_choice = input("Mount Google Drive? (y/n): ").strip().lower()
    if mount_choice == "y":
        try:
            drive.mount("/content/drive", force_remount=True)
            print("‚úÖ Drive mounted.")
        except Exception as e:
            print(f"‚ùå Mount failed: {e}")
            print("üí° TIP: Go to 'Runtime > Disconnect and Delete Runtime' to reset.")

folder_choice = input("Folder type (Civil/Criminal/Both) [both]: ").strip().lower() or "both"
if folder_choice not in ["civil", "criminal", "both"]: folder_choice = "both"

# Path Logic
custom_path = input("Custom path? (Enter FILE PATH, not URL): ").strip()
if custom_path.startswith("http"):
    print("‚ùå URL detected. Using search fallback.")
    custom_path = ""

if custom_path:
    search_dir = Path(custom_path)
else:
    search_dir = BASE_DIR
    if IN_COLAB and not search_dir.exists(): search_dir = Path(".")

print(f"Searching in: {search_dir.absolute()}")
all_pdfs = list(search_dir.rglob("*.pdf"))
pdf_files = []

if folder_choice == "both":
    pdf_files = all_pdfs
else:
    term = folder_choice
    pdf_files = [p for p in all_pdfs if term in str(p).lower()]

print(f"Found {len(pdf_files)} PDFs.")
if not pdf_files and IN_COLAB:
    print("‚ùå No PDFs. Check your path or upload files.")

# --- 2. CACHING & EMBEDDING CORE ---
def get_file_hash(file_path):
    # Hash based on path + size + mod_time
    try:
        stat = file_path.stat()
        sig = f"{file_path.absolute()}_{stat.st_size}_{stat.st_mtime}"
        return hashlib.md5(sig.encode()).hexdigest()
    except: return "unknown"

def save_cache(cpath, vectors, chunks):
    np.savez_compressed(cpath, vectors=vectors, chunks=chunks)

def load_cache(cpath):
    data = np.load(cpath, allow_pickle=True)
    return data['vectors'], data['chunks']

model_name = "multi-qa-mpnet-base-cos-v1"
embedder = SentenceTransformer(model_name)
if torch.cuda.is_available(): embedder.to("cuda")

def process_and_upload(pdf_files, batch_size=64):
    print(f"Processing {len(pdf_files)} files...")
    global M1_AVG_EMBED_TIME
    
    # Identify what needs embedding vs loading
    to_embed = []
    to_upload = [] # (vectors, chunks)
    
    for pdf in tqdm(pdf_files, desc="Checking Cache"):
        fhash = get_file_hash(pdf)
        cpath = CACHE_DIR / f"{fhash}.npz"
        
        if cpath.exists():
            try:
                vecs, chks = load_cache(cpath)
                to_upload.append((vecs, chks))
            except:
                to_embed.append((pdf, cpath))
        else:
            to_embed.append((pdf, cpath))
            
    # Process new files
    if to_embed:
        print(f"‚ö° Embedding {len(to_embed)} new files...")
        t_start_all = time.time()
        
        for pdf, cpath in tqdm(to_embed, desc="Embedding New"):
            chunks = extract_chunks(str(pdf))
            if not chunks: continue
            
            if USE_NER_REDACTION:
                for c in chunks: c['text'] = run_ner_redact(c['text'])
            
            texts = [c['text'] for c in chunks]
            
            with torch.no_grad():
                vectors = embedder.encode(texts, show_progress_bar=False, convert_to_numpy=True)
                
            save_cache(cpath, vectors, chunks)
            to_upload.append((vectors, chunks))
        
        total_time = time.time() - t_start_all
        # M1: Avg time per file (scaled to 10 files as per description usually means per-batch, but here we do per file avg)
        if len(to_embed) > 0:
            M1_AVG_EMBED_TIME = (total_time / len(to_embed)) * 10 
            print(f"üìä M1 (Time for 10 PDFs): {round(M1_AVG_EMBED_TIME, 2)}s")
            
    # Upload to Pinecone
    print("Syncing with Pinecone...")
    upsert_batch = []
    
    for vecs, chks in tqdm(to_upload, desc="Prep Upload"):
        for i, vec in enumerate(vecs):
            chunk = chks[i]
            # ISSUE 1 FIX: Deterministic SHA1 ID
            unique_str = f"{chunk['text']}_{chunk['source']}"
            uid = hashlib.sha1(unique_str.encode()).hexdigest()
            
            meta = {
                'text': chunk['text'],
                'source': chunk['source']
            }
            upsert_batch.append((uid, vec.tolist(), meta))
            
            if len(upsert_batch) >= batch_size:
                index.upsert(vectors=upsert_batch)
                upsert_batch = []
                
    if upsert_batch:
        index.upsert(vectors=upsert_batch)
        
    print("‚úÖ Pipeline Complete.")

# Run
if pdf_files:
    process_and_upload(pdf_files)

In [None]:
# --- 6. RETRIEVAL & METRICS (FULL SUITE M1-M24) ---
from sklearn.feature_extraction.text import CountVectorizer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score

# M2: Index Size
M2_INDEX_SIZE = 0
try:
    stats = index.describe_index_stats()
    M2_INDEX_SIZE = stats.get('total_vector_count', 0)
except: pass

class LegalEvaluator:
    def __init__(self):
        self.legal_terms = {
            "plaintiff", "defendant", "petitioner", "respondent", "appellant", 
            "writ", "jurisdiction", "affidavit", "statute", "provision", "act",
            "section", "article", "constitution", "bench", "judgement", "decree",
            "bail", "custody", "conviction", "acquittal", "prima facie", "locus standi"
        }
        self.bias_terms = {
            "caste", "religion", "hindu", "muslim", "christian", "sikh", 
            "dalit", "brahmin", "shudra", "upper caste", "lower caste",
            "gender", "female", "male", "race", "ethnicity"
        }
        self.citation_pattern = re.compile(r"(v\.|vs\.|versus|AIR \d+|SCC \d+|Section \d+|Article \d+)", re.IGNORECASE)

    def calculate_legal_scores(self, text):
        words = set(re.findall(r"\w+", text.lower()))
        legal_matches = words.intersection(self.legal_terms)
        term_precision = (len(legal_matches) / len(words) * 100) if words else 0.0

        bias_matches = words.intersection(self.bias_terms)
        bias_score = (len(bias_matches) / len(words) * 100) if words else 0.0

        citations = self.citation_pattern.findall(text)
        citation_count = len(citations)
        cit_acc = 100.0 if citation_count > 0 else 0.0 # Simple proxy for now
        
        return term_precision, bias_score, cit_acc, citation_count

def retrieve_answer(query, top_k=5):
    t0 = time.time()
    q_vec = embedder.encode([query])[0]
    
    try:
        res = index.query(vector=q_vec.tolist(), top_k=top_k, include_metadata=True)
        matches = res['matches']
        texts = [m['metadata'].get('text', '') for m in matches]
        scores = [m['score'] for m in matches]
    except Exception as e:
        print(f"Retrieval Error: {e}")
        texts = []
        scores = []
        
    context = "\n\n".join(texts)
    t_retrieval = time.time() - t0

    if not context:
        return "No context found.", [], [], q_vec, t_retrieval, 0

    t_gen_start = time.time()
    prompt = f"Answer using ONLY the context.\n\nContext:\n{context}\n\nQuery: {query}\nAnswer:"
    model = genai.GenerativeModel("gemini-1.5-flash") 
    
    try:
        resp = model.generate_content(prompt)
        answer = resp.text.strip()
    except Exception as e:
        answer = f"[Gen Error: {e}]"
    
    t_gen = time.time() - t_gen_start
    return answer, texts, scores, q_vec, t_retrieval, t_gen

def compute_detailed_metrics(preds, gts, ret_texts, ret_scores, lat_r, lat_g):
    rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    legal_eval = LegalEvaluator()
    smoothie = SmoothingFunction().method4
    
    results = []
    print("Computing metrics (M1-M24)...")
    
    for i, (pred, gt, ret, scores) in enumerate(zip(preds, gts, ret_texts, ret_scores)):
        # --- M1 & M2 ---
        m1 = M1_AVG_EMBED_TIME if 'M1_AVG_EMBED_TIME' in globals() else 0.0
        m2 = M2_INDEX_SIZE
        
        # --- M3 Retrieval Latency ---
        m3 = lat_r[i]
        
        # --- M4 Cosine Similarity (Avg) ---
        m4 = np.mean(scores) if scores else 0.0
        
        # --- M5 Top-k Accuracy ---
        gt_words = set(gt.lower().split())
        ctx_words = set(" ".join(ret).lower().split())
        m5 = 100.0 if len(gt_words.intersection(ctx_words)) > (len(gt_words) * 0.3) else 0.0
        
        # --- M6-M8 ROUGE ---
        r_scores = rouge.score(gt, pred)
        m6 = r_scores['rouge1'].fmeasure
        m7 = r_scores['rouge2'].fmeasure
        m8 = r_scores['rougeL'].fmeasure
        
        # --- M9 Context Length ---
        m9 = len(" ".join(ret).split())
        
        # --- M10-M12 Text Quality ---
        m10 = sentence_bleu([gt.split()], pred.split(), smoothing_function=smoothie)
        m11 = meteor_score([gt.split()], pred.split())
        try:
            P, R, F1 = bert_score([pred], [gt], lang="en", rescale_with_baseline=True)
            m12 = float(F1[0])
        except: m12 = 0.0
        
        # --- M13 FCD ---
        m13 = 0.0
        if ret:
            try:
                _, _, F1_ctx = bert_score([pred], [" ".join(ret)], lang="en", rescale_with_baseline=True)
                m13 = 100 * (1 - float(F1_ctx[0]))
            except: pass
            
        # --- M14 Faithfulness ---
        ans_words = set(pred.lower().split())
        m14 = (len(ans_words.intersection(ctx_words)) / len(ans_words) * 100) if ans_words else 0.0
        
        # --- M15 GT Coverage ---
        m15 = (len(gt_words.intersection(ans_words)) / len(gt_words) * 100) if gt_words else 0.0
        
        # --- M16-M19 System ---
        m16 = lat_r[i] + lat_g[i]
        m17 = 1 / m16 if m16 > 0 else 0
        m18 = psutil.cpu_percent()
        m19 = psutil.virtual_memory().used / (1024**3)
        
        # --- M20-M24 Legal ---
        m21, m24, m20, cit_count = legal_eval.calculate_legal_scores(pred)
        m22 = 100.0 if cit_count > 1 else 0.0
        m23 = m13 # Duplicate
        
        results.append({
            "QID": i,
            "M1_EmbedTime": round(float(m1), 2), "M2_IndexSize": m2,
            "M3_RetLatency": round(m3, 3), "M4_CosSim": round(m4, 3), "M5_TopK_Acc": m5,
            "M6_R1": round(m6, 3), "M7_R2": round(m7, 3), "M8_RL": round(m8, 3),
            "M9_CtxLen": m9, "M10_BLEU": round(m10, 3), "M11_METEOR": round(m11, 3),
            "M12_BERT_F1": round(m12, 3), "M13_FCD": round(m13, 1),
            "M14_Faithful": round(m14, 1), "M15_GTCov": round(m15, 1),
            "M16_E2ELat": round(m16, 2), "M17_Throughput": round(m17, 2),
            "M18_CPU": m18, "M19_RAM": round(m19, 2),
            "M20_CitAcc": round(m20, 1), "M21_TermPrec": round(m21, 1),
            "M22_PrecCov": m22, "M23_FCD_dup": round(m23, 1), "M24_Bias": round(m24, 1)
        })
        
    return pd.DataFrame(results)

# Run Evaluation
queries = ["Explain the 2001 SC/ST promotion case.", "Summarize a criminal case discussed in the documents."]
gts = ["The 2001 case regarding SC/ST reservation in promotion established that Article 16(4A) is an enabling provision...", "Criminal case summary placeholder..."]

preds, ctxs, sco, lr, lg = [], [], [], [], []

print("\n--- RUNNING EVALUATION ---")
for q in queries:
    print(f"Processing Q: {q}...")
    a, c, s, qv, tr, tg = retrieve_answer(q)
    print(f"A: {a[:150]}...\n")
    preds.append(a)
    ctxs.append(c)
    sco.append(s)
    lr.append(tr)
    lg.append(tg)

if preds:
    df_metrics = compute_detailed_metrics(preds, gts, ctxs, sco, lr, lg)
    from IPython.display import display
    display(df_metrics)
    df_metrics.to_csv("metrics_final.csv", index=False)
    print("‚úÖ Full M1-M24 Metrics saved to metrics_final.csv")