In [1]:
import os
import json
import random
import numpy as np
from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import torch

DATA_INTERIM_PATH = "../data/interim"
DATA_PROCESSED_PATH = "../data/processed"
NQ_QUESTIONS_FILE = "nq_questions_1000.jsonl"
WIKIPEDIA_CHUNKS_FILE = "wikipedia_chunks_bge_base.jsonl"
OUTPUT_TRAINING_FILE = "retriever_training_data.jsonl"


MODEL_TEACHER_NAME = "BAAI/bge-large-en-v1.5"
K_QUERY = 50
THRESHOLD_ANSWER_SIMILARITY = 0.5
MIN_THRESHOLD_ANSWER_SIMILARITY = 0.25
NUM_POSITIVES_PER_QUERY = 2
NUM_HARD_NEGATIVES_PER_QUERY = 5
NUM_EASY_NEGATIVES_PER_QUERY = 1


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

random.seed(42)
np.random.seed(42)


nq_questions_path = os.path.join(DATA_INTERIM_PATH, NQ_QUESTIONS_FILE)
wikipedia_chunks_path = os.path.join(DATA_INTERIM_PATH, WIKIPEDIA_CHUNKS_FILE)
output_training_path = os.path.join(DATA_PROCESSED_PATH, OUTPUT_TRAINING_FILE)

if not os.path.exists(nq_questions_path):
    raise FileNotFoundError(f"NQ file not found: {nq_questions_path}")
if not os.path.exists(wikipedia_chunks_path):
    raise FileNotFoundError(f"Wiki corupus file not found: {wikipedia_chunks_path}")


Using device: cuda


In [2]:
nq_data = []
with open(nq_questions_path, 'r', encoding='utf-8') as f:
    for line in f:
        record = json.loads(line)
        answer_text = record['answer']
        if isinstance(answer_text, str) and answer_text.strip():
            nq_data.append({
                "id": record.get("id", f"nq_unknown_{len(nq_data)}"),
                "query": record["query"],
                "answer": answer_text.strip() 
            })


print(f"Loaded {len(nq_data)} NQ quesions with answers")
if not nq_data:
    raise ValueError("No NQ question loaded")

passages_data = []
passage_texts_for_embedding = []
passage_id_map = {}

with open(wikipedia_chunks_path, 'r', encoding='utf-8') as f:
    for i, line in enumerate(f):
        record = json.loads(line)
        if 'passage_id' in record and 'passage_text' in record and record['passage_text'].strip():
            passages_data.append(record)
            passage_texts_for_embedding.append(record['passage_text'])
            passage_id_map[i] = record['passage_id']


print(f"Loaded {len(passages_data)} Wikipedia chunks.")
if not passages_data:
    raise ValueError("No Wikipedia chunks loaded")

query_texts_for_embedding = [item['query'] for item in nq_data]
answer_texts_for_embedding = [item['answer'] for item in nq_data]

Loaded 1000 NQ quesions with answers
Loaded 36508 Wikipedia chunks.


In [3]:
teacher_model = SentenceTransformer(MODEL_TEACHER_NAME, device=DEVICE)
print(f"Loaded teacher retriver: {MODEL_TEACHER_NAME}")

print("Generating embeddings for NQ questions")
query_embeddings = teacher_model.encode(query_texts_for_embedding, convert_to_tensor=True, show_progress_bar=True)
print(f"Questions embeddins shape: {query_embeddings.shape}")

print("Generating embeddings for NQ answers")
answer_embeddings = teacher_model.encode(answer_texts_for_embedding, convert_to_tensor=True, show_progress_bar=True)
print(f"Anwsers embeddins shape: {answer_embeddings.shape}")

print("Generating embeddings for Wikipedia chunks...")
passage_embeddings = teacher_model.encode(passage_texts_for_embedding, convert_to_tensor=True, show_progress_bar=True)
print(f"Wikipedia chunks embeddins shape: {passage_embeddings.shape}")

if DEVICE == "cuda":
    query_embeddings = query_embeddings.cpu()
    answer_embeddings = answer_embeddings.cpu()
    passage_embeddings = passage_embeddings.cpu()

Loaded teacher retriver: BAAI/bge-large-en-v1.5
Generating embeddings for NQ questions


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Questions embeddins shape: torch.Size([1000, 1024])
Generating embeddings for NQ answers


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Anwsers embeddins shape: torch.Size([1000, 1024])
Generating embeddings for Wikipedia chunks...


Batches:   0%|          | 0/1141 [00:00<?, ?it/s]

Wikipedia chunks embeddins shape: torch.Size([36508, 1024])


In [4]:
retriever_training_samples = []
processed_passage_ids_for_query = {} # to avoid duplicates


