Skip to content
Merged
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for meta-device initialization with `Fabric.init_module(empty_init=True)` in FSDP ([#18122](https://github.com/Lightning-AI/lightning/pull/18122))


- Added `lightning.fabric.plugins.Precision.init_context()` and `lightning.fabric.strategies.Strategy.module_init_context()` context managers to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
- Added `lightning.fabric.plugins.Precision.module_init_context()` and `lightning.fabric.strategies.Strategy.module_init_context()` context managers to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))


- `lightning.fabric.strategies.Strategy.tensor_init_context()` context manager to instantiate tensors efficiently directly on device and dtype ([#17607](https://github.com/Lightning-AI/lightning/pull/17607))
Expand Down
7 changes: 5 additions & 2 deletions src/lightning/fabric/plugins/precision/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
m.compute_type_is_set = False
return module

def init_context(self) -> ContextManager:
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self.dtype)

def module_init_context(self) -> ContextManager:
if self.ignore_modules:
# cannot patch the Linear class if the user wants to skip some submodules
raise RuntimeError(
Expand All @@ -125,7 +128,7 @@ def init_context(self) -> ContextManager:
" may initialize the layers on-device, defeating the purpose of quantization. You can remove"
" `ignore_modules` or remove the `init_module` context manager."
)
dtype_ctx = _DtypeContextManager(self.dtype)
dtype_ctx = self.tensor_init_context()
# TODO: this could also support replacing `Embedding` and `Conv1D`
context_manager = _ClassReplacementContextManager({"torch.nn.Linear": self._linear_cls})
stack = ExitStack()
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/fabric/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,14 @@ def convert_module(self, module: Module) -> Module:
return module.to(dtype=self._desired_dtype)
return module

def init_context(self) -> ContextManager:
def tensor_init_context(self) -> ContextManager:
if "true" not in self.precision:
return nullcontext()
return _DtypeContextManager(self._desired_dtype)

def module_init_context(self) -> ContextManager:
return self.tensor_init_context()

def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)

Expand Down
7 changes: 5 additions & 2 deletions src/lightning/fabric/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,14 @@ class DoublePrecision(Precision):
def convert_module(self, module: Module) -> Module:
return module.double()

def init_context(self) -> ContextManager:
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(torch.double)

def module_init_context(self) -> ContextManager:
return self.tensor_init_context()

def forward_context(self) -> ContextManager:
return _DtypeContextManager(torch.double)
return self.tensor_init_context()

def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.double)
Expand Down
7 changes: 5 additions & 2 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,16 @@ def mixed_precision_config(self) -> "TorchMixedPrecision":
buffer_dtype=buffer_dtype,
)

def init_context(self) -> ContextManager:
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self._desired_input_dtype)

def module_init_context(self) -> ContextManager:
return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32)

def forward_context(self) -> ContextManager:
if "mixed" in self.precision:
return torch.autocast("cuda", dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16))
return _DtypeContextManager(self._desired_input_dtype)
return self.tensor_init_context()

def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
Expand Down
7 changes: 5 additions & 2 deletions src/lightning/fabric/plugins/precision/half.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@ def __init__(self, precision: Literal["bf16-true", "16-true"] = "16-true") -> No
def convert_module(self, module: Module) -> Module:
return module.to(dtype=self._desired_input_dtype)

def init_context(self) -> ContextManager:
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self._desired_input_dtype)

def module_init_context(self) -> ContextManager:
return self.tensor_init_context()

def forward_context(self) -> ContextManager:
return _DtypeContextManager(self._desired_input_dtype)
return self.tensor_init_context()

def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
Expand Down
6 changes: 5 additions & 1 deletion src/lightning/fabric/plugins/precision/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def convert_module(self, module: Module) -> Module:
"""
return module

def init_context(self) -> ContextManager:
def tensor_init_context(self) -> ContextManager:
"""Controls how tensors get created (device, dtype)."""
return nullcontext()

def module_init_context(self) -> ContextManager:
"""Instantiate module parameters or tensors in the precision type this plugin handles.

This is optional and depends on the precision limitations during optimization.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
module = module.to(dtype=self.dtype)
return module

def init_context(self) -> ContextManager:
dtype_ctx = _DtypeContextManager(self.dtype)
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self.dtype)

def module_init_context(self) -> ContextManager:
dtype_ctx = self.tensor_init_context()
stack = ExitStack()
if self.replace_layers:
import transformer_engine.pytorch as te
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,16 +329,17 @@ def module_to_device(self, module: Module) -> None:
pass

def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
precision_init_ctx = self.precision.init_context()
precision_init_ctx = self.precision.module_init_context()
module_sharded_ctx = self.module_sharded_context()
empty_ctx = _EmptyInit(enabled=bool(empty_init))
stack = ExitStack()
if _TORCH_GREATER_EQUAL_2_1 and empty_init:
# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
# 1) materialize module 2) call `reset_parameters()` 3) shard the module.
# These operations are applied to each submodule 'bottom up' in the module hierarchy.
stack.enter_context(torch.device("meta"))
elif _TORCH_GREATER_EQUAL_1_13:
stack.enter_context(_EmptyInit(enabled=bool(empty_init)))
stack.enter_context(empty_ctx)
stack.enter_context(precision_init_ctx)
stack.enter_context(module_sharded_ctx)
return stack
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader:

def tensor_init_context(self) -> ContextManager:
"""Controls how tensors get created (device, dtype)."""
precision_init_ctx = self.precision.init_context()
precision_init_ctx = self.precision.tensor_init_context()
stack = ExitStack()
if _TORCH_GREATER_EQUAL_2_0:
stack.enter_context(self.root_device)
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/strategies/xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def module_to_device(self, module: Module) -> None:
pass

def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
precision_init_ctx = self.precision.init_context()
precision_init_ctx = self.precision.module_init_context()
module_sharded_ctx = self.module_sharded_context()
stack = ExitStack()
if _TORCH_GREATER_EQUAL_1_13:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for meta-device initialization with `Trainer.init_module(empty_init=True)` in FSDP ([#18385](https://github.com/Lightning-AI/lightning/pull/18385))


- Added `lightning.pytorch.plugins.PrecisionPlugin.init_context()` and `lightning.pytorch.strategies.Strategy.tensor_init_context()` context managers to control model and tensor instantiation ([#18004](https://github.com/Lightning-AI/lightning/pull/18004))
- Added `lightning.pytorch.plugins.PrecisionPlugin.module_init_context()` and `lightning.pytorch.strategies.Strategy.tensor_init_context()` context managers to control model and tensor instantiation ([#18004](https://github.com/Lightning-AI/lightning/pull/18004))


- Automatically call `xla_model.mark_step()` before saving checkpoints with XLA ([#17882](https://github.com/Lightning-AI/lightning/pull/17882))
Expand Down
23 changes: 9 additions & 14 deletions src/lightning/pytorch/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, Union
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Optional, Union

import torch
from lightning_utilities import apply_to_collection
Expand All @@ -23,7 +23,7 @@

import lightning.pytorch as pl
from lightning.fabric.plugins.precision.deepspeed import _PRECISION_INPUT
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
from lightning.fabric.utilities.types import Steppable
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.utilities import GradClipAlgorithmType
Expand Down Expand Up @@ -77,18 +77,13 @@ def convert_module(self, module: Module) -> Module:
def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_dtype)

@contextmanager
def init_context(self) -> Generator[None, None, None]:
def tensor_init_context(self) -> ContextManager:
if "true" not in self.precision:
yield
return

default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._desired_dtype)
try:
yield
finally:
torch.set_default_dtype(default_dtype)
return nullcontext()
return _DtypeContextManager(self._desired_dtype)

def module_init_context(self) -> ContextManager:
return self.tensor_init_context()

def backward( # type: ignore[override]
self,
Expand Down
26 changes: 7 additions & 19 deletions src/lightning/pytorch/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Generator, Literal
from typing import Any, ContextManager, Generator, Literal

import torch
import torch.nn as nn
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor

import lightning.pytorch as pl
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.utilities.rank_zero import rank_zero_deprecation
Expand All @@ -34,19 +34,11 @@ class DoublePrecisionPlugin(PrecisionPlugin):
def convert_module(self, module: nn.Module) -> nn.Module:
return module.double()

@contextmanager
def init_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type when initializing module parameters or tensors.

See: :func:`torch.set_default_dtype`
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(torch.float64)

"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float64)
try:
yield
finally:
torch.set_default_dtype(default_dtype)
def module_init_context(self) -> ContextManager:
return self.tensor_init_context()

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
Expand All @@ -55,12 +47,8 @@ def forward_context(self) -> Generator[None, None, None]:
See: :func:`torch.set_default_dtype`

"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float64)
try:
with self.tensor_init_context():
yield
finally:
torch.set_default_dtype(default_dtype)

def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.double)
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ def mixed_precision_config(self) -> "TorchMixedPrecision":
buffer_dtype=buffer_dtype,
)

def init_context(self) -> ContextManager:
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self._desired_input_dtype)

def module_init_context(self) -> ContextManager:
return _DtypeContextManager(self.mixed_precision_config.param_dtype or torch.float32)

def forward_context(self) -> ContextManager:
Expand Down
20 changes: 6 additions & 14 deletions src/lightning/pytorch/plugins/precision/half.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Generator, Literal
from typing import Any, ContextManager, Generator, Literal

import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.nn import Module

from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin


Expand All @@ -40,19 +40,11 @@ def __init__(self, precision: Literal["bf16-true", "16-true"] = "16-true") -> No
def convert_module(self, module: Module) -> Module:
return module.to(dtype=self._desired_input_dtype)

@contextmanager
def init_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type when initializing module parameters or tensors.

See: :func:`torch.set_default_dtype`
def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self._desired_input_dtype)

"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self._desired_input_dtype)
try:
yield
finally:
torch.set_default_dtype(default_dtype)
def module_init_context(self) -> ContextManager:
return self.tensor_init_context()

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[No
empty_init_context = _EmptyInit(enabled=bool(empty_init))
else:
empty_init_context = nullcontext()
with empty_init_context, self.precision_plugin.init_context():
with empty_init_context, self.precision_plugin.tensor_init_context():
yield

@contextmanager
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[No
"""
device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext()
empty_init_context = _EmptyInit(enabled=bool(empty_init)) if _TORCH_GREATER_EQUAL_1_13 else nullcontext()
with empty_init_context, device_context, self.precision_plugin.init_context():
with empty_init_context, device_context, self.precision_plugin.tensor_init_context():
yield

@contextmanager
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/trainer/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _call_configure_model(trainer: "pl.Trainer") -> None:
# we don't normally check for this before calling the hook. it is done here to avoid instantiating the context
# managers
if is_overridden("configure_model", trainer.lightning_module):
with trainer.strategy.tensor_init_context(), trainer.strategy.model_sharded_context():
with trainer.strategy.tensor_init_context(), trainer.strategy.model_sharded_context(), trainer.precision_plugin.module_init_context(): # noqa: E501
_call_lightning_module_hook(trainer, "configure_model")


Expand Down
4 changes: 2 additions & 2 deletions tests/tests_fabric/plugins/precision/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def test_default_dtype_is_restored(precision):
precision = FSDPPrecision("16-true")

contexts = (
(precision.init_context, precision.forward_context)
(precision.module_init_context, precision.forward_context)
if not isinstance(precision, DeepSpeedPrecision)
else (precision.init_context,)
else (precision.module_init_context,)
)
for context in contexts:
assert torch.get_default_dtype() is torch.float32
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_fabric/plugins/precision/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, in_features, out_features, bias=True, *_, **__):

# same logic as in `test_default_dtype_is_restored`
assert torch.get_default_dtype() is torch.float32
with pytest.raises(RuntimeError, match="foo"), precision.init_context():
with pytest.raises(RuntimeError, match="foo"), precision.module_init_context():
assert torch.get_default_dtype() is not torch.float32
raise RuntimeError("foo")
assert torch.get_default_dtype() is torch.float32
Expand All @@ -65,7 +65,7 @@ def __init__(self):
_NF4Linear = vars(module)["_NF4Linear"]
_NF4Linear._quantize_weight = Mock()

with precision.init_context():
with precision.module_init_context():
assert torch.get_default_dtype() == torch.float16
model = MyModule()
assert isinstance(model.l1, _NF4Linear)
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_fabric/plugins/precision/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def test_selected_dtype(precision, expected_dtype):
("16-true", torch.float16),
],
)
def test_init_context(precision, expected_dtype):
def test_module_init_context(precision, expected_dtype):
plugin = DeepSpeedPrecision(precision=precision)
with plugin.init_context():
with plugin.module_init_context():
model = torch.nn.Linear(2, 2)
assert torch.get_default_dtype() == expected_dtype
assert model.weight.dtype == expected_dtype
Expand Down
Loading