## Train SBERT with Inbatch negatives

### Reference: https://github.com/beir-cellar/beir/blob/main/examples/retrieval/training/train_sbert.py

In [None]:
# Install the beir PyPI package
!pip install beir

In [2]:
from beir import util, LoggingHandler

import logging
import pathlib, os

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])

  from tqdm.autonotebook import tqdm


In [3]:
import pathlib, os
from beir import util

dataset = "scifact"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(os.getcwd(), "datasets")
data_path = util.download_and_unzip(url, out_dir)
print("Dataset downloaded here: {}".format(data_path))

/content/datasets/scifact.zip:   0%|          | 0.00/2.69M [00:00<?, ?iB/s]

Dataset downloaded here: /content/datasets/scifact


In [4]:
!ls datasets/scifact/

corpus.jsonl  qrels  queries.jsonl


In [12]:
#https://github.com/beir-cellar/beir/blob/c80a554314439293a7afdf577e46125c9fa976f5/beir/datasets/data_loader.py#L10
from typing import Dict, Tuple
from tqdm.autonotebook import tqdm
import json
import os
import logging
import csv

logger = logging.getLogger(__name__)

class GenericDataLoader:
    
    def __init__(self, data_folder: str = None, prefix: str = None, corpus_file: str = "corpus.jsonl", query_file: str = "queries.jsonl", 
                 qrels_folder: str = "qrels", qrels_file: str = ""):
        self.corpus = {}
        self.queries = {}
        self.qrels = {}
        
        if prefix:
            query_file = prefix + "-" + query_file
            qrels_folder = prefix + "-" + qrels_folder

        self.corpus_file = os.path.join(data_folder, corpus_file) if data_folder else corpus_file
        self.query_file = os.path.join(data_folder, query_file) if data_folder else query_file
        self.qrels_folder = os.path.join(data_folder, qrels_folder) if data_folder else None
        self.qrels_file = qrels_file
    
    @staticmethod
    def check(fIn: str, ext: str):
        if not os.path.exists(fIn):
            raise ValueError("File {} not present! Please provide accurate file.".format(fIn))
        
        if not fIn.endswith(ext):
            raise ValueError("File {} must be present with extension {}".format(fIn, ext))

    def load_custom(self) -> Tuple[Dict[str, Dict[str, str]], Dict[str, str], Dict[str, Dict[str, int]]]:

        self.check(fIn=self.corpus_file, ext="jsonl")
        self.check(fIn=self.query_file, ext="jsonl")
        self.check(fIn=self.qrels_file, ext="tsv")
        
        if not len(self.corpus):
            logger.info("Loading Corpus...")
            self._load_corpus()
            logger.info("Loaded %d Documents.", len(self.corpus))
            logger.info("Doc Example: %s", list(self.corpus.values())[0])
        
        if not len(self.queries):
            logger.info("Loading Queries...")
            self._load_queries()
        
        if os.path.exists(self.qrels_file):
            self._load_qrels()
            self.queries = {qid: self.queries[qid] for qid in self.qrels}
            logger.info("Loaded %d Queries.", len(self.queries))
            logger.info("Query Example: %s", list(self.queries.values())[0])
        
        return self.corpus, self.queries, self.qrels

    def load(self, split="test") -> Tuple[Dict[str, Dict[str, str]], Dict[str, str], Dict[str, Dict[str, int]]]:
        
        self.qrels_file = os.path.join(self.qrels_folder, split + ".tsv")
        self.check(fIn=self.corpus_file, ext="jsonl")
        self.check(fIn=self.query_file, ext="jsonl")
        self.check(fIn=self.qrels_file, ext="tsv")
        
        if not len(self.corpus):
            logger.info("Loading Corpus...")
            self._load_corpus()
            logger.info("Loaded %d %s Documents.", len(self.corpus), split.upper())
            logger.info("Doc Example: %s", list(self.corpus.values())[0])
        
        if not len(self.queries):
            logger.info("Loading Queries...")
            self._load_queries()
        
        if os.path.exists(self.qrels_file):
            self._load_qrels()
            self.queries = {qid: self.queries[qid] for qid in self.qrels}
            logger.info("Loaded %d %s Queries.", len(self.queries), split.upper())
            logger.info("Query Example: %s", list(self.queries.values())[0])
        
        return self.corpus, self.queries, self.qrels
    
    def load_corpus(self) -> Dict[str, Dict[str, str]]:
        
        self.check(fIn=self.corpus_file, ext="jsonl")

        if not len(self.corpus):
            logger.info("Loading Corpus...")
            self._load_corpus()
            logger.info("Loaded %d Documents.", len(self.corpus))
            logger.info("Doc Example: %s", list(self.corpus.values())[0])

        return self.corpus
    
    def _load_corpus(self):
    
        num_lines = sum(1 for i in open(self.corpus_file, 'rb'))
        with open(self.corpus_file, encoding='utf8') as fIn:
            for line in tqdm(fIn, total=num_lines):
                line = json.loads(line)
                self.corpus[line.get("_id")] = {
                    "text": line.get("text"),
                    "title": line.get("title"),
                }
    
    def _load_queries(self):
        
        with open(self.query_file, encoding='utf8') as fIn:
            for line in fIn:
                line = json.loads(line)
                self.queries[line.get("_id")] = line.get("text")
        
    def _load_qrels(self):
        
        reader = csv.reader(open(self.qrels_file, encoding="utf-8"), 
                            delimiter="\t", quoting=csv.QUOTE_MINIMAL)
        next(reader)
        
        for id, row in enumerate(reader):
            query_id, corpus_id, score = row[0], row[1], int(row[2])
            
            if query_id not in self.qrels:
                self.qrels[query_id] = {corpus_id: score}
            else:
                self.qrels[query_id][corpus_id] = score


In [14]:

data_path = "datasets/scifact"
corpus, queries, qrels = GenericDataLoader(data_path).load(split="train") # or split = "train" or "dev"

  0%|          | 0/5183 [00:00<?, ?it/s]

In [24]:
#### Please Note not all datasets contain a dev split, comment out the line if such the case
dev_corpus, dev_queries, dev_qrels = GenericDataLoader(data_path).load(split="test")

  0%|          | 0/5183 [00:00<?, ?it/s]

#### Sentence Transformers Training without any mined hard negatives or triplets.

In [None]:
!pip install sentence-transformers

In [17]:
#### Provide any sentence-transformers or HF model
from sentence_transformers import losses, models, SentenceTransformer
model_name = "distilbert-base-uncased" 
word_embedding_model = models.Transformer(model_name, max_seq_length=350)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

Downloading (…)lve/main/config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [19]:
import random

In [20]:
###https://github.com/beir-cellar/beir/blob/c80a554314439293a7afdf577e46125c9fa976f5/beir/retrieval/train.py#L16
from sentence_transformers import SentenceTransformer, SentencesDataset, models, datasets
from sentence_transformers.evaluation import SentenceEvaluator, SequentialEvaluator, InformationRetrievalEvaluator
from sentence_transformers.readers import InputExample
from transformers import AdamW
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import Optimizer
from tqdm.autonotebook import trange
from typing import Dict, Type, List, Callable, Iterable, Tuple
import logging
import time
import difflib

logger = logging.getLogger(__name__)

class TrainRetriever:
    
    def __init__(self, model: Type[SentenceTransformer], batch_size: int = 64):
        self.model = model
        self.batch_size = batch_size

    def load_train(self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, str], 
                   qrels: Dict[str, Dict[str, int]]) -> List[Type[InputExample]]:
        
        query_ids = list(queries.keys())
        train_samples = []

        for idx, start_idx in enumerate(trange(0, len(query_ids), self.batch_size, desc='Adding Input Examples')):
            query_ids_batch = query_ids[start_idx:start_idx+self.batch_size]
            for query_id in query_ids_batch:
                for corpus_id, score in qrels[query_id].items():
                    if score >= 1: # if score = 0, we don't consider for training
                        try:
                            s1 = queries[query_id]
                            s2 = corpus[corpus_id].get("title") + " " + corpus[corpus_id].get("text") 
                            train_samples.append(InputExample(guid=idx, texts=[s1, s2], label=1))
                        except KeyError:
                            logging.error("Error: Key {} not present in corpus!".format(corpus_id))

        logger.info("Loaded {} training pairs.".format(len(train_samples)))
        return train_samples

    def load_train_triplets(self, triplets: List[Tuple[str, str, str]]) -> List[Type[InputExample]]:        
        
        train_samples = []

        for idx, start_idx in enumerate(trange(0, len(triplets), self.batch_size, desc='Adding Input Examples')):
            triplets_batch = triplets[start_idx:start_idx+self.batch_size]
            for triplet in triplets_batch:
                guid = None
                train_samples.append(InputExample(guid=guid, texts=triplet))

        logger.info("Loaded {} training pairs.".format(len(train_samples)))
        return train_samples
    
    def prepare_train(self, train_dataset: List[Type[InputExample]], shuffle: bool = True, dataset_present: bool = False) -> DataLoader:
        
        if not dataset_present: 
            train_dataset = SentencesDataset(train_dataset, model=self.model)
        
        train_dataloader = DataLoader(train_dataset, shuffle=shuffle, batch_size=self.batch_size)
        return train_dataloader
    
    def prepare_train_triplets(self, train_dataset: List[Type[InputExample]]) -> DataLoader:
        
        train_dataloader = datasets.NoDuplicatesDataLoader(train_dataset, batch_size=self.batch_size)
        return train_dataloader
    
    def load_ir_evaluator(self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, str], 
                 qrels: Dict[str, Dict[str, int]], max_corpus_size: int = None, name: str = "eval") -> SentenceEvaluator:

        if len(queries) <= 0:
            raise ValueError("Dev Set Empty!, Cannot evaluate on Dev set.")
        
        rel_docs = {}
        corpus_ids = set()
        
        # need to convert corpus to cid => doc      
        corpus = {idx: corpus[idx].get("title") + " " + corpus[idx].get("text") for idx in corpus}
        
        # need to convert dev_qrels to qid => Set[cid]        
        for query_id, metadata in qrels.items():
            rel_docs[query_id] = set()
            for corpus_id, score in metadata.items():
                if score >= 1:
                    corpus_ids.add(corpus_id)
                    rel_docs[query_id].add(corpus_id)
        
        if max_corpus_size:
            # check if length of corpus_ids > max_corpus_size
            if len(corpus_ids) > max_corpus_size:
                raise ValueError("Your maximum corpus size should atleast contain {} corpus ids".format(len(corpus_ids)))
            
            # Add mandatory corpus documents
            new_corpus = {idx: corpus[idx] for idx in corpus_ids}
            
            # Remove mandatory corpus documents from original corpus
            for corpus_id in corpus_ids:
                corpus.pop(corpus_id, None)
            
            # Sample randomly remaining corpus documents
            for corpus_id in random.sample(list(corpus), max_corpus_size - len(corpus_ids)):
                new_corpus[corpus_id] = corpus[corpus_id]

            corpus = new_corpus

        logger.info("{} set contains {} documents and {} queries".format(name, len(corpus), len(queries)))
        return InformationRetrievalEvaluator(queries, corpus, rel_docs, name=name)
    
    def load_dummy_evaluator(self) -> SentenceEvaluator:
            return SequentialEvaluator([], main_score_function=lambda x: time.time())

    def fit(self, 
            train_objectives: Iterable[Tuple[DataLoader, nn.Module]],
            evaluator: SentenceEvaluator = None,
            epochs: int = 1,
            steps_per_epoch = None,
            scheduler: str = 'WarmupLinear',
            warmup_steps: int = 10000,
            optimizer_class: Type[Optimizer] = AdamW,
            optimizer_params : Dict[str, object]= {'lr': 2e-5, 'eps': 1e-6, 'correct_bias': False},
            weight_decay: float = 0.01,
            evaluation_steps: int = 0,
            output_path: str = None,
            save_best_model: bool = True,
            max_grad_norm: float = 1,
            use_amp: bool = False,
            callback: Callable[[float, int, int], None] = None,
            **kwargs):
        
        # Train the model
        logger.info("Starting to Train...")

        self.model.fit(train_objectives=train_objectives,
                evaluator=evaluator,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                warmup_steps=warmup_steps,
                optimizer_class=optimizer_class,
                scheduler=scheduler,
                optimizer_params=optimizer_params,
                weight_decay=weight_decay,
                output_path=output_path,
                evaluation_steps=evaluation_steps,
                save_best_model=save_best_model,
                max_grad_norm=max_grad_norm,
                use_amp=use_amp,
                callback=callback, **kwargs)

In [21]:
retriever = TrainRetriever(model=model, batch_size=16)

#### Prepare training samples
train_samples = retriever.load_train(corpus, queries, qrels)
train_dataloader = retriever.prepare_train(train_samples, shuffle=True)

Adding Input Examples:   0%|          | 0/51 [00:00<?, ?it/s]

In [25]:

#### Training SBERT with cosine-product
train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)
#### training SBERT with dot-product
# train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model, similarity_fct=util.dot_score)

#### Prepare dev evaluator
ir_evaluator = retriever.load_ir_evaluator(dev_corpus, dev_queries, dev_qrels)


In [28]:
import os
os.makedirs('models/')

In [30]:
#### Provide model save path
model_save_path = 'models/'#os.path.join(pathlib.Path(__file__).parent.absolute(), "output", "{}-v1-{}".format(model_name, dataset))
os.makedirs(model_save_path, exist_ok=True)

In [None]:
#### Configure Train params
num_epochs = 1
evaluation_steps = 50
warmup_steps = int(len(train_samples) * num_epochs / retriever.batch_size * 0.1)

retriever.fit(train_objectives=[(train_dataloader, train_loss)], 
                evaluator=ir_evaluator, 
                epochs=num_epochs,
                output_path=model_save_path,
                warmup_steps=warmup_steps,
                evaluation_steps=evaluation_steps,
                use_amp=True)