diff --git a/pyproject.toml b/pyproject.toml index 05eba62c50402..15f0293bb1c8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,6 @@ module = [ "pytorch_lightning.strategies.ipu", "pytorch_lightning.strategies.sharded", "pytorch_lightning.strategies.sharded_spawn", - "pytorch_lightning.strategies.tpu_spawn", "pytorch_lightning.trainer.callback_hook", "pytorch_lightning.trainer.connectors.callback_connector", "pytorch_lightning.trainer.connectors.data_connector", diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index 4d20e784e0d29..62bb1c308480b 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -13,7 +13,7 @@ # limitations under the License. import io import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union import torch from torch import Tensor @@ -29,15 +29,17 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.strategies.launchers.xla import _XLALauncher +from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.data import has_len from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_only -from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT +from pytorch_lightning.utilities.types import _PATH, EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS if _TPU_AVAILABLE: import torch_xla.core.xla_env_vars as xenv @@ -58,7 +60,7 @@ class TPUSpawnStrategy(DDPSpawnStrategy): def __init__( self, accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None, - parallel_devices: Optional[List[int]] = None, + parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, @@ -72,6 +74,7 @@ def __init__( precision_plugin=precision_plugin, start_method="fork", ) + self._checkpoint_io: Optional[CheckpointIO] self.debug = debug self._launched = False @@ -95,17 +98,16 @@ def root_device(self) -> torch.device: return xm.xla_device() @staticmethod - def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> None: - if not isinstance(dataloaders, list): - dataloaders = [dataloaders] - - for dataloader in dataloaders: + def _validate_dataloader(dataloaders: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None: + def check_has_len(dataloader: DataLoader) -> None: if not has_len(dataloader): raise MisconfigurationException( "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." " HINT: You can mock the length on your dataset to bypass this MisconfigurationException." ) + apply_to_collection(dataloaders, dtype=object, wrong_dtype=(Sequence, Mapping), function=check_has_len) + @staticmethod def _validate_patched_dataloaders(model: "pl.LightningModule") -> None: """Validate and fail fast if the dataloaders were passed directly to fit.""" @@ -118,24 +120,29 @@ def _validate_patched_dataloaders(model: "pl.LightningModule") -> None: ) for source in sources: if not source.is_module(): + assert source.instance is not None + assert not isinstance(source.instance, (pl.LightningModule, pl.LightningDataModule)) TPUSpawnStrategy._validate_dataloader(source.instance) - def connect(self, model: "pl.LightningModule") -> None: + def connect(self, model: "pl.LightningModule") -> None: # type: ignore TPUSpawnStrategy._validate_patched_dataloaders(model) self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model)) return super().connect(model) - def _configure_launcher(self): + def _configure_launcher(self) -> None: self._launcher = _XLALauncher(self) def setup(self, trainer: "pl.Trainer") -> None: + assert self.accelerator self.accelerator.setup(trainer) if self.debug: os.environ["PT_XLA_DEBUG"] = "1" + assert self.model shared_params = find_shared_parameters(self.model) self.model_to_device() + assert isinstance(self.model.module, Module) set_shared_parameters(self.model.module, shared_params) self.setup_precision_plugin() @@ -143,7 +150,7 @@ def setup(self, trainer: "pl.Trainer") -> None: self.setup_optimizers(trainer) optimizers_to_device(self.optimizers, self.root_device) - def _setup_model(self, model: Module) -> Module: + def _setup_model(self, model: Module) -> Module: # type: ignore return model @property @@ -168,11 +175,11 @@ def configure_ddp(self) -> None: def model_to_device(self) -> None: self.model = self.wrapped_model.to(self.root_device) - def barrier(self, name: Optional[str] = None) -> None: + def barrier(self, name: Optional[str] = None, *args: Any, **kwargs: Any) -> None: if self.is_distributed: rendezvous(name) - def broadcast(self, obj: object, src: int = 0) -> object: + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: if not self.is_distributed: return obj buffer = io.BytesIO() @@ -184,7 +191,9 @@ def broadcast(self, obj: object, src: int = 0) -> object: obj = torch.load(buffer) return obj - def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): + def reduce( + self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None + ) -> Tensor: if not isinstance(output, Tensor): output = torch.tensor(output, device=self.root_device) @@ -203,20 +212,23 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ return output - def _worker_setup(self, process_idx: int): + def _worker_setup(self, process_idx: int) -> None: self._launched = True self.set_world_ranks(process_idx) rank_zero_only.rank = self.global_rank - def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: + assert self.model is not None with self.precision_plugin.val_step_context(): return self.model(*args, **kwargs) - def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: + assert self.model is not None with self.precision_plugin.test_step_context(): return self.model(*args, **kwargs) - def predict_step(self, *args, **kwargs) -> STEP_OUTPUT: + def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + assert self.model is not None with self.precision_plugin.predict_step_context(): return self.model(*args, **kwargs) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 7831316a98ae1..e1aca404722db 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -516,7 +516,7 @@ def is_defined(self) -> bool: return not self.is_module() or is_overridden(self.name, self.instance) def is_module(self) -> bool: - """Returns whether the the DataLoader source is a LightningModule or a LightningDataModule. + """Returns whether the DataLoader source is a LightningModule or a LightningDataModule. It does not check whether ``*_dataloader`` methods are actually overridden. """ diff --git a/src/pytorch_lightning/utilities/apply_func.py b/src/pytorch_lightning/utilities/apply_func.py index cfeb48c423332..8729520ee9d96 100644 --- a/src/pytorch_lightning/utilities/apply_func.py +++ b/src/pytorch_lightning/utilities/apply_func.py @@ -76,7 +76,7 @@ def apply_to_collection( dtype: Union[type, Any, Tuple[Union[type, Any]]], function: Callable, *args: Any, - wrong_dtype: Optional[Union[type, Tuple[type]]] = None, + wrong_dtype: Optional[Union[type, Tuple[type, ...]]] = None, include_none: bool = True, **kwargs: Any, ) -> Any: