# 1) Qdrant for persistent, scalable vector retrieval
# 2) Cross-Encoder to rerank retrieved chunks
# 3) Your fine-tuned LLaMA-2 model (via Hugging Face pipeline) for precise extractive QA

## Notebook Summary: Qdrant + Cross-Encoder RAG Evaluation (Setup 5)

This notebook evaluates a Qdrant-powered RAG pipeline for telecom-specific extractive QA using a fine-tuned LLaMA-2 model.

### Pipeline Overview:

1. **Qdrant Retrieval**  
   Retrieves top-k chunks from a persistent Qdrant vector store using MiniLM embeddings and cosine similarity.

2. **Cross-Encoder Reranking**  
   Applies `cross-encoder/ms-marco-MiniLM-L-6-v2` to rerank retrieved chunks for better alignment with the query intent.

3. **Compound Question Support**  
   Splits multi-part queries into subquestions and answers them independently. Combines sub-answers into a final output.

4. **Prompting and Inference**  
   Uses an extractive prompt format compatible with LLaMA-2 and performs generation using a LoRA-fine-tuned QA model on GPU.

5. **Evaluation Metrics**  
   Computes:
   - **Exact Match (EM)** and **F1** (SQuAD)
   - **ROUGE-L**
   - **BLEU**

This setup demonstrates a high-precision, production-ready RAG architecture using Qdrant and cross-encoder reranking for improved factual grounding in telecom QA tasks.

In [1]:
import torch
torch.cuda.empty_cache()

In [2]:
from qdrant_client import QdrantClient
from qdrant_client.models import Distance
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch, re

# Qdrant connection
client = QdrantClient(host="localhost", port=6333)

# Embedding model (same as used for indexing)
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

# Cross-Encoder for reranking
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

# Load fine-tuned LLaMA-2 QA model
model_path = "/mnt/data/llama2_qa_lora_output5/final"
tokenizer = AutoTokenizer.from_pretrained(model_path)
qa_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).to("cuda")

qa_pipeline = pipeline("text-generation", model=qa_model, tokenizer=tokenizer)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cuda:0


In [3]:
def retrieve_with_qdrant_rerank(question, top_k=5):
    query_vec = embedding_model.encode(question, normalize_embeddings=True).tolist()

    # Search wider in Qdrant (top_k * 2)
    results = client.search(
        collection_name="3gpp_chunks",
        query_vector=query_vec,
        limit=top_k * 2
    )

    initial = [{
        "content": r.payload["content"],
        "source": r.payload["source"]
    } for r in results]

    # Rerank using Cross-Encoder
    pairs = [(question, doc["content"]) for doc in initial]
    scores = reranker.predict(pairs)
    reranked = sorted(zip(scores, initial), key=lambda x: x[0], reverse=True)[:top_k]

    return [doc for _, doc in reranked]

In [4]:
SYSTEM_PROMPT = (
    "You are a precise assistant. Extract the exact answer span from the context. "
    "Do not paraphrase, summarize, or add extra information. "
    "The answer must appear exactly in the context. "
    "If the context lists multiple conditions, actions, or branches, include them all as written."
)

def build_rag_prompt(context_chunks, question):
    combined_context = "\n\n".join([chunk['content'] for chunk in context_chunks])
    user_prompt = (
        f"Context: {combined_context}\n\n"
        f"Question: {question}\n"
        f"Answer from the context only:"
    )
    return f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n{user_prompt} [/INST]"

