# Download and loading the data

In [1]:
%%capture
!wget https://www.dropbox.com/scl/fi/md1qj0dz07wi66lan1aah/afternoon_session_files.zip?rlkey=ficet9hbs55wxs7e11u6yi9w0&dl=0
!unzip afternoon_session_files.zip?rlkey=ficet9hbs55wxs7e11u6yi9w0&dl=0

## Install libraries and load utils

In [4]:
!pip install pytrec_eval
!pip install sentence_transformers
import pytrec_eval
import json
import tqdm
import numpy as np
from google.colab import drive
drive.mount('/content/gdrive')
COLAB_RUN = True
base_path = "./"

Collecting pytrec_eval
  Downloading pytrec_eval-0.5.tar.gz (15 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pytrec_eval
  Building wheel for pytrec_eval (setup.py) ... [?25l[?25hdone
  Created wheel for pytrec_eval: filename=pytrec_eval-0.5-cp310-cp310-linux_x86_64.whl size=308203 sha256=d407a0b3061616ca3cbe5d44794fd5e1aa6b0984606c64d9fceae12ede6e1897
  Stored in directory: /root/.cache/pip/wheels/51/3a/cd/dcc1ddfc763987d5cb237165d8ac249aa98a23ab90f67317a8
Successfully built pytrec_eval
Installing collected packages: pytrec_eval
Successfully installed pytrec_eval-0.5
Collecting sentence_transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting transformers<5.0.0,>=4.6.0 (from sentence_transformers)
  Downloading transformers-4.32.1-p

In [5]:
import math
import sys
from datetime import datetime
import gzip
import os
import tarfile
import logging

In [6]:
from torch.utils.data import DataLoader
from sentence_transformers import LoggingHandler, util
from sentence_transformers import InputExample
from transformers import AutoTokenizer

In [7]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# CrossEncoder Class


In [8]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
import numpy as np
import logging
import os
from typing import Dict, Type, Callable, List
import transformers
import torch
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, util
from sentence_transformers.evaluation import SentenceEvaluator
class CrossEncoder():
    def __init__(self, model_name:str, num_labels:int = None, max_length:int = None, device:str = None, tokenizer_args:Dict = {},
                 default_activation_function = None):
        """
        A CrossEncoder takes exactly two sentences / texts as input and either predicts
        a score or label for this sentence pair. It can for example predict the similarity of the sentence pair
        on a scale of 0 ... 1.

        It does not yield a sentence embedding and does not work for individually sentences.

        :param model_name: Any model name from Huggingface Models Repository that can be loaded with AutoModel. We provide several pre-trained CrossEncoder models that can be used for common tasks
        :param num_labels: Number of labels of the classifier. If 1, the CrossEncoder is a regression model that outputs a continous score 0...1. If > 1, it output several scores that can be soft-maxed to get probability scores for the different classes.
        :param max_length: Max length for input sequences. Longer sequences will be truncated. If None, max length of the model will be used
        :param device: Device that should be used for the model. If None, it will use CUDA if available.
        :param tokenizer_args: Arguments passed to AutoTokenizer
        :param default_activation_function: Callable (like nn.Sigmoid) about the default activation function that should be used on-top of model.predict(). If None. nn.Sigmoid() will be used if num_labels=1, else nn.Identity()
        """

        self.config = AutoConfig.from_pretrained(model_name)
        classifier_trained = True
        if self.config.architectures is not None:
            classifier_trained = any([arch.endswith('ForSequenceClassification') for arch in self.config.architectures])

        if num_labels is None and not classifier_trained:
            num_labels = 1

        if num_labels is not None:
            self.config.num_labels = num_labels

        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config, ignore_mismatched_sizes = True) # ignore_mismatched_sizes = True for transfer learning. first post_training, then using it for binary classification
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_args)
        self.max_length = max_length

        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
            logger.info("Use pytorch device: {}".format(device))

        self._target_device = torch.device(device)

        if default_activation_function is not None:
            self.default_activation_function = default_activation_function
            try:
                self.config.sbert_ce_default_activation_function = util.fullname(self.default_activation_function)
            except Exception as e:
                logger.warning("Was not able to update config about the default_activation_function: {}".format(str(e)) )
        elif hasattr(self.config, 'sbert_ce_default_activation_function') and self.config.sbert_ce_default_activation_function is not None:
            self.default_activation_function = util.import_from_string(self.config.sbert_ce_default_activation_function)()
        else:
            self.default_activation_function = nn.Sigmoid() if self.config.num_labels == 1 else nn.Identity()

    def smart_batching_collate(self, batch):
        texts = [[] for _ in range(len(batch[0].texts))]
        labels = []

        for example in batch:
            for idx, text in enumerate(example.texts):
                texts[idx].append(text.strip())

            labels.append(example.label)

        tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length)
        labels = torch.tensor(labels, dtype=torch.float if self.config.num_labels == 1 else torch.long).to(self._target_device)

        for name in tokenized:
            tokenized[name] = tokenized[name].to(self._target_device)

        return tokenized, labels

    def smart_batching_collate_text_only(self, batch):
        texts = [[] for _ in range(len(batch[0]))]

        for example in batch:
            for idx, text in enumerate(example):
                texts[idx].append(text.strip())

        tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length)

        for name in tokenized:
            tokenized[name] = tokenized[name].to(self._target_device)

        return tokenized

    def fit(self,
            train_dataloader: DataLoader,
            evaluator: SentenceEvaluator = None,
            epochs: int = 1,
            loss_fct = None,
            activation_fct = nn.Identity(),
            scheduler: str = 'WarmupLinear',
            warmup_steps: int = 10000,
            accumulation_steps: int = 1,
            optimizer_class: Type[Optimizer] = transformers.AdamW,
            optimizer_params: Dict[str, object] = {'lr': 2e-5},
            weight_decay: float = 0.01,
            evaluation_steps: int = 0,
            output_path: str = None,
            save_best_model: bool = True,
            max_grad_norm: float = 1,
            use_amp: bool = False,
            callback: Callable[[float, int, int], None] = None,
            ):
        """
        Train the model with the given training objective
        Each training objective is sampled in turn for one batch.
        We sample only as many batches from each objective as there are in the smallest one
        to make sure of equal training with each dataset.

        :param train_dataloader: DataLoader with training InputExamples
        :param evaluator: An evaluator (sentence_transformers.evaluation) evaluates the model performance during training on held-out dev data. It is used to determine the best model that is saved to disc.
        :param epochs: Number of epochs for training
        :param loss_fct: Which loss function to use for training. If None, will use nn.BCEWithLogitsLoss() if self.config.num_labels == 1 else nn.CrossEntropyLoss()
        :param activation_fct: Activation function applied on top of logits output of model.
        :param scheduler: Learning rate scheduler. Available schedulers: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
        :param warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is increased from o up to the maximal learning rate. After these many training steps, the learning rate is decreased linearly back to zero.
        :param accumulation_steps: Number of steps to accumulate before performing a backward pass
        :param optimizer_class: Optimizer
        :param optimizer_params: Optimizer parameters
        :param weight_decay: Weight decay for model parameters
        :param evaluation_steps: If > 0, evaluate the model using evaluator after each number of training steps
        :param output_path: Storage path for the model and evaluation files
        :param save_best_model: If true, the best model (according to evaluator) is stored at output_path
        :param max_grad_norm: Used for gradient normalization.
        :param use_amp: Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0
        :param callback: Callback function that is invoked after each evaluation.
                It must accept the following three parameters in this order:
                `score`, `epoch`, `steps`
        """
        train_dataloader.collate_fn = self.smart_batching_collate

        if use_amp:
            from torch.cuda.amp import autocast
            scaler = torch.cuda.amp.GradScaler()

        self.model.to(self._target_device)

        if output_path is not None:
            os.makedirs(output_path, exist_ok=True)

        self.best_score = -9999999
        num_train_steps = int(len(train_dataloader) * epochs)

        # Prepare optimizers
        param_optimizer = list(self.model.named_parameters())

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

        optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)

        if isinstance(scheduler, str):
            scheduler = SentenceTransformer._get_scheduler(optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps)

        if loss_fct is None:
            loss_fct = nn.BCEWithLogitsLoss() if self.config.num_labels == 1 else nn.CrossEntropyLoss()


        skip_scheduler = False
        for epoch in tqdm.trange(epochs, desc="Epoch"):
            training_steps = 0
            self.model.zero_grad()
            self.model.train()
            for i, (features, labels) in tqdm.tqdm(enumerate(train_dataloader), total=(len(train_dataloader) // accumulation_steps), desc="Iteration", smoothing=0.05):
                if use_amp:
                    with autocast():
                        model_predictions = self.model(**features, return_dict=True)
                        logits = activation_fct(model_predictions.logits)
                        if self.config.num_labels == 1:
                            logits = logits.view(-1)
                        loss_value = loss_fct(logits, labels)
                        loss_value /= accumulation_steps

                    scale_before_step = scaler.get_scale()
                    scaler.scale(loss_value).backward()
                    if (i + 1) % accumulation_steps == 0:
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
                        scaler.step(optimizer)
                        scaler.update()
                        optimizer.zero_grad()

                    skip_scheduler = scaler.get_scale() != scale_before_step
                else:
                    model_predictions = self.model(**features, return_dict=True)
                    logits = activation_fct(model_predictions.logits)
                    if self.config.num_labels == 1:
                        logits = logits.view(-1)
                    loss_value = loss_fct(logits, labels)
                    loss_value /= accumulation_steps
                    loss_value.backward()
                    if (i + 1) % accumulation_steps == 0:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
                        optimizer.step()
                        optimizer.zero_grad()

                if not skip_scheduler and (i + 1) % accumulation_steps == 0:
                    scheduler.step()

                training_steps += 1

                if evaluator is not None and evaluation_steps > 0 and training_steps % evaluation_steps == 0:
                    self._eval_during_training(evaluator, output_path, save_best_model, epoch, training_steps, callback)

                    self.model.zero_grad()
                    self.model.train()

            if evaluator is not None:
                self._eval_during_training(evaluator, output_path, save_best_model, epoch, -1, callback)



    def predict(self, sentences: List[List[str]],
               batch_size: int = 32,
               show_progress_bar: bool = None,
               num_workers: int = 0,
               activation_fct = None,
               apply_softmax = False,
               convert_to_numpy: bool = True,
               convert_to_tensor: bool = False
               ):
        """
        Performs predicts with the CrossEncoder on the given sentence pairs.

        :param sentences: A list of sentence pairs [[Sent1, Sent2], [Sent3, Sent4]]
        :param batch_size: Batch size for encoding
        :param show_progress_bar: Output progress bar
        :param num_workers: Number of workers for tokenization
        :param activation_fct: Activation function applied on the logits output of the CrossEncoder. If None, nn.Sigmoid() will be used if num_labels=1, else nn.Identity
        :param convert_to_numpy: Convert the output to a numpy matrix.
        :param apply_softmax: If there are more than 2 dimensions and apply_softmax=True, applies softmax on the logits output
        :param convert_to_tensor:  Conver the output to a tensor.
        :return: Predictions for the passed sentence pairs
        """
        input_was_string = False
        if isinstance(sentences[0], str):  # Cast an individual sentence to a list with length 1
            sentences = [sentences]
            input_was_string = True

        inp_dataloader = DataLoader(sentences, batch_size=batch_size, collate_fn=self.smart_batching_collate_text_only, num_workers=num_workers, shuffle=False)

        if show_progress_bar is None:
            show_progress_bar = (logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG)

        iterator = inp_dataloader
        if show_progress_bar:
            iterator = tqdm.tqdm(inp_dataloader, desc="Batches")

        if activation_fct is None:
            activation_fct = self.default_activation_function

        pred_scores = []
        self.model.eval()
        self.model.to(self._target_device)
        with torch.no_grad():
            for features in iterator:
                model_predictions = self.model(**features, return_dict=True)
                logits = activation_fct(model_predictions.logits)

                if apply_softmax and len(logits[0]) > 1:
                    logits = torch.nn.functional.softmax(logits, dim=1)
                pred_scores.extend(logits)

        if self.config.num_labels == 1:
            pred_scores = [score[0] for score in pred_scores]

        if convert_to_tensor:
            pred_scores = torch.stack(pred_scores)
        elif convert_to_numpy:
            pred_scores = np.asarray([score.cpu().detach().numpy() for score in pred_scores])

        if input_was_string:
            pred_scores = pred_scores[0]

        return pred_scores


    def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps, callback):
        """Runs evaluation during the training"""
        if evaluator is not None:
            score = evaluator(self, output_path=output_path, epoch=epoch, steps=steps)
            if callback is not None:
                callback(score, epoch, steps)
            if score > self.best_score:
                self.best_score = score
                if save_best_model:
                    self.save(output_path)

    def save(self, path):
        """
        Saves all model and tokenizer to path
        """
        if path is None:
            return

        logger.info("Save model to {}".format(path))
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)

    def save_pretrained(self, path):
        """
        Same function as save
        """
        return self.save(path)

