# Train

In [1]:
max_seq_length = 512
model_name = "distilbert-base-uncased" 
dataset = "msmarco_tiny"

dataset_path = "../datasets/msmarco_tiny/"
corpus_file = "tiny_collection.json"
queries_file = "topics.dl20.txt"
qrels_test_file = "qrels.dl20-passage.txt"
training_set = "msmarco_triples.train.tiny.tsv"

In [2]:

from sentence_transformers import losses, models, SentenceTransformer
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.search.lexical import BM25Search as BM25
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.train import TrainRetriever
import pathlib, os, tqdm
import logging

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

  from tqdm.autonotebook import tqdm, trange


In [4]:
# data_path = f"../datasets/{dataset}"
# corpus, queries, qrels = GenericDataLoader(data_path).load(split="train")

In [5]:
import collections
import pytrec_eval
import json

def load_triplets(path):
    triplets = []
    with open(path) as f:
        for line in f:
            query, positive_passage, negative_passage = line.strip().split('\t')
            triplets.append([query, positive_passage, negative_passage])
    return triplets

def load_corpus_json(path):
    with open(path, 'r') as corpus_f:
        corpus_json = json.load(corpus_f)
    return corpus_json


triplets_temp = load_triplets(f"{dataset_path}{training_set}")
corpus = load_corpus_json(f"{dataset_path}{corpus_file}")

In [6]:
#### Lexical Retrieval using Bm25 (Elasticsearch) ####

## elasticsearch settings
hostname = "localhost" #localhost
index_name = dataset # scifact
initialize = True # True - Delete existing index and re-index all documents from scratch 

number_of_shards = 1
model = BM25(index_name=index_name, hostname=hostname, initialize=initialize, number_of_shards=number_of_shards)
bm25 = EvaluateRetrieval(model)

#### Index passages into the index (seperately)
bm25.retriever.index(corpus)


2024-06-02 16:06:55 - Activating Elasticsearch....
2024-06-02 16:06:55 - Elastic Search Credentials: {'hostname': 'localhost', 'index_name': 'msmarco_tiny', 'keys': {'title': 'title', 'body': 'txt'}, 'timeout': 100, 'retry_on_timeout': True, 'maxsize': 24, 'number_of_shards': 1, 'language': 'english'}
2024-06-02 16:06:55 - Deleting previous Elasticsearch-Index named - msmarco_tiny
2024-06-02 16:06:58 - Creating fresh Elasticsearch-Index named - msmarco_tiny


  0%|          | 0/510585 [00:00<?, ?docs/s]                


In [7]:

triplets = []
hard_negatives_max = 10

#### Retrieve BM25 hard negatives => Given a positive document, find most similar lexical documents
for query_text, pos_doc_text, neg_doc_text in tqdm.tqdm(triplets_temp, desc="Retrieve Hard Negatives using BM25"):
    hits = bm25.retriever.es.lexical_multisearch(texts=[pos_doc_text], top_hits=hard_negatives_max+1)
    for (neg_id, _) in hits[0].get("hits"):
        if corpus[neg_id]["text"] != neg_doc_text:
            neg_text = corpus[neg_id]["text"]
            triplets.append([query_text, pos_doc_text, neg_text])


Retrieve Hard Negatives using BM25: 100%|██████████| 11000/11000 [03:54<00:00, 46.86it/s]


In [8]:
import pickle

with open(f"{dataset_path}{dataset}bm25_triplets.pickle", 'wb') as f:
            pickle.dump(triplets, f, protocol=pickle.HIGHEST_PROTOCOL)

In [9]:
import pickle

with open(f"{dataset_path}{dataset}bm25_triplets.pickle", 'rb') as f:
    triplets = pickle.load(f)

In [10]:
triplets_temp[0]

['is a little caffeine ok during pregnancy',
 'We donâ\x80\x99t know a lot about the effects of caffeine during pregnancy on you and your baby. So itâ\x80\x99s best to limit the amount you get each day. If youâ\x80\x99re pregnant, limit caffeine to 200 milligrams each day. This is about the amount in 1Â½ 8-ounce cups of coffee or one 12-ounce cup of coffee.',
 'It is generally safe for pregnant women to eat chocolate because studies have shown to prove certain benefits of eating chocolate during pregnancy. However, pregnant women should ensure their caffeine intake is below 200 mg per day.']

In [11]:
triplets[0]

['is a little caffeine ok during pregnancy',
 'We donâ\x80\x99t know a lot about the effects of caffeine during pregnancy on you and your baby. So itâ\x80\x99s best to limit the amount you get each day. If youâ\x80\x99re pregnant, limit caffeine to 200 milligrams each day. This is about the amount in 1Â½ 8-ounce cups of coffee or one 12-ounce cup of coffee.',
 'Should I limit caffeine during pregnancy? If youâ\x80\x99re pregnant, you should limit the amount of caffeine you have to 200 milligrams (mg) a day â\x80\x93 the equivalent of two mugs of instant coffee. Caffeine is found naturally in lots of foods, such as coffee, tea and chocolate.']

In [12]:
triplets[1]

['is a little caffeine ok during pregnancy',
 'We donâ\x80\x99t know a lot about the effects of caffeine during pregnancy on you and your baby. So itâ\x80\x99s best to limit the amount you get each day. If youâ\x80\x99re pregnant, limit caffeine to 200 milligrams each day. This is about the amount in 1Â½ 8-ounce cups of coffee or one 12-ounce cup of coffee.',
 'Limit the amount of caffeine you get each day to 200 mg during pregnancy. Drinks and foods with caffeine incldue coffee, tea, energy drinks, soft drinks and chocolate. Limit the amount of caffeine you get each day to 200 mg during pregnancy.']

In [25]:
#### Provide any sentence-transformers or HF model
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

#### Provide a high batch-size to train better with triplets!
retriever = TrainRetriever(model=model, batch_size=12)



2024-06-02 14:22:32 - Use pytorch device_name: cuda


In [26]:

#### Prepare triplets samples
train_samples = retriever.load_train_triplets(triplets=triplets)
train_dataloader = retriever.prepare_train_triplets(train_samples)

#### Training SBERT with cosine-product
train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)

#### Prepare dev evaluator
# ir_evaluator = retriever.load_ir_evaluator(dev_corpus, dev_queries, dev_qrels)

#### If no dev set is present from above use dummy evaluator
ir_evaluator = retriever.load_dummy_evaluator()

#### Provide model save path
model_save_path = os.path.join(os.getcwd(), "../output", "{}-v2-{}-bm25-hard-negs".format(model_name, dataset))
os.makedirs(model_save_path, exist_ok=True)


Adding Input Examples: 100%|██████████| 10081/10081 [00:00<00:00, 56787.88it/s]


2024-06-02 14:22:47 - Loaded 120965 training pairs.


In [27]:
#### Configure Train params
num_epochs = 10
evaluation_steps = 5000
warmup_steps = int(len(train_samples) * num_epochs / retriever.batch_size * 0.1)

retriever.fit(train_objectives=[(train_dataloader, train_loss)], 
                evaluator=ir_evaluator, 
                epochs=num_epochs,
                output_path=model_save_path,
                warmup_steps=warmup_steps,
                evaluation_steps=evaluation_steps,
                use_amp=True)


2024-06-02 14:22:51 - Starting to Train...




Step,Training Loss,Validation Loss,Sequential Score
5000,0.1183,No log,1717364138.819354
10000,0.0551,No log,1717364900.493634
10080,0.0551,No log,1717364916.61267
15000,0.022,No log,1717365665.962539
20000,0.0136,No log,1717366427.345196
20160,0.0136,No log,1717366455.142648
25000,0.0091,No log,1717367190.91179
30000,0.0075,No log,1717367955.457952
30240,0.0075,No log,1717367996.056918
35000,0.0053,No log,1717368727.613854


2024-06-02 14:35:38 - Save model to /mnt/c/D_drive/UCSD/Quarters/Q3/DSC253-Adv_txt_mining/Project/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 14:48:20 - Save model to /mnt/c/D_drive/UCSD/Quarters/Q3/DSC253-Adv_txt_mining/Project/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 14:48:36 - Save model to /mnt/c/D_drive/UCSD/Quarters/Q3/DSC253-Adv_txt_mining/Project/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 15:01:05 - Save model to /mnt/c/D_drive/UCSD/Quarters/Q3/DSC253-Adv_txt_mining/Project/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 15:13:47 - Save model to /mnt/c/D_drive/UCSD/Quarters/Q3/DSC253-Adv_txt_mining/Project/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 15:14:15 - Save model to /mnt/c/D_drive/UCSD/Quarters/Q3/DSC253-Adv_txt_mining/Project/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 15:26:30 - Save model to /mnt/c/D_drive/UCSD/Quarters/Q3/DSC253-Adv_txt_mining/Project/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 15:39:15 - Save model to /mnt/c/D_drive/UCSD/Quarters/Q3/DSC253-Adv_txt_mining/Project/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 15:39:56 - Save model to /mnt/c/D_drive/UCSD/Quarters/Q3/DSC253-Adv_txt_mining/Project/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


                                                                             

