# Use deep_mlg_norm to normalize

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import pickle as pkl

from sklearn.metrics import precision_recall_fscore_support

import string
import logging

from nlstruct.utils import torch_clone
from nlstruct.utils import torch_global as tg
from nlstruct.dataloaders import load_from_brat

# path to deep_multilingual_normalization
sys.path.insert(0,'./deep_multilingual_normalization')

from deep_multilingual_normalization.preprocess import preprocess, load_quaero
from deep_multilingual_normalization.train import train_step1, train_step2, clear
from deep_multilingual_normalization.eval import predict

from notebook_utils import *

from tqdm import tqdm
import pickle

print('done')

In [None]:
# load cui2sty
open_file = open('/export/home/cse200093/Expe_Translation/english_Norm/umls_en_fr_cui2sty.pkl', "rb")
umls_cui2sty = pickle.load(open_file)
open_file.close()

In [None]:
# map a sty group to a semantic group
sty2type = {
    'Activity': 'ACTI',
    'Behavior': 'ACTI',
    'Daily or Recreational Activity': 'ACTI',
    'Event': 'ACTI',
    'Governmental or Regulatory Activity': 'ACTI',
    'Individual Behavior': 'ACTI',
    'Machine Activity': 'ACTI',
    'Occupational Activity': 'ACTI',
    'Social Behavior': 'ACTI',
    'Anatomical Structure': 'ANAT',
    'Body Location or Region': 'ANAT',
    'Body Part, Organ, or Organ Component': 'ANAT',
    'Body Space or Junction': 'ANAT',
    'Body Substance': 'ANAT',
    'Body System': 'ANAT',
    'Cell': 'ANAT',
    'Cell Component': 'ANAT',
    'Embryonic Structure': 'ANAT',
    'Fully Formed Anatomical Structure': 'ANAT',
    'Tissue': 'ANAT',
    'Amino Acid, Peptide, or Protein': 'CHEM',
    'Antibiotic': 'CHEM',
    'Biologically Active Substance': 'CHEM',
    'Biomedical or Dental Material': 'CHEM',
    'Chemical': 'CHEM',
    'Chemical Viewed Functionally': 'CHEM',
    'Chemical Viewed Structurally': 'CHEM',
    'Clinical Drug': 'CHEM',
    'Element, Ion, or Isotope': 'CHEM',
    'Enzyme': 'CHEM',
    'Hazardous or Poisonous Substance': 'CHEM',
    'Hormone': 'CHEM',
    'Immunologic Factor': 'CHEM',
    'Indicator, Reagent, or Diagnostic Aid': 'CHEM',
    'Inorganic Chemical': 'CHEM',
    'Nucleic Acid, Nucleoside, or Nucleotide': 'CHEM',
    'Organic Chemical': 'CHEM',
    'Pharmacologic Substance': 'CHEM',
    'Receptor': 'CHEM',
    'Vitamin': 'CHEM',
    'Classification': 'CONC',
    'Conceptual Entity': 'CONC',
    'Functional Concept': 'CONC',
    'Group Attribute': 'CONC',
    'Idea or Concept': 'CONC',
    'Intellectual Product': 'CONC',
    'Language': 'CONC',
    'Qualitative Concept': 'CONC',
    'Quantitative Concept': 'CONC',
    'Regulation or Law': 'CONC',
    'Spatial Concept': 'CONC',
    'Temporal Concept': 'CONC',
    'Drug Delivery Device': 'DEVI',
    'Medical Device': 'DEVI',
    'Research Device': 'DEVI',
    'Acquired Abnormality': 'DISO',
    'Anatomical Abnormality': 'DISO',
    'Cell or Molecular Dysfunction': 'DISO',
    'Congenital Abnormality': 'DISO',
    'Disease or Syndrome': 'DISO',
    'Experimental Model of Disease': 'DISO',
    'Finding': 'DISO',
    'Injury or Poisoning': 'DISO',
    'Mental or Behavioral Dysfunction': 'DISO',
    'Neoplastic Process': 'DISO',
    'Pathologic Function': 'DISO',
    'Sign or Symptom': 'DISO',
    'Amino Acid Sequence': 'GENE',
    'Carbohydrate Sequence': 'GENE',
    'Gene or Genome': 'GENE',
    'Molecular Sequence': 'GENE',
    'Nucleotide Sequence': 'GENE',
    'Geographic Area': 'GEOG',
    'Age Group': 'LIVB',
    'Amphibian': 'LIVB',
    'Animal': 'LIVB',
    'Archaeon': 'LIVB',
    'Bacterium': 'LIVB',
    'Bird': 'LIVB',
    'Eukaryote': 'LIVB',
    'Family Group': 'LIVB',
    'Fish': 'LIVB',
    'Fungus': 'LIVB',
    'Group': 'LIVB',
    'Human': 'LIVB',
    'Mammal': 'LIVB',
    'Organism': 'LIVB',
    'Patient or Disabled Group': 'LIVB',
    'Plant': 'LIVB',
    'Population Group': 'LIVB',
    'Professional or Occupational Group': 'LIVB',
    'Reptile': 'LIVB',
    'Vertebrate': 'LIVB',
    'Virus': 'LIVB',
    'Entity': 'OBJC',
    'Food': 'OBJC',
    'Manufactured Object': 'OBJC',
    'Physical Object': 'OBJC',
    'Substance': 'OBJC',
    'Biomedical Occupation or Discipline': 'OCCU',
    'Occupation or Discipline': 'OCCU',
    'Health Care Related Organization': 'ORGA',
    'Organization': 'ORGA',
    'Professional Society': 'ORGA',
    'Self-help or Relief Organization': 'ORGA',
    'Biologic Function': 'PHEN',
    'Environmental Effect of Humans': 'PHEN',
    'Human-caused Phenomenon or Process': 'PHEN',
    'Laboratory or Test Result': 'PHEN',
    'Natural Phenomenon or Process': 'PHEN',
    'Phenomenon or Process': 'PHEN',
    'Cell Function': 'PHYS',
    'Clinical Attribute': 'PHYS',
    'Genetic Function': 'PHYS',
    'Mental Process': 'PHYS',
    'Molecular Function': 'PHYS',
    'Organism Attribute': 'PHYS',
    'Organism Function': 'PHYS',
    'Organ or Tissue Function': 'PHYS',
    'Physiologic Function': 'PHYS',
    'Diagnostic Procedure': 'PROC',
    'Educational Activity': 'PROC',
    'Health Care Activity': 'PROC',
    'Laboratory Procedure': 'PROC',
    'Molecular Biology Research Technique': 'PROC',
    'Research Activity': 'PROC',
    'Therapeutic or Preventive Procedure': 'PROC'
}