DEBUG:tensorflow:Falling back to TensorFlow client; we recommended you install the Cloud TPU client directly with pip install cloud-tpu-client.
DEBUG:h5py._conv:Creating converter from 7 to 5
DEBUG:h5py._conv:Creating converter from 5 to 7
DEBUG:h5py._conv:Creating converter from 7 to 5
DEBUG:h5py._conv:Creating converter from 5 to 7
INFO:numexpr.utils:Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.


# Evaluator Class

In [9]:
import numpy as np
import os
import csv
import pytrec_eval
import tqdm
from sentence_transformers import LoggingHandler, util
class CERerankingEvaluator:
    """
    This class evaluates a CrossEncoder model for the task of re-ranking.

    Given a query and a list of documents, it computes the score [query, doc_i] for all possible
    documents and sorts them in decreasing order. Then, ndcg@10 is compute to measure the quality of the ranking.

    :param samples: Must be a list and each element is of the form: {'query': '', 'positive': [], 'negative': []}. Query is the search query,
     positive is a list of positive (relevant) documents, negative is a list of negative (irrelevant) documents.
    """
    def __init__(self, samples, all_metrics: set = {"recall.1"}, name: str = '', write_csv: bool = True, show_progress_bar: bool = False):
        self.samples = samples
        self.name = name
        self.all_metrics = all_metrics

        if isinstance(self.samples, dict):
            self.samples = list(self.samples.values())

        self.csv_file = "CERerankingEvaluator" + ("_" + name if name else '') + "_results.csv"
        self.csv_headers = ["epoch", "steps"] + list(all_metrics)
        self.write_csv = write_csv
        self.mean_metrics = {}
        self.show_progress_bar = show_progress_bar
    def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
        mean_ndcg = 0

        if epoch != -1:
            if steps == -1:
                out_txt = " after epoch {}:".format(epoch)
            else:
                out_txt = " in epoch {} after {} steps:".format(epoch, steps)
        else:
            out_txt = ":"

        logger.info("CERerankingEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)

        all_ndcg_scores = []
        num_queries = 0
        num_positives = []
        num_negatives = []
        run = {}
        qrel = {}
        print("len: self.samples: " + str(len(self.samples)))
        try:
            for instance in tqdm.tqdm(self.samples):
                # print("instance: ", instance)
                qid = instance['qid']
                query = instance['query']
                positive = list(instance['positive'])
                negative = list(instance['negative'])
                positive_pids = list(instance['positive_ids'])
                negative_pids = list(instance['negative_ids'])
                docs =  negative + positive
                docs_ids = negative_pids + positive_pids
                is_relevant = [False]*len(negative) +  [True]*len(positive)

                qrel[qid] = {}
                run[qid] = {}
                for pid in positive_pids:
                    qrel[qid][pid] = 1

                if len(positive) == 0 or len(negative) == 0:
                    continue

                num_queries += 1
                num_positives.append(len(positive))
                num_negatives.append(len(negative))

                model_input = [[query, doc] for doc in docs]
                if model.config.num_labels > 1: #Cross-Encoder that predict more than 1 score, we use the last and apply softmax
                    pred_scores = model.predict(model_input, apply_softmax=True, batch_size=16, show_progress_bar = self.show_progress_bar)[:, 1].tolist()
                else:
                    pred_scores = model.predict(model_input, batch_size=16, show_progress_bar = self.show_progress_bar).tolist()
                for pred_score, did in zip(list(pred_scores), docs_ids):
                    line = "{query_id} Q0 {document_id} {rank} {score} STANDARD\n".format(query_id=qid,
                                                                                          document_id=did,
                                                                                          rank="-10",#rank,
                                                                                          score=str(pred_score))
                    run[qid][did] = float(pred_score)

            evaluator = pytrec_eval.RelevanceEvaluator(qrel, self.all_metrics)
            scores = evaluator.evaluate(run)
            self.mean_metrics = {}
            metrics_string = ""
            for metric in list(self.all_metrics):
                self.mean_metrics[metric] = np.mean([ele[metric.replace(".","_")] for ele in scores.values()])
                metrics_string = metrics_string +  "{}: {} | ".format(metric, self.mean_metrics[metric])
            print("metrics eval: ", metrics_string)
            logger.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives)))
        except Exception as e:
            logger.error("error: ", e)
        if output_path is not None and self.write_csv:
            csv_path = os.path.join(output_path, self.csv_file)
            output_file_exists = os.path.isfile(csv_path)
            with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f: # early stopping can be done by modifying this part. You can read this csv file. Then you need to count: best_step - last step + 1. if it is >earlystopping. then, you can just do sys.exit(1) to kill the process :)
                writer = csv.writer(f)
                if not output_file_exists:
                    writer.writerow(self.csv_headers)
                writer.writerow([epoch, steps, sum(self.mean_metrics.values())])
        return sum(self.mean_metrics.values())


