Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deterministic global sampling of corpus ids #66

Merged
merged 16 commits into from
May 13, 2024

Conversation

Waino
Copy link
Collaborator

@Waino Waino commented May 6, 2024

Tasks (corpus ids) are now sampled in a deterministic way such that all TaskQueueManagers can be aware of which task all other devices are training.

One benefit is that it is no longer necessary to communicate ready_t during gradient sync. This simplifies the algorithm, but does not seem to have any performance impact.

The PR includes refactoring of multiple locations in the code to use a joint representation of DistributedComponents, to determine which parameters are on which devices, and how they need to be communicated.

  • broadcast of initial model parameters
  • gradient communication during training
  • optimizer

There are additional places that could be refactored in future work:

  • model_builder
  • module splitter

This PR also includes a SimpleLookAheadBucketing, which provides dynamic minibatching with guarantees on the maximum batch size. This allows increasing the minibatch size without causing VRAM OOM.

Waino added 13 commits April 15, 2024 10:03
Naming scheme now consistently uses my_ or get_my_ prefix for local.
There are 4 places where almost the same loop over distributed
components is performed, with subtle differences.
1) in train_single.py, when broadcasting initialized parameters
2) in trainer.py, when communicating the gradients
3) in utils/optimizers.py, when stepping the optimizer
4) in utils/module_splitter.py, when saving a checkpoint

DRY could be greatly improved by refactoring these.
Tests are passing, but the modification is still only part way
- In order to simplify, CPU and single-GPU now also use multiprocessing,
  placing the dataloader in a separate process.

Several places have been refactored to use the new distributed
components:
- Init broadcast in train_single.py
- Gradient communication in trainer.py
  (still uses only_ready_reduce_and_rescale_grads though)
- Sub-optimizer construction in utils/optimizers.py

Two places have been identified as potential future candidates:
- model_builder
- module splitter for checkpointing

The task distribution is now logged in the TQM of the main rank.
It no longer has to be done in the dataloader, as each TQM has a global
view.
This guarantees that the producer and consumer don't share state
Reads in data until there is a few minibatches worth of it.
Guaranteed to not exceed maximum minibatch size, to avoid VRAM OOM.
Sorts examples locally according to length (max of source and target).
This allows some minibatches to use less padding.
Avoids the need to communicate ready_t
@Waino Waino requested a review from TimotheeMickus May 6, 2024 11:06
Copy link
Collaborator

@TimotheeMickus TimotheeMickus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but I'm not sure about dropping the sentence-level batching + max len padding trick, which thus far is the one with the best tok/s, right?
to what extent could this be added back in, given the current implementation?

mammoth/bin/train.py Show resolved Hide resolved
@@ -35,81 +36,48 @@ def broadcast_tensors(tensors, src=0, group=None):
torch.distributed.broadcast(t, src, group=group)


