-
Notifications
You must be signed in to change notification settings - Fork 220
[skyrl-train][step-wise] 1/N - Support step-wise training with step_wise_training flag
#694
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| trajectory_ids: Optional[List[TrajectoryID]] | ||
| # Applicable only for step-wise training | ||
| is_last_step: Optional[List[bool]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both the fields are optional right now since it's not a hard requirement for all generators to send this over, and only required for step wise training
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR is a good first step towards integrating step-wise training natively into skyrl-train. It successfully refactors the logic from the examples/step_wise directory into the core library, controlled by a new step_wise_training flag. The changes are well-contained and the new flag provides a clear way to enable the feature. The removal of custom entrypoints and trainers for the example and moving the logic to the base classes is a great improvement for maintainability. I have one major concern regarding the padding logic for is_last_step which might lead to incorrect advantage calculations. Please see my detailed comment.
| additional_dims = tuple(tensor.shape[1:]) if len(tensor.shape) > 1 else () | ||
|
|
||
| if key == "is_last_step": | ||
| padding_tensor = torch.ones(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The padding for is_last_step should be False (i.e., torch.zeros) instead of True (torch.ones). When is_last_step is True for padded rows, they are incorrectly included in the advantage calculation for step-wise training. This can lead to incorrect advantages being computed, as rewards from padded rows (which are cloned from other valid rows) are used as if they are from a final step of a trajectory. Although these padded rows are masked out from the loss calculation, the incorrect advantage values could still affect metrics and potentially other parts of the training logic in the future.
| padding_tensor = torch.ones(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device) | |
| padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is incorrect
CharlieFRuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much! Only one comment. We could add some unit tests as follow ups
| response_ids=response_ids, | ||
| reward=step_reward, | ||
| loss_mask=copy.deepcopy(loss_mask), | ||
| prompt_ids=copy.deepcopy(input_ids[:current_prompt_length]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This commet is for the line response_ids = copy.deepcopy(input_ids[current_prompt_length:]).
This input_ids is after we added the observation tokens and next turn's generation prompt right? Shouldn't the response IDs just be output_ids?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm aren't both equivalent ways for step wise training? i.e you could treat (assistant response + obs, reward) as a step vs just count (assistant resp, reward) as a step?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, was just wondering if this adds additional computation. Indeed I don't think the current way affects correctness
| ) / len(response_ids) | ||
|
|
||
| logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}") | ||
| training_input = self.pad_batch(training_input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wanted to make sure the PR doesn't break existing flow. Would this be a no-op if we're not doing step-wise training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct! The pad_batch logic is actually very similar to the initial _remove_tail_data logic - if the batch is already divisible by the DP dimensions then there's no need for padding. The pad_batch statement is also inserted in convert_to_training_input after generation has fully finished. Without step wise training, in existing flow, there are two branches: with and without dynamic sampling. In both cases, the batch size should be train_batch_size*num_prompts (with tail data trimmed) - and padding should be zero
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it, thanks for the explanation!
| if self.cfg.trainer.step_wise_training: | ||
| avg_rewards: float = return_sums[data["is_last_step"][: num_samples - pad_size]].mean().item() | ||
| else: | ||
| avg_rewards: float = return_sums.mean().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would the changes here be no-op if we're not doing step-wise training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! pad_size is 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Go it!
| ) | ||
| training_input.metadata = { | ||
| "uids": uids, | ||
| "trajectory_ids": [trajectory_id.to_string() for trajectory_id in generator_output["trajectory_ids"]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just realized this is breaking. We now require trajectory_ids to be a required field.
…h `step_wise_training` flag" (#706) Reverts #694 See #694 (comment) The PR expects `trajectory_ids` to always be in the generator output, which currently is not enforced and is breaking. `run_gsm8k.sh` fails with https://gist.github.com/CharlieFRuan/cbbef69fde60a20d483d03efb13d60bb
…h `step_wise_training` flag" (NovaSky-AI#706) Reverts NovaSky-AI#694 See NovaSky-AI#694 (comment) The PR expects `trajectory_ids` to always be in the generator output, which currently is not enforced and is breaking. `run_gsm8k.sh` fails with https://gist.github.com/CharlieFRuan/cbbef69fde60a20d483d03efb13d60bb
What does this PR do?
Supports step wise training natively with
step_wise_trainingCurrently step-wise training introduces a new generator output format, and some custom book-keeping in the agent loop. For the first integration, we add this functionality as a separate generator.
I plan to have a follow-up PR where we simplify this and have the logic in the base generator.
TODO:
step_wiseexample to use the same flagE2E Run:
Step wise training for SkyRL-SQL:
Reference run: