In [None]:
!pip install pytrec_eval
!pip install sentence_transformers

# Large parts of the code under here are taken from https://github.com/arian-askari/RelevanceCAT/tree/main

In [None]:
import pytrec_eval
import json
import tqdm
import numpy as np
import math
import sys
from datetime import datetime
import gzip
import shutil
import os
import tarfile
import logging
from torch.utils.data import DataLoader
from sentence_transformers import LoggingHandler, util
from sentence_transformers import InputExample
from transformers import AutoTokenizer
import gdown
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
"""# Initializing variables

"""
base_write_path = ""
base_path = "msmarco-data"

os.makedirs(base_path, exist_ok=True)
num_epochs = 1
pos_neg_ratio = 4
max_train_samples = 0 # full train set
valid_max_queries = 1000 # full validation set
valid_max_negatives_per_query = 400 # full negatives per query
model_name = 'microsoft/MiniLM-L12-H384-uncased'
queries_path = os.path.join(base_path,"queries.train.tsv")
if not os.path.exists(queries_path):
    tar_filepath = os.path.join(base_path, 'queries.tar.gz')
    if not os.path.exists(tar_filepath):
        logging.info("Download queries.tar.gz")
        util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath)

    with tarfile.open(tar_filepath, "r:gz") as tar:
        tar.extractall(path=base_path)


corpus_path = os.path.join(base_path,"collection.tsv")
if not os.path.exists(corpus_path):
    tar_filepath = os.path.join(base_path, 'collection.tar.gz')
    if not os.path.exists(tar_filepath):
        logging.info("Download collection.tar.gz")
        util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath)

    with tarfile.open(tar_filepath, "r:gz") as tar:
        tar.extractall(path=base_path)

triples_train_path = os.path.join(base_path,"bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv?download=1")
if not os.path.exists(triples_train_path):
    util.http_get('https://zenodo.org/record/4068216/files/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv?download=1', triples_train_path)

triples_validation_path = os.path.join(base_path,"msmarco-qidpidtriples.rnd-shuf.train-eval.tsv")
if not os.path.exists(triples_validation_path):
    tar_filepath = os.path.join(base_path, 'msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz')
    if not os.path.exists(tar_filepath):
        logging.info("Download "+os.path.basename(triples_validation_path))
        util.http_get('https://sbert.net/datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz', tar_filepath)

    with gzip.open(tar_filepath, "rb") as f_in:
        with open(triples_validation_path, "wb") as f_out:
            shutil.copyfileobj(f_in, f_out)

model_save_path = os.path.join(base_write_path,'finetuned_CEs/train_all-MiniLM-L12-v3-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
# training parameters
max_length_query = 30
max_length_passage= 200
model_max_length = 230 + 3 + 3 # 3:[cls]query[sep]doc[sep]. Because we do have injecting into the input, we do not consider 3 token for v2.1! 3 extra tokens are needed in injection because bm25 score: normally takes two tokens, and onre more sep bm25 score [sep]

train_batch_size = 32
accumulation_steps = 1
evaluation_steps = 11000
print("model_name {} | model_max_length {} | batch_size {} | accumulation_steps {} ".format(model_name, model_max_length, train_batch_size, accumulation_steps))

INFO:root:Download queries.tar.gz
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): msmarco.blob.core.windows.net:443
DEBUG:urllib3.connectionpool:https://msmarco.blob.core.windows.net:443 "GET /msmarcoranking/queries.tar.gz HTTP/1.1" 200 18882551


  0%|          | 0.00/18.9M [00:00<?, ?B/s]

INFO:root:Download collection.tar.gz
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): msmarco.blob.core.windows.net:443
DEBUG:urllib3.connectionpool:https://msmarco.blob.core.windows.net:443 "GET /msmarcoranking/collection.tar.gz HTTP/1.1" 200 1035009698


  0%|          | 0.00/1.04G [00:00<?, ?B/s]

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): zenodo.org:443
DEBUG:urllib3.connectionpool:https://zenodo.org:443 "GET /record/4068216/files/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv?download=1 HTTP/1.1" 301 339
DEBUG:urllib3.connectionpool:https://zenodo.org:443 "GET /records/4068216/files/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv HTTP/1.1" 200 2389143631


  0%|          | 0.00/2.39G [00:00<?, ?B/s]

INFO:root:Download msmarco-qidpidtriples.rnd-shuf.train-eval.tsv
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): sbert.net:443
DEBUG:urllib3.connectionpool:https://sbert.net:443 "GET /datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz HTTP/1.1" 301 None
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): public.ukp.informatik.tu-darmstadt.de:443
DEBUG:urllib3.connectionpool:https://public.ukp.informatik.tu-darmstadt.de:443 "GET /reimers/sentence-transformers/datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz HTTP/1.1" 200 2313734


  0%|          | 0.00/2.31M [00:00<?, ?B/s]

model_name microsoft/MiniLM-L12-H384-uncased | model_max_length 236 | batch_size 32 | accumulation_steps 1 


In [None]:
# almost all of this is taken from https://github.com/arian-askari/RelevanceCAT/blob/main/train/train_ms-marco-MiniLM-L-12_v3_bm25.py

"""# CrossEncoder Class

"""

from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
import numpy as np
import logging
import os
from typing import Dict, Type, Callable, List
import transformers
import torch
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, util
from sentence_transformers.evaluation import SentenceEvaluator
class CrossEncoder():
    def __init__(self, model_name:str, num_labels:int = None, max_length:int = None, device:str = None, tokenizer_args:Dict = {},
                 default_activation_function = None):
        """
        A CrossEncoder takes exactly two sentences / texts as input and either predicts
        a score or label for this sentence pair. It can for example predict the similarity of the sentence pair
        on a scale of 0 ... 1.

        It does not yield a sentence embedding and does not work for individually sentences.

        :param model_name: Any model name from Huggingface Models Repository that can be loaded with AutoModel. We provide several pre-trained CrossEncoder models that can be used for common tasks
        :param num_labels: Number of labels of the classifier. If 1, the CrossEncoder is a regression model that outputs a continous score 0...1. If > 1, it output several scores that can be soft-maxed to get probability scores for the different classes.
        :param max_length: Max length for input sequences. Longer sequences will be truncated. If None, max length of the model will be used
        :param device: Device that should be used for the model. If None, it will use CUDA if available.
        :param tokenizer_args: Arguments passed to AutoTokenizer
        :param default_activation_function: Callable (like nn.Sigmoid) about the default activation function that should be used on-top of model.predict(). If None. nn.Sigmoid() will be used if num_labels=1, else nn.Identity()
        """

        self.config = AutoConfig.from_pretrained(model_name)
        classifier_trained = True
        if self.config.architectures is not None:
            classifier_trained = any([arch.endswith('ForSequenceClassification') for arch in self.config.architectures])

        if num_labels is None and not classifier_trained:
            num_labels = 1

        if num_labels is not None:
            self.config.num_labels = num_labels

        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes = True) # ignore_mismatched_sizes = True for transfer learning. first post_training, then using it for binary classification
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_args)
        self.max_length = max_length

        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
            logger.info("Use pytorch device: {}".format(device))

        self._target_device = torch.device(device)

        if default_activation_function is not None:
            self.default_activation_function = default_activation_function
            try:
                self.config.sbert_ce_default_activation_function = util.fullname(self.default_activation_function)
            except Exception as e:
                logger.warning("Was not able to update config about the default_activation_function: {}".format(str(e)) )
        elif hasattr(self.config, 'sbert_ce_default_activation_function') and self.config.sbert_ce_default_activation_function is not None:
            self.default_activation_function = util.import_from_string(self.config.sbert_ce_default_activation_function)()
        else:
            self.default_activation_function = nn.Sigmoid() if self.config.num_labels == 1 else nn.Identity()

    def smart_batching_collate(self, batch):
        texts = [[] for _ in range(len(batch[0].texts))]
        labels = []

        for example in batch:
            for idx, text in enumerate(example.texts):
                texts[idx].append(text.strip())

            labels.append(example.label)

        tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length)
        labels = torch.tensor(labels, dtype=torch.float if self.config.num_labels == 1 else torch.long).to(self._target_device)

        for name in tokenized:
            tokenized[name] = tokenized[name].to(self._target_device)

        return tokenized, labels

    def smart_batching_collate_text_only(self, batch):
        texts = [[] for _ in range(len(batch[0]))]

        for example in batch:
            for idx, text in enumerate(example):
                texts[idx].append(text.strip())

        tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length)

        for name in tokenized:
            tokenized[name] = tokenized[name].to(self._target_device)

        return tokenized

    def fit(self,
            train_dataloader: DataLoader,
            evaluator: SentenceEvaluator = None,
            epochs: int = 1,
            loss_fct = None,
            activation_fct = nn.Identity(),
            scheduler: str = 'WarmupLinear',
            warmup_steps: int = 10000,
            accumulation_steps: int = 1,
            optimizer_class: Type[Optimizer] = transformers.AdamW,
            optimizer_params: Dict[str, object] = {'lr': 2e-5},
            weight_decay: float = 0.01,
            evaluation_steps: int = 0,
            output_path: str = None,
            save_best_model: bool = True,
            max_grad_norm: float = 1,
            use_amp: bool = False,
            callback: Callable[[float, int, int], None] = None,
            ):
        """
        Train the model with the given training objective
        Each training objective is sampled in turn for one batch.
        We sample only as many batches from each objective as there are in the smallest one
        to make sure of equal training with each dataset.

        :param train_dataloader: DataLoader with training InputExamples
        :param evaluator: An evaluator (sentence_transformers.evaluation) evaluates the model performance during training on held-out dev data. It is used to determine the best model that is saved to disc.
        :param epochs: Number of epochs for training
        :param loss_fct: Which loss function to use for training. If None, will use nn.BCEWithLogitsLoss() if self.config.num_labels == 1 else nn.CrossEntropyLoss()
        :param activation_fct: Activation function applied on top of logits output of model.
        :param scheduler: Learning rate scheduler. Available schedulers: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
        :param warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is increased from o up to the maximal learning rate. After these many training steps, the learning rate is decreased linearly back to zero.
        :param accumulation_steps: Number of steps to accumulate before performing a backward pass
        :param optimizer_class: Optimizer
        :param optimizer_params: Optimizer parameters
        :param weight_decay: Weight decay for model parameters
        :param evaluation_steps: If > 0, evaluate the model using evaluator after each number of training steps
        :param output_path: Storage path for the model and evaluation files
        :param save_best_model: If true, the best model (according to evaluator) is stored at output_path
        :param max_grad_norm: Used for gradient normalization.
        :param use_amp: Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0
        :param callback: Callback function that is invoked after each evaluation.
                It must accept the following three parameters in this order:
                `score`, `epoch`, `steps`
        """
        train_dataloader.collate_fn = self.smart_batching_collate

        if use_amp:
            from torch.cuda.amp import autocast
            scaler = torch.cuda.amp.GradScaler()

        self.model.to(self._target_device)

        if output_path is not None:
            os.makedirs(output_path, exist_ok=True)

        self.best_score = -9999999
        num_train_steps = int(len(train_dataloader) * epochs)

        # Prepare optimizers
        param_optimizer = list(self.model.named_parameters())

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

        optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)

        if isinstance(scheduler, str):
            scheduler = SentenceTransformer._get_scheduler(optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps)

        if loss_fct is None:
            loss_fct = nn.BCEWithLogitsLoss() if self.config.num_labels == 1 else nn.CrossEntropyLoss()


        skip_scheduler = False
        for epoch in tqdm.trange(epochs, desc="Epoch"):
            training_steps = 0
            self.model.zero_grad()
            self.model.train()
            for i, (features, labels) in tqdm.tqdm(enumerate(train_dataloader), total=(len(train_dataloader) // accumulation_steps), desc="Iteration", smoothing=0.05):
                if use_amp:
                    with autocast():
                        model_predictions = self.model(**features, return_dict=True)
                        logits = activation_fct(model_predictions.logits)
                        if self.config.num_labels == 1:
                            logits = logits.view(-1)
                        loss_value = loss_fct(logits, labels)
                        loss_value /= accumulation_steps

                    scale_before_step = scaler.get_scale()
                    scaler.scale(loss_value).backward()
                    if (i + 1) % accumulation_steps == 0:
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
                        scaler.step(optimizer)
                        scaler.update()
                        optimizer.zero_grad()

                    skip_scheduler = scaler.get_scale() != scale_before_step
                else:
                    model_predictions = self.model(**features, return_dict=True)
                    logits = activation_fct(model_predictions.logits)
                    if self.config.num_labels == 1:
                        logits = logits.view(-1)
                    loss_value = loss_fct(logits, labels)
                    loss_value /= accumulation_steps
                    loss_value.backward()
                    if (i + 1) % accumulation_steps == 0:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
                        optimizer.step()
                        optimizer.zero_grad()

                if not skip_scheduler and (i + 1) % accumulation_steps == 0:
                    scheduler.step()

                training_steps += 1

                if evaluator is not None and evaluation_steps > 0 and training_steps % evaluation_steps == 0:
                    self._eval_during_training(evaluator, output_path, save_best_model, epoch, training_steps, callback)

                    self.model.zero_grad()
                    self.model.train()

            if evaluator is not None:
                self._eval_during_training(evaluator, output_path, save_best_model, epoch, -1, callback)



    def predict(self, sentences: List[List[str]],
               batch_size: int = 32,
               show_progress_bar: bool = None,
               num_workers: int = 0,
               activation_fct = None,
               apply_softmax = False,
               convert_to_numpy: bool = True,
               convert_to_tensor: bool = False
               ):
        """
        Performs predicts with the CrossEncoder on the given sentence pairs.

        :param sentences: A list of sentence pairs [[Sent1, Sent2], [Sent3, Sent4]]
        :param batch_size: Batch size for encoding
        :param show_progress_bar: Output progress bar
        :param num_workers: Number of workers for tokenization
        :param activation_fct: Activation function applied on the logits output of the CrossEncoder. If None, nn.Sigmoid() will be used if num_labels=1, else nn.Identity
        :param convert_to_numpy: Convert the output to a numpy matrix.
        :param apply_softmax: If there are more than 2 dimensions and apply_softmax=True, applies softmax on the logits output
        :param convert_to_tensor:  Conver the output to a tensor.
        :return: Predictions for the passed sentence pairs
        """
        input_was_string = False
        if isinstance(sentences[0], str):  # Cast an individual sentence to a list with length 1
            sentences = [sentences]
            input_was_string = True

        inp_dataloader = DataLoader(sentences, batch_size=batch_size, collate_fn=self.smart_batching_collate_text_only, num_workers=num_workers, shuffle=False)

        if show_progress_bar is None:
            show_progress_bar = (logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG)

        iterator = inp_dataloader
        if show_progress_bar:
            iterator = tqdm.tqdm(inp_dataloader, desc="Batches")

        if activation_fct is None:
            activation_fct = self.default_activation_function

        pred_scores = []
        self.model.eval()
        self.model.to(self._target_device)
        with torch.no_grad():
            for features in iterator:
                model_predictions = self.model(**features, return_dict=True)
                logits = activation_fct(model_predictions.logits)

                if apply_softmax and len(logits[0]) > 1:
                    logits = torch.nn.functional.softmax(logits, dim=1)
                pred_scores.extend(logits)

        if self.config.num_labels == 1:
            pred_scores = [score[0] for score in pred_scores]

        if convert_to_tensor:
            pred_scores = torch.stack(pred_scores)
        elif convert_to_numpy:
            pred_scores = np.asarray([score.cpu().detach().numpy() for score in pred_scores])

        if input_was_string:
            pred_scores = pred_scores[0]

        return pred_scores


    def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps, callback):
        """Runs evaluation during the training"""
        if evaluator is not None:
            score = evaluator(self, output_path=output_path, epoch=epoch, steps=steps)
            if callback is not None:
                callback(score, epoch, steps)
            if score > self.best_score:
                self.best_score = score
                if save_best_model:
                    self.save(output_path)

    def save(self, path):
        """
        Saves all model and tokenizer to path
        """
        if path is None:
            return

        logger.info("Save model to {}".format(path))
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)

    def save_pretrained(self, path):
        """
        Same function as save
        """
        return self.save(path)


