Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 13 additions & 54 deletions slime/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from megatron.core import mpu
from megatron.core.packed_seq_params import PackedSeqParams

from slime.utils import train_metric_utils
from slime.utils.data import get_minimum_num_micro_batch_size
from slime.utils.flops_utils import calculate_fwd_flops
from slime.utils.metric_utils import compute_pass_rate
from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions
from slime.utils.timer import Timer
from slime.utils.types import RolloutBatch

from .cp_utils import get_sum_of_sample_mean, slice_with_cp
Expand Down Expand Up @@ -419,59 +419,18 @@ def log_passrate(rollout_id: int, args: Namespace, rollout_data: RolloutBatch) -


def log_perf_data(rollout_id: int, args: Namespace) -> None:
"""
Log timing metrics and derived TFLOPs for compute phases if available.

Only active on PP last stage, TP rank 0, and DP source rank. The step is
consistent with other logs.
"""
timer_instance = Timer()
if (
mpu.get_tensor_model_parallel_rank() == 0
and mpu.is_pipeline_last_stage()
and mpu.get_data_parallel_rank(with_context_parallel=True) == 0
):
log_dict = {f"perf/{key}_time": val for key, val in timer_instance.log_dict().items()}

if "perf/actor_train_time" in log_dict:
world_size = dist.get_world_size()
total_fwd_flops = calculate_fwd_flops(seqlens=timer_instance.seq_lens, args=args) / world_size / 1e12

if "perf/log_probs_time" in log_dict:
log_dict["perf/log_probs_tflops"] = total_fwd_flops / log_dict["perf/log_probs_time"]

if "perf/ref_log_probs_time" in log_dict:
log_dict["perf/ref_log_probs_tflops"] = total_fwd_flops / log_dict["perf/ref_log_probs_time"]

if log_dict["perf/actor_train_time"] > 0:
log_dict["perf/actor_train_tflops"] = 3 * total_fwd_flops / log_dict["perf/actor_train_time"]
log_dict["perf/actor_train_tok_per_s"] = (
sum(timer_instance.seq_lens) / log_dict["perf/actor_train_time"]
)

if "perf/train_wait_time" in log_dict and "perf/train_time" in log_dict:
total_time = log_dict["perf/train_wait_time"] + log_dict["perf/train_time"]
if total_time > 0:
log_dict["perf/step_time"] = total_time
log_dict["perf/wait_time_ratio"] = log_dict["perf/train_wait_time"] / total_time

print(f"perf {rollout_id}: {log_dict}")

step = (
rollout_id
if not args.wandb_always_use_train_step
else rollout_id * args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size
)
if args.use_wandb:
log_dict["rollout/step"] = step
wandb.log(log_dict)

if args.use_tensorboard:
from slime.utils.tensorboard_utils import _TensorboardAdapter

tb = _TensorboardAdapter(args)
tb.log(data=log_dict, step=step)
timer_instance.reset()
train_metric_utils.log_perf_data_raw(
rollout_id=rollout_id,
args=args,
is_primary_rank=(
mpu.get_tensor_model_parallel_rank() == 0
and mpu.is_pipeline_last_stage()
and mpu.get_data_parallel_rank(with_context_parallel=True) == 0
),
compute_total_fwd_flops=lambda seq_lens: calculate_fwd_flops(seqlens=seq_lens, args=args)
/ dist.get_world_size()
/ 1e12,
)


def sync_actor_critic_data(
Expand Down
56 changes: 56 additions & 0 deletions slime/utils/train_metric_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from argparse import Namespace
from copy import deepcopy
from typing import Callable

import wandb

from slime.utils.timer import Timer


def log_perf_data_raw(
rollout_id: int, args: Namespace, is_primary_rank: bool, compute_total_fwd_flops: Callable
) -> None:
timer_instance = Timer()
log_dict_raw = deepcopy(timer_instance.log_dict())
timer_instance.reset()

if not is_primary_rank:
return

log_dict = {f"perf/{key}_time": val for key, val in log_dict_raw.items()}

if ("perf/actor_train_time" in log_dict) and (compute_total_fwd_flops is not None):
total_fwd_flops = compute_total_fwd_flops(seq_lens=timer_instance.seq_lens)

if "perf/log_probs_time" in log_dict:
log_dict["perf/log_probs_tflops"] = total_fwd_flops / log_dict["perf/log_probs_time"]

if "perf/ref_log_probs_time" in log_dict:
log_dict["perf/ref_log_probs_tflops"] = total_fwd_flops / log_dict["perf/ref_log_probs_time"]

if log_dict["perf/actor_train_time"] > 0:
log_dict["perf/actor_train_tflops"] = 3 * total_fwd_flops / log_dict["perf/actor_train_time"]
log_dict["perf/actor_train_tok_per_s"] = sum(timer_instance.seq_lens) / log_dict["perf/actor_train_time"]

if "perf/train_wait_time" in log_dict and "perf/train_time" in log_dict:
total_time = log_dict["perf/train_wait_time"] + log_dict["perf/train_time"]
if total_time > 0:
log_dict["perf/step_time"] = total_time
log_dict["perf/wait_time_ratio"] = log_dict["perf/train_wait_time"] / total_time

print(f"perf {rollout_id}: {log_dict}")

step = (
rollout_id
if not args.wandb_always_use_train_step
else rollout_id * args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size
)
if args.use_wandb:
log_dict["rollout/step"] = step
wandb.log(log_dict)

if args.use_tensorboard:
from slime.utils.tensorboard_utils import _TensorboardAdapter

tb = _TensorboardAdapter(args)
tb.log(data=log_dict, step=step)
Loading