<a href="https://colab.research.google.com/github/RL370/JSExperiment2150/blob/main/RAG_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!# Install all required packages
! pip install torch transformers sentence-transformers faiss-cpu rank_bm25 matplotlib seaborn bitsandbytes accelerate datasets tqdm
! pip install rag_pipeline



In [None]:
"""
Full-Scale Hybrid RAG Pipeline with Fine-tuning
Uses complete HotpotQA dataset with training capabilities
"""

import json, re, csv, time, logging, os
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Tuple, Optional
from pathlib import Path
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import torch
import faiss
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

from sentence_transformers import SentenceTransformer, CrossEncoder, InputExample, losses
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from rank_bm25 import BM25Okapi
from transformers import (
    AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForCausalLM,
    BitsAndBytesConfig, TrainingArguments, Trainer, DefaultDataCollator
)
from datasets import Dataset
from torch.utils.data import DataLoader

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# ======================================================================
# Configuration
# ======================================================================

@dataclass
class RAGConfig:
    # Data
    data_dir: str = "data"
    output_dir: str = "outputs"
    cache_dir: str = "cache"

    # Models
    dense_model: str = "sentence-transformers/all-mpnet-base-v2"
    rerank_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
    extractive_model: str = "deepset/roberta-base-squad2"
    generative_model: str = "microsoft/phi-2"

    # Training
    use_full_dataset: bool = True
    train_retriever: bool = True
    train_extractive: bool = True
    num_train_examples: int = -1  # -1 for all
    num_val_examples: int = -1    # -1 for all

    # Retrieval
    retrieval_k: int = 10
    alpha: float = 0.7  # Dense vs sparse balance

    # Training hyperparameters
    learning_rate: float = 2e-5
    num_epochs: int = 3
    batch_size: int = 8
    warmup_steps: int = 500

    # Hardware
    use_gpu: bool = True
    use_4bit_quant: bool = True
    max_length: int = 384

    # Checkpointing
    save_checkpoints: bool = True
    checkpoint_dir: str = "checkpoints"

# ======================================================================
# Data structures
# ======================================================================

@dataclass
class Document:
    id: str
    title: str
    content: str
    sentences: List[str]

@dataclass
class RAGOutput:
    question_id: str
    question: str
    answer: str
    answer_type: str
    gold_answer: str
    retrieved_passages: List[Dict[str, Any]]
    confidence_score: float
    prompt_tokens: int
    output_tokens: int
    total_tokens: int
    retrieval_scores: List[float]
    supporting_facts: List[List]
    retrieval_recall: float


In [None]:
# ======================================================================
# Enhanced Hybrid Retriever with Training
# ======================================================================

