# FAISS + Fusion

# Notebook Summary: Multi-Chunk Fusion RAG Evaluation (Setup 6)

This notebook implements a centralized RAG pipeline that fuses multiple top-ranked context chunks into a single, dense prompt for improved factual grounding in telecom QA.

### Key Features:

1. **FAISS Retrieval (Top-k = 6)**  
   Retrieves six relevant document chunks using MiniLM embeddings and cosine similarity.

2. **Fusion Prompt Construction**  
   Combines all retrieved chunks into a single prompt, each annotated with its source filename. This enables the model to reason over richer context spans.

3. **Model Inference with LLaMA-2**  
   Uses a LoRA-fine-tuned LLaMA-2 model to generate extractive answers from the fused prompt.

4. **Answer Cleaning & Evaluation**  
   Applies standard postprocessing and evaluates using:
   - **SQuAD (Exact Match, F1)**
   - **ROUGE-L**
   - **BLEU**

This variant tests whether high-context fusion improves precision in telecom QA without complex reranking or decomposition logic.

In [1]:
# Imports
import re
import pickle
import faiss
import numpy as np
from pathlib import Path
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# Load FAISS index and chunks
index_path = "/mnt/data/RAG/3gpp_index.faiss"
chunks_path = "/mnt/data/RAG/3gpp_chunks.pkl"

index = faiss.read_index(index_path)
with open(chunks_path, "rb") as f:
    documents = pickle.load(f)

In [2]:
# Load embedding model
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

# Retrieve top-k context chunks from FAISS
def retrieve_context(query, top_k=6):
    query_emb = embedding_model.encode([query], normalize_embeddings=True)
    D, I = index.search(query_emb.astype("float32"), top_k)
    return [documents[i] for i in I[0]]

In [3]:
# Define the multi-chunk fusion prompt
SYSTEM_PROMPT = (
    "You are a precise assistant. Extract the exact answer span from the context. "
    "Do not paraphrase, summarize, or add extra information. "
    "The answer must appear exactly in the context. "
    "If the context lists multiple conditions, actions, or branches, include them all as written. "
    "Do not summarize or paraphrase — copy the exact text from the context, line by line."
)


In [4]:
def build_fusion_prompt(context_chunks, question):
    context_lines = []
    for chunk in context_chunks:
        source = chunk.get("source", "unknown").split("/")[-1]
        context_lines.append(f"[Source: {source}]\n-----\n{chunk['content'].strip()}")
    fused_context = "\n\n".join(context_lines)

    user_prompt = (
        f"Context:\n{fused_context}\n\n"
        f"Question: {question}\n"
        f"Answer from the context only:"
    )
    return f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n{user_prompt} [/INST]"

# Clean LLaMA-2 output
def clean_prediction(raw_text):
    answer = raw_text.split("[/INST]")[-1].strip()
    answer = re.sub(r"[^\w\s\-.,:/()]", "", answer)
    answer = re.sub(r'(\b.+?:)(\s*\1)+', r'\1', answer)

    tokens = answer.split()
    for i in range(1, len(tokens) // 2):
        if tokens[:i] == tokens[i:2*i]:
            answer = " ".join(tokens[:i])
            break

    sentence_end = re.search(r'[.?!]', answer)
    if sentence_end:
        answer = answer[:sentence_end.end()]
    return answer.strip()

In [5]:
import torch
# Load LLaMA-2 model + tokenizer
model_path = "/mnt/data/llama2_qa_lora_output5/final"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).to("cuda")

qa_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0


In [6]:
# Full RAG Inference Function
def answer_with_fusion_rag(question, top_k=6, verbose=False):
    chunks = retrieve_context(question, top_k=top_k)
    prompt = build_fusion_prompt(chunks, question)
    
    output = qa_pipeline(
        prompt,
        max_new_tokens=160,
        do_sample=False,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id
    )[0]["generated_text"]

    answer = clean_prediction(output)

    if verbose:
        print("📌 Prompt (truncated):\n", prompt[:500], "...\n")
        print("🧾 Raw Output:\n", output)
        print("✅ Final Answer:\n", answer)
        for i, c in enumerate(chunks):
            print(f"\n--- Context {i+1} ---\n{c['content'][:300]}...\n")

    return answer, chunks

In [7]:
import json
from tqdm import tqdm
from evaluate import load

# Load QA pairs
def load_qa_pairs(path):
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line) for line in f]

qa_pairs = load_qa_pairs("3gpp_qa_100_pairs.jsonl")

# Load metrics
squad_metric = load("squad")
rouge = load("rouge")
bleu = load("bleu")

bleu_predictions = []
bleu_references = []
results = []

for sample in tqdm(qa_pairs):
    question = sample["question"]
    reference = sample["answer"]

    try:
        prediction, _ = answer_with_fusion_rag(question)
    except Exception as e:
        print(f"⚠️ Error on: {question}\n{e}")
        prediction = ""

    # Add to metrics
    squad_metric.add(
        prediction={"id": str(hash(question)), "prediction_text": prediction},
        reference={"id": str(hash(question)), "answers": {"text": [reference], "answer_start": [0]}}
    )
    rouge.add(prediction=prediction, reference=reference)
    bleu_predictions.append(prediction)
    bleu_references.append([reference])
    results.append({
        "question": question,
        "reference": reference,
        "prediction": prediction
    })

# Compute final scores
squad_scores = squad_metric.compute()
rouge_scores = rouge.compute()
bleu_score = bleu.compute(predictions=bleu_predictions, references=bleu_references)["bleu"]

# Print results
print("\n📊 Final Evaluation Results (Setup 6 — Multi-Chunk Fusion Prompt):")
print(f"Exact Match (EM): {squad_scores['exact_match']:.2f}")
print(f"F1 Score        : {squad_scores['f1']:.2f}")
print(f"ROUGE-L         : {rouge_scores['rougeL']:.4f}")
print(f"BLEU            : {bleu_score:.4f}")

  0%|                                                   | 0/100 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  1%|▍                                          | 1/100 [00:02<04:30,  2.73s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  2%|▊                                          | 2/100 [00:11<09:50,  6.03s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  3%|█▎                                         | 3/100 [00:19<11:32,  7.14s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  4%|█▋                                         | 4/100 [00:27<12:08,  7.59s/it]The following generation fla


📊 Final Evaluation Results (Setup 6 — Multi-Chunk Fusion Prompt):
Exact Match (EM): 0.00
F1 Score        : 19.37
ROUGE-L         : 0.2081
BLEU            : 0.0192



