Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAM Callback DDP and Multi Data Fixes. #187

Merged
merged 2 commits into from
Apr 19, 2024

Conversation

melo-gonzo
Copy link
Collaborator

Two bugs are addressed in this PR:

  1. In DDP, the pytorch lightning trainer api is different, so the original trainer.model.optimizers will not work properly. This pr switches the logic to use the task module to get the optimizer names instead, which is consistent across training pipelines.

  2. In multi-data training, each dataset does not necessarily have the same amount of samples. When there is sample imbalance, the _compute_loss() function only returns losses for datasets that are still being processed. The optimizer_names still contain mappings for datasets with no more samples to process, which was causing problems when gathering losses and computing the "global loss". This pr adds a few checks to extract_optimizer_specific_loss to fix these issues.

…n the loss, adding fix for ddp where trainer class switches (using task module to get opt names now).
@laserkelvin laserkelvin added the bug Something isn't working label Apr 18, 2024
Copy link
Collaborator

@laserkelvin laserkelvin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one comment - merge after you've added

@@ -779,11 +783,12 @@ def on_before_optimizer_step(
org_weights = self._first_step(optimizer)
with torch.enable_grad():
loss = task._compute_losses(self.batch)
if len(trainer.optimizers) > 1:
loss = self.extract_optimizer_specific_loss(trainer, optimizer, loss)
if len(task.optimizers()) > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you just throw a comment in here just to say this is the multitask case?

@melo-gonzo melo-gonzo merged commit 46c1737 into IntelLabs:main Apr 19, 2024
2 of 3 checks passed
@melo-gonzo melo-gonzo deleted the sam-multidata-and-ddp-fix branch April 19, 2024 15:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants