# BERT Evalscript

In order to evaluate a BERT model, we follow the same procedure as training. Retrieve top-10 (100? 1000?) from Anserini and re-rank accordingly.

In [1]:
import os
from os.path import expanduser
home = expanduser("~")
os.environ["JAVA_HOME"] = f"{home}/.sdkman/candidates/java/11.0.7.hs-adpt"  #Set right JAVA version
data_home = "/ssd2/arthur/MsMarcoTREC/"
def path(x):
    return os.path.join(data_home, x)

try:
    import pyserini
except:
    !pip install pyserini==0.9.2.0 # install pyserini
    import pyserini
try:
    import tqdm
except:
    !pip install tqdm # Good for progress bars!
    import tqdm

In [2]:
import jnius_config
jnius_config.add_options('-Xmx32G') # Adjust to your machine.
from pyserini.search import pysearch
import subprocess
from tqdm.auto import tqdm
import random
import pickle
import sys
import unicodedata
import string
import re
import os
from collections import defaultdict
import math

## `pytrec_eval` setting

In [106]:
try:
    import pytrec_eval
except:
    !pip install pytrec_eval
    import pytrec_eval
    
qrel = defaultdict(lambda:dict())
qrel_path = path("qrels/2019qrels-docs.txt")
for line in open(qrel_path):
    query_id, _, doc_id, rel = line.split()
    qrel[query_id][doc_id] = int(rel)
qrel = dict(qrel)
metrics = {"map", "ndcg", "recip_rank"}
evaluator = pytrec_eval.RelevanceEvaluator(qrel, metrics)

In [4]:
index_path = path("lucene-index.msmarco-doc.pos+docvectors+rawdocs")
searcher = pysearch.SimpleSearcher(index_path)

## Extract Anserini top-K

In [5]:
pattern = re.compile('([^\s\w]|_)+')

searcher.set_bm25_similarity(0.9, 0.4)
pairs = []
threads = 42 # Number of Threads to use when retrieving
k = 1000       # Number of documents to retrieve 
triples = []

query_texts = dict()
file_path = path(f"queries/msmarco-test2019-queries.tsv")
run_search=True
if os.path.isfile("test_triples.pkl"):
    print(f"Already found file {file_path}. Cowardly refusing to run this again. Will only load querytexts.")
    pairs = pickle.load(open(path("test_triples.pkl"), 'rb'))
    run_search = False
number_of_queries = int(subprocess.run(f"wc -l {file_path}".split(), capture_output=True).stdout.split()[0])
print(f"Running retrieval step for {number_of_queries} queries, on top-{k}")
queries = []
query_ids = []
for idx, line in enumerate(open(file_path, encoding="utf-8")):
    query_id, query = line.strip().split("\t")
    query_ids.append(query_id)
    query = unicodedata.normalize("NFKD", query) # Force queries into UTF-8
    query = pattern.sub(' ',query) # Remove non-ascii characters. It clears up most of the issues we may find on the query datasets
    query_texts[query_id] = query
    if run_search is False:
        continue
    queries.append(query)
scores = defaultdict(lambda:dict())
results = searcher.batch_search(queries, query_ids, k=k, threads=threads)

#There is probably a one-liner for this...
for qid in results.keys():
    for result in results[qid]:
        scores[qid][result.docid] = result.score
        triples.append((qid, result.docid))

print(f"dumping results...")
pickle.dump(dict(scores), open(path(f"test_triples.pkl"), 'wb'))

Running retrieval step for 200 queries, on top-1000
dumping results...


In [6]:
import numpy as np
avg_map = np.mean([x["map"] for x in evaluator.evaluate(scores).values()])
avg_ndcg = np.mean([x["ndcg"] for x in evaluator.evaluate(scores).values()])
print(f"Anserini MAP: {avg_map} nDCG: {avg_ndcg}")

Anserini MAP: 0.33099106298835235 nDCG: 0.5942674768376021


In [47]:
# This is a copy from the MsMarcoDataset Class.
from torch.utils.data import Dataset
import torch