class EnhancedHybridRetriever:
    def __init__(self, config: RAGConfig):
        self.config = config
        self.use_gpu = config.use_gpu and torch.cuda.is_available()
        device = "cuda" if self.use_gpu else "cpu"

        logger.info(f"Loading retrieval models on {device}...")
        self.dense_encoder = SentenceTransformer(config.dense_model, device=device)
        self.reranker = CrossEncoder(config.rerank_model, device=device)

        self.documents = []
        self.doc_map = {}  # title -> document
        self.tokenized_docs = []
        self.bm25 = None
        self.dense_index = None
        self.document_embeddings = None

    def build_index(self, documents: List[Document]):
        """Build search indexes from documents"""
        logger.info(f"Building indexes for {len(documents)} documents...")
        self.documents = documents
        self.doc_map = {d.title: d for d in documents}

        doc_texts = [f"{d.title}: {d.content}" for d in documents]

        # Dense index
        logger.info("Encoding documents...")
        self.document_embeddings = self.dense_encoder.encode(
            doc_texts, show_progress_bar=True,
            batch_size=32 if self.use_gpu else 16,
            convert_to_numpy=True
        )

        self.dense_index = faiss.IndexFlatIP(self.document_embeddings.shape[1])
        faiss.normalize_L2(self.document_embeddings)
        self.dense_index.add(self.document_embeddings.astype("float32"))

        # Sparse index
        logger.info("Building BM25 index...")
        self.tokenized_docs = [d.content.lower().split() for d in documents]
        self.bm25 = BM25Okapi(self.tokenized_docs)

        logger.info("✓ Indexes built successfully")

    def prepare_training_data(self, examples: List[Dict]) -> List[InputExample]:
        """Prepare training data for retriever fine-tuning"""
        logger.info("Preparing retriever training data...")
        train_samples = []

        for ex in tqdm(examples, desc="Creating training pairs"):
            question = ex["question"]

            # Positive passages (supporting facts)
            positive_titles = set(title for title, _ in ex["supporting_facts"])

            for title in positive_titles:
                if title in self.doc_map:
                    pos_doc = self.doc_map[title]
                    pos_text = f"{pos_doc.title}: {pos_doc.content}"
                    train_samples.append(InputExample(texts=[question, pos_text], label=1.0))

            # Negative passages (random non-supporting)
            available_titles = [d.title for d in self.documents if d.title not in positive_titles]
            if available_titles:
                neg_titles = np.random.choice(available_titles, min(2, len(available_titles)), replace=False)
                for title in neg_titles:
                    if title in self.doc_map:
                        neg_doc = self.doc_map[title]
                        neg_text = f"{neg_doc.title}: {neg_doc.content}"
                        train_samples.append(InputExample(texts=[question, neg_text], label=0.0))

        logger.info(f"Created {len(train_samples)} training pairs")
        return train_samples

    def fine_tune(self, train_examples: List[Dict], val_examples: List[Dict]):
        """Fine-tune the dense retriever"""
        logger.info("="*60)
        logger.info("FINE-TUNING DENSE RETRIEVER")
        logger.info("="*60)

        train_samples = self.prepare_training_data(train_examples)

        # Create dataloader
        train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=self.config.batch_size)

        # Loss function
        train_loss = losses.CosineSimilarityLoss(self.dense_encoder)

        # Training
        output_path = Path(self.config.checkpoint_dir) / "dense_retriever"
        output_path.mkdir(parents=True, exist_ok=True)

        # Disable all logging/tracking
        os.environ['WANDB_DISABLED'] = 'true'
        os.environ['WANDB_MODE'] = 'disabled'

        self.dense_encoder.fit(
            train_objectives=[(train_dataloader, train_loss)],
            epochs=self.config.num_epochs,
            warmup_steps=self.config.warmup_steps,
            output_path=str(output_path),
            save_best_model=True,
            show_progress_bar=True,
            use_amp=self.use_gpu
        )

        logger.info("✓ Retriever fine-tuning complete")

        # Rebuild index with fine-tuned model
        logger.info("Rebuilding index with fine-tuned embeddings...")
        self.build_index(self.documents)

    def retrieve(self, query: str, k: int = 5, alpha: float = 0.7) -> Tuple[List[Document], List[float], List[float]]:
        """Hybrid retrieval with dense + sparse fusion"""
        # Dense retrieval
        q_embed = self.dense_encoder.encode([query], convert_to_numpy=True)
        faiss.normalize_L2(q_embed)
        dense_scores, dense_idx = self.dense_index.search(q_embed.astype("float32"), k * 2)
        dense_scores, dense_idx = dense_scores[0].tolist(), dense_idx[0].tolist()

        # Sparse retrieval
        query_tokens = query.lower().split()
        sparse_scores = self.bm25.get_scores(query_tokens)
        sparse_top_idx = np.argsort(sparse_scores)[-k * 2:][::-1].tolist()

        # Fusion
        max_dense = max(dense_scores) if dense_scores else 1.0
        max_sparse = max(sparse_scores) if max(sparse_scores) > 0 else 1.0

        doc_scores = {}
        for idx, score in zip(dense_idx, dense_scores):
            doc_scores[idx] = alpha * (score / max_dense)

        for idx in sparse_top_idx:
            current = doc_scores.get(idx, 0.0)
            doc_scores[idx] = current + (1 - alpha) * (sparse_scores[idx] / max_sparse)

        # Top-k fusion scores
        top_fused = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)[:k]
        candidate_docs = [self.documents[i] for i, _ in top_fused]
        candidate_scores = [s for _, s in top_fused]

        # Reranking
        pairs = [[query, f"{d.title}: {d.content}"] for d in candidate_docs]
        rerank_scores = self.reranker.predict(pairs, show_progress_bar=False)

        # Sort by rerank scores
        sorted_indices = np.argsort(rerank_scores)[::-1]

        final_docs = [candidate_docs[i] for i in sorted_indices]
        final_hybrid_scores = [candidate_scores[i] for i in sorted_indices]
        final_rerank_scores = [float(rerank_scores[i]) for i in sorted_indices]

        return final_docs, final_hybrid_scores, final_rerank_scores

    def calculate_retrieval_recall(self, retrieved_docs: List[Document],
                                   supporting_facts: List[List]) -> float:
        """Calculate recall@k for retrieved documents"""
        if not supporting_facts:
            return 0.0

        retrieved_titles = set(d.title for d in retrieved_docs)
        gold_titles = set(title for title, _ in supporting_facts)

        if not gold_titles:
            return 0.0

        overlap = retrieved_titles & gold_titles
        return len(overlap) / len(gold_titles)

In [None]:
# ======================================================================
# Enhanced Extractive QA with Fine-tuning
# ======================================================================

class EnhancedExtractiveQA:
    def __init__(self, config: RAGConfig):
        self.config = config
        self.use_gpu = config.use_gpu and torch.cuda.is_available()
        self.device = "cuda" if self.use_gpu else "cpu"

        logger.info(f"Loading extractive QA model on {self.device}...")
        self.tokenizer = AutoTokenizer.from_pretrained(config.extractive_model)
        self.model = AutoModelForQuestionAnswering.from_pretrained(config.extractive_model).to(self.device)
        self.model.eval()

    def prepare_training_data(self, examples: List[Dict], documents: Dict[str, Document]) -> Dataset:
        """Prepare training data for extractive QA"""
        logger.info("Preparing extractive QA training data...")

        training_samples = []

        for ex in tqdm(examples, desc="Creating QA pairs"):
            question = ex["question"]
            answer = ex["answer"]

            # Get supporting documents
            supporting_titles = set(title for title, _ in ex["supporting_facts"])

            for title in supporting_titles:
                if title in documents:
                    doc = documents[title]
                    context = doc.content

                    # Find answer in context
                    answer_start = context.lower().find(answer.lower())

                    if answer_start != -1:
                        training_samples.append({
                            "question": question,
                            "context": context,
                            "answer_start": answer_start,
                            "answer_text": answer
                        })

        logger.info(f"Created {len(training_samples)} training samples")
        return Dataset.from_list(training_samples)

    def preprocess_function(self, examples):
        """Tokenize and prepare inputs"""
        questions = examples["question"]
        contexts = examples["context"]

        tokenized = self.tokenizer(
            questions,
            contexts,
            max_length=self.config.max_length,
            truncation="only_second",
            padding="max_length",
            return_offsets_mapping=True
        )

        offset_mapping = tokenized.pop("offset_mapping")
        start_positions = []
        end_positions = []

        for i, (answer_start, answer_text) in enumerate(zip(examples["answer_start"], examples["answer_text"])):
            answer_end = answer_start + len(answer_text)

            # Find token positions
            sequence_ids = tokenized.sequence_ids(i)
            context_start = 0
            while sequence_ids[context_start] != 1:
                context_start += 1
            context_end = len(sequence_ids) - 1
            while sequence_ids[context_end] != 1:
                context_end -= 1

            # Find start and end token positions
            token_start = context_start
            while token_start <= context_end and offset_mapping[i][token_start][0] <= answer_start:
                token_start += 1
            start_positions.append(token_start - 1)

            token_end = context_end
            while token_end >= context_start and offset_mapping[i][token_end][1] >= answer_end:
                token_end -= 1
            end_positions.append(token_end + 1)

        tokenized["start_positions"] = start_positions
        tokenized["end_positions"] = end_positions

        return tokenized

    def fine_tune(self, train_examples: List[Dict], val_examples: List[Dict], doc_map: Dict[str, Document]):
        """Fine-tune extractive QA model"""
        logger.info("="*60)
        logger.info("FINE-TUNING EXTRACTIVE QA MODEL")
        logger.info("="*60)

        # Prepare datasets
        train_dataset = self.prepare_training_data(train_examples, doc_map)

        if not train_dataset:
            logger.warning("No training data available, skipping fine-tuning")
            return

        # Tokenize
        tokenized_train = train_dataset.map(
            self.preprocess_function,
            batched=True,
            remove_columns=train_dataset.column_names
        )

        # Training arguments
        output_dir = Path(self.config.checkpoint_dir) / "extractive_qa"
        training_args = TrainingArguments(
            output_dir=str(output_dir),
            learning_rate=self.config.learning_rate,
            per_device_train_batch_size=self.config.batch_size,
            num_train_epochs=self.config.num_epochs,
            warmup_steps=self.config.warmup_steps,
            logging_steps=100,
            save_strategy="epoch",
            fp16=self.use_gpu,
            report_to="none",
            use_cpu=(not self.use_gpu)
        )

        # Trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=tokenized_train,
            data_collator=DefaultDataCollator(),
            tokenizer=self.tokenizer
        )

        # Train
        trainer.train()

        # Save
        if self.config.save_checkpoints:
            trainer.save_model(str(output_dir / "final"))
            logger.info(f"✓ Model saved to {output_dir / 'final'}")

        logger.info("✓ Extractive QA fine-tuning complete")

    def extract_answer(self, question: str, context: str) -> Tuple[str, float]:
        """Extract answer span from context"""
        inputs = self.tokenizer(
            question, context,
            max_length=self.config.max_length,
            truncation=True,
            return_tensors="pt",
            padding=True
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)
            start_logits = outputs.start_logits[0]
            end_logits = outputs.end_logits[0]

        start_idx = torch.argmax(start_logits).item()
        end_idx = torch.argmax(end_logits).item()

        if end_idx < start_idx or (end_idx - start_idx) > 30:
            return "", 0.0

        tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
        answer = self.tokenizer.convert_tokens_to_string(tokens[start_idx:end_idx+1])

        confidence = (start_logits[start_idx] + end_logits[end_idx]).item() / 2
        return answer.strip(), float(torch.sigmoid(torch.tensor(confidence)).item())


In [None]:
# ======================================================================
# Enhanced Generative QA
# ======================================================================

class EnhancedGenerativeQA:
    def __init__(self, config: RAGConfig):
        self.config = config
        self.use_gpu = config.use_gpu and torch.cuda.is_available()

        logger.info("Loading generative model...")
        self.tokenizer = AutoTokenizer.from_pretrained(config.generative_model, trust_remote_code=True)

        if self.use_gpu and config.use_4bit_quant:
            quant_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4"
            )
            self.model = AutoModelForCausalLM.from_pretrained(
                config.generative_model,
                quantization_config=quant_config,
                device_map="auto",
                trust_remote_code=True,
                torch_dtype=torch.float16
            )
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                config.generative_model,
                device_map="auto" if self.use_gpu else None,
                trust_remote_code=True,
                torch_dtype=torch.float16 if self.use_gpu else torch.float32
            )

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model.eval()

    def generate_answer(self, question: str, context: str) -> Tuple[str, float]:
        """Generate answer from context"""
        prompt = f"""Based on the context, answer the question concisely.

Context: {context}

Question: {question}

Answer:"""

        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            max_length=1536,
            truncation=True
        ).to(self.model.device)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=100,
                temperature=0.3,
                do_sample=True,
                top_p=0.95,
                return_dict_in_generate=True,
                output_scores=True,
                pad_token_id=self.tokenizer.eos_token_id
            )

        # Calculate confidence
        scores = torch.stack(outputs.scores)
        probs = torch.softmax(scores, dim=-1)
        max_probs = probs.max(dim=-1).values
        confidence = max_probs.mean().item()

        answer = self.tokenizer.decode(
            outputs.sequences[0][len(inputs["input_ids"][0]):],
            skip_special_tokens=True
        ).strip()

        # Clean answer
        for prefix in ["Answer:", "A:", "The answer is", "answer:"]:
            if answer.lower().startswith(prefix.lower()):
                answer = answer[len(prefix):].strip()

        # Reject invalid answers
        invalid = ["unknown", "i don't know", "cannot determine", "not enough information", "unclear"]
        if any(inv in answer.lower() for inv in invalid) or len(answer.split()) < 2:
            return "", 0.0

        return answer, confidence


