# Sadiri Hard Mix Batches

In [None]:
import os
import time
import random
import faiss
import numpy as np
import pandas as pd
from random import Random
from transformers import TrainerCallback
epoch_tracker = {}
epoch_tracker['epoch'] = 0
class EpochTrackerCallback(TrainerCallback):
    def on_epoch_begin(self, args, state, control, **kwargs):
        epoch_tracker['epoch'] = int(state.epoch)

    def on_epoch_end(self, args, state, control, **kwargs):
        epoch_tracker['epoch'] = int(state.epoch)

path = "/data/araghavan/HIATUS/datadreamer-ta2/data/ta2_jan_2025_trian_data/anchor_pos_train_sadiri_luarmudsbertamllv2.jsonl"
file_df = pd.read_json(path, lines=True)
file_df.reset_index(drop=True, inplace=True)


In [None]:

def get_anchorpos_hardmix_batches_optimized(file_df, split, args):
    """
    Optimized version that:
     1) Builds and keeps a single FAISS index
     2) Uses a mask to "remove" docs from availability
     3) Writes each batch to JSONL with a 'batch_id' instead of storing all in memory
    """
    print(f"Called get_anchorpos_hardmix_batches_optimized with split: {split}")
    print(f"Reading file: {path}")
    # ---------------------
    # 1. Load & sample data
    # ---------------------

    print(f"Finished loading file: {path} with shape: {file_df.shape}")

    # ---------------------
    # 2. Normalize embeddings once
    # ---------------------
    def normalize(vecs):
        vecs_ = np.vstack(vecs).astype(np.float32)
        norms = np.linalg.norm(vecs_, axis=1, keepdims=True)
        norms = np.clip(norms, 1e-8, np.inf)
        return vecs_ / norms

    all_embeddings = normalize(file_df["doc_luarmud_embedding_anchor"].values)
    ids_array = file_df.index.to_numpy()

    # ---------------------
    # 3. Build a single FAISS index
    # ---------------------
    dim = all_embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)       # inner-product index
    index = faiss.IndexIDMap(index)
    index.add_with_ids(all_embeddings, ids_array.astype(np.int64))

    # Boolean mask for availability
    available_mask = np.ones(len(file_df), dtype=bool)

    # For convenience
    sbert_sim = file_df["sbertamllv2_embedding_similarity_score"].values
    author_ids = file_df["authorID"].values
    doc_ids = file_df["documentID_anchor"].values  # or unique ID

    # ---------------------
    # Batch parameters
    # ---------------------
    batch_query_needed_size = int(0.3 * args.batch_size)
    batch_candidates_needed_size = int(0.3 * args.batch_size)
    print(
        f"Generating batches with batch_query_needed_size={batch_query_needed_size} "
        f"and batch_candidates_needed_size={batch_candidates_needed_size}"
    )

    # We'll count how many batches we produce & how many rows total
    batch_id = 0
    total_rows = 0

    # File path to write out batches
    timestamp = int(time.time())
    output_path_mixedbatches = os.path.join(args.model_op_dir, f"train_sadiri_hardmixbatches_{timestamp}.jsonl")

    # We'll do a random object with a seed based on epoch
    epoch_seed = getattr(epoch_tracker, "epoch", 0)
    rng = Random(epoch_seed)

    while True:
        # ----------------------------------------------
        # 4.1. Pick queries from docs with sbert_sim < threshold & still available
        # ----------------------------------------------
        valid_query_mask = (sbert_sim < args.sem_sim_threshold) & (available_mask)
        valid_queries = np.nonzero(valid_query_mask)[0]  # array of doc indices

        if len(valid_queries) < batch_query_needed_size:
            # Not enough queries to form a batch => break
            break

        # Sample the queries
        chosen_query_indices = rng.sample(list(valid_queries), batch_query_needed_size)
        # Mark them unavailable
        for idx in chosen_query_indices:
            available_mask[idx] = False

        # Collect their authors
        chosen_query_authors = set(author_ids[idx] for idx in chosen_query_indices)

        # ----------------------------------------------
        # 4.2. For each query, search top-K in FAISS
        # ----------------------------------------------
        chosen_query_embeddings = all_embeddings[chosen_query_indices, :]
        # Instead of searching entire corpus_size, limit to a top-K
        top_k = 500
        distances, neighbors = index.search(chosen_query_embeddings, top_k)

        # We'll gather candidate IDs from these searches
        candidate_indices_collector = []
        max_cands_per_query = max(1, batch_candidates_needed_size // batch_query_needed_size)

        used_authors_in_batch = set(chosen_query_authors)

        for i, q_idx in enumerate(chosen_query_indices):
            q_author = author_ids[q_idx]
            picks_for_this_query = 0

            for rank, cand_id in enumerate(neighbors[i]):
                if cand_id == -1:
                    # FAISS sometimes returns -1 if no neighbor
                    continue

                # Already used in a prior batch?
                if not available_mask[cand_id]:
                    continue

                # Check similarity threshold
                sim_score = distances[i][rank]
                if sim_score < args.sty_sim_threshold:
                    continue

                cand_author = author_ids[cand_id]
                # Must differ from query author & not used in this batch
                if cand_author == q_author or cand_author in used_authors_in_batch:
                    continue

                # If it passes all checks, we pick it
                candidate_indices_collector.append(cand_id)
                used_authors_in_batch.add(cand_author)
                picks_for_this_query += 1

                if picks_for_this_query >= max_cands_per_query:
                    break  # done for this query

        # We might have more than needed => sample down
        if len(candidate_indices_collector) > batch_candidates_needed_size:
            candidate_indices_collector = rng.sample(candidate_indices_collector, batch_candidates_needed_size)

        # Mark them unavailable
        for idx in candidate_indices_collector:
            available_mask[idx] = False

        # ----------------------------------------------
        # 4.3. Random docs to fill remainder of the batch
        # ----------------------------------------------
        batch_query_size = len(chosen_query_indices)
        batch_cand_size = len(candidate_indices_collector)
        random_needed = args.batch_size - batch_query_size - batch_cand_size

        random_indices_chosen = []
        if random_needed > 0:
            # Sample from available_mask
            avails = np.nonzero(available_mask)[0]
            if len(avails) < random_needed:
                break  # Not enough docs left

            # Get authors to exclude (from query + candidate docs)
            used_authors_in_batch = set(author_ids[np.concatenate([chosen_query_indices, candidate_indices_collector])])

            # Shuffle & pick random docs whose authors haven't been used yet
            for idx in rng.sample(list(avails), len(avails)):
                if author_ids[idx] in used_authors_in_batch:
                    continue
                random_indices_chosen.append(idx)
                used_authors_in_batch.add(author_ids[idx])  # keep author uniqueness
                if len(random_indices_chosen) == random_needed:
                    break

            if len(random_indices_chosen) < random_needed:
                print(f"Insufficient random docs with unique authors for batch {batch_id}")
                break

            for idx in random_indices_chosen:
                available_mask[idx] = False


        # ----------------------------------------------
        # 4.4. Construct the DataFrame for the batch & write it
        # ----------------------------------------------
        all_batch_indices = chosen_query_indices + candidate_indices_collector + random_indices_chosen

        # If we can't form a full batch, break
        if len(all_batch_indices) < args.batch_size:
            print(f"Breaking at batch_id={batch_id} due to insufficient rows.")
            break

        # Build the batch DF
        batch_df = file_df.iloc[all_batch_indices].copy()
        batch_df["batch_id"] = batch_id

        # Write this batch to the JSONL file in append mode
        # Each row has a field "batch_id"
        # batch_df.to_json(output_path_mixedbatches, orient="records", lines=True, mode="a")
        with open(output_path_mixedbatches, "a") as f:
            batch_df.to_json(f, orient="records", lines=True)


        batch_id += 1
        total_rows += len(batch_df)

        print(f"\r[Progress] Generated batch_id={batch_id}, total_rows={total_rows}, left={available_mask.sum()} ", end="", flush=True)

    print("\nDone generating batches.")
    print(f"Total batches: {batch_id}  | Total rows: {total_rows}")

## Create args namespace with necessary attributes
class Args:
    def __init__(self):
        self.batch_size = 32
        self.sem_sim_threshold = 0.1
        self.sty_sim_threshold = 0.3
        self.model_op_dir = "/data/araghavan/HIATUS/datadreamer-ta2/data/ta2_jan_2025_trian_data"

args = Args()
get_anchorpos_hardmix_batches_optimized(file_df, "train", args)

# Biber Anchor - Pos Pairs

In [1]:
## To generate luar and sbert embeddings for biber data, refer to ~/datadreamer-ta2/src/generate_luar_sbert_genre_emb_biber.py
## Once done, read the embeddings file to create hard batches
import pandas as pd
biber_data_genre_emb_df = pd.read_json("/data/araghavan/HIATUS/datadreamer-ta2/data/biber/genre_data/biber_train_luar_sbert_genre_embeddings_v001.jsonl", lines=True)
biber_data_genre_emb_df.rename(columns={"predicted_genre": "biber_pred_genre"}, inplace=True)
biber_data_genre_emb_df.reset_index(drop=True, inplace=True)


In [2]:
biber_data_genre_emb_df.head()

Unnamed: 0,documentID,authorIDs,fullText,spanAttribution,collectionNum,source,dateCollected,publiclyAvailable,deidentified,languages,...,sourceSpecific,gender,age,topic,sign,date,biber_pred_genre,doc_luarmud_embedding,doc_sbertamllv2_embedding,doc_xrbmtgc_genre
0,bdfad251aa714c409edcc84579a217e0,[002b041f3b0848edb40005d51f2826c8],"“You have done some amazing work, Kurt.” Was a...",[{'authorID': '002b041f3b0848edb40005d51f2826c...,HRS 1,https://archiveofourown.org/,2023-09-20,1.0,1.0,"[en, un, un]",...,,,,,,,5,"[0.050881873800000005, 0.3831697106, -0.137104...","[-0.035311203400000005, 0.058806069200000005, ...",Prose/Lyrical
1,f408b9535f7d46e597c8d19b356b5141,[002b041f3b0848edb40005d51f2826c8],Will gave Blaine a key to the teachers bathroo...,[{'authorID': '002b041f3b0848edb40005d51f2826c...,HRS 1,https://archiveofourown.org/,2023-09-20,1.0,1.0,"[en, un, un]",...,,,,,,,5,"[0.0454084352, 0.31967493890000004, 0.15773031...","[-0.0980639383, 0.0466547422, 0.04821296040000...",Prose/Lyrical
2,48f3efd4362c42b883d45a8671b75744,[002b041f3b0848edb40005d51f2826c8],\n1. Finding Out\nMercedes and Sam had gone to...,[{'authorID': '002b041f3b0848edb40005d51f2826c...,HRS 1,https://archiveofourown.org/,2023-09-20,1.0,1.0,"[en, un, un]",...,,,,,,,5,"[0.157700792, 0.2678265274, -0.0553099811, 0.3...","[-0.0577936843, 0.030992426000000003, 0.061523...",Prose/Lyrical
3,65a57bca1a1444cc87a7e7c368cf6606,[002b041f3b0848edb40005d51f2826c8],"“He is,” Blaine said grabbing the boy’s hand. ...",[{'authorID': '002b041f3b0848edb40005d51f2826c...,HRS 1,https://archiveofourown.org/,2023-09-20,1.0,1.0,"[en, un, un]",...,,,,,,,5,"[0.3562221229, 0.0158033203, 0.1227550954, 0.4...","[-0.1175161228, 0.0711242929, 0.0032115169, 0....",Prose/Lyrical
4,f54c730ce1fa49f0a44998d1e6376ca8,[002b041f3b0848edb40005d51f2826c8],"They talked for the next half an hour, and dis...",[{'authorID': '002b041f3b0848edb40005d51f2826c...,HRS 1,https://archiveofourown.org/,2023-09-20,1.0,1.0,"[en, un, un]",...,,,,,,,5,"[0.5024769306, 0.1137882546, -0.0644055903, 0....","[-0.1070570871, 0.037984032200000004, 0.024115...",Prose/Lyrical


In [3]:
## Do outer join of biber_data_genre_emb_df with itself on authorID column, have suffixes as _anchor, _positive
## Filter out docs where the biber_pred_genre_anchor and biber_pred_genre_positive are same
## generate batchwise cosine similarity scores between doc_sbertamllv2_embedding_anchor and doc_sbertamllv2_embedding_positive and store in a new column sbertamllv2_embedding_similarity_score
## create random batches of size 32 with below conditions
## 1. 30% of the batch: query docs: docs from authors with same biber_pred_genre_anchor and sbertamllv2_embedding_similarity_score < sem_sim_threshold
## 2. 30% of the batch: candidate docs: docs from different authors with same biber_pred_genre_anchor and FAISS vectorized search on doc_luarmud_embedding score > sem_sim_threshold
## 3. 40% of the batch: random docs: docs from different authors with same biber_pred_genre_anchor

## Repeat the above process for 3 epochs, with corpus same as the anchor-pos pairs every epoch while having different random seed and different random batches but continued batch_id, starting from the last batch_id of the previous epoch and incrementing it by 1 and storing the batch_id in the jsonl file but with corpus dataframe as the whole dataframe for every epoch



In [4]:
print(biber_data_genre_emb_df.shape)
biber_data_genre_emb_anchorpos_df = biber_data_genre_emb_df.merge(biber_data_genre_emb_df, how="outer", left_on="authorID", right_on="authorID", suffixes=('_anchor', '_positive'))
biber_data_genre_emb_anchorpos_df = biber_data_genre_emb_anchorpos_df.loc[
    (biber_data_genre_emb_anchorpos_df.documentID_anchor != biber_data_genre_emb_anchorpos_df.documentID_positive) 
    & (biber_data_genre_emb_anchorpos_df.biber_pred_genre_anchor != biber_data_genre_emb_anchorpos_df.biber_pred_genre_positive)
]
print(biber_data_genre_emb_anchorpos_df.shape)
display(biber_data_genre_emb_anchorpos_df.head())

(144628, 51)
(807442, 101)


Unnamed: 0,documentID_anchor,authorIDs_anchor,fullText_anchor,spanAttribution_anchor,collectionNum_anchor,source_anchor,dateCollected_anchor,publiclyAvailable_anchor,deidentified_anchor,languages_anchor,...,sourceSpecific_positive,gender_positive,age_positive,topic_positive,sign_positive,date_positive,biber_pred_genre_positive,doc_luarmud_embedding_positive,doc_sbertamllv2_embedding_positive,doc_xrbmtgc_genre_positive
799,e73a55df06db47f28dfcd15626e345ee,[017457df53294886a3593e5a2977eb6d],Mirchandani had made Ronon Head of Offworld Li...,[{'authorID': '017457df53294886a3593e5a2977eb6...,HRS 1,https://archiveofourown.org/,2023-09-20,1.0,1.0,"[en, un, un]",...,,,,,,,14,"[0.3409982622, 0.0741266087, 0.2813294232, 0.2...","[-0.0947817564, 0.0683237612, 0.0600385629, -0...",Prose/Lyrical
802,e73a55df06db47f28dfcd15626e345ee,[017457df53294886a3593e5a2977eb6d],Mirchandani had made Ronon Head of Offworld Li...,[{'authorID': '017457df53294886a3593e5a2977eb6...,HRS 1,https://archiveofourown.org/,2023-09-20,1.0,1.0,"[en, un, un]",...,,,,,,,14,"[-0.19908979540000002, 0.26586478950000003, 0....","[-0.0569949374, 0.0539704487, 0.0345009565, -0...",Prose/Lyrical
812,591fa218125e4f5bad8cd3fe26d7d7c0,[017457df53294886a3593e5a2977eb6d],"‘Right, I’m going out to get you some clothes ...",[{'authorID': '017457df53294886a3593e5a2977eb6...,HRS 1,https://archiveofourown.org/,2023-09-20,1.0,1.0,"[en, un, un]",...,,,,,,,14,"[0.3409982622, 0.0741266087, 0.2813294232, 0.2...","[-0.0947817564, 0.0683237612, 0.0600385629, -0...",Prose/Lyrical
815,591fa218125e4f5bad8cd3fe26d7d7c0,[017457df53294886a3593e5a2977eb6d],"‘Right, I’m going out to get you some clothes ...",[{'authorID': '017457df53294886a3593e5a2977eb6...,HRS 1,https://archiveofourown.org/,2023-09-20,1.0,1.0,"[en, un, un]",...,,,,,,,14,"[-0.19908979540000002, 0.26586478950000003, 0....","[-0.0569949374, 0.0539704487, 0.0345009565, -0...",Prose/Lyrical
825,141e32047f914ac3b17b3d45bf6d8929,[017457df53294886a3593e5a2977eb6d],Rodney had read the report and been impressed ...,[{'authorID': '017457df53294886a3593e5a2977eb6...,HRS 1,https://archiveofourown.org/,2023-09-20,1.0,1.0,"[en, un, un]",...,,,,,,,14,"[0.3409982622, 0.0741266087, 0.2813294232, 0.2...","[-0.0947817564, 0.0683237612, 0.0600385629, -0...",Prose/Lyrical


In [5]:
# -----------------------------
# Compute Cosine Similarity (Vectorized)
# -----------------------------
import numpy as np
from tqdm import tqdm
def batched_cosine_similarity(a_embeddings, b_embeddings, batch_size=10000):
    scores = []
    for i in tqdm(range(0, len(a_embeddings), batch_size), desc="Cosine Similarity"):
        a_batch = a_embeddings[i:i+batch_size]
        b_batch = b_embeddings[i:i+batch_size]
        sim = np.sum(a_batch * b_batch, axis=1) / (
            np.linalg.norm(a_batch, axis=1) * np.linalg.norm(b_batch, axis=1)
        )
        scores.extend(sim)
    return np.array(scores)
anchor_embeddings = np.stack(biber_data_genre_emb_anchorpos_df["doc_sbertamllv2_embedding_anchor"].values)
positive_embeddings = np.stack(biber_data_genre_emb_anchorpos_df["doc_sbertamllv2_embedding_positive"].values)
biber_data_genre_emb_anchorpos_df["sbertamllv2_embedding_similarity_score"] = batched_cosine_similarity(anchor_embeddings, positive_embeddings, batch_size=10000)

Cosine Similarity: 100%|██████████| 81/81 [00:00<00:00, 106.46it/s]


In [6]:
biber_data_genre_emb_anchorpos_df.reset_index(drop=True, inplace=True)

In [None]:
op_file_path = "/data/araghavan/HIATUS/datadreamer-ta2/data/biber/genre_data/biber_train_anchorpos_raw_v001.jsonl"
biber_data_genre_emb_anchorpos_df.to_json(op_file_path, orient="records", lines=True)

# Biber 3 Epochs Cached Hard Mix Batches

In [None]:
import os
import time
import random
import faiss
import numpy as np
import pandas as pd
from random import Random
from transformers import TrainerCallback
epoch_tracker = {}
epoch_tracker['epoch'] = 0
class EpochTrackerCallback(TrainerCallback):
    def on_epoch_begin(self, args, state, control, **kwargs):
        epoch_tracker['epoch'] = int(state.epoch)

    def on_epoch_end(self, args, state, control, **kwargs):
        epoch_tracker['epoch'] = int(state.epoch)

def get_anchorpos_biber_hardmix_batches_optimized(file_df, split, args, uniq_biber_genres):
    """
    Optimized version that:
     1) Builds and keeps a single FAISS index
     2) Uses a mask to "remove" docs from availability
     3) Writes each batch to JSONL with a 'batch_id' instead of storing all in memory
    """
    print(f"Called get_anchorpos_biber_hardmix_batches_optimized with split: {split}")
    # ---------------------
    # 1. Load & sample data
    # ---------------------

    print(f"Passed df with shape: {file_df.shape}")

    # ---------------------
    # 2. Normalize embeddings once
    # ---------------------
    def normalize(vecs):
        vecs_ = np.vstack(vecs).astype(np.float32)
        norms = np.linalg.norm(vecs_, axis=1, keepdims=True)
        norms = np.clip(norms, 1e-8, np.inf)
        return vecs_ / norms

    all_embeddings = normalize(file_df["doc_luarmud_embedding_anchor"].values)
    ids_array = file_df.index.to_numpy()

    # ---------------------
    # 3. Build a single FAISS index
    # ---------------------
    dim = all_embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)       # inner-product index
    index = faiss.IndexIDMap(index)
    index.add_with_ids(all_embeddings, ids_array.astype(np.int64))

    # Boolean mask for availability
    available_mask = np.ones(len(file_df), dtype=bool)

    # For convenience
    sbert_sim = file_df["sbertamllv2_embedding_similarity_score"].values
    author_ids = file_df["authorID"].values
    biber_genre = file_df["biber_pred_genre_anchor"].values
    doc_ids = file_df["documentID_anchor"].values  # or unique ID

    # ---------------------
    # Batch parameters
    # ---------------------
    batch_query_needed_size = int(0.3 * args.batch_size)
    batch_candidates_needed_size = int(0.3 * args.batch_size)
    print(
        f"Generating batches with batch_query_needed_size={batch_query_needed_size} "
        f"and batch_candidates_needed_size={batch_candidates_needed_size}"
    )

    # We'll count how many batches we produce & how many rows total
    batch_id = args.batch_id
    batch_id_ctr = 0
    total_rows = 0

    # File path to write out batches
    timestamp = int(time.time())
    output_path_mixedbatches = os.path.join(args.model_op_dir, args.op_file_name)

    # We'll do a random object with a seed based on epoch
    epoch_seed = getattr(epoch_tracker, "epoch", 0)
    rng = Random(epoch_seed)

    while True:
        # ----------------------------------------------
        # 4.1. Pick queries from docs with sbert_sim < threshold & still available
        # ----------------------------------------------

        # Distribution of biber_pred_genre is not the same, there are some genres with very few samples, while some with more. While selecting queries, for a particular genre, if the number of samples remaining is less than the batch_query_needed_size, then we can skip that genre and move to the next one, until we dont find for multiple batches
        # We can also use a random seed to ensure we get different genres each time
        local_genres = list(uniq_biber_genres)  # make a copy so global list isn't modified
        selected_biber_genre = None
        while local_genres:
            genre_try = rng.choice(local_genres)
            genre_mask = (biber_genre == genre_try)
            sim_mask = (sbert_sim < args.sem_sim_threshold)
            genre_count = np.sum(genre_mask & available_mask & sim_mask)
            if genre_count >= args.batch_size:
                selected_biber_genre = genre_try
                break
            else:
                local_genres.remove(genre_try)

        if selected_biber_genre is None:
            print(f"Breaking at batch_id_ctr={batch_id_ctr}, batch_id={batch_id} due to no valid genres with enough available docs.")
            break

        # Now we have a selected genre with enough samples
        # Filter the valid queries based on the selected genre
        valid_query_mask = (sbert_sim < args.sem_sim_threshold) & (available_mask) & (biber_genre == selected_biber_genre)
        valid_queries = np.nonzero(valid_query_mask)[0]  # array of doc indices

        if len(valid_queries) < batch_query_needed_size:
            print(f"Not enough queries for genre %s, available count: %d", selected_biber_genre, len(valid_queries))
            # Not enough queries to form a batch => break
            break

        # Sample the queries
        chosen_query_indices = rng.sample(list(valid_queries), batch_query_needed_size)
        # Mark them unavailable
        for idx in chosen_query_indices:
            available_mask[idx] = False

        # Collect their authors
        chosen_query_authors = set(author_ids[idx] for idx in chosen_query_indices)

        # ----------------------------------------------
        # 4.2. For each query, search top-K in FAISS
        # ----------------------------------------------
        chosen_query_embeddings = all_embeddings[chosen_query_indices, :]
        # Instead of searching entire corpus_size, limit to a top-K
        top_k = 1000
        distances, neighbors = index.search(chosen_query_embeddings, top_k)

        # We'll gather candidate IDs from these searches
        candidate_indices_collector = []
        max_cands_per_query = max(1, batch_candidates_needed_size // batch_query_needed_size)

        used_authors_in_batch = set(chosen_query_authors)

        for i, q_idx in enumerate(chosen_query_indices):
            q_author = author_ids[q_idx]
            picks_for_this_query = 0

            for rank, cand_id in enumerate(neighbors[i]):
                if cand_id == -1:
                    # FAISS sometimes returns -1 if no neighbor
                    continue

                # Already used in a prior batch?
                if not available_mask[cand_id]:
                    continue

                # Check similarity threshold
                sim_score = distances[i][rank]
                if sim_score < args.sty_sim_threshold:
                    continue

                cand_author = author_ids[cand_id]
                # Must differ from query author & not used in this batch
                if cand_author == q_author or cand_author in used_authors_in_batch:
                    continue

                # Check if the candidate belongs to the same biber genre
                if biber_genre[cand_id] != selected_biber_genre:
                    continue

                # If it passes all checks, we pick it
                candidate_indices_collector.append(cand_id)
                used_authors_in_batch.add(cand_author)
                picks_for_this_query += 1

                if picks_for_this_query >= max_cands_per_query:
                    break  # done for this query

        # We might have more than needed => sample down
        if len(candidate_indices_collector) > batch_candidates_needed_size:
            candidate_indices_collector = rng.sample(candidate_indices_collector, batch_candidates_needed_size)

        # Mark them unavailable
        for idx in candidate_indices_collector:
            available_mask[idx] = False

        # ----------------------------------------------
        # 4.3. Random docs to fill remainder of the batch
        # ----------------------------------------------
        batch_query_size = len(chosen_query_indices)
        batch_cand_size = len(candidate_indices_collector)
        random_needed = args.batch_size - batch_query_size - batch_cand_size

        random_indices_chosen = []
        if random_needed > 0:
            genre_avail_mask = available_mask & (biber_genre == selected_biber_genre)
            # Restrict available docs to same genre
            avails = np.nonzero(genre_avail_mask)[0]

            if len(avails) < random_needed:
                print(f"Not enough random docs for genre %s, available count: %d", selected_biber_genre, len(avails))
                # Not enough docs => break
                break

        # used_authors_in_batch = set(author_ids[np.concatenate([chosen_query_indices, candidate_indices_collector])])
        used_authors_in_batch = set(author_ids[np.concatenate(
            [
                np.array(chosen_query_indices, dtype=int), 
                np.array(candidate_indices_collector, dtype=int)
            ]
        )])


        for idx in rng.sample(list(avails), len(avails)):
            if author_ids[idx] in used_authors_in_batch:
                continue
            random_indices_chosen.append(idx)
            used_authors_in_batch.add(author_ids[idx])
            if len(random_indices_chosen) == random_needed:
                break


        all_batch_indices = chosen_query_indices + candidate_indices_collector + random_indices_chosen
        # If we can't form a full batch, break
        if len(all_batch_indices) < args.batch_size:
            print(f"Breaking at batch_id_ctr={batch_id_ctr}, batch_id={batch_id} due to insufficient rows.")
            break


        for idx in random_indices_chosen:
            available_mask[idx] = False

        # ----------------------------------------------
        # 4.4. Construct the DataFrame for the batch & write it
        # ----------------------------------------------
        # Build the batch DF
        batch_df = file_df.iloc[all_batch_indices].copy()
        batch_df["batch_id"] = batch_id

        # Write this batch to the JSONL file in append mode
        # Each row has a field "batch_id"
        # batch_df.to_json(output_path_mixedbatches, orient="records", lines=True, mode="a")
        with open(output_path_mixedbatches, "a") as f:
            batch_df.to_json(f, orient="records", lines=True)


        batch_id += 1
        batch_id_ctr += 1
        total_rows += len(batch_df)

        print(f"\r[Progress] Generated batch_id_ctr={batch_id_ctr}, batch_id={batch_id}, total_rows={total_rows}, left={available_mask.sum()} ", end="", flush=True)

    print("\nDone generating batches.")
    args.batch_id = batch_id
    print(f"Total batches: {batch_id_ctr}, batch_id: {batch_id}, args.batch_id: {args.batch_id}  | Total rows: {total_rows}")

## Create args namespace with necessary attributes
class Args:
    def __init__(self):
        self.batch_size = 32
        self.sem_sim_threshold = 0.2
        self.sty_sim_threshold = 0.3
        self.model_op_dir = "/data/araghavan/HIATUS/datadreamer-ta2/data/biber/genre_data/"
        self.op_file_name = f"biber_train_sadiri_anchorpos_cachedhardmixbatches_epochs3_{int(time.time())}.jsonl"
        self.batch_id = 0

args = Args()
file_df = biber_data_genre_emb_anchorpos_df.copy()
uniq_biber_genres = biber_data_genre_emb_anchorpos_df.biber_pred_genre_anchor.unique()
get_anchorpos_biber_hardmix_batches_optimized(file_df, "train", args, uniq_biber_genres)

Called get_anchorpos_biber_hardmix_batches_optimized with split: train
Passed df with shape: (807442, 102)
Generating batches with batch_query_needed_size=9 and batch_candidates_needed_size=9
[Progress] Generated batch_id_ctr=22469, batch_id=22469, total_rows=719008, left=88434  Breaking at batch_id_ctr=22469, batch_id=22469 due to no valid genres with enough available docs.

Done generating batches.
Total batches: 22469, batch_id: 22469, args.batch_id: 22469  | Total rows: 719008


In [23]:
args.batch_id

36194

In [21]:
epoch_tracker['epoch'] = 42

In [22]:
args.sem_sim_threshold = 0.05
args.sty_sim_threshold = 0.4
get_anchorpos_biber_hardmix_batches_optimized(file_df, "train", args, uniq_biber_genres)

Called get_anchorpos_biber_hardmix_batches_optimized with split: train
Passed df with shape: (807442, 102)
Generating batches with batch_query_needed_size=9 and batch_candidates_needed_size=9
[Progress] Generated batch_id_ctr=13725, batch_id=36194, total_rows=439200, left=368242 Breaking at batch_id_ctr=13725, batch_id=36194 due to no valid genres with enough available docs.

Done generating batches.
Total batches: 13725, batch_id: 36194, args.batch_id: 36194  | Total rows: 439200


In [24]:
args.sem_sim_threshold = 0
args.sty_sim_threshold = 0.5
epoch_tracker['epoch'] = 69
get_anchorpos_biber_hardmix_batches_optimized(file_df, "train", args, uniq_biber_genres)

Called get_anchorpos_biber_hardmix_batches_optimized with split: train
Passed df with shape: (807442, 102)
Generating batches with batch_query_needed_size=9 and batch_candidates_needed_size=9
[Progress] Generated batch_id_ctr=8394, batch_id=44588, total_rows=268608, left=538834 Breaking at batch_id_ctr=8394, batch_id=44588 due to no valid genres with enough available docs.

Done generating batches.
Total batches: 8394, batch_id: 44588, args.batch_id: 44588  | Total rows: 268608
