feat: log total_valid_tokens and tokens_per_sec during SFT training.#803
feat: log total_valid_tokens and tokens_per_sec during SFT training.#803xxman-google wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
total_valid_tokens and tokens_per_sec during SFT training.#803Conversation
Signed-off-by: Xuehan <xxman@google.com>
| total_steps: int # Track total number of steps across all epochs | ||
| val_loss: NotRequired[float] # Optional field - may not be present during training | ||
| consumed_samples: int | ||
| total_valid_tokens: int # Track total number of non-padding tokens during training |
There was a problem hiding this comment.
is this supposed to track a different metric than
RL/nemo_rl/experience/rollouts.py
Line 605 in 7b3fad8
can you use that instead? that one shows up as mean_total_tokens_per_sample
There was a problem hiding this comment.
is this supposed to track a different metric than
RL/nemo_rl/experience/rollouts.py
Line 605 in 7b3fad8
?
can you use that instead? that one shows up asmean_total_tokens_per_sample
yes, I am tracking non-padded tokens while that one tracks all tokens.
| print(f" • {k}: {v:.2f}s ({percent:.1f}%)") | ||
|
|
||
| timing_metrics["valid_tokens_per_sec"] = ( | ||
| metrics["global_valid_toks"] / total_time |
There was a problem hiding this comment.
Would it be better to add/switch to valid_tokens_per_sec_per_gpu? It is normalized metrics for num_gpus. Say we can linearly scale using data parallelism with more gpus, then valid_tokens_per_sec_per_gpu would be a better metrics to track efficiency/speed of sft training.
| } | ||
| metrics.update(train_results["all_mb_metrics"]) | ||
| for k, v in metrics.items(): | ||
| if k in {"lr", "wd", "global_valid_seqs", "global_valid_toks"}: |
There was a problem hiding this comment.
a naive question:
why we do mean average for global_valid_seqs, global_valid_toks? Is the metric already normalized on num_workers?
There was a problem hiding this comment.
@xxman-google I think I am mainly confused ad the naming.
It is called "globalxxxx", while we do mean average right after.
if k in {"lr", "wd", "global_valid_seqs", "global_valid_toks"}:
metrics[k] = np.mean(v).item()
|
@xxman-google thank you for contribution. Just curious if you can address the last 3 comments, then we are good to go. |
|
closing. continuing in #1249 |
What does this PR do ?
Add two additional metrics for logging during SFT.
It is good to know how many tokens (without padding) in total we have trained and it is also useful to log the training speed in terms of tokens per second.
Issues
List issues that this PR closes (syntax): N/A
Usage
total_valid_tokensis saved into the checkpoint so it can be resumed when training is interrupted.Before your PR is "Ready for review"
Pre checks:
Additional Information