In [1]:
from biosyn.dataloader import load_dictionary, load_queries
from transformers import AutoModel, AutoTokenizer
import torch
TRAIN_DICT_PATH = "./data/data-ncbi-fair/train_dictionary.txt"
TRAIN_DIR = "./data/data-ncbi-fair/traindev"

train_dictionary  = load_dictionary(dict_path=TRAIN_DICT_PATH)
train_queries  = load_queries(data_dir=TRAIN_DIR, filter_composite=False, filter_duplicates=False, filter_cuiless=True)

train_dictionary = train_dictionary[:50]
train_queries = train_queries[:10]

tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.1')
encoder = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.1')


max_length = 25


query_names = [row[0] for row in train_queries]
dict_names = [row[0] for row in train_dictionary]
topk = 4

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 90599/90599 [00:00<00:00, 1436449.35it/s]
100%|██████████| 691/691 [00:00<00:00, 6252.94it/s]


In [4]:
import numpy as np
from torch.utils.data import DataLoader
from transformers import default_data_collator

class NamesDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self,idx):
        return {key: torch.tensor(val[idx]) for key,val in self.encodings.items()}
    def __len__(self):
        return len(self.encodings.input_ids)


def embed_dense(names):
        encoder.eval()
        batch_size = 128
        dense_embeds = []
        if isinstance(names, np.ndarray):
            names = names.tolist()
        name_encodings = tokenizer(names, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt")

        name_dataset = NamesDataset(name_encodings)
        name_dataloader = DataLoader(name_dataset, shuffle=False, collate_fn=default_data_collator, batch_size=batch_size)

        with torch.no_grad():
            for batch in name_dataloader:
                outputs = encoder(**batch)
                batch_dense_embeds = outputs[0][:,0].cpu().detach().numpy() # [CLS] representations
                dense_embeds.append(batch_dense_embeds)
        dense_embeds = np.concatenate(dense_embeds, axis=0)
        return dense_embeds



In [5]:
dict_embs = embed_dense(names=dict_names).astype("float32")



  return {key: torch.tensor(val[idx]) for key,val in self.encodings.items()}


In [26]:
topk_cand_idxs = [0,10,5,3]
cands_embs_1 = dict_embs[topk_cand_idxs]
cands_embs_1 = torch.from_numpy( cands_embs_1.astype(np.float32, copy=False) )
cands_embs_1.shape


torch.Size([4, 768])

In [None]:

        # candidate_embeds = self.dict_embs[topk_cand_idxs]
        B, K  = topk_cand_idxs.shape
        flat = topk_cand_idxs.reshape(-1)
        candidate_embs = self.dict_embs.index_select(0, flat).reshape(B, K, -1)


In [13]:
all_dict_names_tokens= tokenizer(dict_names, max_length=max_length,padding='max_length', truncation=True, return_tensors='pt')



cand_idxs_tensor = torch.as_tensor(topk_cand_idxs, dtype=torch.long)
cand_tokens = {
    k: v.index_select(0, cand_idxs_tensor)
    for k, v in all_dict_names_tokens.items()
    if isinstance(v, torch.Tensor)
}

candidate_embeds = encoder(
            input_ids=cand_tokens['input_ids'].reshape(-1, max_length),
            attention_mask=cand_tokens['attention_mask'].reshape(-1, max_length)
        )
cand_embs_2 = candidate_embeds[0][:,0].reshape(topk, -1) # [topk, hidden]
cand_embs_2.shape

torch.Size([4, 768])

In [27]:
are_equal_elementwise = torch.all(torch.tensor(cands_embs_1) == cand_embs_2)
are_equal_elementwise

  are_equal_elementwise = torch.all(torch.tensor(cands_embs_1) == cand_embs_2)


tensor(False)

In [28]:
import torch
a = torch.tensor(cands_embs_1)
b= cand_embs_2

# a, b: shape (4, 768)

# 1) Sanity checks
print(a.shape == b.shape, a.dtype, b.dtype)

# 2) Elementwise difference mask (exact equality for non-floats)
if torch.is_floating_point(a) and torch.is_floating_point(b):
    # treat NaNs at same spots as equal; flag others via isclose
    same_nan = torch.isnan(a) & torch.isnan(b)
    close = torch.isclose(a, b, rtol=0.0, atol=0.0)  # exact values for floats
    diff_mask = ~(close | same_nan)
else:
    diff_mask = a != b

num_diff = diff_mask.sum().item()
print(f"Number of differing elements: {num_diff}")

# 3) Indices of differences
idx = diff_mask.nonzero(as_tuple=False)  # shape (num_diff, 2), columns: [row, col]
print("First few differing indices (row, col):")
print(idx[:10])

# 4) Inspect values at those spots
for r, c in idx[:5]:  # show first 5
    r = r.item(); c = c.item()
    print(f"[{r}, {c}]  a={a[r, c].item()}   b={b[r, c].item()}")

# 5) Per-row counts (handy for (4, 768))
per_row = diff_mask.sum(dim=1)
print("Differences per row:", per_row.tolist())

# 6) Magnitude summary (useful for floats)
if torch.is_floating_point(a) and torch.is_floating_point(b):
    max_abs = (a - b).abs().max().item()
    print("Max |a-b|:", max_abs)


True torch.float32 torch.float32
Number of differing elements: 2826
First few differing indices (row, col):
tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
        [0, 4],
        [0, 5],
        [0, 6],
        [0, 7],
        [0, 8],
        [0, 9]])
[0, 0]  a=0.37157875299453735   b=0.3715788722038269
[0, 1]  a=0.23373164236545563   b=0.23373158276081085
[0, 2]  a=-0.2861679196357727   b=-0.2861679494380951
[0, 3]  a=-0.04448281601071358   b=-0.044482868164777756
[0, 4]  a=-0.05088842287659645   b=-0.050888314843177795
Differences per row: [719, 724, 673, 710]
Max |a-b|: 9.5367431640625e-07


  a = torch.tensor(cands_embs_1)
