
# MS MARCO Cohere Pre-embedded Subset (PySpark)

This notebook preprocesses the MS MARCO v2.1 Cohere embeddings with PySpark to build a compact subset that preserves query->passage coverage. It resolves the loader failure where too-small `base_limit` / `max_passage_scan` budgets prevented any relevant passages from being materialised.



## Workflow

1. Install PySpark (PyArrow is already required by the repository).
2. Configure source parquet paths, output directory, and recall parameters.
3. Load the MS MARCO parquet shards, detect schema fields, and explode the top-k passage candidates per query.
4. Join the candidate passages with the passage shard, deduplicate overlaps, and retain only queries with enough positives.
5. Export numpy-friendly artifacts (`subset.npz`, `metadata.json`, ID mappings) to `/storage/ice-shared/cs8903onl/vectordb-retrieval/datasets/mamarco_pre_embeded_subset`.


Install PySpark in the notebook environment (skip if your interpreter already has it).

In [1]:
%pip install --quiet pyspark==3.5.1

Note: you may need to restart the kernel to use updated packages.


Set the input/output locations and preprocessing knobs. Adjust limits if you want a larger subset.

In [1]:

from pathlib import Path

PASSAGES_SOURCE = Path("/storage/ice-shared/cs8903onl/vectordb-retrieval/datasets/msmarco_pre_embeded/passages_parquet/msmarco_v2.1_doc_segmented_*.parquet")
QUERIES_SOURCE = Path("/storage/ice-shared/cs8903onl/vectordb-retrieval/datasets/msmarco_pre_embeded/queries_parquet/queries.parquet")
OUTPUT_BASE = Path("/storage/ice-shared/cs8903onl/vectordb-retrieval/datasets/mamarco_pre_embeded_subset")

GROUND_TRUTH_K = 10
MIN_MATCHES_REQUIRED = 10  # keep <= GROUND_TRUTH_K to ensure dense recall labels
MAX_QUERIES = 512          # upper bound on exported queries
TOPK_CANDIDATES = 200      # inspect top-N candidates per query before truncating


Create a Spark session with Arrow acceleration for efficient vector handling.

In [2]:

from pyspark.sql import SparkSession

spark = (
    SparkSession.builder
    .appName("MS MARCO PySpark subset builder")
    .config("spark.sql.execution.arrow.pyspark.enabled", "true")
    .config("spark.sql.shuffle.partitions", "200")
    .getOrCreate()
)

spark


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/10/19 12:34:13 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Load the passage and query parquet files, then inspect the available schema.

In [3]:

from pyspark.sql import functions as F

passages_df = spark.read.parquet(str(PASSAGES_SOURCE))
queries_df = spark.read.parquet(str(QUERIES_SOURCE))

passages_df.printSchema()
queries_df.printSchema()


root
 |-- docid: string (nullable = true)
 |-- url: string (nullable = true)
 |-- title: string (nullable = true)
 |-- headings: string (nullable = true)
 |-- segment: string (nullable = true)
 |-- start_char: long (nullable = true)
 |-- end_char: long (nullable = true)
 |-- emb: array (nullable = true)
 |    |-- element: float (containsNull = true)

root
 |-- _id: string (nullable = true)
 |-- text: string (nullable = true)
 |-- trec-year: long (nullable = true)
 |-- emb: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- top1k_offsets: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- top1k_passage_ids: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- top1k_cossim: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- qrels: array (nullable = true)
 |    |-- element: array (containsNull = true)
 |    |    |-- element: string (containsNull = true)



In [9]:
passages_df.show(10)

25/10/19 12:31:36 ERROR Executor: Exception in task 10.0 in stage 8.0 (TID 139)]
java.lang.OutOfMemoryError: Java heap space
	at java.base/java.nio.HeapByteBuffer.<init>(HeapByteBuffer.java:64)
	at java.base/java.nio.ByteBuffer.allocate(ByteBuffer.java:363)
	at org.apache.parquet.bytes.HeapByteBufferAllocator.allocate(HeapByteBufferAllocator.java:32)
	at org.apache.parquet.hadoop.ParquetFileReader$ConsecutivePartList.readAll(ParquetFileReader.java:1842)
	at org.apache.parquet.hadoop.ParquetFileReader.internalReadRowGroup(ParquetFileReader.java:990)
	at org.apache.parquet.hadoop.ParquetFileReader.readNextRowGroup(ParquetFileReader.java:940)
	at org.apache.parquet.hadoop.ParquetFileReader.readNextFilteredRowGroup(ParquetFileReader.java:1082)
	at org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase$ParquetRowGroupReaderImpl.readNextRowGroup(SpecificParquetRecordReaderBase.java:284)
	at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecord

Py4JError: py4j does not exist in the JVM

ERROR:root:Exception while sending command.
Traceback (most recent call last):
  File "/home/hice1/pli396/.local/lib/python3.10/site-packages/py4j/clientserver.py", line 516, in send_command
    raise Py4JNetworkError("Answer from Java side is empty")
py4j.protocol.Py4JNetworkError: Answer from Java side is empty

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/hice1/pli396/.local/lib/python3.10/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/home/hice1/pli396/.local/lib/python3.10/site-packages/py4j/clientserver.py", line 539, in send_command
    raise Py4JNetworkError(
py4j.protocol.Py4JNetworkError: Error while sending or receiving


In [6]:
queries_df.show(10,truncate=True)



+-------+--------------------+---------+--------------------+--------------------+--------------------+--------------------+--------------------+
|    _id|                text|trec-year|                 emb|       top1k_offsets|   top1k_passage_ids|        top1k_cossim|               qrels|
+-------+--------------------+---------+--------------------+--------------------+--------------------+--------------------+--------------------+
| 787021|what is produced ...|     2021|[-0.04473035, -0....|[17421142, 525134...|[msmarco_v2.1_doc...|[0.62506163, 0.61...|                  []|
|1049187|who recorded be m...|     2021|[-0.025443953, -0...|[31325611, 313256...|[msmarco_v2.1_doc...|[0.73272413, 0.71...|                  []|
|1049519|who said no one c...|     2021|[0.051458575, 0.0...|[56636723, 803704...|[msmarco_v2.1_doc...|[0.7079681, 0.703...|                  []|
| 788054|         what is ptf|     2021|[0.008271873, -0....|[35230535, 388727...|[msmarco_v2.1_doc...|[0.700606, 0.6036...|

                                                                                

Detect the relevant columns for embeddings, identifiers, offsets, and scores.

In [10]:

from typing import Iterable, Optional

def select_column(columns: Iterable[str], candidates: Iterable[str], purpose: str, required: bool = True) -> Optional[str]:
    for name in candidates:
        if name in columns:
            return name
    if required:
        raise ValueError(f"Unable to locate a column for {purpose}. Available columns: {sorted(columns)}")
    return None

passage_columns = passages_df.columns
query_columns = queries_df.columns

passage_embedding_col = select_column(passage_columns, ["emb", "embedding", "vector"], "passage embeddings")
passage_id_col = select_column(passage_columns, ["_id", "id", "passage_id", "docid"], "passage identifiers", required=False)
passage_offset_col = select_column(passage_columns, ["offset", "global_offset", "row_id", "position"], "passage offsets", required=False)

query_embedding_col = select_column(query_columns, ["emb", "embedding", "vector"], "query embeddings")
query_id_col = select_column(query_columns, ["_id", "id", "query_id"], "query identifiers")
query_candidate_ids_col = select_column(query_columns, ["top1k_passage_ids", "positive_passage_ids", "doc_ids"], "candidate passage ids", required=False)
query_candidate_offsets_col = select_column(query_columns, ["top1k_offsets", "positive_passage_offsets", "offsets"], "candidate offsets", required=False)
query_candidate_scores_col = select_column(query_columns, ["top1k_cossim", "scores", "similarities"], "candidate scores", required=False)

if not query_candidate_ids_col and not query_candidate_offsets_col:
    raise ValueError("Need at least one relevance column (ids or offsets) in the queries parquet.")

print("Passage embedding column:", passage_embedding_col)
print("Passage id column:", passage_id_col)
print("Passage offset column:", passage_offset_col)
print("Query embedding column:", query_embedding_col)
print("Query id column:", query_id_col)
print("Query candidate id column:", query_candidate_ids_col)
print("Query candidate offset column:", query_candidate_offsets_col)
print("Query candidate score column:", query_candidate_scores_col)


Passage embedding column: emb
Passage id column: docid
Passage offset column: None
Query embedding column: emb
Query id column: _id
Query candidate id column: top1k_passage_ids
Query candidate offset column: top1k_offsets
Query candidate score column: top1k_cossim


Explode the top-k candidate passages per query, join with passage embeddings, and keep the best-ranked matches.

In [11]:

from pyspark.sql import Window

zip_columns = []
if query_candidate_ids_col:
    zip_columns.append(F.col(query_candidate_ids_col))
if query_candidate_offsets_col:
    zip_columns.append(F.col(query_candidate_offsets_col))
if query_candidate_scores_col:
    zip_columns.append(F.col(query_candidate_scores_col))

queries_candidates = queries_df.select(
    F.col(query_id_col).cast("string").alias("query_id"),
    F.col(query_embedding_col).alias("query_emb"),
    F.arrays_zip(*zip_columns).alias("candidate_pairs")
)

exploded = queries_candidates.select(
    "query_id",
    "query_emb",
    F.posexplode_outer("candidate_pairs").alias("rank", "candidate_struct")
)

candidate_id_expr = F.col(f"candidate_struct.`{query_candidate_ids_col}`") if query_candidate_ids_col else F.lit(None)
candidate_offset_expr = F.col(f"candidate_struct.`{query_candidate_offsets_col}`") if query_candidate_offsets_col else F.lit(None)
candidate_score_expr = F.col(f"candidate_struct.`{query_candidate_scores_col}`") if query_candidate_scores_col else F.lit(None)

exploded = (
    exploded
    .withColumn("candidate_passage_id", candidate_id_expr.cast("string"))
    .withColumn("candidate_offset", candidate_offset_expr.cast("long"))
    .withColumn("candidate_score", candidate_score_expr.cast("double"))
    .drop("candidate_struct")
)

passage_core = passages_df.select(
    F.col(passage_embedding_col).alias("passage_emb"),
    F.col(passage_id_col).cast("string").alias("passage_id") if passage_id_col else F.lit(None).alias("passage_id"),
    F.col(passage_offset_col).cast("long").alias("passage_offset") if passage_offset_col else F.lit(None).cast("long").alias("passage_offset")
)

join_conditions = []
if passage_id_col and query_candidate_ids_col:
    join_conditions.append(F.col("candidate_passage_id") == F.col("passage_id"))
if passage_offset_col and query_candidate_offsets_col:
    join_conditions.append(F.col("candidate_offset") == F.col("passage_offset"))

if not join_conditions:
    raise ValueError("No valid join keys found between queries and passages. Check the dataset schema.")

join_condition = join_conditions[0]
for cond in join_conditions[1:]:
    join_condition = join_condition | cond

matched = (
    exploded
    .where((F.col("candidate_passage_id").isNotNull()) | (F.col("candidate_offset").isNotNull()))
    .join(passage_core, join_condition, how="inner")
)

if TOPK_CANDIDATES:
    matched = matched.where(F.col("rank") < TOPK_CANDIDATES)

rank_window = Window.partitionBy("query_id", "passage_id").orderBy(F.col("rank"))
matched = matched.withColumn("_rn", F.row_number().over(rank_window)).where(F.col("_rn") == 1).drop("_rn")

match_struct = F.struct(
    F.col("rank").alias("rank"),
    F.col("passage_id").alias("passage_id"),
    F.col("passage_emb").alias("passage_emb"),
    F.col("candidate_offset").alias("passage_offset"),
    F.col("candidate_score").alias("score")
)

aggregated = matched.groupBy("query_id", "query_emb").agg(
    F.sort_array(F.collect_list(match_struct)).alias("matches")
)

trimmed = aggregated.withColumn("matches", F.slice("matches", 1, GROUND_TRUTH_K))
filtered = trimmed.where(F.size("matches") >= MIN_MATCHES_REQUIRED)

if MAX_QUERIES:
    filtered = filtered.orderBy("query_id").limit(MAX_QUERIES)

filtered = filtered.cache()
filtered_count = filtered.count()
print(f"Queries retained with >= {MIN_MATCHES_REQUIRED} positives: {filtered_count}")


25/10/19 12:40:47 ERROR Executor: Exception in task 47.0 in stage 13.0 (TID 169)
java.lang.OutOfMemoryError: Java heap space
	at java.base/java.nio.HeapByteBuffer.<init>(HeapByteBuffer.java:64)
	at java.base/java.nio.ByteBuffer.allocate(ByteBuffer.java:363)
	at org.apache.parquet.bytes.HeapByteBufferAllocator.allocate(HeapByteBufferAllocator.java:32)
	at org.apache.parquet.hadoop.ParquetFileReader$ConsecutivePartList.readAll(ParquetFileReader.java:1842)
	at org.apache.parquet.hadoop.ParquetFileReader.internalReadRowGroup(ParquetFileReader.java:990)
	at org.apache.parquet.hadoop.ParquetFileReader.readNextRowGroup(ParquetFileReader.java:940)
	at org.apache.parquet.hadoop.ParquetFileReader.readNextFilteredRowGroup(ParquetFileReader.java:1100)
	at org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase$ParquetRowGroupReaderImpl.readNextRowGroup(SpecificParquetRecordReaderBase.java:284)
	at org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecord

Py4JError: An error occurred while calling o160.count

Collect the subset to the driver and build numpy arrays for passages, queries, and recall labels.

In [None]:

import numpy as np
from datetime import datetime, timezone

subset_rows = filtered.collect()
if not subset_rows:
    raise ValueError("No queries satisfied the filtering criteria. Relax MIN_MATCHES_REQUIRED or raise TOPK_CANDIDATES.")

passage_index = {}
passage_vectors = []
passage_records = []

for row in subset_rows:
    for match in row.matches:
        pid = match.passage_id
        if pid is None and match.passage_offset is not None:
            pid = f"offset:{match.passage_offset}"
        if pid is None:
            continue
        if pid not in passage_index:
            vec = np.asarray(match.passage_emb, dtype=np.float32)
            passage_index[pid] = len(passage_vectors)
            passage_vectors.append(vec)
            passage_records.append({
                "index": passage_index[pid],
                "passage_id": pid,
                "passage_offset": None if match.passage_offset is None else int(match.passage_offset),
            })

query_vectors = []
query_records = []
ground_truth = []

for row in subset_rows:
    q_vec = np.asarray(row.query_emb, dtype=np.float32)
    query_vectors.append(q_vec)
    query_records.append({"query_id": row.query_id})
    indices = []
    for match in row.matches:
        pid = match.passage_id
        if pid is None and match.passage_offset is not None:
            pid = f"offset:{match.passage_offset}"
        if pid is None:
            continue
        indices.append(passage_index[pid])
    if len(indices) < GROUND_TRUTH_K:
        raise ValueError("Matches were truncated below GROUND_TRUTH_K; adjust MIN_MATCHES_REQUIRED or TOPK_CANDIDATES.")
    ground_truth.append(indices[:GROUND_TRUTH_K])

train_array = np.stack(passage_vectors, axis=0)
query_array = np.stack(query_vectors, axis=0)

print(f"Passage matrix: {train_array.shape}")
print(f"Query matrix: {query_array.shape}")


Persist the subset (`subset.npz`) plus lightweight metadata files.

In [None]:

import json

OUTPUT_BASE.mkdir(parents=True, exist_ok=True)

np.savez_compressed(OUTPUT_BASE / "subset.npz", train=train_array, test=query_array, ground_truth=np.asarray(ground_truth, dtype=np.int32))

metadata = {
    "created_at": datetime.now(timezone.utc).isoformat(),
    "ground_truth_k": GROUND_TRUTH_K,
    "queries": len(query_records),
    "passages": len(passage_records),
    "embedding_dim": int(train_array.shape[1]),
    "topk_candidates": TOPK_CANDIDATES,
    "min_matches_required": MIN_MATCHES_REQUIRED,
    "source_passages": str(PASSAGES_SOURCE),
    "source_queries": str(QUERIES_SOURCE),
}

with open(OUTPUT_BASE / "metadata.json", "w", encoding="utf-8") as f:
    json.dump(metadata, f, indent=2)

with open(OUTPUT_BASE / "passage_index.json", "w", encoding="utf-8") as f:
    json.dump(passage_records, f, indent=2)

with open(OUTPUT_BASE / "query_index.json", "w", encoding="utf-8") as f:
    json.dump(query_records, f, indent=2)

print(f"Wrote subset assets to {OUTPUT_BASE}")


Sanity-check that the saved arrays can be reloaded.

In [None]:

loaded = np.load(OUTPUT_BASE / "subset.npz")
print({key: loaded[key].shape for key in loaded.files})


Stop the Spark session once preprocessing is complete.

In [None]:
spark.stop()