In [None]:
import ray
import pyarrow as pa
import pyarrow.parquet as pq

In [None]:
NUN_ROWS_PER_ROW_GROUP = 100_000
COLUMN_TO_CAST_FLOAT_LIST_TYPE = "values"

# SPREAD strategy ensures that tasks are scheduled across cluster evenly.
# Retry on `OSError` to avoid failure caused by S3 rate limiting/other related issues.
@ray.remote(num_cpus=1, max_retries=20, scheduling_strategy="SPREAD", retry_exceptions=[OSError])
def merge_files(bucket, task_index, *read_task_list):
    """ Read files from read_task_list, and write to parquet files to `bucket`.
    In each output file, each row group has roughly `NUN_ROWS_PER_ROW_GROUP` rows
    and `NUM_ROWS_PER_FILE` total rows per file."""
    file_name = f"{bucket}/{task_index}.parquet"
    writer = None
    total_rows = 0
    current_rows = 0
    current_batches = []

    def write_current_batches():
        nonlocal current_batches, current_rows, writer, total_rows, file_name
        new_batch = pa.concat_tables(current_batches)
        
        # Cast the specified column to list<float32> data type.
        values_field = new_batch.schema.field(COLUMN_TO_CAST_FLOAT_LIST_TYPE)
        values_field_idx = new_batch.schema.get_field_index(COLUMN_TO_CAST_FLOAT_LIST_TYPE)
        new_values_type = pa.list_(pa.float32())
        new_values_field = values_field.with_type(new_values_type)
        new_schema = new_batch.schema.set(values_field_idx, new_values_field)
        new_batch = new_batch.cast(new_schema)

        if writer is None:
            writer = pq.ParquetWriter(file_name, schema=new_schema)
        writer.write(new_batch)
        total_rows += new_batch.num_rows
        current_batches = []
        current_rows = 0

    for read_task in read_task_list:
        for batch in read_task():
            batch.replace_schema_metadata()

            current_rows += batch.num_rows
            current_batches.append(batch)
            if current_rows >= NUN_ROWS_PER_ROW_GROUP:
                write_current_batches()

    if current_rows > 0:
        write_current_batches()
    writer.close()
    return total_rows

In [None]:
# Output path containing embeddings generated from notebook `1_generate_embeddings.ipynb`
embedding_input_path = "YOUR-EMBEDDINGS-BUCKET-HERE"
ds = ray.data.read_parquet(embedding_input_path)

num_rows = ds.count()
num_input_files = len(ds._logical_plan._dag._datasource._pq_fragments)
read_tasks = ds._logical_plan._dag._datasource.get_read_tasks(num_input_files)

print(f"Number of rows: {num_rows}")
print(f"Number of input files: {num_input_files}")
print(f"Number of read tasks: {len(read_tasks)}")

output_path = "YOUR-OUTPUT-BUCKET-HERE"
print(f"Output path: {output_path}")

In [None]:
# Write 1M rows per file.
NUM_ROWS_PER_FILE = 1e6

num_rows_per_file = [i.num_rows for i in ds._logical_plan._dag._datasource._metadata]
assert len(num_rows_per_file) == len(read_tasks)
read_tasks_with_rows = list(zip(num_rows_per_file, read_tasks))

# Decide grouping of input files such that the merged output 
# contains NUM_ROWS_PER_FILE rows per file.
# Use First Fit Decreasing algorithm to merge input files into output files (bin packing).
read_tasks_with_rows.sort(key=lambda t: t[0], reverse=True)
merged_read_tasks = []
for current_num_rows, current_task in read_tasks_with_rows:
    found_bin = False
    for idx, this_bin in enumerate(merged_read_tasks):
        total_rows, task_list = this_bin
        if total_rows + current_num_rows <= NUM_ROWS_PER_FILE:
            found_bin = True
            total_rows += current_num_rows
            task_list.append(current_task)
            merged_read_tasks[idx] = (total_rows, task_list)
            break
    if not found_bin:
        merged_read_tasks.append((current_num_rows, [current_task]))

num_output_files = len(merged_read_tasks)
print(f"Number of output files: {num_output_files}")
for num_rows, task_list in merged_read_tasks:
    print(f"Number of output rows: {num_rows}, Number of input files: {len(task_list)}")


In [None]:
# Main Ray code to run the file merging and write.
%%time

result = []
for i, this_bin in enumerate(merged_read_tasks):
    result.append(merge_files.remote(output_path, i, *(this_bin[1])))
ray.get(result)