feat: CISPO implementation#2187
Conversation
Signed-off-by: slikhite-1 <slikhite@nvidia.com>
Signed-off-by: slikhite-1 <slikhite@nvidia.com>
Signed-off-by: slikhite-1 <slikhite@nvidia.com>
Signed-off-by: slikhite-1 <slikhite@nvidia.com>
Signed-off-by: slikhite-1 <slikhite@nvidia.com>
|
Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
| @@ -0,0 +1,25 @@ | |||
| defaults: "grpo_math_1B.yaml" | |||
There was a problem hiding this comment.
The filename says 8B but this inherits from grpo_math_1B.yaml which sets model_name: "Qwen/Qwen2.5-1.5B", and there's no model override here.
Can you provide async lag-1 results with and without this change? That would help validate correctness.
For the nightly test, please rename this config to match the actual model (Qwen 1.5B) and add:
-
Recipe YAML at
examples/configs/recipes/llm/cispo-qwen2.5-math-1.5b-instruct-<nodes>n<gpus>g-<strategy>.yaml— setdefaults: ../../grpo_math_1B.yaml, enableuse_cispounderloss_fn:, and add model/cluster overrides. Seeexamples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yamlfor the pattern. -
Driver script at
tests/test_suites/llm/cispo-qwen2.5-math-1.5b-instruct-<nodes>n<gpus>g-<strategy>.sh— sourcecommon.env, calluv run examples/run_grpo.py --config $CONFIG_PATH ..., dump TB logs, and runcheck_metrics.py. Seetests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.shfor the pattern. -
Add to nightly list — append the driver script path to
tests/test_suites/nightly.txt.
| if self.use_cispo: | ||
| assert not self.disable_ppo_ratio, ( | ||
| "use_cispo is incompatible with disable_ppo_ratio; " | ||
| ) |
There was a problem hiding this comment.
Add an assertion here to guard against ratio_clip_c being set with CISPO — the dual-clipping block below will silently corrupt the CISPO loss if both are enabled.
| self.use_cispo = cfg.get("use_cispo", False) | ||
| if self.use_cispo: | ||
| assert not self.disable_ppo_ratio, ( | ||
| "use_cispo is incompatible with disable_ppo_ratio; " |
There was a problem hiding this comment.
| "use_cispo is incompatible with disable_ppo_ratio; " | |
| "use_cispo is incompatible with disable_ppo_ratio; CISPO needs the pi_theta/pi_theta_old ratio but disable_ppo_ratio removes it" |
| assert self.loss_type == LossType.SEQUENCE_LEVEL, ( | ||
| "sequence-level importance sampling (e.g. GSPO) is mutually exclusive with token-level loss" | ||
| ) | ||
| self.use_cispo = cfg.get("use_cispo", False) |
There was a problem hiding this comment.
Should be cfg["use_cispo"] — no hidden default. Also, please add use_cispo: false to examples/configs/grpo_math_1B.yaml under loss_fn: alongside the other boolean flags.
|
@slikhite-1 , hi this is Peng. Are you still working on this? If you are busy with other stuff, I am happy to address the comments and push this forward. Thank you! |
What does this PR do ?
This PR implements CISPO (Clipped Importance Sampling Policy Optimization) Algorithm from Minimax M1 paper.
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information