-
Notifications
You must be signed in to change notification settings - Fork 668
feat(tinker): Add support for built-in loss functions and checkpoint control #523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
looks like some checks aren't passing. and have we tested this? |
No, I tested some of the more foundational pieces in earlier PRs (eg. the move from I can wait to merge until that run is working probably. (This will likely turn into a mega-PR with lots of unrelated changes along the way though as I find more pieces I need to update.) |
…control Add two new features to TinkerBackend: 1. Built-in loss functions (tinker_loss_fn, tinker_loss_fn_config) - Supports Tinker's optimized losses: importance_sampling, ppo, cispo, dro - Uses forward_backward_async instead of forward_backward_custom_async - ~1.5x fewer FLOPs, up to 3x faster (per Tinker docs) - Default behavior unchanged (uses ART's custom loss) 2. Checkpoint control (save_checkpoint parameter) - When False, only saves sampler weights (fast, for inference) - When True (default), saves full state + optimizer (for resumption) - Enables faster training when full checkpoints only needed at intervals Both features are backwards-compatible - existing code works unchanged.
…ain() - Add optional adam_beta1, adam_beta2, adam_eps parameters to train() - Pass through to TinkerService via dev_config - Use params when calling optim_step_async with tinker.AdamParams This allows customization of Adam optimizer hyperparameters, which is needed when using non-default values (e.g., beta2=0.95 instead of 0.999).
Add adam_beta1, adam_beta2, and adam_eps to fix Pyright type errors when assigning these keys to the dev_config dict.
- Update shift_tensor to support both 1D and 2D tensors - Replace NaN values in logprobs before JSON serialization to Tinker API - Guard Qwen3InstructRenderer patch for older tinker_cookbook versions
Previously, if port 8000 was already in use, the server would bind to a different port via get_free_port() but the client would still try to connect to port 8000, causing connection failures. Now the port is determined once upfront and passed to both the server and client.
Summary
This PR adds two new features to
TinkerBackendthat are fully backwards-compatible with existing code.1. Built-in Loss Functions
Adds support for Tinker's optimized built-in loss functions via new parameters on
TinkerBackend.train():tinker_loss_fn: Select from"importance_sampling","ppo","cispo","dro"tinker_loss_fn_config: Pass loss-specific config (e.g.,{"clip_low_threshold": 0.0, "clip_high_threshold": 6.0})Benefits:
forward_backward_asyncinstead offorward_backward_custom_asyncDefault behavior unchanged - when
tinker_loss_fn=None(default), continues to use ART's custom loss implementation.2. Checkpoint Control
The existing
save_checkpointparameter now controls checkpoint behavior in TinkerBackend:save_checkpoint=True(default): Saves full state + optimizer (enables training resumption)save_checkpoint=False: Only saves sampler weights (fast, for inference only)This enables faster training when full checkpoints are only needed at specific intervals (e.g., at eval steps).
Usage
Files Changed
src/art/dev/train.py: Addedtinker_loss_fn,tinker_loss_fn_config,tinker_save_checkpointto TrainConfigsrc/art/tinker/backend.py: Overrodetrain()with new parameterssrc/art/tinker/service.py: Added dispatch logic for built-in vs custom loss, added_save_sampler_weights_only()methodBackwards Compatibility
All existing code continues to work unchanged. The new parameters are optional with sensible defaults that preserve current behavior.