In [19]:
import ir_datasets
from collections import defaultdict
import pandas as pd
import torch
import random

MAX_QUERIES = 3000

# Load the MS MARCO passage dataset
dataset = ir_datasets.load("msmarco-passage")

# Load the queries from the development set
queries_dataset = ir_datasets.load("msmarco-passage/dev")

# Create dictionaries to store passages and queries
passages = {}
queries = {}
qrels = defaultdict(dict)

# Load passages
for passage in dataset.docs_iter():
    passages[passage.doc_id] = passage.text

# Load queries
for query in queries_dataset.queries_iter():
    queries[query.query_id] = query.text

# Load qrels
for qrel in queries_dataset.qrels_iter():
    qrels[qrel.query_id][qrel.doc_id] = qrel.relevance

# Create triplets with positives and negatives
triplets = []
passage_ids = list(passages.keys())

# Limit to max 100 queries
for query_id, doc_dict in list(qrels.items())[:MAX_QUERIES]:
    # Add positive examples
    for doc_id, relevance in doc_dict.items():
        if query_id in queries and doc_id in passages:
            triplet = (queries[query_id], passages[doc_id], relevance)
            triplets.append(triplet)
            
            # Add 8 random negative examples for this query
            negative_passages = []
            while len(negative_passages) < 8:
                neg_doc_id = random.choice(passage_ids)
                # Make sure negative example isn't actually positive
                if neg_doc_id not in doc_dict:
                    triplet = (queries[query_id], passages[neg_doc_id], 0)
                    triplets.append(triplet)
                    negative_passages.append(neg_doc_id)

# Create DataFrame from triplets
df = pd.DataFrame(triplets, columns=['query', 'passage', 'label'])

# Display first few rows
print(df.head())


                      query  \
0  . what is a corporation?   
1  . what is a corporation?   
2  . what is a corporation?   
3  . what is a corporation?   
4  . what is a corporation?   

                                             passage  label  
0  McDonald's Corporation is one of the most reco...      1  
1  Exclusive provider organization (EPO) vs. pref...      0  
2  You could train rear delts first on shoulder d...      0  
3  [ more ]. Iona is a very prominent first name ...      0  
4  Human diploid cells contain two sets of 23 chr...      0  


In [21]:
df["label"].value_counts()

label
0    25776
1     3222
Name: count, dtype: int64

In [22]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
model = AutoModel.from_pretrained("BAAI/bge-m3")
model.eval()



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


XLMRobertaModel(
  (embeddings): XLMRobertaEmbeddings(
    (word_embeddings): Embedding(250002, 1024, padding_idx=1)
    (position_embeddings): Embedding(8194, 1024, padding_idx=1)
    (token_type_embeddings): Embedding(1, 1024)
    (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): XLMRobertaEncoder(
    (layer): ModuleList(
      (0-23): 24 x XLMRobertaLayer(
        (attention): XLMRobertaAttention(
          (self): XLMRobertaSdpaSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): XLMRobertaSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-05, elem

In [23]:
tokenized_query = tokenizer(df["query"].tolist(), padding=True, truncation=True, return_tensors="pt", max_length=512)
tokenized_passage = tokenizer(df["passage"].tolist(), padding=True, truncation=True, return_tensors="pt", max_length=512)

In [45]:
DEVICE = "cuda:4"
model = model.to(DEVICE)
tokenized_query = tokenized_query.to(DEVICE)
tokenized_passage = tokenized_passage.to(DEVICE)


In [39]:
from tqdm import tqdm
import time

def compute_similarity_scores(model, tokenized_query, tokenized_passage, batch_size=128):
    num_samples = len(tokenized_query['input_ids'])
    embeddings_query = []
    embeddings_passage = []

    start_time = time.time()

    with torch.no_grad():
        # Process queries in batches
        for i in tqdm(range(0, num_samples, batch_size), desc="Processing queries"):
            batch_query = {k: v[i:i+batch_size] for k,v in tokenized_query.items()}
            model_output_query = model(**batch_query)
            embeddings_query.append(model_output_query[0][:, 0])
            
        # Process passages in batches  
        for i in tqdm(range(0, num_samples, batch_size), desc="Processing passages"):
            batch_passage = {k: v[i:i+batch_size] for k,v in tokenized_passage.items()}
            model_output_passage = model(**batch_passage)
            embeddings_passage.append(model_output_passage[0][:, 0])

    # Concatenate batches
    embeddings_query = torch.cat(embeddings_query)
    embeddings_passage = torch.cat(embeddings_passage)

    dot = (embeddings_query * embeddings_passage).sum(axis=1)
    scores = dot.cpu().numpy()

    end_time = time.time()
    print(f"\nTotal processing time: {end_time - start_time:.2f} seconds")
    
    return scores

In [40]:
df["score"] = compute_similarity_scores(model, tokenized_query, tokenized_passage, batch_size=128)

Processing queries: 100%|██████████| 227/227 [00:26<00:00,  8.53it/s]
Processing passages: 100%|██████████| 227/227 [01:58<00:00,  1.92it/s]



Total processing time: 144.93 seconds


In [41]:
batch_query = {k: v[0:128] for k,v in tokenized_query.items()}
preds = model(**batch_query)
preds

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-1.1897,  0.1981, -0.1613,  ..., -0.4873, -1.4341, -1.0992],
         [-0.5305, -0.3591,  0.5178,  ..., -0.1679, -1.1997, -0.4868],
         [-0.3913, -0.1612,  0.5222,  ..., -0.2366, -1.1795, -0.0294],
         ...,
         [-0.3763,  0.5333,  0.4813,  ..., -0.4275, -1.1674, -0.3303],
         [-0.3763,  0.5333,  0.4813,  ..., -0.4275, -1.1674, -0.3303],
         [-0.3763,  0.5333,  0.4813,  ..., -0.4275, -1.1674, -0.3303]],

        [[-1.1897,  0.1981, -0.1613,  ..., -0.4873, -1.4341, -1.0992],
         [-0.5305, -0.3591,  0.5178,  ..., -0.1679, -1.1997, -0.4868],
         [-0.3913, -0.1612,  0.5222,  ..., -0.2366, -1.1795, -0.0294],
         ...,
         [-0.3763,  0.5333,  0.4813,  ..., -0.4275, -1.1674, -0.3303],
         [-0.3763,  0.5333,  0.4813,  ..., -0.4275, -1.1674, -0.3303],
         [-0.3763,  0.5333,  0.4813,  ..., -0.4275, -1.1674, -0.3303]],

        [[-1.1897,  0.1981, -0.1613,  ..., -0.4873, -

In [43]:
wrapped_model = torch.nn.DataParallel(model, device_ids=[4,5,6,7])
wrapped_model

DataParallel(
  (module): XLMRobertaModel(
    (embeddings): XLMRobertaEmbeddings(
      (word_embeddings): Embedding(250002, 1024, padding_idx=1)
      (position_embeddings): Embedding(8194, 1024, padding_idx=1)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): XLMRobertaEncoder(
      (layer): ModuleList(
        (0-23): 24 x XLMRobertaLayer(
          (attention): XLMRobertaAttention(
            (self): XLMRobertaSdpaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): XLMRobertaSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=T

In [46]:
preds2 = wrapped_model(**batch_query)
preds.last_hidden_state[0,0,:5], preds2.last_hidden_state[0,0,:5]

(tensor([-1.1897,  0.1981, -0.1613,  0.2960, -0.2720], device='cuda:1',
        grad_fn=<SliceBackward0>),
 tensor([-1.1897,  0.1981, -0.1613,  0.2960, -0.2720], device='cuda:4',
        grad_fn=<SliceBackward0>))

In [50]:
df["score"] = compute_similarity_scores(wrapped_model, tokenized_query, tokenized_passage, batch_size=4 * 128)

Processing queries: 100%|██████████| 57/57 [00:13<00:00,  4.14it/s]
Processing passages: 100%|██████████| 57/57 [00:35<00:00,  1.59it/s]



Total processing time: 49.93 seconds
