In [5]:
import os
import json
from sentence_transformers import SentenceTransformer, util
import torch
import numpy as np
from tqdm.auto import tqdm


DATA_PROCESSED_PATH = "../data/processed"
FINETUNED_MODEL_PATH = "../models/retriever_finetuned/best-model"

WIKIPEDIA_CHUNKS_FILE = "wikipedia_chunks.jsonl"
CORPUS_EMBEDDINGS_FILE = "corpus_embeddings_finetuned.npy"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")


wikipedia_chunks_path = os.path.join(DATA_PROCESSED_PATH, WIKIPEDIA_CHUNKS_FILE)
corpus_embeddings_path = os.path.join(DATA_PROCESSED_PATH, CORPUS_EMBEDDINGS_FILE)

if not os.path.exists(wikipedia_chunks_path):
    raise FileNotFoundError(f"Wikipedia chunks file not found: {wikipedia_chunks_path}")
if not os.path.exists(FINETUNED_MODEL_PATH):
    raise FileNotFoundError(f"Fine-tuned retriever not found: {FINETUNED_MODEL_PATH}")

Using device: cuda


In [6]:
retriever_model = SentenceTransformer(FINETUNED_MODEL_PATH, device=DEVICE)
print("Retriever loaded.")

corpus_passages_data = []
corpus_passage_texts = []

with open(wikipedia_chunks_path, "r", encoding="utf-8") as f:
    for line in tqdm(f, desc="Loading corpus passages"):
        record = json.loads(line)
        if "passage_id" in record and "passage_text" in record and record["passage_text"].strip():
            corpus_passages_data.append(record)
            corpus_passage_texts.append(record["passage_text"])
        else:
            print(f"Warning: Skipping invalid record: {record}")

print(f"Loaded {len(corpus_passages_data)} corpus passages.")
if not corpus_passages_data:
    raise ValueError("No corpus passages loaded. Check the wikipedia_chunks_bge_base.jsonl file.")

Retriever loaded.


Loading corpus passages: 0it [00:00, ?it/s]

Loaded 36508 corpus passages.


In [7]:
if os.path.exists(corpus_embeddings_path):
    print(f"Loading existing embeddings from: {corpus_embeddings_path}")
    corpus_embeddings = np.load(corpus_embeddings_path)

    if corpus_embeddings.shape[0] == len(corpus_passages_data):
        print(
            f"Loaded {corpus_embeddings.shape[0]} embeddings, matching the number of loaded passages."
        )
    else:
        print(
            f"WARNING: Number of loaded embeddings ({corpus_embeddings.shape[0]}) "
            f"does not match number of loaded passages ({len(corpus_passages_data)}). "
            f"Embeddings will be recalculated."
        )
        corpus_embeddings = None
else:
    corpus_embeddings = None

if corpus_embeddings is None:
    corpus_embeddings = retriever_model.encode(
        corpus_passage_texts, show_progress_bar=True, convert_to_numpy=True, batch_size=128
    )
    print(f"Calculated embeddings for corpus. Shape: {corpus_embeddings.shape}")

    np.save(corpus_embeddings_path, corpus_embeddings)
    print(f"Embeddings saved to: {corpus_embeddings_path}.")


if isinstance(corpus_embeddings, torch.Tensor):
    corpus_embeddings = corpus_embeddings.cpu().numpy()

Batches:   0%|          | 0/286 [00:00<?, ?it/s]

Calculated embeddings for corpus. Shape: (36508, 768)
Embeddings saved to: ../data/processed/corpus_embeddings_finetuned.npy.


In [8]:
def search_corpus(query_text, top_k=5):
    """
    Gives top-k similar passages to the asked quesiton
    """
    if not query_text.strip():
        print("Empty qustion")
        return []

    query_embedding = retriever_model.encode(query_text, convert_to_numpy=True)

    if query_embedding.ndim == 1:
        query_embedding = query_embedding.reshape(1, -1)

    cosine_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
    top_k_indices_unsorted = np.argpartition(-cosine_scores, range(top_k))[:top_k]
    top_k_indices = top_k_indices_unsorted[np.argsort(-cosine_scores[top_k_indices_unsorted])]

    results = []
    for i, idx in enumerate(top_k_indices):
        passage_info = corpus_passages_data[idx]
        results.append(
            {
                "rank": i + 1,
                "score": float(cosine_scores[idx]),
                "passage_id": passage_info.get("passage_id"),
                "text": passage_info.get("passage_text"),
                "document_title": passage_info.get("document_title", "N/A"),
            }
        )

    return results

In [9]:
def print_found_passages(results):
    for result in results:
        print(
            f"  Rank {result['rank']}: Score={result['score']:.4f}, DocTitle='{result['document_title']}', PassageID='{result['passage_id']}'"
        )
        print(f"    Fragment: \"{result['text']}\"")
        print("-" * 30)

In [10]:
# example retriver usage

test_query_1 = "When did the French Revolution start?"
retrieved_docs_1 = search_corpus(test_query_1, top_k=3)
print_found_passages(retrieved_docs_1)

print("\n" + "=" * 50 + "\n")

test_query_2 = "What is the capital of Poland?"
retrieved_docs_2 = search_corpus(test_query_2, top_k=3)
print_found_passages(retrieved_docs_2)

print("\n" + "=" * 50 + "\n")

test_query_3 = "Who painted the Mona Lisa?"
retrieved_docs_3 = search_corpus(test_query_3, top_k=3)
print_found_passages(retrieved_docs_3)

print("\n" + "=" * 50 + "\n")

test_query_4 = "Who is the biggest exporter of copper?"
retrieved_docs_4 = search_corpus(test_query_4, top_k=3)
print_found_passages(retrieved_docs_4)

  Rank 1: Score=0.9661, DocTitle='French Revolution', PassageID='wiki_153_chunk_0'
    Fragment: "The French Revolution (French: Révolution française [ʁevɔlysjɔ̃ fʁɑ̃sɛːz]) was a period of political and societal change in France which began with the Estates General of 1789 and ended with the Coup of 18 Brumaire on 9 November 1799. Many of the revolution's ideas are considered fundamental principles of liberal democracy, and its values remain central to modern French political discourse. The causes of the revolution were a combination of social, political, and economic factors which the ancien régime ("old regime") proved unable to manage. A financial crisis and widespread social distress led to the convocation of the Estates General in May 1789, its first meeting since 1614. The representatives of the Third Estate broke away and re-constituted themselves as a National Assembly in June. The Storming of the Bastille in Paris on 14 July was followed by radical measures by the Assembly, am