-
Notifications
You must be signed in to change notification settings - Fork 220
Revert "[skyrl-train][step-wise] 1/N - Support step-wise training with step_wise_training flag"
#706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…h `step_…" This reverts commit a30405f.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request reverts the integration of step-wise training from the core library and refactors it into a self-contained example. This is a good approach to isolate the experimental feature and fix the bug mentioned in the description. The changes correctly decouple the step-wise logic from the main trainer and generator.
My review focuses on the newly added/moved example code. I've identified a few areas for improvement regarding code clarity, maintainability, and performance. Specifically, I've suggested renaming a method and a dataclass in the StepWiseGenerator to avoid confusion and Liskov Substitution Principle violations, pointed out a potentially incorrect check that restricts the use of custom chat templates, and highlighted a potential performance bottleneck in the StepWiseTrainer due to a GPU-to-CPU data transfer.
| @dataclass | ||
| class AgentLoopOutput: | ||
| """Output from a single agent_loop execution.""" | ||
|
|
||
| response_ids: List[int] | ||
| reward: Union[List[float], float] | ||
| stop_reason: str | ||
| loss_mask: List[int] | ||
| prompt_ids: List[int] | ||
| rollout_logprobs: Optional[List[float]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This AgentLoopOutput dataclass is a near-duplicate of the one in skyrl_train.generators.skyrl_gym_generator. However, they represent different concepts: here it's the output of a single step, while in the base class it's the output of a whole trajectory. This name collision is confusing.
To improve clarity and avoid duplication, consider renaming this dataclass to something more specific, like StepOutput. This would make the code easier to understand and maintain.
| last_step_advantages, last_step_returns = ppo_utils.compute_advantages_and_returns( | ||
| token_level_rewards=last_step_rewards, | ||
| response_mask=response_mask[is_last_step], | ||
| index=index[is_last_step.cpu().numpy()], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expression is_last_step.cpu().numpy() involves a GPU-to-CPU data transfer within the training loop, which can be a performance bottleneck, especially with large batch sizes. Since index is a NumPy array of strings, this transfer is necessary for boolean indexing.
To optimize this, consider keeping is_last_step on the CPU if possible, or explore ways to perform the grouping logic on the GPU before this step to avoid the synchronization overhead.
…h `step_wise_training` flag" (NovaSky-AI#706) Reverts NovaSky-AI#694 See NovaSky-AI#694 (comment) The PR expects `trajectory_ids` to always be in the generator output, which currently is not enforced and is breaking. `run_gsm8k.sh` fails with https://gist.github.com/CharlieFRuan/cbbef69fde60a20d483d03efb13d60bb
Reverts #694
See #694 (comment)
The PR expects
trajectory_idsto always be in the generator output, which currently is not enforced and is breaking.run_gsm8k.shfails with https://gist.github.com/CharlieFRuan/cbbef69fde60a20d483d03efb13d60bb