In [None]:
# Unchanged taken from https://github.com/arian-askari/RelevanceCAT/blob/main/evaluation/eval_trec19-MiniLM-L-12-v3.py

"""# Evaluator Class"""

import numpy as np
import os
import csv
import pytrec_eval
import tqdm
from sentence_transformers import LoggingHandler, util
class CERerankingEvaluator:
    """
    This class evaluates a CrossEncoder model for the task of re-ranking.

    Given a query and a list of documents, it computes the score [query, doc_i] for all possible
    documents and sorts them in decreasing order. Then, ndcg@10 is compute to measure the quality of the ranking.

    :param samples: Must be a list and each element is of the form: {'query': '', 'positive': [], 'negative': []}. Query is the search query,
     positive is a list of positive (relevant) documents, negative is a list of negative (irrelevant) documents.
    """
    def __init__(self, samples, all_metrics: set = {"recall.1"}, name: str = '', write_csv: bool = True, show_progress_bar: bool = False):
        self.samples = samples
        self.name = name
        self.all_metrics = all_metrics

        if isinstance(self.samples, dict):
            self.samples = list(self.samples.values())

        self.csv_file = "CERerankingEvaluator" + ("_" + name if name else '') + "_results.csv"
        self.csv_headers = ["epoch", "steps"] + list(all_metrics)
        self.write_csv = write_csv
        self.mean_metrics = {}
        self.show_progress_bar = show_progress_bar
    def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
        mean_ndcg = 0

        if epoch != -1:
            if steps == -1:
                out_txt = " after epoch {}:".format(epoch)
            else:
                out_txt = " in epoch {} after {} steps:".format(epoch, steps)
        else:
            out_txt = ":"

        logger.info("CERerankingEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)

        all_ndcg_scores = []
        num_queries = 0
        num_positives = []
        num_negatives = []
        run = {}
        qrel = {}
        print("len: self.samples: " + str(len(self.samples)))
        try:
            for instance in tqdm.tqdm(self.samples):
                # print("instance: ", instance)
                qid = instance['qid']
                query = instance['query']
                positive = list(instance['positive'])
                negative = list(instance['negative'])
                positive_pids = list(instance['positive_ids'])
                negative_pids = list(instance['negative_ids'])
                docs =  negative + positive
                docs_ids = negative_pids + positive_pids
                is_relevant = [False]*len(negative) +  [True]*len(positive)

                qrel[qid] = {}
                run[qid] = {}
                for pid in positive_pids:
                    qrel[qid][pid] = 1

                if len(positive) == 0 or len(negative) == 0:
                    continue

                num_queries += 1
                num_positives.append(len(positive))
                num_negatives.append(len(negative))

                model_input = [[query, doc] for doc in docs]
                if model.config.num_labels > 1: #Cross-Encoder that predict more than 1 score, we use the last and apply softmax
                    pred_scores = model.predict(model_input, apply_softmax=True, batch_size=16, show_progress_bar = self.show_progress_bar)[:, 1].tolist()
                else:
                    pred_scores = model.predict(model_input, batch_size=16, show_progress_bar = self.show_progress_bar).tolist()
                for pred_score, did in zip(list(pred_scores), docs_ids):
                    line = "{query_id} Q0 {document_id} {rank} {score} STANDARD\n".format(query_id=qid,
                                                                                          document_id=did,
                                                                                          rank="-10",#rank,
                                                                                          score=str(pred_score))
                    run[qid][did] = float(pred_score)

            evaluator = pytrec_eval.RelevanceEvaluator(qrel, self.all_metrics)
            scores = evaluator.evaluate(run)
            self.mean_metrics = {}
            metrics_string = ""
            for metric in list(self.all_metrics):
                self.mean_metrics[metric] = np.mean([ele[metric.replace(".","_")] for ele in scores.values()])
                metrics_string = metrics_string +  "{}: {} | ".format(metric, self.mean_metrics[metric])
            print("metrics eval: ", metrics_string)
            logger.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives)))
        except Exception as e:
            logger.error("error: ", e)
        if output_path is not None and self.write_csv:
            csv_path = os.path.join(output_path, self.csv_file)
            output_file_exists = os.path.isfile(csv_path)
            with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f: # early stopping can be done by modifying this part. You can read this csv file. Then you need to count: best_step - last step + 1. if it is >earlystopping. then, you can just do sys.exit(1) to kill the process :)
                writer = csv.writer(f)
                if not output_file_exists:
                    writer.writerow(self.csv_headers)
                writer.writerow([epoch, steps, sum(self.mean_metrics.values())])
        return sum(self.mean_metrics.values())

