In [4]:
import json
from collections import defaultdict
from nir.utils import create_filter_query_function, change_bm25_parameters
from elasticsearch import Elasticsearch, helpers
from mmnrm.utils import set_random_seed
from mmnrm.dataset import TestCollectionV2, TrainPairwiseCollection
from mmnrm.text import TREC_goldstandard_transform, TREC_queries_transform, TREC_results_transform
from mmnrm.evaluation import TREC_Evaluator
import pandas as pd
import numpy as np
import random

import os 
import sys

from utils import collection_iterator

set_random_seed(42)


In [13]:
def load_TREC_qrels(q_rels_file, test_ids):
    
    with open(q_rels_file) as f:
        goldstandard = defaultdict(list)

        for line in f:
            line = line.strip().split(" ")
            try:
                if line[0] in test_ids:
                    continue
                goldstandard[line[0]].append((line[2], line[3]))
            except :
                print(line)
            
    return TREC_goldstandard_transform(goldstandard)


import xmltodict
import ctypes
from collections import defaultdict

def load_TREC_topics(topics_file, test_ids):
    with open(topics_file) as f:
        xml_dict=xmltodict.parse(f.read())["topics"]["topic"]

    topics_train_json = []
    topics_test_json = []
    
    for topic in xml_dict:
        if topic["@number"] in test_ids:
            topics_test_json.append({"id":topic["@number"],
                           "disease":topic["disease"],
                           "gene":topic["gene"],
                           "demographic":topic["demographic"]})
        else:
            topics_train_json.append({"id":topic["@number"],
                           "disease":topic["disease"],
                           "gene":topic["gene"],
                           "demographic":topic["demographic"]})
        
    train = TREC_queries_transform(topics_train_json, number_parameter="id", fn=lambda x:"What is the treatment for " + x["disease"]+" "+x["gene"])
    test = TREC_queries_transform(topics_test_json, number_parameter="id", fn=lambda x:"What is the treatment for " + x["disease"]+" "+x["gene"])
    return train, test

In [20]:
test_ids = set(map(lambda x:str(x), random.sample(list(range(1,41)),k=10)))

train_topics, test_topics = load_TREC_topics("topics2019.xml", test_ids)
goldstandard = load_TREC_qrels("qrels-treceval-abstracts.2019.txt", test_ids)

zipped_collection = "/backup/TREC-PM/Corpus/collection-json.tar.gz"


In [21]:
test_topics

[{'id': '3',
  'query': 'What is the treatment for prostate cancer ATM deletion'},
 {'id': '7',
  'query': 'What is the treatment for non-small cell lung cancer EGFR (T790M)'},
 {'id': '10',
  'query': 'What is the treatment for mucosal melanoma KIT (L576P), KIT amplification'},
 {'id': '12',
  'query': 'What is the treatment for inflammatory myofibroblastic tumor RANBP2-ALK fusion'},
 {'id': '16',
  'query': 'What is the treatment for papillary thyroid carcinoma BRAF (V600E)'},
 {'id': '18',
  'query': 'What is the treatment for lung adenocarcinoma SND1-BRAF fusion'},
 {'id': '19',
  'query': 'What is the treatment for colon cancer ERBB2 amplification'},
 {'id': '20', 'query': 'What is the treatment for pancreatic cancer BRCA2'},
 {'id': '27',
  'query': 'What is the treatment for non-small cell lung cancer KRAS (G12C)'},
 {'id': '30', 'query': 'What is the treatment for endometrial cancer PIK3R1'}]

In [4]:
unique_pmid = set()
skipped = 0
articles = {}
for article_subset in collection_iterator(zipped_collection):
    for article in article_subset:

        #skip empty abstracts
        if article["abstract"]=="":
            skipped+=1
            continue

        if article["id"] in unique_pmid:
            continue

        unique_pmid.add(article["id"])
        articles[article["id"]] = {"id": article["id"],
                                   "text": article["title"]+" "+article["abstract"],
                                   "title": article["title"]}

