From adef5d45194ea2f666ef19f3b9601c5dffa7ca36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 4 Oct 2023 17:52:45 +0200 Subject: [PATCH] Create context managers before entering any with ExitStack --- .../fabric/plugins/precision/bitsandbytes.py | 4 ++-- .../fabric/plugins/precision/transformer_engine.py | 13 +++++++------ src/lightning/fabric/strategies/deepspeed.py | 3 ++- src/lightning/fabric/strategies/fsdp.py | 6 ++++-- src/lightning/fabric/strategies/strategy.py | 6 ++++-- src/lightning/fabric/strategies/xla_fsdp.py | 6 ++++-- 6 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index 490dee8b205d4..5e3ef06dfb6f2 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -120,10 +120,10 @@ def init_context(self) -> ContextManager: if self.ignore_modules: # cannot patch the Linear class if the user wants to skip some submodules return dtype_ctx - stack = ExitStack() - stack.enter_context(dtype_ctx) # TODO: this could also support replacing `Embedding` and `Conv1D` context_manager = _ClassReplacementContextManager({"torch.nn.Linear": self._linear_cls}) + stack = ExitStack() + stack.enter_context(dtype_ctx) stack.enter_context(context_manager) return stack diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py index 9fdf75d05c53d..a9486de0b92c4 100644 --- a/src/lightning/fabric/plugins/precision/transformer_engine.py +++ b/src/lightning/fabric/plugins/precision/transformer_engine.py @@ -94,9 +94,8 @@ def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: return module def init_context(self) -> ContextManager: + dtype_ctx = _DtypeContextManager(self.dtype) stack = ExitStack() - stack.enter_context(_DtypeContextManager(self.dtype)) - if self.replace_layers: import transformer_engine.pytorch as te @@ -107,15 +106,17 @@ def init_context(self) -> ContextManager: } ) stack.enter_context(context_manager) + stack.enter_context(dtype_ctx) return stack def forward_context(self) -> ContextManager: - stack = ExitStack() - stack.enter_context(_DtypeContextManager(self.dtype)) - + dtype_ctx = _DtypeContextManager(self.dtype) import transformer_engine.pytorch as te - stack.enter_context(te.fp8_autocast(enabled=True, fp8_recipe=self.recipe)) + autocast_ctx = te.fp8_autocast(enabled=True, fp8_recipe=self.recipe) + stack = ExitStack() + stack.enter_context(dtype_ctx) + stack.enter_context(autocast_ctx) return stack def convert_input(self, data: Any) -> Any: diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index a7bd9a9655dcd..43d40247aadb0 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -350,10 +350,11 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag raise NotImplementedError( f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled." ) + module_sharded_ctx = self.module_sharded_context() stack = ExitStack() if not self.zero_stage_3: stack.enter_context(super().module_init_context(empty_init=empty_init)) - stack.enter_context(self.module_sharded_context()) + stack.enter_context(module_sharded_ctx) return stack def module_sharded_context(self) -> ContextManager: diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index fc076784c5630..d51f6f0029c74 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -337,6 +337,8 @@ def module_to_device(self, module: Module) -> None: pass def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + precision_init_ctx = self.precision.init_context() + module_sharded_ctx = self.module_sharded_context() stack = ExitStack() if _TORCH_GREATER_EQUAL_2_1 and empty_init: # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is: @@ -345,8 +347,8 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag stack.enter_context(torch.device("meta")) elif _TORCH_GREATER_EQUAL_1_13: stack.enter_context(_EmptyInit(enabled=bool(empty_init))) - stack.enter_context(self.precision.init_context()) - stack.enter_context(self.module_sharded_context()) + stack.enter_context(precision_init_ctx) + stack.enter_context(module_sharded_ctx) return stack def module_sharded_context(self) -> ContextManager: diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index a9401ddeafc7e..9747835690321 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -120,10 +120,11 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader: def tensor_init_context(self) -> ContextManager: """Controls how tensors get created (device, dtype).""" + precision_init_ctx = self.precision.init_context() stack = ExitStack() if _TORCH_GREATER_EQUAL_2_0: stack.enter_context(self.root_device) - stack.enter_context(self.precision.init_context()) + stack.enter_context(precision_init_ctx) return stack def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: @@ -137,10 +138,11 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag If ``None``, the strategy will decide. Some strategies may not support all options. """ + tensor_init_ctx = self.tensor_init_context() stack = ExitStack() if _TORCH_GREATER_EQUAL_1_13: stack.enter_context(_EmptyInit(enabled=bool(empty_init))) - stack.enter_context(self.tensor_init_context()) + stack.enter_context(tensor_init_ctx) return stack def setup_module_and_optimizers( diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index e9d1e005e4a92..30ebd5c1b5c5a 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -194,11 +194,13 @@ def module_to_device(self, module: Module) -> None: pass def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: + precision_init_ctx = self.precision.init_context() + module_sharded_ctx = self.module_sharded_context() stack = ExitStack() if _TORCH_GREATER_EQUAL_1_13: stack.enter_context(_EmptyInit(enabled=bool(empty_init))) - stack.enter_context(self.precision.init_context()) - stack.enter_context(self.module_sharded_context()) + stack.enter_context(precision_init_ctx) + stack.enter_context(module_sharded_ctx) return stack def module_sharded_context(self) -> ContextManager: