In [3]:
import pandas as pd
from datasets import load_dataset
import random

In [4]:

# Load the dataset
ds = load_dataset("microsoft/ms_marco", "v1.1")

In [41]:
print(len(ds["train"]))

82326


In [6]:
# Initialize mappings
doc_text_to_id = {}
doc_id_to_text = {}
next_doc_id = 0

In [22]:
# Create list of all documents in the dataset
all_documents = []
total_passages_count = 0
for example in ds["train"]:
    passage_texts = example["passages"]["passage_text"]
    for passage_text in passage_texts:
        if passage_text not in doc_text_to_id:  # Add document only if it's not seen before
            doc_text_to_id[passage_text] = next_doc_id
            doc_id_to_text[next_doc_id] = passage_text
            next_doc_id += 1
        total_passages_count += 1
        all_documents.append(passage_text)

In [53]:
print(next_doc_id)
print(total_passages_count) # includes duplicates

626907
676193


In [29]:
# Built [query_id, pos_doc_id, neg_doc_id] triplets per batch
def build_triplets_batch(batch):
    output = {
        "query_id": [],
        "pos_doc_id": [],
        "neg_doc_id": [],
    }

    for query_id, passage_dict in zip(batch["query_id"], batch["passages"]):
        passage_texts = passage_dict["passage_text"]
        is_selected_flags = passage_dict["is_selected"]
        
        # Get a positive doc
        pos_doc_id = None
        for i, flag in enumerate(is_selected_flags):
            if flag:
                pos_doc_id = doc_text_to_id.get(passage_texts[i])
                break
        if pos_doc_id is None:
            continue

        # Create a set of current query's docs
        current_query_docs = set(passage_texts)

        # Sample a negative doc efficiently
        neg_doc_text = None
        attempts = 0
        while attempts < 10:  # Try a few times to avoid rare edge cases
            candidate = random.choice(all_documents)
            if candidate not in current_query_docs:
                neg_doc_text = candidate
                break
            attempts += 1
        if neg_doc_text is None:
            continue  # fallback if sampling fails

        neg_doc_id = doc_text_to_id[neg_doc_text]

        output["query_id"].append(query_id)
        output["pos_doc_id"].append(pos_doc_id)
        output["neg_doc_id"].append(neg_doc_id)

    return output

In [43]:
batched_output = ds["train"].map(
    build_triplets_batch,
    batched=True,
    remove_columns=ds["train"].column_names,
    desc="Building triplets" # progress bar
)

In [44]:
print(batched_output[0])
print(len(batched_output))

{'query_id': 19699, 'pos_doc_id': 5, 'neg_doc_id': 592984}
79704


In [46]:
# Create DataFrames for each table
queries_df = pd.DataFrame([(query["query_id"], query["query"]) for query in ds["train"]], columns=["query_id", "query_text"])
documents_df = pd.DataFrame(list(doc_id_to_text.items()), columns=["doc_id", "doc_text"])
triplets_df = pd.DataFrame(batched_output, columns=["query_id", "pos_doc_id", "neg_doc_id"])

# Save tables
queries_df.to_csv("queries.tsv", sep="\t", index=False)
documents_df.to_csv("documents.tsv", sep="\t", index=False)
triplets_df.to_csv("triplets.tsv", sep="\t", index=False)

In [57]:
# Out of curiosity, let's check if there are any positive documents that are used by multiple queries

# Group by pos_doc_id and count occurrences
pos_doc_counts = triplets_df.groupby('pos_doc_id')['query_id'].count().reset_index()

# Filter for positive document IDs that appear more than once
duplicate_pos_docs = pos_doc_counts[pos_doc_counts['query_id'] > 1]

if not duplicate_pos_docs.empty:
    print(f"Found {len(duplicate_pos_docs)} positive documents that are used by multiple queries")
    
    # # If you want to see the actual queries that share the same positive document
    # for pos_doc_id in duplicate_pos_docs['pos_doc_id']:
    #     queries_with_same_doc = triplets_df[triplets_df['pos_doc_id'] == pos_doc_id]
    #     print(f"\nQueries sharing positive document {pos_doc_id}:")
    #     print(queries_with_same_doc[['query_id', 'pos_doc_id']])
else:
    print("No duplicate positive documents found across queries.")

Found 1862 positive documents that are used by multiple queries


In [59]:
from datasets import Dataset

# Push to hugging face 
queries_dataset = Dataset.from_pandas(queries_df)
documents_dataset = Dataset.from_pandas(documents_df)
triplets_dataset = Dataset.from_pandas(triplets_df)

queries_dataset.push_to_hub("amyf/ms-marco-queries-train")
documents_dataset.push_to_hub("amyf/ms-marco-documents-train")
triplets_dataset.push_to_hub("amyf/ms-marco-triplets-train")

Creating parquet from Arrow format: 100%|██████████| 83/83 [00:00<00:00, 2520.80ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.33s/it]
Creating parquet from Arrow format: 100%|██████████| 627/627 [00:00<00:00, 1155.24ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:18<00:00, 18.29s/it]
Creating parquet from Arrow format: 100%|██████████| 80/80 [00:00<00:00, 2701.32ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.09s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/amyf/ms-marco-triplets-train/commit/4a8b99603c7ca9dfd3d10a97a82655763c48b67b', commit_message='Upload dataset', commit_description='', oid='4a8b99603c7ca9dfd3d10a97a82655763c48b67b', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/amyf/ms-marco-triplets-train', endpoint='https://huggingface.co', repo_type='dataset', repo_id='amyf/ms-marco-triplets-train'), pr_revision=None, pr_num=None)