In [1]:
# for every visited = False hard positive pair: X(author Xa, anchor Xda, positive Xdp)
#     query -> Xda [anchor of pair X]
#     mark -> pair X as visited = True
#     candidates -> pairs author != Xa & visited = False
#     faiss search (query, candidates) -> search_results
#     drop duplicates by authors in search_results -> unique_filtered_search_results
#     negatives -> random select batch_size - 1 from unique_filtered_search_results
#     mark -> negatives pairs as visited = True
#     mark -> batch for X & negatives as batch_num

In [2]:
import os
import glob
import pandas as pd
import numpy as np
import faiss
import random
import json
from tqdm import tqdm

def load_jsonl_into_dataframe(jsonl_path):
    """
    Reads a single .jsonl file (or you could adapt to read multiple).
    Returns a pandas DataFrame.
    """
    df = pd.read_json(jsonl_path, orient="records", lines=True)
    return df

def prepare_dataframe(
    df, 
    embedding_col="doc_luarmud_embedding_anchor", 
    new_embedding_col="doc_luarmud_embedding_anchor_nparr"
):
    """
    1) Checks that 'documentID_anchor' exists (unique ID).
    2) Extracts 'authorID' from 'authorIDs_anchor' (first element).
    3) Converts anchor embeddings to float32 arrays (new_embedding_col).
    4) Drops rows missing embeddings.
    """
    df = df.copy()

    # Ensure 'documentID_anchor' is present
    if "documentID_anchor" not in df.columns:
        raise ValueError("Data must contain a 'documentID_anchor' column for unique identification.")

    # If you always have 'authorIDs_anchor' as a list with at least 1 ID,
    # create a single 'authorID' column:
    if "authorID" not in df.columns:
        raise ValueError("Data must contain 'authorID' for extracting the single authorID.")

    # Convert anchor embeddings to np.float32
    if embedding_col not in df.columns:
        raise ValueError(f"Column '{embedding_col}' not found. Check your data.")

    def to_array(x):
        return np.array(x, dtype=np.float32) if isinstance(x, list) else None

    df[new_embedding_col] = df[embedding_col].apply(to_array)
    
    # Drop rows missing anchor embeddings
    df = df.dropna(subset=[new_embedding_col]).reset_index(drop=True)

    return df

def build_faiss_index(df, embedding_col="doc_luarmud_embedding_anchor_nparr", metric="l2"):
    """
    Builds a FAISS index over df[embedding_col].
    Returns (index_id_map, id_map).
    """
    embeddings = np.vstack(df[embedding_col].values)  # shape: (N, dim)
    dim = embeddings.shape[1]

    if metric == "l2":
        index = faiss.IndexFlatL2(dim)
    elif metric == "ip":
        index = faiss.IndexFlatIP(dim)
    else:
        raise ValueError("Invalid metric. Use 'l2' or 'ip'.")

    # ID map: 0..N-1 -> row index
    id_map = np.arange(len(df), dtype=np.int64)

    index_id_map = faiss.IndexIDMap2(index)
    index_id_map.add_with_ids(embeddings, id_map)
    return index_id_map, id_map

def create_batches_iterative(
    df, 
    faiss_index, 
    batch_size=32, 
    embedding_col="doc_luarmud_embedding_anchor_nparr",
    author_col="authorID",  
    random_seed=42,
    out_jsonl_path="iterative_batches_output.jsonl"
):
    """
    Iteratively forms batches using anchor embeddings:
      1) Pick an unvisited seed
      2) FAISS search for neighbors among unvisited
      3) Exclude same-author
      4) Randomly pick up to batch_size-1
      5) Mark used as visited
      6) Write each row (seed + negatives) to out_jsonl_path with "batch_id"
    """
    random.seed(random_seed)
    
    # row index -> docID_anchor
    row_to_documentID = df["documentID_anchor"].to_dict()  # {row_idx: docID}

    # Ensure the author_col is present
    if author_col not in df.columns:
        raise ValueError(f"Data must contain a column '{author_col}' for negative filtering.")
    
    # Get the array of authors
    author_ids = df[author_col].values  # shape: (N,)

    # Set of row indices that are unvisited
    unvisited_rows = set(range(len(df)))

    with open(out_jsonl_path, "w", encoding="utf-8") as out_file:
        batches = []
        batch_num = 0

        pbar = tqdm(total=len(df), desc="Forming batches")
        while unvisited_rows:
            # pick a seed
            seed_row = next(iter(unvisited_rows))
            seed_author = author_ids[seed_row]

            # search in FAISS
            top_k = len(unvisited_rows)  # you can reduce this if it's too large
            seed_emb = df.at[seed_row, embedding_col].reshape(1, -1)
            distances, neighbors = faiss_index.search(seed_emb, top_k)
            neighbors = neighbors[0]

            # Filter out seed, visited, same-author
            valid_negatives = []
            for nbr_row in neighbors:
                if nbr_row == seed_row:
                    continue
                if nbr_row not in unvisited_rows:
                    continue
                if author_ids[nbr_row] == seed_author:
                    continue
                valid_negatives.append(nbr_row)

            # One negative per unique author (if desired)
            unique_neg_by_author = {}
            for nr in valid_negatives:
                a = author_ids[nr]
                if a not in unique_neg_by_author:
                    unique_neg_by_author[a] = nr
            final_candidates = list(unique_neg_by_author.values())

            # random pick up to (batch_size - 1)
            needed = batch_size - 1
            if len(final_candidates) > needed:
                final_candidates = random.sample(final_candidates, needed)

            # form the batch
            batch_rows = [seed_row] + final_candidates

            # mark them visited
            for r in batch_rows:
                unvisited_rows.remove(r)

            # write each row to JSONL
            doc_ids_in_batch = []
            for r in batch_rows:
                row_data = df.iloc[r].to_dict()
                row_data["batch_id"] = batch_num
                # remove the numpy array so JSON dumping won't fail
                row_data.pop(embedding_col, None)
                row_data.pop("doc_luarmud_embedding_anchor", None)
                row_data.pop("doc_luarmud_embedding_positive", None)

                out_file.write(json.dumps(row_data, ensure_ascii=False) + "\n")
                doc_ids_in_batch.append(row_to_documentID[r])

            # keep track in memory (optional)
            batches.append({
                "batch_id": batch_num,
                "doc_ids": doc_ids_in_batch
            })
            batch_num += 1
            pbar.update(len(batch_rows))

        pbar.close()
    return batches


In [3]:
input_hard_pos_pairs_path = "/data/araghavan/HIATUS/datadreamer-ta2/data/ta2_jan_2025_trian_data/trainsadiri_luarmud_chunks/train_sadiri_luarmud_embeddings_hard_pos_pairs.jsonl"
df_raw = load_jsonl_into_dataframe(input_hard_pos_pairs_path)
print(f"Loaded {len(df_raw)} total rows from JSONL.")


Loaded 401587 total rows from JSONL.


In [4]:
df_raw.columns

Index(['documentID_anchor', 'authorIDs_anchor', 'fullText_anchor',
       'spanAttribution_anchor', 'collectionNum_anchor', 'source_anchor',
       'dateCollected_anchor', 'publiclyAvailable_anchor',
       'deidentified_anchor', 'languages_anchor', 'lengthWords_anchor',
       'dateCreated_anchor', 'timeCreated_anchor', 'isForeground_anchor',
       'authorID', 'genre_anchor', 'doc_luarmud_embedding_anchor',
       'doc_xrbmtgc_genre_anchor', 'documentID_positive', 'authorIDs_positive',
       'fullText_positive', 'spanAttribution_positive',
       'collectionNum_positive', 'source_positive', 'dateCollected_positive',
       'publiclyAvailable_positive', 'deidentified_positive',
       'languages_positive', 'lengthWords_positive', 'dateCreated_positive',
       'timeCreated_positive', 'isForeground_positive', 'genre_positive',
       'doc_luarmud_embedding_positive', 'doc_xrbmtgc_genre_positive',
       'similarity_score'],
      dtype='object')

In [5]:
# 1) Prepare DataFrame
df_prepped = prepare_dataframe(df_raw, embedding_col="doc_luarmud_embedding_anchor", new_embedding_col="doc_luarmud_embedding_anchor_nparr")
print(f"Data after dropping missing embeddings: {len(df_prepped)} rows.")


Data after dropping missing embeddings: 401587 rows.


In [6]:
# 2) Build a FAISS index on anchor embeddings
faiss_index, id_map = build_faiss_index(
    df_prepped, 
    embedding_col="doc_luarmud_embedding_anchor_nparr", 
    metric="l2"
)
print("Built FAISS index with size =", faiss_index.ntotal)


Built FAISS index with size = 401587


In [None]:

# 3) Iterative batching
# output_jsonl = "/data/araghavan/HIATUS/datadreamer-ta2/data/ta2_jan_2025_trian_data/trainsadiri_luarmud_chunks/train_sadiri_hard_batches_v001.jsonl"
output_jsonl = "/data/araghavan/HIATUS/datadreamer-ta2/data/ta2_jan_2025_trian_data/hard_batching_dataset_luaremb_wo_ao3_filtered/train_sadiri_hard_batches_v001.jsonl"
batch_size = 32
print(f"Creating batches of size {batch_size} ...")

all_batches = create_batches_iterative(
    df=df_prepped,
    faiss_index=faiss_index,
    batch_size=batch_size,
    embedding_col="doc_luarmud_embedding_anchor_nparr",
    author_col="authorID",   # the single ID we extracted
    random_seed=42,
    out_jsonl_path=output_jsonl
)
print(f"Formed {len(all_batches)} batches.")
print(f"Saved each batch record to '{output_jsonl}' in JSON lines format.")

Creating batches of size 32 ...


Forming batches: 100%|██████████| 401587/401587 [45:05<00:00, 148.43it/s] 


Formed 13574 batches.
Saved each batch record to '/data/araghavan/HIATUS/datadreamer-ta2/data/ta2_jan_2025_trian_data/trainsadiri_luarmud_chunks/train_sadiri_hard_batches_v001.jsonl' in JSON lines format.