# This is our main Dataset class.
class MsMarcoDataset(Dataset):
    def __init__(self,
                 samples,
                 tokenizer,
                 searcher,
                 split,
                 tokenizer_batch=8000):
        '''Initialize a Dataset object. 
        Arguments:
            samples: A list of samples. Each sample should be a tuple with (query_id, doc_id, <label>), where label is optional
            tokenizer: A tokenizer object from Hugging Face's Tokenizer lib. (need to implement encode_batch())
            searcher: A PySerini Simple Searcher object. Should implement the .doc() method
            split: A strong indicating if we are in a train, dev or test dataset.
            tokenizer_batch: How many samples to be tokenized at once by the tokenizer object.
            The biggest bottleneck is the searcher, not the tokenizer.
        '''
        self.searcher = searcher
        self.split = split
        # If we already have the data pre-computed, we shouldn't need to re-compute it.
        self.split = split
        if (os.path.isfile(path(f"{split}_msmarco_samples.tsv"))
                and os.path.isfile(path(f"{split}_msmarco_offset.pkl"))
                and os.path.isfile(path(f"{split}_msmarco_index.pkl"))):
            print("Already found every meaningful file. Cowardly refusing to re-compute.")
            self.samples_offset_dict = pickle.load(open(path(f"{split}_msmarco_offset.pkl"), 'rb'))
            self.index_dict = pickle.load(open(path(f"{split}_msmarco_index.pkl"), 'rb'))
            return
        self.tokenizer = tokenizer
        print("Loading and tokenizing dataset...")
        self.samples_offset_dict = dict()
        self.index_dict = dict()

        self.samples_file = open(path(f"{split}_msmarco_samples.tsv"),'w',encoding="utf-8")
        self.processed_samples = 0
        query_batch = []
        doc_batch = []
        sample_ids_batch = []
        labels_batch = []
        number_of_batches = math.ceil(len(samples) // tokenizer_batch)
        print(number_of_batches)
        # A progress bar to display how far we are.
        batch_pbar = tqdm(total=number_of_batches, desc="Tokenizer batches")
        for i, sample in enumerate(samples):
            if split=="train" or split == "dev":
                label = sample[2]
                labels_batch.append(label)
            query_batch.append(query_texts[sample[0]])
            doc_batch.append(self._get_document_content_from_id(sample[1]))
            sample_ids_batch.append(f"{sample[0]}_{sample[1]}")
            #If we hit the number of samples for this batch OR this is the last sample
            if len(query_batch) == tokenizer_batch or i == len(samples) - 1:
                self._tokenize_and_dump_batch(doc_batch, query_batch, labels_batch, sample_ids_batch)
                batch_pbar.update()
                query_batch = []
                doc_batch = []
                sample_ids_batch = []
                if split == "train" or split == "dev":
                    labels_batch = []
        batch_pbar.close()
        # Dump files in disk, so we don't need to go over it again.
        self.samples_file.close()
        pickle.dump(self.index_dict, open(path(f"{self.split}_msmarco_index.pkl"), 'wb'))
        pickle.dump(self.samples_offset_dict, open(path(f"{self.split}_msmarco_offset.pkl"), 'wb'))

    def _tokenize_and_dump_batch(self, doc_batch, query_batch, labels_batch,
                                 sample_ids_batch):
        '''tokenizes and dumps the samples in the current batch
        It also store the positions from the current file into the samples_offset_dict.
        '''
        # Use the tokenizer object
        tokens = self.tokenizer.encode_batch(list(zip(query_batch, doc_batch)))
        for idx, (sample_id, token) in enumerate(zip(sample_ids_batch, tokens)):
            #BERT supports up to 512 tokens. If we have more than that, we need to remove some tokens from the document
            if len(token.ids) >= 512:
                token_ids = token.ids[:511]
                token_ids.append(tokenizer.token_to_id("[SEP]"))
                segment_ids = token.type_ids[:512]
            # With less tokens, we need to "pad" the vectors up to 512.
            else:
                padding = [0] * (512 - len(token.ids))
                token_ids = token.ids + padding
                segment_ids = token.type_ids + padding
            # How far in the file are we? This is where we need to go to find the documents later.
            file_location = self.samples_file.tell()
            # If we have labels
            if self.split=="train" or self.split == "dev":
                self.samples_file.write(f"{sample_id}\t{token_ids}\t{segment_ids}\t{labels_batch[idx]}\n")
            else:
                self.samples_file.write(f"{sample_id}\t{token_ids}\t{segment_ids}\n")
            self.samples_offset_dict[sample_id] = file_location
            self.index_dict[self.processed_samples] = sample_id
            self.processed_samples += 1

    def _get_document_content_from_id(self, doc_id):
        '''Get the raw text value from the doc_id
        There is probably an easier way to do that, but this works.
        '''
        doc_text = self.searcher.doc(doc_id).lucene_document().getField("raw").stringValue()
        return doc_text[7:-8]

    def __getitem__(self, idx):
        '''Returns a sample with index idx
        DistilBERT does not take into account segment_ids. (indicator if the token comes from the query or the document) 
        However, for the sake of completness, we are including it here, together with the attention mask
        position_ids, with the positional encoder, is not needed. It's created for you inside the model.
        '''
        if isinstance(idx, int):
            idx = self.index_dict[idx]
        with open(path(f"{self.split}_msmarco_samples.tsv"), 'r', encoding="utf-8") as inf:
            inf.seek(self.samples_offset_dict[idx])
            line = inf.readline().split("\t")
            try:
                sample_id = line[0]
                qid, did = sample_id.split("_")
                input_ids = eval(line[1])
                token_type_ids = eval(line[2])
                input_mask = [1] * 512
            except:
                print(line, idx)
                raise IndexError
            # If it's a training dataset, we also have a label tag.
            if self.split=="train" or self.split == "dev":
                label = int(line[3])
                return (torch.tensor(input_ids, dtype=torch.long),
                        torch.tensor(input_mask, dtype=torch.long),
                        torch.tensor(token_type_ids, dtype=torch.long),
                        torch.tensor([label], dtype=torch.long),
                       qid,
                       did)
            return (torch.tensor(input_ids, dtype=torch.long),
                    torch.tensor(input_mask, dtype=torch.long),
                    torch.tensor(token_type_ids, dtype=torch.long),
                   qid,
                   did)
    def __len__(self):
        return len(self.samples_offset_dict)

## Evaluation Script

In [8]:
from transformers import DistilBertForSequenceClassification
from torch.utils.data import DataLoader
from tokenizers import BertWordPieceTokenizer

tokenizer = BertWordPieceTokenizer("/ssd2/arthur/bert-axioms/tokenizer/bert-base-uncased-vocab.txt", lowercase=True)

In [48]:
eval_dataset = MsMarcoDataset(triples, tokenizer, searcher, split="test")

Already found every meaningful file. Cowardly refusing to re-compute.


In [49]:
eval_dataset[0]

(tensor([  101, 12098,  7877,  6210,   102, 16770,  1024,  1013,  1013, 29393,
         15007, 12782,  1012,  4012,  1013, 25022, 11365,  1013,  3193, 14399,
          2594,  1012, 25718,  1029,  1056,  1027, 19843,  2692, 14142, 22166,
         15072,  1024,  2024,  2027,  4276,  9343,  1029, 10682,  2332, 29393,
         10975,  6679, 25969,  2271,  5068,  1024, 11703,  2538,  1010,  2456,
         19894,  2015,  1024, 21211,  2575, 19894,  2098,  1024, 12256, 12022,
          2324,  1010,  2289,  1016,  1024,  2459,  2572,  1045,  1005,  1049,
          2349,  2005,  1037,  2047,  3940,  1997, 20422,  7877,  1998,  1045,
          1005,  2310,  2042,  4994,  2070,  2204,  2477,  2055,  6653, 15072,
          1012,  2129,  2024,  2027,  1029,  2157,  2085,  1045,  2031, 20422,
          7877,  1998, 20422, 17072,  2021,  2049,  2428,  1037,  3255,  2000,
          4287,  2105,  2048,  7689,  1997,  7877,  2043,  1045,  2175,  2000,
          1037,  4323,  2380,  1006,  2073,  1045,  

In [10]:
# In a real-world scenario, this would be in a separate file, and you would just import this.
import torch
from torch import nn
from transformers import DistilBertModel, BertModel

class BertRelevanceRanker(nn.Module):
    def __init__(self, model="distilbert-base-uncased"):
        """Creates an instance of Bert Relevance Ranker. 
        It feeds two senteces into a pre-trained BERT model, extracts the [CLS] token and feeds it into a one-layer FFNN"""
        super().__init__()
        self.distil = False
        self.loss_fct = nn.CrossEntropyLoss()
        if "distil" in model:
            self.distil = True
            self.bert = DistilBertModel.from_pretrained(model)
        else:
            self.bert = BertModel.from_pretrained(model)
        self.config = self.bert.config
        self.linear1 = nn.Linear(self.bert.config.dim, self.bert.config.dim)
        self.linear2 = nn.Linear(self.bert.config.dim, 2)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, input_ids, attention_mask, token_type_ids=None, labels=None):
        if not self.distil and token_type_ids is None:
            raise ValueError("Model is not distilBERT and it did not received token_type_ids!")
        if not self.distil:
            outputs = self.bert(input_ids, attention_mask, token_type_ids)
        else:
            pooled_output = self.bert(input_ids, attention_mask)[0][:, 0]
        pooled_output = self.linear1(pooled_output)
        pooled_output = nn.ReLU()(pooled_output)
        pooled_output = self.dropout(pooled_output)
        logits = self.linear2(pooled_output)
        outputs = (logits,)
        if labels is not None:
            loss = self.loss_fct(logits.view(-1, 2), labels.view(-1))
            outputs = (loss, ) + outputs
        return outputs
    

In [66]:
import torch
try:
    del model
    torch.cuda.empty_cache() # Make sure we have a clean slate. Usefull in a Notebook.
except:
    pass

GPUS_TO_USE = [0, 1] # If you have multiple GPUs, pick the ones you want to use.
number_of_cpus = 24 # Number of CPUS to use when loading your dataset.
model = BertRelevanceRanker()
model.load_state_dict(torch.load(path(f"models/distilBERT-2020-05-21/pytorch_model.bin"))) #load last model saved
model.eval()

if torch.cuda.is_available():
    model = torch.nn.DataParallel(model, device_ids=GPUS_TO_USE)
    device = torch.device(f"cuda:{GPUS_TO_USE[0]}") 
    model.to(device)
    batch_size = len(GPUS_TO_USE) * 128 #Eval is WAAY smaller than train. We can load a fairly large batch here.
    print(f"running on {len(GPUS_TO_USE)} GPUS, on {batch_size}-sized batches")
else:
    print("Are you sure about it? We will try to run this in CPU, but it's a BAD idea...")
    device = torch.device("cpu")
    batch_size = 16
    model.to(device)

data_loader = DataLoader(eval_dataset, batch_size=batch_size, num_workers=number_of_cpus, shuffle=False) # Don't shuffle!

running on 2 GPUS, on 256-sized batches


In [67]:
import warnings
bert_scores = defaultdict(lambda:dict())
for batch in tqdm(data_loader, desc="Batches", total=len(data_loader)):
    with torch.no_grad():
        inputs = {'input_ids': batch[0].to(device),
          'attention_mask': batch[1].to(device)}
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            outputs = model(**inputs)
        predict_scores = nn.Softmax(dim=1)(outputs[0])[:,1].detach().cpu().numpy().flatten()
        for qid, did, score in zip(batch[3], batch[4], predict_scores):
            bert_scores[qid][did] = score

HBox(children=(FloatProgress(value=0.0, description='Batches', max=782.0, style=ProgressStyle(description_widt…




In [119]:
alphas = np.arange(0.0, 1.05, 0.05) # It's a nice idea to combine BM25 scores with BERT scores. Let's see which value is the best here.
best_alpha = -1
best_ndcg = 0.0
best_map = 0.0
best_mrr = 0.0
top_k = 1000
for alpha in alphas:
    final_scores = defaultdict(lambda: dict())
    for qid in bert_scores:
        max_score = max(scores[qid].values())
        min_score = min(scores[qid].values())
        normalized_scores_anserini = {k : (v-min_score)/(max_score-min_score) for k, v in scores[qid].items()}
        max_score = max(bert_scores[qid].values())
        min_score = min(bert_scores[qid].values())
        normalized_scores_bert = {k : (v-min_score)/(max_score-min_score) for k, v in bert_scores[qid].items()}
        for did in normalized_scores_anserini:
            final_scores[qid][did] = (alpha * normalized_scores_anserini[did]  + (1-alpha) * normalized_scores_bert[did])
        final_scores[qid] = dict(sorted(final_scores[qid].items(), key=lambda x: x[1], reverse=True)[:top_k])
    avg_map = np.mean([x["map"] for x in evaluator.evaluate(final_scores).values()])
    avg_ndcg = np.mean([x["ndcg"] for x in evaluator.evaluate(final_scores).values()])
    avg_mrr = np.mean([x["recip_rank"] for x in evaluator.evaluate(final_scores).values()])
    if avg_mrr > best_mrr:
        best_ndcg = avg_ndcg
        best_alpha = alpha
        best_map = avg_map
        best_mrr = avg_mrr
    print(f"{alpha:5.2} MAP@{top_k}: {avg_map:.5} nDCG@{top_k}: {avg_ndcg:.5} MRR@{top_k}: {avg_mrr:.5}")
print(f"****Best alpha is {best_alpha:5.2}, with nDCG@{top_k}: {best_ndcg:.5} MAP@{top_k}: {best_map:.5f} MRR@{top_k}: {best_mrr:.5f}****")

  0.0 MAP@1000: 0.16087 nDCG@1000: 0.45354 MRR@1000: 0.35727
 0.05 MAP@1000: 0.18849 nDCG@1000: 0.47773 MRR@1000: 0.56659
  0.1 MAP@1000: 0.20094 nDCG@1000: 0.48724 MRR@1000: 0.57985
 0.15 MAP@1000: 0.21243 nDCG@1000: 0.50081 MRR@1000: 0.63296
  0.2 MAP@1000: 0.22177 nDCG@1000: 0.50929 MRR@1000: 0.67885
 0.25 MAP@1000: 0.23293 nDCG@1000: 0.52206 MRR@1000: 0.72644
  0.3 MAP@1000: 0.24816 nDCG@1000: 0.54091 MRR@1000: 0.78575
 0.35 MAP@1000: 0.26332 nDCG@1000: 0.55697 MRR@1000: 0.82193
  0.4 MAP@1000: 0.27895 nDCG@1000: 0.57372 MRR@1000: 0.87016
 0.45 MAP@1000: 0.29465 nDCG@1000: 0.58715 MRR@1000: 0.91805
  0.5 MAP@1000: 0.30808 nDCG@1000: 0.59671 MRR@1000: 0.91938
 0.55 MAP@1000: 0.3215 nDCG@1000: 0.60492 MRR@1000: 0.9186
  0.6 MAP@1000: 0.33528 nDCG@1000: 0.61219 MRR@1000: 0.9186
 0.65 MAP@1000: 0.34472 nDCG@1000: 0.61728 MRR@1000: 0.89903
  0.7 MAP@1000: 0.35086 nDCG@1000: 0.61895 MRR@1000: 0.88566
 0.75 MAP@1000: 0.35315 nDCG@1000: 0.6192 MRR@1000: 0.88218
  0.8 MAP@1000: 0.35119 nDCG