for i in tqdm(range(len(nq_data)), desc="Processing NQ questions"):
    current_query_data = nq_data[i]
    current_query_text = current_query_data['query']
    current_query_embedding = query_embeddings[i]
    current_answer_embedding = answer_embeddings[i]
    
    query_id = current_query_data['id']
    processed_passage_ids_for_query[query_id] = set()

    # similarity between question and all documents
    query_passage_similarities = cos_sim(current_query_embedding, passage_embeddings)[0]
    
    # sort by similarity
    try:
        query_passage_similarities_np = query_passage_similarities.numpy()
    except AttributeError: # Jeśli już jest NumPy array
        query_passage_similarities_np = query_passage_similarities

    # take indexes of the best passages
    # [::-1] reverse the order (descending)
    top_k_query_indices = np.argsort(query_passage_similarities_np)[::-1][:K_QUERY]

    positive_candidates = []
    hard_negative_candidates_from_query_ranking = []

    # evaluation of these passages by comparing their embeddings to anwsers
    for passage_idx in top_k_query_indices:
        passage_id = passage_id_map[passage_idx]
        passage_text = passage_texts_for_embedding[passage_idx]
        passage_embedding = passage_embeddings[passage_idx]
        
        sim_to_query = float(query_passage_similarities_np[passage_idx])
        sim_to_answer = float(cos_sim(passage_embedding, current_answer_embedding)[0][0])

        candidate_info = {
            "query_id": query_id,
            "query_text": current_query_text,
            "passage_id": passage_id,
            "passage_text": passage_text,
            "sim_to_query": sim_to_query,
            "sim_to_answer": sim_to_answer
        }

        if sim_to_answer >= THRESHOLD_ANSWER_SIMILARITY:
            positive_candidates.append(candidate_info)
        else:
            hard_negative_candidates_from_query_ranking.append(candidate_info)
    
    # emergency loop with lower similarity requirement if none positive candidates were found
    if len(positive_candidates) == 0:
        for passage_idx in top_k_query_indices:
            passage_id = passage_id_map[passage_idx]
            passage_text = passage_texts_for_embedding[passage_idx]
            passage_embedding = passage_embeddings[passage_idx]
            
            sim_to_query = float(query_passage_similarities_np[passage_idx])
            sim_to_answer = float(cos_sim(passage_embedding, current_answer_embedding)[0][0])

            candidate_info = {
                "query_id": query_id,
                "query_text": current_query_text,
                "passage_id": passage_id,
                "passage_text": passage_text,
                "sim_to_query": sim_to_query,
                "sim_to_answer": sim_to_answer
            }

            if sim_to_answer >= MIN_THRESHOLD_ANSWER_SIMILARITY:
                positive_candidates.append(candidate_info)
            
            
    # sort positives by similarity to answer
    positive_candidates.sort(key=lambda x: x['sim_to_answer'], reverse=True)
    
    num_added_positives = 0
    for cand in positive_candidates:
        if num_added_positives < NUM_POSITIVES_PER_QUERY and cand['passage_id'] not in processed_passage_ids_for_query[query_id]:
            retriever_training_samples.append({
                "query": cand['query_text'], 
                "passage": cand['passage_text'], 
                "label": 1
            })
            processed_passage_ids_for_query[query_id].add(cand['passage_id'])
            num_added_positives += 1


    # sort hard negatives by those highly similar to question but not similar to answer (descending)
    hard_negative_candidates_from_query_ranking.sort(key=lambda x: x['sim_to_query'] - x['sim_to_answer'], reverse=True)
    
    num_added_hard_negatives = 0
    for cand in hard_negative_candidates_from_query_ranking:
        if num_added_hard_negatives < NUM_HARD_NEGATIVES_PER_QUERY and cand['passage_id'] not in processed_passage_ids_for_query[query_id]:
            retriever_training_samples.append({
                "query": cand['query_text'], 
                "passage": cand['passage_text'], 
                "label": 0
            })
            processed_passage_ids_for_query[query_id].add(cand['passage_id'])
            num_added_hard_negatives += 1
            
    # add easy negatives
    num_added_easy_negatives = 0
    attempts_easy_neg = 0
    all_passage_indices = list(range(len(passages_data)))
    potential_easy_negative_indices = [idx for idx in all_passage_indices if idx not in top_k_query_indices]
    
    if potential_easy_negative_indices:
        while num_added_easy_negatives < NUM_EASY_NEGATIVES_PER_QUERY and attempts_easy_neg < (NUM_EASY_NEGATIVES_PER_QUERY * 10):
            attempts_easy_neg += 1
            if not potential_easy_negative_indices: break

            random_passage_idx = random.choice(potential_easy_negative_indices)
            random_passage_id = passage_id_map[random_passage_idx]
            
            if random_passage_id not in processed_passage_ids_for_query[query_id]:
                random_passage_text = passage_texts_for_embedding[random_passage_idx]
                retriever_training_samples.append({
                    "query": current_query_text, 
                    "passage": random_passage_text, 
                    "label": 0
                })
                processed_passage_ids_for_query[query_id].add(random_passage_id)
                num_added_easy_negatives += 1
                potential_easy_negative_indices.remove(random_passage_idx) 
            if not potential_easy_negative_indices: break

print(f"\nCreated {len(retriever_training_samples)} training indicies.")
if retriever_training_samples:
    labels = [sample['label'] for sample in retriever_training_samples]
    print(f"Class distribution: Positive (1): {labels.count(1)}, Negative (0): {labels.count(0)}")

Processing NQ questions:   0%|          | 0/1000 [00:00<?, ?it/s]


Created 5604 training indicies.
Class distribution: Positive (1): 1998, Negative (0): 3606


In [None]:
with open(output_training_path, 'w', encoding='utf-8') as f_out:
    for sample in tqdm(retriever_training_samples, desc="Saving training dataset"):
        f_out.write(json.dumps(sample) + '\n')
print(f"Saved retriver's training dataset to: {output_training_path}")

Saving training dataset:   0%|          | 0/5604 [00:00<?, ?it/s]

Saved retriver's training dataset to: ../data/processed/retriever_training_data.jsonl