In [None]:
# ======================================================================
# Full Hybrid QA System
# ======================================================================

class FullHybridQASystem:
    def __init__(self, config: RAGConfig):
        self.config = config
        logger.info("Initializing Full Hybrid QA System...")

        self.extractive = EnhancedExtractiveQA(config)
        self.generative = EnhancedGenerativeQA(config)
        self.verifier = CrossEncoder(
            config.rerank_model,
            device="cuda" if config.use_gpu and torch.cuda.is_available() else "cpu"
        )

    def _is_multi_hop(self, question: str) -> bool:
        """Detect multi-hop questions"""
        multi_hop_patterns = [
            r'\b(who|what|when|where|which)\b.*\b(who|what|when|where|which)\b',
            r'\b(and|also|both|between)\b',
            r'\b(compare|contrast|relate)\b'
        ]
        q_lower = question.lower()
        return any(re.search(pattern, q_lower) for pattern in multi_hop_patterns)

    def answer(self, question: str, docs: List[Document]) -> Tuple[str, str, float, int, int, int]:
        """Generate answer using hybrid approach"""
        # Combine contexts
        context = "\n\n".join([f"{d.title}: {d.content[:800]}" for d in docs[:5]])

        # Try both
        ext_answer, ext_conf = self.extractive.extract_answer(question, context)
        gen_answer, gen_conf = self.generative.generate_answer(question, context)

        # Verify
        ext_verify = self.verifier.predict([[question, ext_answer]])[0] if ext_answer else 0.0
        gen_verify = self.verifier.predict([[question, gen_answer]])[0] if gen_answer else 0.0

        # Route decision
        is_multi_hop = self._is_multi_hop(question)

        if is_multi_hop:
            if gen_answer and gen_verify > 0.5:
                answer, ans_type, conf = gen_answer, "generative", gen_conf
            elif ext_answer:
                answer, ans_type, conf = ext_answer, "extractive", ext_conf
            else:
                answer, ans_type, conf = "", "none", 0.0
        else:
            if ext_answer and ext_verify > gen_verify:
                answer, ans_type, conf = ext_answer, "extractive", ext_conf
            elif gen_answer:
                answer, ans_type, conf = gen_answer, "generative", gen_conf
            elif ext_answer:
                answer, ans_type, conf = ext_answer, "extractive", ext_conf
            else:
                answer, ans_type, conf = "", "none", 0.0

        # Token counts
        p_tok = len(self.generative.tokenizer.encode(question + context[:500]))
        o_tok = len(self.generative.tokenizer.encode(answer)) if answer else 0

        return answer, ans_type, conf, p_tok, o_tok, p_tok + o_tok


In [None]:
# ======================================================================
# Full Pipeline
# ======================================================================

class FullHybridRAGPipeline:
    def __init__(self, config: RAGConfig):
        self.config = config
        logger.info("="*60)
        logger.info("INITIALIZING FULL-SCALE HYBRID RAG PIPELINE")
        logger.info("="*60)

        self.retriever = EnhancedHybridRetriever(config)
        self.qa_system = FullHybridQASystem(config)
        self.all_documents = []
        self.doc_map = {}

    def load_and_process_data(self):
        """Load full HotpotQA dataset"""
        logger.info("Loading HotpotQA dataset...")

        data_path = Path(self.config.data_dir) / "hotpot_dev_distractor_v1.json"

        if not data_path.exists():
            logger.info("Downloading HotpotQA...")
            from datasets import load_dataset

            # Load full dataset
            train_dset = load_dataset("hotpot_qa", "distractor", split="train", trust_remote_code=True)
            val_dset = load_dataset("hotpot_qa", "distractor", split="validation", trust_remote_code=True)

            def convert_item(item):
                return {
                    "_id": item["id"],
                    "question": item["question"],
                    "answer": item["answer"],
                    "type": item["type"],
                    "level": item["level"],
                    "context": [[t, s] for t, s in zip(item["context"]["title"], item["context"]["sentences"])],
                    "supporting_facts": [[t, sid] for t, sid in zip(item["supporting_facts"]["title"], item["supporting_facts"]["sent_id"])]
                }

            logger.info("Converting training data...")
            train_data = [convert_item(item) for item in tqdm(train_dset)]

            logger.info("Converting validation data...")
            val_data = [convert_item(item) for item in tqdm(val_dset)]

            # Save
            Path(self.config.data_dir).mkdir(exist_ok=True)
            with open(Path(self.config.data_dir) / "hotpot_train.json", "w") as f:
                json.dump(train_data, f, indent=2)
            with open(data_path, "w") as f:
                json.dump(val_data, f, indent=2)

            logger.info(f"✓ Saved {len(train_data)} training examples")
            logger.info(f"✓ Saved {len(val_data)} validation examples")

        # Load data
        train_path = Path(self.config.data_dir) / "hotpot_train.json"

        with open(train_path) as f:
            train_data = json.load(f)
        with open(data_path) as f:
            val_data = json.load(f)

        # Limit if specified
        if self.config.num_train_examples > 0:
            train_data = train_data[:self.config.num_train_examples]
        if self.config.num_val_examples > 0:
            val_data = val_data[:self.config.num_val_examples]

        logger.info(f"✓ Loaded {len(train_data)} train, {len(val_data)} val examples")

        return train_data, val_data

    def build_document_index(self, examples: List[Dict]):
        """Extract and index all documents"""
        logger.info("Extracting documents from examples...")

        docs = []
        for ex in tqdm(examples, desc="Processing"):
            for i, (title, sentences) in enumerate(ex["context"]):
                docs.append(Document(
                    id=f"{ex['_id']}_{i}",
                    title=title,
                    content=" ".join(sentences),
                    sentences=sentences
                ))

        # Deduplicate by title
        unique_docs = {d.title: d for d in docs}
        self.all_documents = list(unique_docs.values())
        self.doc_map = {d.title: d for d in self.all_documents}

        logger.info(f"✓ Extracted {len(self.all_documents)} unique documents")

        # Build index
        self.retriever.build_index(self.all_documents)

    def train(self, train_data: List[Dict], val_data: List[Dict]):
        """Train all components"""
        logger.info("="*60)
        logger.info("TRAINING PHASE")
        logger.info("="*60)

        # Train retriever
        if self.config.train_retriever:
            self.retriever.fine_tune(train_data, val_data)

        # Train extractive QA
        if self.config.train_extractive:
            self.qa_system.extractive.fine_tune(train_data, val_data, self.doc_map)

        logger.info("✓ Training complete!")

    def evaluate(self, val_data: List[Dict], output_file: str):
        """Evaluate on validation set"""
        logger.info("="*60)
        logger.info("EVALUATION PHASE")
        logger.info("="*60)

        results = []

        with open(output_file, "w") as f:
            for ex in tqdm(val_data, desc="Evaluating"):
                # Retrieve
                docs, h_scores, r_scores = self.retriever.retrieve(
                    ex["question"],
                    k=self.config.retrieval_k,
                    alpha=self.config.alpha
                )

                # Calculate retrieval recall
                retrieval_recall = self.retriever.calculate_retrieval_recall(
                    docs, ex.get("supporting_facts", [])
                )

                # Answer
                answer, ans_type, conf, p_tok, o_tok, t_tok = self.qa_system.answer(ex["question"], docs)

                output = RAGOutput(
                    question_id=ex["_id"],
                    question=ex["question"],
                    answer=answer,
                    answer_type=ans_type,
                    gold_answer=ex["answer"],
                    retrieved_passages=[{
                        "passage_id": d.id,
                        "title": d.title,
                        "content": d.content,
                        "hybrid_score": float(hs),
                        "rerank_score": float(rs)
                    } for d, hs, rs in zip(docs, h_scores, r_scores)],
                    confidence_score=conf,
                    prompt_tokens=p_tok,
                    output_tokens=o_tok,
                    total_tokens=t_tok,
                    retrieval_scores=h_scores,
                    supporting_facts=ex.get("supporting_facts", []),
                    retrieval_recall=retrieval_recall
                )

                results.append(output)
                f.write(json.dumps(asdict(output)) + "\n")

        logger.info(f"✓ Saved predictions to {output_file}")
        return results

In [None]:
# ======================================================================
# Enhanced Evaluator
# ======================================================================

