-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Outline & Motivation
The Fabric strategy interface has two hooks that customize instantiation. One for modules and one for general tensors.
https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/strategies/strategy.py#L121-L144
these correspond to fabric.init_module and fabric.init_tensor.
https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/fabric.py#L667-L704
The Trainer only has one, because modules need to be initialized in configure_model since processes haven't been launched before.
https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/strategies/strategy.py#L494-L506
https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/trainer/call.py#L105-L109
the tensor init contexts currently call precision.init_context. This has been fine so far but two recent plugins, TransformerEnginePrecision and BitsandbytesPrecision do both dtype management and class replacement in this hook.
https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/transformer_engine.py#L96-L110
https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/bitsandbytes.py#L118-L128
This means that a snippet like
fabric = Fabric(precision="transformer-engine")
with fabric.init_tensor():
l = torch.nn.Linear()will replace the l class when it should not.
Pitch
Add Precision.tensor_init_context and Precision.module_init_context, matching the API in the strategy so that the plugin can separate this logic.
For the Trainer specifically, since there's no Strategy.module_init_context, this line https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/bitsandbytes.py#L118-L128 will have to call PrecisionPlugin.module_init_context
Additional context
No response