Here is the sketch we want to try to implemnt:
- For each document we need to compute: semantic embedding using sbert (all-distilroberta-v1), style embedding using luar-mud, genre embedding using xml-roberta
- Construct each batch such that:
	1. Each pair of anchor and positive have semantic similarity below threshold-x (0.4)
	2. Each pair of anchor and positive have genre similarity below threshold-z (0.7)
	3. Inside each batch, we have XX% of authors such that their style similarity is higher than threshold-y (0.3)
- Ideas on how to construct these batches:
  - 20% of the instances in the batch are from condition 1
  - 20% of the instances in the batch are from condition 2
  - 20% of the instances would follow condition 3
  - 30% randomly selected instances
  OR
  In each batch having few examples that are hard positives in terms of semantic and genre, and few that contain hard negatives.

In [None]:
import pandas as pd
sadiri_luarsbert_emb_train_path = "/data/araghavan/HIATUS/datadreamer-ta2/data/ta2_jan_2025_trian_data/train_sadiri_processed_with_luarsbertembeddings_wo_ao3_filtered.jsonl"
df_luarsbert_emb_train = pd.read_json(sadiri_luarsbert_emb_train_path, lines=True)

In [None]:
df_luarsbert_emb_train.head()

In [None]:
import numpy as np
from tqdm import tqdm

In [None]:
# -----------------------------
# Create Anchor-Positive Pairs (Within Author)
# -----------------------------
df_anchorpos_pairs_raw = df_luarsbert_emb_train.merge(df_luarsbert_emb_train, how="outer", on="authorID", suffixes=["_anchor", "_positive"])
df_anchorpos_pairs_raw

In [None]:
# -----------------------------
# Filter out self-pairs and same-genre pairs
# -----------------------------
df_anchorpos_pairs_prep = df_anchorpos_pairs_raw.loc[
    (df_anchorpos_pairs_raw["documentID_anchor"] != df_anchorpos_pairs_raw["documentID_positive"]) &
    (df_anchorpos_pairs_raw["doc_xrbmtgc_genre_anchor"] != df_anchorpos_pairs_raw["doc_xrbmtgc_genre_positive"])
].copy()

In [None]:
# -----------------------------
# Deduplicate Unordered Pairs per Author
# -----------------------------
doc_min = np.minimum(df_anchorpos_pairs_prep["documentID_anchor"], df_anchorpos_pairs_prep["documentID_positive"])
doc_max = np.maximum(df_anchorpos_pairs_prep["documentID_anchor"], df_anchorpos_pairs_prep["documentID_positive"])
df_anchorpos_pairs_prep.loc[:, "pair_key"] = (
    df_anchorpos_pairs_prep["authorID"].astype(str) + "__" +
    doc_min.astype(str) + "__" +
    doc_max.astype(str)
)
df_anchorpos_pairs_prep = df_anchorpos_pairs_prep.drop_duplicates(subset="pair_key").drop(columns="pair_key")


In [None]:
# -----------------------------
# Compute Cosine Similarity (Vectorized)
# -----------------------------
def batched_cosine_similarity(a_embeddings, b_embeddings, batch_size=10000):
    scores = []
    for i in tqdm(range(0, len(a_embeddings), batch_size), desc="Cosine Similarity"):
        a_batch = a_embeddings[i:i+batch_size]
        b_batch = b_embeddings[i:i+batch_size]
        sim = np.sum(a_batch * b_batch, axis=1) / (
            np.linalg.norm(a_batch, axis=1) * np.linalg.norm(b_batch, axis=1)
        )
        scores.extend(sim)
    return np.array(scores)

In [None]:
anchor_embeddings_luar = np.stack(df_anchorpos_pairs_prep["doc_luarmud_embedding_anchor"].values)
positive_embeddings_luar = np.stack(df_anchorpos_pairs_prep["doc_luarmud_embedding_positive"].values)
df_anchorpos_pairs_prep["luarmud_embedding_similarity_score"] = batched_cosine_similarity(anchor_embeddings_luar, positive_embeddings_luar, batch_size=10000)

In [None]:
anchor_embeddings_sbert = np.stack(df_anchorpos_pairs_prep["doc_sbertamllv2_embedding_anchor"].values)
positive_embeddings_sbert = np.stack(df_anchorpos_pairs_prep["doc_sbertamllv2_embedding_positive"].values)
df_anchorpos_pairs_prep["sbertamllv2_embedding_similarity_score"] = batched_cosine_similarity(anchor_embeddings_sbert, positive_embeddings_sbert, batch_size=10000)

In [None]:
df_anchorpos_pairs_prep.head()

In [None]:
df_anchorpos_pairs_prep

In [None]:
df_anchorpos_pairs_prep.to_json("../data/ta2_jan_2025_trian_data/anchor_pos_train_sadiri_luarmudsbertamllv2.jsonl", orient='records', lines=True)