Skip to content

feat: add callback hooks to standalone GRPO trainer#198

Merged
abrichr merged 1 commit into
mainfrom
fix/eval-infra-and-forced-override
Mar 28, 2026
Merged

feat: add callback hooks to standalone GRPO trainer#198
abrichr merged 1 commit into
mainfrom
fix/eval-infra-and-forced-override

Conversation

@abrichr
Copy link
Copy Markdown
Member

@abrichr abrichr commented Mar 28, 2026

Summary

Four optional callback hooks for the standalone GRPO trainer, eliminating the need for monkey-patching. Requested by customer to replace 3 of their 6 monkey-patches.

Callbacks

Hook Signature Use case
on_model_loaded (model, processor) -> None Gradient checkpointing, custom submodule setup
on_before_collect (task_id, env) -> None WAA health checks, tunnel verification
on_rollout_complete (rollout, index) -> None Per-rollout W&B logging, screenshot capture
on_step_complete (step, rollouts, metrics) -> None Per-step W&B logging, early stopping

Usage

import wandb

def log_step(step, rollouts, metrics):
    wandb.log({"reward": metrics["reward_mean"], "loss": metrics["loss"]}, step=step)

def check_waa(task_id, env):
    if not env.health_check():
        raise ConnectionError("WAA tunnel died")

trainer = GRPOTrainer(
    config,
    on_step_complete=log_step,
    on_before_collect=check_waa,
)
trainer.train()

Test plan

  • All callbacks default to None (no-op), backward compatible
  • Compiles and signature matches documented types

🤖 Generated with Claude Code

Four optional callback hooks eliminate the need for monkey-patching:

- on_model_loaded(model, processor): Custom model setup (gradient
  checkpointing on specific submodules, hook attachment)
- on_before_collect(task_id, env): WAA health checks, tunnel
  verification, task-specific setup before rollout collection
- on_rollout_complete(rollout, index): Per-rollout W&B logging,
  screenshot/thought capture
- on_step_complete(step, rollouts, metrics): Per-step W&B logging,
  early stopping, custom evaluation

All callbacks are keyword-only with None defaults (no-op).
Eliminates 3 of 6 monkey-patches reported by customer.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@abrichr abrichr merged commit 28b0193 into main Mar 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant