Skip to content

refactor/ppo#1856

Merged
zhuzilin merged 10 commits into
THUDM:mainfrom
lilei199908:refactor_ppo
Apr 24, 2026
Merged

refactor/ppo#1856
zhuzilin merged 10 commits into
THUDM:mainfrom
lilei199908:refactor_ppo

Conversation

@lilei199908
Copy link
Copy Markdown
Collaborator

@lilei199908 lilei199908 commented Apr 24, 2026

future TODO:

  1. maybe not need foward_only for values.
  2. curret only surpot same parapall between actor and critic, maybe need more flexible.
  3. mayne can considerate muti ciritc
  4. more test

curret reward compare verl
up slime, down verl
image

Copilot AI review requested due to automatic review settings April 24, 2026 05:46
@lilei199908 lilei199908 changed the title refactor ppo [WIP] refactor ppo Apr 24, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Refactors PPO actor/critic training to decouple critic configuration from the actor and to move actor↔critic value sharing out of NCCL process-group broadcast and into Ray-mediated data passing (with DP/CP shard identity).

Changes:

  • Adds critic YAML override parsing (--critic-config-path) and a --custom-advantage-fn hook.
  • Reworks Ray training actors/groups to support passing per-worker external_data (e.g., critic values) into actor training.
  • Removes the actor↔critic NCCL “connect” path and related broadcast sync, replacing it with CPU tensor transfer helpers and shard-identity APIs.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
slime/utils/misc.py Adds helper to map critic value shards to actor workers by (DP, CP).
slime/utils/arguments.py Introduces critic YAML parsing + new CLI flags; removes old critic-only CLI flags; adjusts critic GPU assumptions.
slime/ray/train_actor.py Extends train() signature and exposes per-worker parallel value identity.
slime/ray/rollout.py Updates GPU offset/size computations (now excluding critic).
slime/ray/placement_group.py Splits critic into its own placement group and wires critic arg overrides into model creation.
slime/ray/actor_group.py Adds external_data plumbing to async_train() and exposes group shard identities.
slime/backends/megatron_utils/loss.py Adds custom_advantage_fn hook and persists computed KL into rollout data.
slime/backends/megatron_utils/data.py Replaces actor↔critic broadcast sync with CPU↔GPU tensor transfer helpers.
slime/backends/megatron_utils/arguments.py Removes critic-train-only world-size special-casing (world size now actor-based).
slime/backends/megatron_utils/actor.py Removes actor↔critic NCCL connect; critic returns CPU values; actor optionally consumes external CPU values.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread slime/utils/misc.py Outdated
Comment on lines +104 to +109
shard_to_values = {
(dp, cp): value["values"]
for (dp, cp, _, _), value in zip(critic_model.get_parallel_value_info(), per_worker_values, strict=False)
if "values" in value
}
return [{"values": shard_to_values[(dp, cp)]} for dp, cp, _, _ in actor_model.get_parallel_value_info()]
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critic_values_by_actor_worker will raise KeyError for actor workers that don't have a corresponding entry in shard_to_values (e.g., non-last PP stages, since critic train_critic() returns {} for those workers). Consider returning {} for workers with no values (or only generating entries for last-PP-stage workers), and validate the lengths/world-sizes so missing shards fail fast instead of being silently truncated by zip(..., strict=False).

Suggested change
shard_to_values = {
(dp, cp): value["values"]
for (dp, cp, _, _), value in zip(critic_model.get_parallel_value_info(), per_worker_values, strict=False)
if "values" in value
}
return [{"values": shard_to_values[(dp, cp)]} for dp, cp, _, _ in actor_model.get_parallel_value_info()]
critic_parallel_info = list(critic_model.get_parallel_value_info())
actor_parallel_info = list(actor_model.get_parallel_value_info())
if len(critic_parallel_info) != len(per_worker_values):
raise ValueError(
"critic parallel info and per_worker_values must have the same length: "
f"{len(critic_parallel_info)} != {len(per_worker_values)}"
)
critic_shards = {(dp, cp) for dp, cp, _, _ in critic_parallel_info}
actor_shards = {(dp, cp) for dp, cp, _, _ in actor_parallel_info}
missing_actor_shards = actor_shards - critic_shards
if missing_actor_shards:
raise ValueError(
"actor workers require (dp, cp) shards that do not exist in critic parallel info: "
f"{sorted(missing_actor_shards)}"
)
shard_to_values = {
(dp, cp): value["values"]
for (dp, cp, _, _), value in zip(critic_parallel_info, per_worker_values)
if "values" in value
}
return [
{"values": shard_to_values[(dp, cp)]} if (dp, cp) in shard_to_values else {}
for dp, cp, _, _ in actor_parallel_info
]

Copilot uses AI. Check for mistakes.
Comment on lines 96 to +114
logger.info(f"Creating placement group with {num_gpus} GPUs...")
pg, actor_pg_reordered_bundle_indices, actor_pg_reordered_gpu_ids = _create_placement_group(num_gpus)

rollout_pg_reordered_bundle_indices = actor_pg_reordered_bundle_indices[rollout_offset:]
rollout_pg_reordered_gpu_ids = actor_pg_reordered_gpu_ids[rollout_offset:]
if args.use_critic:
critic_pg_reordered_bundle_indices = actor_pg_reordered_bundle_indices[critic_offset:]
critic_pg_reordered_gpu_ids = actor_pg_reordered_gpu_ids[critic_offset:]

return {
result = {
"actor": (pg, actor_pg_reordered_bundle_indices, actor_pg_reordered_gpu_ids),
"critic": (pg, critic_pg_reordered_bundle_indices, critic_pg_reordered_gpu_ids) if args.use_critic else None,
"rollout": (pg, rollout_pg_reordered_bundle_indices, rollout_pg_reordered_gpu_ids),
}

if args.use_critic:
critic_num_gpus = args.critic_num_nodes * args.critic_num_gpus_per_node
logger.info(f"Creating critic placement group with {critic_num_gpus} GPUs...")
c_pg, c_bundle_indices, c_gpu_ids = _create_placement_group(critic_num_gpus)
result["critic"] = (c_pg, c_bundle_indices, c_gpu_ids)
else:
result["critic"] = None

return result
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a separate placement group for the critic removes gang-scheduling with the actor+rollout group. Because the actor+rollout PG is created (and pg.ready() awaited) before the critic PG, it's possible to reserve resources in a way that prevents the critic PG from ever being scheduled (resource fragmentation/deadlock when the cluster is near capacity). Consider allocating a single placement group for actor+critic+rollout again, or otherwise ensuring atomic/gang allocation (e.g., one PG with bundle offsets for each role, or a strategy that avoids partial allocation).

Copilot uses AI. Check for mistakes.
Comment thread slime/utils/arguments.py
Comment on lines +1488 to +1491
assert isinstance(critic_entries, list) and len(critic_entries) == 1, (
"critic config must contain exactly one entry under 'critic', e.g. "
"critic: [{name: default, overrides: {...}}]"
)
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parse_critic_args uses assert for user config validation. Asserts can be disabled with python -O, turning this into undefined behavior later. Prefer raising a ValueError (or RuntimeError) with the same message so validation is always enforced.

Suggested change
assert isinstance(critic_entries, list) and len(critic_entries) == 1, (
"critic config must contain exactly one entry under 'critic', e.g. "
"critic: [{name: default, overrides: {...}}]"
)
if not (isinstance(critic_entries, list) and len(critic_entries) == 1):
raise ValueError(
"critic config must contain exactly one entry under 'critic', e.g. "
"critic: [{name: default, overrides: {...}}]"
)

Copilot uses AI. Check for mistakes.
Comment thread slime/utils/arguments.py
# Critic always uses the same GPU count as actor.
args.critic_num_gpus_per_node = args.actor_num_gpus_per_node
args.critic_num_nodes = args.actor_num_nodes

Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slime_validate_args now forces critic_num_nodes/gpus_per_node to always match the actor, but some downstream logic still treats critic GPUs as part of the rollout/colocate GPU budget. Double-check that rollout_num_gpus and placement-group sizing remain consistent under --colocate with --advantage-estimator ppo; otherwise rollout server sizing/validation can diverge from the actual rollout placement group.

Suggested change
actor_num_gpus = args.actor_num_gpus_per_node * args.actor_num_nodes
if args.colocate and args.use_critic:
# Under colocated PPO, actor and critic share the same placement-group GPUs.
# Keep rollout sizing aligned with the colocated actor/critic world size instead
# of implicitly budgeting critic GPUs a second time.
if not args.rollout_num_gpus:
args.rollout_num_gpus = actor_num_gpus
else:
assert args.rollout_num_gpus == actor_num_gpus, (
"When colocate is enabled with advantage_estimator=ppo, "
"rollout_num_gpus must match the colocated actor/critic GPU count "
f"({actor_num_gpus}), got {args.rollout_num_gpus}."
)

Copilot uses AI. Check for mistakes.
Comment on lines +627 to 628
advantages, returns = rollout_data["advantages"], rollout_data["returns"]

Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compute_advantages_and_returns assumes a custom advantage function will always populate rollout_data['advantages'] and rollout_data['returns'], but if it doesn't, the code will fail later with a KeyError/unclear error. Consider validating after custom_adv_fn(...) returns and raising a clear ValueError if required keys are missing or have the wrong type/shape.

Suggested change
advantages, returns = rollout_data["advantages"], rollout_data["returns"]
missing_keys = [
key for key in ("advantages", "returns") if key not in rollout_data
]
if missing_keys:
raise ValueError(
f"Custom advantage function {args.custom_advantage_fn!r} must populate "
f"rollout_data with keys 'advantages' and 'returns'; missing keys: "
f"{missing_keys}."
)
advantages = rollout_data["advantages"]
returns = rollout_data["returns"]
if not isinstance(advantages, list) or not isinstance(returns, list):
raise ValueError(
f"Custom advantage function {args.custom_advantage_fn!r} must populate "
f"rollout_data['advantages'] and rollout_data['returns'] as lists of "
f"torch.Tensor objects, got {type(advantages).__name__} and "
f"{type(returns).__name__}."
)
expected_num_samples = len(kl)
if len(advantages) != expected_num_samples or len(returns) != expected_num_samples:
raise ValueError(
f"Custom advantage function {args.custom_advantage_fn!r} returned "
f"{len(advantages)} advantages and {len(returns)} returns, expected "
f"{expected_num_samples} of each."
)
for i, (advantage, ret, k) in enumerate(
zip(advantages, returns, kl, strict=False)
):
if not isinstance(advantage, torch.Tensor) or not isinstance(ret, torch.Tensor):
raise ValueError(
f"Custom advantage function {args.custom_advantage_fn!r} must return "
f"lists of torch.Tensor objects; sample {i} has types "
f"{type(advantage).__name__} and {type(ret).__name__}."
)
if advantage.shape != k.shape or ret.shape != k.shape:
raise ValueError(
f"Custom advantage function {args.custom_advantage_fn!r} returned "
f"incompatible tensor shapes for sample {i}: advantages shape "
f"{tuple(advantage.shape)}, returns shape {tuple(ret.shape)}, "
f"expected {tuple(k.shape)}."
)

Copilot uses AI. Check for mistakes.
Comment thread slime/utils/arguments.py
Comment on lines +749 to +753
"--num-critic-only-steps",
type=int,
default=0,
help="number of iterations to linearly warmup for critic model.",
help="Number of initial rollout steps that train critic only; set >= num_rollout for critic-only runs",
)
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change removes --critic-train-only (and related critic-* CLI args), but train.py and train_async.py still read args.critic_train_only in multiple places. After this PR, those entrypoints will raise AttributeError unless they are updated (or the flag is kept with a deprecated alias).

