# Using Siamese BioBERT to create IR system

In [None]:
import pandas as pd
import glob
import os
from elasticsearch import Elasticsearch, helpers
import requests
import numpy as np
import csv
import sys

sys.path.append('../')

from utils.uniprot_loader import *
from utils.annoy_helper import *

In [11]:
device = 'cpu'

In [12]:
ES_HOST = '172.17.0.3'
ES_PORT = '9200'
EVAL_INDEX = 'disease_bert_no_train'

This setting is specific to Novartis machine as we need to connect to local host in this example. Thus, we will remove all env variable for this notebook

In [13]:
os.environ['http_proxy'] = ""
os.environ['HTTP_PROXY'] = ""
os.environ['https_proxy'] = ""
os.environ['HTTPS_PROXY'] = ""
os.environ['NO_PROXY'] = ""
os.environ['no_proxy'] = ""

This function is only an example function for constructing input / output for table liked evaluation.
If you only want to use this to query to the ElasticSearch or Annoy, please refer to the implemtation in the next section

In [14]:
def create_table_like_input():
    query_df = pd.read_csv('../data/devset/disease_eval_list_2.csv')
    query_df = query_df[query_df['ambiguos'] == 0]
    query_df['name'].to_csv('../data/devset/disease_table/disease.tsv', sep='\t', index=False)
    
    rows = []
    cols = []
    kg_label = []
    kg_id = []
    for ridx, row in query_df.reset_index().iterrows():
        rows.append(ridx)
        cols.append(0)
        kg_label.append(row['name'])
        kg_id.append(row['kgid'])
        
    truth_df = pd.DataFrame({
        'row': rows,
        'column': cols,
        'kg_id': kg_id,
        'kg_label': kg_label})
    truth_df.to_csv('../data/devset/truth_disease_table/disease.tsv', sep='\t', index=False)
    return truth_df

In [15]:
truth_df = create_table_like_input()

In [16]:
data_file = '../data/devset/disease_table'
truth_path = '../data/devset/truth_disease_table'
filenames = glob.glob(data_file + '/*')

In [17]:
def get_prediction(filenames, generate_candidate_func, **kwargs):
    prediction = dict()
    for file in filenames:
        basefile_name = os.path.basename(file)
        prediction[basefile_name] = dict()
        table = pd.read_csv(file, delimiter='\t')
        nrow = len(table)
        ncol = len(table.columns)
        answer_table = []
        for i in range(nrow):
            answer_table.append([])
            for j in range(ncol):
                answer_table[i].append([])

        for ridx, row in table.iterrows():
            for cidx, (colname, cell) in enumerate(row.items()):
                if pd.isna(cell):
                    continue
                if isinstance(cell, float):
                    continue
                hits = generate_candidate_func(cell, **kwargs)
                candidate_set = [hit[1] for hit in hits]
                answer_table[ridx][cidx] = candidate_set
        prediction[basefile_name] = answer_table
    return prediction

In [70]:
def evaluate(file, result, topk=3):
    pred = result[file]
    nrow = len(pred)
    ncol = len(pred[0])
    
    gt_table = []
    table = pd.read_csv(truth_path + '/' + file, delimiter='\t')
    tp = 0
    fp = 0
    tn = 0
    fn = 0
    non_empty_gt_idx = set()
    stat = dict()
    stat['prediction_log'] = []
    for i in range(nrow):
        gt_table.append([])
        for j in range(ncol):
            gt_table[i].append([])
            
    for _, row in table.iterrows():
        cidx = row['column']
        ridx = row['row']
        cell_answer = row['kg_id']
        if pd.notna(cell_answer):
            gt_table[ridx][cidx].append(cell_answer)
            non_empty_gt_idx.add((cidx,ridx))
        
    for ridx in range(nrow):
        for cidx in range(ncol):
            gt_set = set(gt_table[ridx][cidx])
            answer_set = set(pred[ridx][cidx][:topk])
            if answer_set and gt_set: 
                if answer_set.intersection(gt_set):
                    tp += 1
                else:
                    fp += 1
                    stat['prediction_log'].append([table[ridx][cidx], answer_set])
            elif answer_set:
                fp += 1
                stat['prediction_log'].append([table[ridx][cidx], answer_set])
            elif gt_set:
                fn += 1
                stat['prediction_log'].append([table[ridx][cidx], answer_set])
            else:
                tn += 1
    stat['tp'] = tp
    stat['fp'] = fp
    stat['fn'] = fn
    stat['tn'] = tn
    stat['topk'] = tp
    stat['total'] = len(non_empty_gt_idx)
    print("Top{}: {}/{}".format(topk, tp, (len(non_empty_gt_idx))))
    return stat, gt_table

In [32]:
def predict_and_eval(filenames, generate_candidates_query_func, **kwargs):
    eval_topk = [1,10]
    if 'top_k' not in kwargs:
        kwargs['top_k'] = 10
    else:
        eval_topk.append(kwargs['top_k'])
    all_pred = get_prediction(filenames, generate_candidates_query_func, **kwargs)
    
    stat_topk = dict()
    
    for topk in eval_topk:
        key = 'top_' + str(topk)
        stat_topk[key] = []
        for file in filenames:
            basefile = os.path.basename(file)
            stat, _ = evaluate(basefile, all_pred, topk)
            stat_topk[key].append(stat)
    return stat_topk

In [20]:
def write_report(filename, stat, extra_topk=None):
    with open(filename, 'w') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['file', 'upperbound', 'top1', 'top10', 'topk', 'percent top1', 'percent top10', 'percent topk'])
        count_upperbound = []
        count_correct_top10 = []
        count_correct_top1 = []
        count_correct_topk = []
        top10_stat = stat['top_10']
        top1_stat = stat['top_1']
        topk_stat = stat['top_{}'.format(extra_topk)] if extra_topk else None
        
        for i, file in enumerate(filenames):
            if UPPER_BOUND_STAT[i]['topk'] != 0:
                top10_percent = top10_stat[i]['topk']*100/UPPER_BOUND_STAT[i]['topk']
                top1_percent = top1_stat[i]['topk']*100/UPPER_BOUND_STAT[i]['topk']
                topk_value = topk_stat[i]['topk'] if topk_stat else 0
                topk_percent = topk_value*100/UPPER_BOUND_STAT[i]['topk']
                writer.writerow([os.path.basename(file), UPPER_BOUND_STAT[i]['topk'], top1_stat[i]['topk'], top10_stat[i]['topk'], topk_value, top1_percent, top10_percent, topk_percent])
                count_upperbound.append(UPPER_BOUND_STAT[i]['topk'])
                count_correct_top10.append(top10_stat[i]['topk'])
                count_correct_top1.append(top1_stat[i]['topk'])
                count_correct_topk.append(topk_value)
        writer.writerow(["Total", np.sum(count_upperbound), np.sum(count_correct_top1), np.sum(count_correct_top10), np.sum(count_correct_topk) ,np.sum(count_correct_top1)/np.sum(count_upperbound), np.sum(count_correct_top10)/np.sum(count_upperbound), np.sum(count_correct_topk)/np.sum(count_upperbound)])     

In [21]:
disease_df = pd.read_csv('../uniprot_data_prep/disease_alias_label.tsv'.format(EVAL_INDEX), delimiter='\t')
corpus_disease = []
for idx, row in disease_df.iterrows():
    corpus_disease.append((row['name'], row['id']))

def get_upper_bound_stat():
    def return_whole_corpus_candidate(cell):
        global corpus_disease
        return corpus_disease

    upper_bound_all_stat = []
    for file in filenames:
        pred_upperbound = get_prediction([file], return_whole_corpus_candidate)
        basefile = os.path.basename(file)
        stat, _ = evaluate(basefile, pred_upperbound, topk=len(corpus_disease))
        upper_bound_all_stat.append(stat)
    return upper_bound_all_stat

UPPER_BOUND_STAT = get_upper_bound_stat()

Top53619: 154/162


In [22]:
disease_ids = set(disease_df['id'])

