## Installing transformer-rankers and dependencies

---



In [None]:
!pip install git+https://github.com/Guzpenha/transformer_rankers.git
!wget https://raw.githubusercontent.com/Guzpenha/transformer_rankers/master/requirements.txt
!pip install -r requirements.txt

In [None]:
#Install Anserini which is also a requirement for part of transformer-rankers (BM25 Negative Samplers)
!apt-get install maven -qq
!git clone --recurse-submodules https://github.com/castorini/anserini.git
!cd anserini; mvn clean package appassembler:assemble -DskipTests -Dmaven.javadoc.skip=true
!ls anserini/target/appassembler/bin/

In [None]:
#Google colab with torch 1.5 doesnt see the GPU
!pip install -I torch==1.4.0
import torch
torch.cuda.get_device_name(0)  #This should ouptut a GPU device name

## Downloading ClariQ data

In [None]:
!mkdir data
!mkdir data/clariq
!cd data/clariq; wget https://github.com/aliannejadi/ClariQ/raw/master/data/dev.tsv
!cd data/clariq; wget https://github.com/aliannejadi/ClariQ/raw/master/data/train.tsv
!cd data/clariq; wget https://github.com/aliannejadi/ClariQ/raw/master/data/question_bank.tsv
!mv data/clariq/train.tsv data/clariq/train_original.tsv

## Preprocess ClariQ for transformer-rankers

In [None]:
import pandas as pd
data_path = "./data/"

train = pd.read_csv(data_path+"clariq/train_original.tsv", sep="\t")
valid = pd.read_csv(data_path+"clariq/dev.tsv", sep="\t")

train = train[["initial_request", "question"]]
train.columns = ["query", "clarifying_question"]
train = train[~train["clarifying_question"].isnull()]

valid = valid[["initial_request", "question"]]
valid.columns = ["query", "clarifying_question"]
valid = valid[~valid["clarifying_question"].isnull()]

train.to_csv(data_path+"clariq/train.tsv", sep="\t", index=False)
valid.to_csv(data_path+"clariq/valid.tsv", sep="\t", index=False)

In [None]:
# For transformer-rankers we only need a pandas DF with query (here the initial request) 
# and relevant documents (here the clarifying questions).
train.head()

Unnamed: 0,query,clarifying_question
0,Tell me about Obama family tree.,are you interested in seeing barack obamas family
1,Tell me about Obama family tree.,would you like to know barack obamas geneology
2,Tell me about Obama family tree.,would you like to know about obamas ancestors
3,Tell me about Obama family tree.,would you like to know who is currently alive ...
4,Tell me about Obama family tree.,are you looking for biological information on ...


In [None]:
# We will sample negative samples for training using the question bank
question_bank = pd.read_csv(data_path+"clariq/question_bank.tsv", sep="\t")
question_bank.head()

Unnamed: 0,question_id,question
0,Q00001,
1,Q00002,a total cholesterol of 180 to 200 mgdl 10 to 1...
2,Q00003,about how many years experience do you want th...
3,Q00004,according to anima the bible or what other source
4,Q00005,ae you looking for examples of septic system d...


## Training a transformer-ranker for ClariQ (RQ2)

The problem is to retrieve the most relevant clarifying question for a given query. We will train a BERT-ranker using transformer-rankers.

In [None]:
from transformer_rankers.trainers import transformer_trainer
from transformer_rankers.datasets import dataset
from transformer_rankers.negative_samplers import negative_sampling
from transformer_rankers.eval import results_analyses_tools

from transformers import BertTokenizer, BertForSequenceClassification

import logging
import os
import sys
import torch
import random
import numpy as np

np.random.seed(42)
random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

logging.basicConfig(
  level=logging.INFO,
  format="%(asctime)s [%(levelname)s] %(message)s",
  handlers=[
      logging.StreamHandler(sys.stdout)
  ]
)

#The combination of query and question are not that big.
max_seq_len = 50