class EnhancedRAGEvaluator:
    def normalize_answer(self, text: str) -> str:
        """Normalize answer text"""
        text = text.lower().strip()
        text = re.sub(r'\b(a|an|the)\b', ' ', text)
        text = re.sub(r'[^\w\s]', '', text)
        text = re.sub(r'\s+', ' ', text)
        return text.strip()

    def exact_match(self, predicted: str, gold: str) -> float:
        """Calculate exact match"""
        return 1.0 if self.normalize_answer(predicted) == self.normalize_answer(gold) else 0.0

    def f1_score(self, predicted: str, gold: str) -> float:
        """Calculate token-level F1 score"""
        pred_tokens = set(self.normalize_answer(predicted).split())
        gold_tokens = set(self.normalize_answer(gold).split())

        if not pred_tokens or not gold_tokens:
            return 0.0

        common = pred_tokens & gold_tokens
        if not common:
            return 0.0

        precision = len(common) / len(pred_tokens)
        recall = len(common) / len(gold_tokens)

        return 2 * (precision * recall) / (precision + recall)

    def evaluate_results(self, results_jsonl: str, original_data: List[Dict]) -> Dict:
        """Comprehensive evaluation"""
        logger.info("Evaluating results...")

        predictions = [json.loads(line) for line in open(results_jsonl)]
        gold_map = {ex["_id"]: ex for ex in original_data}

        metrics = {
            "overall": {"em": [], "f1": [], "retrieval_recall": []},
            "by_type": {},
            "by_question_type": {},
            "by_level": {}
        }

        for pred in predictions:
            qid = pred["question_id"]
            if qid not in gold_map:
                continue

            gold_ex = gold_map[qid]
            gold_answer = gold_ex["answer"]

            # Calculate metrics
            em = self.exact_match(pred["answer"], gold_answer)
            f1 = self.f1_score(pred["answer"], gold_answer)
            ret_recall = pred.get("retrieval_recall", 0.0)

            # Overall
            metrics["overall"]["em"].append(em)
            metrics["overall"]["f1"].append(f1)
            metrics["overall"]["retrieval_recall"].append(ret_recall)

            # By answer type
            ans_type = pred.get("answer_type", "unknown")
            if ans_type not in metrics["by_type"]:
                metrics["by_type"][ans_type] = {"em": [], "f1": [], "count": 0}
            metrics["by_type"][ans_type]["em"].append(em)
            metrics["by_type"][ans_type]["f1"].append(f1)
            metrics["by_type"][ans_type]["count"] += 1

            # By question type
            q_type = gold_ex.get("type", "unknown")
            if q_type not in metrics["by_question_type"]:
                metrics["by_question_type"][q_type] = {"em": [], "f1": []}
            metrics["by_question_type"][q_type]["em"].append(em)
            metrics["by_question_type"][q_type]["f1"].append(f1)

            # By difficulty level
            level = gold_ex.get("level", "unknown")
            if level not in metrics["by_level"]:
                metrics["by_level"][level] = {"em": [], "f1": []}
            metrics["by_level"][level]["em"].append(em)
            metrics["by_level"][level]["f1"].append(f1)

        # Aggregate
        results = {
            "overall": {
                "exact_match": float(np.mean(metrics["overall"]["em"])),
                "f1_score": float(np.mean(metrics["overall"]["f1"])),
                "retrieval_recall": float(np.mean(metrics["overall"]["retrieval_recall"])),
                "num_examples": len(metrics["overall"]["em"])
            },
            "by_type": {
                k: {
                    "exact_match": float(np.mean(v["em"])),
                    "f1_score": float(np.mean(v["f1"])),
                    "count": v["count"]
                } for k, v in metrics["by_type"].items()
            },
            "by_question_type": {
                k: {
                    "exact_match": float(np.mean(v["em"])),
                    "f1_score": float(np.mean(v["f1"])),
                    "count": len(v["em"])
                } for k, v in metrics["by_question_type"].items()
            },
            "by_level": {
                k: {
                    "exact_match": float(np.mean(v["em"])),
                    "f1_score": float(np.mean(v["f1"])),
                    "count": len(v["em"])
                } for k, v in metrics["by_level"].items()
            },
            "raw_scores": metrics["overall"]
        }

        return results


In [None]:
# ======================================================================
# Enhanced Visualizer
# ======================================================================

