# CrossEncoder + Fusion + Saliency Rewriting

## Notebook Summary: Saliency-Based Fusion QA with Cross-Encoder Reranking (Setup 9)

This notebook evaluates a high-precision RAG QA pipeline that integrates saliency-based sentence rewriting and deduplication to improve the factual grounding and efficiency of extractive QA for telecom standards.

### Key Features:

1. **Cross-Encoder Reranking**  
   Uses `ms-marco-MiniLM-L-6-v2` to rerank top FAISS chunks based on semantic similarity to the question.

2. **Saliency Sentence Rewriting**  
   For each chunk, extracts the top 3 most relevant sentences using TF-IDF scores relative to the query.

3. **Semantic Deduplication**  
   Applies cosine similarity filtering (≥ 0.95) to remove redundant content and reduce prompt length.

4. **Custom Sentence Splitting**  
   Uses NLTK’s Punkt tokenizer enhanced with domain-specific abbreviations for clean sentence segmentation.

5. **Multi-Chunk Fusion Prompt**  
   Combines rewritten, filtered chunks into a dense prompt passed to the LoRA-fine-tuned LLaMA-2 QA model.

6. **Evaluation Metrics**  
   - **Exact Match (EM)** and **F1** from SQuAD  
   - **ROUGE-L** for overlap quality  
   - **BLEU** for syntactic similarity

This setup (Setup 9) prioritizes saliency, non-redundancy, and contextual clarity — yielding high-quality answers grounded in minimal yet relevant evidence.

In [1]:
# Load FAISS index and chunked docs
import faiss, pickle, torch, re
import numpy as np
from pathlib import Path

index = faiss.read_index("/mnt/data/RAG/3gpp_index.faiss")
with open("/mnt/data/RAG/3gpp_chunks.pkl", "rb") as f:
    documents = pickle.load(f)

In [2]:
# Load embedding + cross-encoder models
from sentence_transformers import SentenceTransformer, CrossEncoder
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

In [3]:
# FAISS + Cross-Encoder Retrieval
def retrieve_with_rerank(query, top_k=5):
    query_vec = embedding_model.encode(query, normalize_embeddings=True)
    query_vec = np.array(query_vec).reshape(1, -1).astype("float32")

    D, I = index.search(query_vec, top_k * 2)
    initial_results = [documents[i] for i in I[0]]
    pairs = [(query, doc["content"]) for doc in initial_results]

    scores = reranker.predict(pairs)
    reranked = sorted(zip(scores, initial_results), key=lambda x: x[0], reverse=True)[:top_k]
    return [doc for _, doc in reranked]

In [4]:
import json
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters

# Load abbreviation map
with open("abbreviation_master_map.json", "r") as f:
    abbrev_dict = json.load(f)

# Convert keys to lowercase and strip dots
abbrevs = set(k.lower().strip(".") for k in abbrev_dict.keys())

In [5]:
# Setup tokenizer with custom abbreviations
punkt_param = PunktParameters()
punkt_param.abbrev_types = abbrevs

sentence_splitter = PunktSentenceTokenizer(punkt_param)

def sent_tokenize(text):
    return sentence_splitter.tokenize(text.strip())

In [6]:
# Saliency Rewriting
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer

def extract_salient_sentences(chunk_text, query, max_sentences=3):
    sentences = sent_tokenize(chunk_text)
    if len(sentences) <= max_sentences:
        return chunk_text.strip()

    vectorizer = TfidfVectorizer().fit([query] + sentences)
    query_vec = vectorizer.transform([query])
    sentence_vecs = vectorizer.transform(sentences)

    scores = (sentence_vecs @ query_vec.T).toarray().flatten()
    top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:max_sentences]
    salient = [sentences[i] for i in sorted(top_indices)]
    return " ".join(salient)

def rewrite_chunks_for_saliency(chunks, query, max_sentences=3, sim_threshold=0.95):
    seen_embeddings = []
    filtered_chunks = []

    for chunk in chunks:
        rewritten = extract_salient_sentences(chunk["content"], query, max_sentences).strip()
        if not rewritten:
            continue

        # Embed rewritten chunk
        emb = embedding_model.encode([rewritten])[0]

        # Check similarity with all previous embeddings
        is_duplicate = any(
            cosine_similarity([emb], [prev_emb])[0][0] >= sim_threshold
            for prev_emb in seen_embeddings
        )

        if not is_duplicate:
            seen_embeddings.append(emb)
            filtered_chunks.append({
                "content": rewritten,
                "source": chunk.get("source", "unknown")
            })

    return filtered_chunks

