## Notebook Summary: RAG + LLaMA-2 Evaluation for Telecom QA

This notebook evaluates a fine-tuned LLaMA-2 model integrated with a centralized FAISS-based RAG system for extractive telecom question answering.

### Key Steps:

1. **Context Retrieval**  
   Uses FAISS to retrieve top-k chunks for each query from embedded 3GPP documents. Chunks are reranked using a hybrid score combining lexical overlap, TF-IDF, and FAISS recall rank.

2. **Prompt Construction**  
   Assembles a prompt with system instructions and the selected context chunks using the `[INST]` format expected by LLaMA-2 with LoRA.

3. **Model Inference**  
   Generates extractive answers using a LoRA-fine-tuned LLaMA-2 model loaded on GPU, optimized for telecom QA tasks.

4. **Prediction Cleaning**  
   Applies regex-based filters to remove formatting noise, loops, and low-confidence patterns from model output.

5. **Evaluation**  
   Compares model predictions to 100 ground-truth QA pairs using:
   - **Exact Match** and **F1** (SQuAD)
   - **ROUGE-L** for overlap quality
   - **BLEU** for syntactic similarity

This notebook validates the effectiveness of centralized RAG in grounding telecom-specific LLM outputs and forms a baseline for later federated enhancements in the thesis workflow.

| Setup | Retrieval Backend | Context Strategy | Enhancements | Safety / Noise Control | Key Aim |
|-------|-------------------|------------------|--------------|------------------------|---------|
| **1** *(Baseline)* | FAISS only | Single chunk prompt | None | None | Simple baseline QA |
| **2** *(Heuristic rerank)* | FAISS | Single chunk prompt | TF-IDF + lexical overlap reranking | None | Improve chunk precision |
| **3** *(Compound QA)* | FAISS | Per-question reranked chunks | Compound question splitting | Basic containment check | Handle multi-part queries |
| **4** *(Procedural routing)* | FAISS + cross-encoder | Per-question reranked chunks | Cross-encoder rerank, compound handling, regex-based procedural span extraction | Context containment check | Target procedural queries |
| **5** *(Qdrant backend)* | Qdrant + cross-encoder | Per-question reranked chunks | Cross-encoder rerank | None | Persistent vector DB retrieval |
| **6** *(Fusion prompting)* | FAISS | Fused top-k chunks into one prompt | Source-tagged multi-chunk fusion | None | More complete context in one go |
| **7** *(Fusion + cross-encoder)* | FAISS + cross-encoder | Fused top-k reranked chunks | Source-tagged fusion | None | Combine precision retrieval with dense prompt coverage |
| **8** *(Fusion + cross-encoder + truncation)* | FAISS + cross-encoder | Fusion of top spans | Sliding-window truncation + TF-IDF + lexical scoring | Fuzzy grounding validation | Limit prompt length, ensure grounding |
| **9** *(Fusion + cross-encoder + saliency)* | FAISS + cross-encoder | Fusion of salient sentences | TF-IDF saliency extraction + semantic deduplication | Domain-aware sentence tokenization | Reduce noise and redundancy in prompts |
| **10** *(Confidence-weighted semantic fusion)* | FAISS + cross-encoder | Per-chunk saliency rewrite → per-chunk QA → semantic fusion | Answer clustering + best-answer selection | Deduplication before prompting | Ensemble-like fusion of multiple perspectives |

In [1]:
from pathlib import Path
import faiss
import pickle
import torch
import re
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer

# Load FAISS index and chunks
index_path = "/mnt/data/RAG/3gpp_index.faiss"
chunks_path = "/mnt/data/RAG/3gpp_chunks.pkl"

index = faiss.read_index(index_path)
with open(chunks_path, "rb") as f:
    documents = pickle.load(f)

In [2]:
# Load embedding model used for indexing
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
def retrieve_context(query, top_k=3):
    query_emb = embedding_model.encode([query], normalize_embeddings=True)
    D, I = index.search(query_emb.astype("float32"), top_k)
    return [documents[i] for i in I[0]]

SYSTEM_PROMPT = (
    "You are a precise assistant. Extract the exact answer span from the context. "
    "Do not paraphrase, summarize, or add extra information. "
    "The answer must appear exactly in the context."
)

