<a href="https://colab.research.google.com/github/Steve-Falkovsky/Hypencoder-Entity-Linking/blob/main/notebooks/BC5CDR_nameonly_hard_negative_mining.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import importlib.util

REPO_NAME = "Hypencoder-Entity-Linking"
GIT_URL = f"https://github.com/Steve-Falkovsky/{REPO_NAME}.git"
BRANCH_NAME = "main"

if not os.path.exists(REPO_NAME):
    !git clone -b {BRANCH_NAME} --single-branch {GIT_URL}

    # Move into the downloaded repo (The Root)
    os.chdir(REPO_NAME)


%pip install -q -e "./hypencoder-paper"

os.chdir("hypencoder-paper")

print(f"üìç Working Directory is now: {os.getcwd()}")
print("‚úÖ Environment Ready!")

In [None]:
from datasets import load_dataset

# there are all "positive" pairs"
dataset = load_dataset("Stevenf232/BC5CDR_MeSH2015_nameonly")

### Load the model

In [None]:
# Core Hypencoder model for outputing dense vector representations
from hypencoder_cb.modeling.hypencoder import Hypencoder, HypencoderDualEncoder, TextEncoder
from transformers import AutoTokenizer

model_name = "Stevenf232/SapBERT_freeze_hypencoder"

dual_encoder = HypencoderDualEncoder.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)


query_encoder: Hypencoder = dual_encoder.query_encoder
passage_encoder: TextEncoder = dual_encoder.passage_encoder

### Move the model to the GPU

In [None]:
import torch

# Setup the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")  # This should say 'cuda'

# Move the model to the GPU
passage_encoder.to(device)
query_encoder.to(device)

### Load datasets and tokenise

# Passage Encodings


In [None]:
from tqdm import tqdm
import torch
from torch.amp import autocast

def batch_encode_passages(encoder ,passages):
  batch_size=256
  entity_name_features = []

  num_passages = passages["input_ids"].shape[0]

  with torch.no_grad(): # Disable gradient calculation (saves tons of memory)
    for i in tqdm(range(0, num_passages, batch_size), desc="Extracting features"):

        # extract entity features
        # Autocast does the math in fp16 where possible (default is fp32)
        # this will save memory and increase speed. The loss in precision shouldn't matter much (can check on a small sample if we want)
        with autocast("cuda"):
          features = encoder(
              input_ids=passages["input_ids"][i:i + batch_size].to(device),
              attention_mask=passages["attention_mask"][i:i + batch_size].to(device)
            ).representation

          entity_name_features.append(features.detach().cpu()) # Detach and move to CPU to save VRAM/RAM


  features_tensor = torch.cat(entity_name_features, dim=0)

  return features_tensor

# Q-nets take **a lot** of memory.

Instead of creating all of them and then doing the similarity calculation, we will create batches and calculate similarities for just those q-nets, then discard those q-nets and move on to the next batch.

In [None]:
def batch_encode_queries(encoder, queries, passage_embeddings):
  batch_size = 8
  similarity_scores = []

  num_queries = queries["input_ids"].shape[0]

  with torch.no_grad():
    for i in tqdm(range(0, num_queries, batch_size), desc="Creating q-nets and calculating similarity scores"):

        # create q-nets
        with autocast("cuda"):
          q_nets = encoder(
              input_ids=queries["input_ids"][i:i + batch_size].to(device),
              attention_mask=queries["attention_mask"][i:i + batch_size].to(device)
            ).representation


        passages_gpu = passage_embeddings.to(device)

        # Note: we use q_nets.num_queries (our repo's noTorch equivalent of q_nets.shape[0]) instead of batch_size
        # because the total number might not be divisible by batch_size so the last batch might be smaller than the actual batch size
        passages_batch = passages_gpu.unsqueeze(0).expand(q_nets.num_queries, -1, -1)

        # calculate similarity
        batch_scores = q_nets(passages_batch)
        similarity_scores.append(batch_scores.detach().cpu())


  scores_tensor = torch.cat(similarity_scores, dim=0)
  return scores_tensor


## Create a dataset of Hard Negatives based on "Negative Hard Mining"
We take the top "incorrect" item similarities of each query as negatives

In [None]:
"""
Desired format for each line in the JSONL file:
{
  "query": {
    "id": query ID,
    "content": query text,
  },
  "items": [
    {
      "id": passage ID,
      "content": passage text,
      "score": Optional teacher score,
      "type": Sometimes used to specify type of item,
    },
    {
        # another item
    },
  ]
}

Contrastive Loss with Hard Negatives: The positive must be the first item, all following items
will be treated as negative
"""

# Perform Hard Negative Mining

In [None]:
import torch
import numpy as np

