In [1]:
from pathlib import Path
import pandas as pd
from tqdm import tqdm
import torch


data_path = Path("/home/yazici/playground/new-prompts/output/gpt-4o-2024-08-06_event2_newest_report_20241107-035313.json")
anchor_file_path = Path("/mnt/datasets/dop-position-mining/wiki-anchor/anchor_target_counts.csv")
eval_output_path = Path("/home/yazici/playground/new-prompts/output/eval_output.json")
eval_output_path_new = Path("/home/yazici/playground/new-prompts/output/eval_output_new.json")
# read eval_output.json (if it exists, otherwise create it)
device = "cuda" if torch.cuda.is_available() else "cpu"

df = pd.read_json(data_path)
df_report = df.explode("positions").reset_index(drop=True)
# Normalize the 'positions' field into a separate dataframe
df_positions = pd.json_normalize(df_report["positions"])
# drop rows that have the targets as empty lists
df_positions = df_positions[df_positions["targets"].apply(len) > 0]
# --- Stakeholder Clustering ---
print(
    f"{len(df_positions['stakeholder'].unique())} stakeholders"
    " before clustering..."
)
stakeholders = df_positions["stakeholder"].tolist()
stakeholders = [stakeholder.lower().strip() for stakeholder in stakeholders]
# does df_max_views.parquet exist?
anchor_file_path_dir = Path(anchor_file_path).parent
if (anchor_file_path_dir / "df_max_views.parquet").exists():
    print("Anchor file found.")
    df_max_views = pd.read_parquet(anchor_file_path_dir / "df_max_views.parquet")
else:
    raise FileNotFoundError("Anchor file not found.")


3836 stakeholders before clustering...
Anchor file found.


In [2]:
# create all possible combinations of the following items:
# 1. ("hg", "arkohut/jina-embeddings-v3")
# 2. ("st", "all-mpnet-base-v2")
# 3. ("hg", "Alibaba-NLP/gte-multilingual-base")
# 4. ("hg", "dunzhang/stella_en_1.5B_v5"),
# 5. ("hg", "intfloat/multilingual-e5-large-instruct"),
# 6. ("st", "paraphrase-multilingual-mpnet-base-v2")

import itertools

# Define the list of items
items = [
    ("hg", "arkohut/jina-embeddings-v3"),
    ("st", "all-mpnet-base-v2"),
    ("st", "Alibaba-NLP/gte-multilingual-base"),
    ("st", "dunzhang/stella_en_1.5B_v5"),
    ("st", "intfloat/multilingual-e5-large-instruct"),
    ("st", "paraphrase-multilingual-mpnet-base-v2")
]

# Generate all combinations of sizes 1 to 6
all_combinations = []
for r in range(3, len(items) + 1):
    combinations = itertools.combinations(items, r)
    all_combinations.extend(combinations)

# Print the combinations
for combination in all_combinations:
    print(combination)

# If you want the total count of combinations
print(f"Total number of combinations: {len(all_combinations)}")

(('hg', 'arkohut/jina-embeddings-v3'), ('st', 'all-mpnet-base-v2'), ('st', 'Alibaba-NLP/gte-multilingual-base'))
(('hg', 'arkohut/jina-embeddings-v3'), ('st', 'all-mpnet-base-v2'), ('st', 'dunzhang/stella_en_1.5B_v5'))
(('hg', 'arkohut/jina-embeddings-v3'), ('st', 'all-mpnet-base-v2'), ('st', 'intfloat/multilingual-e5-large-instruct'))
(('hg', 'arkohut/jina-embeddings-v3'), ('st', 'all-mpnet-base-v2'), ('st', 'paraphrase-multilingual-mpnet-base-v2'))
(('hg', 'arkohut/jina-embeddings-v3'), ('st', 'Alibaba-NLP/gte-multilingual-base'), ('st', 'dunzhang/stella_en_1.5B_v5'))
(('hg', 'arkohut/jina-embeddings-v3'), ('st', 'Alibaba-NLP/gte-multilingual-base'), ('st', 'intfloat/multilingual-e5-large-instruct'))
(('hg', 'arkohut/jina-embeddings-v3'), ('st', 'Alibaba-NLP/gte-multilingual-base'), ('st', 'paraphrase-multilingual-mpnet-base-v2'))
(('hg', 'arkohut/jina-embeddings-v3'), ('st', 'dunzhang/stella_en_1.5B_v5'), ('st', 'intfloat/multilingual-e5-large-instruct'))
(('hg', 'arkohut/jina-embed

In [3]:
def get_param_combinations(params):
    keys, values = zip(*params.items())
    combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
    return combinations

In [4]:
search_grid = {
    "threshold": [0.05, 0.1, 0.15, 0.2, 0.25, 0.3],
    "model_names": all_combinations,
    "voting": ["all", "majority", "any"]
}

combinations = get_param_combinations(search_grid)
print(f"Total number of combinations: {len(combinations)}")

Total number of combinations: 756


In [5]:
combinations[217]

{'threshold': 0.1,
 'model_names': (('st', 'all-mpnet-base-v2'),
  ('st', 'Alibaba-NLP/gte-multilingual-base'),
  ('st', 'dunzhang/stella_en_1.5B_v5'),
  ('st', 'intfloat/multilingual-e5-large-instruct')),
 'voting': 'majority'}

In [6]:
model_names = [
    ("st", "arkohut/jina-embeddings-v3", "cuda:0"), ("st", "all-mpnet-base-v2", "cuda:1")
]

In [7]:
import post_processing_multimodel

model_names = combinations[217]["model_names"]
threshold = combinations[217]["threshold"]
model_names

  from tqdm.autonotebook import tqdm, trange


(('st', 'all-mpnet-base-v2'),
 ('st', 'Alibaba-NLP/gte-multilingual-base'),
 ('st', 'dunzhang/stella_en_1.5B_v5'),
 ('st', 'intfloat/multilingual-e5-large-instruct'))

In [8]:
models = post_processing_multimodel.get_models(model_names)

Some weights of the model checkpoint at Alibaba-NLP/gte-multilingual-base were not used when initializing NewModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing NewModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing NewModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
df_embeddings = []
for model_name in model_names:
    print(f"Loading embeds for model: {model_name}")
    df_embeddings.append(torch.load(post_processing_multimodel.device_to_embed_map[model_name[1]]))

Loading embeds for model: ('st', 'all-mpnet-base-v2')


  df_embeddings.append(torch.load(post_processing_multimodel.device_to_embed_map[model_name[1]]))


Loading embeds for model: ('st', 'Alibaba-NLP/gte-multilingual-base')
Loading embeds for model: ('st', 'dunzhang/stella_en_1.5B_v5')
Loading embeds for model: ('st', 'intfloat/multilingual-e5-large-instruct')


In [21]:
from importlib import reload
reload(post_processing_multimodel)

<module 'post_processing_multimodel' from '/home/yazici/playground/new-prompts/post_processing_multimodel.py'>

In [22]:
models[2]

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: Qwen2Model 
  (1): Pooling({'word_embedding_dimension': 1536, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Dense({'in_features': 1536, 'out_features': 1024, 'bias': True, 'activation_function': 'torch.nn.modules.linear.Identity'})
)

In [27]:
temp_file_path = Path("/home/yazici/playground/new-prompts/temp-files")

In [23]:
new_stakeholders, stakeholder_index_to_wiki_id, stakeholder_replacement = post_processing_multimodel.wiki_anchor(
    stakeholders=stakeholders,
    df_embeddings=df_embeddings,
    device="cuda",
    df_max_views=df_max_views,
    models=models,
    voting=combinations[217]["voting"],
    output_dir=temp_file_path,
    threshold=threshold,
    clustering_method="fast",
    event_name="event2_multimodel",
)


In else | Encoding stakeholders using sentence-transformers/all-mpnet-base-v2


Batches:   0%|          | 0/120 [00:00<?, ?it/s]


In else | Encoding stakeholders using Alibaba-NLP/gte-multilingual-base


Batches:   0%|          | 0/120 [00:00<?, ?it/s]


In If | Encoding stakeholders using dunzhang/stella_en_1.5B_v5


Batches:   0%|          | 0/120 [00:00<?, ?it/s]


In Elif | Encoding stakeholders using intfloat/multilingual-e5-large-instruct


Batches:   0%|          | 0/120 [00:00<?, ?it/s]

Query chunks: 100%|██████████| 39/39 [02:37<00:00,  4.04s/it]
Query chunks: 100%|██████████| 39/39 [02:53<00:00,  4.44s/it]
Query chunks: 100%|██████████| 39/39 [04:26<00:00,  6.82s/it]
Query chunks: 100%|██████████| 39/39 [04:13<00:00,  6.50s/it]


Hit count: 17 | Hit percentage before wiki: 0.12%


Fetching wiki info: 100%|██████████| 7/7 [00:01<00:00,  3.56it/s]


Length of wiki corpus: 207
Length of missing stakeholders: 3811

In else | Encoding stakeholders using sentence-transformers/all-mpnet-base-v2


Batches:   0%|          | 0/120 [00:00<?, ?it/s]


In else | Encoding stakeholders using Alibaba-NLP/gte-multilingual-base


Batches:   0%|          | 0/120 [00:00<?, ?it/s]


In If | Encoding stakeholders using dunzhang/stella_en_1.5B_v5


Batches:   0%|          | 0/120 [00:00<?, ?it/s]


In Elif | Encoding stakeholders using intfloat/multilingual-e5-large-instruct


Batches:   0%|          | 0/120 [00:00<?, ?it/s]


In else | Encoding stakeholders using sentence-transformers/all-mpnet-base-v2


Batches:   0%|          | 0/7 [00:00<?, ?it/s]


In else | Encoding stakeholders using Alibaba-NLP/gte-multilingual-base


Batches:   0%|          | 0/7 [00:00<?, ?it/s]


In If | Encoding stakeholders using dunzhang/stella_en_1.5B_v5


Batches:   0%|          | 0/7 [00:00<?, ?it/s]


In Elif | Encoding stakeholders using intfloat/multilingual-e5-large-instruct


Batches:   0%|          | 0/7 [00:00<?, ?it/s]

Query chunks: 100%|██████████| 39/39 [00:00<00:00, 687.74it/s]
Query chunks: 100%|██████████| 39/39 [00:00<00:00, 1074.59it/s]
Query chunks: 100%|██████████| 39/39 [00:00<00:00, 787.42it/s]
Query chunks: 100%|██████████| 39/39 [00:00<00:00, 1182.07it/s]


Hit count: 41 | Hit percentage after wiki: 0.29%


Augmenting stakeholders: 100%|██████████| 3819/3819 [00:00<00:00, 6772.84it/s]



In else | Encoding stakeholders using sentence-transformers/all-mpnet-base-v2


Batches:   0%|          | 0/120 [00:00<?, ?it/s]


In else | Encoding stakeholders using Alibaba-NLP/gte-multilingual-base


Batches:   0%|          | 0/120 [00:00<?, ?it/s]


In If | Encoding stakeholders using dunzhang/stella_en_1.5B_v5


Batches:   0%|          | 0/120 [00:00<?, ?it/s]


In Elif | Encoding stakeholders using intfloat/multilingual-e5-large-instruct


Batches:   0%|          | 0/120 [00:00<?, ?it/s]

Clustering stakeholders/targets threshold 0.1...
Fast clustering start


Finding clusters: 100%|██████████| 4/4 [00:00<00:00, 10.31it/s]


Clustering done after 0.40 sec
Hit count: 41 | Hit percentage after clustering: 0.29%


In [28]:
import pickle

# read all_results_eventname.json
with open(temp_file_path / "all_results_event2_multimodel.pkl", "rb") as f:
    all_results = pickle.load(f)

In [33]:
model_names

(('st', 'all-mpnet-base-v2'),
 ('st', 'Alibaba-NLP/gte-multilingual-base'),
 ('st', 'dunzhang/stella_en_1.5B_v5'),
 ('st', 'intfloat/multilingual-e5-large-instruct'))

In [32]:
all_results["zelenskyy"]

[{'score': 0.9107993245124817,
  'corpus_id': 4671886,
  'anchor_text': 'zelensky',
  'model_idx': 0},
 {'score': 0.8802839517593384,
  'corpus_id': 4671876,
  'anchor_text': 'zelenodolsky',
  'model_idx': 0},
 {'score': 0.8370147347450256,
  'corpus_id': 2479776,
  'anchor_text': 'lev zeleny',
  'model_idx': 0},
 {'score': 0.8345829248428345,
  'corpus_id': 4671855,
  'anchor_text': 'zelenchukskaya',
  'model_idx': 0},
 {'score': 0.8317021131515503,
  'corpus_id': 4667176,
  'anchor_text': 'zalischyky',
  'model_idx': 0},
 {'score': 0.9817342162132263,
  'corpus_id': 4671886,
  'anchor_text': 'zelensky',
  'model_idx': 1},
 {'score': 0.9017746448516846,
  'corpus_id': 4671852,
  'anchor_text': 'zelenay',
  'model_idx': 1},
 {'score': 0.8885848522186279,
  'corpus_id': 4671853,
  'anchor_text': 'zelenaši',
  'model_idx': 1},
 {'score': 0.8780649304389954,
  'corpus_id': 4470569,
  'anchor_text': 'vladimir zelensky',
  'model_idx': 1},
 {'score': 0.8705978393554688,
  'corpus_id': 46718

In [24]:
import json
from collections import defaultdict
stakeholder_eval_set_answers = ""
with open("stakeholder_eval_set_answers.txt", "r") as f:
    for line in f:
        stakeholder_eval_set_answers += line.strip()

stakeholder_eval_set_answers = json.loads(stakeholder_eval_set_answers)

# Positive samples (pairs of items within the same list)
positive_samples = []
for sublist in stakeholder_eval_set_answers:
    for i in range(len(sublist)):
        for j in range(i + 1, len(sublist)):
            positive_samples.append((sublist[i], sublist[j]))

# Negative samples (pairs of items from different sublists)
negative_samples = []
for i in range(len(stakeholder_eval_set_answers)):
    for j in range(i + 1, len(stakeholder_eval_set_answers)):
        # Create all possible pairs between sublist[i] and sublist[j]
        for element1 in stakeholder_eval_set_answers[i]:
            for element2 in stakeholder_eval_set_answers[j]:
                negative_samples.append((element1, element2))

print(f"Number of positive samples: {len(positive_samples)}")
print(f"Number of negative samples: {len(negative_samples)}")

Number of positive samples: 398
Number of negative samples: 41797


In [25]:
stakeholder_replacement_grouped = defaultdict(list)

for k,v in stakeholder_replacement.items():
    stakeholder_replacement_grouped[v].append(k)

stakeholder_clusters_final = [
    [stakeholder for stakeholder in cluster] for cluster in stakeholder_replacement_grouped.values()
]

positive_results = []
for sublist in stakeholder_clusters_final:
    for i in range(len(sublist)):
        for j in range(i + 1, len(sublist)):
            positive_results.append((sublist[i], sublist[j]))

positive_samples_set = set(positive_samples)
negative_samples_set = set(negative_samples)

# Convert the samples to sorted tuples (to handle unordered pairs)
positive_samples_set = {tuple(sorted(sample)) for sample in positive_samples}
negative_samples_set = {tuple(sorted(sample)) for sample in negative_samples}

# Calculate true positives (TP): positive samples that exist in positive_results
true_positives_results = [sample for sample in positive_samples_set if tuple(sorted(sample)) in positive_results]

# Calculate false negatives (FN): positive samples that do not exist in positive_results
false_negatives_results = [sample for sample in positive_samples_set if tuple(sorted(sample)) not in positive_results]

# Calculate false positives (FP): negative samples that exist in positive_results
false_positives_results = [sample for sample in negative_samples_set if tuple(sorted(sample)) in positive_results]

# Calculate true negatives (TN): negative samples that do not exist in positive_results
true_negatives_results = [sample for sample in negative_samples_set if tuple(sorted(sample)) not in positive_results]


true_positives = len(true_positives_results)
false_negatives = len(false_negatives_results)
false_positives = len(false_positives_results)
true_negatives = len(true_negatives_results)

# Output the results
print("True Positives (TP):", true_positives)
print("False Negatives (FN):", false_negatives)
print("False Positives (FP):", false_positives)
print("True Negatives (TN):", true_negatives)

# Calculate the metrics based on TP, FP, TN, FN
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) != 0 else 0
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) != 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0
accuracy = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives)
fpr = false_positives / (false_positives + true_negatives) if (false_positives + true_negatives) != 0 else 0
specificity = true_negatives / (true_negatives + false_positives) if (true_negatives + false_positives) != 0 else 0

# Output the results
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"False Positive Rate (FPR): {fpr:.4f}")
print(f"Specificity: {specificity:.4f}")

True Positives (TP): 153
False Negatives (FN): 245
False Positives (FP): 6
True Negatives (TN): 41515
Precision: 0.9623
Recall: 0.3844
F1 Score: 0.5494
Accuracy: 0.9940
False Positive Rate (FPR): 0.0001
Specificity: 0.9999


In [26]:
false_positives_results

[('armed forces of ukraine', "ukraine's defense forces"),
 ("ukraine's armed forces", "ukraine's defense forces"),
 ('ukraine military', "ukraine's defense forces"),
 ("ukraine's defense forces", 'ukrainian troops'),
 ('defense forces of ukraine', "ukraine's defense forces"),
 ('alexei shevtsov', 'dmitry shevtsov')]