In [None]:
import os
import importlib.util

# for syncing changes to notebook when changing something in VScode
# deprecated/need to update to new IPython which breaks colab
# %load_ext autoreload
# %autoreload 2



REPO_NAME = "Hypencoder-Entity-Linking"
GIT_URL = f"https://github.com/Steve-Falkovsky/{REPO_NAME}.git"
BRANCH_NAME = "Professional-Structure"



# --- COLAB SETUP ---
is_colab = importlib.util.find_spec("google.colab") is not None
if is_colab:
    print("‚òÅÔ∏è Running in Colab...")
    if not os.path.exists(REPO_NAME):
        !git clone -b {BRANCH_NAME} --single-branch {GIT_URL}

    # Move into the downloaded repo (The Root)
    os.chdir(REPO_NAME)



# --- LOCAL SETUP ---
else:
    print("üíª Running Locally...")
    if os.path.basename(os.getcwd()) == "notebooks":
        os.chdir("..")


%pip install -q -e "./hypencoder-paper"

print(f"üìç Working Directory is now: {os.getcwd()}")
print("‚úÖ Environment Ready!")

In [None]:
from datasets import load_dataset

# there are all "positive" pairs"
dataset = load_dataset("Stevenf232/BC5CDR_MeSH2015_nameonly")

README.md:   0%|          | 0.00/31.0 [00:00<?, ?B/s]

name_only_train.jsonl: 0.00B [00:00, ?B/s]

name_only_val.jsonl: 0.00B [00:00, ?B/s]

name_only_test.jsonl: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/2654 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/2559 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2656 [00:00<?, ? examples/s]

In [None]:
# take a sample of entries until I can get the batching fixed
train_pairs = dataset['train']
train_pairs = train_pairs[:100]
print(train_pairs)
print(len(train_pairs))

mention_names = train_pairs['mention']
entity_names = train_pairs['entity']
print(mention_names[:3])
print(entity_names[:3])
print(len(mention_names))


{'mention': ['human immunodeficiency virus', "non-Hodgkin's lymphoma", 'renal cell carsinom', 'Duchenne dystrophy', 'Acute psychosis', 'urethane', 'bronchitis', 'spasticity', 'PG-9', 'neurotoxic', 'sinus bradycardia', 'endophthalmitis', 'itchiness', 'Srl', 'Coronary aneurysm', 'gamma-carboxyglutamate', 'lymphadenopathy', 'thienodiazepine', 'sodium', 'chronic pain', 'intravascular coagulation', 'choanal atresia', 'docetaxel', 'erythroderma', 'renal and kidney disease', 'Adriamycinol', 'hepatic injury', 'syncope', 'Diabetic', 'cancer of the ureter', 'raloxifene', 'a decrease in MAP', 'isoflurane', 'MFL regimen', 'VPA', 'Urinary bladder cancer', 'myotonia congenita', 'xerostomia', 'K', 'hypothyroidism', "Kaposi's sarcoma", 'citrate', 'Ximelagatran', 'arthrogryposis', 'VPU', 'beta-carboline', 'STR', 'albuminuria', 'hoarding', 'lupus', 'ischemic stroke', 'des Arg10 HOE 140', 'Angiotension II', 'aplastic anemia', 'decreased thymus (P < 0.001) and bodyweights', 'degeneration of myelin', 'hear

In [None]:
# Core Hypencoder model for outputing dense vector representations
from hypencoder_cb.modeling.hypencoder import Hypencoder, HypencoderDualEncoder, TextEncoder
from transformers import AutoTokenizer

model_name = "jfkback/hypencoder.2_layer"
model_name = "Stevenf232/hypencoder_BC5CDR"

dual_encoder = HypencoderDualEncoder.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)


query_encoder: Hypencoder = dual_encoder.query_encoder
passage_encoder: TextEncoder = dual_encoder.passage_encoder

config.json:   0%|          | 0.00/940 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/483M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/462 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

In [None]:
# convert from type "datasets" to python list
queries = list(mention_names)
passages = list(entity_names)


# the output of the tokenizer contains 3 fields:
# input_ids, token_type_ids, and attention_mask
# all contain a tensor in the shape (number of queries, max number of tokens)

query_inputs = tokenizer(queries, return_tensors="pt", padding=True, truncation=True)
passage_inputs = tokenizer(passages, return_tensors="pt", padding=True, truncation=True)


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [None]:
print(f"query_inputs:\n{query_inputs}")
print("\n\n\n")
print(f"passage_inputs:\n{passage_inputs}")

query_inputs:
{'input_ids': tensor([[    2,  2616, 13141,  ...,     0,     0,     0],
        [    2,  2447,    17,  ...,     0,     0,     0],
        [    2,  4604,  2024,  ...,     0,     0,     0],
        ...,
        [    2,  7409, 13569,  ...,     0,     0,     0],
        [    2, 14948,  2960,  ...,     0,     0,     0],
        [    2, 15924,  2912,  ...,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}




passage_inputs:
{'input_ids': tensor([[    2,  3525,  5024,  ...,     0,     0,     0],
        [    2,  8636,    16,  ...,     0, 

In [None]:
import torch
from hypencoder_cb.modeling.q_net import NoTorchSequential, NoTorchLinear, NoTorchDenseBlock

def batched_encode_qnets(encoder, inputs, batch_size=2, device="cuda"):
    """
    Batched encoding for the hypernetwork that uses PRE-ALLOCATION
    to avoid CPU RAM crashes.
    """

    # These will hold the FINAL, full-sized tensors, not lists of chunks
    final_weights = []
    final_biases = []
    layer_configs = [] # Still need this to rebuild the object

    num_total_queries = len(inputs["input_ids"])

    encoder.to(device)
    encoder.eval()

    with torch.no_grad():
        for i in range(0, num_total_queries, batch_size):
            raw_batch = {k: v[i:i+batch_size].to(device) for k, v in inputs.items()}

            valid_keys = ["input_ids", "attention_mask"]
            batch = {k: v for k, v in raw_batch.items() if k in valid_keys}

            q_net_batch_output = encoder(**batch)
            q_net_batch = q_net_batch_output.representation

            # --- This is the new, memory-efficient logic ---

            # 1. On the FIRST batch, set up final tensors
            if i == 0:
                for j, layer in enumerate(q_net_batch.layers):
                    config = {'type': None}

                    if isinstance(layer, (NoTorchLinear, NoTorchDenseBlock)):
                        linear_part = layer if isinstance(layer, NoTorchLinear) else layer.linear

                        # Get shape and dtype from the first batch
                        w_shape = linear_part.weight.shape
                        w_dtype = linear_part.weight.dtype

                        # Create the full-sized EMPTY tensor on CPU
                        # Shape is (Total_Queries, In_Dim, Out_Dim)
                        final_w_tensor = torch.empty(
                            (num_total_queries, *w_shape[1:]),
                            dtype=w_dtype,
                            device='cpu'
                        )
                        final_weights.append(final_w_tensor)

                        # Handle biases
                        if linear_part.bias is not None:
                            b_shape = linear_part.bias.shape
                            b_dtype = linear_part.bias.dtype
                            # Shape is (Total_Queries, 1, Out_Dim)
                            final_b_tensor = torch.empty(
                                (num_total_queries, *b_shape[1:]),
                                dtype=b_dtype,
                                device='cpu'
                            )
                            final_biases.append(final_b_tensor)
                        else:
                            final_biases.append(None) # Mark that this layer has no bias

                        # Save layer configs
                        if isinstance(layer, NoTorchLinear):
                            config['type'] = 'linear'
                        else:
                            config['type'] = 'dense'
                            config['activation'] = layer.activation
                            config['do_layer_norm'] = layer.do_layer_norm
                            config['do_residual'] = layer.do_residual
                            config['dropout_prob'] = layer.dropout_prob
                            config['layer_norm_before_residual'] = layer.layer_norm_before_residual

                    layer_configs.append(config)

            # 2. Copy the current batch's data into the correct slice
            start_idx = i
            end_idx = i + q_net_batch.num_queries # Use num_queries from object for safety

            bias_idx = 0
            for j, layer in enumerate(q_net_batch.layers):
                if isinstance(layer, (NoTorchLinear, NoTorchDenseBlock)):
                    linear_part = layer if isinstance(layer, NoTorchLinear) else layer.linear

                    # Copy weights
                    final_weights[j][start_idx:end_idx] = linear_part.weight.cpu()

                    # Copy biases
                    if linear_part.bias is not None:
                        final_biases[bias_idx][start_idx:end_idx] = linear_part.bias.cpu()
                        bias_idx += 1

            del raw_batch, batch, q_net_batch_output, q_net_batch
            torch.cuda.empty_cache()

    # --- After the loop ---
    # The tensors are already built! No torch.cat() needed.

    # 3. Rebuild the final, single NoTorchSequential object
    final_layers = []
    bias_idx = 0
    for i, config in enumerate(layer_configs):
        weight = final_weights[i]

        if config['type'] == 'linear':
            final_layers.append(NoTorchLinear(weight=weight, bias=None))

        elif config['type'] == 'dense':
            bias = final_biases[bias_idx]
            bias_idx += 1

            final_layers.append(
                NoTorchDenseBlock(
                    weight=weight,
                    bias=bias,
                    activation=config['activation'],
                    do_layer_norm=config['do_layer_norm'],
                    do_residual=config['do_residual'],
                    do_dropout=False,
                    dropout_prob=config['dropout_prob'],
                    layer_norm_before_residual=config['layer_norm_before_residual'],
                )
            )

    final_q_net = NoTorchSequential(final_layers, num_queries=num_total_queries)

    return final_q_net

In [None]:
q_nets = batched_encode_qnets(query_encoder, query_inputs)

In [None]:
# batch_size=2
# for i in range(0, len(queries), batch_size):
#     batch = query_inputs[i:i+batch_size]
#     q_nets = query_encoder(input_ids=query_inputs["input_ids"], attention_mask=query_inputs["attention_mask"]).representation


In [None]:
import torch

def batched_encode_passages(encoder, inputs, batch_size=8, device="cuda"):
    """
    Standard batched encoding for regular models (like passage_encoder)
    that output a simple tensor embedding.
    """
    all_reps = []

    encoder.to(device)
    encoder.eval()
    with torch.no_grad():
        for i in range(0, len(inputs["input_ids"]), batch_size):
            raw_batch = {k: v[i:i+batch_size].to(device) for k, v in inputs.items()}

            # Filter keys
            valid_keys = ["input_ids", "attention_mask"]
            batch = {k: v for k, v in raw_batch.items() if k in valid_keys}

            # Use .pooler_output for standard Hugging Face models
            outputs = encoder(**batch).representation

            all_reps.append(outputs.cpu())

            del raw_batch, batch, outputs
            torch.cuda.empty_cache()

    return torch.cat(all_reps, dim=0)

In [None]:
# from tqdm import tqdm

# def batched_encode_passages(encoder ,passages):
#   batch_size=16
#   # input all names to model so it will create dense vectors of all the names (of the mentions and entities)
#   entity_name_features = []

#   for i in tqdm(range(0, len(passages), batch_size), desc="Extracting features"):
#       # extract entity features
#       print(passages[0])
#       batch = passages[i:i + batch_size]
#       print(batch[0])
#       features = encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).representation
#       entity_name_features.extend(features)

#   return entity_name_features

In [None]:
passage_embeddings = batched_encode_passages(passage_encoder, passage_inputs)

In [None]:
# unbatched encoding (bad for RAM!)

# creates a q_net per query (each query is transformed into a separate q_net)
# q_nets = query_encoder(input_ids=query_inputs["input_ids"], attention_mask=query_inputs["attention_mask"]).representation
# passage_embeddings = passage_encoder(input_ids=passage_inputs["input_ids"], attention_mask=passage_inputs["attention_mask"]).representation


In [None]:
print(q_nets) # this is a collection of neural networks
print(passage_embeddings.shape) # this is a tensor
print(passage_embeddings)

<hypencoder_cb.modeling.q_net.NoTorchSequential object at 0x7b3259fce060>
torch.Size([100, 768])
tensor([[ 0.0039, -0.1667, -0.5780,  ...,  0.4420,  0.1888,  0.0720],
        [-0.0852,  0.0449, -0.2686,  ...,  0.2474,  0.1347,  0.0272],
        [ 0.0132,  0.0376, -0.1731,  ...,  0.2352,  0.1907,  0.1369],
        ...,
        [ 0.0086,  0.0079, -0.2292,  ...,  0.1822,  0.1295,  0.1334],
        [ 0.0096,  0.1095, -0.1781,  ...,  0.2321,  0.0855,  0.2054],
        [-0.1522,  0.0057, -0.2576,  ...,  0.1721,  0.1481,  0.1078]])


In [None]:
N = query_inputs["input_ids"].shape[0] # number of queries (mentions)
M = passage_embeddings.shape[0] # number of passages (entities)
H = passage_embeddings.shape[-1] # Hidden dimension (e.g., 768)

# q_nets expect the shape (N, M, H)
# so we need to reshape the passage_embeddings which are currently of shape (M, H)

In [None]:
# Case 1 - comparing a query to its respective passage

# In the simple case where each q_net only takes one passage, we can just
# reshape the passage_embeddings to (N, 1, H).
# passage_embeddings_single = passage_embeddings.unsqueeze(1)
# print(f"passage_embeddings shape: {passage_embeddings_single.shape}")
# giving the nueral network the input of passage_embeddings
# the output provides the relevance score of query 1 against passage 1, query 2 against passage 2, etc...
# scores = q_nets(passage_embeddings_single)
# print(f"scores: {scores}")

In [None]:
# Case 2 - comparing a query to all passages

# The case where each q_net takes multiple passages
# meaning multiple passages are now associated with each of the queries
passage_embeddings_multi = passage_embeddings.repeat(N, 1).reshape(N, M, H)
print(f"passage_embeddings shape: {passage_embeddings_multi.shape}")
similarity_scores = q_nets(passage_embeddings_multi)
print(f"similarity_scores shape: {similarity_scores.shape}")
#print(f"similarity_scores: {similarity_scores}")


passage_embeddings shape: torch.Size([100, 100, 768])
similarity_scores shape: torch.Size([100, 100, 1])


In [None]:
similarity_scores

tensor([[[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        ...,

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]]])

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

def evaluate(train_pairs):
  correct_count = 0
  top_idxs = torch.argmax(similarity_scores,dim=1).flatten()

  for i in range(len(queries)):
      top_idx = top_idxs[i]
      # the conversion to int from here on out is because the original idx is of type numpy.int64
      top_match_id = train_pairs["id"][int(top_idx)]
      correct_id = train_pairs["id"][int(i)]

      if top_match_id == correct_id:
          correct_count += 1

      mention_name = train_pairs["mention"][int(i)]
      top_match = train_pairs["entity"][int(top_idx)]
      correct_name = train_pairs["entity"][int(i)]
      print(f"mention_name: {mention_name}\ncorrect entity name: {correct_name}\ntop_match: {top_match}\n")


  print(f"total comparisons: {len(queries)}")
  print(f"correct comparisons: {correct_count}")
  print(f"accuracy: {correct_count / len(queries)}")

In [None]:
evaluate(train_pairs)

mention_name: human immunodeficiency virus
correct entity name: HIV Infections
top_match: HIV Infections

mention_name: non-Hodgkin's lymphoma
correct entity name: Lymphoma, Non-Hodgkin
top_match: HIV Infections

mention_name: renal cell carsinom
correct entity name: Carcinoma, Renal Cell
top_match: HIV Infections

mention_name: Duchenne dystrophy
correct entity name: Muscular Dystrophy, Duchenne
top_match: N-methyltropan-3-yl 2-(4-bromophenyl)propionate

mention_name: Acute psychosis
correct entity name: Psychoses, Substance-Induced
top_match: HIV Infections

mention_name: urethane
correct entity name: Urethane
top_match: Urethane

mention_name: bronchitis
correct entity name: Bronchitis
top_match: HIV Infections

mention_name: spasticity
correct entity name: Muscle Spasticity
top_match: HIV Infections

mention_name: PG-9
correct entity name: N-methyltropan-3-yl 2-(4-bromophenyl)propionate
top_match: HIV Infections

mention_name: neurotoxic
correct entity name: Neurotoxicity Syndromes

In [None]:
# more evaluation methods
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
# print(f"{accuracy_score(train_labels, predicted_labels)=:.3f}")
# print(f"{recall_score(train_labels, predicted_labels)=:.3f}")
# print(f"{precision_score(train_labels, predicted_labels)=:.3f}")
# print(f"{f1_score(train_labels, predicted_labels)=:.3f}")