In [5]:
def clean_prediction(raw_text):
    answer = raw_text.split("[/INST]")[-1].strip()
    answer = re.sub(r"[^\w\s\-.,:/()]", "", answer)
    answer = re.sub(r'(\b.+?:)(\s*\1)+', r'\1', answer)

    tokens = answer.split()
    for i in range(1, len(tokens) // 2):
        if tokens[:i] == tokens[i:2*i]:
            return " ".join(tokens[:i])

    sentence_end = re.search(r'[.?!]', answer)
    if sentence_end:
        answer = answer[:sentence_end.end()]

    return answer.strip()

In [6]:
def split_compound_question(q):
    parts = re.split(r"\band\b|\bor\b|[,;]", q)
    return [p.strip() for p in parts if len(p.strip().split()) > 3]

def answer_with_qdrant_llama(question, top_k=5, verbose=False):
    retrieved = retrieve_with_qdrant_rerank(question, top_k=top_k)
    sub_qs = split_compound_question(question)

    if len(sub_qs) > 1:
        answers = []
        for sq in sub_qs:
            sub_prompt = build_rag_prompt(retrieved, sq)
            raw = qa_pipeline(sub_prompt, max_new_tokens=160, do_sample=False,
                              eos_token_id=tokenizer.eos_token_id,
                              pad_token_id=tokenizer.eos_token_id)[0]["generated_text"]
            ans = clean_prediction(raw)
            answers.append(f"→ {sq}: {ans}")

        final = "\n".join(answers)
        if verbose:
            print("\n".join([f"Context {i+1}:\n{c['content']}" for i, c in enumerate(retrieved)]))
        return final, retrieved

    # Simple case
    prompt = build_rag_prompt(retrieved, question)
    raw = qa_pipeline(prompt, max_new_tokens=160, do_sample=False,
                      eos_token_id=tokenizer.eos_token_id,
                      pad_token_id=tokenizer.eos_token_id)[0]["generated_text"]
    answer = clean_prediction(raw)

    if verbose:
        print("📌 Prompt:\n", prompt)
        print("\n🧾 Raw Output:\n", raw)
        print("\n✅ Cleaned Answer:", answer)
        for i, chunk in enumerate(retrieved):
            print(f"\n--- Context {i+1} ---\n{chunk['content']}")

    return answer, retrieved

In [7]:
import json
from tqdm import tqdm
from evaluate import load

# Load QA pairs
def load_qa_pairs(path):
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line) for line in f]

qa_pairs = load_qa_pairs("3gpp_qa_100_pairs.jsonl")

# Load metrics
squad_metric = load("squad")
rouge = load("rouge")
bleu = load("bleu")

bleu_predictions = []
bleu_references = []
results = []

for sample in tqdm(qa_pairs):
    question = sample["question"]
    reference = sample["answer"]

    try:
        prediction, _ = answer_with_qdrant_llama(question)
    except Exception as e:
        print(f"⚠️ Error on: {question}\n{e}")
        prediction = ""

    # Add to metrics
    squad_metric.add(
        prediction={"id": str(hash(question)), "prediction_text": prediction},
        reference={"id": str(hash(question)), "answers": {"text": [reference], "answer_start": [0]}}
    )
    rouge.add(prediction=prediction, reference=reference)
    bleu_predictions.append(prediction)
    bleu_references.append([reference])
    results.append({
        "question": question,
        "reference": reference,
        "prediction": prediction
    })

# Compute final scores
squad_scores = squad_metric.compute()
rouge_scores = rouge.compute()
bleu_score = bleu.compute(predictions=bleu_predictions, references=bleu_references)["bleu"]

# Print results
print("\n📊 Final Evaluation Results (Setup 5 — Qdrant + Cross-Encoder + Compound):")
print(f"Exact Match (EM): {squad_scores['exact_match']:.2f}")
print(f"F1 Score        : {squad_scores['f1']:.2f}")
print(f"ROUGE-L         : {rouge_scores['rougeL']:.4f}")
print(f"BLEU            : {bleu_score:.4f}")

  results = client.search(
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  1%|▍                                          | 1/100 [00:08<13:27,  8.16s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  2%|▊                                          | 2/100 [00:15<12:54,  7.91s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  3%|█▎                                         | 3/100 [00:23<12:37,  7.81s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  4%|█▋                                         | 4/100 [00:31<12:26,  7.78s/it]The following generation flags are not valid and may be ignored: ['temperature', 


📊 Final Evaluation Results (Setup 5 — Qdrant + Cross-Encoder + Compound):
Exact Match (EM): 1.00
F1 Score        : 19.21
ROUGE-L         : 0.2118
BLEU            : 0.0419



