# Load UMLS with preferred language range and source range

In [None]:
from gensim import models
import os
import sys
sys.path.append("../../")
from load_umls import UMLS
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel, AutoConfig
import tqdm
import pickle

batch_size = 128
device = "cuda:0"

In [None]:
def get_bert_embed(phrase_list, m, tok, normalize=True, summary_method="CLS", tqdm_bar=False):
    input_ids = []
    for phrase in phrase_list:
        input_ids.append(tok.encode_plus(
            phrase, max_length=32, add_special_tokens=True,
            truncation=True, pad_to_max_length=True)['input_ids'])
    m.eval()

    count = len(input_ids)
    now_count = 0
    with torch.no_grad():
        if tqdm_bar:
            pbar = tqdm.tqdm(total=count)
        while now_count < count:
            input_gpu_0 = torch.LongTensor(input_ids[now_count:min(
                now_count + batch_size, count)]).to(device)
            if summary_method == "CLS":
                embed = m(input_gpu_0)[1]
            if summary_method == "MEAN":
                embed = torch.mean(m(input_gpu_0)[0], dim=1)
            if normalize:
                embed_norm = torch.norm(
                    embed, p=2, dim=1, keepdim=True).clamp(min=1e-12)
                embed = embed / embed_norm
            if now_count == 0:
                output = embed
            else:
                output = torch.cat((output, embed), dim=0)
            if tqdm_bar:
                pbar.update(min(now_count + batch_size, count) - now_count)
            now_count = min(now_count + batch_size, count)
        if tqdm_bar:
            pbar.close()
    return output

In [None]:
# coder model path
model_checkpoint = '/export/home/cse200093/coder_eng'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModel.from_pretrained(model_checkpoint).to(device)

In [None]:
# source_range and lang_range can be modified 
umls = UMLS("../../deep_mlg_normalization/resources/umls/2021AB", source_range=['BI','CHV','CSP','CST','CVX','DRUGBANK','HPO','ICD10','ICD10CM','ICPC2P','ICPCFRE',
                                                                               'LNC','LNC-FR-FR','MDR','MDRFRE','MEDCIN','MMX','MSH','MSHFRE','MTHICD9','MTHMSTFRE',
                                                                               'NCBI','NCI','NCI_CDISC','NCI_CTRP','NDDF','OMIM','PDQ','RCD','RXNORM','SNMI','SNOMEDCT_US',
                                                                               'SRC','WHO','WHOFRE'], lang_range=['ENG'])

In [None]:
umls_label = [] # cuis
umls_label_set = set()
umls_des = [] # terms

for cui in tqdm.tqdm(umls.cui2str):
    if not cui in umls_label_set:
        tmp_str = list(umls.cui2str[cui])
        umls_label.extend([cui] * len(tmp_str))
        umls_des.extend(tmp_str)
        umls_label_set.update([cui])
print(len(umls_des))

In [None]:
# save umls_label and des
open_file = open('umls_des_en_fr_coder_eng.pkl', "wb")
pickle.dump(umls_des, open_file)
open_file.close()

open_file = open('umls_label_en_fr_coder_eng.pkl', "wb")
pickle.dump(umls_label, open_file)
open_file.close()

# save umls_en_fr.cui2sty
open_file = open('umls_en_fr_cui2sty_coder_eng.pkl', "wb")
pickle.dump(umls.cui2sty, open_file)
open_file.close()

In [None]:
# calculate embedding for all umls terms
umls_embedding = get_bert_embed(umls_des, model, tokenizer, tqdm_bar=True)

In [None]:
# save umls_embedding
torch.save(umls_embedding, 'umls_embedding_en_fr_coder_eng.pt')