In [2]:
import json
import gzip
import os

In [3]:
def load_jsonl(file_path):
  data = []
  with open(file_path, 'r') as file:
    for line in file:
      json_data = json.loads(line)
      data.append(json_data)
  return data

def load_json(file_path):
  with open(file_path, 'r') as file:
    data = json.load(file)
  return data

def load_partial_json(file_path, num_lines=10):
    data = []
    with open(file_path, 'r') as file:
        for _ in range(num_lines):
            line = file.readline()
            if not line:
                break
            try:
                json_data = json.loads(line)
                data.append(json_data)
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON: {e}")
    return data



In [4]:
# collection of texts

if not os.path.exists("datasets/collection_with_texts.jsonl"):

        collection_with_texts = list()

        with gzip.open("datasets/pubmed_2022_tiny.jsonl.gz", 'rb') as gz_file:
                for line in gz_file:
                        res = line.decode('utf-8')
                        res = json.loads(res)
                        collection_with_texts.append({"id": res["pmid"], "text" : res["title"] + " " + res["abstract"]})


        collection_with_texts_id = list()

        for id in collection_with_texts:
                collection_with_texts_id.append(id["id"])

        positive = load_jsonl("datasets/RI_2023_training_data_wContents.jsonl")

        positive_text = list()
        list_of_ids = list()

        for pos in positive:
                for id in pos["documents"]:
                        if id not in list_of_ids:
                                list_of_ids.append(id)
                                positive_text.append({"id": id["id"], "text": id["text"]})

        for i in positive_text:
                if i["id"] not in collection_with_texts_id:
                        collection_with_texts.append(i)

        with open("datasets/collection_with_texts.jsonl", "w") as jsonl_file:

                for entry in collection_with_texts:
                # Convert each dictionary to a JSON string and write it as a line
                        json_line = json.dumps(entry)
                        jsonl_file.write(json_line + '\n')

In [5]:
# train_positive_docs

dataset_positive = list()

positive = load_jsonl("datasets/RI_2023_training_data_wContents.jsonl")
for p in positive:
        docs_list = list()
        for d in p["documents"]:
                docs_list.append(d["id"])
        dataset_positive.append({"body": p["body"], "documents": docs_list, "id": p["id"]})

with open("datasets/train_dataset_positive.jsonl", "w") as jsonl_file:

        for entry in dataset_positive:
        # Convert each dictionary to a JSON string and write it as a line
                json_line = json.dumps(entry)
                jsonl_file.write(json_line + '\n')


In [6]:
import jsonlines

negative = load_jsonl("datasets/train_dataset_negative.jsonl")
positive = load_jsonl("datasets/train_dataset_positive.jsonl")
i = 0
k = 0
while i<=len(negative):
    try:
        neg = negative[i]
        pos = positive[k]
    except:
        break
    if pos["documents"] in neg["neg_docs"]:
        pos_set = set(pos["documents"])
        neg_set = set(neg["neg_docs"])
        common_docs = pos_set.intersection(neg_set)
        neg_set -= common_docs
        neg["neg_docs"] = list(neg_set)
    
    if len(neg["neg_docs"]) <= len(pos["documents"]):

        positive.pop(k)
        negative.pop(k)
        k= k-1
        i = i-1

    k=k+1
    i=i+1
   

with jsonlines.open('datasets/positive_filtered.jsonl', 'w') as writer:
    writer.write_all(positive)
with jsonlines.open('datasets/negative_filtered.jsonl', 'w') as writer:
    writer.write_all(negative)

In [7]:
from data import get_qrels, InferenceRankingIterator,  BioASQPointwiseIterator, InferenceDataset, create_training_dataset
from sampler import BasicSampler
from transformers import AutoTokenizer


model_checkpoint = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tokenizer.model_max_length = 512





# torch.utils.data.Dataset (Iterable)
train_ds = create_training_dataset("datasets/positive_filtered.jsonl", # "body"/"question" q_id:[doc_id] 60K
                                   "datasets/negative_filtered.jsonl", # q_id:[{"id":doc_id, "score": doc_id}]
                                   "datasets/collection_with_texts.jsonl", # doc_id: text (title + " " + abstract)
                                    tokenizer=tokenizer,
                                    iterator_class=BioASQPointwiseIterator[BasicSampler],
                                    #max_questions=500, # debug
                                )

