In [None]:
%%capture
!pip install faiss-cpu
!pip install rank_bm25

In [None]:
import csv
import re
import json
import random
import numpy as np
import torch

from collections import defaultdict
from datasets import load_dataset
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
import faiss

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

## Datu sagatavošana

In [None]:
dataset = load_dataset("squad", split="train")

In [None]:
titles = dataset.unique("title")
len(titles)

In [None]:
SEED = 42
N_SAMPLES = 50
K_VALUES = [1, 3, 5, 10]

In [None]:
by_title = defaultdict(list)
for ex in dataset:
    by_title[ex["title"]].append(ex)

rng = random.Random(SEED)
titles = sorted(by_title.keys())
selected_titles = rng.sample(titles, N_SAMPLES)

eval_samples = []
for t in selected_titles:
    eval_samples.append(by_title[t][0])

In [None]:
with open("KD-RAG-eval.json", "w") as f:
    json.dump(eval_samples, f, indent=2)

In [None]:
contexts = [ex["context"] for ex in eval_samples]

In [None]:
def chunk_text(text, chunk_size=80, overlap=20):
    words = text.split()
    chunks = []
    start = 0
    cid = 0

    while start < len(words):
        end = min(start + chunk_size, len(words))
        chunks.append({
            "chunk_id": cid,
            "text": " ".join(words[start:end]),
            "word_start": start,
            "word_end": end
        })
        cid += 1
        start += chunk_size - overlap

    return chunks

def char_to_word_span(text, char_start, char_end):
    words = text.split()
    pos = 0

    for i, w in enumerate(words):
        w_start = text.find(w, pos)
        w_end = w_start + len(w)
        pos = w_end

        if w_end > char_start:
            ws = i
            break

    for j in range(i, len(words)):
        w_start = text.find(words[j], pos)
        if w_start >= char_end:
            we = j
            break
        we = j + 1

    return ws, we

def overlaps(chunk, ans_ws, ans_we):
    return not (
        ans_we <= chunk["word_start"]
        or ans_ws >= chunk["word_end"]
    )


In [None]:
lengths = np.array([len(c.split()) for c in contexts])
print("N contexts:", len(lengths))
print("mean:", lengths.mean())
print("median:", np.median(lengths))
print("p75:", np.percentile(lengths, 75))
print("p90:", np.percentile(lengths, 90))
print("max:", lengths.max())


In [None]:
all_chunks = []
chunks_mapped = []
for ctx_id, ctx in enumerate(contexts):
    ctx_chunks = chunk_text(ctx)
    for ch in ctx_chunks:
        all_chunks.append(ch)
        chunks_mapped.append({
            "context_id": ctx_id
        })

In [None]:
len(all_chunks) == len(chunks_mapped)

In [None]:
question_data = []

for i, ex in enumerate(eval_samples):
    question_data.append({
        "question": ex["question"],
        "answers": ex["answers"],
        "context_id": i,
        "context": ex["context"]
    })

In [None]:
chunks_by_context = defaultdict(list)
for i, meta in enumerate(chunks_mapped):
    chunks_by_context[meta["context_id"]].append(i)

In [None]:
def char_to_word_span(text, char_start, char_end):
    words = text.split()
    pos = 0
    spans = []

    for w in words:
        s = text.find(w, pos)
        e = s + len(w)
        spans.append((s, e))
        pos = e

    ws = we = None
    for i, (s, e) in enumerate(spans):
        if ws is None and e > char_start:
            ws = i
        if s < char_end:
            we = i + 1

    return ws, we

def overlaps(chunk, ans_ws, ans_we):
    return not (
        ans_we <= chunk["word_start"]
        or ans_ws >= chunk["word_end"]
    )

In [None]:
question_to_chunks = {}

for q in question_data:
    relevant = set()

    for ans_text, ans_start in zip(
        q["answers"]["text"], q["answers"]["answer_start"]
    ):
        ans_end = ans_start + len(ans_text)
        ws, we = char_to_word_span(q["context"], ans_start, ans_end)

        for chunk_idx in chunks_by_context[q["context_id"]]:
            chunk = all_chunks[chunk_idx]
            if overlaps(chunk, ws, we):
                relevant.add(chunk_idx)

    question_to_chunks[q["question"]] = sorted(relevant)


## Izgūšanas komponentes bāzlīnijas izvērtēšana

#### Leksiskās izgūšanas metodes izvērtēšana

In [None]:
bm25_corpus = [ch["text"].lower().split() for ch in all_chunks]
bm25 = BM25Okapi(bm25_corpus)

In [None]:
def bm25_retrieve(query, k):
    scores = bm25.get_scores(query.lower().split())
    topk = np.argsort(scores)[::-1][:k]
    return list(topk)

In [None]:
def recall_at_k(retrieved, relevant):
    if not relevant:
        return 0.0
    return len(set(retrieved) & set(relevant)) / len(relevant)

def precision_at_k(retrieved, relevant, k):
    if not relevant:
        return 0.0
    return len(set(retrieved) & set(relevant)) / k

In [None]:
bm25_recall = {k: [] for k in K_VALUES}
bm25_precision = {k: [] for k in K_VALUES}

for q in question_data:
    query = q["question"]
    relevant = question_to_chunks[query]

    for k in K_VALUES:
        retrieved = bm25_retrieve(query, k)
        bm25_recall[k].append(recall_at_k(retrieved, relevant))
        bm25_precision[k].append(precision_at_k(retrieved, relevant, k))

#### Semantiskās, blīvās izgūšanas bāzlīnijas izvērtēšana

In [None]:
embedder = SentenceTransformer("all-MiniLM-L6-v2")

In [None]:
chunk_texts = [ch["text"] for ch in all_chunks]

In [None]:
chunk_embeddings = embedder.encode(
    chunk_texts,
    convert_to_numpy=True,
    show_progress_bar=True
)

In [None]:
print(chunk_embeddings.shape)

In [None]:
faiss.normalize_L2(chunk_embeddings)
dim = chunk_embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(chunk_embeddings)

In [None]:
def dense_retrieve(query, k):
    q_emb = embedder.encode([query], convert_to_numpy=True)
    faiss.normalize_L2(q_emb)
    scores, indices = index.search(q_emb, k)
    return list(indices[0])

In [None]:
dense_recall = {k: [] for k in K_VALUES}
dense_precision = {k: [] for k in K_VALUES}

for q in question_data:
    query = q["question"]
    relevant = question_to_chunks[query]

    for k in K_VALUES:
        retrieved = dense_retrieve(query, k)
        dense_recall[k].append(recall_at_k(retrieved, relevant))
        dense_precision[k].append(precision_at_k(retrieved, relevant, k))

In [None]:
print("\nBM25")
for k in K_VALUES:
    print(
        f"@{k}: Recall={np.mean(bm25_recall[k]):.3f} | "
        f"Precision={np.mean(bm25_precision[k]):.3f}"
    )

print("\nDense")
for k in K_VALUES:
    print(
        f"@{k}: Recall={np.mean(dense_recall[k]):.3f} | "
        f"Precision={np.mean(dense_precision[k]):.3f}"
    )


#### Izgūšanas komponentes bāzlīnijas izvērtēšana

In [None]:
MODEL = "google/flan-t5-base"
device = "cuda" if torch.cuda.is_available() else "cpu"

tok = AutoTokenizer.from_pretrained(MODEL)
llm = AutoModelForSeq2SeqLM.from_pretrained(MODEL).to(device)
llm.eval()

In [None]:
@torch.inference_mode()
def generate_answer(prompt, max_new_tokens=32):
    inputs = tok(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=512
    ).to(device)

    out = llm.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        num_beams=1
    )
    return tok.decode(out[0], skip_special_tokens=True).strip()

In [None]:
def build_prompt_no_rag(question):
    return (
        "Answer the question with a short factual phrase.\n"
        f"Question: {question}\n"
        "Answer:"
    )

def build_prompt_rag(question, retrieved_chunk_texts):
    context_block = "\n\n".join(retrieved_chunk_texts)
    return (
        "Use ONLY the context provided below to answer with a short factual phrase.\n"
        "If the answer is NOT in the context, say: unknown.\n\n"
        f"Context:\n{context_block}\n\n"
        f"Question: {question}\n"
        "Answer:"
    )

In [None]:
K_RAG = 3
results = []

for q in question_data:
    question = q["question"]
    gold_answers = q["answers"]["text"]

    prompt_no_rag = build_prompt_no_rag(question)
    pred_no_rag = generate_answer(prompt_no_rag)

    retrieved_ids = dense_retrieve(question, K_RAG)
    retrieved_chunks = [all_chunks[i]["text"] for i in retrieved_ids]

    prompt_rag = build_prompt_rag(question, retrieved_chunks)
    pred_rag = generate_answer(prompt_rag)

    results.append({
        "question": question,
        "gold_answers": gold_answers,
        "no_rag_answer": pred_no_rag,
        "rag_answer": pred_rag,
        "retrieved_chunks": retrieved_chunks
    })


In [None]:
import csv

with open("rag_eval.csv", "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow([
        "question",
        "gold_answers",
        "no_rag_answer",
        "rag_answer"
    ])

    for r in results:
        writer.writerow([
            r["question"],
            " | ".join(r["gold_answers"]),
            r["no_rag_answer"],
            r["rag_answer"]
        ])


with open("rag_eval_full.json", "w", encoding="utf-8") as f:
    json.dump(results, f, indent=2, ensure_ascii=False)

Eksportētie dati un atbilstošās metrikas tiek analizētas manuāli