-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deterministic global sampling of corpus ids #66
Conversation
Naming scheme now consistently uses my_ or get_my_ prefix for local.
There are 4 places where almost the same loop over distributed components is performed, with subtle differences. 1) in train_single.py, when broadcasting initialized parameters 2) in trainer.py, when communicating the gradients 3) in utils/optimizers.py, when stepping the optimizer 4) in utils/module_splitter.py, when saving a checkpoint DRY could be greatly improved by refactoring these.
Tests are passing, but the modification is still only part way
- In order to simplify, CPU and single-GPU now also use multiprocessing, placing the dataloader in a separate process. Several places have been refactored to use the new distributed components: - Init broadcast in train_single.py - Gradient communication in trainer.py (still uses only_ready_reduce_and_rescale_grads though) - Sub-optimizer construction in utils/optimizers.py Two places have been identified as potential future candidates: - model_builder - module splitter for checkpointing The task distribution is now logged in the TQM of the main rank. It no longer has to be done in the dataloader, as each TQM has a global view.
This guarantees that the producer and consumer don't share state
Reads in data until there is a few minibatches worth of it. Guaranteed to not exceed maximum minibatch size, to avoid VRAM OOM. Sorts examples locally according to length (max of source and target). This allows some minibatches to use less padding.
Avoids the need to communicate ready_t
Closes #8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but I'm not sure about dropping the sentence-level batching + max len padding trick, which thus far is the one with the best tok/s, right?
to what extent could this be added back in, given the current implementation?
@@ -35,81 +36,48 @@ def broadcast_tensors(tensors, src=0, group=None): | |||
torch.distributed.broadcast(t, src, group=group) | |||
|
|||
|
|||
def only_ready_reduce_and_rescale_grads(named_parameters, group=None): | |||
def managed_reduce_and_rescale_grads( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"managed" ? I'm not sure about the clarity of the terminology here, could you elaborate on that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should rename this, if we can come up with a more descriptive name.
"Reduce and rescale grads or dummy grads, with the caller deciding which one to do."
Maybe externally_managed_reduce_and_rescale_grads
would do the trick?
The idea is that only_ready
uses the metadata provided by the has_grad
hook to make its own decisions about how to communicate certain parameters, but the new implementation needs to be managed by something else, with the caller passing in the has_local_gradient
and gradient_norm
.
help="Maximum number of bins for batching.", | ||
default=4, | ||
help="The number of minibatches that will be yielded once bucketing is complete. " | ||
"Recommended value: same as accum_count, or at least a multiple of it." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bucket is a terrible name for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bucket is indeed a terrible name. I reused the existing parameter without renaming it, while totally changing its meaning. Classic mammoth.
The plan was to either clean this up or throw this away based on the benchmark results, but unfortunately they were somewhat inconclusive.
choices=["sents", "tokens"], | ||
help="Batch grouping for batch_size. Standard is sents. Tokens will do dynamic batching", | ||
help="Batch grouping for batch_size. Standard is tokens (max of src and tgt). Sents is unimplemented.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so no support for max len padding, which is the thing that seems to work the least bad for now?
@@ -83,55 +54,63 @@ def build_torch_optimizer(model, opts, task_queue_manager): | |||
Returns: | |||
A ``torch.optim.Optimizer`` instance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feels like we're missing a class constructor signature inspection + kwargs popping approach
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Apparently issues with logging the TQM on this branch. Has this been smoketested / tested? if so I'll dismiss this second review)
[2024-05-06 11:46:02,478 1083 INFO] world_size = 4, queue_size = 40
[2024-05-06 11:46:02,481 1083 INFO] in task_queue_manager: node_rank 0 local_rank 0
Traceback (most recent call last):
File "/home/nloppi/Tiedemann_project/mammoth/mammoth/train.py", line 6, in <module>
main()
File "/home/nloppi/Tiedemann_project/mammoth/mammoth/mammoth/bin/train.py", line 304, in main
train(opts)
File "/home/nloppi/Tiedemann_project/mammoth/mammoth/mammoth/bin/train.py", line 245, in train
logger.info(f'TaskQueueManager: {global_task_queue_manager}')
File "/home/nloppi/Tiedemann_project/mammoth/mammoth/mammoth/distributed/tasks.py", line 310, in __repr__
kwargs = ',\n '.join(
File "/home/nloppi/Tiedemann_project/mammoth/mammoth/mammoth/distributed/tasks.py", line 311, in <genexpr>
f'{key}={pformat(self.__getattribute__(key))}'
AttributeError: 'TaskQueueManager' object has no attribute 'node_rank'
The sentence-level batching + max len padding trick is implemented as a special case of the spiral bucketing data loader. If we want both dynamic batching and the sentence-level batching + max len padding trick, then the easiest way is to keep all 3 implementations. However, it would be quite confusing to have two different ways of doing dynamic batching, of which one is known to be flaky. So maybe the dynamic spiral bucketing would be there but inaccessible through the config. A better solution would be to refactor the sentence-level batching + max len padding trick so that it doesn't need the spiral bucketing data loader. For benchmarking, it was just easier to rip it out. |
@@ -223,27 +202,17 @@ def zero_grad(self): | |||
for name in self.optimizers: | |||
self.optimizers[name].zero_grad() | |||
|
|||
def step(self, grad_scaler=None): | |||
"""Step through all the suboptimizers""" | |||
def managed_step(self, gradient_syncs, grad_scaler=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we also use the term managed
with the same meaning.
The caller must supply gradient syncs (a sequence of DistributedComponentGradientSync
. Should add a type annotation) which determine whether or not to step each suboptimizer.
In commit fbe4f5c I fixed the bug. |
For some reason, setting the logging level of the logger doesn't work. Neither --verbose nor --log_file_level seems to affect the level of the logger. Therefore, messages are still logged as warnings, but only shown if the verbose flag is set.
fix linting and this should be GTG |
LGTM! will let you the honors of merging :'] |
Closes #8 |
Tasks (corpus ids) are now sampled in a deterministic way such that all TaskQueueManagers can be aware of which task all other devices are training.
One benefit is that it is no longer necessary to communicate ready_t during gradient sync. This simplifies the algorithm, but does not seem to have any performance impact.
The PR includes refactoring of multiple locations in the code to use a joint representation of DistributedComponents, to determine which parameters are on which devices, and how they need to be communicated.
There are additional places that could be refactored in future work:
This PR also includes a SimpleLookAheadBucketing, which provides dynamic minibatching with guarantees on the maximum batch size. This allows increasing the minibatch size without causing VRAM OOM.