In [7]:
# Prompt Builder
SYSTEM_PROMPT = (
    "You are a precise assistant. Extract the exact answer span from the context. "
    "Do not paraphrase, summarize, or add extra information. "
    "The answer must appear exactly in the context. "
    "If the context lists multiple conditions, actions, or branches, include them all as written. "
    "Do not summarize or paraphrase — copy the exact text from the context, line by line."
)

def build_fusion_prompt(context_chunks, question):
    context_lines = []
    for chunk in context_chunks:
        source = chunk.get("source", "unknown").split("/")[-1]
        context_lines.append(f"[Source: {source}]\n-----\n{chunk['content'].strip()}")
    fused_context = "\n\n".join(context_lines)

    user_prompt = (
        f"Context:\n{fused_context}\n\n"
        f"Question: {question}\n"
        f"Answer from the context only:"
    )

    return f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n{user_prompt} [/INST]"

In [8]:
# Output Cleaning
def clean_prediction(raw_text):
    answer = raw_text.split("[/INST]")[-1].strip()
    answer = re.sub(r"[^\w\s\-.,:/()]", "", answer)
    answer = re.sub(r'(\b.+?:)(\s*\1)+', r'\1', answer)

    tokens = answer.split()
    for i in range(1, len(tokens) // 2):
        if tokens[:i] == tokens[i:2*i]:
            answer = " ".join(tokens[:i])
            break

    sentence_end = re.search(r'[.?!]', answer)
    if sentence_end:
        answer = answer[:sentence_end.end()]
    return answer.strip()

In [9]:
# Load Fine-Tuned LLaMA-2 + Pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

model_path = "/mnt/data/llama2_qa_lora_output5/final"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).to("cuda")

qa_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0


In [10]:
# End-to-End Function
def answer_with_saliency_cross_rag(question, top_k=5, max_sentences=3, verbose=False):
    chunks = retrieve_with_rerank(question, top_k=top_k)
    salient_chunks = rewrite_chunks_for_saliency(chunks, question, max_sentences=max_sentences)
    prompt = build_fusion_prompt(salient_chunks, question)

    output = qa_pipeline(
        prompt,
        max_new_tokens=220,
        do_sample=False,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id
    )[0]["generated_text"]

    answer = clean_prediction(output)

    if verbose:
        print("📌 Prompt (truncated):\n", prompt[:600], "...\n")
        print("🧾 Raw Output:\n", output)
        print("✅ Final Answer:\n", answer)
        for i, c in enumerate(salient_chunks):
            print(f"\n--- Context {i+1} ---\n{c['content'][:300]}...\n")

    return answer, salient_chunks

In [11]:
import json
from tqdm import tqdm
from evaluate import load

# Load QA pairs
def load_qa_pairs(path):
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line) for line in f]

qa_pairs = load_qa_pairs("3gpp_qa_100_pairs.jsonl")

# Load metrics
squad_metric = load("squad")
rouge = load("rouge")
bleu = load("bleu")

bleu_predictions = []
bleu_references = []
results = []

for sample in tqdm(qa_pairs):
    question = sample["question"]
    reference = sample["answer"]

    try:
        prediction, _ = answer_with_saliency_cross_rag(question)
    except Exception as e:
        print(f"⚠️ Error on: {question}\n{e}")
        prediction = ""

    # Add to metrics
    squad_metric.add(
        prediction={"id": str(hash(question)), "prediction_text": prediction},
        reference={"id": str(hash(question)), "answers": {"text": [reference], "answer_start": [0]}}
    )
    rouge.add(prediction=prediction, reference=reference)
    bleu_predictions.append(prediction)
    bleu_references.append([reference])
    results.append({
        "question": question,
        "reference": reference,
        "prediction": prediction
    })

# Compute final scores
squad_scores = squad_metric.compute()
rouge_scores = rouge.compute()
bleu_score = bleu.compute(predictions=bleu_predictions, references=bleu_references)["bleu"]

# Print results
print("\n📊 Final Evaluation Results (Setup 9 — Cross-Encoder + Fusion + Saliency):")
print(f"Exact Match (EM): {squad_scores['exact_match']:.2f}")
print(f"F1 Score        : {squad_scores['f1']:.2f}")
print(f"ROUGE-L         : {rouge_scores['rougeL']:.4f}")
print(f"BLEU            : {bleu_score:.4f}")

  0%|                                                   | 0/100 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  1%|▍                                          | 1/100 [00:09<15:52,  9.62s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  2%|▊                                          | 2/100 [00:10<07:05,  4.34s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  3%|█▎                                         | 3/100 [00:13<06:19,  3.92s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  4%|█▋                                         | 4/100 [00:14<04:35,  2.87s/it]The following generation fla


📊 Final Evaluation Results (Setup 9 — Cross-Encoder + Fusion + Saliency):
Exact Match (EM): 1.00
F1 Score        : 20.00
ROUGE-L         : 0.2079
BLEU            : 0.0218