#Lets use an almost balanced amount of positive and negative samples during training.
average_relevant_per_query = train.groupby("query").count().mean().values[0]

#Instantiate BM25 negative sampler.
ns_train = negative_sampling.BM25NegativeSamplerPyserini(list(question_bank["question"].values[1:]), int(average_relevant_per_query) , 
                    "/content/data/clariq/anserini_train/", -1, "./anserini/")
ns_val = negative_sampling.BM25NegativeSamplerPyserini(list(question_bank["question"].values[1:]), int(average_relevant_per_query), 
                    "/content/data/clariq/anserini_train/", -1, "./anserini/")

# We could also use random sampling which does not require Anserini.
# ns_train = negative_sampling.RandomNegativeSampler(list(question_bank["question"].values[1:]), int(average_relevant_per_query))
# ns_val = negative_sampling.RandomNegativeSampler(list(question_bank["question"].values[1:]), int(average_relevant_per_query))

#Create the loaders for the dataset, with the respective negative samplers
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataloader = dataset.QueryDocumentDataLoader(train_df=train,
                    val_df=valid, test_df=valid,
                    tokenizer=tokenizer, negative_sampler_train=ns_train,
                    negative_sampler_val=ns_val, task_type='classification',
                    train_batch_size=12, val_batch_size=12, max_seq_len=max_seq_len,
                    sample_data=-1, cache_path="./data/clariq/")

train_loader, val_loader, test_loader = dataloader.\
  get_pytorch_dataloaders()

#Use BERT (any model that has SequenceClassification class from HuggingFace would work here)
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

#Instantiate trainer that handles fitting.
trainer = transformer_trainer.TransformerTrainer(model=model,
  train_loader=train_loader,
  val_loader=val_loader, test_loader=test_loader,
  num_ns_eval=int(average_relevant_per_query), task_type="classification", tokenizer=tokenizer,
  validate_every_epochs=1, num_validation_instances=-1,
  num_epochs=1, lr=5e-7, sacred_ex=None)

#Train (our validation eval uses the NS sampling procedure)
trainer.fit()

2020-08-07 17:03:23,222 [INFO] Lock 139694630808936 acquired on /root/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084.lock
2020-08-07 17:03:23,224 [INFO] https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt not found in cache or force_download set to True, downloading to /root/.cache/torch/transformers/tmpc1wij842


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…


2020-08-07 17:03:23,520 [INFO] storing https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt in cache at /root/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
2020-08-07 17:03:23,522 [INFO] creating metadata file for /root/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
2020-08-07 17:03:23,527 [INFO] Lock 139694630808936 released on /root/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084.lock
2020-08-07 17:03:23,529 [INFO] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /root/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a5595

100%|██████████| 187/187 [00:01<00:00, 97.69it/s] 

2020-08-07 17:03:25,562 [INFO] Encoding examples using tokenizer.batch_encode_plus().





2020-08-07 17:03:33,169 [INFO] Transforming examples to instances format.
2020-08-07 17:03:33,383 [INFO] Set train Instance 0 query 

Child support in Indiana?[...]

2020-08-07 17:03:33,384 [INFO] Set train Instance 0 document 

are you interested in indiana child support

2020-08-07 17:03:33,385 [INFO] Set train Instance 0 features 

