[Examples] Add an example for step-wise training#436
[Examples] Add an example for step-wise training#436SumanthRH merged 44 commits intoNovaSky-AI:mainfrom
Conversation
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…into sumanthrh/gptoss
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
…epwise Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a valuable example for step-wise training, which is a great addition. The implementation of custom components like the StepWiseGenerator and StepWiseTrainer is well-structured. I've identified a few areas for improvement, primarily concerning correctness in the example's run script and padding logic, as well as some opportunities to enhance performance and code readability. Addressing these points will make the example more robust and easier to follow.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
| additional_dims = tuple(tensor.shape[1:]) if len(tensor.shape) > 1 else () | ||
|
|
||
| if key == "is_last_step": | ||
| padding_tensor = torch.ones(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device) |
There was a problem hiding this comment.
I didn't follow why this should be ones instead of zeros -- can you explain?
There was a problem hiding this comment.
is_last_step needs to have a one non-zero entry for each trajectory - each trajctory has atleast one last step... if you add all zeros, that means the padding trajectories have no last step at all.
There was a problem hiding this comment.
Ideally pad_batch is very generic and we can use it in other examples as well, but is_last_step is special and I'd rather do the padding correctly here
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces an example for step-wise training, where each turn in a conversation is treated as a separate sample. This is a significant feature addition, implemented through a new StepWiseGenerator and StepWiseTrainer that customize the data generation and training loop. The changes also include a custom advantage estimation function tailored for outcome rewards in a multi-turn setting and a corresponding evaluation function.
My review focuses on the clarity and maintainability of the new implementation. The core logic for step-wise processing appears correct. I've suggested minor improvements to the method signatures in StepWiseTrainer to enhance code clarity by explicitly marking unused parameters inherited from the base class. The modifications to existing utility functions to support this new training paradigm are well-designed for extensibility.
| if generator_output["rollout_metrics"] is not None: | ||
| self.all_metrics.update(generator_output["rollout_metrics"]) | ||
|
|
||
| # don't validate - will error out |
There was a problem hiding this comment.
Could you add just a little more detail on why it will error out, just for posterity :)
# What does this PR do? Adds an example for step-wise training where each turn is represented as an individual sample in the batch. Currently, the example still assumes outcome rewards. Implements : - A custom generator for providing inputs and outputs at each step as an individual sample - A custom trainer that can handle step-wise generator output - A custom advantage estimation function that will compute advantages for the last step and broadcast it to the other steps in that trajectory - A custom evaluation function to calculate metrics correctly Currently this uses TITO with multi-turn chat templating for a simple demonstration. We simply append responses and observations to a running list of input ids. The generator is not yet compatible with qwen3 or gpt-oss like chat templating - where think tokens are removed - this will be added as a follow-up There are many bits that can be cleaned up (for example, padding logic is brittle at the moment given special handling for the tensor `is_last_step`) but it works as an initial example. I've tested convergence with SkyRL2SQL for the first 20 steps and it seems to match the original wandb curve <img width="1045" height="595" alt="Screenshot 2025-10-08 at 4 39 13 PM" src="https://github.com/user-attachments/assets/78b1f135-cb0a-4553-afc9-d032bc1459a7" /> <img width="1037" height="636" alt="Screenshot 2025-10-08 at 4 39 25 PM" src="https://github.com/user-attachments/assets/5d44fcf6-9d21-4e3f-8f02-c12843fbaecd" /> Original curve for reference: https://wandb.ai/sky-posttraining-uc-berkeley/skyrl-sql/reports/SkyRL-SQL---VmlldzoxMzM0MTAyMw?accessToken=vrqncoa32qcobvvpuo672yji4gweguk6tjxvaflk1zh73fn70j6l5rj8j619uvry --------- Signed-off-by: SumanthRH <sumanthrh99@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
What does this PR do?
Adds an example for step-wise training where each turn is represented as an individual sample in the batch.
Currently, the example still assumes outcome rewards.
Implements :
Currently this uses TITO with multi-turn chat templating for a simple demonstration. We simply append responses and observations to a running list of input ids. The generator is not yet compatible with qwen3 or gpt-oss like chat templating - where think tokens are removed - this will be added as a follow-up
There are many bits that can be cleaned up (for example, padding logic is brittle at the moment given special handling for the tensor
is_last_step) but it works as an initial example.I've tested convergence with SkyRL2SQL for the first 20 steps and it seems to match the original wandb curve
Original curve for reference: https://wandb.ai/sky-posttraining-uc-berkeley/skyrl-sql/reports/SkyRL-SQL---VmlldzoxMzM0MTAyMw?accessToken=vrqncoa32qcobvvpuo672yji4gweguk6tjxvaflk1zh73fn70j6l5rj8j619uvry