In [19]:
from biosyn.dataloader import load_dictionary, load_queries
from transformers import AutoModel, AutoTokenizer
import torch
TRAIN_DICT_PATH = "./data/data-ncbi-fair/train_dictionary.txt"
TRAIN_DIR = "./data/data-ncbi-fair/traindev"

train_dictionary  = load_dictionary(dict_path=TRAIN_DICT_PATH)
train_queries  = load_queries(data_dir=TRAIN_DIR, filter_composite=False, filter_duplicates=False, filter_cuiless=True)

train_dictionary = train_dictionary[:100]
train_queries = train_queries[:10]

tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.1')
encoder = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.1')


max_length = 25


query_names = [row[0] for row in train_queries]
dict_names = [row[0] for row in train_dictionary]
topk = 4

100%|██████████| 90599/90599 [00:00<00:00, 1542375.78it/s]
100%|██████████| 691/691 [00:00<00:00, 5135.14it/s]


In [25]:
import numpy as np
class CandidateDataset(torch.utils.data.Dataset):
    def __init__(self, queries, dicts, tokenizer, max_length, topk, pre_tokenize):
        """
        Retrieve top-k candidates based on dense embedding
        Parameters
        ----------
        queries : list
            A list of tuples (name, id)
        dicts : list
            A list of tuples (name, id)
        tokenizer : BertTokenizer
            A BERT tokenizer for dense embedding
        topk : int
            The number of candidates
        """
        self.query_names, self.query_ids = [row[0] for row in queries], [row[1] for row in queries]
        self.dict_names, self.dict_ids = [row[0] for row in dicts], [row[1] for row in dicts]
        self.tokenizer= tokenizer
        self.max_length = max_length
        self.topk = topk
        self.d_cand_idxs = None
        self.pre_tokenize = pre_tokenize
        if pre_tokenize:
            all_query_names_tokens = self.tokenizer(self.query_names, max_length=max_length,padding='max_length', truncation=True, return_tensors='pt' )
            self.all_query_names_tokens = [
                {
                    "input_ids": all_query_names_tokens["input_ids"][idx],
                    "attention_mask": all_query_names_tokens["attention_mask"][idx],
                } for  idx in range(len(all_query_names_tokens["input_ids"]))]

            self.all_dict_names_tokens= self.tokenizer(self.dict_names, max_length=max_length,padding='max_length', truncation=True, return_tensors='pt')

    def set_dense_candidate_idxs(self, d_cand_idxs):
        self.d_cand_idxs = d_cand_idxs

    def __getitem__(self, query_idx):
        """
            Return (query_tokens, cand_tokens), labels
            query_tokens: tokenized the query_name (query_name is query_names[query_idx] the specific mention)
            cand_tokens: 
        """
        assert (self.d_cand_idxs is not None)

        if self.pre_tokenize:
            query_tokens = self.all_query_names_tokens[query_idx]
        else:
            query_name = self.query_names[query_idx]
            query_tokens = self.tokenizer(query_name, max_length=self.max_length,padding='max_length', truncation=True, return_tensors='pt' )

        d_cand_idxs = self.d_cand_idxs[query_idx]
        topk_candidate_idx = np.array(d_cand_idxs)

        assert len(topk_candidate_idx) == self.topk
        assert len(topk_candidate_idx) == len(set(topk_candidate_idx))


        if self.pre_tokenize:
            cand_idxs_tensor = torch.as_tensor(topk_candidate_idx, dtype=torch.long)
            cand_tokens = {
                k: v.index_select(0, cand_idxs_tensor)
                for k, v in self.all_dict_names_tokens.items()
                if isinstance(v, torch.Tensor)
            }
        else:
            cand_names = [self.dict_names[cand_idx] for cand_idx in topk_candidate_idx]
            cand_tokens = self.tokenizer(cand_names, max_length=self.max_length, padding="max_length" , truncation=True, return_tensors="pt")

        labels = self.get_labels(query_idx, topk_candidate_idx).astype(np.float32)

        return (query_tokens, cand_tokens), labels


    def __len__(self):
        return len(self.query_names)

    def check_label(self, query_id, candidate_id_set):
        """
            check if all q_id in query_id.split("|") exists in candidate_id_set 
        """
        label = 0
        query_ids = query_id.split("|")
        for q_id in query_ids:
            if q_id in candidate_id_set:
                label = 1
                continue
            else:
                label = 0
                break
        return label
    
    def get_labels(self, query_idx, candidate_idxs):
        labels = np.array([])
        query_id = self.query_ids[query_idx]
        candidate_ids = np.array(self.dict_ids)[candidate_idxs]
        for candidate_id in candidate_ids:
            label = self.check_label(query_id, candidate_id)
            labels = np.append(labels, label)
        return labels