In [3]:
def build_rag_prompt(context_chunks, question):
    combined_context = "\n\n".join([chunk['content'] for chunk in context_chunks])
    user_prompt = (
        f"Context: {combined_context}\n\n"
        f"Question: {question}\n"
        f"Answer from the context only:"
    )
    return f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n{user_prompt} [/INST]"

model_path = "/mnt/data/llama2_qa_lora_output6/final"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).to("cuda")

qa_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0


In [4]:
def clean_prediction(raw_text):
    # Remove everything before the last [INST]
    answer = raw_text.split("[/INST]")[-1].strip()

    # Remove strange characters
    answer = re.sub(r"[^\w\s\-.,:/()]", "", answer)

    # Remove repeating phrases like "The key is... The key is... The key is..."
    answer = re.sub(r'(\b.+?:)(\s*\1)+', r'\1', answer)

    # Trim repetitive word loops (e.g., "structured as follows" x 5)
    tokens = answer.split()
    for i in range(1, len(tokens) // 2):
        if tokens[:i] == tokens[i:2*i]:
            answer = " ".join(tokens[:i])
            break

    # Optionally truncate to sentence boundary
    sentence_end = re.search(r'[.?!]', answer)
    if sentence_end:
        answer = answer[:sentence_end.end()]

    return answer.strip()

In [5]:
import re
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from nltk.corpus import stopwords
from nltk import word_tokenize
import nltk

STOPWORDS = set(stopwords.words("english"))

import re

def normalize(text):
    return re.sub(r'\W+', ' ', text.lower())

def lexical_overlap(query, chunk):
    q_tokens = set(normalize(query).split()) - STOPWORDS
    c_tokens = set(normalize(chunk).split()) - STOPWORDS
    return len(q_tokens & c_tokens) / (len(q_tokens | c_tokens) + 1e-5)

def tfidf_score(query, chunk, vectorizer=None):
    docs = [query, chunk]
    if not vectorizer:
        vectorizer = TfidfVectorizer().fit(docs)
    vecs = vectorizer.transform(docs)
    return (vecs[0] @ vecs[1].T).A[0][0]

def rerank_chunks(chunks, query, alpha_overlap=0.7, beta_faiss=0.3, top_k=3):
    vectorizer = TfidfVectorizer().fit([query] + [c["content"] for c in chunks])
    reranked = []

    for idx, c in enumerate(chunks):
        overlap = lexical_overlap(query, c["content"])
        tfidf_sim = tfidf_score(query, c["content"], vectorizer)
        faiss_rank_bonus = (len(chunks) - idx) / len(chunks)

        # Final rerank score = weighted combination
        score = alpha_overlap * overlap + (1 - alpha_overlap) * tfidf_sim + beta_faiss * faiss_rank_bonus

        reranked.append((score, c))

    reranked.sort(reverse=True, key=lambda x: x[0])
    return [c for _, c in reranked[:top_k]]

In [6]:
def answer_with_rag_llama(question, top_k=5, verbose=False):
    initial_chunks = retrieve_context(question, top_k=10)  
    retrieved = rerank_chunks(initial_chunks, question, top_k=top_k)
    prompt = build_rag_prompt(retrieved, question)
    
    # Inference
    raw_output = qa_pipeline(
        prompt, 
        max_new_tokens=160, 
        do_sample=False, 
        eos_token_id=tokenizer.eos_token_id, 
        pad_token_id=tokenizer.eos_token_id
    )[0]["generated_text"]

    answer = clean_prediction(raw_output)

    # Sanity check on the answer
    if len(answer.split()) < 2 or len(answer.split()) > 40:
        print("⚠️ Warning: Possibly bad output. Check content or retrieval.")

    if verbose:
        print("📌 Prompt:\n", prompt)
        print("\n🧾 Raw Output:\n", raw_output)
        print("\n✅ Cleaned Answer:", answer)
        for i, chunk in enumerate(retrieved):
            print(f"\n--- Context {i+1} ---")
            print(chunk["content"])

    return answer, retrieved

In [7]:
import json
from tqdm import tqdm
from evaluate import load

# Load QA pairs
def load_qa_pairs(path):
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line) for line in f]

qa_pairs = load_qa_pairs("3gpp_qa_100_pairs.jsonl")

# Load metrics
squad_metric = load("squad")
rouge = load("rouge")
bleu = load("bleu")

# Store all predictions and references for BLEU batch compute
bleu_predictions = []
bleu_references = []

# Store results
results = []

for sample in tqdm(qa_pairs):
    question = sample["question"]
    reference = sample["answer"]

    try:
        prediction, _ = answer_with_rag_llama(question)
    except Exception as e:
        print(f"⚠️ Error on: {question}\n{e}")
        prediction = ""

    # SQuAD metric
    squad_metric.add(
        prediction={"id": str(hash(question)), "prediction_text": prediction},
        reference={"id": str(hash(question)), "answers": {"text": [reference], "answer_start": [0]}}
    )

    # ROUGE
    rouge.add(prediction=prediction, reference=reference)

    # BLEU 
    bleu_predictions.append(prediction)
    bleu_references.append([reference])  # list of references per prediction

    results.append({
        "question": question,
        "reference": reference,
        "prediction": prediction
    })

# Final scores
squad_scores = squad_metric.compute()
rouge_scores = rouge.compute()
bleu_score = bleu.compute(predictions=bleu_predictions, references=bleu_references)["bleu"]

# Display
print("\n📊 Final Evaluation Results:")
print(f"Exact Match (EM): {squad_scores['exact_match']:.2f}")
print(f"F1 Score        : {squad_scores['f1']:.2f}")
print(f"ROUGE-L         : {rouge_scores['rougeL']:.4f}")
print(f"BLEU            : {bleu_score:.4f}")

  0%|                                                   | 0/100 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  1%|▍                                          | 1/100 [00:07<12:56,  7.84s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.




  2%|▊                                          | 2/100 [00:15<12:48,  7.84s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  3%|█▎                                         | 3/100 [00:23<12:43,  7.87s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  4%|█▋                                         | 4/100 [00:31<12:32,  7.84s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  5%|██▏                                        | 5/100 [00:35<10:10,  6.43s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  6%|██▌                                        | 6/100 [00:43<10:45,  6.86s/it]The following generation fla



 13%|█████▍                                    | 13/100 [01:24<09:58,  6.88s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 14%|█████▉                                    | 14/100 [01:32<10:09,  7.09s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.




 15%|██████▎                                   | 15/100 [01:39<10:20,  7.30s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 16%|██████▋                                   | 16/100 [01:47<10:26,  7.46s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 17%|███████▏                                  | 17/100 [01:55<10:18,  7.45s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 18%|███████▌                                  | 18/100 [01:56<07:41,  5.63s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 19%|███████▉                                  | 19/100 [02:04<08:30,  6.30s/it]The following generation fla



 25%|██████████▌                               | 25/100 [02:44<08:37,  6.90s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 26%|██████████▉                               | 26/100 [02:45<06:19,  5.12s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 27%|███████████▎                              | 27/100 [02:53<07:18,  6.01s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 28%|███████████▊                              | 28/100 [03:01<07:54,  6.59s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 29%|████████████▏                             | 29/100 [03:08<08:10,  6.91s/it]The following generation fla



 52%|█████████████████████▊                    | 52/100 [05:38<05:16,  6.60s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 53%|██████████████████████▎                   | 53/100 [05:41<04:09,  5.31s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 54%|██████████████████████▋                   | 54/100 [05:48<04:38,  6.05s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.




 55%|███████████████████████                   | 55/100 [05:50<03:27,  4.61s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 56%|███████████████████████▌                  | 56/100 [05:58<04:06,  5.59s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 57%|███████████████████████▉                  | 57/100 [06:05<04:26,  6.20s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 58%|████████████████████████▎                 | 58/100 [06:13<04:35,  6.56s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 59%|████████████████████████▊                 | 59/100 [06:20<04:44,  6.95s/it]The following generation fla



 71%|█████████████████████████████▊            | 71/100 [07:08<02:13,  4.61s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 72%|██████████████████████████████▏           | 72/100 [07:16<02:36,  5.60s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 73%|██████████████████████████████▋           | 73/100 [07:25<02:55,  6.51s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 74%|███████████████████████████████           | 74/100 [07:33<03:05,  7.13s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 75%|███████████████████████████████▌          | 75/100 [07:34<02:12,  5.30s/it]The following generation fla


📊 Final Evaluation Results:
Exact Match (EM): 0.00
F1 Score        : 19.67
ROUGE-L         : 0.1991
BLEU            : 0.0227





In [8]:
import pandas as pd
import os

df = pd.DataFrame(results)
df.to_csv("/mnt/data/P1.csv", index=False)
print("✅ Saved to rag_qa_predictions.csv")

✅ Saved to rag_qa_predictions.csv
