In [2]:
from biosyn.dataloader import load_dictionary, load_queries
from transformers import AutoModel, AutoTokenizer
import torch
TRAIN_DICT_PATH = "./data/data-ncbi-fair/train_dictionary.txt"
TRAIN_DIR = "./data/data-ncbi-fair/traindev"

train_dictionary  = load_dictionary(dict_path=TRAIN_DICT_PATH)
train_queries  = load_queries(data_dir=TRAIN_DIR, filter_composite=False, filter_duplicates=False, filter_cuiless=True)

# train_dictionary = train_dictionary[:50]
train_queries = train_queries[:10]

tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.1')
encoder = AutoModel.from_pretrained('dmis-lab/biobert-base-cased-v1.1')


max_length = 25


query_names, query_ids = [row[0] for row in train_queries], [row[1] for row in train_queries]
dict_names, dict_ids = [row[0] for row in train_dictionary], [row[1] for row in train_dictionary]
topk = 4

  0%|          | 0/90599 [00:00<?, ?it/s]

100%|██████████| 90599/90599 [00:00<00:00, 1439213.08it/s]
100%|██████████| 691/691 [00:00<00:00, 12735.26it/s]


In [3]:
import numpy as np
m = len(query_names)
n = len(dict_names)
cand_idxs = np.random.randint(0,n, size=(m, topk))
for c_idx, q_cands in enumerate(cand_idxs):
    correct_dict_idxs = [idx  for idx, id in enumerate(dict_ids) if query_ids[c_idx] in id]
    if len(correct_dict_idxs) > 0:
        q_cands[0] = correct_dict_idxs[0]

cand_idxs

array([[ 1975,  6319, 20703, 72395],
       [ 1975, 12987, 18556, 21810],
       [18118, 67536,  4804, 20365],
       [18118, 74007, 20594, 13752],
       [59360, 64499, 45507, 56247],
       [18244, 78950, 29542, 32328],
       [18192,  6740, 83380, 10547],
       [18244,  3215, 15565, 58657],
       [18244, 45008, 31817, 88936],
       [18192, 62700, 28704, 28204]], dtype=int32)

In [4]:
dict_id_sets = [set(s.split("|")) if isinstance(s, str) else set(s) for s in dict_ids]
query_id_tokens = [tuple(q.split("|")) if isinstance(q, str) else tuple(q) for q in query_ids]

labels_per_query = []
for q_idx, q_cand_idxs in enumerate(cand_idxs):
    q_id_tokens = query_id_tokens[q_idx]
    labels = np.fromiter(
        (1.0 if all(tok in dict_id_sets[i] for tok in q_id_tokens) else 0.0 for i in q_cand_idxs),
        dtype=np.float32,
        count=len(q_cand_idxs)
    )
    labels_per_query.append(labels)


In [5]:
labels_per_query

[array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32)]

In [8]:
def check_label(query_id, candidate_id_set):
    """
        check if all q_id in query_id.split("|") exists in candidate_id_set 
    """
    label = 0
    query_ids = query_id.split("|")
    for q_id in query_ids:
        if q_id in candidate_id_set:
            label = 1
            continue
        else:
            label = 0
            break
    return label

def get_labels(query_idx, candidate_idxs):
    labels = np.array([])
    query_id = query_ids[query_idx]
    candidate_ids = np.array(dict_ids)[candidate_idxs]
    for candidate_id in candidate_ids:
        label = check_label(query_id, candidate_id)
        labels = np.append(labels, label)
    return labels


In [10]:
all_query_labels = []
for query_idx, topk_ids in enumerate(query_ids):
    d_cand_idxs = cand_idxs[query_idx]
    topk_candidate_idx = np.array(d_cand_idxs)
    labels = get_labels(query_idx, topk_candidate_idx).astype(np.float32)
    all_query_labels.append(labels)

In [11]:
all_query_labels

[array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32),
 array([1., 0., 0., 0.], dtype=float32)]