-
Notifications
You must be signed in to change notification settings - Fork 2.2k
FairScale integration #5242
FairScale integration #5242
Conversation
@jacobdanovitch you might find this interesting! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
100 lines of code for distributed training. 1400 lines for checkpointing. I see why you gave it all those stars.
I'm not a big fan of those wrappers, or wrapper factories. They change the model init API in unintuitive ways, they are annoying to pass around, they mess with serialization. Is there some way we could do this from outside the model, cleanly? For example, what if we gave regexes that tell the trainer which modules to wrap? Or maybe an API where the model can optionally return a list of modules that it would like wrapped by the trainer during initialization? What scenarios would be broken if we had that approach?
I also seem to remember you saying that this wrapper factory approach mirrors the approach that FairScale took. Can you point me to some examples of that?
CHANGELOG.md
Outdated
- The type of the `grad_norm` parameter of `GradientDescentTrainer` is now `Union[float, bool]`, | ||
with a default value of `False`. `False` means gradients are not rescaled and the gradient | ||
norm is never even calculated. `True` means the gradients are still not rescaled but the gradient | ||
norm is calculated and passed on to callbacks. A `float` value means gradients are rescaled. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't seen the code yet, but I'm not too wild about this API. That means you have to know whether some other component needs the gradient norm or not. I'd rather provide a function called get_grad_norm()
or something like that, which calculates it lazily.
allennlp/nn/parallel/ddp_wrapper.py
Outdated
return amp.GradScaler() | ||
|
||
|
||
class DdpWrapper(Registrable): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not really a wrapper though. This is more like a wrapper factory. Is there any scenario where we would create this object and then wrap multiple models with it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a terrible name. I've renamed it DdpAccelerator
.
"model", it gets specified as "ddp_wrapper" in the "distributed" part of the config, and is then | ||
passed in to the model separately. | ||
"model", it gets specified as "ddp_accelerator" in the "distributed" part of the config, and is then | ||
passed in to the model automatically. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? Can we just pass it to the model and save ourselves one exception?
Still TODO ☑️ 👇
(stars indicate relative difficulty / complexity / how much I'm dreading this item)
BeamSearch
a lazy parameter to T5 module and modelallennlp_models/training_configs/generation/
for fine-tuning T5 on CNN/DMdo_auto_wrap
option._MODULE_SHARDED_FLAG
and_WRAPPED_MODULE_GETTER
. Maybe have a function innn.util
likeset_module_sharded
which would take awrapped_module_getter
function, and anotheris_sharded_module
function. Or could just create a mixin base class with the needed methods.allennlp-models
branch patch in.github/workflows/ci.yml
, and do the same in the correspondingallennlp-models
PR (requirements.txt
andMakefile
)For reviewers of this PR, I would suggest you start by looking at the new functionality provided in
allennlp/nn/parallel/
andallennlp/nn/checkpoint/
. Then look atallennlp/modules/transformer/t5.py
to see how these features are integrated into a model. The training config in the models PR is a complete example of how to specify these options in a config.