In [None]:
"""#Data

## utils

### read collections
"""

def smart_truncate(content, length):
    return content if len(content) <= length else content[:length].rsplit(' ', 1)[0]

def read_collection(f_path, truncate_length):
  corpus = {}
  with open(f_path, "r") as fp:
    for line in tqdm.tqdm(fp, desc="reading {}".format(f_path)):
      did, dtext = line.strip().split("\t")
      dtext = "[CLS] "+smart_truncate(dtext, truncate_length)+" [SEP]"
      corpus[did] = dtext
  return corpus

def read_train_triples_logits_and_score_injected(f_path, queries, corpus, max_instances = 0 , pos_neg_ratio = 4, dev_qids = set()):
  samples = []
  count = 0
  with open(f_path) as fIn:
    for line in tqdm.tqdm(fIn, desc="reading {}".format(f_path)):
      count +=1
      if count > 500000:
        return samples
      pos_score, neg_score, qid, pid1, pid2 = line.strip().split("\t")
      if qid in dev_qids: #Skip queries in our dev dataset. The logits file contains the train set as well so we shold skip the dev queries.
        continue
      samples.append(InputExample(texts=[queries[qid], corpus[pid1]], label=float(pos_score)))
      samples.append(InputExample(texts=[queries[qid], corpus[pid2]], label=float(neg_score)))
  return samples

def read_triples_validation(f_path, queries, corpus, max_queries = 500, max_negatives_per_query = 200): # 200 negative per queries and 500 queries are good enough for evaluating during training!
  """
    :param max_instances: 0 means read full instances
  """
  samples = {}
  with open(f_path, "r") as fp:
    for line in tqdm.tqdm(fp, desc="reading {}".format(f_path)):
      qid, pos_id, neg_id = line.strip().split("\t")
      query = queries[qid]
      if qid not in samples and (max_queries == 0 or len(samples) < max_queries):
        samples[qid] = {'qid': qid , 'query': query, 'positive': list(), 'negative': list(), "positive_ids": list(), "negative_ids": list()}
      if qid in samples:
        if pos_id not in samples[qid]['positive_ids']:
            samples[qid]['positive'].append(corpus[pos_id])
            samples[qid]['positive_ids'].append(pos_id)
        if (len(samples[qid]['negative']) < max_negatives_per_query) or (max_negatives_per_query == 0):
            samples[qid]['negative'].append(corpus[neg_id])
            samples[qid]['negative_ids'].append(neg_id)
  return samples

In [None]:
"""## Reading data

### reading corpus and queries and truncate it
"""

queries = read_collection(queries_path, 150) # Truncated to 150 characters
corpus = read_collection(corpus_path, 1000)  # Truncated to 1000 characters