2024-06-02 15:52:07 - Save model to /mnt/c/D_drive/UCSD/Quarters/Q3/DSC253-Adv_txt_mining/Project/slm4search/src/../output/distilbert-base-uncased-v2-msmarco_tiny-bm25-hard-negs


: 

# Evaluate

In [9]:
# Loading test set
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")

2024-05-29 03:12:57 - Loading Corpus...


100%|██████████| 5183/5183 [00:00<00:00, 17583.97it/s]


2024-05-29 03:12:58 - Loaded 5183 TEST Documents.
2024-05-29 03:12:58 - Doc Example: {'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 vers

In [10]:
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval import models
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

## Load retriever from saved model

model = DRES(models.SentenceBERT(model_save_path), batch_size=128)
retriever = EvaluateRetrieval(model, score_function="cos_sim")

#### Retrieve dense results (format of results is identical to qrels)
results = retriever.retrieve(corpus, queries)

2024-05-29 03:13:01 - Loading faiss with AVX2 support.
2024-05-29 03:13:02 - Successfully loaded faiss with AVX2 support.
2024-05-29 03:13:02 - Use pytorch device_name: cuda
2024-05-29 03:13:02 - Load pretrained SentenceTransformer: /mnt/c/D_drive/UCSD/Quarters/Q3/DSC253-Adv_txt_mining/Project/slm4search/src/../output/distilbert-base-uncased-v2-scifact-bm25-hard-negs
2024-05-29 03:13:09 - Encoding Queries...


Batches: 100%|██████████| 3/3 [00:00<00:00,  3.91it/s]


2024-05-29 03:13:10 - Sorting Corpus by document length (Longest first)...
2024-05-29 03:13:10 - Scoring Function: Cosine Similarity (cos_sim)
2024-05-29 03:13:10 - Encoding Batch 1/1...


Batches: 100%|██████████| 41/41 [00:36<00:00,  1.14it/s]


In [11]:
#### Evaluate your retrieval using NDCG@k, MAP@K ...
logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
ndcg, _map, recall, precision

2024-05-29 03:13:46 - Retriever evaluation for k in: [1, 3, 5, 10, 100, 1000]
2024-05-29 03:13:46 - For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.
2024-05-29 03:13:46 - 

2024-05-29 03:13:46 - NDCG@1: 0.6767
2024-05-29 03:13:46 - NDCG@3: 0.7256
2024-05-29 03:13:46 - NDCG@5: 0.7348
2024-05-29 03:13:46 - NDCG@10: 0.7430
2024-05-29 03:13:46 - NDCG@100: 0.7634
2024-05-29 03:13:46 - NDCG@1000: 0.7725
2024-05-29 03:13:46 - 

2024-05-29 03:13:46 - MAP@1: 0.6409
2024-05-29 03:13:46 - MAP@3: 0.7052
2024-05-29 03:13:46 - MAP@5: 0.7131
2024-05-29 03:13:46 - MAP@10: 0.7168
2024-05-29 03:13:46 - MAP@100: 0.7207
2024-05-29 03:13:46 - MAP@1000: 0.7210
2024-05-29 03:13:46 - 

2024-05-29 03:13:46 - Recall@1: 0.6409
2024-05-29 03:13:46 - Recall@3: 0.7594
2024-05-29 03:13:46 - Recall@5: 0.7855
2024-05-29 03:13:46 - Recall@10: 0.8089
2024-05-29 03:13:46 - Recall@100: 0.9061
2024-05-29 03:13:46 - Recall@1000: 0.9782

({'NDCG@1': 0.67667,
  'NDCG@3': 0.72557,
  'NDCG@5': 0.73475,
  'NDCG@10': 0.74305,
  'NDCG@100': 0.76343,
  'NDCG@1000': 0.77253},
 {'MAP@1': 0.64094,
  'MAP@3': 0.70522,
  'MAP@5': 0.71306,
  'MAP@10': 0.71684,
  'MAP@100': 0.72071,
  'MAP@1000': 0.72102},
 {'Recall@1': 0.64094,
  'Recall@3': 0.75939,
  'Recall@5': 0.7855,
  'Recall@10': 0.80894,
  'Recall@100': 0.90611,
  'Recall@1000': 0.97822},
 {'P@1': 0.67667,
  'P@3': 0.27667,
  'P@5': 0.17267,
  'P@10': 0.08933,
  'P@100': 0.01013,
  'P@1000': 0.0011})