Skip to content

Commit

Permalink
Fixes incorrect strategy init with HPUAccelerator (#19615)
Browse files Browse the repository at this point in the history
  • Loading branch information
ankitgola005 committed Mar 14, 2024
1 parent 97a95ed commit 1439da4
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions src/lightning/pytorch/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,21 +408,25 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:
return LightningEnvironment()

def _choose_strategy(self) -> Union[Strategy, str]:
if self._accelerator_flag == "hpu":
if not _habana_available_and_importable():
raise ImportError(
"You have asked for HPU but you miss install related integration."
" Please run `pip install lightning-habana` or see for further instructions"
" in https://github.com/Lightning-AI/lightning-Habana/."
)
if self._parallel_devices and len(self._parallel_devices) > 1:
from lightning_habana import HPUParallelStrategy
if _habana_available_and_importable():
from lightning_habana import HPUAccelerator

return HPUParallelStrategy.strategy_name
if self._accelerator_flag == "hpu" or isinstance(self._accelerator_flag, HPUAccelerator):
if self._parallel_devices and len(self._parallel_devices) > 1:
from lightning_habana import HPUParallelStrategy

from lightning_habana import SingleHPUStrategy
return HPUParallelStrategy.strategy_name

from lightning_habana import SingleHPUStrategy

return SingleHPUStrategy(device=torch.device("hpu"))
if self._accelerator_flag == "hpu" and not _habana_available_and_importable():
raise ImportError(
"You asked to run with HPU but you are missing a required dependency."
" Please run `pip install lightning-habana` or seek further instructions"
" in https://github.com/Lightning-AI/lightning-Habana/."
)

return SingleHPUStrategy(device=torch.device("hpu"))
if self._accelerator_flag == "tpu" or isinstance(self._accelerator_flag, XLAAccelerator):
if self._parallel_devices and len(self._parallel_devices) > 1:
return XLAStrategy.strategy_name
Expand Down

0 comments on commit 1439da4

Please sign in to comment.