Libraries

In [None]:
import sys
import os
import numpy as np
import pandas as pd
import json
from pathlib import Path

from concurrent.futures import ThreadPoolExecutor

from sklearn.metrics.pairwise import cosine_similarity

# add path 
sys.path.append(os.path.abspath("../0. Helpers"))
sys.path.append(os.path.abspath("../2. Data Processing/_dataset_entities"))

from datasets import load_dataset, load_from_disk
from datasetProcessing import tokens_to_sentence, tokens_to_entities, join_datasets, recursive_fix

Process whole dataset

In [None]:
topic = "music"

In [None]:
if topic == "lener":
    from entities_leNER import entity_names, entity_names_parsed
    dataset = load_from_disk("...")
    lang = "portuguese"

elif topic == "neuralshift":
    from entities_neuralshift import entity_names, entity_names_parsed
    dataset = load_from_disk("...")
    lang = "portuguese"

elif topic == "ener":
    from entities_eNER import entity_names, entity_names_parsed
    dataset = load_from_disk("...")
    lang = "english"

elif topic == "multinerd_en":
    from entities_multinerd_en import entity_names, entity_names_parsed
    dataset = load_from_disk("...")
    lang = "english"

elif topic == "multinerd_pt":
    from entities_multinerd_pt import entity_names, entity_names_parsed
    dataset = load_from_disk("...")
    lang = "portuguese"

else:
    from entities_crossNER import entity_names, entity_names_parsed
    dataset = load_dataset("...")
    lang = "english"

# train_data
train_data = dataset["train"]
test_data = dataset["test"]

# get the entity names
start_of_entity_indices = [i for i in range(len(entity_names)) if (entity_names[i].startswith("B-") or entity_names[i].startswith("U-"))]
entity_index_to_name = {i: entity_names[i].split("-")[1] for i in range(len(entity_names)) if entity_names[i] != "O"}
entity_index_to_name[0] = "O"

Run for all test instances

In [None]:
test_len = len(dataset["test"])
train_len = len(dataset["train"])

Get top n demos

In [None]:
lambda_mmr = 0.5

In [None]:
all_n = [5]

# Ensure result dir exists
for n in all_n:
    os.makedirs(f"in_context/{topic}/test/mmr{str(n)}/qwen", exist_ok=True)

In [None]:
# Load train embeddings and sentences
train_embeddings, train_files = [], []

for train_index in range(train_len):

    db_file_path = f"classification/{topic}/train/data/{train_index}.json"
    db_file = json.load(open(db_file_path, "r", encoding="utf-8"))

    # Get embeddings from folder
    with open(f"embeddings/{topic}/train/{train_index}.json", "r", encoding="utf-8") as f:
        train_embedding_data = json.load(f)
        db_embedding_qwen = train_embedding_data['embedding_qwen']

    # Get the entities
    true_entities = tokens_to_entities(db_file['tokens'], db_file['ner_tags'], entity_names_parsed, start_of_entity_indices, entity_index_to_name)

    # Append to train objects
    train_embeddings.append(db_embedding_qwen)

    train_files.append({
        'index': train_index,
        'embedding_qwen': db_embedding_qwen,
        "sentence": db_file['sentence'],
        "true_entities": [entity.to_dict() for entity in true_entities],
    })

train_embeddings = np.array(train_embeddings)
print(f"✅ Cached {len(train_embeddings)} training embeddings")

# Precompute cosine similarity matrix once
train_sim_matrix = cosine_similarity(train_embeddings)
print(f"✅ Computed {train_sim_matrix.shape[0]} x {train_sim_matrix.shape[1]} training embeddings")

In [None]:
# Loop test
for test_index in range(test_len):

    print(f"\rtest {test_index+1}/{test_len}", end='', flush=True)

    # check if mmr already computed
    mmr_n_done = []
    for n in all_n:
        output_file = f"in_context/{topic}/test/mmr{str(n)}/qwen/{test_index}.txt"
        mmr_n_done.append(os.path.exists(output_file))

    if all(mmr_n_done):
        print(f"MMR for test example {test_index} already computed, skipping...")
        continue

    test_example = dataset["test"][test_index]
    test_sentence = tokens_to_sentence(test_example['tokens'])

    # Get test embeddings from folder
    with open(f"embeddings/{topic}/test/{test_index}.json", "r", encoding="utf-8") as f:
        test_embedding_data = json.load(f)
        test_embedding_qwen = test_embedding_data['embedding_qwen']

    # Calculate cosine similarity of documents with the query
    query_similarities = cosine_similarity([test_embedding_qwen], train_embeddings)[0]

    # Initialize variables
    selected_indices = []
    remaining_indices = np.arange(len(train_embeddings))

    # MMR selection process (select top n)
    for _ in range(max(all_n)):

        if not selected_indices:
            mmr_scores = query_similarities[remaining_indices]
        else:
            diversity_scores = np.max(train_sim_matrix[np.ix_(remaining_indices, selected_indices)], axis=1)
            mmr_scores = (lambda_mmr * query_similarities[remaining_indices] - (1 - lambda_mmr) * diversity_scores)

        best_idx_local = np.argmax(mmr_scores)
        selected_indices.append(remaining_indices[best_idx_local])
        remaining_indices = np.delete(remaining_indices, best_idx_local)

        # mmr_scores = []
        
        # for i in remaining_indices:

        #     # Calculate diversity term
        #     diversity_score = max(
        #         cosine_similarity([train_embeddings[i]], [train_embeddings[j]])[0][0]
        #         for j in selected_indices
        #     ) if selected_indices else 0
            
        #     # MMR formula
        #     mmr_score = lambda_mmr * query_similarities[i] - (1 - lambda_mmr) * diversity_score
        #     mmr_scores.append((i, mmr_score))
        
        # # Select instance with highest MMR score
        # best_train_instance = max(mmr_scores, key = lambda x: x[1])
        # selected_indices.append(best_train_instance[0])
        # remaining_indices.remove(best_train_instance[0])

    for n in all_n:

        qwen_example_txt = ""
        for i in range(n):
            idx = selected_indices[i]
            db_file = train_files[idx]
            qwen_example_txt += f"Example #{i+1}: {db_file['sentence']}\n"
            qwen_example_txt += f"Expected output: 'entities: {db_file['true_entities']}'\n\n"

        output_file = f"in_context/{topic}/test/mmr{str(n)}/qwen/{test_index}.txt"
        with open(output_file, "w", encoding="utf-8") as f:
            f.write(qwen_example_txt)