def write_hardneg_contrastive_jsonl_masked(
    pairs,
    similarity_scores: torch.Tensor,
    output_jsonl_path: str,
    num_negatives: int = 8,
):

    output_jsonl_path = Path(output_jsonl_path)
    output_jsonl_path.parent.mkdir(parents=True, exist_ok=True)

    similarity_scores = similarity_scores.detach().cpu()

    similarity_scores = similarity_scores.squeeze(-1) # flatten last dimenstion
    N, M = similarity_scores.shape

    # Integer Mapping for Speed
    # Convert string IDs to integers. unique_ids[inverse_indices[i]] == pairs["id"][i]
    # We use 'inverse_indices' to check for equality instantly (int vs int)
    # np.unique returns 'inverse' which are the integer indices for each string
    unique_ids, inverse_indices = np.unique(pairs["id"], return_inverse=True)
    id_tensor = torch.from_numpy(inverse_indices) # Shape: (N,)

    # Compare every mention ID (row) against every entity ID (col)
    # unsqueeze(1) makes it (N, 1)
    # unsqueeze(0) makes it (1, M)
    # The mask matches shape of similarity_scores (N, M)
    mask = (id_tensor.unsqueeze(1) == id_tensor.unsqueeze(0))

    # Apply Mask
    # Set any cell where mask is True (same ID) to -infinity.
    # This hides the query itself AND any other rows with the same ID.
    similarity_scores.masked_fill_(mask, -float('inf'))

    # Get Top-K Negatives
    # Since positives are now -inf, topk will only return valid negatives.
    _, top_idxs = torch.topk(similarity_scores, k=num_negatives, dim=1)


    # Extract Negatives with Deduplication
    with output_jsonl_path.open("w", encoding="utf-8") as f:
        for i in range(N):
            q_id = pairs["id"][i]

            # --- Dynamic Fetching Loop ---
            # We need 'num_negatives' unique items.
            # Since duplicates might clog the top of the list, we fetch more than we need.
            # Start with a safe buffer (e.g., 3x what we need + 32).
            k_attempt = (num_negatives * 4) + 32

            valid_neg_indices = []
            seen_entity_ids = set()

            while len(valid_neg_indices) < num_negatives:
                # Cap k at M (total entities)
                if k_attempt > M:
                    k_attempt = M

                # Get top K candidates
                # (Since positives are -inf, these are guaranteed to be negatives)
                _, candidates = torch.topk(similarity_scores[i], k=k_attempt, dim=0, largest=True)

                # Reset collection for this attempt with new k
                valid_neg_indices = []
                seen_entity_ids = set()

                for idx in candidates.tolist():
                    ent_int_id = inverse_indices[idx] # Get the integer ID of this candidate

                    # Deduplicate: Have we seen this Entity ID in this negative list yet?
                    if ent_int_id in seen_entity_ids:
                        continue # Skip duplicate negative

                    # Found a new unique negative
                    seen_entity_ids.add(ent_int_id)
                    valid_neg_indices.append(idx)

                    if len(valid_neg_indices) >= num_negatives:
                        break

                # Check exit conditions
                if len(valid_neg_indices) >= num_negatives:
                    break # Success!
                if k_attempt >= M:
                    break # We searched the entire dataset

                # Not enough unique negatives found? Double the search radius.
                k_attempt *= 2

            # --- Write to File ---
            # Positive
            pos_item = {
                "id": pairs["id"][i],
                "content": pairs["entity"][i],
                "score": None, "type": None
            }

            # Negatives
            neg_items = []
            for idx in valid_neg_indices:
                neg_items.append({
                    "id": pairs["id"][idx],
                    "content": pairs["entity"][idx],
                    "score": None, "type": None
                })

            # Warn if dataset is too small/repetitive
            if len(neg_items) < num_negatives:
                 print(f"Warning: Query {i} found only {len(neg_items)} unique negatives.")

            entry = {
                "query": {"id": q_id, "content": pairs["mention"][i]},
                "items": [pos_item, *neg_items],
            }
            json.dump(entry, f, ensure_ascii=False)
            f.write("\n")

    print(f"Wrote {N} lines to {output_jsonl_path}")

In [None]:
# generate jsonl for contrastive loss for train/val/test splits

data_splits = ("train", "validation", "test")
seen = set()
splits = [s for s in data_splits if (s in dataset and not (s in seen or seen.add(s)))]

for split in splits:
    print(f"Starting {split} split")
    pairs = dataset[split]

    # build query/passage lists for this split
    queries = list(pairs["mention"])
    passages = list(pairs["entity"])

    # tokenize
    query_inputs = tokenizer(queries, return_tensors="pt", padding=True, truncation=True)
    passage_inputs = tokenizer(passages, return_tensors="pt", padding=True, truncation=True)

    # encode + score
    passage_embeddings = batch_encode_passages(passage_encoder, passage_inputs)
    similarity_scores = batch_encode_queries(query_encoder, query_inputs, passage_embeddings)

    # write jsonl
    write_hardneg_contrastive_jsonl_masked(
        pairs=pairs,
        similarity_scores=similarity_scores,
        output_jsonl_path=f"bc5cdr_{split}_hypencoder_contrastive.jsonl",
        num_negatives=8,
    )

    print("\n\n")

## Upload dataset to HuggingFace

In [None]:
from datasets import load_dataset
from pathlib import Path

data_files = {
    "train": "bc5cdr_train_hypencoder_contrastive.jsonl",
    "validation": "bc5cdr_validation_hypencoder_contrastive.jsonl",
    "test": "bc5cdr_test_hypencoder_contrastive.jsonl",
}

# Load as a single DatasetDict
ds = load_dataset("json", data_files=data_files)

# Push to Hub
repo_id = "Stevenf232/BC5CDR_nameonly_hard_negative_mining"
ds.push_to_hub(repo_id, private=False)