class EnhancedRAGVisualizer:
    def __init__(self, output_dir="visualizations"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        sns.set_style("whitegrid")
        sns.set_palette("husl")

    def plot_comprehensive_metrics(self, results: Dict, filename="full_results.png"):
        """Create comprehensive visualization"""
        fig = plt.figure(figsize=(20, 12))
        gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

        # Overall metrics
        ax1 = fig.add_subplot(gs[0, 0])
        overall = results["overall"]
        metrics = ["EM", "F1", "Ret.\nRecall"]
        values = [overall["exact_match"], overall["f1_score"], overall["retrieval_recall"]]
        bars = ax1.bar(metrics, values, color=["#2ecc71", "#3498db", "#9b59b6"], alpha=0.8)
        ax1.set_ylim(0, 1)
        ax1.set_title("Overall Performance", fontweight="bold", fontsize=12)
        for bar in bars:
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.3f}', ha='center', va='bottom', fontsize=10)

        # By answer type
        ax2 = fig.add_subplot(gs[0, 1])
        types = list(results["by_type"].keys())
        if types:
            x = np.arange(len(types))
            width = 0.35
            em_vals = [results["by_type"][t]["exact_match"] for t in types]
            f1_vals = [results["by_type"][t]["f1_score"] for t in types]
            ax2.bar(x - width/2, em_vals, width, label="EM", alpha=0.8)
            ax2.bar(x + width/2, f1_vals, width, label="F1", alpha=0.8)
            ax2.set_xticks(x)
            ax2.set_xticklabels(types, rotation=45, ha="right")
            ax2.legend()
            ax2.set_title("Performance by Answer Type", fontweight="bold", fontsize=12)
            ax2.set_ylim(0, 1)

        # By question type
        ax3 = fig.add_subplot(gs[0, 2])
        q_types = list(results["by_question_type"].keys())
        if q_types:
            x = np.arange(len(q_types))
            width = 0.35
            em_vals = [results["by_question_type"][t]["exact_match"] for t in q_types]
            f1_vals = [results["by_question_type"][t]["f1_score"] for t in q_types]
            ax3.bar(x - width/2, em_vals, width, label="EM", alpha=0.8)
            ax3.bar(x + width/2, f1_vals, width, label="F1", alpha=0.8)
            ax3.set_xticks(x)
            ax3.set_xticklabels(q_types, rotation=45, ha="right")
            ax3.legend()
            ax3.set_title("Performance by Question Type", fontweight="bold", fontsize=12)
            ax3.set_ylim(0, 1)

        # By difficulty level
        ax4 = fig.add_subplot(gs[1, 0])
        levels = list(results["by_level"].keys())
        if levels:
            x = np.arange(len(levels))
            width = 0.35
            em_vals = [results["by_level"][lv]["exact_match"] for lv in levels]
            f1_vals = [results["by_level"][lv]["f1_score"] for lv in levels]
            ax4.bar(x - width/2, em_vals, width, label="EM", alpha=0.8)
            ax4.bar(x + width/2, f1_vals, width, label="F1", alpha=0.8)
            ax4.set_xticks(x)
            ax4.set_xticklabels(levels, rotation=45, ha="right")
            ax4.legend()
            ax4.set_title("Performance by Difficulty", fontweight="bold", fontsize=12)
            ax4.set_ylim(0, 1)

        # F1 distribution
        ax5 = fig.add_subplot(gs[1, 1])
        f1_scores = results["raw_scores"]["f1"]
        ax5.hist(f1_scores, bins=30, color="#3498db", alpha=0.7, edgecolor="black")
        ax5.axvline(np.mean(f1_scores), color="red", linestyle="--", linewidth=2, label=f"Mean: {np.mean(f1_scores):.3f}")
        ax5.set_xlabel("F1 Score")
        ax5.set_ylabel("Frequency")
        ax5.set_title("F1 Score Distribution", fontweight="bold", fontsize=12)
        ax5.legend()

        # EM distribution
        ax6 = fig.add_subplot(gs[1, 2])
        em_scores = results["raw_scores"]["em"]
        unique_em, counts_em = np.unique(em_scores, return_counts=True)
        colors = ["#e74c3c" if val == 0 else "#2ecc71" for val in unique_em]
        ax6.bar(["Incorrect", "Correct"], counts_em, color=colors, alpha=0.8, edgecolor="black")
        ax6.set_ylabel("Count")
        ax6.set_title("Exact Match Distribution", fontweight="bold", fontsize=12)
        for i, (val, count) in enumerate(zip(unique_em, counts_em)):
            ax6.text(i, count, f'{count}\n({count/len(em_scores)*100:.1f}%)',
                    ha='center', va='bottom', fontsize=10)

        # Retrieval recall distribution
        ax7 = fig.add_subplot(gs[2, 0])
        ret_recalls = results["raw_scores"]["retrieval_recall"]
        ax7.hist(ret_recalls, bins=30, color="#9b59b6", alpha=0.7, edgecolor="black")
        ax7.axvline(np.mean(ret_recalls), color="red", linestyle="--", linewidth=2,
                   label=f"Mean: {np.mean(ret_recalls):.3f}")
        ax7.set_xlabel("Retrieval Recall")
        ax7.set_ylabel("Frequency")
        ax7.set_title("Retrieval Recall Distribution", fontweight="bold", fontsize=12)
        ax7.legend()

        # Answer type distribution
        ax8 = fig.add_subplot(gs[2, 1])
        if results["by_type"]:
            type_names = list(results["by_type"].keys())
            type_counts = [results["by_type"][t]["count"] for t in type_names]
            colors_pie = plt.cm.Set3(range(len(type_names)))
            ax8.pie(type_counts, labels=type_names, autopct='%1.1f%%', colors=colors_pie, startangle=90)
            ax8.set_title("Answer Type Distribution", fontweight="bold", fontsize=12)

        # Performance summary table
        ax9 = fig.add_subplot(gs[2, 2])
        ax9.axis('tight')
        ax9.axis('off')

        table_data = [
            ["Metric", "Value"],
            ["Overall EM", f"{overall['exact_match']:.3f}"],
            ["Overall F1", f"{overall['f1_score']:.3f}"],
            ["Retrieval Recall", f"{overall['retrieval_recall']:.3f}"],
            ["Total Examples", f"{overall['num_examples']}"]
        ]

        table = ax9.table(cellText=table_data, cellLoc='left', loc='center',
                         colWidths=[0.5, 0.5])
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1, 2)

        # Style header row
        for i in range(2):
            table[(0, i)].set_facecolor('#3498db')
            table[(0, i)].set_text_props(weight='bold', color='white')

        ax9.set_title("Performance Summary", fontweight="bold", fontsize=12, pad=20)

        plt.suptitle("Full-Scale Hybrid RAG Performance Analysis",
                    fontsize=16, fontweight="bold", y=0.98)

        save_path = self.output_dir / filename
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        plt.close()
        logger.info(f"✓ Saved visualization: {save_path}")

    def plot_training_progress(self, train_metrics: Dict, filename="training_progress.png"):
        """Plot training progress if available"""
        # Placeholder for training metrics visualization
        pass


In [None]:
# ======================================================================
# Main Execution
# ======================================================================

