In [1]:
import sys, os, re, json, pickle, ijson,json
from elasticsearch import Elasticsearch
from tqdm import tqdm
from pprint import pprint

In [5]:

# Modified bioclean: also split on dashes. Works better for retrieval with galago.
bioclean_mod = lambda t: re.sub(
    '[.,?;*!%^&_+():-\[\]{}]', '',
    t.replace('"', '').replace('/', '').replace('\\', '').replace("'", '').replace("-", ' ').strip().lower()
).split()
bioclean    = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').replace('\\', '').replace("'", '').strip().lower()).split()

doc_index = 'pubmed_abstracts_index_0_1'
map         = "pubmed_abstracts_mapping_0_1"
def idf_val(w, idf, max_idf):
    if w in idf:
        return idf[w]
    return max_idf

def tokenize(x):
  return bioclean(x)

def GetWords(data, doc_text, words):
  for i in range(len(data['queries'])):
    qwds = tokenize(data['queries'][i]['query_text'])
    for w in qwds:
      words[w] = 1
    for j in range(len(data['queries'][i]['retrieved_documents'])):
      doc_id = data['queries'][i]['retrieved_documents'][j]['doc_id']
      dtext = (
              doc_text[doc_id]['title'] + ' <title> ' + doc_text[doc_id]['abstractText']
              # +
              # ' '.join(
              #     [
              #         ' '.join(mm) for mm in
              #         get_the_mesh(doc_text[doc_id])
              #     ]
              # )
      )
      dwds = tokenize(dtext)
      for w in dwds:
        words[w] = 1

def load_idfs(idf_path, words):
    print('Loading IDF tables')
    #
    # with open(dataloc + 'idf.pkl', 'rb') as f:
    with open(idf_path, 'rb') as f:
        idf = pickle.load(f)
    ret = {}
    for w in words:
        if w in idf:
            ret[w] = idf[w]
    max_idf = 0.0
    for w in idf:
        if idf[w] > max_idf:
            max_idf = idf[w]
    idf = None
    print('Loaded idf tables with max idf {}'.format(max_idf))
    #
    return ret, max_idf

def load_all_data(dataloc, idf_pickle_path):
    print('loading pickle data')
    #
    with open(dataloc+'trainining7b.json', 'r') as f:
        bioasq7_data = json.load(f)
        bioasq7_data = dict((q['id'], q) for q in bioasq7_data['questions'])
    #
    with open(dataloc + 'bioasq7_bm25_top100.dev.pkl', 'rb') as f:
        dev_data = pickle.load(f)
    with open(dataloc + 'bioasq7_bm25_docset_top100.dev.pkl', 'rb') as f:
        dev_docs = pickle.load(f)
    with open(dataloc + 'bioasq7_bm25_top100.train.pkl', 'rb') as f:
        train_data = pickle.load(f)
    with open(dataloc + 'bioasq7_bm25_docset_top100.train.pkl', 'rb') as f:
        train_docs = pickle.load(f)
    print('loading words')
    #
    words               = {}
    GetWords(train_data, train_docs, words)
    GetWords(dev_data,   dev_docs,   words)
    #
    print('loading idfs')
    idf, max_idf    = load_idfs(idf_pickle_path, words)
    return dev_data, dev_docs, train_data, train_docs, idf, max_idf, bioasq7_data

# recall: 0.3883
def get_first_n_20(qtext, n, max_year=2019):
    #
    tokenized_body  = bioclean_mod(qtext)
    question_tokens = [t for t in tokenized_body if t not in stopwords]
    idf_scores      = [idf_val(w, idf, max_idf) for w in question_tokens]
    question        = ' '.join(question_tokens)
    #
    the_shoulds = []
    for q_tok, idf_score in zip(question_tokens, idf_scores):
        the_shoulds.append({"match": {"joint_text"                  : {"query": q_tok, "boost": idf_score}}})
        the_shoulds.append({"match": {"Chemicals.NameOfSubstance"   : {"query": q_tok, "boost": idf_score}}})
        the_shoulds.append({"match": {"MeshHeadings.text"           : {"query": q_tok, "boost": idf_score}}})
        the_shoulds.append({"match": {"SupplMeshList.text"          : {"query": q_tok, "boost": idf_score}}})
        ################################################
        the_shoulds.append({"terms": {"joint_text"                  : [q_tok], "boost": idf_score}})
        the_shoulds.append({"terms": {"Chemicals.NameOfSubstance"   : [q_tok], "boost": idf_score}})
        the_shoulds.append({"terms": {"MeshHeadings.text"           : [q_tok], "boost": idf_score}})
        the_shoulds.append({"terms": {"joint_text"                  : [q_tok], "boost": idf_score}})
    ################################################
    if(len(question_tokens) > 1):
        the_shoulds.append({"span_near": {"clauses": [{"span_term": {"joint_text": w}} for w in question_tokens], "slop": 5, "in_order": False}})
    ################################################
    bod         = {
        "size": n,
        "query": {
            "bool": {
                "must": [{"range":{"DateCompleted": {"gte": "1800", "lte": str(max_year), "format": "dd/MM/yyyy||yyyy"}}}],
                "should": [
                    {"match":{"joint_text": {"query": question, "boost": sum(idf_scores)}}},
                ]+the_shoulds,
                "minimum_should_match": 1,
            }
        }
    }
    print(json.dumps(bod))
    res         = es.search(index=doc_index, body=bod, request_timeout=120)
    print(res)
    return res['hits']['hits']

# recall: 0.4140
def get_first_n_1(qtext, n, max_year=2017):
    tokenized_body  = bioclean_mod(qtext)
    tokenized_body  = [t for t in tokenized_body if t not in stopwords]
    question        = ' '.join(tokenized_body)
    ################################################
    bod         = {
        "size": n,
        "query": {
            "bool": {
                "must": [{"range": {"DateCompleted": {"gte": "1800", "lte": str(max_year), "format": "dd/MM/yyyy||yyyy"}}}],
                "should": [{"match": {"joint_text": {"query": question, "boost": 1}}}],
                "minimum_should_match": 1,
            }
        }
    }
    print(json.dumps(bod))
    res         = es.search(index=doc_index,  body=bod, request_timeout=120)
    print(res)
    return res['hits']['hits']

# recall: 0.4144
def get_first_n_2(qtext, n, max_year=2017):
    tokenized_body      = bioclean_mod(qtext)
    question_tokens     = [t for t in tokenized_body if t not in stopwords]
    question            = ' '.join(question_tokens)
    ################################################
    the_shoulds     = []
    if(len(question_tokens) > 1):
        the_shoulds.append({"span_near": {"clauses": [{"span_term": {"joint_text": w}} for w in question_tokens], "slop": 5, "in_order": False}})
    ################################################
    bod         = {
        "size": n,
        "query": {
            "bool": {
                "must": [{"range": {"DateCompleted": {"gte": "1800", "lte": str(max_year), "format": "dd/MM/yyyy||yyyy"}}}],
                "should": [{"match": {"joint_text": {"query": question, "boost": 1}}}] + the_shoulds,
                "minimum_should_match": 1,
            }
        }
    }
    print(json.dumps(bod))
    res         = es.search(index=doc_index, body=bod, request_timeout=120)
    print(res)
    return res['hits']['hits']

