In [6]:
from helpers.Data import load_queries
from config import GlobalConfig
import torch
from transformers import AutoTokenizer


cfg = GlobalConfig()

tokenize_batch_size = cfg.tokenize.tokenize_batch_size
dictionary_max_length = cfg.tokenize.dictionary_max_length
queries_max_length = cfg.tokenize.queries_max_length
dictionary_max_chars_length = cfg.tokenize.dictionary_max_chars_length

mention_start_special_token = cfg.tokenize.special_tokens_dict["mention_start"]
mention_end_special_token = cfg.tokenize.special_tokens_dict["mention_end"]

use_cuda = torch.cuda.is_available()
device = "cuda"    if use_cuda else "cpu"

tokenizer = AutoTokenizer.from_pretrained(cfg.model.model_name, use_fast=True)
tokenizer.add_special_tokens(cfg.tokenize.special_tokens)


train_queries = load_queries(
            data_dir=cfg.paths.queries_raw_dir,
            queries_max_length=queries_max_length,
            special_token_start=mention_start_special_token ,
            special_token_end=mention_end_special_token,
            tokenizer=tokenizer)

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 691/691 [00:20<00:00, 33.52it/s]

annotation_skipped: 13





In [11]:
from helpers.Data import load_dictionary
dictionary_max_chars_length = cfg.tokenize.dictionary_max_chars_length

_, dictionary_cuis, dictionary_names_annotated = load_dictionary(cfg.paths.dictionary_raw_path, 
                                     special_token_start=mention_start_special_token, 
                                     special_token_end=mention_end_special_token,
                                    dictionary_max_chars_length=dictionary_max_chars_length,
                                    add_synonyms=True
                                     )


pre process dictionary: 100%|██████████| 90599/90599 [00:00<00:00, 786641.47it/s]
annotating dictionary: 100%|██████████| 90510/90510 [00:00<00:00, 191998.07it/s]


In [13]:
dictionary_names_annotated[20:25]

['[MS] 16p11.2 deletion syndrome [ME] ',
 '[MS] 17,20-lyase deficiency, isolated [ME]  ; 17-alpha-hydroxylase-17,20-lyase deficiency, combined complete ; 17-alpha-hydroxylase-17,20-lyase deficiency, combined partial',
 '[MS] 17-alpha-hydroxylase-17,20-lyase deficiency, combined complete [ME]  ; 17,20-lyase deficiency, isolated ; 17-alpha-hydroxylase-17,20-lyase deficiency, combined partial',
 '[MS] 17-alpha-hydroxylase-17,20-lyase deficiency, combined partial [ME]  ; 17,20-lyase deficiency, isolated ; 17-alpha-hydroxylase-17,20-lyase deficiency, combined complete',
 '[MS] 17-hydroxysteroid dehydrogenase deficiency [ME]  ; 17-ketosteroid reductase deficiency of testis ; 17-beta hydroxysteroid dehydrogenase 3 deficiency ; neutral 17-beta-hydroxysteroid oxidoreductase deficiency ; pseudohermaphroditism, male, with gynecomastia polycystic ovary syndrome due to 17-ketosteroid reductase deficiency, included ; 17 alpha ketosteroid reductase deficiency of testis']

In [111]:

query_idx = 1
hard_positives_num = 2
hard_negatives_num = 1
query_cui = queries_cuis[query_idx]
current_query_candidates_idxs = new_cands[query_idx].tolist()
current_candidates_cuis = dictionary_cuis[current_query_candidates_idxs]


positive_positions = np.where(current_candidates_cuis == query_cui)[0]
candidates_idxs_available = list(set(range(topk))  - set(positive_positions)  )
positive_candidates_indexes = dictionary_cui_to_idx.get(query_cui, [])
available_positives = list(set(positive_candidates_indexes) - set(current_query_candidates_idxs))


