fix(*:skip): Normalize avg_trainloss in PyTorch quickstart #6334
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Issue
Description
The PyTorch quickstart computes
running_lossacross all local epochs but normalizes it only bylen(trainloader). As a result, the reported train loss metric scales approximately linearly withlocal-epochs > 1, which is misleading.Related issues/PRs
Fixes #6333
Proposal
Explanation
This PR fixes the normalization by dividing
running_lossby(epochs * len(trainloader)), yielding an average loss per batch over the full local training run.This is a single line change to the train function, and the fix is also applied in the docs that use this code directly.
examples/quickstart-pytorch/pytorchexample/task.pyframework/docs/source/tutorial-quickstart-pytorch.rstframework/docs/source/tutorial-series-get-started-with-flower-pytorch.rstBehavior for
local-epochs=1is unchanged but the train loss forlocal-epochs>1is corrected. This was verified by running one server round withlocal-epochs={1,2,4}and the loss no longer scales 2x/4x, but instead remains in the same range.Note that
examples/quickstart-pennylane/...already normalizes loss asrunning_loss / (epochs * len(trainloader)).Checklist
#contributions)Any other comments?
./framework/dev/format.sh currently fails on upstream main due to a unrelated missing copyright notice in framework/py/flwr/app/message_type.py, which is reproducible on
mainat92b9de10b