diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 58baf7d2a63cd..2c58a555a8d6f 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -178,6 +178,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue where `self.log`-ing a tensor would create a user warning from PyTorch about cloning tensors ([#14599](https://github.com/Lightning-AI/lightning/pull/14599)) +- Break HPU Graphs into two parts (forward + backward as one and optimizer as another) for better performance ([#14656](https://github.com/Lightning-AI/lightning/pull/14656)) + + - Fixed compatibility when `torch.distributed` is not available ([#14454](https://github.com/Lightning-AI/lightning/pull/14454)) diff --git a/src/pytorch_lightning/strategies/hpu_parallel.py b/src/pytorch_lightning/strategies/hpu_parallel.py index 96c66224ed72b..fdca6813c44f3 100644 --- a/src/pytorch_lightning/strategies/hpu_parallel.py +++ b/src/pytorch_lightning/strategies/hpu_parallel.py @@ -13,9 +13,11 @@ # limitations under the License. import logging import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import torch.distributed +from torch.nn import Module +from torch.optim.optimizer import Optimizer import pytorch_lightning as pl from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment @@ -137,10 +139,22 @@ def broadcast(self, obj: object, src: int = 0) -> object: # type: ignore broadcast_object_list(obj, src, group=_group.WORLD) return obj[0] - def training_step_end(self, step_output: STEP_OUTPUT) -> STEP_OUTPUT: - # Break lazy accumulation of graph after every step + def on_after_backward(self) -> None: + # Break lazy accumulation of graph after fwd+bwd htcore.mark_step() - return step_output + + def optimizer_step( + self, + optimizer: Optimizer, + opt_idx: int, + closure: Callable[[], Any], + model: Optional[Union["pl.LightningModule", Module]] = None, + **kwargs: Any, + ) -> Any: + optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs) + # Break lazy accumulation of graph after optimizer + htcore.mark_step() + return optimizer_output def validation_step_end(self, step_output: STEP_OUTPUT) -> STEP_OUTPUT: # Break lazy accumulation of graph after every step diff --git a/src/pytorch_lightning/strategies/single_hpu.py b/src/pytorch_lightning/strategies/single_hpu.py index 1e91150cded22..5d6ead0358744 100644 --- a/src/pytorch_lightning/strategies/single_hpu.py +++ b/src/pytorch_lightning/strategies/single_hpu.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional +from typing import Any, Callable, Dict, Optional, Union + +from torch.nn import Module +from torch.optim.optimizer import Optimizer import pytorch_lightning as pl from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO @@ -79,10 +82,22 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None: def model_to_device(self) -> None: self.model.to(self.root_device) # type: ignore - def training_step_end(self, step_output: STEP_OUTPUT) -> STEP_OUTPUT: - # Break lazy accumulation of graph after every step + def on_after_backward(self) -> None: + # Break lazy accumulation of graph after fwd+bwd htcore.mark_step() - return step_output + + def optimizer_step( + self, + optimizer: Optimizer, + opt_idx: int, + closure: Callable[[], Any], + model: Optional[Union["pl.LightningModule", Module]] = None, + **kwargs: Any, + ) -> Any: + optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs) + # Break lazy accumulation of graph after optimizer + htcore.mark_step() + return optimizer_output def validation_step_end(self, step_output: STEP_OUTPUT) -> STEP_OUTPUT: # Break lazy accumulation of graph after every step