#Data

## utils


### read collections

In [10]:
def read_collection(f_path):
  corpus = {}
  with open(f_path, "r") as fp:
    for line in tqdm.tqdm(fp, desc="reading {}".format(f_path)):
      did, dtext = line.strip().split("\t")
      corpus[did] = dtext
  return corpus
from glob import glob
def read_aila_documents(f_path):
  files = glob(corpus_path+"*.txt")
  corpus = {}
  for file_ in tqdm.tqdm(files, desc="reading {}".format(f_path)):
    content = open(file_, "r").read().split("\n")[1].split(":")[1]
    doc_id = file_.split("/")[-1].replace(".txt", "")
    corpus[doc_id] = content
  return corpus

## Initializing variables


In [27]:
base_write_path = "./gdrive/MyDrive/legal_essir/"
truncation_mode = None

num_epochs = 5
qlen = 254 # tokens
dlen = 254 # tokens
pos_neg_ratio = 99
max_train_samples = 0 #full train set
valid_max_queries = 10
valid_max_negatives_per_query = 100
model_name = 'nlpaueb/legal-bert-base-uncased'
queries_path = base_path + "queries_aila.tsv"
corpus_path = base_path + "Object_statutes/"
triples_train_path = base_path + "aila_train_qidpidtriples_v2.tsv"
triples_valid_path = base_path + "aila_valid_qidpidtriples_v2.tsv"
model_save_path = base_write_path + 'finetuned_CEs/{}_training_cross-encoder-'.format(model_name.replace("/", "-"))+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# training parameters
model_max_length = 512
batch_size = 64#32
accumulation_steps = 1
evaluation_steps = 1000#100 #5000
print("model_name {} | model_max_length {} | batch_size {} | accumulation_steps {} ".format(model_name, model_max_length, batch_size, accumulation_steps))

model_name nlpaueb/legal-bert-base-uncased | model_max_length 512 | batch_size 64 | accumulation_steps 1 


## Reading data

In [28]:
queries = read_collection(queries_path)
corpus =  read_aila_documents(corpus_path)

reading ./queries_aila.tsv: 50it [00:00, 44234.38it/s]
reading ./Object_statutes/: 100%|██████████| 197/197 [00:00<00:00, 33312.28it/s]


### reading corpus and queries: utils

In [29]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, truncation_side = "right") # right is default btw.

DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /nlpaueb/legal-bert-base-uncased/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


In [30]:
def get_truncated_dict(id_content_dict, tokenizer, max_length):
  for id_, content, in tqdm.tqdm(id_content_dict.items()):
    truncated_content = tokenizer.batch_decode(tokenizer(content, padding=True, truncation=True, return_tensors="pt", max_length=max_length)['input_ids'], skip_special_tokens=True)[0]
    id_content_dict[id_] = truncated_content
  return id_content_dict

### reading corpus and queries: main

In [31]:
queries = get_truncated_dict(queries, tokenizer, qlen)
corpus = get_truncated_dict(corpus,tokenizer, dlen)

100%|██████████| 50/50 [00:00<00:00, 309.38it/s]
100%|██████████| 197/197 [00:00<00:00, 849.34it/s]


### reading triples: utils

#### train set

In [32]:
def read_triples_train(f_path, queries, corpus, max_instances = 0 , pos_neg_ratio = 4):
  """
    :param max_instances: 0 means read full instances
  """
  samples = []
  cnt = 0
  with open(f_path, "r") as fp:
    for line in tqdm.tqdm(fp, desc="reading {}".format(f_path)):
      qid, pos_id, neg_id = line.strip().split("\t")
      query = queries[qid]
      if (cnt % (pos_neg_ratio+1)) == 0:
        passage = corpus[pos_id]
        label = 1
      else:
        passage = corpus[neg_id]
        label = 0
      samples.append(InputExample(texts=[query, passage], label=label))
      cnt += 1
      if max_instances != 0 and  cnt >= max_instances:
        break
  return samples

#### validation set

In [33]:
def read_triples_validation(f_path, queries, corpus, max_queries = 500, max_negatives_per_query = 200): # 200 negative per queries and 500 queries are good enough for evaluating during training!
  """
    :param max_instances: 0 means read full instances
  """
  samples = {}
  with open(f_path, "r") as fp:
    for line in tqdm.tqdm(fp, desc="reading {}".format(f_path)):
      qid, pos_id, neg_id = line.strip().split("\t")
      query = queries[qid]
      if qid not in samples and len(samples) < max_queries:
        samples[qid] = {'qid': qid , 'query': query, 'positive': list(), 'negative': list(), "positive_ids": list(), "negative_ids": list()}
      if qid in samples:
        if pos_id not in samples[qid]['positive_ids']:
            samples[qid]['positive'].append(corpus[pos_id])
            samples[qid]['positive_ids'].append(pos_id)
        if len(samples[qid]['negative']) < max_negatives_per_query:
            samples[qid]['negative'].append(corpus[neg_id])
            samples[qid]['negative_ids'].append(neg_id)
  return samples

### reading triples: main

#### train set

In [34]:
train_samples = read_triples_train(triples_train_path, queries, corpus, max_train_samples, pos_neg_ratio)

reading ./aila_train_qidpidtriples_v2.tsv: 18779it [00:00, 592172.28it/s]


#### validation set

In [35]:
dev_samples = read_triples_validation(triples_valid_path, queries, corpus, valid_max_queries, valid_max_negatives_per_query)

reading ./aila_valid_qidpidtriples_v2.tsv: 4501it [00:00, 880201.52it/s]


# Training

## Training parameters explaination