In [54]:
count = 0
sum_main = 0
for i, file in enumerate(filenames):
    basename = os.path.basename(file)
    columns = dict()
    if UPPER_BOUND_STAT[i]['topk'] != 0:
        print(basename)
        truth_df = pd.read_csv(os.path.join(truth_path, basename), delimiter='\t')
        seen = set()
        for _, row in truth_df.iterrows():
            if row['kg_id'] in disease_ids and pd.notna(row['kg_id']):
                if (row['row'], row['column']) in seen:
                    continue
                seen.add((row['row'], row['column']))
                count+=1
                if row['column'] not in columns:
                    columns[row['column']] = 0
                columns[row['column']] += 1
        for j, (col, total) in enumerate(sorted(columns.items(), key=lambda x:x[1], reverse=True)):
            if j == 0:
                sum_main += total
            print("Column: {} Count: {}".format(col, total))
print("Total for Main: ", sum_main)
print("Total: ", count)

Human_blood_group_systems_dbpv=2020-02_nif=table_ref=3_2_order=0.tsv
Column: 3 Count: 2
Column: 1 Count: 1
Keratin_disease_dbpv=2020-02_nif=table_ref=1_2_order=0.tsv
Column: 0 Count: 7
Major_facilitator_superfamily_dbpv=2020-02_nif=table_ref=6.1_2_order=0.tsv
Column: 3 Count: 10
Potassium_channel_dbpv=2020-02_nif=table_ref=2_2_order=0.tsv
Column: 2 Count: 1
Cancer_syndrome_dbpv=2020-02_nif=table_ref=4_2_order=0.tsv
Column: 0 Count: 8
Distal_hereditary_motor_neuronopathies_dbpv=2020-02_nif=table_ref=1_2_order=0.tsv
Column: 5 Count: 6
Ciliopathy_dbpv=2020-02_nif=table_ref=4.1_2_order=0.tsv
Column: 0 Count: 12
List_of_therapeutic_monoclonal_antibodies_dbpv=2020-02_nif=table_ref=0_1_order=0.tsv
Column: 6 Count: 76
Column: 4 Count: 7
Ciliopathy_dbpv=2020-02_nif=table_ref=4.3_2_order=0.tsv
Column: 0 Count: 50
Total for Main:  172
Total:  180


## Elastic Search Indexing

In [23]:
es = Elasticsearch([ {'host': '172.17.0.3', 'port': 9200}])

In [24]:
disease_df

Unnamed: 0,index,id,name
0,0,Q1001150,fibrillation
1,1,Q100165995,acute pulmonary hypertension
2,2,Q1001920,hallucinogen persisting perception disorder
3,3,Q1002195,autosomal recessive limb-girdle muscular dystr...
4,4,Q100270830,Benadryl challenge
...,...,...,...
53614,41068,Q998273,"Polydactyly, Sex Reversal, Renal Hypoplasia, a..."
53615,41069,Q998273,Smith Lemli Opitz syndrome
53616,41070,Q998273,SMITH-LEMLI-OPITZ SYNDROME; SLOS
53617,41071,Q998273,SLOS


In [26]:
def index_to_es_embedding(corpus_df, index='', model=None):
    mod_factor = len(corpus_df)//10
    for idx, row in corpus_df.iterrows():
        if idx % mod_factor == 0 or idx == len(corpus_df)-1:
            print(idx*100/len(corpus_df))
            
        payload = dict()
        payload['name'] = row['name'].strip()
        payload['qid'] = row['id']

        if model:
            word_emb = model.encode(row['name'], device=device)
            payload['word_embedding'] = word_emb
            
        try:
            res = es.create(index=index, body=payload, id=str(idx))
        except Exception as e:
            print(e)

## Query ElasticSearch with Fuzzy

In [27]:
def query_elastic_search_fuzzy(keyword, **kwargs):
    url = "http://{}:{}/{}/_search".format(ES_HOST, ES_PORT, EVAL_INDEX)
    payload = {
        "size": top_k,
        "query": {
            "match" : {
                "name" : {
                    "query" : keyword,
                    "fuzziness": "auto"
                }
            }
        }
    }
    try:
        top_k = kwargs['top_k']
        r = requests.get(url, json = payload)
        hits = r.json()['hits']['hits']
        hits = [(doc['_source']['name'], doc['_source']['qid']) for doc in hits[:top_k]]
        return hits
    except Exception as e:
        print(e)
        return []

## Query Elasticsearch with ElasticSearch Vector Field (Linear Scan) 

In [28]:
def query_elastic_search_with_embed(keyword, **kwargs):
    if 'model' not in kwargs:
        print("No Model Found, returning without querying ES")
        return []
    model = kwargs['model']
    top_k = kwargs['top_k'] if 'top_k' in kwargs else 10
    url = "http://{}:{}/{}/_search".format(ES_HOST, ES_PORT, EVAL_INDEX)
    vector = list(model.encode(str(keyword), device=device).astype(float))
    payload = {
        "size": top_k,
        "query": {
            "script_score": {
                "query" : {
                    "match_all" : {}
                },
                "script": {
                    "source": "cosineSimilarity(params.query_vector, 'word_embedding') + 1.0", 
                    "params": {
                        "query_vector": vector
                    }
                }
            },
        }
    }
    try:
        r = requests.get(url, json = payload)
        hits = r.json()['hits']['hits']
        hits = [(doc['_source']['name'], doc['_source']['qid']) for doc in hits[:top_k]]
        return hits
    except Exception as e:
        print(e)
        return []

## Query Elasticsearch with Vector and Fuzzy Match

This function will filter the search result using Fuzzy Match and then rerank them with Embedded vector

In [29]:
def query_elastic_search_with_fuzzy_and_embed(keyword, **kwargs):
    if 'model' not in kwargs:
        print("No Model Found, returning without querying ES")
        return []
    model = kwargs['model']
    top_k = kwargs['top_k'] if 'top_k' in kwargs else 10
    url = "http://{}:{}/{}/_search".format(ES_HOST, ES_PORT, EVAL_INDEX)
    vector = list(model.encode(str(keyword), device=device).astype(float))
    payload = {
        "size": top_k,
        "query": {
            "script_score": {
                "query" : {
                    "match" : {
                        "name" : {
                            "query" : keyword,
                            "fuzziness": "auto"
                        }
                    }
                },
                "script": {
                    "source": "cosineSimilarity(params.query_vector, 'word_embedding') + 1.0", 
                    "params": {
                        "query_vector": vector
                    }
                }
            },
        }
    }
    try:
        r = requests.get(url, json = payload)
        hits = r.json()['hits']['hits']
        hits = [(doc['_source']['name'], doc['_source']['qid']) for doc in hits[:top_k]]
        return hits
    except Exception as e:
        print(e)
        return []

## Query to Annoy
To use Annoy as an indexed reference data, we need to first create annoy file and then issue the query to it

In [27]:
import annoy

In [28]:
def generate_candidate_with_annoy(keyword, **kwargs):
    model = kwargs['model']
    annoy_object_wrapper = kwargs['annoy_object_wrapper']
    id2name = kwargs['id2name']
    top_k = kwargs['top_k']
    name2id = kwargs['name2id']
    
    corpus_sentences = annoy_object_wrapper.embedding_object.corpus_sentences
    corpus_embeddings = annoy_object_wrapper.embedding_object.corpus_embeddings
    name_to_id = annoy_object_wrapper.embedding_object.name_to_id
    annoy_index = annoy_object_wrapper.annoy_index

    query_embedding = model.encode(str(keyword), device=device)

    found_corpus_ids, scores = annoy_index.get_nns_by_vector(query_embedding, top_k, include_distances=True)
    hits = []

    for _id, score in zip(found_corpus_ids, scores):
        # Cosine Distance is equivalent to Euclidean distance of normalized vectors = sqrt(2-2*cos(u, v))
        # cosine_dist = sqrt(2-2*cos(u,v))
        # Thus cos(u,v) = 1-(cosine_dist**2)/2
        hits.append({'corpus_id': _id, 'score': 1 - ((score ** 2) / 2)})

    end_time = time.time()
    
    return_hits= []

    for hit in hits:
        name = corpus_sentences[hit['corpus_id']]
        possible_id = name2id.get(corpus_sentences[hit['corpus_id']], [])
        for _id in possible_id:
            return_hits.append((name, _id))

    return return_hits

