Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Gradient accumulation with the distributed trainer #3537

Merged
merged 4 commits into from
Dec 17, 2019

Conversation

dirkgr
Copy link
Member

@dirkgr dirkgr commented Dec 17, 2019

I re-did gradient accumulation with the distributed trainer.

@dirkgr dirkgr self-assigned this Dec 17, 2019
@@ -343,21 +353,20 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
train_generator_tqdm = train_generator

cumulative_batch_size = 0
for batch in train_generator_tqdm:
for batches in lazy_groups_of(train_generator_tqdm, self._num_gradient_accumulation_steps):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Maybe accumulation_groups (or batch_group or whatever) instead of batches? Might be confusing to a reader see lines like cur_batch = sum(training_util.get_batch_size(batch) for batch in batches).

@@ -323,7 +331,9 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:

# Get tqdm for the training batches
train_generator = self.iterator(self.train_data, num_epochs=1, shuffle=self.shuffle)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lazy_groups_of call should be here to create a new generator and tqdm should be applied to that. As in

raw_train_generator = self.iterator(self.train_data, num_epochs=1, shuffle=self.shuffle)
previously. Otherwise tqdm will give an inaccurate count of the batches.

@dirkgr
Copy link
Member Author

dirkgr commented Dec 17, 2019

@brendan-ai2, I took some license with the naming, and fixed the tqdm issue while I was at it.

@dirkgr dirkgr mentioned this pull request Dec 17, 2019
Copy link
Contributor

@brendan-ai2 brendan-ai2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants