# MS MARCO Data Preparation from Parquet Files

This notebook processes the MS MARCO Passage Ranking dataset from local parquet files for our Learn to Rank project. We'll create a structured JSON format to use in our learning to rank model.

## 1. Introduction and Setup

First, let's import the necessary libraries and set up our environment.

In [1]:
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import random
from pathlib import Path
from collections import defaultdict

In [2]:
# Configuration
DATA_DIR = Path("../data")
RAW_DATA_DIR = DATA_DIR / "raw"
PROCESSED_DATA_DIR = DATA_DIR / "processed"

# Create directories if they don't exist
PROCESSED_DATA_DIR.mkdir(parents=True, exist_ok=True)

# Dataset size controls - useful for development
MAX_SAMPLES = None  # Set to a number for development, None for full dataset
USE_SAMPLE_FOR_STATS = 5000  # Number of samples to use for statistics

## 2. Loading the MS MARCO Parquet Files

Load the MS MARCO dataset from local parquet files in the `data/raw` directory.

In [3]:
# File paths for the parquet files
file_paths = {
    'train': RAW_DATA_DIR / 'train-00000-of-00001.parquet',
    'validation': RAW_DATA_DIR / 'validation-00000-of-00001.parquet',
    'test': RAW_DATA_DIR / 'test-00000-of-00001.parquet'
}

# Verify that the files exist
for name, path in file_paths.items():
    if not path.exists():
        print(f"Warning: {path} does not exist!")
    else:
        print(f"Found {name} file at {path}")

Found train file at ../data/raw/train-00000-of-00001.parquet
Found validation file at ../data/raw/validation-00000-of-00001.parquet
Found test file at ../data/raw/test-00000-of-00001.parquet


In [4]:
def load_msmarco_from_parquet(file_path, max_samples=None):
    """Load MS MARCO data from a parquet file."""
    print(f"Loading data from {file_path}...")
    
    # Load the parquet file
    df = pd.read_parquet(file_path)
    
    # Limit the number of samples if specified
    if max_samples is not None:
        df = df.head(max_samples)
    
    print(f"Loaded {len(df)} rows with columns: {df.columns.tolist()}")
    
    return df

In [5]:
# Load data from all available files
data_frames = {}
for split, path in file_paths.items():
    if path.exists():
        try:
            df = load_msmarco_from_parquet(path, max_samples=MAX_SAMPLES)
            data_frames[split] = df
        except Exception as e:
            print(f"Error loading {split} data: {e}")

# Combine all splits if desired
if len(data_frames) > 0:
    # We can optionally combine all splits, but for now let's keep them separate
    # all_data = pd.concat(data_frames.values(), ignore_index=True)
    # print(f"Combined dataset has {len(all_data)} rows")
    
    # For now we'll primarily use the training data for our processing
    main_data = data_frames.get('train', next(iter(data_frames.values())))
    print(f"Using {len(main_data)} examples for processing")
else:
    print("No data was loaded. Please check the file paths.")

Loading data from ../data/raw/train-00000-of-00001.parquet...
Loaded 82326 rows with columns: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers']
Loading data from ../data/raw/validation-00000-of-00001.parquet...
Loaded 10047 rows with columns: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers']
Loading data from ../data/raw/test-00000-of-00001.parquet...
Loaded 9650 rows with columns: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers']
Using 82326 examples for processing


## 3. Exploratory Data Analysis

Let's examine the data to understand its structure and characteristics.

In [7]:
# Display a sample row to understand the structure
if 'main_data' in locals() and len(main_data) > 0:
    sample_row = main_data.iloc[0]
    print("Sample query:")
    print(f"Query ID: {sample_row['query_id']}")
    print(f"Query text: {sample_row['query']}")
    print(f"Query type: {sample_row['query_type']}")
    print("\nPassages structure:")
    passages = sample_row['passages']
    print(f"Number of passages: {len(passages['passage_text'])}")
    
    # Find the selected passage
    selected_indices = np.where(is_selected == 1)[0]  # Returns array of indices
    selected_idx = selected_indices[0] if len(selected_indices) > 0 else None
    
    if selected_idx is not None:
        print(f"\nSelected passage (index {selected_idx}):")
        selected_text = passages['passage_text'][selected_idx]
        display_text = selected_text[:200] + "..." if len(selected_text) > 200 else selected_text
        print(display_text)
        if 'url' in passages:
            print(f"URL: {passages['url'][selected_idx]}")
    else:
        print("\nNo passage is marked as selected")

Sample query:
Query ID: 19699
Query text: what is rba
Query type: description

Passages structure:
Number of passages: 10


NameError: name 'is_selected' is not defined

In [None]:
# Analyze query lengths
if 'main_data' in locals() and len(main_data) > 0:
    # Use a sample for statistics if dataset is large
    stats_data = main_data.sample(min(USE_SAMPLE_FOR_STATS, len(main_data))) if len(main_data) > USE_SAMPLE_FOR_STATS else main_data
    
    query_lengths = [len(query.split()) for query in stats_data['query']]
    avg_query_length = sum(query_lengths) / len(query_lengths) if query_lengths else 0
    
    plt.figure(figsize=(10, 6))
    plt.hist(query_lengths, bins=30, alpha=0.7)
    plt.axvline(avg_query_length, color='r', linestyle='dashed', linewidth=2, label=f'Avg: {avg_query_length:.2f} words')
    plt.title('Distribution of Query Lengths')
    plt.xlabel('Number of Words')
    plt.ylabel('Count')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    print(f"Average query length: {avg_query_length:.2f} words")
    print(f"Minimum query length: {min(query_lengths)} words")
    print(f"Maximum query length: {max(query_lengths)} words")

In [None]:
# Analyze passage lengths
if 'main_data' in locals() and len(main_data) > 0:
    # Use a sample for statistics if dataset is large
    stats_data = main_data.sample(min(USE_SAMPLE_FOR_STATS, len(main_data))) if len(main_data) > USE_SAMPLE_FOR_STATS else main_data
    
    # Get selected passages
    selected_passages = []
    for _, row in stats_data.iterrows():
        passages = row['passages']
        if 1 in passages['is_selected']:
            selected_idx = passages['is_selected'].index(1)
            selected_passages.append(passages['passage_text'][selected_idx])
    
    passage_lengths = [len(passage.split()) for passage in selected_passages if passage]
    avg_passage_length = sum(passage_lengths) / len(passage_lengths) if passage_lengths else 0
    
    plt.figure(figsize=(10, 6))
    plt.hist(passage_lengths, bins=50, alpha=0.7)
    plt.axvline(avg_passage_length, color='r', linestyle='dashed', linewidth=2, label=f'Avg: {avg_passage_length:.2f} words')
    plt.title('Distribution of Selected Passage Lengths')
    plt.xlabel('Number of Words')
    plt.ylabel('Count')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    print(f"Average passage length: {avg_passage_length:.2f} words")
    print(f"Minimum passage length: {min(passage_lengths)} words")
    print(f"Maximum passage length: {max(passage_lengths)} words")

In [None]:
# Analyze number of passages per query
if 'main_data' in locals() and len(main_data) > 0:
    # Use a sample for statistics if dataset is large
    stats_data = main_data.sample(min(USE_SAMPLE_FOR_STATS, len(main_data))) if len(main_data) > USE_SAMPLE_FOR_STATS else main_data
    
    passages_per_query = [len(row['passages']['passage_text']) for _, row in stats_data.iterrows()]
    avg_passages = sum(passages_per_query) / len(passages_per_query) if passages_per_query else 0
    
    plt.figure(figsize=(10, 6))
    plt.hist(passages_per_query, bins=30, alpha=0.7)
    plt.axvline(avg_passages, color='r', linestyle='dashed', linewidth=2, label=f'Avg: {avg_passages:.2f} passages')
    plt.title('Distribution of Passages per Query')
    plt.xlabel('Number of Passages')
    plt.ylabel('Count')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    print(f"Average passages per query: {avg_passages:.2f}")
    print(f"Minimum passages per query: {min(passages_per_query)}")
    print(f"Maximum passages per query: {max(passages_per_query)}")

## 4. Data Processing

Now, let's process the data to create our structured JSON format with queries, passages, and matches.

In [None]:
def extract_structured_data(df):
    """Extract structured data from the MS MARCO dataframe."""
    queries = {}
    passages = {}
    matches = {}
    
    passage_id_counter = 0
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing data"):
        query_id = str(row['query_id'])
        query_text = row['query']
        
        # Add query to queries dictionary
        queries[query_id] = query_text
        
        # Process passages
        row_passages = row['passages']
        passage_texts = row_passages['passage_text']
        is_selected = row_passages['is_selected']
        
        # Find selected passage index
        selected_idx = is_selected.index(1) if 1 in is_selected else None
        
        if selected_idx is not None:
            # Create unique IDs for all passages in this row
            passage_ids = [f"p{passage_id_counter + i}" for i in range(len(passage_texts))]
            
            # Add passages to passages dictionary
            for i, (pid, text) in enumerate(zip(passage_ids, passage_texts)):
                passages[pid] = text
            
            # Create match entry
            matches[query_id] = {
                "suggested": passage_ids,  # All passages
                "selected": passage_ids[selected_idx]  # The selected passage
            }
            
            # Increment counter for next row
            passage_id_counter += len(passage_texts)
    
    return queries, passages, matches

In [None]:
# Process the data
if 'main_data' in locals() and len(main_data) > 0:
    queries, passages, matches = extract_structured_data(main_data)
    
    print(f"Extracted {len(queries)} queries")
    print(f"Extracted {len(passages)} passages")
    print(f"Created {len(matches)} query-passage matches")

## 5. Creating JSON Structure

Now, let's put everything together into our desired JSON structure.

In [None]:
# Create the final dataset structure
if 'queries' in locals() and 'passages' in locals() and 'matches' in locals():
    msmarco_data = {
        "queries": queries,
        "passages": passages,
        "matches": matches
    }
    
    # Preview the structure
    print("Dataset Structure Preview:")
    print(f"Number of queries: {len(msmarco_data['queries'])}")
    print(f"Number of passages: {len(msmarco_data['passages'])}")
    print(f"Number of matches: {len(msmarco_data['matches'])}")
    
    # Show a sample match
    if matches:
        sample_query_id = next(iter(msmarco_data['matches']))
        sample_match = msmarco_data['matches'][sample_query_id]
        sample_query = msmarco_data['queries'][sample_query_id]
        sample_selected = msmarco_data['passages'][sample_match['selected']]
        
        print("\nSample match:")
        print(f"Query ID: {sample_query_id}")
        print(f"Query text: {sample_query}")
        print(f"Selected passage ID: {sample_match['selected']}")
        display_text = str(sample_selected)[:200] + "..." if len(str(sample_selected)) > 200 else sample_selected
        print(f"Selected passage text: {display_text}")
        print(f"Number of suggested passages: {len(sample_match['suggested'])}")

## 6. Saving and Loading

Let's save our processed data to JSON files and test loading it back.

In [None]:
# Save the data to JSON files
def save_to_json(data, file_path, pretty=True):
    """Save data to a JSON file."""
    indent = 4 if pretty else None
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=indent, ensure_ascii=False)
    return file_path

# Save each component separately
if 'msmarco_data' in locals():
    print("Saving data to JSON files...")
    queries_file = save_to_json(msmarco_data["queries"], PROCESSED_DATA_DIR / "queries.json")
    passages_file = save_to_json(msmarco_data["passages"], PROCESSED_DATA_DIR / "passages.json")
    matches_file = save_to_json(msmarco_data["matches"], PROCESSED_DATA_DIR / "matches.json")
    
    print(f"Data saved to:\n- {queries_file}\n- {passages_file}\n- {matches_file}")

In [None]:
# Test loading the data back
def load_from_json(file_path):
    """Load data from a JSON file."""
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

try:
    print("Testing data loading...")
    loaded_queries = load_from_json(PROCESSED_DATA_DIR / "queries.json")
    loaded_passages = load_from_json(PROCESSED_DATA_DIR / "passages.json")
    loaded_matches = load_from_json(PROCESSED_DATA_DIR / "matches.json")
    
    print(f"Loaded {len(loaded_queries)} queries.")
    print(f"Loaded {len(loaded_passages)} passages.")
    print(f"Loaded {len(loaded_matches)} matches.")
except Exception as e:
    print(f"Error loading JSON files: {e}")

## 7. Sample Triplet Generation

Now, let's demonstrate how to create triplets (query, positive passage, negative passage) for training our LTR model.

In [None]:
def generate_triplets(matches, queries, passages, num_triplets=10):
    """Generate triplets for training with triplet loss."""
    triplets = []
    
    for query_id, match_data in matches.items():
        positive_id = match_data["selected"]
        negative_ids = [pid for pid in match_data["suggested"] if pid != positive_id]
        
        if negative_ids:  # Skip if no negatives
            for negative_id in negative_ids:
                triplet = {
                    "query_id": query_id,
                    "query_text": queries[query_id],
                    "positive_id": positive_id,
                    "positive_text": passages[positive_id],
                    "negative_id": negative_id,
                    "negative_text": passages[negative_id]
                }
                triplets.append(triplet)
                
                # Break if we've reached the desired number of triplets
                if len(triplets) >= num_triplets:
                    return triplets
    
    return triplets

# Generate sample triplets
try:
    sample_triplets = generate_triplets(loaded_matches, loaded_queries, loaded_passages, num_triplets=5)
    
    # Display sample triplets
    print(f"Generated {len(sample_triplets)} sample triplets.")
    for i, triplet in enumerate(sample_triplets, 1):
        print(f"\nTriplet {i}:")
        print(f"Query: {triplet['query_text']}")
        pos_display = str(triplet['positive_text'])[:100] + "..." if len(str(triplet['positive_text'])) > 100 else triplet['positive_text']
        neg_display = str(triplet['negative_text'])[:100] + "..." if len(str(triplet['negative_text'])) > 100 else triplet['negative_text']
        print(f"Positive: {pos_display}")
        print(f"Negative: {neg_display}")
except Exception as e:
    print(f"Error generating triplets: {e}")

## 8. Next Steps

Now that we have prepared our data, here are the next steps for our Learning to Rank project:

1. **Create the PyTorch Dataset Class**
   - Implement a dataset class that reads our JSON files
   - Generate triplets on-the-fly or load pre-generated triplets
   - Apply tokenization and preprocessing

2. **Implement Encoders**
   - Start with a simple encoder for queries and documents
   - Experiment with different architectures later

3. **Implement Triplet Loss Training**
   - Use the generated triplets to train with triplet loss
   - Monitor training metrics

4. **Evaluation**
   - Implement ranking metrics (NDCG, MRR, etc.)
   - Evaluate on a test set

5. **Hard Negative Mining**
   - Implement strategies for finding better negative examples
   - Experiment with online hard negative mining

The processed data we've created provides a solid foundation for these next steps.