In [1]:
import os
import sys
import yaml
import json
import tiktoken
import openai
import torch
from torch.utils.data import DataLoader

from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator

import math
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, util, InputExample, SentencesDataset
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, InformationRetrievalEvaluator
import logging
from datetime import datetime
import gzip
import csv

import tarfile
import tqdm
import numpy as np
import wandb

from transformers import AutoTokenizer, AutoModel

root_path = '/home/ec2-user/sarang/wiki_cheat'

sys.path.insert(0, os.path.abspath(root_path))
os.chdir(root_path)


In [2]:
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])

## Load dataset

In [3]:
### Combined dataset
test_data_path = 'data/combined_test_data.json'
with open(test_data_path, 'r') as fp:
    combined_test_data = json.load(fp)

train_data_path = 'data/combined_train_data.json'
with open(train_data_path, 'r') as fp:
    combined_train_data = json.load(fp)

### Synth dataset
test_data_path = 'data/test_data_synth_nf.json'
with open(test_data_path, 'r') as fp:
    test_data_synth = json.load(fp)

train_data_path = 'data/train_data_synth_nf.json'
with open(train_data_path, 'r') as fp:
    train_data_synth = json.load(fp)

### Wikiqa dataset
test_data_path = 'data/test_data_wikiqa_nf.json'
with open(test_data_path, 'r') as fp:
    test_data_wikiqa = json.load(fp)

train_data_path = 'data/train_data_wikiqa_nf.json'
with open(train_data_path, 'r') as fp:
    train_data_wikiqa = json.load(fp)

In [43]:
combined_test_data[0]

{'query': 'HOW AFRICAN AMERICANS WERE IMMIGRATED TO THE US',
 'title': 'African immigration to the United States',
 'pos': 'As such, African immigrants are to be distinguished from African American people, the latter of whom are descendants of mostly West and Central Africans who were involuntarily brought to the United States by means of the historic Atlantic slave trade .',
 'negs': ['African immigration to the United States refers to immigrants to the United States who are or were nationals of Africa .',
  'The term African in the scope of this article refers to geographical or national origins rather than racial affiliation.',
  'From the Immigration and Nationality Act of 1965 to 2007, an estimated total of 0.8 to 0.9 million Africans immigrated to the United States, accounting for roughly 3.3% of total immigration to the United States during this period.',
  'African immigrants in the United States come from almost all regions in Africa and do not constitute a homogeneous group.'

#### Quickly check how many samples are exceeding the max token limit. Only 1 so not a lot. Do not need to chunk the wiki passages further. 

In [33]:


tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')

input_examples = []
for data in combined_train_data:
    input_examples.append(InputExample(texts=[data['query'], data['pos']]))
len(input_examples)

max_toks = 0
cnt = 0

for inp in input_examples:
    text = inp.texts[1]
    tokens = tokenizer([text], padding=True)
    if len(tokens['input_ids'][0]) > 512:
        cnt+=1
    max_toks = max(max_toks, len(tokens['input_ids'][0]))

max_toks, cnt

Token indices sequence length is longer than the specified maximum sequence length for this model (545 > 512). Running this sequence through the model will result in indexing errors


(545, 1)

## Evaluation Data

In [4]:
test_queries = { str(idx): data['query'] for idx, data in enumerate(combined_test_data) }

test_passages = { str(idx)+'_'+str(jdx+1): neg for idx, data in enumerate(combined_test_data) 
                 for jdx, neg in enumerate(data['negs']) }

for idx, data in enumerate(combined_test_data):
    test_passages[str(idx)+'_'+str(0)] = data['pos']

test_relevant_docs = { str(idx): set([str(idx)+'_'+str(0)]) for idx, data in enumerate(combined_test_data) }

In [5]:
len(test_passages)

24093

## Train code

In [41]:
def train(train_config, train_input_examples, test_queries, test_passages, test_relevant_docs):
    train_config['model_save_path'] = 'train_embedder/models/'+train_config['base_model'].replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    model = SentenceTransformer(train_config['base_model'])
    train_dataloader = DataLoader(input_examples, shuffle=True, batch_size=train_config['train_batch_size'])
    train_loss = losses.MultipleNegativesRankingLoss(model=model)
    
    evaluator = InformationRetrievalEvaluator(test_queries, test_passages, test_relevant_docs)
    warmup_steps = math.ceil(len(train_dataloader) * train_config['epochs'] * 0.1) #10% of train data for warm-up
    logging.info("Warmup-steps: {}".format(warmup_steps))
    
    # Train the model
    model.fit(train_objectives=[(train_dataloader, train_loss)],
              evaluator=evaluator,
              epochs=train_config['epochs'],
              warmup_steps=warmup_steps,
              output_path=train_config['model_save_path'],
              use_amp=True,
              show_progress_bar=True
            )

## Retriever Experiments

### v1.0 Base model, with 16 as the batch size, with the only the synth data
### v1.1 Base model, with 44 as the batch size, with the only the synth data
### v2.0 Base model, with 44 as the batch size, with the combined dataset

In [42]:
%%time
input_examples = []
for data in combined_train_data:
    input_examples.append(InputExample(texts=[data['query'], data['pos']]))

base_model_name = 'sentence-transformers/all-mpnet-base-v2'

train_config ={
        "base_model": base_model_name,
        "epochs": 10,
        "train_batch_size" : 44,
        "warmup_steps": 50,
        "max_length": 512,
        "device": 'cuda',
        "train": True
}

train(train_config, input_examples, test_queries, test_passages, test_relevant_docs)

2024-01-28 15:48:56 - Load pretrained SentenceTransformer: sentence-transformers/all-mpnet-base-v2
2024-01-28 15:48:56 - Use pytorch device: cuda
2024-01-28 15:48:56 - Warmup-steps: 28


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/274 [00:00<?, ?it/s]

2024-01-28 15:49:56 - Information Retrieval Evaluation on  dataset after epoch 0:
2024-01-28 15:50:32 - Queries: 2184
2024-01-28 15:50:32 - Corpus: 24093

2024-01-28 15:50:32 - Score-Function: cos_sim
2024-01-28 15:50:32 - Accuracy@1: 83.38%
2024-01-28 15:50:32 - Accuracy@3: 95.01%
2024-01-28 15:50:32 - Accuracy@5: 97.02%
2024-01-28 15:50:32 - Accuracy@10: 98.63%
2024-01-28 15:50:32 - Precision@1: 83.38%
2024-01-28 15:50:32 - Precision@3: 31.67%
2024-01-28 15:50:32 - Precision@5: 19.40%
2024-01-28 15:50:32 - Precision@10: 9.86%
2024-01-28 15:50:32 - Recall@1: 83.38%
2024-01-28 15:50:32 - Recall@3: 95.01%
2024-01-28 15:50:32 - Recall@5: 97.02%
2024-01-28 15:50:32 - Recall@10: 98.63%
2024-01-28 15:50:32 - MRR@10: 0.8943
2024-01-28 15:50:32 - NDCG@10: 0.9172
2024-01-28 15:50:32 - MAP@100: 0.8950
2024-01-28 15:50:32 - Score-Function: dot_score
2024-01-28 15:50:32 - Accuracy@1: 83.42%
2024-01-28 15:50:32 - Accuracy@3: 95.05%
2024-01-28 15:50:32 - Accuracy@5: 97.02%
2024-01-28 15:50:32 - Acc

### v3.0 Base model, with 16 as the batch size, with the combined dataset, adding a hard negative as well this time
#### Batch size 32 and above kept throwing OOM
#### So there are about ~1000 dps which don't have a neg

In [5]:
%%time
base_model_name = 'sentence-transformers/all-mpnet-base-v2'

input_examples = []
for data in combined_train_data:
    if data['negs']:
        input_examples.append(InputExample(texts=[data['query'], data['pos'],data['negs'][0]]))
        
train_config ={
        "base_model": base_model_name,
        "epochs": 10,
        "train_batch_size" : 44,
        "warmup_steps": 50,
        "max_length": 512,
        "device": 'cuda',
        "train": True
}

train(train_config, input_examples, test_queries, test_passages, test_relevant_docs)

2024-01-28 02:55:09 - Load pretrained SentenceTransformer: sentence-transformers/all-mpnet-base-v2
2024-01-28 02:55:09 - Use pytorch device: cuda
2024-01-28 02:55:09 - Warmup-steps: 479


Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/479 [00:00<?, ?it/s]

2024-01-28 02:56:42 - Information Retrieval Evaluation on  dataset after epoch 0:
2024-01-28 02:57:18 - Queries: 2184
2024-01-28 02:57:18 - Corpus: 24093

2024-01-28 02:57:18 - Score-Function: cos_sim
2024-01-28 02:57:18 - Accuracy@1: 81.64%
2024-01-28 02:57:18 - Accuracy@3: 92.58%
2024-01-28 02:57:18 - Accuracy@5: 95.33%
2024-01-28 02:57:18 - Accuracy@10: 97.44%
2024-01-28 02:57:18 - Precision@1: 81.64%
2024-01-28 02:57:18 - Precision@3: 30.86%
2024-01-28 02:57:18 - Precision@5: 19.07%
2024-01-28 02:57:18 - Precision@10: 9.74%
2024-01-28 02:57:18 - Recall@1: 81.64%
2024-01-28 02:57:18 - Recall@3: 92.58%
2024-01-28 02:57:18 - Recall@5: 95.33%
2024-01-28 02:57:18 - Recall@10: 97.44%
2024-01-28 02:57:18 - MRR@10: 0.8751
2024-01-28 02:57:18 - NDCG@10: 0.8995
2024-01-28 02:57:18 - MAP@100: 0.8761
2024-01-28 02:57:18 - Score-Function: dot_score
2024-01-28 02:57:18 - Accuracy@1: 81.68%
2024-01-28 02:57:18 - Accuracy@3: 92.58%
2024-01-28 02:57:18 - Accuracy@5: 95.33%
2024-01-28 02:57:18 - Acc

Iteration:   0%|          | 0/479 [00:00<?, ?it/s]

2024-01-28 02:58:50 - Information Retrieval Evaluation on  dataset after epoch 1:
2024-01-28 02:59:26 - Queries: 2184
2024-01-28 02:59:26 - Corpus: 24093

2024-01-28 02:59:26 - Score-Function: cos_sim
2024-01-28 02:59:26 - Accuracy@1: 82.23%
2024-01-28 02:59:26 - Accuracy@3: 91.80%
2024-01-28 02:59:26 - Accuracy@5: 94.78%
2024-01-28 02:59:26 - Accuracy@10: 96.98%
2024-01-28 02:59:26 - Precision@1: 82.23%
2024-01-28 02:59:26 - Precision@3: 30.60%
2024-01-28 02:59:26 - Precision@5: 18.96%
2024-01-28 02:59:26 - Precision@10: 9.70%
2024-01-28 02:59:26 - Recall@1: 82.23%
2024-01-28 02:59:26 - Recall@3: 91.80%
2024-01-28 02:59:26 - Recall@5: 94.78%
2024-01-28 02:59:26 - Recall@10: 96.98%
2024-01-28 02:59:26 - MRR@10: 0.8756
2024-01-28 02:59:26 - NDCG@10: 0.8987
2024-01-28 02:59:26 - MAP@100: 0.8770
2024-01-28 02:59:26 - Score-Function: dot_score
2024-01-28 02:59:26 - Accuracy@1: 82.23%
2024-01-28 02:59:26 - Accuracy@3: 91.80%
2024-01-28 02:59:26 - Accuracy@5: 94.78%
2024-01-28 02:59:26 - Acc

Iteration:   0%|          | 0/479 [00:00<?, ?it/s]

2024-01-28 03:01:00 - Information Retrieval Evaluation on  dataset after epoch 2:
2024-01-28 03:01:36 - Queries: 2184
2024-01-28 03:01:36 - Corpus: 24093

2024-01-28 03:01:36 - Score-Function: cos_sim
2024-01-28 03:01:36 - Accuracy@1: 79.95%
2024-01-28 03:01:36 - Accuracy@3: 91.03%
2024-01-28 03:01:36 - Accuracy@5: 93.68%
2024-01-28 03:01:36 - Accuracy@10: 96.29%
2024-01-28 03:01:36 - Precision@1: 79.95%
2024-01-28 03:01:36 - Precision@3: 30.34%
2024-01-28 03:01:36 - Precision@5: 18.74%
2024-01-28 03:01:36 - Precision@10: 9.63%
2024-01-28 03:01:36 - Recall@1: 79.95%
2024-01-28 03:01:36 - Recall@3: 91.03%
2024-01-28 03:01:36 - Recall@5: 93.68%
2024-01-28 03:01:36 - Recall@10: 96.29%
2024-01-28 03:01:36 - MRR@10: 0.8595
2024-01-28 03:01:36 - NDCG@10: 0.8849
2024-01-28 03:01:36 - MAP@100: 0.8611
2024-01-28 03:01:36 - Score-Function: dot_score
2024-01-28 03:01:36 - Accuracy@1: 79.95%
2024-01-28 03:01:36 - Accuracy@3: 91.07%
2024-01-28 03:01:36 - Accuracy@5: 93.68%
2024-01-28 03:01:36 - Acc

Iteration:   0%|          | 0/479 [00:00<?, ?it/s]

2024-01-28 03:03:08 - Information Retrieval Evaluation on  dataset after epoch 3:
2024-01-28 03:03:44 - Queries: 2184
2024-01-28 03:03:44 - Corpus: 24093

2024-01-28 03:03:44 - Score-Function: cos_sim
2024-01-28 03:03:44 - Accuracy@1: 81.55%
2024-01-28 03:03:44 - Accuracy@3: 91.90%
2024-01-28 03:03:44 - Accuracy@5: 94.92%
2024-01-28 03:03:44 - Accuracy@10: 97.16%
2024-01-28 03:03:44 - Precision@1: 81.55%
2024-01-28 03:03:44 - Precision@3: 30.63%
2024-01-28 03:03:44 - Precision@5: 18.98%
2024-01-28 03:03:44 - Precision@10: 9.72%
2024-01-28 03:03:44 - Recall@1: 81.55%
2024-01-28 03:03:44 - Recall@3: 91.90%
2024-01-28 03:03:44 - Recall@5: 94.92%
2024-01-28 03:03:44 - Recall@10: 97.16%
2024-01-28 03:03:44 - MRR@10: 0.8719
2024-01-28 03:03:44 - NDCG@10: 0.8964
2024-01-28 03:03:44 - MAP@100: 0.8731
2024-01-28 03:03:44 - Score-Function: dot_score
2024-01-28 03:03:44 - Accuracy@1: 81.59%
2024-01-28 03:03:44 - Accuracy@3: 91.85%
2024-01-28 03:03:44 - Accuracy@5: 94.92%
2024-01-28 03:03:44 - Acc

Iteration:   0%|          | 0/479 [00:00<?, ?it/s]

2024-01-28 03:05:16 - Information Retrieval Evaluation on  dataset after epoch 4:
2024-01-28 03:05:52 - Queries: 2184
2024-01-28 03:05:52 - Corpus: 24093

2024-01-28 03:05:52 - Score-Function: cos_sim
2024-01-28 03:05:52 - Accuracy@1: 83.65%
2024-01-28 03:05:52 - Accuracy@3: 93.13%
2024-01-28 03:05:52 - Accuracy@5: 95.15%
2024-01-28 03:05:52 - Accuracy@10: 97.71%
2024-01-28 03:05:52 - Precision@1: 83.65%
2024-01-28 03:05:52 - Precision@3: 31.04%
2024-01-28 03:05:52 - Precision@5: 19.03%
2024-01-28 03:05:52 - Precision@10: 9.77%
2024-01-28 03:05:52 - Recall@1: 83.65%
2024-01-28 03:05:52 - Recall@3: 93.13%
2024-01-28 03:05:52 - Recall@5: 95.15%
2024-01-28 03:05:52 - Recall@10: 97.71%
2024-01-28 03:05:52 - MRR@10: 0.8874
2024-01-28 03:05:52 - NDCG@10: 0.9094
2024-01-28 03:05:52 - MAP@100: 0.8884
2024-01-28 03:05:52 - Score-Function: dot_score
2024-01-28 03:05:52 - Accuracy@1: 83.65%
2024-01-28 03:05:52 - Accuracy@3: 93.13%
2024-01-28 03:05:52 - Accuracy@5: 95.15%
2024-01-28 03:05:52 - Acc

Iteration:   0%|          | 0/479 [00:00<?, ?it/s]

2024-01-28 03:07:26 - Information Retrieval Evaluation on  dataset after epoch 5:
2024-01-28 03:08:02 - Queries: 2184
2024-01-28 03:08:02 - Corpus: 24093

2024-01-28 03:08:02 - Score-Function: cos_sim
2024-01-28 03:08:02 - Accuracy@1: 83.01%
2024-01-28 03:08:02 - Accuracy@3: 92.54%
2024-01-28 03:08:02 - Accuracy@5: 95.15%
2024-01-28 03:08:02 - Accuracy@10: 97.25%
2024-01-28 03:08:02 - Precision@1: 83.01%
2024-01-28 03:08:02 - Precision@3: 30.85%
2024-01-28 03:08:02 - Precision@5: 19.03%
2024-01-28 03:08:02 - Precision@10: 9.73%
2024-01-28 03:08:02 - Recall@1: 83.01%
2024-01-28 03:08:02 - Recall@3: 92.54%
2024-01-28 03:08:02 - Recall@5: 95.15%
2024-01-28 03:08:02 - Recall@10: 97.25%
2024-01-28 03:08:02 - MRR@10: 0.8825
2024-01-28 03:08:02 - NDCG@10: 0.9046
2024-01-28 03:08:02 - MAP@100: 0.8838
2024-01-28 03:08:02 - Score-Function: dot_score
2024-01-28 03:08:02 - Accuracy@1: 83.01%
2024-01-28 03:08:02 - Accuracy@3: 92.54%
2024-01-28 03:08:02 - Accuracy@5: 95.15%
2024-01-28 03:08:02 - Acc

Iteration:   0%|          | 0/479 [00:00<?, ?it/s]

2024-01-28 03:09:33 - Information Retrieval Evaluation on  dataset after epoch 6:
2024-01-28 03:10:09 - Queries: 2184
2024-01-28 03:10:09 - Corpus: 24093

2024-01-28 03:10:10 - Score-Function: cos_sim
2024-01-28 03:10:10 - Accuracy@1: 82.88%
2024-01-28 03:10:10 - Accuracy@3: 92.77%
2024-01-28 03:10:10 - Accuracy@5: 95.05%
2024-01-28 03:10:10 - Accuracy@10: 97.48%
2024-01-28 03:10:10 - Precision@1: 82.88%
2024-01-28 03:10:10 - Precision@3: 30.92%
2024-01-28 03:10:10 - Precision@5: 19.01%
2024-01-28 03:10:10 - Precision@10: 9.75%
2024-01-28 03:10:10 - Recall@1: 82.88%
2024-01-28 03:10:10 - Recall@3: 92.77%
2024-01-28 03:10:10 - Recall@5: 95.05%
2024-01-28 03:10:10 - Recall@10: 97.48%
2024-01-28 03:10:10 - MRR@10: 0.8818
2024-01-28 03:10:10 - NDCG@10: 0.9046
2024-01-28 03:10:10 - MAP@100: 0.8829
2024-01-28 03:10:10 - Score-Function: dot_score
2024-01-28 03:10:10 - Accuracy@1: 82.83%
2024-01-28 03:10:10 - Accuracy@3: 92.72%
2024-01-28 03:10:10 - Accuracy@5: 95.05%
2024-01-28 03:10:10 - Acc

Iteration:   0%|          | 0/479 [00:00<?, ?it/s]

2024-01-28 03:11:42 - Information Retrieval Evaluation on  dataset after epoch 7:
2024-01-28 03:12:17 - Queries: 2184
2024-01-28 03:12:17 - Corpus: 24093

2024-01-28 03:12:18 - Score-Function: cos_sim
2024-01-28 03:12:18 - Accuracy@1: 83.15%
2024-01-28 03:12:18 - Accuracy@3: 92.54%
2024-01-28 03:12:18 - Accuracy@5: 94.92%
2024-01-28 03:12:18 - Accuracy@10: 97.21%
2024-01-28 03:12:18 - Precision@1: 83.15%
2024-01-28 03:12:18 - Precision@3: 30.85%
2024-01-28 03:12:18 - Precision@5: 18.98%
2024-01-28 03:12:18 - Precision@10: 9.72%
2024-01-28 03:12:18 - Recall@1: 83.15%
2024-01-28 03:12:18 - Recall@3: 92.54%
2024-01-28 03:12:18 - Recall@5: 94.92%
2024-01-28 03:12:18 - Recall@10: 97.21%
2024-01-28 03:12:18 - MRR@10: 0.8821
2024-01-28 03:12:18 - NDCG@10: 0.9041
2024-01-28 03:12:18 - MAP@100: 0.8833
2024-01-28 03:12:18 - Score-Function: dot_score
2024-01-28 03:12:18 - Accuracy@1: 83.15%
2024-01-28 03:12:18 - Accuracy@3: 92.54%
2024-01-28 03:12:18 - Accuracy@5: 94.92%
2024-01-28 03:12:18 - Acc

Iteration:   0%|          | 0/479 [00:00<?, ?it/s]

2024-01-28 03:13:49 - Information Retrieval Evaluation on  dataset after epoch 8:
2024-01-28 03:14:25 - Queries: 2184
2024-01-28 03:14:25 - Corpus: 24093

2024-01-28 03:14:25 - Score-Function: cos_sim
2024-01-28 03:14:25 - Accuracy@1: 83.42%
2024-01-28 03:14:25 - Accuracy@3: 92.90%
2024-01-28 03:14:25 - Accuracy@5: 95.19%
2024-01-28 03:14:25 - Accuracy@10: 97.39%
2024-01-28 03:14:25 - Precision@1: 83.42%
2024-01-28 03:14:25 - Precision@3: 30.97%
2024-01-28 03:14:25 - Precision@5: 19.04%
2024-01-28 03:14:25 - Precision@10: 9.74%
2024-01-28 03:14:25 - Recall@1: 83.42%
2024-01-28 03:14:25 - Recall@3: 92.90%
2024-01-28 03:14:25 - Recall@5: 95.19%
2024-01-28 03:14:25 - Recall@10: 97.39%
2024-01-28 03:14:25 - MRR@10: 0.8848
2024-01-28 03:14:25 - NDCG@10: 0.9067
2024-01-28 03:14:25 - MAP@100: 0.8860
2024-01-28 03:14:25 - Score-Function: dot_score
2024-01-28 03:14:25 - Accuracy@1: 83.33%
2024-01-28 03:14:25 - Accuracy@3: 92.90%
2024-01-28 03:14:25 - Accuracy@5: 95.19%
2024-01-28 03:14:25 - Acc

Iteration:   0%|          | 0/479 [00:00<?, ?it/s]

2024-01-28 03:15:56 - Information Retrieval Evaluation on  dataset after epoch 9:
2024-01-28 03:16:32 - Queries: 2184
2024-01-28 03:16:32 - Corpus: 24093

2024-01-28 03:16:33 - Score-Function: cos_sim
2024-01-28 03:16:33 - Accuracy@1: 83.42%
2024-01-28 03:16:33 - Accuracy@3: 92.77%
2024-01-28 03:16:33 - Accuracy@5: 95.10%
2024-01-28 03:16:33 - Accuracy@10: 97.34%
2024-01-28 03:16:33 - Precision@1: 83.42%
2024-01-28 03:16:33 - Precision@3: 30.92%
2024-01-28 03:16:33 - Precision@5: 19.02%
2024-01-28 03:16:33 - Precision@10: 9.73%
2024-01-28 03:16:33 - Recall@1: 83.42%
2024-01-28 03:16:33 - Recall@3: 92.77%
2024-01-28 03:16:33 - Recall@5: 95.10%
2024-01-28 03:16:33 - Recall@10: 97.34%
2024-01-28 03:16:33 - MRR@10: 0.8844
2024-01-28 03:16:33 - NDCG@10: 0.9063
2024-01-28 03:16:33 - MAP@100: 0.8856
2024-01-28 03:16:33 - Score-Function: dot_score
2024-01-28 03:16:33 - Accuracy@1: 83.42%
2024-01-28 03:16:33 - Accuracy@3: 92.77%
2024-01-28 03:16:33 - Accuracy@5: 95.10%
2024-01-28 03:16:33 - Acc

### v4.0 Base modle with 44 as the batch size, with the only the synth data, add Title to the passage context

In [45]:
%%time
input_examples = []
for data in train_data_synth:
    input_examples.append(InputExample(texts=[data['query'], data['title']+ ": " + data['pos']]))

base_model_name = 'sentence-transformers/all-mpnet-base-v2'

train_config ={
        "base_model": base_model_name,
        "epochs": 10,
        "train_batch_size" : 44,
        "warmup_steps": 50,
        "max_length": 512,
        "device": 'cuda',
        "train": True
}

train(train_config, input_examples, test_queries, test_passages, test_relevant_docs)

2024-01-28 16:01:35 - Load pretrained SentenceTransformer: sentence-transformers/all-mpnet-base-v2
2024-01-28 16:01:35 - Use pytorch device: cuda
2024-01-28 16:01:35 - Warmup-steps: 182


Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/182 [00:00<?, ?it/s]

2024-01-28 16:02:32 - Information Retrieval Evaluation on  dataset after epoch 0:
2024-01-28 16:03:08 - Queries: 2184
2024-01-28 16:03:08 - Corpus: 24093

2024-01-28 16:03:08 - Score-Function: cos_sim
2024-01-28 16:03:08 - Accuracy@1: 83.38%
2024-01-28 16:03:08 - Accuracy@3: 94.78%
2024-01-28 16:03:08 - Accuracy@5: 97.12%
2024-01-28 16:03:08 - Accuracy@10: 98.58%
2024-01-28 16:03:08 - Precision@1: 83.38%
2024-01-28 16:03:08 - Precision@3: 31.59%
2024-01-28 16:03:08 - Precision@5: 19.42%
2024-01-28 16:03:08 - Precision@10: 9.86%
2024-01-28 16:03:08 - Recall@1: 83.38%
2024-01-28 16:03:08 - Recall@3: 94.78%
2024-01-28 16:03:08 - Recall@5: 97.12%
2024-01-28 16:03:08 - Recall@10: 98.58%
2024-01-28 16:03:08 - MRR@10: 0.8938
2024-01-28 16:03:08 - NDCG@10: 0.9167
2024-01-28 16:03:08 - MAP@100: 0.8944
2024-01-28 16:03:08 - Score-Function: dot_score
2024-01-28 16:03:08 - Accuracy@1: 83.29%
2024-01-28 16:03:08 - Accuracy@3: 94.78%
2024-01-28 16:03:08 - Accuracy@5: 97.12%
2024-01-28 16:03:08 - Acc

Iteration:   0%|          | 0/182 [00:00<?, ?it/s]

2024-01-28 16:04:05 - Information Retrieval Evaluation on  dataset after epoch 1:
2024-01-28 16:04:41 - Queries: 2184
2024-01-28 16:04:41 - Corpus: 24093

2024-01-28 16:04:41 - Score-Function: cos_sim
2024-01-28 16:04:41 - Accuracy@1: 82.37%
2024-01-28 16:04:41 - Accuracy@3: 94.05%
2024-01-28 16:04:41 - Accuracy@5: 96.57%
2024-01-28 16:04:41 - Accuracy@10: 98.12%
2024-01-28 16:04:41 - Precision@1: 82.37%
2024-01-28 16:04:41 - Precision@3: 31.35%
2024-01-28 16:04:41 - Precision@5: 19.31%
2024-01-28 16:04:41 - Precision@10: 9.81%
2024-01-28 16:04:41 - Recall@1: 82.37%
2024-01-28 16:04:41 - Recall@3: 94.05%
2024-01-28 16:04:41 - Recall@5: 96.57%
2024-01-28 16:04:41 - Recall@10: 98.12%
2024-01-28 16:04:41 - MRR@10: 0.8855
2024-01-28 16:04:41 - NDCG@10: 0.9093
2024-01-28 16:04:41 - MAP@100: 0.8863
2024-01-28 16:04:41 - Score-Function: dot_score
2024-01-28 16:04:41 - Accuracy@1: 82.42%
2024-01-28 16:04:41 - Accuracy@3: 94.00%
2024-01-28 16:04:41 - Accuracy@5: 96.57%
2024-01-28 16:04:41 - Acc

Iteration:   0%|          | 0/182 [00:00<?, ?it/s]

2024-01-28 16:05:39 - Information Retrieval Evaluation on  dataset after epoch 2:
2024-01-28 16:06:15 - Queries: 2184
2024-01-28 16:06:15 - Corpus: 24093

2024-01-28 16:06:16 - Score-Function: cos_sim
2024-01-28 16:06:16 - Accuracy@1: 81.91%
2024-01-28 16:06:16 - Accuracy@3: 93.96%
2024-01-28 16:06:16 - Accuracy@5: 96.25%
2024-01-28 16:06:16 - Accuracy@10: 98.08%
2024-01-28 16:06:16 - Precision@1: 81.91%
2024-01-28 16:06:16 - Precision@3: 31.32%
2024-01-28 16:06:16 - Precision@5: 19.25%
2024-01-28 16:06:16 - Precision@10: 9.81%
2024-01-28 16:06:16 - Recall@1: 81.91%
2024-01-28 16:06:16 - Recall@3: 93.96%
2024-01-28 16:06:16 - Recall@5: 96.25%
2024-01-28 16:06:16 - Recall@10: 98.08%
2024-01-28 16:06:16 - MRR@10: 0.8827
2024-01-28 16:06:16 - NDCG@10: 0.9071
2024-01-28 16:06:16 - MAP@100: 0.8835
2024-01-28 16:06:16 - Score-Function: dot_score
2024-01-28 16:06:16 - Accuracy@1: 81.91%
2024-01-28 16:06:16 - Accuracy@3: 93.96%
2024-01-28 16:06:16 - Accuracy@5: 96.25%
2024-01-28 16:06:16 - Acc

Iteration:   0%|          | 0/182 [00:00<?, ?it/s]

2024-01-28 16:07:12 - Information Retrieval Evaluation on  dataset after epoch 3:
2024-01-28 16:07:48 - Queries: 2184
2024-01-28 16:07:48 - Corpus: 24093

2024-01-28 16:07:48 - Score-Function: cos_sim
2024-01-28 16:07:48 - Accuracy@1: 82.65%
2024-01-28 16:07:48 - Accuracy@3: 94.28%
2024-01-28 16:07:48 - Accuracy@5: 96.66%
2024-01-28 16:07:48 - Accuracy@10: 98.21%
2024-01-28 16:07:48 - Precision@1: 82.65%
2024-01-28 16:07:48 - Precision@3: 31.43%
2024-01-28 16:07:48 - Precision@5: 19.33%
2024-01-28 16:07:48 - Precision@10: 9.82%
2024-01-28 16:07:48 - Recall@1: 82.65%
2024-01-28 16:07:48 - Recall@3: 94.28%
2024-01-28 16:07:48 - Recall@5: 96.66%
2024-01-28 16:07:48 - Recall@10: 98.21%
2024-01-28 16:07:48 - MRR@10: 0.8883
2024-01-28 16:07:48 - NDCG@10: 0.9117
2024-01-28 16:07:48 - MAP@100: 0.8891
2024-01-28 16:07:48 - Score-Function: dot_score
2024-01-28 16:07:48 - Accuracy@1: 82.60%
2024-01-28 16:07:48 - Accuracy@3: 94.28%
2024-01-28 16:07:48 - Accuracy@5: 96.66%
2024-01-28 16:07:48 - Acc

Iteration:   0%|          | 0/182 [00:00<?, ?it/s]

2024-01-28 16:08:46 - Information Retrieval Evaluation on  dataset after epoch 4:
2024-01-28 16:09:22 - Queries: 2184
2024-01-28 16:09:22 - Corpus: 24093

2024-01-28 16:09:22 - Score-Function: cos_sim
2024-01-28 16:09:22 - Accuracy@1: 81.78%
2024-01-28 16:09:22 - Accuracy@3: 93.82%
2024-01-28 16:09:22 - Accuracy@5: 96.29%
2024-01-28 16:09:22 - Accuracy@10: 97.94%
2024-01-28 16:09:22 - Precision@1: 81.78%
2024-01-28 16:09:22 - Precision@3: 31.27%
2024-01-28 16:09:22 - Precision@5: 19.26%
2024-01-28 16:09:22 - Precision@10: 9.79%
2024-01-28 16:09:22 - Recall@1: 81.78%
2024-01-28 16:09:22 - Recall@3: 93.82%
2024-01-28 16:09:22 - Recall@5: 96.29%
2024-01-28 16:09:22 - Recall@10: 97.94%
2024-01-28 16:09:22 - MRR@10: 0.8817
2024-01-28 16:09:22 - NDCG@10: 0.9060
2024-01-28 16:09:22 - MAP@100: 0.8827
2024-01-28 16:09:22 - Score-Function: dot_score
2024-01-28 16:09:22 - Accuracy@1: 81.64%
2024-01-28 16:09:22 - Accuracy@3: 93.82%
2024-01-28 16:09:22 - Accuracy@5: 96.29%
2024-01-28 16:09:22 - Acc

Iteration:   0%|          | 0/182 [00:00<?, ?it/s]

2024-01-28 16:10:19 - Information Retrieval Evaluation on  dataset after epoch 5:
2024-01-28 16:10:55 - Queries: 2184
2024-01-28 16:10:55 - Corpus: 24093

2024-01-28 16:10:55 - Score-Function: cos_sim
2024-01-28 16:10:55 - Accuracy@1: 82.23%
2024-01-28 16:10:55 - Accuracy@3: 93.59%
2024-01-28 16:10:55 - Accuracy@5: 96.34%
2024-01-28 16:10:55 - Accuracy@10: 97.94%
2024-01-28 16:10:55 - Precision@1: 82.23%
2024-01-28 16:10:55 - Precision@3: 31.20%
2024-01-28 16:10:55 - Precision@5: 19.27%
2024-01-28 16:10:55 - Precision@10: 9.79%
2024-01-28 16:10:55 - Recall@1: 82.23%
2024-01-28 16:10:55 - Recall@3: 93.59%
2024-01-28 16:10:55 - Recall@5: 96.34%
2024-01-28 16:10:55 - Recall@10: 97.94%
2024-01-28 16:10:55 - MRR@10: 0.8841
2024-01-28 16:10:55 - NDCG@10: 0.9078
2024-01-28 16:10:55 - MAP@100: 0.8851
2024-01-28 16:10:55 - Score-Function: dot_score
2024-01-28 16:10:55 - Accuracy@1: 82.19%
2024-01-28 16:10:55 - Accuracy@3: 93.59%
2024-01-28 16:10:55 - Accuracy@5: 96.34%
2024-01-28 16:10:55 - Acc

Iteration:   0%|          | 0/182 [00:00<?, ?it/s]

2024-01-28 16:11:53 - Information Retrieval Evaluation on  dataset after epoch 6:
2024-01-28 16:12:29 - Queries: 2184
2024-01-28 16:12:29 - Corpus: 24093

2024-01-28 16:12:29 - Score-Function: cos_sim
2024-01-28 16:12:29 - Accuracy@1: 82.65%
2024-01-28 16:12:29 - Accuracy@3: 94.23%
2024-01-28 16:12:29 - Accuracy@5: 96.47%
2024-01-28 16:12:29 - Accuracy@10: 98.17%
2024-01-28 16:12:29 - Precision@1: 82.65%
2024-01-28 16:12:29 - Precision@3: 31.41%
2024-01-28 16:12:29 - Precision@5: 19.29%
2024-01-28 16:12:29 - Precision@10: 9.82%
2024-01-28 16:12:29 - Recall@1: 82.65%
2024-01-28 16:12:29 - Recall@3: 94.23%
2024-01-28 16:12:29 - Recall@5: 96.47%
2024-01-28 16:12:29 - Recall@10: 98.17%
2024-01-28 16:12:29 - MRR@10: 0.8880
2024-01-28 16:12:29 - NDCG@10: 0.9113
2024-01-28 16:12:29 - MAP@100: 0.8888
2024-01-28 16:12:29 - Score-Function: dot_score
2024-01-28 16:12:29 - Accuracy@1: 82.55%
2024-01-28 16:12:29 - Accuracy@3: 94.23%
2024-01-28 16:12:29 - Accuracy@5: 96.47%
2024-01-28 16:12:29 - Acc

Iteration:   0%|          | 0/182 [00:00<?, ?it/s]

2024-01-28 16:13:26 - Information Retrieval Evaluation on  dataset after epoch 7:
2024-01-28 16:14:02 - Queries: 2184
2024-01-28 16:14:02 - Corpus: 24093

2024-01-28 16:14:02 - Score-Function: cos_sim
2024-01-28 16:14:02 - Accuracy@1: 82.78%
2024-01-28 16:14:02 - Accuracy@3: 94.41%
2024-01-28 16:14:02 - Accuracy@5: 96.70%
2024-01-28 16:14:02 - Accuracy@10: 98.26%
2024-01-28 16:14:02 - Precision@1: 82.78%
2024-01-28 16:14:02 - Precision@3: 31.47%
2024-01-28 16:14:02 - Precision@5: 19.34%
2024-01-28 16:14:02 - Precision@10: 9.83%
2024-01-28 16:14:02 - Recall@1: 82.78%
2024-01-28 16:14:02 - Recall@3: 94.41%
2024-01-28 16:14:02 - Recall@5: 96.70%
2024-01-28 16:14:02 - Recall@10: 98.26%
2024-01-28 16:14:02 - MRR@10: 0.8894
2024-01-28 16:14:02 - NDCG@10: 0.9127
2024-01-28 16:14:02 - MAP@100: 0.8903
2024-01-28 16:14:02 - Score-Function: dot_score
2024-01-28 16:14:02 - Accuracy@1: 82.74%
2024-01-28 16:14:02 - Accuracy@3: 94.41%
2024-01-28 16:14:02 - Accuracy@5: 96.70%
2024-01-28 16:14:02 - Acc

Iteration:   0%|          | 0/182 [00:00<?, ?it/s]

2024-01-28 16:14:59 - Information Retrieval Evaluation on  dataset after epoch 8:
2024-01-28 16:15:35 - Queries: 2184
2024-01-28 16:15:35 - Corpus: 24093

2024-01-28 16:15:36 - Score-Function: cos_sim
2024-01-28 16:15:36 - Accuracy@1: 82.60%
2024-01-28 16:15:36 - Accuracy@3: 94.18%
2024-01-28 16:15:36 - Accuracy@5: 96.52%
2024-01-28 16:15:36 - Accuracy@10: 98.31%
2024-01-28 16:15:36 - Precision@1: 82.60%
2024-01-28 16:15:36 - Precision@3: 31.39%
2024-01-28 16:15:36 - Precision@5: 19.30%
2024-01-28 16:15:36 - Precision@10: 9.83%
2024-01-28 16:15:36 - Recall@1: 82.60%
2024-01-28 16:15:36 - Recall@3: 94.18%
2024-01-28 16:15:36 - Recall@5: 96.52%
2024-01-28 16:15:36 - Recall@10: 98.31%
2024-01-28 16:15:36 - MRR@10: 0.8883
2024-01-28 16:15:36 - NDCG@10: 0.9119
2024-01-28 16:15:36 - MAP@100: 0.8890
2024-01-28 16:15:36 - Score-Function: dot_score
2024-01-28 16:15:36 - Accuracy@1: 82.65%
2024-01-28 16:15:36 - Accuracy@3: 94.18%
2024-01-28 16:15:36 - Accuracy@5: 96.52%
2024-01-28 16:15:36 - Acc

Iteration:   0%|          | 0/182 [00:00<?, ?it/s]

2024-01-28 16:16:33 - Information Retrieval Evaluation on  dataset after epoch 9:
2024-01-28 16:17:09 - Queries: 2184
2024-01-28 16:17:09 - Corpus: 24093

2024-01-28 16:17:09 - Score-Function: cos_sim
2024-01-28 16:17:09 - Accuracy@1: 82.33%
2024-01-28 16:17:09 - Accuracy@3: 94.14%
2024-01-28 16:17:09 - Accuracy@5: 96.52%
2024-01-28 16:17:09 - Accuracy@10: 98.26%
2024-01-28 16:17:09 - Precision@1: 82.33%
2024-01-28 16:17:09 - Precision@3: 31.38%
2024-01-28 16:17:09 - Precision@5: 19.30%
2024-01-28 16:17:09 - Precision@10: 9.83%
2024-01-28 16:17:09 - Recall@1: 82.33%
2024-01-28 16:17:09 - Recall@3: 94.14%
2024-01-28 16:17:09 - Recall@5: 96.52%
2024-01-28 16:17:09 - Recall@10: 98.26%
2024-01-28 16:17:09 - MRR@10: 0.8865
2024-01-28 16:17:09 - NDCG@10: 0.9105
2024-01-28 16:17:09 - MAP@100: 0.8873
2024-01-28 16:17:09 - Score-Function: dot_score
2024-01-28 16:17:09 - Accuracy@1: 82.33%
2024-01-28 16:17:09 - Accuracy@3: 94.14%
2024-01-28 16:17:09 - Accuracy@5: 96.52%
2024-01-28 16:17:09 - Acc

## Evaluation ( Eval on the combined test dataset )

In [8]:
evaluator = InformationRetrievalEvaluator(test_queries, test_passages, test_relevant_docts)

In [9]:
model_name = 'sentence-transformers/all-mpnet-base-v2'
model = SentenceTransformer(model_name, device='cuda')

  return self.fget.__get__(instance, owner)()


### Baseline original model 

In [12]:
evaluator.compute_metrices(model)

{'cos_sim': {'accuracy@k': {1: 0.7962454212454212,
   3: 0.9317765567765568,
   5: 0.9601648351648352,
   10: 0.9789377289377289},
  'precision@k': {1: 0.7962454212454212,
   3: 0.31059218559218554,
   5: 0.19203296703296702,
   10: 0.09789377289377288},
  'recall@k': {1: 0.7962454212454212,
   3: 0.9317765567765568,
   5: 0.9601648351648352,
   10: 0.9789377289377289},
  'ndcg@k': {10: 0.8956225018672112},
  'mrr@k': {10: 0.8680139470318038},
  'map@k': {100: 0.869036756178447}},
 'dot_score': {'accuracy@k': {1: 0.7957875457875457,
   3: 0.9317765567765568,
   5: 0.9601648351648352,
   10: 0.9789377289377289},
  'precision@k': {1: 0.7957875457875457,
   3: 0.31059218559218554,
   5: 0.19203296703296702,
   10: 0.09789377289377288},
  'recall@k': {1: 0.7957875457875457,
   3: 0.9317765567765568,
   5: 0.9601648351648352,
   10: 0.9789377289377289},
  'ndcg@k': {10: 0.8953935641382733},
  'mrr@k': {10: 0.8677086967265534},
  'map@k': {100: 0.8687315058731968}}}

### Model trained on the combined dataset, adding a Hard negative. Let the model train for a lot longer and then maybe the accuracy might increase

In [13]:
model_name = 'train_embedder/models/sentence-transformers-all-mpnet-base-v2-2024-01-28_02-55-09'
model = SentenceTransformer(model_name, device='cuda')
evaluator.compute_metrices(model)

{'cos_sim': {'accuracy@k': {1: 0.8365384615384616,
   3: 0.9313186813186813,
   5: 0.9514652014652014,
   10: 0.9771062271062271},
  'precision@k': {1: 0.8365384615384616,
   3: 0.31043956043956045,
   5: 0.19029304029304028,
   10: 0.0977106227106227},
  'recall@k': {1: 0.8365384615384616,
   3: 0.9313186813186813,
   5: 0.9514652014652014,
   10: 0.9771062271062271},
  'ndcg@k': {10: 0.9093850273309033},
  'mrr@k': {10: 0.8874282298389435},
  'map@k': {100: 0.8883934271428281}},
 'dot_score': {'accuracy@k': {1: 0.8365384615384616,
   3: 0.9313186813186813,
   5: 0.9514652014652014,
   10: 0.9771062271062271},
  'precision@k': {1: 0.8365384615384616,
   3: 0.31043956043956045,
   5: 0.19029304029304028,
   10: 0.0977106227106227},
  'recall@k': {1: 0.8365384615384616,
   3: 0.9313186813186813,
   5: 0.9514652014652014,
   10: 0.9771062271062271},
  'ndcg@k': {10: 0.9093850273309033},
  'mrr@k': {10: 0.8874282298389435},
  'map@k': {100: 0.8883963622419172}}}

### Model trained on new combined dataset without a hard negative. Batch size 44

In [14]:
model_name = 'train_embedder/models/sentence-transformers-all-mpnet-base-v2-2024-01-28_02-28-33'
model = SentenceTransformer(model_name, device='cuda')
evaluator.compute_metrices(model)

{'cos_sim': {'accuracy@k': {1: 0.8498168498168498,
   3: 0.9551282051282052,
   5: 0.9757326007326007,
   10: 0.9885531135531136},
  'precision@k': {1: 0.8498168498168498,
   3: 0.31837606837606836,
   5: 0.19514652014652012,
   10: 0.09885531135531135},
  'recall@k': {1: 0.8498168498168498,
   3: 0.9551282051282052,
   5: 0.9757326007326007,
   10: 0.9885531135531136},
  'ndcg@k': {10: 0.9264136299687921},
  'mrr@k': {10: 0.9056961887319026},
  'map@k': {100: 0.9061929212927448}},
 'dot_score': {'accuracy@k': {1: 0.8493589743589743,
   3: 0.9551282051282052,
   5: 0.9757326007326007,
   10: 0.9885531135531136},
  'precision@k': {1: 0.8493589743589743,
   3: 0.31837606837606836,
   5: 0.19514652014652012,
   10: 0.09885531135531135},
  'recall@k': {1: 0.8493589743589743,
   3: 0.9551282051282052,
   5: 0.9757326007326007,
   10: 0.9885531135531136},
  'ndcg@k': {10: 0.9262446417607205},
  'mrr@k': {10: 0.9054672510029649},
  'map@k': {100: 0.905962300198153}}}

### Model trained on the original synth dataset. 

#### Batch size 44

In [18]:
model_name = 'train_embedder/models/sentence-transformers-all-mpnet-base-v2-2024-01-27_20-14-10'
model = SentenceTransformer(model_name, device='cuda')
evaluator.compute_metrices(model)

{'cos_sim': {'accuracy@k': {1: 0.8598901098901099,
   3: 0.9565018315018315,
   5: 0.9752747252747253,
   10: 0.9880952380952381},
  'precision@k': {1: 0.8598901098901099,
   3: 0.3188339438339438,
   5: 0.19505494505494506,
   10: 0.09880952380952379},
  'recall@k': {1: 0.8598901098901099,
   3: 0.9565018315018315,
   5: 0.9752747252747253,
   10: 0.9880952380952381},
  'ndcg@k': {10: 0.9302517795134545},
  'mrr@k': {10: 0.9110340717483572},
  'map@k': {100: 0.9114858255539129}},
 'dot_score': {'accuracy@k': {1: 0.8598901098901099,
   3: 0.9565018315018315,
   5: 0.9752747252747253,
   10: 0.9880952380952381},
  'precision@k': {1: 0.8598901098901099,
   3: 0.3188339438339438,
   5: 0.19505494505494506,
   10: 0.09880952380952379},
  'recall@k': {1: 0.8598901098901099,
   3: 0.9565018315018315,
   5: 0.9752747252747253,
   10: 0.9880952380952381},
  'ndcg@k': {10: 0.9302517795134545},
  'mrr@k': {10: 0.9110340717483572},
  'map@k': {100: 0.911484142188259}}}

#### Batch size 32

In [19]:
model_name = 'train_embedder/models/sentence-transformers-all-mpnet-base-v2-2024-01-27_19-46-50'
model = SentenceTransformer(model_name, device='cuda')
evaluator.compute_metrices(model)

{'cos_sim': {'accuracy@k': {1: 0.8443223443223443,
   3: 0.9496336996336996,
   5: 0.9743589743589743,
   10: 0.9858058608058609},
  'precision@k': {1: 0.8443223443223443,
   3: 0.31654456654456653,
   5: 0.19487179487179485,
   10: 0.09858058608058608},
  'recall@k': {1: 0.8443223443223443,
   3: 0.9496336996336996,
   5: 0.9743589743589743,
   10: 0.9858058608058609},
  'ndcg@k': {10: 0.9222560706815424},
  'mrr@k': {10: 0.9010830934938074},
  'map@k': {100: 0.9016895475162915}},
 'dot_score': {'accuracy@k': {1: 0.8447802197802198,
   3: 0.9496336996336996,
   5: 0.9743589743589743,
   10: 0.9858058608058609},
  'precision@k': {1: 0.8447802197802198,
   3: 0.31654456654456653,
   5: 0.19487179487179485,
   10: 0.09858058608058608},
  'recall@k': {1: 0.8447802197802198,
   3: 0.9496336996336996,
   5: 0.9743589743589743,
   10: 0.9858058608058609},
  'ndcg@k': {10: 0.9224250588896141},
  'mrr@k': {10: 0.9013120312227452},
  'map@k': {100: 0.9019184852452292}}}

#### Batch size 16

In [21]:
model_name = 'train_embedder/models/sentence-transformers-all-mpnet-base-v2-2024-01-27_19-11-04'
model = SentenceTransformer(model_name, device='cuda')
evaluator.compute_metrices(model)

{'cos_sim': {'accuracy@k': {1: 0.8397435897435898,
   3: 0.9464285714285714,
   5: 0.9665750915750916,
   10: 0.9835164835164835},
  'precision@k': {1: 0.8397435897435898,
   3: 0.31547619047619047,
   5: 0.1933150183150183,
   10: 0.09835164835164835},
  'recall@k': {1: 0.8397435897435898,
   3: 0.9464285714285714,
   5: 0.9665750915750916,
   10: 0.9835164835164835},
  'ndcg@k': {10: 0.918007903869739},
  'mrr@k': {10: 0.8962708587708583},
  'map@k': {100: 0.8970177927622983}},
 'dot_score': {'accuracy@k': {1: 0.8397435897435898,
   3: 0.9459706959706959,
   5: 0.9665750915750916,
   10: 0.9835164835164835},
  'precision@k': {1: 0.8397435897435898,
   3: 0.3153235653235653,
   5: 0.1933150183150183,
   10: 0.09835164835164835},
  'recall@k': {1: 0.8397435897435898,
   3: 0.9459706959706959,
   5: 0.9665750915750916,
   10: 0.9835164835164835},
  'ndcg@k': {10: 0.9179761623670253},
  'mrr@k': {10: 0.8962327024827021},
  'map@k': {100: 0.8969785462944804}}}

### Model trained with title added as a prefix to the passage

#### Without title in the test set

In [46]:
model_name = 'train_embedder/models/sentence-transformers-all-mpnet-base-v2-2024-01-28_16-01-35'
model = SentenceTransformer(model_name, device='cuda')
evaluator.compute_metrices(model)

2024-01-28 16:18:53 - Load pretrained SentenceTransformer: train_embedder/models/sentence-transformers-all-mpnet-base-v2-2024-01-28_16-01-35
2024-01-28 16:19:29 - Queries: 2184
2024-01-28 16:19:29 - Corpus: 24093

2024-01-28 16:19:29 - Score-Function: cos_sim
2024-01-28 16:19:29 - Accuracy@1: 83.38%
2024-01-28 16:19:29 - Accuracy@3: 94.78%
2024-01-28 16:19:29 - Accuracy@5: 97.12%
2024-01-28 16:19:29 - Accuracy@10: 98.58%
2024-01-28 16:19:29 - Precision@1: 83.38%
2024-01-28 16:19:29 - Precision@3: 31.59%
2024-01-28 16:19:29 - Precision@5: 19.42%
2024-01-28 16:19:29 - Precision@10: 9.86%
2024-01-28 16:19:29 - Recall@1: 83.38%
2024-01-28 16:19:29 - Recall@3: 94.78%
2024-01-28 16:19:29 - Recall@5: 97.12%
2024-01-28 16:19:29 - Recall@10: 98.58%
2024-01-28 16:19:29 - MRR@10: 0.8938
2024-01-28 16:19:29 - NDCG@10: 0.9167
2024-01-28 16:19:29 - MAP@100: 0.8944
2024-01-28 16:19:29 - Score-Function: dot_score
2024-01-28 16:19:29 - Accuracy@1: 83.29%
2024-01-28 16:19:29 - Accuracy@3: 94.78%
2024-01

{'cos_sim': {'accuracy@k': {1: 0.8337912087912088,
   3: 0.9478021978021978,
   5: 0.9711538461538461,
   10: 0.9858058608058609},
  'precision@k': {1: 0.8337912087912088,
   3: 0.3159340659340659,
   5: 0.1942307692307692,
   10: 0.09858058608058606},
  'recall@k': {1: 0.8337912087912088,
   3: 0.9478021978021978,
   5: 0.9711538461538461,
   10: 0.9858058608058609},
  'ndcg@k': {10: 0.9167376423168481},
  'mrr@k': {10: 0.8938246773068198},
  'map@k': {100: 0.8943720453694837}},
 'dot_score': {'accuracy@k': {1: 0.8328754578754579,
   3: 0.9478021978021978,
   5: 0.9711538461538461,
   10: 0.9853479853479854},
  'precision@k': {1: 0.8328754578754579,
   3: 0.3159340659340659,
   5: 0.1942307692307692,
   10: 0.09853479853479853},
  'recall@k': {1: 0.8328754578754579,
   3: 0.9478021978021978,
   5: 0.9711538461538461,
   10: 0.9853479853479854},
  'ndcg@k': {10: 0.9162673102109987},
  'mrr@k': {10: 0.8933210143031569},
  'map@k': {100: 0.8939100074074458}}}

#### Lets evaluate when we add title as well to the passage in our test set. Performs even worse

In [47]:
test_queries = { str(idx): data['query'] for idx, data in enumerate(combined_test_data) }

test_passages = { str(idx)+'_'+str(jdx+1): data['title']+ ": " + neg for idx, data in enumerate(combined_test_data) 
                 for jdx, neg in enumerate(data['negs']) }

for idx, data in enumerate(combined_test_data):
    test_passages[str(idx)+'_'+str(0)] = data['pos']

test_relevant_docs = { str(idx): set([str(idx)+'_'+str(0)]) for idx, data in enumerate(combined_test_data) }

local_evaluator = InformationRetrievalEvaluator(test_queries, test_passages, test_relevant_docts)

model_name = 'train_embedder/models/sentence-transformers-all-mpnet-base-v2-2024-01-28_16-01-35'
model = SentenceTransformer(model_name, device='cuda')
local_evaluator.compute_metrices(model)

2024-01-28 16:21:18 - Load pretrained SentenceTransformer: train_embedder/models/sentence-transformers-all-mpnet-base-v2-2024-01-28_16-01-35
2024-01-28 16:21:55 - Queries: 2184
2024-01-28 16:21:55 - Corpus: 24093

2024-01-28 16:21:55 - Score-Function: cos_sim
2024-01-28 16:21:55 - Accuracy@1: 81.91%
2024-01-28 16:21:55 - Accuracy@3: 93.41%
2024-01-28 16:21:55 - Accuracy@5: 95.92%
2024-01-28 16:21:55 - Accuracy@10: 97.80%
2024-01-28 16:21:55 - Precision@1: 81.91%
2024-01-28 16:21:55 - Precision@3: 31.14%
2024-01-28 16:21:55 - Precision@5: 19.18%
2024-01-28 16:21:55 - Precision@10: 9.78%
2024-01-28 16:21:55 - Recall@1: 81.91%
2024-01-28 16:21:55 - Recall@3: 93.41%
2024-01-28 16:21:55 - Recall@5: 95.92%
2024-01-28 16:21:55 - Recall@10: 97.80%
2024-01-28 16:21:55 - MRR@10: 0.8803
2024-01-28 16:21:55 - NDCG@10: 0.9046
2024-01-28 16:21:55 - MAP@100: 0.8813
2024-01-28 16:21:55 - Score-Function: dot_score
2024-01-28 16:21:55 - Accuracy@1: 81.91%
2024-01-28 16:21:55 - Accuracy@3: 93.41%
2024-01

{'cos_sim': {'accuracy@k': {1: 0.8191391941391941,
   3: 0.9340659340659341,
   5: 0.9592490842490843,
   10: 0.978021978021978},
  'precision@k': {1: 0.8191391941391941,
   3: 0.31135531135531136,
   5: 0.19184981684981683,
   10: 0.09780219780219779},
  'recall@k': {1: 0.8191391941391941,
   3: 0.9340659340659341,
   5: 0.9592490842490843,
   10: 0.978021978021978},
  'ndcg@k': {10: 0.9045665650601143},
  'mrr@k': {10: 0.8803002718181286},
  'map@k': {100: 0.8813429766961823}},
 'dot_score': {'accuracy@k': {1: 0.8191391941391941,
   3: 0.9340659340659341,
   5: 0.9592490842490843,
   10: 0.978021978021978},
  'precision@k': {1: 0.8191391941391941,
   3: 0.31135531135531136,
   5: 0.19184981684981683,
   10: 0.09780219780219779},
  'recall@k': {1: 0.8191391941391941,
   3: 0.9340659340659341,
   5: 0.9592490842490843,
   10: 0.978021978021978},
  'ndcg@k': {10: 0.9045665650601143},
  'mrr@k': {10: 0.8803002718181286},
  'map@k': {100: 0.8813429766961823}}}