Skip to content

[Benchmark] Fix ZeRO-3 step log#31

Merged
zarzen merged 3 commits intoawslabs:mainfrom
comaniac:fix_zero3_log
Jan 31, 2023
Merged

[Benchmark] Fix ZeRO-3 step log#31
zarzen merged 3 commits intoawslabs:mainfrom
comaniac:fix_zero3_log

Conversation

@comaniac
Copy link
Contributor

Description

Fix the benchmark utility train_with_torch to consider micro batch when printing the log. Now it accepts optional micro_batch_size and only prints the loss per global batch.

Checklist

  • PR's title starts with a category (e.g. [Bugfix], [Model], [Tutorial], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

cc @zarzen

@comaniac
Copy link
Contributor Author

Per offline discussion, we now have two functions to deal with different cases:

  1. train_with_deepspeed_engine: Train with ZeRO but not pipeline. This function uses DeepSpeed model APIs .global_steps to print loss, so that we don't need to worry about batch size and DP size.
  2. train_with_torch: Train with PyTorch native runtime. In this case we assume no DP, PP, and gradient accumulation, so each micro batch is just the global batch. This is currently only used by WideResNet w. TP.

In addition, this PR also reduces the number of steps in CI to reduce the CI time.

@comaniac comaniac mentioned this pull request Jan 31, 2023
4 tasks
@zarzen zarzen merged commit 1b1ef74 into awslabs:main Jan 31, 2023
@comaniac comaniac deleted the fix_zero3_log branch February 3, 2023 23:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants