## Fine-tuning L1 Dense Retriever (MPNET, E5)

## Imports

In [None]:
%pip install --upgrade tqdm==4.66.5 blingfire einops accelerate>=0.26.0 datasets transformers[torch] sentence-transformers

In [2]:
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import ContrastiveLoss, TripletLoss, TripletDistanceMetric


from tqdm import tqdm
import numpy as np
from random import choice

import json
import glob
import torch
import torch.nn.functional as F
import subprocess
from io import BytesIO
import pandas as pd
from random import shuffle

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

## Data Loading

In [None]:
!git clone https://github.com/RegNLP/ObliQADataset.git
!git clone https://github.com/usnistgov/trec_eval.git
!cd trec_eval && make

In [4]:
from itertools import chain

with open("train_gt.trec", "w") as f:
    json_t = json.loads(open("ObliQADataset/ObliQA_train.json").read())
    cache = set()
    for each_question in json_t:
        if each_question['QuestionID'] in cache:
            continue
        cache.add(each_question['QuestionID'])
        for selected_passage in each_question['Passages']:
            line = f"{each_question['QuestionID']} Q0 {selected_passage['DocumentID']}:{selected_passage['PassageID'].replace(' ', '_')} 1"
            f.write(line + "\n")

with open("test_gt.trec", "w") as f:
    json_t = json.loads(open("ObliQADataset/ObliQA_test.json").read())
    cache = set()
    for each_question in json_t:
        if each_question['QuestionID'] in cache:
            continue
        cache.add(each_question['QuestionID'])
        for selected_passage in each_question['Passages']:
            line = f"{each_question['QuestionID']} Q0 {selected_passage['DocumentID']}:{selected_passage['PassageID'].replace(' ', '_')} 1"
            f.write(line + "\n")

## Custom Data Manager for Sentence Transformer

In [5]:
class evaluate:
    """Sentence Transformer Evaluator
        - Allows dynamic hard negative sampling every k iterations (epochs).
        - Enable metric evaluation using trec function directly.
    """

    def load_json_files_from_directory(self, directory_path):
        """Loads all JSON files from a given directory into a list of JSON objects."""

        json_files = glob.glob(directory_path + "/*.json")
        json_data_list = []
        for json_file in json_files:
            with open(json_file, 'r') as f:
                try:
                    json_data = json.load(f)
                    json_data_list.append((json_file, json_data))
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON in file {json_file}: {e}")
        return json_data_list

    def get_passage_id(self, e):
        """For a given document, return global passage id"""
        return str(e['DocumentID'])+":"+e['PassageID'].replace(" ", "_")

    def __init__(self, topk=10):
        # Topk input parameters allows the threshold for hard negative sampling
        # as well as Recall@k

        self.passages, self.passages_id = self.get_data()
        self.test_questions, self.test_questions_id, self.test_gt = self.get_questions("test")
        self.train_questions, self.train_questions_id, self.train_gt = self.get_questions("train")
        self.passage_id_to_index = {e: i for i, e in enumerate(self.passages_id)}
        self.topk=topk

    def get_questions(self, type):
        # type: train or test (file name)
        with open(f'ObliQADataset/ObliQA_{type}.json', 'r') as file:
            data = json.load(file)
            questions = []
            questions_id =  []
            gt = []
            for each_question in tqdm(data[:]):
                if each_question['QuestionID'] not in questions_id and len(each_question['Question'].split(" ")) > 5:
                    questions.append(each_question['Question'])
                    questions_id.append(each_question['QuestionID'])
                    gt.append([self.get_passage_id(e) for e in each_question['Passages']])
                else:
                    print(each_question['Question'], len(each_question['Question'].split(" ")))

            return questions, questions_id, gt

    def get_data(self):
        # returns all the passages in dataset and corresponding id's for retrieval
        all_sentences = {}
        for file, data in tqdm(self.load_json_files_from_directory("ObliQADataset/StructuredRegulatoryDocuments")):
            sentences = {self.get_passage_id(e): e['Passage'] for e in data if len(e['Passage'].split(" ")) > 10}
            all_sentences |= sentences

        passages = []
        passages_id = []
        for i, e in all_sentences.items():
            passages.append("Article: "+i+"\n "+e)
            passages_id.append(i)
        return passages, passages_id

    def get_embedding(self, e):
        # return L2 normalized embedding
        if len(e) > 0:
            embeddings = torch.Tensor(model.encode(e, batch_size=16, show_progress_bar = False))
            return np.array(F.normalize(embeddings, p=2, dim=1))
        else:
            return e

    def compute_batch_embeddings(self, sentences, batch_size=16):
        output_embedding = []
        with torch.no_grad():
            for e in tqdm(range(0, len(sentences), batch_size)):
                try:
                    output_embedding.append([e for e in self.get_embedding(sentences[e:e+batch_size])])
                except Exception as ee:
                    print(ee)
                    print(sentences[e:e+batch_size])
                    raise ee

        # TODO: Yash why? chain
        return np.concatenate(output_embedding)

    def evaluate(self):
        """Evaluation Function. Computes Recall@k and MAP@k for training and test dataset individually"""

        self.passages_embeddings = self.compute_batch_embeddings(self.passages)

        self.train_question_embeddings = self.compute_batch_embeddings(self.train_questions)
        self.test_question_embeddings = self.compute_batch_embeddings(self.test_questions)

        with open("train.pred_rels", "w") as pred_rels:
            relevant_passages = self.train_question_embeddings.dot(self.passages_embeddings.T)
            indices = np.argsort(relevant_passages, axis=1)
            for i in tqdm(range(len(relevant_passages))):
                top_10_passages = indices[i][-self.topk:]
                top_10_passages = top_10_passages[::-1]
                for c, p in enumerate(top_10_passages):
                    line = f"{self.train_questions_id[i]} 0 {self.passages_id[p]} {c+1} {relevant_passages[i][p]} alg"
                    pred_rels.write(line + "\n")

        with open("test.pred_rels", "w") as pred_rels:
            relevant_passages = self.test_question_embeddings.dot(self.passages_embeddings.T)
            indices = np.argsort(relevant_passages, axis=1)
            for i in tqdm(range(len(relevant_passages))):
                top_10_passages = indices[i][-self.topk:]
                top_10_passages = top_10_passages[::-1]
                for c, p in enumerate(top_10_passages):
                    line = f"{self.test_questions_id[i]} 0 {self.passages_id[p]} {c+1} {relevant_passages[i][p]} alg"
                    pred_rels.write(line + "\n")


        train_output = subprocess.run(["trec_eval/trec_eval", "-m", f"recall.{self.topk}", "-m", f"map_cut.{self.topk}", "train_gt.trec", "train.pred_rels"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        test_output = subprocess.run(["trec_eval/trec_eval", "-m", f"recall.{self.topk}", "-m", f"map_cut.{self.topk}", "test_gt.trec", "test.pred_rels"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        try:
            train_scores = np.array(pd.read_csv(BytesIO(train_output.stdout), sep="\t", names=["alg", "_", "scores"], index_col=['alg'])['scores'])
            test_scores = np.array(pd.read_csv(BytesIO(test_output.stdout), sep="\t", names=["alg", "_", "scores"], index_col=['alg'])['scores'])
            return {
                "train_recall": train_scores[0],
                "train_map": train_scores[1],
                "test_recall": test_scores[0],
                "test_map": test_scores[1],
            }
        except:
            print(train_output.stderr, test_output.stderr)


    def train(self, epoch_size = 500, total_epocs=200):
        """train loop, used to control the negative dynamic hard mining.
        We run this loop for total of epoch_size*total_epocs times,
        with negative samples getting refreshed every `epoch_size` times.
        """
        for num in range(0, epoch_size*total_epocs):
            if num % epoch_size == 0:
                cache_passages = []
                distances = []
                pair = []
                relevant_passages = self.train_question_embeddings.dot(self.passages_embeddings.T)
                # embeddings are computed while evaluation, so no need to re-calculate them
                # deleting the objects might come handy in memory-constraint scenarios as we already have list of relevant passages.
                # del self.passages_embeddings
                # del self.train_question_embeddings
                # del self.test_question_embeddings
                indices = np.argsort(-1*relevant_passages)
                top_10 = 0
                """
                Logic for sampling is very simple,
                - if already in top_k, don't use them for training.
                - else use the neighbors which are performing better.
                We have made the contraint even strict (x*top_k) where x < 1 to improve relative performance.
                """

                for i in tqdm(range(len(relevant_passages))):
                    gt_positing = [(ind, e) for ind, e in enumerate(indices[i]) if self.passages_id[e] in self.train_gt[i]]
                    for gt0, gt1 in gt_positing:
                        if gt0 < 0.6*self.topk:
                            continue
                        gt_score = relevant_passages[i][gt1]
                        for closest_better_passage in indices[i][gt0-15:gt0-2]:
                            if self.passages_id[closest_better_passage] in self.train_gt[i] or gt0 < 0.6*self.topk:
                                continue
                            distances.append((gt_score - relevant_passages[i][closest_better_passage])*gt0)
                            pair.append([self.train_questions[i], self.passages[gt1], self.passages[closest_better_passage], gt0])
                del indices
                del relevant_passages
                current_batch = [pair[e] for e in np.argsort(distances)[:epoch_size]]
                for e in current_batch:
                    cache_passages.append({"anchor": e[0], "positive": e[1], "negative": e[2]})
                shuffle(cache_passages)

            yield cache_passages[num%epoch_size]

In [6]:
# necessary to override these classes to integrate the custom evaluator

from sentence_transformers.evaluation import SentenceEvaluator
class SmartEvaluator(SentenceEvaluator):
    def __init__(self):
        self.primary_metric="recall"

    def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1):
        return data_manager.evaluate()

class MyIterableDataset(torch.utils.data.IterableDataset):

    def __init__(self):
        super(MyIterableDataset).__init__()

    def __iter__(self):
        for e in data_manager.train(epoch_size=820*16, total_epocs=100):
            yield e
    def __len__(self):
        return 820*16*100


## Initializing

In [None]:
data_manager = evaluate(topk=10)

In [None]:
MODEL_NAME = "intfloat/e5-base-v2"
model = SentenceTransformer(MODEL_NAME)

loss = TripletLoss(model, distance_metric=TripletDistanceMetric.COSINE, triplet_margin=0.2)

In [None]:
# disable gradient upgrades in initial layers to save on compute and keep it stable
for e in model[0].named_parameters():
    if ".pooler." not in e[0] and 'layer.11' not in e[0]:
      e[1].requires_grad = False
    else:
        print(e[0])

## Fine tuning

In [10]:
from sentence_transformers.training_args import SentenceTransformerTrainingArguments, BatchSamplers
import accelerate

args = SentenceTransformerTrainingArguments(
    output_dir=f"regnlp/",
    # num_train_epochs=200,
    lr_scheduler_type='inverse_sqrt',
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    learning_rate=2e-6,
    warmup_ratio=0,
    fp16=True,
    bf16=False,
    batch_sampler=BatchSamplers.NO_DUPLICATES,
    eval_strategy="steps",
    eval_steps=800,
    save_strategy="steps",
    save_steps=800,
    save_total_limit=3,
    logging_steps=10,
    run_name="layer11.e5-basev2",
    # report_to="mlflow",
    report_to="wandb",
    max_steps=820*100,
    metric_for_best_model="test_recall",
)


In [None]:
%env WANDB_WATCH=all

In [12]:
from sentence_transformers.trainer import SentenceTransformerTrainer
from torch.utils.data import IterableDataset
from typing import Dict, Union, Any
class custom_trainer(SentenceTransformerTrainer):
    def get_train_dataloader(self):
        return torch.utils.data.DataLoader(MyIterableDataset(), num_workers=0, batch_size=16, collate_fn= self.data_collator,)


trainer = custom_trainer(
    model=model,
    args=args,
    train_dataset=None,
    evaluator=SmartEvaluator(),
    eval_dataset=None,
    # max_steps=200*500/8,
    data_collator=None,
    loss=loss,
)

In [None]:
# prior running is important for cache building
print(trainer.evaluate())
t = trainer.train()

In [None]:
trainer.evaluate()