In [8]:
gs = load_jsonl("datasets/question_E8B1_gs.jsonl")

key_mapping = {"query_id": "id", "query_text": "body", "documents_pmid": "documents"}

dev_gs = list()

for d in gs:

        dev_gs.append({key_mapping[old_key]: value for old_key, value in d.items()})

with open("datasets/dev_gs.jsonl", "w") as jsonl_file:

        for entry in dev_gs:
        # Convert each dictionary to a JSON string and write it as a line
                json_line = json.dumps(entry)
                jsonl_file.write(json_line + '\n')

In [27]:
scores_bm25 = load_json("datasets/scores_bm25.json")
queries = load_jsonl("datasets/question_E8B1_gs.jsonl")

queries_text = dict()

for query in queries:
    queries_text[query["query_id"]] = query["query_text"]

convert = list()
for score in scores_bm25:
    
    docs = list()
    for doc in scores_bm25[score]:
        docs.append({"id": doc, "score":50}) ## não gerei os dados com o score
    convert.append({"id": score, "documents": docs, "question":queries_text[score]})

with open("datasets/scores_bm25.jsonl", 'w') as jsonl_file:
    for entry in convert:
        json_line = json.dumps(entry)
        jsonl_file.write(json_line + '\n')

In [28]:
bm25 = load_jsonl("datasets/scores_bm25.jsonl")
golden = load_jsonl("datasets/dev_gs.jsonl")

In [29]:
bm25_ids = list()

for i in bm25:
        bm25_ids.append(i["id"])

golden_ids = list()

for i in golden:
        golden_ids.append(i["id"])

In [30]:
# torch.utils.data.Dataset (Iterable)
dev_ds = InferenceDataset("datasets/scores_bm25.jsonl", #q_id:[doc_id ->>] BM25 top-1000 top-100 # 100 question -> 1000 docs
                          train_ds.collection,
                          tokenizer,
                          #max_questions=10, # debug
                          at=100, #max docs
                          gs_path="datasets/dev_gs.jsonl", # q_id: [doc_id]
                          iterator_class=InferenceRankingIterator)

In [31]:
from transformers import AutoModelForSequenceClassification

id2label = {0: "IRRELEVANT", 1: "RELEVANT"}
label2id = {"IRRELEVANT": 0, "RELEVANT": 1}

model = AutoModelForSequenceClassification.from_pretrained(
    model_checkpoint, num_labels=2, id2label=id2label, label2id=label2id
)#.to("cuda")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [32]:
type(train_ds)

data.BioASQDataset

In [33]:
_iter = iter(train_ds)

print(next(_iter))
print(next(_iter))

{'input_ids': [101, 2003, 15578, 4588, 14854, 2594, 5648, 2109, 2005, 3949, 1997, 3078, 12170, 6632, 2854, 16480, 25023, 13706, 1029, 102, 15578, 4588, 14854, 2594, 5648, 2005, 1996, 3949, 1997, 3078, 12170, 6632, 2854, 25022, 12171, 25229, 1012, 4955, 2045, 2003, 3278, 4895, 11368, 2342, 1999, 3078, 12170, 6632, 2854, 16480, 25023, 13706, 1006, 1052, 9818, 1007, 1999, 5022, 2104, 1011, 26651, 2000, 1996, 2069, 4844, 7242, 24471, 6499, 3207, 11636, 17994, 23518, 5648, 1006, 20904, 3540, 1007, 2040, 2024, 2012, 3445, 3891, 1997, 27673, 2000, 2203, 1011, 2754, 11290, 4295, 1012, 15578, 4588, 14854, 2594, 5648, 1006, 1051, 3540, 1007, 2003, 1037, 2521, 5267, 9314, 1060, 10769, 1006, 23292, 2099, 1007, 3283, 26942, 2029, 2038, 2042, 16330, 2004, 1037, 2117, 2240, 7242, 1999, 1052, 9818, 1998, 2038, 3728, 2042, 11172, 2094, 2011, 1996, 17473, 1012, 2752, 3139, 1996, 6887, 27292, 22684, 6483, 1998, 7366, 1997, 1051, 3540, 2004, 2019, 23292, 2099, 3283, 26942, 1998, 2049, 6612, 6666, 1012, 10