In [None]:
VOCAB_PATH = '/export/home/cse200093/deep_mlg_normalization'
PREDICT_DATA_PATH = '/export/home/cse200093/brat_data/n2c2_2019/test_restrict_proc_chem_devi_diso'

In [None]:
# load the model
model = torch.load('/export/home/cse200093/deep_mlg_normalization/models/UMLS2021AB_without_quaero.pt')
model.eval()

In [None]:
bert_name = "/export/home/cse200093/deep_mlg_normalization/bert-base-multilingual-uncased"

import os

with open(os.path.join(VOCAB_PATH,'vocab1.pkl'), 'rb') as f:
    vocabularies1 = pkl.load(f)

with open(os.path.join(VOCAB_PATH,'vocab2021AB_v2.pkl'), 'rb') as f:
    vocabularies2 = pkl.load(f)

In [None]:
dataset = load_from_brat(PREDICT_DATA_PATH)

dataset['mentions']['mention_id'] = dataset['mentions']['doc_id'] +'.'+ dataset['mentions']['mention_id'].astype(str)

batcher, vocs, mention_ids = preprocess_train(
    dataset,
    vocabularies=vocabularies2,
    bert_name=bert_name,
)

batch_size = len(batcher)
with_tqdm = True

tg.set_device('cuda:0') #('cuda:0')
device = tg.device

# topk: return first k candidates
nb = 5
pred_batcher = predict(batcher, model, batch_size=64, return_loss=False, topk=nb)

print('pred_batcher', pred_batcher)
print('batcher', batcher)

In [None]:
voc_tokens = vocs['token']
voc_labels = vocs['label']
def replace_fn(s):
    return s.replace(' ##', '').replace('[SEP]', '').replace('[PAD]', '').replace('[CLS]', '').strip()
merged_batcher = batcher['mention',['mention_id','token']].merge(pred_batcher)
final_tokens = [replace_fn(' '.join([voc_tokens[i] for i in b])) for b in merged_batcher['mention']['token'].toarray()]
final_labels = [voc_labels[b] for b in merged_batcher['mention']['label']]

final_couples = [(t, l) for t,l in zip(final_tokens, final_labels)]
# index of elements in the list is just the corresponding mention_id
final_couples

In [None]:
# Add type information: output the first candidate with right type
def choose(l, stand):
    i = 0
    while l[i]!=stand and i<nb-1:
        i+=1
    if i == nb-1:
        return 0
    else:
        return i
    
def type_info(final_couples):
    gold_types = list(dataset['mentions']['label'])
    new_final_couples = []
    nb_unknown = 0 # number of unknown cuis
    unknown_cuis = []
    for i in tqdm(range(len(final_couples))):
        candidates = final_couples[i][1]
        candidate_types = [sty2type[umls_cui2sty[candidate]] if candidate in umls_cui2sty.keys() else 'unknown_cui' for candidate in candidates]
        nb_unknown+=len([x for x in candidate_types if x=='unknown_cui'])
        unknown_cuis+=[candidates[i] for i in range(nb) if candidate_types[i]=='unknown_cui']
        gold_type = gold_types[i]
        res = choose(candidate_types, gold_type)
        
        new_final_couples.append((final_couples[i][0],candidates[res]))
        
    print(nb_unknown)
    return (new_final_couples,unknown_cuis)

In [None]:
# add type info
# there can be many unknown cuis for the umls you loaded based on the lang_range and source_range you chose
final_couples,unknown_cuis = type_info(final_couples)
final_couples

In [None]:
df_mentions = dataset['mentions'].copy()
df_mentions['cui_res'] = [x[1] for x in final_couples]
df_mentions.set_index('mention_id',inplace=True)
df_mentions

In [None]:
df_comments = dataset['comments'].copy()
df_comments['mention_id'] = df_comments['doc_id']+'.'+df_comments['mention_id']

# for Mantra
#types = [x.split(',')[-1][1:].rstrip('\"') for x in list(df_comments['comment'])]
#df_comments['comment'] = df_comments['comment'].apply(lambda x: x.split(',')[0][1:].rstrip('\"'))
df_comments.set_index('mention_id',inplace=True)
df_comments

In [None]:
df_compare = pd.concat([df_comments, df_mentions], axis=1, join='inner') 

# for Mantra
#df_compare['label'] = type

# filter type
df_compare = df_compare[(df_compare['label']=='PROC')|(df_compare['label']=='DEVI')|(df_compare['label']=='DISO')|(df_compare['label']=='CHEM')]
df_compare

In [None]:
# calculate and print accuracy
accuracy = len(df_compare[df_compare['cui_res']==df_compare['comment']])/len(df_compare)
print(f'accuracy = {accuracy}')

In [None]:
# error analysis
df_err = df_compare[df_compare['cui_res']!=df_compare['comment']]
df_err

In [None]:
# number of errors per type
df_err.groupby('label').count()['comment_id']