Copilot uses AI. Check for mistakes.
Comment thread slime/ray/actor_group.py
Comment on lines +111 to +129
def async_train(self, rollout_id, rollout_data_ref, external_data=None):
"""Do one rollout training. Returns a list of Ray refs (one per worker).

For critics, each ref resolves to ``{"values": [cpu tensors...]}`` (or ``{}``
for non-last-PP-stage workers). Actor refs resolve to ``None``.

``external_data`` may be a list (one item per worker) or a single dict
broadcast to all workers.
"""
if isinstance(external_data, list):
assert len(external_data) == len(self._actor_handlers)
return [
actor.train.remote(rollout_id, rollout_data_ref, external_data=ed)
for actor, ed in zip(self._actor_handlers, external_data, strict=False)
]
return [
actor.train.remote(rollout_id, rollout_data_ref, external_data=external_data)
for actor in self._actor_handlers
]
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

async_train() now supports external_data, but there are currently no call sites passing critic-produced values into the actor. With the removal of sync_actor_critic_data, PPO actor training will compute advantages without rollout_data['values'] and will fail in compute_advantages_and_returns() (PPO branch requires values). Wire the training loop to (1) ray.get() critic results, (2) map them to actor workers (DP+CP), and (3) pass per-worker external_data into actor_model.async_train(...).

Copilot uses AI. Check for mistakes.
Comment thread slime/utils/arguments.py
Comment on lines +1501 to +1508
ignored_keys = {"num_nodes", "num_gpus_per_node"}

# Apply overrides from the YAML config.
# Unspecified keys inherit from actor_args via deepcopy.
for key, value in critic_config.items():
if key in ignored_keys:
logger.info(f"Ignoring critic config key '{key}'; critic GPU allocation always follows actor.")
continue
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignored_keys = {"num_nodes", "num_gpus_per_node"} doesn't match the actual CLI attribute names in this repo (e.g., actor_num_nodes, actor_num_gpus_per_node, critic_num_nodes, ...). As written, a critic YAML could override actor_num_nodes/actor_num_gpus_per_node inside critic_args, leading to inconsistent DP calculations and confusing behavior. Consider explicitly ignoring/forbidding overrides of the GPU allocation keys that should always follow the actor (and maybe validating that the YAML doesn't contain them).

Suggested change
ignored_keys = {"num_nodes", "num_gpus_per_node"}
# Apply overrides from the YAML config.
# Unspecified keys inherit from actor_args via deepcopy.
for key, value in critic_config.items():
if key in ignored_keys:
logger.info(f"Ignoring critic config key '{key}'; critic GPU allocation always follows actor.")
continue
forbidden_gpu_override_keys = {
# Legacy generic names.
"num_nodes",
"num_gpus_per_node",
# Actual CLI attribute names used by this repo.
"actor_num_nodes",
"actor_num_gpus_per_node",
"critic_num_nodes",
"critic_num_gpus_per_node",
}
invalid_gpu_override_keys = sorted(set(critic_config).intersection(forbidden_gpu_override_keys))
if invalid_gpu_override_keys:
raise ValueError(
"Critic config cannot override GPU allocation keys that must follow the actor: "
+ ", ".join(invalid_gpu_override_keys)
)
# Apply overrides from the YAML config.
# Unspecified keys inherit from actor_args via deepcopy.
for key, value in critic_config.items():

Copilot uses AI. Check for mistakes.
Comment thread slime/ray/rollout.py
Comment on lines 975 to 979
"""Total number of megatron (actor + critic) GPU slots in the placement group."""
if args.debug_rollout_only:
return 0
if args.critic_train_only:
return args.critic_num_nodes * args.critic_num_gpus_per_node
num = args.actor_num_nodes * args.actor_num_gpus_per_node
if args.use_critic:
num += args.critic_num_nodes * args.critic_num_gpus_per_node
return num
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_compute_megatron_num_gpus no longer includes critic GPUs, but its docstring still says "actor + critic". Update the docstring (and any dependent comments) to reflect the new placement-group layout so future changes don’t reintroduce incorrect offsets/sizing.

Copilot uses AI. Check for mistakes.
Comment thread slime/ray/placement_group.py Outdated
c_pg, c_bundle_indices, c_gpu_ids = _create_placement_group(critic_num_gpus)
result["critic"] = (c_pg, c_bundle_indices, c_gpu_ids)
else:
result["critic"] = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
result["critic"] = None
result["critic"] = result["actor"] if args.use_critic else None

Comment thread slime/utils/arguments.py Outdated
),
)
parser.add_argument(
"--custom-advantage-fn",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please follow the naming convention, use --custom-xxx-function-path

Comment thread slime/utils/arguments.py Outdated
# Critic-specific: disable features that only apply to actors
critic_args.kl_coef = 0
critic_args.use_opd = False
critic_args.normalize_advantages = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to turn off normalize adv here?

Comment thread slime/utils/misc.py Outdated
return (step % interval == 0) or (num_rollout_per_epoch is not None and step % num_rollout_per_epoch == 0)


def critic_values_by_actor_worker(actor_model, critic_model, per_worker_values):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we only supports setting critic and actor as the same parallel config, maybe we can simply pass the data from critic to actor and leave this change to future PRs.

Comment thread train.py Outdated
per_worker_values = ray.get(value_refs)

if actor_trains_this_step:
actor_external = critic_values_by_actor_worker(actor_model, critic_model, per_worker_values)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we don't need critic_values_by_actor_worker, can we pass the value_refs to actor instread of materialize it in the control process?

Comment thread train_async.py
@lilei199908 lilei199908 changed the title [WIP] refactor ppo refactor/ppo Apr 24, 2026
@zhuzilin zhuzilin merged commit 75af529 into THUDM:main Apr 24, 2026
21 of 24 checks passed
FortPercent pushed a commit to HyperdriveHustle/slime that referenced this pull request May 8, 2026
When offload_train=True, training jobs crash with:

  [torch_memory_saver.cpp] CUresult error: 1 (invalid argument)
  file=csrc/core.cpp func=free line=81

right after the first rollout. The crash fires inside the post-rollout
actor.sleep() invoked by the outer training loop (e.g.
interface.py:481 -> actor_model.offload() -> actor.sleep()).

Two compounding causes (only the combination triggers the crash):

1. PR THUDM#1856 (refactor/ppo, commit 75af529) added a trailing
   self.sleep() at the end of MegatronTrainRayActor.train(). Outer
   loops still call actor.sleep() right after async_train returns,
   which yields two consecutive sleep() calls with no intervening
   wake_up().

2. The existing actor.sleep() body has a latent invariant violation:
   destroy_process_groups() releases NCCL GPU buffers back to PyTorch's
   CUDA Caching Allocator (CCA) free pool, so the underlying ptrs are
   still tracked by torch_memory_saver as ACTIVE.
   torch_memory_saver.pause() then unmaps every ACTIVE entry, marking
   those CCA-cached ptrs as PAUSED while they still sit in CCA's free
   pool. The next clear_memory() (head of the second sleep) calls
   torch.cuda.empty_cache(), which routes those cached segments through
   cudaFree -> LD_PRELOAD hook -> TorchMemorySaver::free() ->
   cuMemUnmap on already-unmapped -> CURESULT_CHECK -> exit(1).

Distinct from THUDM#1786 / THUDM#1690 (those report func=pause line=133, runtime
API cudaError, host-pinned-memory exhaustion). Ours is func=free
line=81, driver API CUresult, double-unmap.

