In [1]:
import json

import pandas as pd
from tqdm.auto import tqdm

In [2]:
# Dataset from https://allenai.org/data/scifact
dataset_dir = "scifact"

## Loading the data

In [3]:
query_df = pd.read_json(f"data/data_{dataset_dir}/queries.jsonl", lines=True)
query_df

Unnamed: 0,_id,text,metadata
0,0,0-dimensional biomaterials lack inductive prop...,{}
1,2,1 in 5 million in UK have abnormal PrP positiv...,"{'13734012': [{'sentences': [4], 'label': 'CON..."
2,4,1-1% of colorectal cancer patients are diagnos...,{}
3,6,10% of sudden infant death syndrome (SIDS) dea...,{}
4,9,32% of liver transplantation programs required...,"{'44265107': [{'sentences': [15], 'label': 'SU..."
...,...,...,...
1104,1379,Women with a higher birth weight are more like...,"{'16322674': [{'sentences': [5], 'label': 'SUP..."
1105,1382,aPKCz causes tumour enhancement by affecting g...,"{'17755060': [{'sentences': [3], 'label': 'CON..."
1106,1385,cSMAC formation enhances weak ligand signalling.,"{'306006': [{'sentences': [4], 'label': 'SUPPO..."
1107,1389,mTORC2 regulates intracellular cysteine levels...,"{'23895668': [{'sentences': [2, 3], 'label': '..."


In [4]:
corpus_df = pd.read_json(f"data/data_{dataset_dir}/corpus.jsonl", lines=True)
corpus_df