def main():
    print("\n" + "="*70)
    print(" "*15 + "FULL-SCALE HYBRID RAG PIPELINE")
    print(" "*10 + "Extractive + Generative with Fine-tuning")
    print("="*70 + "\n")

    # Configuration
    config = RAGConfig(
        use_full_dataset=True,
        train_retriever=True,
        train_extractive=True,
        num_train_examples=1000,  # Use 10K for training (adjust based on resources)
        num_val_examples=100,     # Use 1K for validation
        use_gpu=torch.cuda.is_available(),
        save_checkpoints=True
    )

    # Display configuration
    print("Configuration:")
    print(f"  GPU Available: {config.use_gpu}")
    if config.use_gpu:
        print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  Train Examples: {config.num_train_examples if config.num_train_examples > 0 else 'ALL'}")
    print(f"  Val Examples: {config.num_val_examples if config.num_val_examples > 0 else 'ALL'}")
    print(f"  Train Retriever: {config.train_retriever}")
    print(f"  Train Extractive QA: {config.train_extractive}")
    print()

    # Initialize pipeline
    pipeline = FullHybridRAGPipeline(config)

    # Load data
    train_data, val_data = pipeline.load_and_process_data()

    # Build document index
    all_examples = train_data + val_data
    pipeline.build_document_index(all_examples)

    # Training phase
    if config.train_retriever or config.train_extractive:
        pipeline.train(train_data, val_data)
    else:
        logger.info("Skipping training (training disabled in config)")

    # Evaluation phase
    output_file = Path(config.output_dir) / "full_predictions.jsonl"
    results = pipeline.evaluate(val_data, str(output_file))

    # Evaluate metrics
    evaluator = EnhancedRAGEvaluator()
    metrics = evaluator.evaluate_results(str(output_file), val_data)

    # Save metrics
    metrics_file = Path(config.output_dir) / "evaluation_metrics.json"
    with open(metrics_file, "w") as f:
        json.dump(metrics, f, indent=2)
    logger.info(f"✓ Saved metrics to {metrics_file}")

    # Print results
    print("\n" + "="*70)
    print(" "*25 + "FINAL RESULTS")
    print("="*70)
    print(f"\nOverall Performance:")
    print(f"  Exact Match (EM):     {metrics['overall']['exact_match']:.4f}")
    print(f"  F1 Score:             {metrics['overall']['f1_score']:.4f}")
    print(f"  Retrieval Recall:     {metrics['overall']['retrieval_recall']:.4f}")
    print(f"  Total Examples:       {metrics['overall']['num_examples']}")

    print(f"\nPerformance by Answer Type:")
    for ans_type, scores in metrics["by_type"].items():
        print(f"  {ans_type.capitalize():15} - EM: {scores['exact_match']:.4f}, "
              f"F1: {scores['f1_score']:.4f}, Count: {scores['count']}")

    print(f"\nPerformance by Question Type:")
    for q_type, scores in metrics["by_question_type"].items():
        print(f"  {q_type.capitalize():15} - EM: {scores['exact_match']:.4f}, "
              f"F1: {scores['f1_score']:.4f}, Count: {scores['count']}")

    print(f"\nPerformance by Difficulty Level:")
    for level, scores in metrics["by_level"].items():
        print(f"  {level.capitalize():15} - EM: {scores['exact_match']:.4f}, "
              f"F1: {scores['f1_score']:.4f}, Count: {scores['count']}")

    # Visualize
    visualizer = EnhancedRAGVisualizer(output_dir=config.output_dir)
    visualizer.plot_comprehensive_metrics(metrics)

    print("\n" + "="*70)
    print(" "*20 + "✓ PIPELINE COMPLETE!")
    print("="*70)
    print(f"\nOutputs saved to: {config.output_dir}/")
    print(f"  - Predictions: {output_file}")
    print(f"  - Metrics: {metrics_file}")
    print(f"  - Visualizations: {config.output_dir}/visualizations/")
    if config.save_checkpoints:
        print(f"  - Checkpoints: {config.checkpoint_dir}/")
    print()

if __name__ == "__main__":
    main()


               FULL-SCALE HYBRID RAG PIPELINE
          Extractive + Generative with Fine-tuning

Configuration:
  GPU Available: True
  GPU: Tesla T4
  Train Examples: 1000
  Val Examples: 100
  Train Retriever: True
  Train Extractive QA: True



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Processing: 100%|██████████| 1100/1100 [00:00<00:00, 46947.66it/s]


Batches:   0%|          | 0/336 [00:00<?, ?it/s]

Creating training pairs: 100%|██████████| 1000/1000 [00:02<00:00, 334.05it/s]
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.0712
1000,0.0414
1500,0.0228


Batches:   0%|          | 0/336 [00:00<?, ?it/s]

Creating QA pairs: 100%|██████████| 1000/1000 [00:00<00:00, 104437.24it/s]


Map:   0%|          | 0/1254 [00:00<?, ? examples/s]

Step,Training Loss
100,1.5035
200,0.9107
300,0.7635
400,0.565


Evaluating: 100%|██████████| 100/100 [10:02<00:00,  6.02s/it]



                         FINAL RESULTS

Overall Performance:
  Exact Match (EM):     0.0800
  F1 Score:             0.1447
  Retrieval Recall:     0.6400
  Total Examples:       100

Performance by Answer Type:
  Generative      - EM: 0.0247, F1: 0.0947, Count: 81
  Extractive      - EM: 0.3333, F1: 0.3778, Count: 18
  None            - EM: 0.0000, F1: 0.0000, Count: 1

Performance by Question Type:
  Comparison      - EM: 0.0000, F1: 0.0383, Count: 21
  Bridge          - EM: 0.1013, F1: 0.1730, Count: 79

Performance by Difficulty Level:
  Hard            - EM: 0.0800, F1: 0.1447, Count: 100

                    ✓ PIPELINE COMPLETE!

Outputs saved to: outputs/
  - Predictions: outputs/full_predictions.jsonl
  - Metrics: outputs/evaluation_metrics.json
  - Visualizations: outputs/visualizations/
  - Checkpoints: checkpoints/