dev_samples = read_triples_validation(triples_validation_path, queries, corpus, valid_max_queries, valid_max_negatives_per_query)

train_samples = read_train_triples_logits_and_score_injected(triples_train_path, queries, corpus, max_train_samples, pos_neg_ratio, dev_qids = dev_samples.keys())

reading msmarco-data/queries.train.tsv: 808731it [00:01, 533067.00it/s]
reading msmarco-data/collection.tsv: 8841823it [00:33, 266559.66it/s]
reading msmarco-data/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv: 542646it [00:02, 205804.54it/s]
reading msmarco-data/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv?download=1: 500000it [00:09, 51154.08it/s]


In [None]:
del queries
del corpus

In [None]:
# del model
# torch.cuda.empty_cache()

In [None]:
train_batch_size = 32
accumulation_steps = 1

"""# Training"""
#We set num_labels=1 and set the activation function to Identiy, so that we get the raw logits
model = CrossEncoder(model_name, num_labels=1, max_length=model_max_length)
model.config.gradient_checkpointing = False

"""## Fit"""
# We create a DataLoader to load our train samples
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size, drop_last=True)


# We add an evaluator, which evaluates the performance during training
# It performs a classification task and measures scores: "ndcg_cut.10", "map_cut.1000", "recall.10" to select best performing model based on the performance on the validation set
evaluator = CERerankingEvaluator(dev_samples, name='train-eval',  all_metrics={"ndcg_cut.10", "map_cut.1000", "recall.10"})#https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/cross_encoder/evaluation/CERerankingEvaluator.py

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /microsoft/MiniLM-L12-H384-uncased/resolve/main/config.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 132960458371904 on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-uncased/d17c3208194971b8d2bb226e3ebd07a2aa477d48.lock
DEBUG:filelock:Lock 132960458371904 acquired on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-uncased/d17c3208194971b8d2bb226e3ebd07a2aa477d48.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /microsoft/MiniLM-L12-H384-uncased/resolve/main/config.json HTTP/1.1" 200 385


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 132960458371904 on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-uncased/d17c3208194971b8d2bb226e3ebd07a2aa477d48.lock
DEBUG:filelock:Lock 132960458371904 released on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-uncased/d17c3208194971b8d2bb226e3ebd07a2aa477d48.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /microsoft/MiniLM-L12-H384-uncased/resolve/main/model.safetensors HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /microsoft/MiniLM-L12-H384-uncased/resolve/main/model.safetensors.index.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /microsoft/MiniLM-L12-H384-uncased/resolve/main/pytorch_model.bin HTTP/1.1" 302 0
DEBUG:filelock:Attempting to acquire lock 132960360818352 on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-uncased/bcc275fe3e183a68e629c55355be77ef51c2ee780990173ca8f8edc651

