Skip to content

Add dataset_processor_path CLI knob for custom datasets#4031

Merged
copybara-service[bot] merged 1 commit into
mainfrom
pr/dataset-processor-path
Jun 8, 2026
Merged

Add dataset_processor_path CLI knob for custom datasets#4031
copybara-service[bot] merged 1 commit into
mainfrom
pr/dataset-processor-path

Conversation

@py4

@py4 py4 commented Jun 1, 2026

Copy link
Copy Markdown
Collaborator

Add a dataset_processor_path CLI/yaml knob that lets users plug in a custom process_data(dataset_name, model_tokenizer, template_config, tmvp_config, x) -> dict function from a user-provided Python file, instead of editing maxtext to support a new dataset shape.

Why: the built-in utils_rl.process_data is hardcoded for a small set of dataset schemas (GSM8K, etc.). For users running RL on custom datasets with different answer columns / cleaning rules, the alternative was either (1) edit maxtext source (fork divergence) or (2) reformat the dataset to look like GSM8K (lossy). This knob gives a clean third option: ship your dataset processor as a Python file and point maxtext at it.

Changes (2 files, +41/-16 lines):

  • src/maxtext/trainers/post_train/rl/train_rl.py:
    • New _load_custom_callable(module_path, function_name) helper that uses importlib.util.spec_from_file_location to load a function from an arbitrary .py file (without adding to sys.path).
    • prepare_datasets checks trainer_config.dataset_processor_path; if set, loads process_data from that file and substitutes for utils_rl.process_data in the dataset pipeline.
  • src/maxtext/configs/post_train/rl.yml: new top-level knob dataset_processor_path: '' with comment documenting the signature contract.

Backward compatible: default empty string falls back to utils_rl.process_data (identical to old behavior). The _load_custom_callable helper is only invoked when the user explicitly sets the path.

User-facing contract:

# user_process_data.py
def process_data(dataset_name, model_tokenizer, template_config, tmvp_config, x):
    return {"prompts": ..., "question": ..., "answer": ...}
python3 -m maxtext.trainers.post_train.rl.train_rl rl.yml \
  dataset_processor_path=/path/to/user_process_data.py \
  ...

Checklist

  • Tested locally with a custom processor file (VTC-style raw-text prompt template); produced expected outputs
  • Backward compatible: default empty string preserves utils_rl.process_data behavior
  • No effect on non-RL paths (only prepare_datasets in the RL trainer touched)
  • _load_custom_callable doesn't pollute sys.path (uses spec_from_file_location)

@codecov

codecov Bot commented Jun 1, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 73.68421% with 5 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/post_train/rl/train_rl.py 40.00% 2 Missing and 1 partial ⚠️
src/maxtext/trainers/post_train/rl/utils_rl.py 85.71% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@py4 py4 force-pushed the pr/dataset-processor-path branch 3 times, most recently from 6d91a2c to 2842052 Compare June 2, 2026 21:16
Comment thread tests/post_training/unit/load_custom_callable_test.py

@khatwanimohit khatwanimohit left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@A9isha A9isha left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one request for the refactor - the rest look good, thank you Pooya!

Comment thread src/maxtext/trainers/post_train/rl/train_rl.py Outdated
@py4 py4 force-pushed the pr/dataset-processor-path branch 2 times, most recently from 2ea1e05 to 2fc7b50 Compare June 5, 2026 22:32
Currently utils_rl.process_data hard-branches on `dataset_name == "openai/gsm8k"`
to call `extract_hash_answer`. Other datasets either work as-is (if their
`answer` column is already clean) or require editing utils_rl directly.

Add an optional `dataset_processor_path` config: a filesystem path to a
user-provided Python file with a `process_data(dataset_name, tokenizer,
template_config, tmvp_config, x) -> dict` function. When set, that function
replaces the built-in one for all train/eval dataset map() calls.

Default (`dataset_processor_path: ''`) keeps existing behavior unchanged.

Also adds `_load_custom_callable` helper used by this and the upcoming
custom reward CLI knob.
@py4 py4 force-pushed the pr/dataset-processor-path branch from 2fc7b50 to f2d4f3b Compare June 8, 2026 18:12
@copybara-service copybara-service Bot merged commit f93627f into main Jun 8, 2026
30 checks passed
@copybara-service copybara-service Bot deleted the pr/dataset-processor-path branch June 8, 2026 21:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants