Skip to content

Split Precision.init_context into Precision.tensor_init_context and Precision.module_init_context  #18703

@carmocca

Description

@carmocca

Outline & Motivation

The Fabric strategy interface has two hooks that customize instantiation. One for modules and one for general tensors.
https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/strategies/strategy.py#L121-L144
these correspond to fabric.init_module and fabric.init_tensor.
https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/fabric.py#L667-L704

The Trainer only has one, because modules need to be initialized in configure_model since processes haven't been launched before.
https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/strategies/strategy.py#L494-L506
https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/trainer/call.py#L105-L109

the tensor init contexts currently call precision.init_context. This has been fine so far but two recent plugins, TransformerEnginePrecision and BitsandbytesPrecision do both dtype management and class replacement in this hook.
https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/transformer_engine.py#L96-L110
https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/bitsandbytes.py#L118-L128

This means that a snippet like

fabric = Fabric(precision="transformer-engine")
with fabric.init_tensor():
    l = torch.nn.Linear()

will replace the l class when it should not.

Pitch

Add Precision.tensor_init_context and Precision.module_init_context, matching the API in the strategy so that the plugin can separate this logic.

For the Trainer specifically, since there's no Strategy.module_init_context, this line https://github.com/Lightning-AI/lightning/blob/master/src/lightning/fabric/plugins/precision/bitsandbytes.py#L118-L128 will have to call PrecisionPlugin.module_init_context

Additional context

No response

cc @justusschock @awaelchli @carmocca

Metadata

Metadata

Assignees

No one assigned

    Labels

    fabriclightning.fabric.FabricplGeneric label for PyTorch Lightning packagepluginrefactor

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions