In [None]:
from vllm.benchmarks.datasets import PrefixRepetitionRandomDataset
from vllm.transformers_utils.tokenizer import get_tokenizer
import daft
from daft.functions import monotonically_increasing_id
import ray
ray.init()
daft.set_runner_ray()

In [None]:
dataset = PrefixRepetitionRandomDataset(random_seed=42)
tokenizer = get_tokenizer("Qwen/Qwen3-8B")
root_dir = "s3://my-bucket/vllm-prefix-caching-partitioned"

def generate_data(request_k: int, num_prefixes: int):
    sample = dataset.sample(
        tokenizer=tokenizer,
        num_requests=request_k * 1000,
        prefix_len=256,
        suffix_len=256,
        num_prefixes=num_prefixes,
        output_len=128,
    )

    sample = [s.prompt for s in sample]
    
    df = daft.from_pydict(
    {
        "prompt": sample,
        }
    )

    df = df.select(monotonically_increasing_id().alias("id"), "prompt")
    df = df.repartition(8)
    return df.write_parquet(f"{root_dir}/{request_k}k_0-5_{num_prefixes}.parquet")


def generate_data_no_prefix(request_k: int):
    sample = dataset.sample(
        tokenizer=tokenizer,
        num_requests=request_k * 1000,
        prefix_len=0,
        suffix_len=512,
        output_len=128,
    )

    sample = [s.prompt for s in sample]
    
    df = daft.from_pydict(
    {
        "prompt": sample,
        }
    )

    df = df.select(monotonically_increasing_id().alias("id"), "prompt")
    df = df.repartition(8)
    return df.write_parquet(f"{root_dir}/{request_k}k_0.parquet")

In [None]:
generate_data(200, 8)
generate_data(200, 64)
generate_data(200, 512)
generate_data(200, 4096)
generate_data_no_prefix(200)