[CORPORA] Openning tar file /backup/TREC-PM/Corpus/collection-json.tar.gz
[CORPORA] Openning tar file tmp/tmpvcyemaof/TREC-PM-baseline-00000000-to-06000000
[CORPORA] Openning tar file tmp/tmpvcyemaof/TREC-PM-baseline-06000000-to-12000000
[CORPORA] Openning tar file tmp/tmpvcyemaof/TREC-PM-baseline-12000000-to-18000000
[CORPORA] Openning tar file tmp/tmpvcyemaof/TREC-PM-baseline-18000000-to-24000000
[CORPORA] Openning tar file tmp/tmpvcyemaof/TREC-PM-baseline-24000000-to-29138919


In [22]:
# clean the goldstandard
for q_id in goldstandard.keys():
    for r in goldstandard[q_id].keys():
        goldstandard[q_id][r] = list(filter(lambda x: x in articles, goldstandard[q_id][r]))

In [23]:
train_collection = TrainPairwiseCollection(train_topics, goldstandard, articles).batch_size(100)

In [24]:
from collections import Counter
print(train_collection.query_sampling_strategy)

train_collection.set_query_sampling_strategy(2)
print(train_collection.query_sampling_strategy)
train_collection.batch_size(100000)
c = Counter(next(train_collection.generator())[0])

print(min(c.values()), max(c.values()))

0
2
2293 5453


In [25]:
train_collection.save("train_data_2019")

## Test set

In [1]:
import tempfile
import sys
import subprocess
import shutil
from elasticsearch import Elasticsearch, helpers
from nir.utils import create_filter_query_function, change_bm25_parameters

import os
import sys
from collections import defaultdict
import pickle
import math
import pandas as pd
from os.path import join

import tensorflow as tf
from tensorflow.keras import backend as K

from mmnrm.evaluation import TREC_Evaluator
from mmnrm.utils import set_random_seed, load_model_weights, load_model
from mmnrm.dataset import TestCollectionV2, sentence_splitter_builderV2
from mmnrm.evaluation import BioASQ_Evaluator
from mmnrm.modelsv2 import deep_rank
from mmnrm.text import TREC_goldstandard_transform, TREC_queries_transform, TREC_results_transform


import numpy as np

es = Elasticsearch(["http://193.136.175.98:8125"])

index_name = "trec-pm-2020-synonym"

def execute_queries(queries, top_k=1000):
    filter_query_string = create_filter_query_function()

    documents = {}

    for j,query_data in enumerate(queries):

        if not j%10:
            print(j, end="\r")

        query = filter_query_string(query_data["query"])
        query_es = {
                  "query": {
                    "bool": {
                      "must": [
                        {
                          "query_string": {
                            "query": query, 
                            "analyzer": "english",
                            "fields": ["text"]
                          }
                        }
                      ], 
                      "filter": [], 
                      "should": [], 
                      "must_not": []
                    }
                  }
                }



        retrieved = es.search(index=index_name, body=query_es, size=top_k, request_timeout=200)

        documents[query_data["id"]] = list(map(lambda x:{"id":x["_source"]["id"], 
                                                         "text":x["_source"]["text"],
                                                         "title":x["_source"]["title"],
                                                         "score":x['_score']}, retrieved['hits']['hits']))

        # just to ensure the elastic search order is mantained
        validate_order = lambda x:all(x[i] >= x[i+1] for i in range(len(x)-1))
        assert validate_order(list(map(lambda x: x['_score'], retrieved['hits']['hits'])))
        
    return documents


In [8]:
retrieved = execute_queries(test_topics)

0

In [10]:
trec_evaluator = TREC_Evaluator("qrels-treceval-abstracts.2019.txt", '/backup/TREC/TestSet/trec_eval-9.0.7/trec_eval')
test_collection = TestCollectionV2(test_topics, retrieved, trec_evaluator).batch_size(100)

In [11]:
test_collection.save("test_data_2019")