pytorch_model.bin:   0%|          | 0.00/133M [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 132960360818352 on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-uncased/bcc275fe3e183a68e629c55355be77ef51c2ee780990173ca8f8edc65152fae7.lock
DEBUG:filelock:Lock 132960360818352 released on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-uncased/bcc275fe3e183a68e629c55355be77ef51c2ee780990173ca8f8edc65152fae7.lock
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/MiniLM-L12-H384-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /microsoft/MiniLM-L12-H384-uncased/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 132960360816720 on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-unca

tokenizer_config.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 132960360816720 on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-uncased/9e26dfeeb6e641a33dae4961196235bdb965b21b.lock
DEBUG:filelock:Lock 132960360816720 released on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-uncased/9e26dfeeb6e641a33dae4961196235bdb965b21b.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /microsoft/MiniLM-L12-H384-uncased/resolve/main/vocab.txt HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 132960360843200 on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-uncased/fb140275c155a9c7c5a3b3e0e77a9e839594a938.lock
DEBUG:filelock:Lock 132960360843200 acquired on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-uncased/fb140275c155a9c7c5a3b3e0e77a9e839594a938.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /microsoft/MiniLM-L12-H384-uncased/resolve/main/vocab.txt HTTP/1.1" 200 231508


vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 132960360843200 on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-uncased/fb140275c155a9c7c5a3b3e0e77a9e839594a938.lock
DEBUG:filelock:Lock 132960360843200 released on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-uncased/fb140275c155a9c7c5a3b3e0e77a9e839594a938.lock
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /microsoft/MiniLM-L12-H384-uncased/resolve/main/tokenizer.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /microsoft/MiniLM-L12-H384-uncased/resolve/main/added_tokens.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /microsoft/MiniLM-L12-H384-uncased/resolve/main/special_tokens_map.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 132956954232864 on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-uncased/e7b0375001f109a6b8873d756ad4f7bbb15fbaa5.lock
DEBUG:filelock:Lock 

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 132956954232864 on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-uncased/e7b0375001f109a6b8873d756ad4f7bbb15fbaa5.lock
DEBUG:filelock:Lock 132956954232864 released on /root/.cache/huggingface/hub/.locks/models--microsoft--MiniLM-L12-H384-uncased/e7b0375001f109a6b8873d756ad4f7bbb15fbaa5.lock
INFO:root:Use pytorch device: cuda


In [None]:
# Configure the training
warmup_steps = 1000
logging.info("Warmup-steps: {}".format(warmup_steps))

# Train the model
model.fit(train_dataloader=train_dataloader,
          loss_fct=torch.nn.MSELoss(),
          evaluator=evaluator,
          evaluation_steps= evaluation_steps,
          epochs=num_epochs,
          warmup_steps=warmup_steps,
          output_path=model_save_path,
          accumulation_steps = accumulation_steps,# train_batch_size*accumulation_steps will be real batch size. We use 1 as accumulation_steps here.
          use_amp=True,
          optimizer_params = {'lr': 7e-6}, # sentence bert config! however,  Hostätter et al. has used 7e-6
          weight_decay = 0.01 # wd as 0 for adam. for adamw which is the correct impl of adam, wd is 0.01!
          )
model.save(model_save_path+'-latest')

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Iteration:  19%|█▉        | 6001/31209 [22:36<1:31:00,  4.62it/s][A
Iteration:  19%|█▉        | 6002/31209 [22:37<1:30:52,  4.62it/s][A
Iteration:  19%|█▉        | 6003/31209 [22:37<1:30:54,  4.62it/s][A
Iteration:  19%|█▉        | 6004/31209 [22:37<1:31:02,  4.61it/s][A
Iteration:  19%|█▉        | 6005/31209 [22:37<1:30:10,  4.66it/s][A
Iteration:  19%|█▉        | 6006/31209 [22:38<1:32:04,  4.56it/s][A
Iteration:  19%|█▉        | 6007/31209 [22:38<1:32:47,  4.53it/s][A
Iteration:  19%|█▉        | 6008/31209 [22:38<1:32:05,  4.56it/s][A
Iteration:  19%|█▉        | 6009/31209 [22:38<1:31:04,  4.61it/s][A
Iteration:  19%|█▉        | 6010/31209 [22:38<1:30:39,  4.63it/s][A
Iteration:  19%|█▉        | 6011/31209 [22:39<1:30:43,  4.63it/s][A
Iteration:  19%|█▉        | 6012/31209 [22:39<1:30:41,  4.63it/s][A
Iteration:  19%|█▉        | 6013/31209 [22:39<1:30:30,  4.64it/s][A
Iteration:  19%|█▉        | 6014/31209

len: self.samples: 500




  0%|          | 0/500 [00:00<?, ?it/s][A[A

  0%|          | 1/500 [00:01<13:44,  1.65s/it][A[A

  0%|          | 2/500 [00:02<11:19,  1.37s/it][A[A

  1%|          | 3/500 [00:04<11:09,  1.35s/it][A[A

  1%|          | 4/500 [00:05<11:10,  1.35s/it][A[A

  1%|          | 5/500 [00:06<10:50,  1.31s/it][A[A

  1%|          | 6/500 [00:07<10:27,  1.27s/it][A[A

  1%|▏         | 7/500 [00:09<10:32,  1.28s/it][A[A

  2%|▏         | 8/500 [00:10<10:41,  1.30s/it][A[A

  2%|▏         | 9/500 [00:11<10:13,  1.25s/it][A[A

  2%|▏         | 10/500 [00:12<09:50,  1.21s/it][A[A

  2%|▏         | 11/500 [00:14<10:20,  1.27s/it][A[A

  2%|▏         | 12/500 [00:15<10:51,  1.34s/it][A[A

  3%|▎         | 13/500 [00:17<11:01,  1.36s/it][A[A

  3%|▎         | 14/500 [00:18<10:56,  1.35s/it][A[A

  3%|▎         | 15/500 [00:20<11:20,  1.40s/it][A[A

  3%|▎         | 16/500 [00:21<11:41,  1.45s/it][A[A

  3%|▎         | 17/500 [00:22<10:39,  1.32s/it][A[A

  4%|▎  