In [None]:
from gensim import models
import os
import sys
import glob
sys.path.append("../../")
from load_umls import UMLS
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel, AutoConfig
# from data_util import load
import tqdm
import pickle

batch_size = 128
device = "cuda:7"

In [None]:
import pandas as pd
import numpy as np
from nltk.corpus import stopwords

In [None]:
def get_bert_embed(phrase_list, m, tok, normalize=True, summary_method="CLS", tqdm_bar=False):
    input_ids = []
    for phrase in phrase_list:
        input_ids.append(tok.encode_plus(
            phrase, max_length=32, add_special_tokens=True,
            truncation=True, pad_to_max_length=True)['input_ids'])
    m.eval()

    count = len(input_ids)
    now_count = 0
    with torch.no_grad():
        if tqdm_bar:
            pbar = tqdm.tqdm(total=count)
        while now_count < count:
            input_gpu_0 = torch.LongTensor(input_ids[now_count:min(
                now_count + batch_size, count)]).to(device)
            if summary_method == "CLS":
                embed = m(input_gpu_0)[1]
            if summary_method == "MEAN":
                embed = torch.mean(m(input_gpu_0)[0], dim=1)
            if normalize:
                embed_norm = torch.norm(
                    embed, p=2, dim=1, keepdim=True).clamp(min=1e-12)
                embed = embed / embed_norm
            if now_count == 0:
                output = embed
            else:
                output = torch.cat((output, embed), dim=0)
            if tqdm_bar:
                pbar.update(min(now_count + batch_size, count) - now_count)
            now_count = min(now_count + batch_size, count)
        if tqdm_bar:
            pbar.close()
    return output

In [None]:
model_checkpoint = '/export/home/cse200093/coder_all'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModel.from_pretrained(model_checkpoint).to(device)

# Load and save umls (execute only once)

In [None]:
# load french umls
umls = UMLS("/export/home/cse200093/deep_mlg_normalization/resources/umls/2021AB", source_range=['BI','CHV','CSP','CST','CVX','DRUGBANK','HPO','ICD10','ICD10CM','ICPC2P','ICPCFRE',
                                                                               'LNC','LNC-FR-FR','MDR','MDRFRE','MEDCIN','MMX','MSH','MSHFRE','MTHICD9','MTHMSTFRE',
                                                                               'NCBI','NCI','NCI_CDISC','NCI_CTRP','NDDF','OMIM','PDQ','RCD','RXNORM','SNMI','SNOMEDCT_US',
                                                                               'SRC','WHO','WHOFRE'], lang_range=['FRE'])

In [None]:
umls_label = []
umls_label_set = set()
umls_des = []

for cui in tqdm.tqdm(umls.cui2str):
    if not cui in umls_label_set:
        tmp_str = list(umls.cui2str[cui])
        umls_label.extend([cui] * len(tmp_str))
        umls_des.extend(tmp_str)
        umls_label_set.update([cui])
print(len(umls_des))

In [None]:
umls_embedding = get_bert_embed(umls_des, model, tokenizer, tqdm_bar=True)

In [None]:
# save umls_embedding
torch.save(umls_embedding, 'umls_embedding_fr_coder_all.pt')
# save umls_label and des
open_file = open('umls_des_fr_coder_all.pkl', "wb")
pickle.dump(umls_des, open_file)
open_file.close()

open_file = open('umls_label_fr_coder_all.pkl', "wb")
pickle.dump(umls_label, open_file)
open_file.close()

# Find synonym
Use CODER_all to find top 10 synonyms of all queries from umls_des

In [None]:
# load queries for 4 types
f = open("./resources/query_4types.txt", "r")
lines = f.readlines()
dic_type_query = {}
for line in lines:
    line = line.rstrip('\n')
    type_name = line.split(':')[0]
    query = line.split(':')[1].split(';')
    dic_type_query[type_name] = query
dic_type_query

In [None]:
# load queries in phenotype_to_extract_large
f = open("./resources/phenotype_to_extract_large.txt", "r")
lines = f.readlines()
dic_type_query = {}
for line in lines:
    line = line.rstrip('\n')
    type_name = line.split(':')[0]
    query = line.split(':')[1].split(';')
    dic_type_query[type_name] = query
dic_type_query

In [None]:
# step 1: find out all terms with the same cui as query
stop_words = list(set(stopwords.words('french')))
stop_words.append('sai')

dic_query_same_cui = {}
for type_name in dic_type_query:
    queries = dic_type_query[type_name]
    for query in queries:
        try:
            cui = umls_label[umls_des.index(query.lower().replace('-',' '))]
            terms_same_cui = dic_cui_term[cui]
            dic_query_same_cui[query] = list(set([x+' ' for x in terms_same_cui if x not in stop_words]))
        except:
            dic_query_same_cui[query] = []
            
dic_query_same_cui

In [None]:
queries = [y for x in list(dic_type_query.values()) for y in x]
queries_embedding = get_bert_embed(queries, model, tokenizer)

In [None]:
# limit
limit = 0.8
candidate_terms = {}
x_size = queries_embedding.size(0)
sim = torch.matmul(queries_embedding, umls_embedding.t())

idx_query = torch.where(sim>limit)[0]
idx_cand = torch.where(sim>limit)[1]

for val in range(len(queries)):
    idx = torch.where(idx_query==val)[0]
    idx_query_selected = [idx_cand[x] for x in idx]
    # terms_selected = [' '+umls_des[x]+' ' for x in idx_query_selected if umls_des[x] not in stop_words]
    terms_selected = list(set([umls_des[x].lower() for x in idx_query_selected if umls_des[x] not in stop_words]))
    candidate_terms[queries[val]] = [x+' ' for x in terms_selected]

candidate_terms

In [None]:
for type_name in dic_type_query:
    new_val = []
    for word in dic_type_query[type_name]:
        new_val.append(list(set([word.replace('-',' ').lower()+' ']+candidate_terms[word]+dic_query_same_cui[word])))
        # new_val.append([word]+candidate_terms[i])
    dic_type_query[type_name] = new_val
dic_type_query

In [None]:
# save dic_type_query
open_file = open(f'dic_type_query_synonym_limit{limit}_same_cui.pkl', "wb")
pickle.dump(dic_type_query, open_file)
open_file.close()