# Training

## Installation

In [None]:
!pip install sentence-transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentence-transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 KB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.27.4-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m59.9 MB/s[0m eta [36m0:00:00[0m
Collecting sentencepiece
  Downloading sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m53.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface-hub>=0.4.0
  Downloading huggingface_hub-0.13.4-py3-none-any.whl (200 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m200.1

In [None]:
!pip install pytrec_eval

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytrec_eval
  Downloading pytrec_eval-0.5.tar.gz (15 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pytrec_eval
  Building wheel for pytrec_eval (setup.py) ... [?25l[?25hdone
  Created wheel for pytrec_eval: filename=pytrec_eval-0.5-cp39-cp39-linux_x86_64.whl size=293193 sha256=8d3f8d0179a601ae0f5c68526f93b5514b331973b73ee3a1fae4d5d006402e20
  Stored in directory: /root/.cache/pip/wheels/e9/91/35/6059501bca98e27e0b4f91ecaaff86c95ca7f4919ff22f0d54
Successfully built pytrec_eval
Installing collected packages: pytrec_eval
Successfully installed pytrec_eval-0.5


## Imports

In [None]:
"""
This examples show how to train a Cross-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking).

The query and the passage are passed simoultanously to a Transformer network. The network then returns
a score between 0 and 1 how relevant the passage is for a given query.

The resulting Cross-Encoder can then be used for passage re-ranking: You retrieve for example 100 passages
for a given query, for example with ElasticSearch, and pass the query+retrieved_passage to the CrossEncoder
for scoring. You sort the results then according to the output of the CrossEncoder.

This gives a significant boost compared to out-of-the-box ElasticSearch / BM25 ranking.
"""
from torch.utils.data import DataLoader
from sentence_transformers import LoggingHandler, util
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
from sentence_transformers import InputExample
from datetime import datetime
import gzip
import os
import tarfile
import tqdm
import logging
from collections import defaultdict
import numpy as np
import sys
import pytrec_eval
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logging.basicConfig(format='%(asctime)s - %(message)s',datefmt='%Y-%m-%d %H:%M:%S')

## Training preparation

### Initialize hyperparameters (e.g., batch size, etc)

#### To prevent from losing the trained model because of getting disconnected from google colab, we suggest you to store trained model on your google drive. In below we do that by loading google.colab and set the path. 


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
base_path = "./gdrive/MyDrive/cross-encoder-reranker-ir-course-2023/"

Mounted at /content/gdrive


In [None]:
!mkdir -p $base_path

In [None]:
#First, we define the transformer model we want to fine-tune

train_batch_size = 32
num_epochs = 1
# We train the network with as a binary label task
# Given [query, passage] is the label 0 = irrelevant or 1 = relevant?
# We use a positive-to-negative ratio: For 1 positive sample (label 1) we include 4 negative samples (label 0)
# in our training setup. For the negative samples, we use the triplets provided by MS Marco that
# specify (query, positive sample, negative sample).
pos_neg_ration = 4

# Maximal number of training samples we want to use
max_train_samples = 5e6 #2e7

## Load model (cross-encoder/ms-marco-MiniLM-L-2-v2)

In [None]:
#We set num_labels=1, which predicts a continous score between 0 and 1
model_name = 'cross-encoder/ms-marco-MiniLM-L-2-v2'
# model_name = 'cross-encoder/ms-marco-TinyBERT-L-2-v2'
# model_name = 'distilroberta-base'
model = CrossEncoder(model_name, num_labels=1, max_length=512)
model_save_path = base_path  +'finetuned_models/cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L-2-v2/resolve/main/config.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 140060639496704 on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/080681a8d63930d920d45b6763dc48090f080f79.lock
DEBUG:filelock:Lock 140060639496704 acquired on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/080681a8d63930d920d45b6763dc48090f080f79.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /cross-encoder/ms-marco-MiniLM-L-2-v2/resolve/main/config.json HTTP/1.1" 200 794


Downloading (…)lve/main/config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 140060639496704 on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/080681a8d63930d920d45b6763dc48090f080f79.lock
DEBUG:filelock:Lock 140060639496704 released on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/080681a8d63930d920d45b6763dc48090f080f79.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L-2-v2/resolve/main/pytorch_model.bin HTTP/1.1" 302 0
DEBUG:filelock:Attempting to acquire lock 140056901734896 on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/e92ec9f854a5d8651f86db03375c55f4e4f893b7518177d4f2c8e31e3b9013a1.lock
DEBUG:filelock:Lock 140056901734896 acquired on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/e92ec9f854a5d8651f86db03375c55f4e4f893b7518177d4f2c8e31e3b9013a1.l

Downloading pytorch_model.bin:   0%|          | 0.00/62.5M [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 140056901734896 on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/e92ec9f854a5d8651f86db03375c55f4e4f893b7518177d4f2c8e31e3b9013a1.lock
DEBUG:filelock:Lock 140056901734896 released on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/e92ec9f854a5d8651f86db03375c55f4e4f893b7518177d4f2c8e31e3b9013a1.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L-2-v2/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 140056846630288 on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/2fd98132fd4620f90908272d9a9e6b2626e83491.lock
DEBUG:filelock:Lock 140056846630288 acquired on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/2fd98132fd4620f90908272d9a9e6b2626e834

Downloading (…)okenizer_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 140056846630288 on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/2fd98132fd4620f90908272d9a9e6b2626e83491.lock
DEBUG:filelock:Lock 140056846630288 released on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/2fd98132fd4620f90908272d9a9e6b2626e83491.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L-2-v2/resolve/main/vocab.txt HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 140056858295072 on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/fb140275c155a9c7c5a3b3e0e77a9e839594a938.lock
DEBUG:filelock:Lock 140056858295072 acquired on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/fb140275c155a9c7c5a3b3e0e77a9e839594a938.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS conn

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 140056858295072 on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/fb140275c155a9c7c5a3b3e0e77a9e839594a938.lock
DEBUG:filelock:Lock 140056858295072 released on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/fb140275c155a9c7c5a3b3e0e77a9e839594a938.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L-2-v2/resolve/main/tokenizer.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L-2-v2/resolve/main/added_tokens.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /cross-encoder/ms-marco-MiniLM-L-2-

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

DEBUG:filelock:Attempting to release lock 140056857437952 on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/e7b0375001f109a6b8873d756ad4f7bbb15fbaa5.lock
DEBUG:filelock:Lock 140056857437952 released on /root/.cache/huggingface/hub/models--cross-encoder--ms-marco-MiniLM-L-2-v2/blobs/e7b0375001f109a6b8873d756ad4f7bbb15fbaa5.lock
INFO:sentence_transformers.cross_encoder.CrossEncoder:Use pytorch device: cuda


## Download MSMARCO data + BM25 initial ranking run file

In [None]:
### Now we read the MS Marco dataset
data_folder = 'msmarco-data'
os.makedirs(data_folder, exist_ok=True)


#### Read the corpus files, that contain all the passages. Store them in the corpus dict
corpus = {}
collection_filepath = os.path.join(data_folder, 'collection.tsv')
if not os.path.exists(collection_filepath):
    tar_filepath = os.path.join(data_folder, 'collection.tar.gz')
    if not os.path.exists(tar_filepath):
        logging.info("Download collection.tar.gz")
        util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath)

    with tarfile.open(tar_filepath, "r:gz") as tar:
        tar.extractall(path=data_folder)

with open(collection_filepath, 'r', encoding='utf8') as fIn:
    for line in fIn:
        pid, passage = line.strip().split("\t")
        corpus[pid] = passage


### Read the train queries, store in queries dict
queries = {}
queries_filepath = os.path.join(data_folder, 'queries.train.tsv')
if not os.path.exists(queries_filepath):
    tar_filepath = os.path.join(data_folder, 'queries.tar.gz')
    if not os.path.exists(tar_filepath):
        logging.info("Download queries.tar.gz")
        util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath)

    with tarfile.open(tar_filepath, "r:gz") as tar:
        tar.extractall(path=data_folder)


with open(queries_filepath, 'r', encoding='utf8') as fIn:
    for line in fIn:
        qid, query = line.strip().split("\t")
        queries[qid] = query



### Now we create our training & dev data
train_samples = []
dev_samples = {}

# We use 200 random queries from the train set for evaluation during training
# Each query has at least one relevant and up to 200 irrelevant (negative) passages
num_dev_queries = 200
num_max_dev_negatives = 200

# msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz and msmarco-qidpidtriples.rnd-shuf.train.tsv.gz is a randomly
# shuffled version of qidpidtriples.train.full.2.tsv.gz from the MS Marco website
# We extracted in the train-eval split 500 random queries that can be used for evaluation during training
train_eval_filepath = os.path.join(data_folder, 'msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz')
if not os.path.exists(train_eval_filepath):
    logging.info("Download "+os.path.basename(train_eval_filepath))
    util.http_get('https://sbert.net/datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz', train_eval_filepath)

with gzip.open(train_eval_filepath, 'rt') as fIn:
    for line in fIn:
        qid, pos_id, neg_id = line.strip().split()

        if qid not in dev_samples and len(dev_samples) < num_dev_queries:
            dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()}

        if qid in dev_samples:
            dev_samples[qid]['positive'].add(corpus[pos_id])

            if len(dev_samples[qid]['negative']) < num_max_dev_negatives:
                dev_samples[qid]['negative'].add(corpus[neg_id])


# Read our training file
train_filepath = os.path.join(data_folder, 'msmarco-qidpidtriples.rnd-shuf.train.tsv.gz')
if not os.path.exists(train_filepath):
    logging.info("Download "+os.path.basename(train_filepath))
    util.http_get('https://sbert.net/datasets/msmarco-qidpidtriples.rnd-shuf.train.tsv.gz', train_filepath)

cnt = 0
with gzip.open(train_filepath, 'rt') as fIn:
    for line in tqdm.tqdm(fIn, unit_scale=True):
        qid, pos_id, neg_id = line.strip().split()

        if qid in dev_samples:
            continue

        query = queries[qid]
        if (cnt % (pos_neg_ration+1)) == 0:
            passage = corpus[pos_id]
            label = 1
        else:
            passage = corpus[neg_id]
            label = 0

        train_samples.append(InputExample(texts=[query, passage], label=label))
        cnt += 1

        if cnt >= max_train_samples:
            break

INFO:root:Download collection.tar.gz
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): msmarco.blob.core.windows.net:443
DEBUG:urllib3.connectionpool:https://msmarco.blob.core.windows.net:443 "GET /msmarcoranking/collection.tar.gz HTTP/1.1" 200 1035009698


  0%|          | 0.00/1.04G [00:00<?, ?B/s]

INFO:root:Download queries.tar.gz
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): msmarco.blob.core.windows.net:443
DEBUG:urllib3.connectionpool:https://msmarco.blob.core.windows.net:443 "GET /msmarcoranking/queries.tar.gz HTTP/1.1" 200 18882551


  0%|          | 0.00/18.9M [00:00<?, ?B/s]

INFO:root:Download msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): sbert.net:443
DEBUG:urllib3.connectionpool:https://sbert.net:443 "GET /datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz HTTP/1.1" 301 None
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): public.ukp.informatik.tu-darmstadt.de:443
DEBUG:urllib3.connectionpool:https://public.ukp.informatik.tu-darmstadt.de:443 "GET /reimers/sentence-transformers/datasets/msmarco-qidpidtriples.rnd-shuf.train-eval.tsv.gz HTTP/1.1" 200 2313734


  0%|          | 0.00/2.31M [00:00<?, ?B/s]

INFO:root:Download msmarco-qidpidtriples.rnd-shuf.train.tsv.gz
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): sbert.net:443
DEBUG:urllib3.connectionpool:https://sbert.net:443 "GET /datasets/msmarco-qidpidtriples.rnd-shuf.train.tsv.gz HTTP/1.1" 301 None
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): public.ukp.informatik.tu-darmstadt.de:443
DEBUG:urllib3.connectionpool:https://public.ukp.informatik.tu-darmstadt.de:443 "GET /reimers/sentence-transformers/datasets/msmarco-qidpidtriples.rnd-shuf.train.tsv.gz HTTP/1.1" 200 4414877667


  0%|          | 0.00/4.41G [00:00<?, ?B/s]

20.0Mit [01:53, 177kit/s]


## Initialize dataloader

In [None]:
# We create a DataLoader to load our train samples
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)

## Initialize CERerankingEvaluator Class
### The CERerankingEvaluator class evaluates the model after every 1k steps of training on the validation set
### Currently, CERerankingEvaluator computes MRR@10 on the valiadion set. You need to change MRR@10 to NDCG@10 for Exercise 4. 
###For that, you can download the CERerankingEvaluator class ([link](https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/cross_encoder/evaluation/CERerankingEvaluator.py)) and upload the modified implementation to the brightspace.




In [None]:
# We add an evaluator, which evaluates the performance during training
# It performs a classification task and measures scores like F1 (finding relevant passages) and Average Precision
evaluator = CERerankingEvaluator(dev_samples, name='train-eval')

## Train the model
### You can stop the training after one hour by stopping the run

In [None]:
# Train the model
model.fit(train_dataloader=train_dataloader,
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=1000,
          warmup_steps=5000,
          output_path=model_save_path,
          use_amp=True)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/625000 [00:00<?, ?it/s]

INFO:sentence_transformers.cross_encoder.evaluation.CERerankingEvaluator:CERerankingEvaluator: Evaluating the model on train-eval dataset in epoch 0 after 1000 steps:
INFO:sentence_transformers.cross_encoder.evaluation.CERerankingEvaluator:Queries: 200 	 Positives: Min 1.0, Mean 1.1, Max 3.0 	 Negatives: Min 100.0, Mean 199.1, Max 200.0
INFO:sentence_transformers.cross_encoder.evaluation.CERerankingEvaluator:MRR@10: 63.25
INFO:sentence_transformers.cross_encoder.CrossEncoder:Save model to ./gdrive/MyDrive/cross-encoder-reranker-ir-course-2023/finetuned_models/cross-encoder-cross-encoder-ms-marco-MiniLM-L-2-v2-2023-04-10_13-13-29
INFO:sentence_transformers.cross_encoder.evaluation.CERerankingEvaluator:CERerankingEvaluator: Evaluating the model on train-eval dataset in epoch 0 after 2000 steps:
INFO:sentence_transformers.cross_encoder.evaluation.CERerankingEvaluator:Queries: 200 	 Positives: Min 1.0, Mean 1.1, Max 3.0 	 Negatives: Min 100.0, Mean 199.1, Max 200.0
INFO:sentence_transforme

KeyboardInterrupt: ignored