diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 3481986f2102f..cb3ffabecd41b 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -316,7 +316,7 @@ def model_sharded_context(self) -> Generator[None, None, None]: else: model_parallel_context = super().model_sharded_context() - with model_parallel_context: + with torch.cuda.amp.autocast(), model_parallel_context: yield def _set_deepspeed_activation_checkpointing(self): diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 85d069b90288d..a7b1d0b96a33f 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -24,10 +24,10 @@ class ModelParallelBoringModel(BoringModel): def __init__(self): super().__init__() - self.linear = None + self.layer = None def configure_sharded_model(self) -> None: - self.linear = torch.nn.Linear(32, 2) + self.layer = torch.nn.Linear(32, 2) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: self.configure_sharded_model()