In [97]:
name2id.get(annoy_object_wrapper.embedding_object.corpus_sentences[9267])

{'Q987664'}

In [95]:
list(id2name.keys())[0]

'Q1001150'

In [29]:
model = get_model('./siamese-biobert-disease-ep-1-Mar-29-2021-with-heuris-hard-neg-10')
annoy_object_wrapper = AnnoyObjectWrapper(index_path='./wikidata-disease-embedding-4096-trees.ann', 
                          embedding_path='./wikidata-disease-768-embedding.pkl', 
                          reference_dataset_path='../uniprot_data_prep/{}_alias_label.tsv'.format(EVAL_INDEX), 
                          name2id_path='./wikidata-name2id-embedding-size-1500000', 
                          model=model, n_trees=4096, embedding_size=768, max_corpus_size=1500000)
annoy_object_wrapper.create_embedding_and_index()

Load pre-computed embeddings from disc


In [30]:
def get_id2name(dataset_path):
    all_name = pd.read_csv(dataset_path, delimiter='\t')
    all_name = all_name.dropna() 
    
    id2name = {}
    for idx, row in all_name.iterrows():
        if row['id'].strip() not in id2name:
            id2name[row['id']] = set()
        id2name[row['id']].add(row['name'].strip())
        
    return id2name


def get_name2id(dataset_path):
    all_name = pd.read_csv(dataset_path, delimiter='\t')
    all_name = all_name.dropna() 
    
    name2id = {}
    for idx, row in all_name.iterrows():
        if row['name'].strip() not in name2id:
            name2id[row['name']] = set()
        name2id[row['name']].add(row['id'].strip())
        
    return name2id

In [31]:
id2name = get_id2name('../uniprot_data_prep/{}_alias_label.tsv'.format(EVAL_INDEX))
name2id = get_name2id('../uniprot_data_prep/{}_alias_label.tsv'.format(EVAL_INDEX))

In [None]:
stat_topk = predict_and_eval(filenames, generate_candidate_with_annoy, model=model, annoy_object_wrapper=annoy_object_wrapper, id2name=id2name, name2id=name2id)
write_report('./annoy_report.csv', stat_topk)

In [None]:
display(pd.read_csv('./annoy_report.csv'))

## Evaluate with interest model

This section shows how we can evaluate each model in ElasticSearch. You can just pass query function to the predict_and_eval directly. For example, we pass `query_elastic_search_with_embed` into `predict_and_eval` to evaluate how the model perform using the vector we indexed in Elasticsearch

In [None]:
model_list = [
    './siamese-biobert-disease-ep-1-Apr-16-2021-with-heuris-hard-neg-1',
    './siamese-biobert-disease-ep-1-Apr-16-2021-with-heuris-hard-neg-2',
    './siamese-biobert-disease-ep-1-Apr-16-2021-with-heuris-hard-neg-3',
    './siamese-biobert-disease-ep-1-Apr-16-2021-with-heuris-hard-neg-4'
]

es_index_payload = {
    "mappings": {
        "properties": {
            "name": {
                "type": "text"
            },
            "qid": {
                "type": "keyword"
            },
            "word_embedding": {
                "type": "dense_vector",
                "dims": 768
            }
        }
    }
}

all_pred_log = dict()

for model_name in model_list:
    basefile = os.path.basename(model_name)
    EVAL_INDEX = basefile.lower()
    print(EVAL_INDEX)
    es.indices.create(
        index=EVAL_INDEX,
        body=es_index_payload,
        ignore=400
    )
    model = get_model(model_name,device=device)
    top_k=100
    index_to_es_embedding(disease_df, index=EVAL_INDEX, model=model)
    stat_topk = predict_and_eval(filenames, query_elastic_search_with_embed, model=model, top_k=top_k)
    all_pred_log[model_name] = stat_topk
    write_report('./es_vector_report/{}_top{}.csv'.format(EVAL_INDEX, top_k), stat_topk, extra_topk=top_k)
    display(pd.read_csv('./es_vector_report/{}_top{}.csv'.format(EVAL_INDEX, top_k)))