## Train

In [42]:
"""# Training"""
model = CrossEncoder(model_name, num_labels=1, max_length=model_max_length)
model.config.gradient_checkpointing = False#True # # we can do gradient checkpointing for all so we use less gpu memory and can have parallel run!
"""## Fit"""
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=batch_size)
evaluator = CERerankingEvaluator(dev_samples, name='train-eval',  all_metrics=
                                 {"ndcg_cut.10", "map_cut.1000", "recall.10"}
                                 )#https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/cross_encoder/evaluation/CERerankingEvaluator.py
warmup_steps = 500
logging.info("Warmup-steps: {}".format(warmup_steps))
model.fit(train_dataloader=train_dataloader,
          evaluator=evaluator,
          evaluation_steps= evaluation_steps,
          epochs=num_epochs,
          warmup_steps=warmup_steps,
          output_path=model_save_path,
          accumulation_steps = accumulation_steps,#32, #batch 1, accumulation 32, real batch will be 32 then:)
          use_amp=True,
          optimizer_params = {'lr': 2e-5}, # sentence bert config!
          weight_decay = 0.01 # they set wd as 0 for adam. for adamw which is the correct impl of adam, wd is 0.01!
          )
model.save(model_save_path+'-latest')

DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /nlpaueb/legal-bert-base-uncased/resolve/main/config.json HTTP/1.1" 200 0
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /nlpaueb/legal-bert-base-uncased/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
INFO:root:Use pytorch device: cuda
INFO:root:Warmup-steps: 500
Epoch:   0%|          | 0/5 [00:00<?, ?it/s]
Iteration:   0%|          | 0/294 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/294 [00:00<02:04,  2.36it/s][A
Iteration:   1%|          | 2/294 [00:00<02:03,  2.36it/s][A
Iteration:   1%|          | 3/294 [00:01<02:03,  2.36it/s][A
Iteration:   1%|▏         | 4/294 [00:01<02:02,  2.37it/s][A
I

len: self.samples: 10



  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [00:00<00:06,  1.33it/s][A
 20%|██        | 2/10 [00:01<00:06,  1.33it/s][A
 30%|███       | 3/10 [00:02<00:05,  1.32it/s][A
 40%|████      | 4/10 [00:03<00:04,  1.33it/s][A
 50%|█████     | 5/10 [00:03<00:03,  1.33it/s][A
 60%|██████    | 6/10 [00:04<00:03,  1.33it/s][A
 70%|███████   | 7/10 [00:05<00:02,  1.32it/s][A
 80%|████████  | 8/10 [00:06<00:01,  1.32it/s][A
 90%|█████████ | 9/10 [00:06<00:00,  1.33it/s][A
100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
INFO:root:Queries: 10 	 Positives: Min 2.0, Mean 4.2, Max 5.0 	 Negatives: Min 100.0, Mean 100.0, Max 100.0
INFO:root:Save model to ./gdrive/MyDrive/legal_essir/finetuned_CEs/nlpaueb-legal-bert-base-uncased_training_cross-encoder--2023-08-29_20-31-38


metrics eval:  map_cut.1000: 0.3462564228429256 | ndcg_cut.10: 0.41050826011882 | recall.10: 0.3516666666666666 | 


Epoch:  20%|██        | 1/5 [02:11<08:45, 131.44s/it]
Iteration:   0%|          | 0/294 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/294 [00:00<02:01,  2.40it/s][A
Iteration:   1%|          | 2/294 [00:00<02:01,  2.41it/s][A
Iteration:   1%|          | 3/294 [00:01<02:01,  2.40it/s][A
Iteration:   1%|▏         | 4/294 [00:01<02:00,  2.40it/s][A
Iteration:   2%|▏         | 5/294 [00:02<02:00,  2.40it/s][A
Iteration:   2%|▏         | 6/294 [00:02<02:00,  2.40it/s][A
Iteration:   2%|▏         | 7/294 [00:02<01:59,  2.40it/s][A
Iteration:   3%|▎         | 8/294 [00:03<01:59,  2.39it/s][A
Iteration:   3%|▎         | 9/294 [00:03<01:59,  2.38it/s][A
Iteration:   3%|▎         | 10/294 [00:04<01:59,  2.38it/s][A
Iteration:   4%|▎         | 11/294 [00:04<01:59,  2.37it/s][A
Iteration:   4%|▍         | 12/294 [00:05<01:58,  2.38it/s][A
Iteration:   4%|▍         | 13/294 [00:05<01:58,  2.38it/s][A
Iteration:   5%|▍         | 14/294 [00:05<01:57,  2.38it/s][A
Iteration:   5%|▌  

len: self.samples: 10



  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [00:00<00:06,  1.33it/s][A
 20%|██        | 2/10 [00:01<00:06,  1.33it/s][A
 30%|███       | 3/10 [00:02<00:05,  1.33it/s][A
 40%|████      | 4/10 [00:03<00:04,  1.33it/s][A
 50%|█████     | 5/10 [00:03<00:03,  1.33it/s][A
 60%|██████    | 6/10 [00:04<00:03,  1.33it/s][A
 70%|███████   | 7/10 [00:05<00:02,  1.33it/s][A
 80%|████████  | 8/10 [00:06<00:01,  1.33it/s][A
 90%|█████████ | 9/10 [00:06<00:00,  1.33it/s][A
100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
INFO:root:Queries: 10 	 Positives: Min 2.0, Mean 4.2, Max 5.0 	 Negatives: Min 100.0, Mean 100.0, Max 100.0
INFO:root:Save model to ./gdrive/MyDrive/legal_essir/finetuned_CEs/nlpaueb-legal-bert-base-uncased_training_cross-encoder--2023-08-29_20-31-38


metrics eval:  map_cut.1000: 0.3844759484368392 | ndcg_cut.10: 0.503059523993203 | recall.10: 0.61 | 


Epoch:  40%|████      | 2/5 [04:22<06:34, 131.50s/it]
Iteration:   0%|          | 0/294 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/294 [00:00<02:01,  2.41it/s][A
Iteration:   1%|          | 2/294 [00:00<02:01,  2.41it/s][A
Iteration:   1%|          | 3/294 [00:01<02:01,  2.40it/s][A
Iteration:   1%|▏         | 4/294 [00:01<02:00,  2.40it/s][A
Iteration:   2%|▏         | 5/294 [00:02<02:00,  2.40it/s][A
Iteration:   2%|▏         | 6/294 [00:02<01:59,  2.40it/s][A
Iteration:   2%|▏         | 7/294 [00:02<01:59,  2.40it/s][A
Iteration:   3%|▎         | 8/294 [00:03<01:59,  2.40it/s][A
Iteration:   3%|▎         | 9/294 [00:03<01:58,  2.40it/s][A
Iteration:   3%|▎         | 10/294 [00:04<01:58,  2.40it/s][A
Iteration:   4%|▎         | 11/294 [00:04<01:58,  2.39it/s][A
Iteration:   4%|▍         | 12/294 [00:05<01:57,  2.40it/s][A
Iteration:   4%|▍         | 13/294 [00:05<01:57,  2.39it/s][A
Iteration:   5%|▍         | 14/294 [00:05<01:56,  2.39it/s][A
Iteration:   5%|▌  

len: self.samples: 10



  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [00:00<00:06,  1.33it/s][A
 20%|██        | 2/10 [00:01<00:06,  1.33it/s][A
 30%|███       | 3/10 [00:02<00:05,  1.33it/s][A
 40%|████      | 4/10 [00:03<00:04,  1.33it/s][A
 50%|█████     | 5/10 [00:03<00:03,  1.33it/s][A
 60%|██████    | 6/10 [00:04<00:03,  1.33it/s][A
 70%|███████   | 7/10 [00:05<00:02,  1.33it/s][A
 80%|████████  | 8/10 [00:06<00:01,  1.33it/s][A
 90%|█████████ | 9/10 [00:06<00:00,  1.33it/s][A
100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
INFO:root:Queries: 10 	 Positives: Min 2.0, Mean 4.2, Max 5.0 	 Negatives: Min 100.0, Mean 100.0, Max 100.0
INFO:root:Save model to ./gdrive/MyDrive/legal_essir/finetuned_CEs/nlpaueb-legal-bert-base-uncased_training_cross-encoder--2023-08-29_20-31-38


metrics eval:  map_cut.1000: 0.4722537717249768 | ndcg_cut.10: 0.572427431535708 | recall.10: 0.63 | 


Epoch:  60%|██████    | 3/5 [06:36<04:24, 132.19s/it]
Iteration:   0%|          | 0/294 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/294 [00:00<02:09,  2.26it/s][A
Iteration:   1%|          | 2/294 [00:00<02:05,  2.33it/s][A
Iteration:   1%|          | 3/294 [00:01<02:04,  2.35it/s][A
Iteration:   1%|▏         | 4/294 [00:01<02:02,  2.36it/s][A
Iteration:   2%|▏         | 5/294 [00:02<02:02,  2.36it/s][A
Iteration:   2%|▏         | 6/294 [00:02<02:01,  2.37it/s][A
Iteration:   2%|▏         | 7/294 [00:02<02:00,  2.37it/s][A
Iteration:   3%|▎         | 8/294 [00:03<02:00,  2.38it/s][A
Iteration:   3%|▎         | 9/294 [00:03<01:59,  2.38it/s][A
Iteration:   3%|▎         | 10/294 [00:04<01:59,  2.38it/s][A
Iteration:   4%|▎         | 11/294 [00:04<01:58,  2.38it/s][A
Iteration:   4%|▍         | 12/294 [00:05<01:58,  2.38it/s][A
Iteration:   4%|▍         | 13/294 [00:05<01:58,  2.38it/s][A
Iteration:   5%|▍         | 14/294 [00:05<01:57,  2.38it/s][A
Iteration:   5%|▌  

len: self.samples: 10



  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [00:00<00:06,  1.33it/s][A
 20%|██        | 2/10 [00:01<00:06,  1.33it/s][A
 30%|███       | 3/10 [00:02<00:05,  1.33it/s][A
 40%|████      | 4/10 [00:03<00:04,  1.33it/s][A
 50%|█████     | 5/10 [00:03<00:03,  1.33it/s][A
 60%|██████    | 6/10 [00:04<00:03,  1.33it/s][A
 70%|███████   | 7/10 [00:05<00:02,  1.33it/s][A
 80%|████████  | 8/10 [00:06<00:01,  1.33it/s][A
 90%|█████████ | 9/10 [00:06<00:00,  1.33it/s][A
100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
INFO:root:Queries: 10 	 Positives: Min 2.0, Mean 4.2, Max 5.0 	 Negatives: Min 100.0, Mean 100.0, Max 100.0
INFO:root:Save model to ./gdrive/MyDrive/legal_essir/finetuned_CEs/nlpaueb-legal-bert-base-uncased_training_cross-encoder--2023-08-29_20-31-38


metrics eval:  map_cut.1000: 0.4845199883865149 | ndcg_cut.10: 0.5727487308637647 | recall.10: 0.6216666666666667 | 


Epoch:  80%|████████  | 4/5 [08:47<02:12, 132.06s/it]
Iteration:   0%|          | 0/294 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/294 [00:00<02:03,  2.38it/s][A
Iteration:   1%|          | 2/294 [00:00<02:02,  2.39it/s][A
Iteration:   1%|          | 3/294 [00:01<02:01,  2.39it/s][A
Iteration:   1%|▏         | 4/294 [00:01<02:01,  2.38it/s][A
Iteration:   2%|▏         | 5/294 [00:02<02:01,  2.38it/s][A
Iteration:   2%|▏         | 6/294 [00:02<02:00,  2.39it/s][A
Iteration:   2%|▏         | 7/294 [00:02<02:00,  2.39it/s][A
Iteration:   3%|▎         | 8/294 [00:03<01:59,  2.39it/s][A
Iteration:   3%|▎         | 9/294 [00:03<01:59,  2.39it/s][A
Iteration:   3%|▎         | 10/294 [00:04<01:58,  2.39it/s][A
Iteration:   4%|▎         | 11/294 [00:04<01:58,  2.39it/s][A
Iteration:   4%|▍         | 12/294 [00:05<01:58,  2.39it/s][A
Iteration:   4%|▍         | 13/294 [00:05<01:57,  2.39it/s][A
Iteration:   5%|▍         | 14/294 [00:05<01:57,  2.39it/s][A
Iteration:   5%|▌  

len: self.samples: 10



  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [00:00<00:06,  1.33it/s][A
 20%|██        | 2/10 [00:01<00:06,  1.33it/s][A
 30%|███       | 3/10 [00:02<00:05,  1.32it/s][A
 40%|████      | 4/10 [00:03<00:04,  1.33it/s][A
 50%|█████     | 5/10 [00:03<00:03,  1.33it/s][A
 60%|██████    | 6/10 [00:04<00:03,  1.33it/s][A
 70%|███████   | 7/10 [00:05<00:02,  1.33it/s][A
 80%|████████  | 8/10 [00:06<00:01,  1.33it/s][A
 90%|█████████ | 9/10 [00:06<00:00,  1.33it/s][A
100%|██████████| 10/10 [00:07<00:00,  1.33it/s]
INFO:root:Queries: 10 	 Positives: Min 2.0, Mean 4.2, Max 5.0 	 Negatives: Min 100.0, Mean 100.0, Max 100.0
INFO:root:Save model to ./gdrive/MyDrive/legal_essir/finetuned_CEs/nlpaueb-legal-bert-base-uncased_training_cross-encoder--2023-08-29_20-31-38


metrics eval:  map_cut.1000: 0.526240220991476 | ndcg_cut.10: 0.633306121866959 | recall.10: 0.6916666666666667 | 


Epoch: 100%|██████████| 5/5 [11:01<00:00, 132.23s/it]
INFO:root:Save model to ./gdrive/MyDrive/legal_essir/finetuned_CEs/nlpaueb-legal-bert-base-uncased_training_cross-encoder--2023-08-29_20-31-38-latest
