In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
import pickle
import numpy as np
import torch
import pandas as pd
from glob import glob
import json
from tqdm import tqdm
import ipdb

In [None]:
import faiss

In [None]:
from transformers import AutoTokenizer, PreTrainedTokenizer, BertTokenizer

pubmed_bert_tokenizer = AutoTokenizer.from_pretrained('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract')

In [None]:
from sklearn.metrics import recall_score,f1_score,accuracy_score

In [None]:
import os
import sys
sys.path.insert(0, os.path.abspath('../src/'))

In [None]:
from UMLS import UMLS
from RetrievalModule import RetrievalModule
%load_ext autoreload
%autoreload 2

In [None]:
past_version = '2020AA'
umls_version = '2020AB'

In [None]:
mrconso_file = '../data/{}-ACTIVE/META_DL_V2/MRCONSO_MASTER.RRF'.format(umls_version)
new_aui_set_file = '../data/AAAC_{}_vs_{}_AUIsList_TEST.txt'.format(past_version, umls_version)
rba_filename = '../data/{}-ACTIVE/RBA_V2/AUI_COLOR.PICKLE'.format(umls_version)
sort_filename = '../output/0/cambridgeltl_SapBERT-from-PubMedBERT-fulltext_candidates.p'

In [None]:
umls = UMLS(mrconso_file)
_ = umls.get_new_aui_set(new_aui_set_file)

In [None]:
len(umls.original_auis),len(umls.new_auis)

In [None]:
candidates = pickle.load(open(sort_filename,'rb'))

In [None]:
len(candidates)

## Load RBA

In [None]:
mrconso = open(mrconso_file,'r').readlines()

mrconso_dict = {}

for line in mrconso:
    line = line.split('|')
    idn = line[0]
    aui = line[4]
    
    mrconso_dict[int(idn)] = aui
    
colors = pickle.load(open(rba_filename,'rb'))

color2aui = {}
aui2color = {}

for idn,color in tqdm(colors.items()):
    
    color_set = color2aui.get(color,set())
    aui = mrconso_dict[idn]
    
    color_set.add(aui)

    color2aui[color] = color_set
    aui2color[aui] = color

In [None]:
rba_predicted_synonyms = {}

for new_aui in tqdm(umls.new_auis):
    
    color = aui2color[new_aui]
    predicted_auis = color2aui[color]
    
    filtered_preds = set()
    
    for pred in predicted_auis:
        if pred != aui and pred not in umls.new_auis:
            filtered_preds.add(pred)
            
    rba_predicted_synonyms[new_aui] = [umls.aui2cui[aui] for aui in filtered_preds]

In [None]:
len(rba_predicted_synonyms)

In [None]:
if '2020AB' in mrconso_file:
    dataset_splits = pickle.load(open('../data/UMLS2020AB_SAPBERT_Source_Info_Official_Split_Basic.p','rb'))
    dataset_splits = dataset_splits[['auis','split']].set_index('auis').to_dict()
    dataset_splits = dataset_splits['split']
else:
    dataset_splits = None

In [None]:
#Create dataframe with aui, true_cand (null included), true CUI

In [None]:
old_cuis = set()

for aui in tqdm(umls.original_auis):
    
    old_cuis.add(umls.aui2cui[aui])

In [None]:
new_cuis = set()

for new_aui in umls.new_auis:
    new_cui = umls.aui2cui[new_aui]
    if new_cui not in old_cuis:
        new_cuis.add(new_cui)
    
len(new_cuis)

In [None]:
df = []

for aui,cand_auis in tqdm(candidates.items()):
    
    query_str = umls.aui2str[aui]
    cand_auis, cand_dists = cand_auis

    true_cui = umls.aui2cui[aui]
    cand_cuis = [umls.aui2cui[aui_cand] for aui_cand in cand_auis]
    cand_strs = [umls.aui2str[aui_cand] for aui_cand in cand_auis]

    null_or_cui = true_cui not in old_cuis

    rba_cands = rba_predicted_synonyms[aui]

    if dataset_splits is not None:
        split = dataset_splits[aui]
    else:
        split = 'dev'

    sem_group = umls.cui2sg[true_cui]

    df.append((aui, query_str, true_cui, sem_group, cand_auis, cand_strs, cand_dists, cand_cuis, null_or_cui, rba_cands, split))
    
df = pd.DataFrame(df,columns=['aui','query_str', 'true_cui', 'sem_group', 'cand_auis', 'cand_strs', 'cand_dists', 'cand_cuis', 'null_or_cui', 'rba_cands', 'split'])

In [None]:
cui_sem_group_cui = df[['true_cui','sem_group']].drop_duplicates()
validation = []
testing = []

test = 0.50

for i,g in cui_sem_group_cui.groupby('sem_group'):
    
    perm_g = g.sample(len(g),random_state=np.random.RandomState(42)).true_cui.values
    
    validation.extend(perm_g[:len(g) - int(len(g)*(test))])
    testing.extend(perm_g[len(g) - int(len(g)*test):])
    
    assert(validation[-1] != testing[0])
    
validation = set(validation)
testing = set(testing)

splits = []

for cui in df.true_cui:
    
    if cui in testing:
        splits.append('test')
    else:
        splits.append('dev')
        
        
df['split'] = splits

display(df.groupby('split').count())
display(df.groupby('null_or_cui').count())
display(df.groupby(['split','null_or_cui']).count())

In [None]:
dev_df = df[[split in ['test'] for split in df.split]]

In [None]:
all_performance_metrics = []

## RBA Only Performance

In [None]:
rba_null_correct = []
rba_preds = []

for i, row in tqdm(dev_df.iterrows()):
    query_aui = row.aui
    true_cui = row.true_cui
    null_cui = row.null_or_cui
    
    rba_cuis = rba_predicted_synonyms[query_aui]
    
    if len(rba_cuis) == 0:
        rba_pred = 'null'
        rba_preds.append(rba_pred)
    
        if null_cui:
            rba_null_correct.append(1)
        else:
            rba_null_correct.append(0)
    else:
        rba_preds.append('cui')
        if true_cui in rba_cuis:
            rba_null_correct.append(1/len(rba_cuis))
        else:
            rba_null_correct.append(0)

In [None]:
dev_df['rba_correct'] = rba_null_correct
dev_df['rba_pred'] = rba_preds

In [None]:
overall_recall_at_1 = dev_df.rba_correct.mean()
print('Full R@1: {}'.format(overall_recall_at_1))
null_recall = dev_df[dev_df.null_or_cui].rba_correct.mean()
null_precision = dev_df[dev_df['rba_pred'] == 'null'].rba_correct.mean()
print('NULL Recall: {}'.format(null_recall))
print('NULL Precision: {}'.format(null_precision))
non_null_recall_at_1 = dev_df[dev_df.null_or_cui == False].rba_correct.mean() 
print('Non-NULL Recall @1: {}'.format(non_null_recall_at_1))

all_performance_metrics.append(('RBA Only',overall_recall_at_1, null_recall, null_precision, non_null_recall_at_1))

## RBA + Ordering Performance

In [None]:
correct = []
preds = []

for i, row in tqdm(dev_df.iterrows()):
    query_aui = row.aui
    true_cui = row.true_cui
    null_cui = row.null_or_cui
    
    rba_cuis = rba_predicted_synonyms[query_aui]
    
    sorted_cuis = row.cand_cuis
    
    if len(rba_cuis) == 0:
        rba_pred = 'null'
        preds.append(rba_pred)
    
        if null_cui:
            correct.append(1)
        else:
            correct.append(0)
    else:        
        chosen_cui = None
        
        for sorted_cui in sorted_cuis:
            if sorted_cui in rba_cuis:
                chosen_cui = sorted_cui
                break
        
        if chosen_cui is None:
            preds.append('C-None')
        else:
            preds.append(chosen_cui)
            
        if true_cui == chosen_cui:
            correct.append(1)
        else:
            correct.append(0)

In [None]:
dev_df['rba_sorted_pred'] = preds
dev_df['rba_sorted_correct'] = correct

In [None]:
overall_recall_at_1 = dev_df.rba_sorted_correct.mean()
print('Full R@1: {}'.format(overall_recall_at_1))
null_recall = dev_df[dev_df.null_or_cui].rba_sorted_correct.mean()
null_precision = dev_df[dev_df['rba_pred'] == 'null'].rba_sorted_correct.mean()
print('NULL Recall: {}'.format(null_recall))
print('NULL Precision: {}'.format(null_precision))
non_null_recall_at_1 = dev_df[dev_df.null_or_cui == False].rba_sorted_correct.mean() 
print('Non-NULL Recall @1: {}'.format(non_null_recall_at_1))

all_performance_metrics.append(('RBA Plus Ranking', overall_recall_at_1, null_recall, null_precision, non_null_recall_at_1))

## Ordering Performance

In [None]:
#Obtained from 2020AB Insertion Set Training Set
threshold = 0.9558532655239105 #SAPBERT
# threshold = 0.9999997615814209 #PubMedBERT


print('Threshold: {}'.format(threshold))

correct = []
preds = []

for i, row in tqdm(dev_df.iterrows()):
    query_aui = row.aui
    true_cui = row.true_cui
    null_cui = row.null_or_cui

    sorted_cuis = row.cand_cuis
    pred_cui = sorted_cuis[0]

    closest_dist = row.cand_dists[0]

    if closest_dist < threshold:
        pred = 'null'
        preds.append(pred)

        if null_cui:
            correct.append(1)
        else:
            correct.append(0)
    else:        
        preds.append(pred_cui)

        if true_cui == pred_cui:
            correct.append(1)
        else:
            correct.append(0)

dev_df['sorted_pred'] = preds
dev_df['sorted_correct'] = correct

overall_recall_at_1 = dev_df.sorted_correct.mean()
print('Full R@1: {}'.format(overall_recall_at_1))
null_recall = dev_df[dev_df.null_or_cui].sorted_correct.mean()
null_precision = dev_df[dev_df['sorted_pred'] == 'null'].sorted_correct.mean()
print('NULL Recall: {}'.format(null_recall))
print('NULL Precision: {}'.format(null_precision))
non_null_recall_at_1 = dev_df[dev_df.null_or_cui == False].sorted_correct.mean() 
print('Non-NULL Recall @1: {}'.format(non_null_recall_at_1))
print('=='*20)

all_performance_metrics.append(('Ranking Only', overall_recall_at_1, null_recall, null_precision, non_null_recall_at_1))

In [None]:
pd.DataFrame(all_performance_metrics)

## Two-Step Re-Ranking Performance

In [None]:
directory = 'models'

for add_rba_info in [True, False]:
    max_len = 64
    num_candidates = 50
    subsets = {'train':10,'valid':None,'test':None}

    data_name = 'uva_ins_eval_{}_max-len-{}_num-cands-{}_train-size-{}'.format(umls_version, max_len, num_candidates, subsets['train'])

    if add_rba_info:
        data_name += '_rba_info'
        
    data_name =  directory+'/'+data_name
    
    best_checkpoint = data_name+'/epoch_init'
    
    logits,labels = pickle.load(open('{}/eval_results.p'.format(best_checkpoint),'rb'))
    dev_set = torch.load('{}/top50_candidates/valid.t7'.format(data_name))
    query = pubmed_bert_tokenizer.batch_decode(dev_set['context_vecs'])

    split_datasets = pickle.load(open('{}/cui_based_datasets.p'.format(data_name),'rb'))
    cui_dev_set = split_datasets['valid']

    logits_flat = []
    for l in logits:
        logits_flat.extend(l)

    labels_flat = []
    for l in labels:
        labels_flat.extend(l)

    assert len(cui_dev_set) == len(logits_flat) 

    cui_dev_set['re_ranked_labels'] = labels_flat
    cui_dev_set['re_ranked_logits'] = logits_flat
    cui_dev_set['re_ranked_preds'] = np.argmax(np.array(logits_flat),axis=1)
    
    preds = []

    for pred, candidates in tqdm(zip(cui_dev_set['re_ranked_preds'], dev_set['candidate_vecs']),total=len(cui_dev_set)):

        pred_str = pubmed_bert_tokenizer.decode(candidates[pred])
        pred_str = pred_str.replace('[PAD]','').replace('[CLS]','').replace('[SEP]','').strip()

        preds.append(pred_str)
    
    cui_dev_set['re_ranked_pred_strs'] = preds
    
    correct = []

    for i, row in tqdm(cui_dev_set.iterrows()):
        query_aui = row.aui
        true_cui = row.true_cui
        null_cui = row.null_or_cui

        if row['re_ranked_labels'] == row['re_ranked_preds']:
            correct.append(1)
        else:
            correct.append(0)

    cui_dev_set['re_ranked_correct'] = correct

    cui_dev_set = dev_df[['aui']].merge(cui_dev_set,on='aui',how='inner')

    overall_recall_at_1 = cui_dev_set.re_ranked_correct.mean()
    print('Full R@1: {}'.format(overall_recall_at_1))
    null_recall = cui_dev_set[cui_dev_set.null_or_cui].re_ranked_correct.mean()
    null_precision = cui_dev_set[cui_dev_set['re_ranked_pred_strs'] == 'null'].re_ranked_correct.mean()
    print('NULL Recall: {}'.format(null_recall))
    print('NULL Precision: {}'.format(null_precision))
    non_null_recall_at_1 = cui_dev_set[cui_dev_set.null_or_cui == False].re_ranked_correct.mean() 
    print('Non-NULL Recall @1: {}'.format(non_null_recall_at_1))
    print('=='*20)
    
    all_performance_metrics.append((add_rba_info, overall_recall_at_1, null_recall, null_precision, non_null_recall_at_1))

In [None]:
all_performance_metrics = pd.DataFrame(all_performance_metrics)
all_performance_metrics['f1'] = 2*all_performance_metrics[2]*all_performance_metrics[3]/(all_performance_metrics[2] + all_performance_metrics[3])
all_performance_metrics

In [None]:
dev_df.groupby('sem_group').count().sort_values('aui',ascending=False)