if available_positives:
    # how many positives we will inject, in case available are less than the one in config
    positive_n = min(hard_positives_num, len(available_positives))
    #  random positive candidates, to choose from available positives (index of dictionary_cui)
    positive_candidates = np.array(available_positives[:positive_n])
    # random indexes in candidate list to be replaced
    candidates_idxs_to_be_replaced = candidates_idxs_available[:positive_n]
    new_cands_2[query_idx, candidates_idxs_to_be_replaced] = torch.from_numpy(positive_candidates)



new_cands_2[query_idx]

tensor([ 3, 11,  6,  1, 12])

In [112]:
import torch

query_cui = queries_cuis[query_idx]
current_idxs = np.array(new_cands[query_idx])
current_cuis = dictionary_cuis[current_idxs]

negative_mask = (current_cuis != query_cui)
available_positions = np.flatnonzero(negative_mask)

pos_dict_idxs = dictionary_cui_to_idx.get(query_cui, [])
available_pos_dict = np.setdiff1d(pos_dict_idxs, current_idxs, assume_unique=False)

pos_n = min(hard_positives_num,
            len(available_pos_dict),
            len(available_positions))
chosen_pos_dict = available_pos_dict[:pos_n]
chosen_slots = available_positions[:pos_n]
new_cands_1[query_idx, chosen_slots] = torch.from_numpy(chosen_pos_dict)
available_positions = np.setdiff1d(available_positions, chosen_slots, assume_unique=False)


new_cands_1[query_idx]

tensor([ 3, 11,  6,  1, 12])

In [74]:
import torch

query_idx = 1
hard_positives_num = 2
hard_negatives_num = 1
query_cui = queries_cuis[query_idx]
current_query_candidates_idxs = new_cands[query_idx].tolist()
current_candidates_cuis = dictionary_cuis[current_query_candidates_idxs]
positive_positions = np.where(current_candidates_cuis == query_cui)[0]

candidates_idxs_to_be_replaced = np.array([])
candidates_idxs_available = list(set(range(topk))  - set(positive_positions)  )

print(f"current candidates are: {new_cands[query_idx]}")

positive_candidates_indexes = dictionary_cui_to_idx.get(query_cui, [])
if len(positive_candidates_indexes) > 0:
    available_positives = list(set(positive_candidates_indexes) - set(current_query_candidates_idxs))
    if available_positives:
        positive_n = min(hard_positives_num, len(available_positives))
        positive_candidates = np.random.choice(available_positives, size=positive_n, replace=False)

        candidates_available_idxs = candidates_idxs_available
        candidates_idxs_to_be_replaced = np.random.choice(candidates_idxs_available, size=positive_n, replace=False)
        new_cands[query_idx, candidates_idxs_to_be_replaced] = torch.from_numpy(positive_candidates)

print(f"after injecting {hard_positives_num} positive are: {new_cands[query_idx]}")
candidates_idxs_available = list(set(candidates_idxs_available) - set(candidates_idxs_to_be_replaced))
prev_cands_idxs = previous_epoch_candidates[query_idx]
prev_dictionary_cuis = dictionary_cuis[prev_cands_idxs]
neg_mask = prev_dictionary_cuis != query_cui
hard_negative_indexes = prev_cands_idxs[neg_mask]

if len(hard_negative_indexes) > 0:
    negatives_n = min(hard_negatives_num, len(hard_negative_indexes))
    hard_negative_candidates = np.random.choice(hard_negative_indexes, size=negatives_n, replace=False)
    # candidates_to_replace_positive
    candidates_idxs_to_be_replaced = np.random.choice(candidates_idxs_available, size=negatives_n, replace=False)

    new_cands[query_idx, candidates_idxs_to_be_replaced] = torch.from_numpy(hard_negative_candidates)

print(f"after injecting {hard_negatives_num} negatives are: {new_cands[query_idx]}")

new_cands[query_idx]

current candidates are: [ 0 11  6  1 12]
after injecting 2 positive are: [ 0 11  6  1 12]
after injecting 1 negatives are: [ 3 11  6  1 12]


array([ 3, 11,  6,  1, 12])