SAM Callback DDP and Multi Data Fixes. #187
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Two bugs are addressed in this PR:
In DDP, the pytorch lightning
trainer
api is different, so the originaltrainer.model.optimizers
will not work properly. This pr switches the logic to use the task module to get the optimizer names instead, which is consistent across training pipelines.In multi-data training, each dataset does not necessarily have the same amount of samples. When there is sample imbalance, the
_compute_loss()
function only returns losses for datasets that are still being processed. Theoptimizer_names
still contain mappings for datasets with no more samples to process, which was causing problems when gathering losses and computing the "global loss". This pr adds a few checks toextract_optimizer_specific_loss
to fix these issues.