Fix:

- sleep() / wake_up() get a _is_paused guard so a second sleep() (or
  spurious wake_up()) becomes a no-op.
- An extra clear_memory() between destroy_process_groups() and
  torch_memory_saver.pause() drains CCA's free pool, so pause() only
  marks truly-active ptrs as PAUSED -- restoring the invariant.

Verified on a 2x8 H800 GLM-4.7-Flash GRPO run with offload_train=True:
4 complete rollouts (rollout_id=0..3) finished cleanly; idempotent-skip
log fires once per rollout as expected; zero CURESULT errors.

See THUDM#1895 for full upstream report and reproduction.
SamitHuang added a commit to SamitHuang/slime that referenced this pull request May 17, 2026
* temp save rfc

Signed-off-by: SamitHuang <285365963@qq.com>

* add plan

Signed-off-by: SamitHuang <285365963@qq.com>

* update

Signed-off-by: SamitHuang <285365963@qq.com>

* [docker] remove true on policy patches (THUDM#1661)

Co-authored-by: Copilot <copilot@github.com>

* [fix]: Qwen3.5-35B-A3B 8-GPU: set TP size to 2 for num_query_groups=2 (THUDM#1662)

* Remove FSDP support (THUDM#1664)

Co-authored-by: Copilot <copilot@github.com>

* docs: add OpenClaw-RL to projects built upon slime (THUDM#1635)

* qwen2.5 0.5b non-colocate (first attempt ok, but nccl error later)

Signed-off-by: samithuang <285365963@qq.com>

* add convert script

* add setup doc

* Support setting update weights in sglang_config (THUDM#1665)

Co-authored-by: Copilot <copilot@github.com>

* fix nccl error by NcclBridge subprocess

* eliminate gpu to cpu weight transfer

Signed-off-by: samithuang <285365963@qq.com>

* Revise weight synchronization strategy in goal plan

Reorder weight synchronization support for colocate and non-colocate scenarios in the goal plan.

* [fix] Fix numerical accuracy issue in dynamic sampling filter (THUDM#1674)

* sync from internal (THUDM#1677)

Co-authored-by: Copilot <copilot@github.com>

* bugfixes from community (THUDM#1678)

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: yueming-yuan <yym022502@gmail.com>
Co-authored-by: coding-famer <chenhegu0109@gmail.com>

* Fix: pass return_tensors in text_kwargs for transformers>=5.0.0 compatibility (THUDM#1648)

* Fix missing packed_seq_params in bshd qkv_format (THUDM#1649)

* [Multimodal][Model] Qwen3.5 VL training example/support (THUDM#1676)

* update docs (THUDM#1680)

Co-authored-by: Copilot <copilot@github.com>

* update docs (THUDM#1681)

Co-authored-by: Copilot <copilot@github.com>

* support offloading non-updatable server (THUDM#1668)

Co-authored-by: Copilot <copilot@github.com>

* bugfix (THUDM#1685)

Co-authored-by: Copilot <copilot@github.com>

* fix: handle Qwen3.5 in quantize_params_fp8 (THUDM#1683)

* bugfix (THUDM#1687)

Co-authored-by: Copilot <copilot@github.com>

* Fix Qwen3.5 & Qwen3-Next linear attention cu_seqlens missing (THUDM#1686)

Co-authored-by: benyi <huangliangmeng.hlm@alibaba-inc.com>

* fix: use semantic version comparison for PyTorch >= 2.6 detection (THUDM#1667)

* [Fix] Minor fix for properly finishing / flushing wandb logging metrics at exit (THUDM#1592)

Co-authored-by: Zilin Zhu <zhuzilinallen@gmail.com>

* Autofix/issue 1578 hf2megatron arg suffix (THUDM#1636)

* bugfix (THUDM#1688)

Co-authored-by: Copilot <copilot@github.com>

* fix(examples): update strands_sglang example to v0.3.x API (THUDM#1684)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* [docker] cherry pick qwen3.5 bugfix (THUDM#1691)

Co-authored-by: Copilot <copilot@github.com>

* bugfix/fix Qwen3.5 dense model precision bug in TP_SIZE>1 from sglang (THUDM#1705)

* Fix/qwen3 5 mtp bridge (THUDM#1702)

Co-authored-by: benyi <huangliangmeng.hlm@alibaba-inc.com>

* support epd for glm4.6v (THUDM#1704)

* [docker] support epd for glm4.6v (THUDM#1707)

Co-authored-by: Copilot <copilot@github.com>

* remove script

* [docker] store v0.5.9 patch (THUDM#1710)

Co-authored-by: Copilot <copilot@github.com>

* Add GLM-4.7-Flash MTP training support (THUDM#1712)

* [release] bump to v0.2.3 (THUDM#1682)

Co-authored-by: Copilot <copilot@github.com>

* feat: add GLM-4.6V MoE VL bridge with CP support (THUDM#1715)

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix: resolve rope_theta from rope_parameters dict in HF config validation (THUDM#1720)

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* [docker] patches for glm4.6v, kimi k2.5 and dsa cp only (THUDM#1722)

Co-authored-by: Copilot <copilot@github.com>

* [docker] support IndexCache

* Fix CUDA IPC cache leaks during weight updates (THUDM#1731)

Co-authored-by: Copilot <copilot@github.com>

* [docker] update megatron (THUDM#1729)

Co-authored-by: Copilot <copilot@github.com>

* [docker] Fix IndexCache with mla model (THUDM#1736)

Co-authored-by: Copilot <copilot@github.com>

* [slime-router] support pd disaggregation and remove radix tree middleware (THUDM#1735)

* Fix glm4v megatron bridge (THUDM#1738)

Co-authored-by: Copilot <copilot@github.com>

* [docker] update sglang patch (THUDM#1743)

Co-authored-by: Copilot <copilot@github.com>

* feat: GLM4V multimodal support improvements (THUDM#1745)

Co-authored-by: Copilot <copilot@github.com>

* feat: placeholder worker type, metrics router, and GPQA letter range (THUDM#1746)

Co-authored-by: Copilot <copilot@github.com>

* always enable_metrics and remove dp context (THUDM#1747)

Co-authored-by: Copilot <copilot@github.com>

* fix: resolve SP/CP gradient inflation in FLA (linear attention) layers (THUDM#1748)

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Update MTP example configs, rename GLM-4.5 to GLM-4.7, clean scripts (THUDM#1749)

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Support qwen3.5 loss mask for multi-turn SFT (THUDM#1742)

Co-authored-by: benyi <huangliangmeng.hlm@alibaba-inc.com>

* fix: propagate moe_token_dispatcher_type in bridge model provider (THUDM#1737)

* fix: resolve rope_theta from rope_parameters in DeepseekV32Bridge (THUDM#1734)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* chore: translate remaining Chinese comments to English (THUDM#1726)

* feat: add Qwen3.5-4B model support (THUDM#1721)

* fix: http_utils. disable system proxy for internal SGLang httpx clients (THUDM#1714)

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: auto-detect GPUs in qwen3-4b script (THUDM#1700)

* fix: quote `$MOE_LAYER_FREQ` (THUDM#1689)

* disable router health_check and allow prompt_data is None (THUDM#1751)

Co-authored-by: Copilot <copilot@github.com>

* Router for vllm (#5)

* Draft router design

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Add vllm router

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Add router to script

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix gpu memory utilization

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix output token ids

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Add more nccl flag

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix bug

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

---------

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* small fix on qwen3-235b-a22b launch script (THUDM#1719)

* sync internal bugfix (THUDM#1765)

Co-authored-by: Copilot <copilot@github.com>

* Fix uploading sglang metrics to wandb (THUDM#1768)

Co-authored-by: Copilot <copilot@github.com>

* use zhuzilin/sgl-router for sglang-router (THUDM#1770)

Co-authored-by: Copilot <copilot@github.com>

* [docker] update sgl-router (THUDM#1772)

Co-authored-by: Copilot <copilot@github.com>

* [Multimodal] Add Multimodal OPD support (THUDM#1760)

* refactor: remove slime router (THUDM#1773)

Co-authored-by: Copilot <copilot@github.com>

* Add rollout trace timeline viewer (THUDM#1776)

Co-authored-by: Hanyu Zhang <hanyu.zhang@aminer.cn>

* [Fix] Fix duplicate Megatron LR scheduler resume when optimizer state is not loaded (THUDM#1775)

* Support FP8 conversion for Qwen3.5 (THUDM#1769)

* fix typo (THUDM#1759)

Co-authored-by: shiqirui <shiqirui@kupasai.com>

* [Fix]Fix some bugs/clean up (THUDM#1756)

* (fix):not have encoder_only attr cause run failed (THUDM#1741)

Co-authored-by: wangch <wangch@wangchdeMacBook-Air.local>

* update docs

* remove redundant envvar

* some minor cleanup

* [release] bump to v0.2.4  (THUDM#1777)

Co-authored-by: Copilot <copilot@github.com>

* Plan refactor vllm/sglang

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Code implemented

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix bug

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix bug

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix bug

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix port

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix config

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix bug MOE weight sync

* Fix bug vllm transfer weight

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix weight sync

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix config

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Change name config

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* pass critic role through to create RayTrainGroup (THUDM#1797)

* fix qwen3.5 397B converting error when enable expert parallel (THUDM#1799)

Co-authored-by: 周鹤云 <zhouheyun@xiaohongshu.com>

* fix(geo3k-vlm-sft): remove --apply-chat-template from SFT launch script (THUDM#1791)

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>

* Add host memory metrics to available_memory function (THUDM#1764)

* [WIP] fix loss oom (THUDM#1788)

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* sync from internal (THUDM#1805)

* sync from internal (THUDM#1807)

* feat: add npu patch for qwen3-vl-8b grpo & ppo (THUDM#1750)

Signed-off-by: cjy0x <isjunyi.chen@gmail.com>
Co-authored-by: shiyuan680 <917935075@qq.com>
Co-authored-by: PengchengShi00 <spc117369@gmail.com>

* fix missing position_ids in log-prob forward step (THUDM#1809)

* feat: add support for including missing weights from origin HF checkp… (THUDM#1812)

* [Fix] Initialize grad_norm before found_inf skip path (THUDM#1762)

* [conda] Add install custom sgl-router to build_conda.sh (THUDM#1813)

* Revert no_grad for entropy to prevent comm stuck in dsa (THUDM#1822)

* Add fallback for get_seqlen_balanced_partitions (THUDM#1823)

* Resolve review

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Try colocated vllm weight

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* docs: add Relax to notable projects in README (THUDM#1834)

* Bugfix: use cpu instead of cuda in convert_torch_dist_to_hf.py when --add-missing-from-origin-hf is set (THUDM#1828)

* [fix] eval sample logging when sample is a list (THUDM#1836)

* [Draft] Local runable dev

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* [Fix]  Fix cuda-python pin in build_conda.sh (THUDM#1827)

* fix entropy bug and update code (THUDM#1846)

* Revert "Add fallback for get_seqlen_balanced_partitions" (THUDM#1848)

* fix (THUDM#1849)

* Fix offload train

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Add support for NVIDIA DGX Spark (GB10 / sm_121a, arm64) (THUDM#1835)

* Fix offload train

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix offload_rollout

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix vllm offload

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix offload traing

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix offload weight

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix offload weight

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* refactor/ppo (THUDM#1856)

* [docker] cleanup sglang patch (THUDM#1859)

* [docker] update v0.5.9 patch

* Rename critic config to megatron config (THUDM#1866)

* [Fix] Use Ray ObjectRef await instead of asyncio.to_thread in distributed POST (THUDM#1873)

* chore: include length context in slice_log_prob_with_cp assert (THUDM#1862)

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* [docker] upgrade megatron to 1dcf0dafa (THUDM#1867)

* fix ppo value head load bugs (THUDM#1878)

* [docker] upgrade sglang to v0.5.10.post1 (THUDM#1874)

* [docs] update docs

* [docker] update megatron-bridge and add qwen3.6 tests (THUDM#1884)

* fix lint

* Fix(checkpoint): add resume/pause in save_model() for offload_train (fixes THUDM#1886) (THUDM#1888)

* fix ppo value offload bugs (THUDM#1882)

* fix qwen3.6 hf config validation bug (THUDM#1889)

* Add missing metrics to log (THUDM#1890)

* fix(qwen3_next): use torch.get_default_dtype() — get_current_dtype do… (THUDM#1883)

Co-authored-by: yeqinghe <yeqinghe@MacBook-Pro-6.local>

* Fix location error in install script (THUDM#1877)

* Only allow --allgather-cp for DSA model (THUDM#1891)

* Migrate internal feature (THUDM#1897)

* [Fix]  Fix distributed POST actor concurrency split (THUDM#1880)

Co-authored-by: Zilin Zhu <zhuzilinallen@gmail.com>

* Fix CI: update rollout_data_postprocess plugin contract for new call site (THUDM#1902)

Co-authored-by: jingshenghang <shenghang.jing@aminer.cn>

* Patch Megatron TP grad coalesce to chunked all-reduce (THUDM#1899)

* fix: harden retool rollout against multi-turn / retry desync (THUDM#1861)

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Fix log file

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix import engine group

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

* Fix rebase code

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>

---------

Signed-off-by: SamitHuang <285365963@qq.com>
Signed-off-by: samithuang <285365963@qq.com>
Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
Signed-off-by: cjy0x <isjunyi.chen@gmail.com>
Co-authored-by: SamitHuang <285365963@qq.com>
Co-authored-by: Zilin Zhu <zhuzilinallen@gmail.com>
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: none0663 <none0663@outlook.com>
Co-authored-by: Yinjie Wang <yinjie@uchicago.edu>
Co-authored-by: Fengqing Jiang <43953876+Django-Jiang@users.noreply.github.com>
Co-authored-by: yueming-yuan <yym022502@gmail.com>
Co-authored-by: coding-famer <chenhegu0109@gmail.com>
Co-authored-by: Lawrence Wu <lawrence.wu@harmonic.fun>
Co-authored-by: huang3eng <huang3eng@gmail.com>
Co-authored-by: benyi <huangliangmeng.hlm@alibaba-inc.com>
Co-authored-by: Aaron Batilo <AaronBatilo@gmail.com>
Co-authored-by: Silun Wang <igeekwang@gmail.com>
Co-authored-by: Chengxing Xie <91449279+yitianlian@users.noreply.github.com>
Co-authored-by: Yuan He <33579950+Lawhy@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Mor Zusman <mor.zusmann@gmail.com>
Co-authored-by: append-only <shw20010329@163.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Nan Jiang <59716405+nanjiangwill@users.noreply.github.com>
Co-authored-by: Xuan Wang <49010704+stevewx@users.noreply.github.com>
Co-authored-by: Hubert Wang <huberthyw@gmail.com>
Co-authored-by: Hou Shihao <shhou007@gmail.com>
Co-authored-by: DongzhuoranZhou <110855293+DongzhuoranZhou@users.noreply.github.com>
Co-authored-by: Ailuntz <130897222+ailuntz@users.noreply.github.com>
Co-authored-by: Zhuohao Li <garrick0508@gmail.com>
Co-authored-by: Hanyu Zhang <hanyu.zhang@aminer.cn>
Co-authored-by: Kang Yu <kangy.me@gmail.com>
Co-authored-by: peterjc123 <peter_jiachen@163.com>
Co-authored-by: qrskannbara <94727257+albaNnaksqr@users.noreply.github.com>
Co-authored-by: shiqirui <shiqirui@kupasai.com>
Co-authored-by: wangyufak <wangch9@xiaopeng.com>
Co-authored-by: wangch <wangch@wangchdeMacBook-Air.local>
Co-authored-by: Xintong Li <znculee@gmail.com>
Co-authored-by: TM <tianmingxu.tmxu@gmail.com>
Co-authored-by: 周鹤云 <zhouheyun@xiaohongshu.com>
Co-authored-by: LiLei <77353389+lilei199908@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: cjy0x <isjunyi.chen@gmail.com>
Co-authored-by: shiyuan680 <917935075@qq.com>
Co-authored-by: PengchengShi00 <spc117369@gmail.com>
Co-authored-by: 杨睿 <595403043@qq.com>
Co-authored-by: Mathew Han <49226490+mathewjhan@users.noreply.github.com>
Co-authored-by: haoxuanJIA <116806014+boots-coder@users.noreply.github.com>
Co-authored-by: ryang <38470282+ryang-max@users.noreply.github.com>
Co-authored-by: Leo Fan <84952531+leofan-lab@users.noreply.github.com>
Co-authored-by: Long Yijun <156500868+Procrastinatorrrr@users.noreply.github.com>
Co-authored-by: HeatherLiuzh <heather996lzh@gmail.com>
Co-authored-by: yeqinghe <yeqinghe@MacBook-Pro-6.local>
Co-authored-by: tao W <122036357+selfanti@users.noreply.github.com>
Co-authored-by: jingshenghang <48083555+jingshenghang@users.noreply.github.com>
Co-authored-by: jingshenghang <shenghang.jing@aminer.cn>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants