# Cross-Encoder + Fusion

Combining the best of previous setups into a hybrid architecture. It integrates:

✅ Cross-encoder reranking for high-precision chunk scoring (like Setup 4 & 5),

✅ Multi-chunk fusion prompting for richer contextual grounding (like Setup 6),

🚫 No compound splitting or procedural routing — the goal is high-quality fused prompts from the top semantically matched chunks.

## Notebook Summary: Fusion + Cross-Encoder RAG Evaluation (Setup 7)

This notebook evaluates an enhanced extractive QA pipeline for telecom documents using a hybrid RAG setup that combines cross-encoder reranking with multi-chunk fusion prompting.

### Key Features:

1. **FAISS + Cross-Encoder Retrieval**  
   Chunks are first retrieved using FAISS and then reranked with `cross-encoder/ms-marco-MiniLM-L-6-v2` for semantic alignment.

2. **Multi-Chunk Fusion Prompt**  
   Top-k (default 5) chunks are merged into a single dense prompt with `[Source: filename]` tags. This allows the LLaMA-2 model to extract answers grounded across multiple related sections.

3. **Clean Extraction with LoRA-Fine-Tuned LLaMA-2**  
   The QA model uses extractive prompting with instruction tokens to ensure precise span output.

4. **Evaluation**  
   Model outputs are compared to gold answers using:
   - **SQuAD (EM, F1)**
   - **ROUGE-L**
   - **BLEU**

This setup balances high retrieval precision and dense context coverage, representing the strongest centralized RAG configuration in this project so far.

In [1]:
# Imports
import re
import faiss
import pickle
import numpy as np
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

In [2]:
import torch
# Load FAISS index and chunked docs
index = faiss.read_index("/mnt/data/RAG/3gpp_index.faiss")
with open("/mnt/data/RAG/3gpp_chunks.pkl", "rb") as f:
    documents = pickle.load(f)

# Load embedding + cross-encoder models
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

# FAISS + Cross-Encoder Retrieval
def retrieve_with_rerank(query, top_k=5):
    query_vec = embedding_model.encode(query, normalize_embeddings=True)
    query_vec = np.array(query_vec).reshape(1, -1).astype("float32")

    D, I = index.search(query_vec, top_k * 2)

    initial_results = [documents[i] for i in I[0]]
    pairs = [(query, doc["content"]) for doc in initial_results]

    scores = reranker.predict(pairs)
    reranked = sorted(zip(scores, initial_results), key=lambda x: x[0], reverse=True)[:top_k]

    return [doc for _, doc in reranked]

# Multi-Chunk Fusion Prompt Builder
SYSTEM_PROMPT = (
    "You are a precise assistant. Extract the exact answer span from the context. "
    "Do not paraphrase, summarize, or add extra information. "
    "The answer must appear exactly in the context. "
    "If the context lists multiple conditions, actions, or branches, include them all as written. "
    "Do not summarize or paraphrase — copy the exact text from the context, line by line."
)

def build_fusion_prompt(context_chunks, question):
    context_lines = []
    for chunk in context_chunks:
        source = chunk.get("source", "unknown").split("/")[-1]
        context_lines.append(f"[Source: {source}]\n-----\n{chunk['content'].strip()}")
    fused_context = "\n\n".join(context_lines)

    user_prompt = (
        f"Context:\n{fused_context}\n\n"
        f"Question: {question}\n"
        f"Answer from the context only:"
    )

    return f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n{user_prompt} [/INST]"

# Output Cleaning
def clean_prediction(raw_text):
    answer = raw_text.split("[/INST]")[-1].strip()
    answer = re.sub(r"[^\w\s\-.,:/()]", "", answer)
    answer = re.sub(r'(\b.+?:)(\s*\1)+', r'\1', answer)

    tokens = answer.split()
    for i in range(1, len(tokens) // 2):
        if tokens[:i] == tokens[i:2*i]:
            answer = " ".join(tokens[:i])
            break

    sentence_end = re.search(r'[.?!]', answer)
    if sentence_end:
        answer = answer[:sentence_end.end()]
    return answer.strip()

# Load Fine-Tuned LLaMA-2 + Pipeline
model_path = "/mnt/data/llama2_qa_lora_output5/final"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).to("cuda")

qa_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)

# End-to-End RAG QA Function
def answer_with_fusion_cross_rag(question, top_k=5, verbose=False):
    chunks = retrieve_with_rerank(question, top_k=top_k)
    prompt = build_fusion_prompt(chunks, question)

    output = qa_pipeline(
        prompt,
        max_new_tokens=160,
        do_sample=False,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id
    )[0]["generated_text"]

    answer = clean_prediction(output)

    if verbose:
        print("📌 Prompt (truncated):\n", prompt[:500], "...\n")
        print("🧾 Raw Output:\n", output)
        print("✅ Final Answer:\n", answer)
        for i, c in enumerate(chunks):
            print(f"\n--- Context {i+1} ---\n{c['content'][:300]}...\n")

    return answer, chunks

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0


In [3]:
import json
from tqdm import tqdm
from evaluate import load

# Load QA pairs
def load_qa_pairs(path):
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line) for line in f]

qa_pairs = load_qa_pairs("3gpp_qa_100_pairs.jsonl")

# Load metrics
squad_metric = load("squad")
rouge = load("rouge")
bleu = load("bleu")

bleu_predictions = []
bleu_references = []
results = []

for sample in tqdm(qa_pairs):
    question = sample["question"]
    reference = sample["answer"]

    try:
        prediction, _ = answer_with_fusion_cross_rag(question)
    except Exception as e:
        print(f"⚠️ Error on: {question}\n{e}")
        prediction = ""

    # Add to metrics
    squad_metric.add(
        prediction={"id": str(hash(question)), "prediction_text": prediction},
        reference={"id": str(hash(question)), "answers": {"text": [reference], "answer_start": [0]}}
    )
    rouge.add(prediction=prediction, reference=reference)
    bleu_predictions.append(prediction)
    bleu_references.append([reference])
    results.append({
        "question": question,
        "reference": reference,
        "prediction": prediction
    })

# Compute final scores
squad_scores = squad_metric.compute()
rouge_scores = rouge.compute()
bleu_score = bleu.compute(predictions=bleu_predictions, references=bleu_references)["bleu"]

# Print results
print("\n📊 Final Evaluation Results (Setup 7 — Cross-Encoder + Multi-Chunk Fusion):")
print(f"Exact Match (EM): {squad_scores['exact_match']:.2f}")
print(f"F1 Score        : {squad_scores['f1']:.2f}")
print(f"ROUGE-L         : {rouge_scores['rougeL']:.4f}")
print(f"BLEU            : {bleu_score:.4f}")

  0%|                                                   | 0/100 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  1%|▍                                          | 1/100 [00:03<06:02,  3.66s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  2%|▊                                          | 2/100 [00:04<03:43,  2.28s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  3%|█▎                                         | 3/100 [00:12<07:52,  4.88s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  4%|█▋                                         | 4/100 [00:14<05:24,  3.38s/it]The following generation fla


📊 Final Evaluation Results (Setup 7 — Cross-Encoder + Multi-Chunk Fusion):
Exact Match (EM): 0.00
F1 Score        : 20.61
ROUGE-L         : 0.2166
BLEU            : 0.0291
