Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from slime.utils.http_utils import find_available_port, get_host_info, init_http_client
from slime.utils.iter_utils import group_by
from slime.utils.metric_checker import MetricChecker
from slime.utils.metric_utils import compute_pass_rate, dict_add_prefix
from slime.utils.metric_utils import compute_pass_rate, compute_statistics, dict_add_prefix
from slime.utils.misc import load_function
from slime.utils.ray_utils import Box
from slime.utils.types import Sample
Expand Down Expand Up @@ -492,6 +492,7 @@ def _log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_
if args.rollout_num_gpus:
log_dict["perf/tokens_per_gpu_per_sec"] = sum(response_lengths) / rollout_time / args.rollout_num_gpus
log_dict["perf/longest_sample_tokens_per_sec"] = max(response_lengths) / rollout_time
log_dict |= dict_add_prefix(compute_statistics(response_lengths), f"rollout/response_len/")
log_dict |= _compute_zero_std_metrics(args, samples)
log_dict |= _compute_spec_metrics(args, samples)
log_dict |= dict_add_prefix(_compute_reward_cat_metrics(args, samples), f"rollout/")
Expand Down
8 changes: 8 additions & 0 deletions slime/utils/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def estimator(n, c, k):
return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples, num_correct)])


def compute_statistics(values: List[float]) -> Dict[str, float]:
values = np.array(values)
return {
"mean": np.mean(values).item(),
"median": np.median(values).item(),
}


def compression_ratio(
data: Union[str, bytes],
*,
Expand Down
Loading