# recall: 0.4150
def get_first_n_3(qtext, n, max_year=2020):
    tokenized_body  = bioclean_mod(qtext)
    question_tokens = [t for t in tokenized_body if t not in stopwords]
    question        = ' '.join(question_tokens)
    ################################################
    the_shoulds = []
    the_shoulds.append({"match": {"Chemicals.NameOfSubstance"   : {"query": question}}})
    the_shoulds.append({"match": {"MeshHeadings.text"           : {"query": question}}})
    the_shoulds.append({"match": {"SupplMeshList.text"          : {"query": question}}})
    ################################################
    the_shoulds     = []
    if(len(question_tokens) > 1):
        the_shoulds.append({"span_near": {"clauses": [{"span_term": {"joint_text": w}} for w in question_tokens], "slop": 5, "in_order": False}})
    ################################################
    bod         = {
        "size": n,
        "query": {
            "bool": {
                "must": [{"range": {"DateCompleted": {"gte": "1800", "lte": str(max_year), "format": "dd/MM/yyyy||yyyy"}}}],
                "should": [{"match": {"joint_text": {"query": question, "boost": 1}}}] + the_shoulds,
                "minimum_should_match": 1,
            }
        }
    }
    print(json.dumps(bod))
    res         = es.search(index=doc_index, body=bod, request_timeout=120)
    print(res)
    return res['hits']['hits']

def get_multi(qtext,n, max_year=2020):
    tokenized_body  = bioclean_mod(qtext)
    question_tokens = [t for t in tokenized_body if t not in stopwords]
    question        = ' '.join(question_tokens)
    bod = {
    "size": n,
   "query": {
        "multi_match": {
            "query": question,
            "type":       "most_fields",
            "fields": ["AbstractText","ArticleTitle"]
        }
    }
}
        
    res         = es.search(index=doc_index, body=bod, request_timeout=120)
    return res['hits']['hits']
    



In [3]:
es = Elasticsearch(
   ['localhost:9200'],
    verify_certs        = True,
    timeout             = 150,
    max_retries         = 10,
    retry_on_timeout    = True
)

dataloc='./Data/bioasq_data/'
w2v_bin_path                = './Data/PretrainedWeightsAndVectors/pubmed2018_w2v_30D.bin'
idf_pickle_path             = './Data/PretrainedWeightsAndVectors/idf.pkl'
(dev_data, dev_docs, train_data, train_docs, idf, max_idf, bioasq7_data) = load_all_data(dataloc, idf_pickle_path)

with open('./Data/stopwords.pkl', 'rb') as f:
    stopwords = pickle.load(f)





loading pickle data
loading words
loading idfs
Loading IDF tables
Loaded idf tables with max idf 16.5163157103


In [7]:
recalls = []
for q in tqdm(dev_data['queries']):
    qtext           = q['query_text']
    #####
    #results         = get_first_n_20(qtext, 100,2019)
    #
    #results         = get_first_n_1(qtext, 100,2019)
    #results         = get_first_n_2(qtext, 100,2019)
    #results         = get_first_n_3(qtext, 100,2020)
    results         = get_multi(qtext, 100,2020)
    #####
    #print(results)
    retr_pmids      = [t['_source']['pmid'] for t in results]
    #####
    rel_ret         = sum([1 if (t in q['relevant_documents']) else 0 for t in retr_pmids])
    #####
    recall          = float(rel_ret) / float(len(q['relevant_documents']))
    recalls.append(recall)
    # if(len(recalls) == 100):
    #     break

print('DEV RECALL')
print(sum(recalls) / float(len(recalls)))



  0%|                                                                                          | 0/100 [00:00<?, ?it/s]

  2%|█▋                                                                                | 2/100 [00:00<00:09, 10.20it/s]

  3%|██▍                                                                               | 3/100 [00:00<00:10,  9.22it/s]

  6%|████▉                                                                             | 6/100 [00:02<00:27,  3.44it/s]

  7%|█████▋                                                                            | 7/100 [00:03<00:38,  2.44it/s]

  8%|██████▌                                                                           | 8/100 [00:03<00:29,  3.16it/s]

  9%|███████▍                                                                          | 9/100 [00:03<00:24,  3.68it/s]

 10%|████████                                                                         | 10/100 [00:03<00:23,  3.87it/s]

 11%|████████▉                

DEV RECALL
0.5099327929645772
