Skip to content

feat: log total_valid_tokens and tokens_per_sec during SFT training.#803

Closed
xxman-google wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
xxman-google:xx/logging
Closed

feat: log total_valid_tokens and tokens_per_sec during SFT training.#803
xxman-google wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
xxman-google:xx/logging

Conversation

@xxman-google
Copy link
Contributor

@xxman-google xxman-google commented Jul 30, 2025

What does this PR do ?

Add two additional metrics for logging during SFT.

It is good to know how many tokens (without padding) in total we have trained and it is also useful to log the training speed in terms of tokens per second.

Issues

List issues that this PR closes (syntax): N/A

Usage

  • You will find two additional loggings in wandb graphs. (See below for an example)
    total_valid_tokens is saved into the checkpoint so it can be resumed when training is interrupted.
Screenshot 2025-07-29 at 9 45 57 PM Screenshot 2025-07-29 at 9 46 06 PM

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

@terrykong terrykong self-requested a review July 30, 2025 16:38
Signed-off-by: Xuehan <xxman@google.com>
total_steps: int # Track total number of steps across all epochs
val_loss: NotRequired[float] # Optional field - may not be present during training
consumed_samples: int
total_valid_tokens: int # Track total number of non-padding tokens during training
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this supposed to track a different metric than

token_count = 0
?

can you use that instead? that one shows up as mean_total_tokens_per_sample

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this supposed to track a different metric than

token_count = 0

?
can you use that instead? that one shows up as mean_total_tokens_per_sample

yes, I am tracking non-padded tokens while that one tracks all tokens.

@ZhiyuLi-Nvidia ZhiyuLi-Nvidia self-requested a review August 1, 2025 23:05
print(f" • {k}: {v:.2f}s ({percent:.1f}%)")

timing_metrics["valid_tokens_per_sec"] = (
metrics["global_valid_toks"] / total_time
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to add/switch to valid_tokens_per_sec_per_gpu? It is normalized metrics for num_gpus. Say we can linearly scale using data parallelism with more gpus, then valid_tokens_per_sec_per_gpu would be a better metrics to track efficiency/speed of sft training.

}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
if k in {"lr", "wd", "global_valid_seqs", "global_valid_toks"}:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a naive question:
why we do mean average for global_valid_seqs, global_valid_toks? Is the metric already normalized on num_workers?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xxman-google I think I am mainly confused ad the naming.

It is called "globalxxxx", while we do mean average right after.

if k in {"lr", "wd", "global_valid_seqs", "global_valid_toks"}:
   metrics[k] = np.mean(v).item()

@ZhiyuLi-Nvidia
Copy link
Contributor

@xxman-google thank you for contribution. Just curious if you can address the last 3 comments, then we are good to go.

@terrykong
Copy link
Collaborator

closing. continuing in #1249

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants