Skip to content

Conversation

@corbt
Copy link
Contributor

@corbt corbt commented Jan 20, 2026

Summary

This PR adds two new features to TinkerBackend that are fully backwards-compatible with existing code.

1. Built-in Loss Functions

Adds support for Tinker's optimized built-in loss functions via new parameters on TinkerBackend.train():

  • tinker_loss_fn: Select from "importance_sampling", "ppo", "cispo", "dro"
  • tinker_loss_fn_config: Pass loss-specific config (e.g., {"clip_low_threshold": 0.0, "clip_high_threshold": 6.0})

Benefits:

  • Uses forward_backward_async instead of forward_backward_custom_async
  • ~1.5x fewer FLOPs per training step
  • Up to 3x faster wall time (per Tinker docs)

Default behavior unchanged - when tinker_loss_fn=None (default), continues to use ART's custom loss implementation.

2. Checkpoint Control

The existing save_checkpoint parameter now controls checkpoint behavior in TinkerBackend:

  • save_checkpoint=True (default): Saves full state + optimizer (enables training resumption)
  • save_checkpoint=False: Only saves sampler weights (fast, for inference only)

This enables faster training when full checkpoints are only needed at specific intervals (e.g., at eval steps).

Usage

backend = art.TinkerBackend()

# Use Tinker's built-in CISPO (faster)
result = await backend.train(
    model,
    trajectory_groups,
    learning_rate=5e-6,
    tinker_loss_fn="cispo",
    tinker_loss_fn_config={"clip_low_threshold": 0.0, "clip_high_threshold": 6.0},
    save_checkpoint=False,  # Fast: only saves sampler weights
)

# Use ART's custom loss (default, backwards compatible)
result = await backend.train(
    model,
    trajectory_groups,
    learning_rate=5e-6,
    ppo=False,
    epsilon=1.0,
)

Files Changed

  • src/art/dev/train.py: Added tinker_loss_fn, tinker_loss_fn_config, tinker_save_checkpoint to TrainConfig
  • src/art/tinker/backend.py: Overrode train() with new parameters
  • src/art/tinker/service.py: Added dispatch logic for built-in vs custom loss, added _save_sampler_weights_only() method

Backwards Compatibility

All existing code continues to work unchanged. The new parameters are optional with sensible defaults that preserve current behavior.

@corbt corbt requested a review from bradhilton January 20, 2026 02:28
@bradhilton
Copy link
Collaborator

bradhilton commented Jan 20, 2026

looks like some checks aren't passing. and have we tested this?

@corbt
Copy link
Contributor Author

corbt commented Jan 20, 2026

have we tested this?

No, I tested some of the more foundational pieces in earlier PRs (eg. the move from model.train to backend.train) but since this one is harder to test independently I was planning on testing it as part of the Simple code migration, which will depend on this functionality.

I can wait to merge until that run is working probably. (This will likely turn into a mega-PR with lots of unrelated changes along the way though as I find more pieces I need to update.)

Cursor Bot added 4 commits January 21, 2026 19:22
…control

Add two new features to TinkerBackend:

1. Built-in loss functions (tinker_loss_fn, tinker_loss_fn_config)
   - Supports Tinker's optimized losses: importance_sampling, ppo, cispo, dro
   - Uses forward_backward_async instead of forward_backward_custom_async
   - ~1.5x fewer FLOPs, up to 3x faster (per Tinker docs)
   - Default behavior unchanged (uses ART's custom loss)

2. Checkpoint control (save_checkpoint parameter)
   - When False, only saves sampler weights (fast, for inference)
   - When True (default), saves full state + optimizer (for resumption)
   - Enables faster training when full checkpoints only needed at intervals

Both features are backwards-compatible - existing code works unchanged.
…ain()

- Add optional adam_beta1, adam_beta2, adam_eps parameters to train()
- Pass through to TinkerService via dev_config
- Use params when calling optim_step_async with tinker.AdamParams

This allows customization of Adam optimizer hyperparameters, which is
needed when using non-default values (e.g., beta2=0.95 instead of 0.999).
Add adam_beta1, adam_beta2, and adam_eps to fix Pyright type errors
when assigning these keys to the dev_config dict.
Cursor Bot added 3 commits January 22, 2026 00:33
- Update shift_tensor to support both 1D and 2D tensors
- Replace NaN values in logprobs before JSON serialization to Tinker API
- Guard Qwen3InstructRenderer patch for older tinker_cookbook versions
Previously, if port 8000 was already in use, the server would bind to a
different port via get_free_port() but the client would still try to
connect to port 8000, causing connection failures. Now the port is
determined once upfront and passed to both the server and client.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants