# Use CODER to normalize concepts

In [None]:
from gensim import models
import os
import sys
import glob
# path tp load_umls
sys.path.append("../../")
from load_umls import UMLS

import pandas as pd
import numpy as np

import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel, AutoConfig
import tqdm
import pickle

batch_size = 128
device = "cuda:6"

In [None]:
def get_bert_embed(phrase_list, m, tok, normalize=True, summary_method="CLS", tqdm_bar=False):
    input_ids = []
    for phrase in phrase_list:
        input_ids.append(tok.encode_plus(
            phrase, max_length=32, add_special_tokens=True,
            truncation=True, pad_to_max_length=True)['input_ids'])
    m.eval()

    count = len(input_ids)
    now_count = 0
    with torch.no_grad():
        if tqdm_bar:
            pbar = tqdm.tqdm(total=count)
        while now_count < count:
            input_gpu_0 = torch.LongTensor(input_ids[now_count:min(
                now_count + batch_size, count)]).to(device)
            if summary_method == "CLS":
                embed = m(input_gpu_0)[1]
            if summary_method == "MEAN":
                embed = torch.mean(m(input_gpu_0)[0], dim=1)
            if normalize:
                embed_norm = torch.norm(
                    embed, p=2, dim=1, keepdim=True).clamp(min=1e-12)
                embed = embed / embed_norm
            if now_count == 0:
                output = embed
            else:
                output = torch.cat((output, embed), dim=0)
            if tqdm_bar:
                pbar.update(min(now_count + batch_size, count) - now_count)
            now_count = min(now_count + batch_size, count)
        if tqdm_bar:
            pbar.close()
    return output

In [None]:
# load umls_embedding
umls_embedding = torch.load('umls_embedding_en_fr_coder_eng.pt', map_location=device) # CODER_en & en_fr UMLS

In [None]:
# load terms, cuis and semantic groups
open_file = open('umls_label_en_fr_coder_eng.pkl', "rb")
umls_label = pickle.load(open_file)
open_file.close()

open_file = open('umls_des_en_fr_coder_eng.pkl', "rb")
umls_des = pickle.load(open_file)
open_file.close()

open_file = open('umls_en_fr_cui2sty_coder_eng.pkl', "rb")
umls_cui2sty = pickle.load(open_file)
open_file.close()

In [None]:
# load CODER model
model_checkpoint = '/export/home/cse200093/coder_eng'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModel.from_pretrained(model_checkpoint).to(device)

In [None]:
# predict data path
data_path = '/export/home/cse200093/brat_data/n2c2_2019/test_restrict_proc_chem_devi_diso/'

In [None]:
# select all ann files
os.chdir(data_path)
my_files = glob.glob('*.ann')
len(my_files)

In [None]:
phrases = []
cuis = []
types = []
dic = {} # necessary when len(phrases) != len(cuis)
for file in tqdm.tqdm(my_files):
    phrase = None
    code = None
    type_name = None
    f = open(data_path+file, "r")
    for line in f:
        line = line.rstrip('\n')
        if line.startswith('T'):
            type_name = line.split('\t')[1].split(' ')[0]
            phrase = line.split('\t')[2]
            code = line.split('\t')[0] # necessary when len(phrases) != len(cuis)
            dic[file+'.'+phrase+'.'+code] = 'no cui' # necessary when len(phrases) != len(cuis)
            #phrases.append(phrase)
            #types.append(type_name)
        elif line.startswith('#'):
            cui = line.split('\t')[2]
            dic[file+'.'+phrase+'.'+code] = (cui,type_name) # necessary when len(phrases) != len(cuis)
            #cui = cui.split(',')[0][1:].rstrip('\"') # necessary for Mantra
            #cuis.append(cui)

In [None]:
# necessary when len(phrases) != len(cuis)
print(len(dic))
no_cuis = []
# phrases_eng = []
for key in tqdm.tqdm(dic):
    if dic[key] != 'no cui':
        phrases.append(key.split('.')[2])
        # necessary with translation
        # phrases_eng.append(translator.translate(key.split('.')[2]).text)
        cuis.append(dic[key][0])
        types.append(dic[key][1])
    else:
        #print(key)
        no_cuis.append(key)
        
print(len(cuis),len(phrases),len(types))

In [None]:
df = pd.DataFrame({'phrases':phrases, 'cuis':cuis, 'types':types})
df

In [None]:
# filter 4 types
df = df[(df['types']=='PROC')|(df['types']=='DEVI')|(df['types']=='DISO')|(df['types']=='CHEM')]
phrases = list(df['phrases'])
cuis = list(df['cuis'])
types = list(df['types'])

len(phrases)

In [None]:
# get embedding for predict data
text_embedding = get_bert_embed(phrases, model, tokenizer)

In [None]:
# prediction considering type information
def predict(umls_label, text_embedding, umls_embedding, cui2sty, start, end, gold_type):
    x_size = text_embedding.size(0)
    sim = torch.matmul(text_embedding[start:end], umls_embedding.t())
    most_similar = torch.max(sim, dim=1)[1]
    most_similar_cui = [umls_label[idx] for idx in most_similar]
    candidates = torch.topk(sim, k=nb, dim=1, sorted=True).indices
    #candidates = [check(candidate.tolist(), sim[0]) for candidate in candidates]
    candidate_cuis = [[umls_label[idx] for idx in candidate] for candidate in candidates]
    #print(candidate_cuis)
    candidate_stys = [[cui2sty[cui] for cui in cuis] for cuis in candidate_cuis]
    candidate_types = [[sty2type[sty] for sty in stys] for stys in candidate_stys]
    pred = [candidate_cuis[i][choose(candidate_types[i],gold_type[start:end][i])] for i in range(len(candidates))]
    ks = [choose(candidate_types[i],gold_type[start:end][i]) for i in range(len(candidates))]
    #[umls_label[idx] for idx in most_similar]
#     for i in range(len(most_similar_cui)):
#         if sty2type[cui2sty[most_similar_cui[i]]] == gold_type[i]:
#             pred[i] = most_similar_cui[i]
    return (pred,ks)

In [None]:
def accuracy(pred, stand, types, phrases, ks=None):
    hit = 0
    if ks is not None:
        df_err = pd.DataFrame(columns=['cui_res','cui_stand','type','text','k'])
    else:
        df_err = pd.DataFrame(columns=['cui_res','cui_stand','type','text'])
    cui_res_l = []
    cui_stand_l = []
    type_l = []
    text_l = []
    ks_l = []
    for i in range(len(pred)):
        if pred[i] == stand[i]:
            hit+=1
            #print(pred[i],stand[i],types[i],phrases[i])
        else:
            cui_res_l.append(pred[i])
            cui_stand_l.append(stand[i])
            type_l.append(types[i])
            text_l.append(phrases[i])
            if ks is not None:
                ks_l.append(ks[i])

    acc = hit/len(pred)
    df_err['cui_res'] = cui_res_l
    df_err['cui_stand'] = cui_stand_l
    df_err['type'] = type_l
    df_err['text'] = text_l
    if ks is not None:
        df_err['k'] = ks_l
    print(acc, hit)
    return (df_err,acc)

In [None]:
# when too many terms to predict, choose start and end to normalize only part of them
start = 0
end = 500

pred1,ks = predict(umls_label, text_embedding, umls_embedding, umls_cui2sty, start, end, types)

In [None]:
# calculate overall accuracy and give errors
df_err,acc = accuracy(pred1,cuis[begin:end],types[begin:end],phrases[begin:endp],ks)