InputFeatures(input_ids=[101, 2775, 2490, 1999, 5242, 1029, 102, 2024, 2017, 4699, 1999, 5242, 2775, 2490, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=1)

2020-08-07 17:03:33,388 [INFO] Set train Instance 1 query 

Child support in Indiana?[...]

2020-08-0

100%|██████████| 50/50 [00:00<00:00, 150.22it/s]

2020-08-07 17:03:33,881 [INFO] Encoding examples using tokenizer.batch_encode_plus().





2020-08-07 17:03:35,882 [INFO] Transforming examples to instances format.
2020-08-07 17:03:35,915 [INFO] Set val Instance 0 query 

Find Brooks Brothers clearance.[...]

2020-08-07 17:03:35,916 [INFO] Set val Instance 0 document 

are you interested in brooks brothers clearance shirts

2020-08-07 17:03:35,919 [INFO] Set val Instance 0 features 

InputFeatures(input_ids=[101, 2424, 8379, 3428, 14860, 1012, 102, 2024, 2017, 4699, 1999, 8379, 3428, 14860, 11344, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=1)

2020-08-07 17:03:35,922 [INFO] Set val Instance 1 query 

Find Brooks Brothers clear

100%|██████████| 50/50 [00:00<00:00, 142.68it/s]

2020-08-07 17:03:36,342 [INFO] Encoding examples using tokenizer.batch_encode_plus().





2020-08-07 17:03:38,301 [INFO] Transforming examples to instances format.
2020-08-07 17:03:38,479 [INFO] Set test Instance 0 query 

Find Brooks Brothers clearance.[...]

2020-08-07 17:03:38,480 [INFO] Set test Instance 0 document 

are you interested in brooks brothers clearance shirts

2020-08-07 17:03:38,483 [INFO] Set test Instance 0 features 

InputFeatures(input_ids=[101, 2424, 8379, 3428, 14860, 1012, 102, 2024, 2017, 4699, 1999, 8379, 3428, 14860, 11344, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=1)

2020-08-07 17:03:38,485 [INFO] Set test Instance 1 query 

Find Brooks Brothers c

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…


2020-08-07 17:03:38,986 [INFO] storing https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json in cache at /root/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517
2020-08-07 17:03:38,988 [INFO] creating metadata file for /root/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517
2020-08-07 17:03:38,990 [INFO] Lock 139694630845128 released on /root/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517.lock
2020-08-07 17:03:38,992 [INFO] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /root/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.71

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…


2020-08-07 17:04:21,308 [INFO] storing https://cdn.huggingface.co/bert-base-uncased-pytorch_model.bin in cache at /root/.cache/torch/transformers/f2ee78bdd635b758cc0a12352586868bef80e47401abe4c4fcc3832421e7338b.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157
2020-08-07 17:04:21,310 [INFO] creating metadata file for /root/.cache/torch/transformers/f2ee78bdd635b758cc0a12352586868bef80e47401abe4c4fcc3832421e7338b.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157
2020-08-07 17:04:21,312 [INFO] Lock 139694572620656 released on /root/.cache/torch/transformers/f2ee78bdd635b758cc0a12352586868bef80e47401abe4c4fcc3832421e7338b.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157.lock
2020-08-07 17:04:21,313 [INFO] loading weights file https://cdn.huggingface.co/bert-base-uncased-pytorch_model.bin from cache at /root/.cache/torch/transformers/f2ee78bdd635b758cc0a12352586868bef80e47401abe4c4fcc3832421e7338b.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2e

Epoch 0: 100%|██████████| 1416/1416 [07:12<00:00,  3.28it/s]
100%|██████████| 368/368 [00:27<00:00, 13.54it/s]

2020-08-07 17:12:09,138 [INFO] Epoch 1 val nDCG@10 0.951





## Evaluating with ClariQ evaluation scripts
The above code uses the transformer-ranker's built-in evaluation. This means that we are only ranking from a set of K (int(average_relevant_per_query) in the example) candidate questions, a re-ranking scenario where all the positive examples are in the candidate list. RQ2 of ClariQ requires us to rank from the entire question_bank. Additionally it evaluates whether the clarifying questions helps for document retrieval.

In [None]:
! git clone https://github.com/aliannejadi/ClariQ.git ClariQ-repo
! pip install rank_bm25

Cloning into 'ClariQ-repo'...
remote: Enumerating objects: 108, done.[K
remote: Counting objects: 100% (108/108), done.[K
remote: Compressing objects: 100% (74/74), done.[K
remote: Total 120 (delta 56), reused 77 (delta 33), pack-reused 12[K
Receiving objects: 100% (120/120), 24.07 MiB | 30.13 MiB/s, done.
Resolving deltas: 100% (58/58), done.
Collecting rank_bm25
  Downloading https://files.pythonhosted.org/packages/16/5a/23ed3132063a0684ea66fb410260c71c4ffda3b99f8f1c021d1e245401b5/rank_bm25-0.2.1-py3-none-any.whl
Installing collected packages: rank-bm25
Successfully installed rank-bm25-0.2.1


### Re-rank BM25 with the fine-tuned BERT-ranker

Lets first use the bm25 [example](https://colab.research.google.com/drive/1g_Sc9j5fYT1hiOxif6BVH5NHNt-icxtT#scrollTo=_7_2LTXoXqK7) from Mohammad to generate the dataset.

In [None]:
rerank_top_k = 30

In [None]:
# Imports required packages, defines stem & tokenizez function
import pandas as pd
from rank_bm25 import BM25Okapi
import nltk
from nltk.stem.porter import PorterStemmer

nltk.download('punkt')
nltk.download('stopwords')

def stem_tokenize(text, remove_stopwords=True):
  stemmer = PorterStemmer()
  tokens = [word for sent in nltk.sent_tokenize(text) \
                                      for word in nltk.word_tokenize(sent)]
  tokens = [word for word in tokens if word not in \
          nltk.corpus.stopwords.words('english')]
  return [stemmer.stem(word) for word in tokens]

# Files paths
request_file_path = './ClariQ-repo/data/dev.tsv'
question_bank_path = './ClariQ-repo/data/question_bank.tsv'
run_file_path = './ClariQ-repo/sample_runs/dev_bm25'

# Reads files and build bm25 corpus (index)
dev = pd.read_csv(request_file_path, sep='\t')
question_bank = pd.read_csv(question_bank_path, sep='\t').fillna('')
question_bank['tokenized_question_list'] = question_bank['question'].map(stem_tokenize)
question_bank['tokenized_question_str'] = question_bank['tokenized_question_list'].map(lambda x: ' '.join(x))
bm25_corpus = question_bank['tokenized_question_list'].tolist()
bm25 = BM25Okapi(bm25_corpus)

# Runs bm25 for every query and stores output in file.
examples = []
all_preds_bm25 = []
with open(run_file_path, 'w') as fo:
  for tid in dev['topic_id'].unique():
    query = dev.loc[dev['topic_id']==tid, 'initial_request'].tolist()[0]
    bm25_ranked_list = bm25.get_top_n(stem_tokenize(query, True), 
                                    bm25_corpus, 
                                    n=rerank_top_k)
    bm25_q_list = [' '.join(sent) for sent in bm25_ranked_list]
    docs = question_bank.set_index('tokenized_question_str').loc[bm25_q_list, 'question'].tolist()
    preds = question_bank.set_index('tokenized_question_str').loc[bm25_q_list, 'question_id'].tolist()
    all_preds_bm25.append(preds)
    for doc in docs[:rerank_top_k]:
      examples.append((query, doc))

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


Now we are going to transform this dataset in the format required for our model.

In [None]:
from transformers.data.data_collator import DefaultDataCollator
from transformers.data.processors.utils import InputFeatures
from torch.utils.data import Dataset, DataLoader
from transformer_rankers.utils import utils

class SimpleDataset(Dataset):
    def __init__(self, features):
        self.features = features
    def __len__(self):
        return len(self.features)
    def __getitem__(self, index):
        return self.features[index]

batch_encoding = tokenizer.batch_encode_plus(examples, 
                max_length=max_seq_len, pad_to_max_length=True)
features = []
for i in range(len(examples)):
    inputs = {k: batch_encoding[k][i] for k in batch_encoding}
    feature = InputFeatures(**inputs, label=0)
    features.append(feature)

dataset = SimpleDataset(features)
collator = DefaultDataCollator()
dataloader = DataLoader(dataset, batch_size=16, shuffle=False, collate_fn=collator.collate_batch)

Now we can run the trained model on this dataset and  save the predictions to a file

In [None]:
logits, _, softmax_output = trainer.predict(dataloader)
softmax_output_by_query = utils.acumulate_list(softmax_output[0], rerank_top_k)

100%|██████████| 94/94 [00:10<00:00,  9.37it/s]


In [None]:
import numpy as np
run_file_path = './ClariQ-repo/sample_runs/dev_BERT-reranker'
with open(run_file_path, 'w') as fo:
  for tid_idx, tid in enumerate(dev['topic_id'].unique()):
    document_scores = np.array(softmax_output_by_query[tid_idx])
    top_k_scores_idx = (-document_scores).argsort()[:rerank_top_k]  
    preds = np.array(all_preds_bm25[tid_idx])[top_k_scores_idx]
    for i, qid in enumerate(preds):
      fo.write('{} 0 {} {} {} BERT-reranker\n'.format(tid, qid, i, len(preds)-i))

In [None]:
# Report question relevance performance
! python ./ClariQ-repo/src/clariq_eval_tool.py  --eval_task question_relevance\
                                                --data_dir ./ClariQ-repo/data/ \
                                                --experiment_type dev \
                                                --run_file {run_file_path} \
                                                --out_file {run_file_path}_question_relevance.eval

Recall5: 0.3474806038474769
Recall10: 0.6136149763549145
Recall20: 0.6912818698329535
Recall30: 0.6912818698329535


In [None]:
! python ./ClariQ-repo/src/clariq_eval_tool.py  --eval_task document_relevance\
                                                --data_dir ./ClariQ-repo/data/ \
                                                --experiment_type dev \
                                                --run_file {run_file_path} \
                                                --out_file {run_file_path}.eval

NDCG1: 0.18958333333333333
NDCG3: 0.17431329825264302
NDCG5: 0.16796281956102732
NDCG10: 0.1658691210524936
NDCG20: 0.1527795302714777
P1: 0.2375
P3: 0.20416666666666666
P5: 0.19
P10: 0.176875
P20: 0.1384375
MRR100: 0.33301824879980596


### Full retrieval
So let's first generate a dataset containing all combinations of dev queries 
and question_bank questions.

In [None]:
from transformers.data.data_collator import DefaultDataCollator
from transformers.data.processors.utils import InputFeatures
from torch.utils.data import Dataset, DataLoader

class SimpleDataset(Dataset):
    def __init__(self, features):
        self.features = features
    def __len__(self):
        return len(self.features)
    def __getitem__(self, index):
        return self.features[index]

#Lets not use the null document for no question.
all_documents = list(question_bank["question"].values[1:])
examples = []
for tid in dev['topic_id'].unique():
    query = dev.loc[dev['topic_id']==tid, 'initial_request'].tolist()[0]
    for doc in all_documents:
      examples.append((query, doc))

batch_encoding = tokenizer.batch_encode_plus(examples, 
                max_length=max_seq_len, pad_to_max_length=True)
features = []
for i in range(len(examples)):
    inputs = {k: batch_encoding[k][i] for k in batch_encoding}
    feature = InputFeatures(**inputs, label=0)
    features.append(feature)

dataset = SimpleDataset(features)
collator = DefaultDataCollator()

Now we have to make the predictions and acumulate the logits by the number of candidate documents

In [None]:
dataloader = DataLoader(dataset, batch_size=16, shuffle=False, collate_fn=collator.collate_batch)
from transformer_rankers.utils import utils
logits, _, softmax_output = trainer.predict(dataloader)
softmax_output_by_query = utils.acumulate_list(softmax_output[0], len(all_documents))

100%|██████████| 12313/12313 [21:48<00:00,  9.41it/s]


In [None]:
import numpy as np
run_file_path = './ClariQ-repo/sample_runs/dev_BERT-ranker'
all_doc_ids = np.array(question_bank["question_id"].values[1:])
with open(run_file_path, 'w') as fo:
  for tid_idx, tid in enumerate(dev['topic_id'].unique()):
    all_documents_scores = np.array(softmax_output_by_query[tid_idx])
    top_30_scores_idx = (-all_documents_scores).argsort()[:30]  
    preds = all_doc_ids[top_30_scores_idx]
    for i, qid in enumerate(preds):    
      fo.write('{} 0 {} {} {} BERT-ranker\n'.format(tid, qid, i, len(preds)-i))

In [None]:
# Report question relevance performance
! python ./ClariQ-repo/src/clariq_eval_tool.py  --eval_task question_relevance\
                                                --data_dir ./ClariQ-repo/data/ \
                                                --experiment_type dev \
                                                --run_file {run_file_path} \
                                                --out_file {run_file_path}_question_relevance.eval

Recall5: 0.35055278656671846
Recall10: 0.6154512724117988
Recall20: 0.7253078340648
Recall30: 0.7529370626793227


In [None]:
! python ./ClariQ-repo/src/clariq_eval_tool.py  --eval_task document_relevance\
                                                --data_dir ./ClariQ-repo/data/ \
                                                --experiment_type dev \
                                                --run_file {run_file_path} \
                                                --out_file {run_file_path}.eval

NDCG1: 0.18958333333333333
NDCG3: 0.17431329825264302
NDCG5: 0.16796281956102732
NDCG10: 0.1658691210524936
NDCG20: 0.1527795302714777
P1: 0.2375
P3: 0.20416666666666666
P5: 0.19
P10: 0.176875
P20: 0.1384375
MRR100: 0.33301824879980596


## Results comparison

In [None]:
import json

models = ["bm25", "BERT-reranker", "BERT-ranker"]
results = []
for model in models:
  with open('./ClariQ-repo/sample_runs/dev_{}_question_relevance.eval'.format(model)) as f:
    res = json.load(f)
    for metric_name in res:
        metric_avg = np.mean([res[metric_name][k] for k in res[metric_name]])
        results.append([model, metric_name, metric_avg])
res_df = pd.DataFrame(results, columns = ["model", "metric", "value"])

pd.set_option("display.precision", 4)
res_df = res_df.set_index(["model", "metric"]).unstack()
cols = res_df.columns.tolist()
res_df.sort_values([("value","Recall10")])[cols[-1:] + cols[:-1]]

Unnamed: 0_level_0,value,value,value,value
metric,Recall5,Recall10,Recall20,Recall30
model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
bm25,0.3246,0.5638,0.6675,0.6913
BERT-reranker,0.3475,0.6136,0.6913,0.6913
BERT-ranker,0.3506,0.6155,0.7253,0.7529


In [None]:
import json

models = ["bm25", "BERT-reranker", "BERT-ranker"]
results = []
for model in models:
  with open('./ClariQ-repo/sample_runs/dev_{}.eval'.format(model)) as f:
    res = json.load(f)
    for metric_name in res:
        metric_avg = np.mean([res[metric_name][k] for k in res[metric_name]])
        results.append([model, metric_name, metric_avg])
res_df = pd.DataFrame(results, columns = ["model", "metric", "value"])

pd.set_option("display.precision", 4)
res_df = res_df.set_index(["model", "metric"]).unstack()
res_df.sort_values(("value", "MRR100"))

Unnamed: 0_level_0,value,value,value,value,value,value,value,value,value,value,value
metric,MRR100,NDCG1,NDCG10,NDCG20,NDCG3,NDCG5,P1,P10,P20,P3,P5
model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2
bm25,0.3096,0.1859,0.1363,0.1285,0.1608,0.153,0.2313,0.1406,0.1181,0.1896,0.175
BERT-ranker,0.333,0.1896,0.1659,0.1528,0.1743,0.168,0.2375,0.1769,0.1384,0.2042,0.19
BERT-reranker,0.333,0.1896,0.1659,0.1528,0.1743,0.168,0.2375,0.1769,0.1384,0.2042,0.19