def only_ready_reduce_and_rescale_grads(named_parameters, group=None):
def managed_reduce_and_rescale_grads(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"managed" ? I'm not sure about the clarity of the terminology here, could you elaborate on that?

Copy link
Collaborator Author

@Waino Waino May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should rename this, if we can come up with a more descriptive name.

"Reduce and rescale grads or dummy grads, with the caller deciding which one to do."

Maybe externally_managed_reduce_and_rescale_grads would do the trick?
The idea is that only_ready uses the metadata provided by the has_grad hook to make its own decisions about how to communicate certain parameters, but the new implementation needs to be managed by something else, with the caller passing in the has_local_gradient and gradient_norm.

mammoth/distributed/communication.py Outdated Show resolved Hide resolved
help="Maximum number of bins for batching.",
default=4,
help="The number of minibatches that will be yielded once bucketing is complete. "
"Recommended value: same as accum_count, or at least a multiple of it."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bucket is a terrible name for that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bucket is indeed a terrible name. I reused the existing parameter without renaming it, while totally changing its meaning. Classic mammoth.

The plan was to either clean this up or throw this away based on the benchmark results, but unfortunately they were somewhat inconclusive.

mammoth/inputters/dataloader.py Outdated Show resolved Hide resolved
choices=["sents", "tokens"],
help="Batch grouping for batch_size. Standard is sents. Tokens will do dynamic batching",
help="Batch grouping for batch_size. Standard is tokens (max of src and tgt). Sents is unimplemented.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so no support for max len padding, which is the thing that seems to work the least bad for now?

@@ -83,55 +54,63 @@ def build_torch_optimizer(model, opts, task_queue_manager):
Returns:
A ``torch.optim.Optimizer`` instance.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels like we're missing a class constructor signature inspection + kwargs popping approach

@TimotheeMickus TimotheeMickus self-requested a review May 6, 2024 11:43
Copy link
Collaborator

@TimotheeMickus TimotheeMickus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Apparently issues with logging the TQM on this branch. Has this been smoketested / tested? if so I'll dismiss this second review)

[2024-05-06 11:46:02,478 1083  INFO] world_size = 4, queue_size = 40
[2024-05-06 11:46:02,481 1083  INFO] in task_queue_manager: node_rank 0 local_rank 0
Traceback (most recent call last):
  File "/home/nloppi/Tiedemann_project/mammoth/mammoth/train.py", line 6, in <module>
    main()
  File "/home/nloppi/Tiedemann_project/mammoth/mammoth/mammoth/bin/train.py", line 304, in main
    train(opts)
  File "/home/nloppi/Tiedemann_project/mammoth/mammoth/mammoth/bin/train.py", line 245, in train
    logger.info(f'TaskQueueManager: {global_task_queue_manager}')
  File "/home/nloppi/Tiedemann_project/mammoth/mammoth/mammoth/distributed/tasks.py", line 310, in __repr__
    kwargs = ',\n '.join(
  File "/home/nloppi/Tiedemann_project/mammoth/mammoth/mammoth/distributed/tasks.py", line 311, in <genexpr>
    f'{key}={pformat(self.__getattribute__(key))}'
AttributeError: 'TaskQueueManager' object has no attribute 'node_rank'

@Waino
Copy link
Collaborator Author

Waino commented May 6, 2024

LGTM, but I'm not sure about dropping the sentence-level batching + max len padding trick, which thus far is the one with the best tok/s, right? to what extent could this be added back in, given the current implementation?

The sentence-level batching + max len padding trick is implemented as a special case of the spiral bucketing data loader.

If we want both dynamic batching and the sentence-level batching + max len padding trick, then the easiest way is to keep all 3 implementations. However, it would be quite confusing to have two different ways of doing dynamic batching, of which one is known to be flaky. So maybe the dynamic spiral bucketing would be there but inaccessible through the config.

A better solution would be to refactor the sentence-level batching + max len padding trick so that it doesn't need the spiral bucketing data loader.

For benchmarking, it was just easier to rip it out.

@@ -223,27 +202,17 @@ def zero_grad(self):
for name in self.optimizers:
self.optimizers[name].zero_grad()

def step(self, grad_scaler=None):
"""Step through all the suboptimizers"""
def managed_step(self, gradient_syncs, grad_scaler=None):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we also use the term managed with the same meaning.
The caller must supply gradient syncs (a sequence of DistributedComponentGradientSync. Should add a type annotation) which determine whether or not to step each suboptimizer.

@Waino
Copy link
Collaborator Author

Waino commented May 6, 2024

(Apparently issues with logging the TQM on this branch. Has this been smoketested / tested? if so I'll dismiss this second review)

In commit fbe4f5c
I tried to reduce the log spam somewhat, and one of the actions was to only log a single TQM (the global one) instead of each device logging their own. It seems that a local TQM could be represented, but the __repr__ of the global one was broken.

I fixed the bug.

For some reason, setting the logging level of the logger doesn't work.
Neither --verbose nor --log_file_level seems to affect the level of the
logger.

Therefore, messages are still logged as warnings, but only shown if the
verbose flag is set.
@TimotheeMickus
Copy link
Collaborator

fix linting and this should be GTG

@TimotheeMickus
Copy link
Collaborator

LGTM! will let you the honors of merging :']

@Waino
Copy link
Collaborator Author

Waino commented May 13, 2024

Closes #8

@Waino Waino merged commit fc51156 into main May 13, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants