In [None]:
import ray
import pandas as pd
import uuid
import numpy as np
from time import time
from langchain.text_splitter import RecursiveCharacterTextSplitter

ray.init()

In [None]:
# CSV file with a single column `files`, containing URLs to parquet files. 
INPUT_FILES_CSV = "INPUT-FILE-PATHS.csv"

input_files = pd.read_csv(INPUT_FILES_CSV)['files'].tolist()
len(input_files), input_files[0]

In [None]:
# Limit the read parallelism here, to avoid unnecessary excessive splitting.
ds = ray.data.read_parquet(
    input_files, 
    columns=['content', 'url', 'timestamp'],
    parallelism=2000,
)

In [None]:
# Main chunking logic
chunk_size = 512
words_to_tokens = 1.2
num_tokens = int(chunk_size // words_to_tokens)

num_words = lambda x: len(x.split())
splitter = RecursiveCharacterTextSplitter(
    chunk_size=num_tokens,
    keep_separator=True, 
    length_function=num_words, 
    chunk_overlap=0,
)

def chunk(row):
    length = num_words(row['content']) * 1.2
    if length < 20 or length > 4000:
        return []
    chunks = splitter.split_text(row['content'])
    document_id = str(uuid.uuid5(uuid.NAMESPACE_URL, row['url']))
    rand_coeff = np.random.rand(1).astype(np.float32)
    metadata = {
        'source': row['url'],
        'timestamp': row['timestamp'].timestamp(),
        'document_id': document_id,
        'rand_coeff': rand_coeff
    }
    
    new_rows = [
        {
            'id': f"{document_id}_{i}",
            'text': chunk,
            'metadata': metadata
        } 
        for i, chunk in enumerate(chunks)]

    return new_rows

In [None]:
data_context = ray.data.DataContext.get_current()

# Allow for up to 1 write failure per block, since the job is long.
data_context.max_errored_blocks = len(input_files)

# Set the max target block size and batch size in the `map_batches` call below
# in order to control the output file sizes in the hundreds of MBs.
data_context.target_max_block_size = 400_000_000

In [None]:
%%time

chunked_ds = (
    ds.flat_map(
        chunk
    ).map_batches(lambda x: x, batch_size=350_000)
)

# Write the chunked input files to cloud storage, 
# e.g. add an S3 path below.
CHUNK_OUTPUT_PATH = "YOUR-OUTPUT-BUCKET-HERE"
chunked_ds.write_parquet(CHUNK_OUTPUT_PATH)
print(f"===> Done writing chunked files to {CHUNK_OUTPUT_PATH}")