In [2]:
%pip install sentence-transformers faiss-cpu rapidfuzz pandas tqdm

Collecting sentence-transformers
  Using cached sentence_transformers-5.0.0-py3-none-any.whl.metadata (16 kB)
Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0.post1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.0 kB)
Downloading sentence_transformers-5.0.0-py3-none-any.whl (470 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m470.2/470.2 kB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading faiss_cpu-1.11.0.post1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (31.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.3/31.3 MB[0m [31m91.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: faiss-cpu, sentence-transformers
Successfully installed faiss-cpu-1.11.0.post1 sentence-transformers-5.0.0
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[3

In [1]:
import pandas as pd
import numpy as np
import json
import torch
from sentence_transformers import SentenceTransformer
import faiss
from rapidfuzz import fuzz
from tqdm.auto import tqdm
from pathlib import Path
from datasets import load_dataset
import warnings

# --- Suppress the specific FutureWarning from torch ---
warnings.filterwarnings("ignore", message="`encoder_attention_mask` is deprecated")


# --- Device Setup ---
# Set device to 'cuda' if a GPU is available, otherwise 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Display settings for pandas
pd.set_option('display.max_colwidth', 80)

Using device: cuda
GPU: NVIDIA A100 80GB PCIe


In [2]:
# --- Load Dataset from Hugging Face ---
try:
    # Dataset configuration based on your training script
    DATASET_NAME = "RobbedoesHF/dutch-definitions"
    
    print(f"Loading dataset '{DATASET_NAME}' from Hugging Face Hub...")
    # Load both splits at once
    dataset = load_dataset(DATASET_NAME)
    
    # --- Data Preparation and Renaming ---
    # The Hugging Face dataset uses different column names than the pipeline expects.
    # We need to rename them.
    column_mapping = {
        'Lemma': 'lemma',
        'DefinitionShort': 'short_definition',
        'DefinitionFull': 'long_definition'
    }
    
    # Rename columns for both splits
    dataset = dataset.rename_columns(column_mapping)
    
    # Convert the train and test splits to pandas DataFrames
    train_df = dataset['train'].to_pandas()
    test_df = dataset['test'].to_pandas()

    # --- Verification ---
    expected_cols = {'lemma', 'short_definition', 'long_definition'}
    if not expected_cols.issubset(train_df.columns):
        raise ValueError(f"Required columns are missing from the training set. Found: {train_df.columns.tolist()}")
    if not expected_cols.issubset(test_df.columns):
         raise ValueError(f"Required columns are missing from the test set. Found: {test_df.columns.tolist()}")

    print(f"Successfully loaded and processed '{DATASET_NAME}'.")
    print(f"   Training examples: {len(train_df)}")
    print(f"   Test examples: {len(test_df)}")

    #print("\n--- Training Data Head ---")
    #display(train_df.head(3))
    #print("\n--- Test Data Head ---")
    #display(test_df.head(3))

except Exception as e:
    print(f"An error occurred: {e}")
    print(f"Please ensure the dataset '{DATASET_NAME}' is accessible and has the expected format.")

Loading dataset 'RobbedoesHF/dutch-definitions' from Hugging Face Hub...
Successfully loaded and processed 'RobbedoesHF/dutch-definitions'.
   Training examples: 27880
   Test examples: 3450


In [3]:
class PromptSelector:
    """Semantic-plus-fuzzy selector for few-shot demonstrations."""

    def __init__(self, model_name='paraphrase-multilingual-MiniLM-L12-v2', device='cpu'):
        """Initializes the selector with a model and compute device."""
        self.model = SentenceTransformer(model_name, device=device)
        self.device = device
        self.index = None
        self.train_df = None

    def build_index(self, train_df: pd.DataFrame) -> None:
        """Encodes short definitions and builds a FAISS index."""
        print("\nBuilding FAISS index from training data...")
        self.train_df = train_df.reset_index(drop=True)
        short_defs = self.train_df['short_definition'].tolist()
        
        embeddings = self.model.encode(
            short_defs,
            convert_to_numpy=True,
            show_progress_bar=True,
            normalize_embeddings=True,
            device=self.device
        )
        
        dim = embeddings.shape[1]
        self.index = faiss.IndexFlatIP(dim)
        self.index.add(embeddings.astype('float32'))
        print(f"CPU-based FAISS index built successfully with {self.index.ntotal} vectors.")

    def retrieve_candidates(self, lemma: str, short_def: str, *, k: int) -> pd.DataFrame:
        """
        Retrieves top-k candidates and calculates their individual semantic and lexical scores.
        This is the slow, computationally expensive part.
        """
        if self.index is None:
            raise RuntimeError("Index not built. Call build_index() first.")

        # 1. Retrieve top-k semantic neighbours from FAISS
        query_emb = self.model.encode(
            [short_def], convert_to_numpy=True, normalize_embeddings=True, device=self.device
        )
        scores, idx = self.index.search(query_emb.astype('float32'), k=k)
        
        scores, idx = scores[0], idx[0]
        valid_indices = idx != -1
        scores, idx = scores[valid_indices], idx[valid_indices]
        
        candidates = self.train_df.iloc[idx].copy()
        candidates['semantic_score'] = scores
        
        # 2. Calculate lexical (fuzzy) scores
        lemmas_cand = candidates['lemma'].tolist()
        fuzzy_scores = np.array([fuzz.token_set_ratio(lemma, l) / 100.0 for l in lemmas_cand])
        candidates['lexical_score'] = fuzzy_scores
        
        return candidates

    def rerank_and_select(self, candidates_df: pd.DataFrame, *, n: int, alpha: float) -> pd.DataFrame:
        """
        Quickly re-ranks pre-computed candidates using a new alpha and selects the top-n.
        This is the fast part.
        """
        # Calculate the combined score using the provided alpha
        candidates_df['combined_score'] = alpha * candidates_df['semantic_score'] + (1 - alpha) * candidates_df['lexical_score']
        
        # Sort by the combined score to get the top n unique demonstrations
        unique_candidates = candidates_df.sort_values(
            by='combined_score', ascending=False
        ).drop_duplicates('lemma', keep='first')
        
        return unique_candidates.head(n)

In [4]:
def build_prompt(test_item: pd.Series, demonstrations: pd.DataFrame) -> str:
    """Formats the final prompt, ordering demos from least to most similar."""
    # Reverse the order of demonstrations (most similar goes last)
    demos_ordered = demonstrations.sort_values(by='combined_score', ascending=True)
    
    prompt_parts = []
    for _, demo in demos_ordered.iterrows():
        prompt_parts.append(
            f"Lemma: {demo['lemma']}\n"
            f"Short Definition: {demo['short_definition']}\n"
            f"Long Definition: {demo['long_definition']}"
        )
        
    prompt_content = "\n###\n".join(prompt_parts)
    
    final_prompt = (
        f"{prompt_content}\n"
        f"###\n"
        f"Lemma: {test_item['lemma']}\n"
        f"Short Definition: {test_item['short_definition']}\n"
        f"Long Definition:"
    )
    return final_prompt

In [5]:
# --- Parameters ---
K = 32  # Pool size to retrieve from the index

# --- Initialization ---
selector = PromptSelector(device=device)
selector.build_index(train_df)

# --- STAGE 1: Perform expensive retrieval and scoring ONCE ---
intermediate_candidates = []
print("\n--- Stage 1: Retrieving all candidates ---")
for _, test_row in tqdm(test_df.iterrows(), total=len(test_df), desc="Retrieving Candidates"):
    candidates = selector.retrieve_candidates(
        test_row['lemma'],
        test_row['short_definition'],
        k=K
    )
    intermediate_candidates.append(candidates)

print(f"\nRetrieved candidate sets for all {len(intermediate_candidates)} test items.")


Building FAISS index from training data...


Batches:   0%|          | 0/872 [00:00<?, ?it/s]

CPU-based FAISS index built successfully with 27880 vectors.

--- Stage 1: Retrieving all candidates ---


Retrieving Candidates:   0%|          | 0/3450 [00:00<?, ?it/s]


Retrieved candidate sets for all 3450 test items.


In [12]:
def generate_final_prompts(candidates_list: list, alpha_value: float, n_demos: int):
    """
    Generates structured JSON records for each test item using a given alpha.
    """
    print(f"\n--- Stage 2: Generating prompts with ALPHA = {alpha_value} ---")
    results = []
    
    for i, test_row in tqdm(test_df.iterrows(), total=len(test_df), desc=f"Building Prompts (alpha={alpha_value})"):
        candidates_df = candidates_list[i]
        
        # Perform the FAST re-ranking and selection
        selected_demos = selector.rerank_and_select(
            candidates_df,
            n=n_demos,
            alpha=alpha_value
        )
        
        # --- NEW: Build the structured JSON record ---
        # Start with the base information for the test item
        result_record = {
            'lemma': test_row['lemma'],
            'short_definition': test_row['short_definition'],
            'gold_long_definition': test_row['long_definition']
        }

        # Convert selected_demos to a list of dicts and add each with a unique key
        # We sort by score so example_1 is always the *most* similar
        demos_list = selected_demos.sort_values('combined_score', ascending=False).to_dict('records')
        for idx, demo in enumerate(demos_list):
            result_record[f'few_shot_example_{idx+1}'] = demo
        
        results.append(result_record)
        
    print(f"\n✅ Successfully generated prompts.")
    return results

# --- Run prompt generation with an alpha of 0.8 ---
# This will be fast because the candidates are already calculated.
results_for_alpha_0_8 = generate_final_prompts(intermediate_candidates, alpha_value=0.8, n_demos=3)

# --- Now, try a different alpha of 0.6 ---
# This is also fast, using the same pre-computed candidates from Stage 1.
results_for_alpha_0_6 = generate_final_prompts(intermediate_candidates, alpha_value=0.6, n_demos=3)

# --- Now, try aanother different alpha of 0.3 ---
# This is also fast, using the same pre-computed candidates from Stage 1.
results_for_alpha_0_3 = generate_final_prompts(intermediate_candidates, alpha_value=0.3, n_demos=3)


--- Stage 2: Generating prompts with ALPHA = 0.8 ---


Building Prompts (alpha=0.8):   0%|          | 0/3450 [00:00<?, ?it/s]


✅ Successfully generated prompts.

--- Stage 2: Generating prompts with ALPHA = 0.6 ---


Building Prompts (alpha=0.6):   0%|          | 0/3450 [00:00<?, ?it/s]


✅ Successfully generated prompts.

--- Stage 2: Generating prompts with ALPHA = 0.3 ---


Building Prompts (alpha=0.3):   0%|          | 0/3450 [00:00<?, ?it/s]


✅ Successfully generated prompts.


In [13]:
entry = 10

print("\n--- Example of a structured record with alpha = 0.8 ---")
# Use json.dumps for pretty-printing the new dictionary format
print(json.dumps(results_for_alpha_0_8[entry], indent=2, ensure_ascii=False))

print("\n--- Example of a structured record with alpha = 0.6 ---")
print(json.dumps(results_for_alpha_0_6[entry], indent=2, ensure_ascii=False))

print("\n--- Example of a structured record with alpha = 0.3 ---")
print(json.dumps(results_for_alpha_0_3[entry], indent=2, ensure_ascii=False))


--- Example of a structured record with alpha = 0.8 ---
{
  "lemma": "aal",
  "short_definition": "paling aal als voedsel",
  "gold_long_definition": "lange, slangachtige vis met een slijmerige, gladde huid, kleine borstvinnen, een lange rugvin die de gehele staart omzoomt en tot de aarsvin reikt, kleine kieuwopeningen en een bovenkaak die korter is dan de onderkaak",
  "few_shot_example_1": {
    "lemma": "baars",
    "POS": "substantief",
    "MeaningNumber": "1.1",
    "LemmaID": 11700,
    "MeaningID": 193777,
    "long_definition": "baars als voedsel, als ingrediënt voor een gerecht",
    "short_definition": "baars als voedsel",
    "semantic_score": 0.9345700740814209,
    "lexical_score": 0.5,
    "combined_score": 0.8476560473442077
  },
  "few_shot_example_2": {
    "lemma": "schaap",
    "POS": "substantief",
    "MeaningNumber": "1.2",
    "LemmaID": 136651,
    "MeaningID": 216440,
    "long_definition": "vlees van schapen, gebruikt als voedsel",
    "short_definition": "s

In [14]:
def save_and_display_results(results_list: list, alpha_value: float):
    """
    Displays results in a DataFrame and saves them to a unique JSONL file.
    The output filename is based on the alpha value to prevent overwriting.
    """
    if not results_list:
        print(f"No results to process for alpha = {alpha_value}.")
        return

    # Create a DataFrame from the results
    results_df = pd.DataFrame(results_list)

    # Display the first few results for this alpha value
    print(f"\n--- Generated Prompts for ALPHA = {alpha_value} (DataFrame) ---")
    # UPDATED: Display the full DataFrame head as 'generated_prompt' is gone
    display(results_df.head())

    # --- Save to JSONL with a dynamic filename ---
    filename = f'prompts_alpha_{str(alpha_value).replace(".", "_")}.jsonl'
    output_path = Path(filename)
    
    with output_path.open('w', encoding='utf-8') as f:
        # The records are already in the correct format, so we just write them
        for record in results_list:
            json.dump(record, f, ensure_ascii=False)
            f.write('\n')

    print(f"\nAll prompts for alpha={alpha_value} saved to {output_path}")

In [15]:
# Save and display results for the first alpha value
save_and_display_results(results_for_alpha_0_8, alpha_value=0.8)

# Save and display results for the second alpha value
save_and_display_results(results_for_alpha_0_6, alpha_value=0.6)

# Save and display results for the third alpha value
save_and_display_results(results_for_alpha_0_3, alpha_value=0.3)


--- Generated Prompts for ALPHA = 0.8 (DataFrame) ---


Unnamed: 0,lemma,short_definition,gold_long_definition,few_shot_example_1,few_shot_example_2,few_shot_example_3
0,"1,5 metermaatschappij",maatschappij waarin fysieke afstand nodig is,maatschappij waarin mensen die niet tot hetzelfde huishouden behoren in de p...,"{'lemma': 'anderhalvemetermaatschappij', 'POS': 'substantief', 'MeaningNumbe...","{'lemma': '1,5 metersamenleving', 'POS': 'substantief', 'MeaningNumber': '1....","{'lemma': 'anderhalvemetersamenleving', 'POS': 'substantief', 'MeaningNumber..."
1,112-centrale,alarmcentrale,alarmcentrale die bereikbaar is onder het telefoonnummer 112,"{'lemma': '112-alarmcentrale', 'POS': 'substantief', 'MeaningNumber': '1.0',...","{'lemma': 'alarmkreet', 'POS': 'substantief', 'MeaningNumber': '1.0', 'Lemma...","{'lemma': '112', 'POS': 'telwoord', 'MeaningNumber': '2.0', 'LemmaID': 240, ..."
2,200 eurobiljet,biljet van 200 euro,bankbiljet dat de waarde van 200 euro vertegenwoordigt,"{'lemma': '20 eurobiljet', 'POS': 'substantief', 'MeaningNumber': '1.0', 'Le...","{'lemma': '100 eurobiljet', 'POS': 'substantief', 'MeaningNumber': '1.0', 'L...","{'lemma': '500 eurobiljet', 'POS': 'substantief', 'MeaningNumber': '1.0', 'L..."
3,3D-monitor,monitor met 3D-beeld,monitor die driedimensionaal beeld weergeeft,"{'lemma': '3D-scanner', 'POS': 'substantief', 'MeaningNumber': '1.0', 'Lemma...","{'lemma': '3D-scan', 'POS': 'substantief', 'MeaningNumber': '1.0', 'LemmaID'...","{'lemma': '3D-techniek', 'POS': 'substantief', 'MeaningNumber': '1.0', 'Lemm..."
4,45 kilometerwagen,autootje met bromfietsstatus,motorvoertuig met de officiële status van een bromfiets en het voorkomen van...,"{'lemma': '45 kilometerwagentje', 'POS': 'substantief', 'MeaningNumber': '1....","{'lemma': '45 kilometerauto', 'POS': 'substantief', 'MeaningNumber': '1.0', ...","{'lemma': '45km-wagen', 'POS': 'substantief', 'MeaningNumber': '1.0', 'Lemma..."



All prompts for alpha=0.8 saved to prompts_alpha_0_8.jsonl

--- Generated Prompts for ALPHA = 0.6 (DataFrame) ---


Unnamed: 0,lemma,short_definition,gold_long_definition,few_shot_example_1,few_shot_example_2,few_shot_example_3
0,"1,5 metermaatschappij",maatschappij waarin fysieke afstand nodig is,maatschappij waarin mensen die niet tot hetzelfde huishouden behoren in de p...,"{'lemma': 'anderhalvemetermaatschappij', 'POS': 'substantief', 'MeaningNumbe...","{'lemma': '1,5 metersamenleving', 'POS': 'substantief', 'MeaningNumber': '1....","{'lemma': 'anderhalvemetersamenleving', 'POS': 'substantief', 'MeaningNumber..."
1,112-centrale,alarmcentrale,alarmcentrale die bereikbaar is onder het telefoonnummer 112,"{'lemma': '112-alarmcentrale', 'POS': 'substantief', 'MeaningNumber': '1.0',...","{'lemma': '112', 'POS': 'telwoord', 'MeaningNumber': '2.0', 'LemmaID': 240, ...","{'lemma': 'alarmkreet', 'POS': 'substantief', 'MeaningNumber': '1.0', 'Lemma..."
2,200 eurobiljet,biljet van 200 euro,bankbiljet dat de waarde van 200 euro vertegenwoordigt,"{'lemma': '20 eurobiljet', 'POS': 'substantief', 'MeaningNumber': '1.0', 'Le...","{'lemma': '100 eurobiljet', 'POS': 'substantief', 'MeaningNumber': '1.0', 'L...","{'lemma': '500 eurobiljet', 'POS': 'substantief', 'MeaningNumber': '1.0', 'L..."
3,3D-monitor,monitor met 3D-beeld,monitor die driedimensionaal beeld weergeeft,"{'lemma': '3D-scanner', 'POS': 'substantief', 'MeaningNumber': '1.0', 'Lemma...","{'lemma': '3D-scan', 'POS': 'substantief', 'MeaningNumber': '1.0', 'LemmaID'...","{'lemma': '3D-techniek', 'POS': 'substantief', 'MeaningNumber': '1.0', 'Lemm..."
4,45 kilometerwagen,autootje met bromfietsstatus,motorvoertuig met de officiële status van een bromfiets en het voorkomen van...,"{'lemma': '45 kilometerwagentje', 'POS': 'substantief', 'MeaningNumber': '1....","{'lemma': '45 kilometerauto', 'POS': 'substantief', 'MeaningNumber': '1.0', ...","{'lemma': '45km-wagen', 'POS': 'substantief', 'MeaningNumber': '1.0', 'Lemma..."



All prompts for alpha=0.6 saved to prompts_alpha_0_6.jsonl

--- Generated Prompts for ALPHA = 0.3 (DataFrame) ---


Unnamed: 0,lemma,short_definition,gold_long_definition,few_shot_example_1,few_shot_example_2,few_shot_example_3
0,"1,5 metermaatschappij",maatschappij waarin fysieke afstand nodig is,maatschappij waarin mensen die niet tot hetzelfde huishouden behoren in de p...,"{'lemma': 'anderhalvemetermaatschappij', 'POS': 'substantief', 'MeaningNumbe...","{'lemma': '1,5 metersamenleving', 'POS': 'substantief', 'MeaningNumber': '1....","{'lemma': 'maatschappijgericht', 'POS': 'adjectief', 'MeaningNumber': '1.0',..."
1,112-centrale,alarmcentrale,alarmcentrale die bereikbaar is onder het telefoonnummer 112,"{'lemma': '112-alarmcentrale', 'POS': 'substantief', 'MeaningNumber': '1.0',...","{'lemma': 'controleur', 'POS': 'substantief', 'MeaningNumber': '1.2', 'Lemma...","{'lemma': '112', 'POS': 'telwoord', 'MeaningNumber': '2.0', 'LemmaID': 240, ..."
2,200 eurobiljet,biljet van 200 euro,bankbiljet dat de waarde van 200 euro vertegenwoordigt,"{'lemma': '20 eurobiljet', 'POS': 'substantief', 'MeaningNumber': '1.0', 'Le...","{'lemma': '100 eurobiljet', 'POS': 'substantief', 'MeaningNumber': '1.0', 'L...","{'lemma': '500 eurobiljet', 'POS': 'substantief', 'MeaningNumber': '1.0', 'L..."
3,3D-monitor,monitor met 3D-beeld,monitor die driedimensionaal beeld weergeeft,"{'lemma': '3D-scanner', 'POS': 'substantief', 'MeaningNumber': '1.0', 'Lemma...","{'lemma': '3D-scan', 'POS': 'substantief', 'MeaningNumber': '1.0', 'LemmaID'...","{'lemma': '3D-techniek', 'POS': 'substantief', 'MeaningNumber': '1.0', 'Lemm..."
4,45 kilometerwagen,autootje met bromfietsstatus,motorvoertuig met de officiële status van een bromfiets en het voorkomen van...,"{'lemma': '45 kilometerwagentje', 'POS': 'substantief', 'MeaningNumber': '1....","{'lemma': '45 kilometerauto', 'POS': 'substantief', 'MeaningNumber': '1.0', ...","{'lemma': '45km-wagen', 'POS': 'substantief', 'MeaningNumber': '1.0', 'Lemma..."



All prompts for alpha=0.3 saved to prompts_alpha_0_3.jsonl


In [None]:
# Create a DataFrame from the results
results_df = pd.DataFrame(results)

# Display the first few results. The 'generated_prompt' column is ready for inference.
print("--- Generated Prompts (DataFrame) ---")
display(results_df[['lemma', 'short_definition', 'generated_prompt']].head())

In [None]:
# --- Save to JSONL ---
output_path = Path('prompts_for_testing.jsonl')
with output_path.open('w', encoding='utf-8') as f:
    for record in results:
        json.dump(record, f, ensure_ascii=False)
        f.write('\n')

print(f"All prompts saved to {output_path}")
print("\nExample record from the file:")
print(json.dumps(results[0], indent=2, ensure_ascii=False))