In [26]:
# cand_idxs = np.array([
#     np.random.choice(len(train_dictionary), size=topk, replace=False  )
#     for _ in range(len(train_queries))
# ], dtype=np.int64)

cand_idxs = np.array([[ 6, 58, 40, 55],
       [60, 23, 21, 56],
       [25, 85,  3, 74],
       [64, 32, 14, 86],
       [14, 99, 90, 10],
       [72, 37, 56, 83],
       [14, 13, 44, 20],
       [83, 11, 42, 41],
       [11, 61, 54, 68],
       [62, 32, 84, 90]])



In [27]:
train_set = CandidateDataset(
        queries = train_queries, 
        dicts = train_dictionary, 
        tokenizer = tokenizer, 
        max_length = max_length, 
        topk= topk,
        pre_tokenize = True
    )



train_set.set_dense_candidate_idxs(cand_idxs)

In [28]:
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=10, shuffle=False
)

In [32]:
def model(x):
    

In [34]:

for i, data in enumerate(train_loader):
    x, batch_y = data
    query_token, candidate_tokens = x
    batch_size, topk, max_length = candidate_tokens['input_ids'].shape
    query_embed = encoder(
            input_ids=query_token["input_ids"].squeeze(1),
            attention_mask=query_token["attention_mask"].squeeze(1),
        )
    query_embed = query_embed[0][:,0].unsqueeze(1) #(B, 1, H) ??????? WHY [0] last hidden state


    candidate_embeds = encoder(
        input_ids=candidate_tokens['input_ids'].reshape(-1, max_length),
        attention_mask=candidate_tokens['attention_mask'].reshape(-1, max_length)
    )
    candidate_embeds = candidate_embeds[0][:,0].reshape(batch_size, topk, -1) # [batch_size, topk, hidden]


    score_1 = torch.bmm(query_embed, candidate_embeds.permute(0,2,1)).squeeze(1)
    print(f"score: {score_1} " )


score: tensor([[144.6959, 144.7281, 153.0147, 152.2294],
        [148.9545, 148.6415, 153.7524, 156.2342],
        [142.7202, 148.5990, 134.1674, 144.6824],
        [152.7010, 152.6938, 143.5929, 148.1381],
        [141.6751, 152.5371, 150.6783, 141.5227],
        [158.3967, 151.0699, 151.5319, 156.7922],
        [143.5973, 144.2186, 149.0260, 142.3616],
        [144.0762, 157.9371, 151.7619, 154.5725],
        [145.2944, 139.2736, 142.4796, 151.8815],
        [140.9750, 143.5386, 144.1806, 143.8584]], grad_fn=<SqueezeBackward1>) 


In [40]:

for i, data in enumerate(train_loader):
    x, batch_y = data
    query_token, candidate_tokens = x
    batch_size, topk, max_length = candidate_tokens['input_ids'].shape
    query_embed = encoder(**query_token, return_dict=True)
    query_embed = query_embed.last_hidden_state[:, 0, :].unsqueeze(1) #(B, 1, H


    candidate_embeds = encoder(
        input_ids=candidate_tokens['input_ids'].reshape(-1, max_length),
        attention_mask=candidate_tokens['attention_mask'].reshape(-1, max_length)
    )
    candidate_embeds = candidate_embeds[0][:,0].reshape(batch_size, topk, -1) # [batch_size, topk, hidden]


    score_2 = torch.bmm(query_embed, candidate_embeds.permute(0,2,1)).squeeze(1)
    print(f"score: {score_2} " )


score: tensor([[144.6959, 144.7281, 153.0147, 152.2294],
        [148.9545, 148.6415, 153.7524, 156.2342],
        [142.7202, 148.5990, 134.1674, 144.6824],
        [152.7010, 152.6938, 143.5929, 148.1381],
        [141.6751, 152.5371, 150.6783, 141.5227],
        [158.3967, 151.0699, 151.5319, 156.7922],
        [143.5973, 144.2186, 149.0260, 142.3616],
        [144.0762, 157.9371, 151.7619, 154.5725],
        [145.2944, 139.2736, 142.4796, 151.8815],
        [140.9750, 143.5386, 144.1806, 143.8584]], grad_fn=<SqueezeBackward1>) 


In [41]:
torch.eq(score_1, score_2)

tensor([[True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True]])