Unnamed: 0,_id,title,text,metadata
0,4983,Microstructural development of human newborn c...,Alterations of the architecture of cerebral wh...,{}
1,5836,Induction of myelodysplasia by myeloid-derived...,Myelodysplastic syndromes (MDS) are age-depend...,{}
2,7912,"BC1 RNA, the transcript from a master gene for...",ID elements are short interspersed elements (S...,{}
3,18670,The DNA Methylome of Human Peripheral Blood Mo...,DNA methylation plays an important role in bio...,{}
4,19238,The human myelin basic protein gene is include...,Two human Golli (for gene expressed in the oli...,{}
...,...,...,...,...
5178,195689316,Body-mass index and cause-specific mortality i...,BACKGROUND The main associations of body-mass ...,{}
5179,195689757,Targeting metabolic remodeling in glioblastoma...,A key aberrant biological difference between t...,{}
5180,196664003,Signaling architectures that transmit unidirec...,A signaling pathway transmits information from...,{}
5181,198133135,"Association between pre-diabetes, type 2 diabe...",AIMS Trabecular bone score (TBS) is a surrogat...,{}


In [5]:
test_rel_df = pd.read_csv(f"data/data_{dataset_dir}/qrels/test.tsv", sep="\t")
test_rel_df

Unnamed: 0,query-id,corpus-id,score
0,1,31715818,1
1,3,14717500,1
2,5,13734012,1
3,13,1606628,1
4,36,5152028,1
...,...,...,...
334,1379,17450673,1
335,1382,17755060,1
336,1385,306006,1
337,1389,23895668,1


## Converting to the Simple Transformers format

For evaluation with Simple Transformers, we are going to generate 3 files.

### TSV file with queries

This file contains:
- query_text
- title
- gold_passage

### TSV file with the passage collection

This file only needs a `passage` column

### JSON file with the same number of entries as the queries (Optional)

This optional file contains a list of lists, where each item in the outer list is a list of relevant documents for each query.

E.g.:

```python
relevant_docs = [
    ["relevant doc 1a", "relevant doc 1b"],
    ["relevant doc 2a", "relevant doc 2b"],
]
```

If this file is provided, the `evaluate_model` method will check the retrieved documents against each of the relevant documents in the file. If this is not provided, `evaluate_model()` will only check against the `gold_passage` column in the *queries* file.

In [6]:
query_df = query_df.rename(columns={"_id": "query_id"})
test_rel_df = test_rel_df.rename(
    columns={"query-id": "query_id", "corpus-id": "corpus_id"}
)
corpus_df = corpus_df.rename(columns={"_id": "corpus_id"})

In [7]:
def get_relevant_title_and_text(query_id, corpus_id, corpus_df, query_df):
    try:
        query_text = query_df.loc[query_df["query_id"] == query_id, "text"].values[0]
    except:
        return None, None, None

    relevant_passage = corpus_df[corpus_df["corpus_id"] == corpus_id]
    if not relevant_passage["text"].values[0]:
        return None, None, None
    return query_text, relevant_passage["title"].values[0], relevant_passage["text"].values[0]


def get_test_and_relevance_data(df):
    if df.iloc[0]["title"]:
        df["combined_passage"] = df["title"] + " " + df["gold_passage"]
    else:
        df["combined_passage"] = df["gold_passage"]
    relevant_docs = []
    for query_id in tqdm(df["query_id"].unique()):
        sub_df = df[df["query_id"] == query_id]
        relevant_docs.append(sub_df["combined_passage"].tolist())
    
    df = df.drop("combined_passage", axis=1)
    return df, relevant_docs


def build_st_format_data(query_df, rel_df, corpus_df):
    query_dict = {
        "query_text": [],
        "title": [],
        "gold_passage": [],
        "query_id": [],
    }

    for query_id, corpus_id in tqdm(zip(rel_df["query_id"], rel_df["corpus_id"]), total=len(rel_df)):
        try:
            query_text, title, gold_passage = get_relevant_title_and_text(
                query_id,
                corpus_id,
                corpus_df,
                query_df,
            )
        except Exception as e:
            continue
        query_dict["query_text"].append(query_text)
        query_dict["title"].append(title)
        query_dict["gold_passage"].append(gold_passage)
        query_dict["query_id"].append(query_id)

    df = pd.DataFrame(query_dict).dropna()

    return get_test_and_relevance_data(df)


In [8]:
test_df, relevant_docs = build_st_format_data(query_df, test_rel_df, corpus_df)

  0%|          | 0/339 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

In [9]:
test_df = test_df.drop_duplicates(subset=["query_id"])

In [10]:
test_df

Unnamed: 0,query_text,title,gold_passage,query_id
0,0-dimensional biomaterials show inductive prop...,New opportunities: the use of nanotechnologies...,Nanotechnologies are emerging platforms that c...,1
1,"1,000 genomes project enables mapping of genet...",Rare Variants Create Synthetic Genome-Wide Ass...,Genome-wide association studies (GWAS) have no...,3
2,1/2000 in UK have abnormal PrP positivity.,Prevalent abnormal prion protein in human appe...,OBJECTIVES To carry out a further survey of ar...,5
3,5% of perinatal mortality is due to low birth ...,Estimates of global prevalence of childhood un...,CONTEXT One key target of the United Nations M...,13
4,A deficiency of vitamin B12 increases blood le...,Folic acid improves endothelial function in co...,BACKGROUND Homocysteine is a risk factor for c...,36
...,...,...,...,...
331,Women with a higher birth weight are more like...,Birth Size and Breast Cancer Risk: Re-analysis...,"BACKGROUND Birth size, perhaps a proxy for pre...",1379
335,aPKCz causes tumour enhancement by affecting g...,Control of Nutrient Stress-Induced Metabolic R...,Tumor cells have high-energetic and anabolic n...,1382
336,cSMAC formation enhances weak ligand signalling.,The stimulatory potency of T cell antigens is ...,T cell activation is predicated on the interac...,1385
337,mTORC2 regulates intracellular cysteine levels...,mTORC2 Regulates Amino Acid Metabolism in Canc...,Mutations in cancer reprogram amino acid metab...,1389


In [10]:
corpus_df

Unnamed: 0,corpus_id,title,text,metadata
0,4983,Microstructural development of human newborn c...,Alterations of the architecture of cerebral wh...,{}
1,5836,Induction of myelodysplasia by myeloid-derived...,Myelodysplastic syndromes (MDS) are age-depend...,{}
2,7912,"BC1 RNA, the transcript from a master gene for...",ID elements are short interspersed elements (S...,{}
3,18670,The DNA Methylome of Human Peripheral Blood Mo...,DNA methylation plays an important role in bio...,{}
4,19238,The human myelin basic protein gene is include...,Two human Golli (for gene expressed in the oli...,{}
...,...,...,...,...
5178,195689316,Body-mass index and cause-specific mortality i...,BACKGROUND The main associations of body-mass ...,{}
5179,195689757,Targeting metabolic remodeling in glioblastoma...,A key aberrant biological difference between t...,{}
5180,196664003,Signaling architectures that transmit unidirec...,A signaling pathway transmits information from...,{}
5181,198133135,"Association between pre-diabetes, type 2 diabe...",AIMS Trabecular bone score (TBS) is a surrogat...,{}


In [11]:
corpus_df["passages"] = corpus_df["title"] + " " + corpus_df["text"]
corpus_df = corpus_df.drop_duplicates(subset=["passages"])
corpus_df = corpus_df.drop(["title", "text", "metadata"], axis=1)
corpus_df

Unnamed: 0,corpus_id,passages
0,4983,Microstructural development of human newborn c...
1,5836,Induction of myelodysplasia by myeloid-derived...
2,7912,"BC1 RNA, the transcript from a master gene for..."
3,18670,The DNA Methylome of Human Peripheral Blood Mo...
4,19238,The human myelin basic protein gene is include...
...,...,...
5178,195689316,Body-mass index and cause-specific mortality i...
5179,195689757,Targeting metabolic remodeling in glioblastoma...
5180,196664003,Signaling architectures that transmit unidirec...
5181,198133135,"Association between pre-diabetes, type 2 diabe..."


### Write to disk

In [12]:
test_df.to_csv(f"data/data_{dataset_dir}/test.tsv", sep="\t", index=False)
corpus_df.to_csv(f"data/data_{dataset_dir}/corpus.tsv", sep="\t", index=False)

with open(f"data/data_{dataset_dir}/relevant_docs.json", "w") as f:
    json.dump(relevant_docs, f)


# Evaluate on the new dataset/corpus

In [13]:
import logging
import json
import os
from pprint import pprint

from simpletransformers.retrieval import RetrievalModel, RetrievalArgs


logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

2022-08-23 17:49:56.718734: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


In [14]:
eval_data = f"data/data_{dataset_dir}/test.tsv"
index_path = f"data/data_{dataset_dir}/corpus.tsv"
relevant_docs_path = f"data/data_{dataset_dir}/relevant_docs.json"

In [15]:
# Load a model from the Huggingface Hub
model_type = "dpr"
model_name = None
context_name = "facebook/dpr-ctx_encoder-multiset-base"
question_name = "facebook/dpr-question_encoder-multiset-base"

model_args = RetrievalArgs()
model_args.include_title_in_knowledge_dataset = False  # The SciFact dataset passages don't have titles
model_args.retrieve_n_docs = 100
model_args.output_dir = f"data/data_{dataset_dir}"

In [16]:
model = RetrievalModel(
    model_type=model_type,
    model_name=model_name,
    context_encoder_name=context_name,
    query_encoder_name=question_name,
    args=model_args,
    prediction_passages=index_path,
)

Some weights of the model checkpoint at facebook/dpr-ctx_encoder-multiset-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenize

Downloading and preparing dataset csv/default to /deep_learning/.cache/huggingface/datasets/csv/default-b2ce6e3e028b1269/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

Dataset csv downloaded and prepared to /deep_learning/.cache/huggingface/datasets/csv/default-b2ce6e3e028b1269/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

INFO:simpletransformers.retrieval.retrieval_utils:Preparing prediction passages completed
INFO:simpletransformers.retrieval.retrieval_utils:Generating embeddings for prediction passages started


  0%|          | 0/41 [00:00<?, ?ba/s]

INFO:simpletransformers.retrieval.retrieval_utils:Generating embeddings for prediction passages completed
INFO:simpletransformers.retrieval.retrieval_utils:Adding FAISS index to prediction passages


  0%|          | 0/6 [00:00<?, ?it/s]

INFO:simpletransformers.retrieval.retrieval_utils:Adding FAISS index to prediction passages completed


In [17]:
results, *_ = model.eval_model(
    eval_data,
    top_k_values=[1, 2, 3, 5, 10, 20, 100],
    relevant_docs=relevant_docs_path,
)



Downloading and preparing dataset csv/default to /deep_learning/.cache/huggingface/datasets/csv/default-7d9ece215f39d7c7/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

Dataset csv downloaded and prepared to /deep_learning/.cache/huggingface/datasets/csv/default-7d9ece215f39d7c7/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?ex/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Running Evaluation:   0%|          | 0/38 [00:00<?, ?it/s]

  (max_idxs == torch.tensor(labels)).sum().cpu().detach().numpy().item()


Retrieving docs:   0%|          | 0/1 [00:00<?, ?it/s]

INFO:simpletransformers.retrieval.retrieval_model:{'eval_loss': 14.01247501373291, 'mrr_at_1': 0.19666666666666666, 'mrr_at_2': 0.235, 'mrr_at_3': 0.24722222222222223, 'mrr_at_5': 0.26072222222222224, 'mrr_at_10': 0.2709378306878307, 'mrr_at_20': 0.27644655752743985, 'mrr_at_100': 0.2806517790657296, 'top_1_accuracy': 0.19666666666666666, 'top_2_accuracy': 0.2733333333333333, 'top_3_accuracy': 0.31, 'top_5_accuracy': 0.37, 'top_10_accuracy': 0.45, 'top_20_accuracy': 0.5233333333333333, 'top_100_accuracy': 0.7033333333333334, 'recall_at_1': 0.18916666666666668, 'recall_at_2': 0.25816666666666666, 'recall_at_3': 0.29483333333333334, 'recall_at_5': 0.35283333333333333, 'recall_at_10': 0.43322222222222223, 'recall_at_20': 0.5057222222222222, 'recall_at_100': 0.6902222222222222}


In [18]:
results

{'eval_loss': 14.01247501373291,
 'mrr_at_1': 0.19666666666666666,
 'mrr_at_2': 0.235,
 'mrr_at_3': 0.24722222222222223,
 'mrr_at_5': 0.26072222222222224,
 'mrr_at_10': 0.2709378306878307,
 'mrr_at_20': 0.27644655752743985,
 'mrr_at_100': 0.2806517790657296,
 'top_1_accuracy': 0.19666666666666666,
 'top_2_accuracy': 0.2733333333333333,
 'top_3_accuracy': 0.31,
 'top_5_accuracy': 0.37,
 'top_10_accuracy': 0.45,
 'top_20_accuracy': 0.5233333333333333,
 'top_100_accuracy': 0.7033333333333334,
 'recall_at_1': 0.18916666666666668,
 'recall_at_2': 0.25816666666666666,
 'recall_at_3': 0.29483333333333334,
 'recall_at_5': 0.35283333333333333,
 'recall_at_10': 0.43322222222222223,
 'recall_at_20': 0.5057222222222222,
 'recall_at_100': 0.6902222222222222}