From 379f1949dc258f4822c6498418a37f38b7e79372 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Wed, 29 Oct 2025 22:57:19 +0800 Subject: [PATCH] cp --- slime/backends/megatron_utils/data.py | 67 ++++++--------------------- slime/utils/train_metric_utils.py | 56 ++++++++++++++++++++++ 2 files changed, 69 insertions(+), 54 deletions(-) create mode 100644 slime/utils/train_metric_utils.py diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index dd84f6c3e..65fa967bc 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -9,11 +9,11 @@ from megatron.core import mpu from megatron.core.packed_seq_params import PackedSeqParams +from slime.utils import train_metric_utils from slime.utils.data import get_minimum_num_micro_batch_size from slime.utils.flops_utils import calculate_fwd_flops from slime.utils.metric_utils import compute_pass_rate from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions -from slime.utils.timer import Timer from slime.utils.types import RolloutBatch from .cp_utils import get_sum_of_sample_mean, slice_with_cp @@ -419,59 +419,18 @@ def log_passrate(rollout_id: int, args: Namespace, rollout_data: RolloutBatch) - def log_perf_data(rollout_id: int, args: Namespace) -> None: - """ - Log timing metrics and derived TFLOPs for compute phases if available. - - Only active on PP last stage, TP rank 0, and DP source rank. The step is - consistent with other logs. - """ - timer_instance = Timer() - if ( - mpu.get_tensor_model_parallel_rank() == 0 - and mpu.is_pipeline_last_stage() - and mpu.get_data_parallel_rank(with_context_parallel=True) == 0 - ): - log_dict = {f"perf/{key}_time": val for key, val in timer_instance.log_dict().items()} - - if "perf/actor_train_time" in log_dict: - world_size = dist.get_world_size() - total_fwd_flops = calculate_fwd_flops(seqlens=timer_instance.seq_lens, args=args) / world_size / 1e12 - - if "perf/log_probs_time" in log_dict: - log_dict["perf/log_probs_tflops"] = total_fwd_flops / log_dict["perf/log_probs_time"] - - if "perf/ref_log_probs_time" in log_dict: - log_dict["perf/ref_log_probs_tflops"] = total_fwd_flops / log_dict["perf/ref_log_probs_time"] - - if log_dict["perf/actor_train_time"] > 0: - log_dict["perf/actor_train_tflops"] = 3 * total_fwd_flops / log_dict["perf/actor_train_time"] - log_dict["perf/actor_train_tok_per_s"] = ( - sum(timer_instance.seq_lens) / log_dict["perf/actor_train_time"] - ) - - if "perf/train_wait_time" in log_dict and "perf/train_time" in log_dict: - total_time = log_dict["perf/train_wait_time"] + log_dict["perf/train_time"] - if total_time > 0: - log_dict["perf/total_train_time"] = total_time - log_dict["perf/wait_time_ratio"] = log_dict["perf/train_wait_time"] / total_time - - print(f"perf {rollout_id}: {log_dict}") - - step = ( - rollout_id - if not args.wandb_always_use_train_step - else rollout_id * args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size - ) - if args.use_wandb: - log_dict["rollout/step"] = step - wandb.log(log_dict) - - if args.use_tensorboard: - from slime.utils.tensorboard_utils import _TensorboardAdapter - - tb = _TensorboardAdapter(args) - tb.log(data=log_dict, step=step) - timer_instance.reset() + train_metric_utils.log_perf_data_raw( + rollout_id=rollout_id, + args=args, + is_primary_rank=( + mpu.get_tensor_model_parallel_rank() == 0 + and mpu.is_pipeline_last_stage() + and mpu.get_data_parallel_rank(with_context_parallel=True) == 0 + ), + compute_total_fwd_flops=lambda seq_lens: calculate_fwd_flops(seqlens=seq_lens, args=args) + / dist.get_world_size() + / 1e12, + ) def sync_actor_critic_data( diff --git a/slime/utils/train_metric_utils.py b/slime/utils/train_metric_utils.py new file mode 100644 index 000000000..ce9e64cef --- /dev/null +++ b/slime/utils/train_metric_utils.py @@ -0,0 +1,56 @@ +from argparse import Namespace +from copy import deepcopy +from typing import Callable + +import wandb + +from slime.utils.timer import Timer + + +def log_perf_data_raw( + rollout_id: int, args: Namespace, is_primary_rank: bool, compute_total_fwd_flops: Callable +) -> None: + timer_instance = Timer() + log_dict_raw = deepcopy(timer_instance.log_dict()) + timer_instance.reset() + + if not is_primary_rank: + return + + log_dict = {f"perf/{key}_time": val for key, val in log_dict_raw.items()} + + if ("perf/actor_train_time" in log_dict) and (compute_total_fwd_flops is not None): + total_fwd_flops = compute_total_fwd_flops(seq_lens=timer_instance.seq_lens) + + if "perf/log_probs_time" in log_dict: + log_dict["perf/log_probs_tflops"] = total_fwd_flops / log_dict["perf/log_probs_time"] + + if "perf/ref_log_probs_time" in log_dict: + log_dict["perf/ref_log_probs_tflops"] = total_fwd_flops / log_dict["perf/ref_log_probs_time"] + + if log_dict["perf/actor_train_time"] > 0: + log_dict["perf/actor_train_tflops"] = 3 * total_fwd_flops / log_dict["perf/actor_train_time"] + log_dict["perf/actor_train_tok_per_s"] = sum(timer_instance.seq_lens) / log_dict["perf/actor_train_time"] + + if "perf/train_wait_time" in log_dict and "perf/train_time" in log_dict: + total_time = log_dict["perf/train_wait_time"] + log_dict["perf/train_time"] + if total_time > 0: + log_dict["perf/step_time"] = total_time + log_dict["perf/wait_time_ratio"] = log_dict["perf/train_wait_time"] / total_time + + print(f"perf {rollout_id}: {log_dict}") + + step = ( + rollout_id + if not args.wandb_always_use_train_step + else rollout_id * args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size + ) + if args.use_wandb: + log_dict["rollout/step"] = step + wandb.log(log_dict) + + if args.use_tensorboard: + from slime.utils.tensorboard_utils import _TensorboardAdapter + + tb = _TensorboardAdapter(args) + tb.log(data=log_dict, step=step)