diff --git a/docs/source-pytorch/governance.rst b/docs/source-pytorch/community/governance.rst similarity index 100% rename from docs/source-pytorch/governance.rst rename to docs/source-pytorch/community/governance.rst diff --git a/docs/source-pytorch/community/index.rst b/docs/source-pytorch/community/index.rst new file mode 100644 index 0000000000000..bd500a914146e --- /dev/null +++ b/docs/source-pytorch/community/index.rst @@ -0,0 +1,68 @@ + +.. toctree:: + :maxdepth: 1 + :hidden: + + ../generated/CODE_OF_CONDUCT.md + ../generated/CONTRIBUTING.md + ../generated/BECOMING_A_CORE_CONTRIBUTOR.md + governance + ../versioning + ../past_versions + ../generated/CHANGELOG.md + +######### +Community +######### + +.. raw:: html + +
+
+ +.. displayitem:: + :header: Code of conduct + :description: Contributor Covenant Code of Conduct + :col_css: col-md-12 + :button_link: ../generated/CODE_OF_CONDUCT.html + :height: 100 + +.. displayitem:: + :header: Contribution guide + :description: How to contribute to PyTorch Lightning + :col_css: col-md-12 + :button_link: ../generated/CONTRIBUTING.html + :height: 100 + +.. displayitem:: + :header: How to Become a core contributor + :description: Steps to be a core contributor + :col_css: col-md-12 + :button_link: ../generated/BECOMING_A_CORE_CONTRIBUTOR.html + :height: 100 + +.. displayitem:: + :header: Lightning Governance + :description: The governance processes we follow + :col_css: col-md-12 + :button_link: governance.html + :height: 100 + +.. displayitem:: + :header: Versioning + :description: PyTorch Lightning's versioning policy + :col_css: col-md-12 + :button_link: ../versioning.html + :height: 100 + +.. displayitem:: + :header: Changelog + :description: All notable changes to PyTorch Lightning + :col_css: col-md-12 + :button_link: ../generated/CHANGELOG.html + :height: 100 + +.. raw:: html + +
+
diff --git a/docs/source-pytorch/glossary/index.rst b/docs/source-pytorch/glossary/index.rst new file mode 100644 index 0000000000000..2c91a7c2d0b80 --- /dev/null +++ b/docs/source-pytorch/glossary/index.rst @@ -0,0 +1,331 @@ + +.. toctree:: + :maxdepth: 1 + :hidden: + + Accelerators <../extensions/accelerator> + Callback <../extensions/callbacks> + Checkpointing <../common/checkpointing> + Cluster <../clouds/cluster> + Cloud checkpoint <../common/checkpointing_advanced> + Console Logging <../common/console_logs> + Debugging <../debug/debugging> + Early stopping <../common/early_stopping> + Experiment manager (Logger) <../visualize/experiment_managers> + Finetuning <../advanced/finetuning> + GPU <../accelerators/gpu> + Half precision <../common/precision> + HPU <../accelerators/hpu> + Inference <../deploy/production_intermediate> + IPU <../accelerators/ipu> + Lightning CLI <../cli/lightning_cli> + LightningDataModule <../data/datamodule> + LightningModule <../common/lightning_module> + Log <../visualize/loggers> + TPU <../accelerators/tpu> + Metrics + Model <../model/build_model.rst> + Model Parallel <../advanced/model_parallel> + Plugins <../extensions/plugins> + Progress bar <../common/progress_bar> + Production <../deploy/production_advanced> + Predict <../deploy/production_basic> + Pretrained models <../advanced/pretrained> + Profiler <../tuning/profiler> + Pruning and Quantization <../advanced/pruning_quantization> + Remote filesystem and FSSPEC <../common/remote_fs> + Strategy <../extensions/strategy> + Strategy registry <../advanced/strategy_registry> + Style guide <../starter/style_guide> + SWA <../advanced/training_tricks> + SLURM <../clouds/cluster_advanced> + Transfer learning <../advanced/transfer_learning> + Trainer <../common/trainer> + Torch distributed <../clouds/cluster_intermediate_2> + +######## +Glossary +######## + +.. raw:: html + +
+
+ +.. displayitem:: + :header: Accelerators + :description: Accelerators connect the Trainer to hardware to train faster + :col_css: col-md-12 + :button_link: ../extensions/accelerator.html + :height: 100 + +.. displayitem:: + :header: Callback + :description: Add self-contained extra functionality during training execution + :col_css: col-md-12 + :button_link: ../extensions/callbacks.html + :height: 100 + +.. displayitem:: + :header: Checkpointing + :description: Save and load progress with checkpoints + :col_css: col-md-12 + :button_link: ../common/checkpointing.html + :height: 100 + +.. displayitem:: + :header: Cluster + :description: Run on your own group of servers + :col_css: col-md-12 + :button_link: ../clouds/cluster.html + :height: 100 + +.. displayitem:: + :header: Cloud checkpoint + :description: Save your models to cloud filesystems + :col_css: col-md-12 + :button_link: ../common/checkpointing_advanced.html + :height: 100 + +.. displayitem:: + :header: Console Logging + :description: Capture more visible logs + :col_css: col-md-12 + :button_link: ../common/console_logs.html + :height: 100 + +.. displayitem:: + :header: Debugging + :description: Fix errors in your code + :col_css: col-md-12 + :button_link: ../debug/debugging.html + :height: 100 + +.. displayitem:: + :header: Early stopping + :description: Stop the training when no improvement is observed + :col_css: col-md-12 + :button_link: ../common/early_stopping.html + :height: 100 + +.. displayitem:: + :header: Experiment manager (Logger) + :description: Tools for tracking and visualizing artifacts and logs + :col_css: col-md-12 + :button_link: ../visualize/experiment_managers.html + :height: 100 + +.. displayitem:: + :header: Finetuning + :description: Technique for training pretrained models + :col_css: col-md-12 + :button_link: ../advanced/finetuning.html + :height: 100 + +.. displayitem:: + :header: GPU + :description: Graphics Processing Unit for faster training + :col_css: col-md-12 + :button_link: ../accelerators/gpu.html + :height: 100 + +.. displayitem:: + :header: Half precision + :description: Using different numerical formats to save memory and run fatser + :col_css: col-md-12 + :button_link: ../common/precision.html + :height: 100 + +.. displayitem:: + :header: HPU + :description: Habana Gaudi AI Processor Unit for faster training + :col_css: col-md-12 + :button_link: ../accelerators/hpu.html + :height: 100 + +.. displayitem:: + :header: Inference + :description: Making predictions by applying a trained model to unlabeled examples + :col_css: col-md-12 + :button_link: ../deploy/production_intermediate.html + :height: 100 + +.. displayitem:: + :header: IPU + :description: Graphcore Intelligence Processing Unit for faster training + :col_css: col-md-12 + :button_link: ../accelerators/ipu.html + :height: 100 + +.. displayitem:: + :header: Lightning CLI + :description: A Command-line Interface (CLI) to interact with Lightning code via a terminal + :col_css: col-md-12 + :button_link: ../cli/lightning_cli.html + :height: 100 + +.. displayitem:: + :header: LightningDataModule + :description: A shareable, reusable class that encapsulates all the steps needed to process data + :col_css: col-md-12 + :button_link: ../data/datamodule.html + :height: 100 + +.. displayitem:: + :header: LightningModule + :description: A base class organizug your neural network module + :col_css: col-md-12 + :button_link: ../common/lightning_module.html + :height: 100 + +.. displayitem:: + :header: Log + :description: Outpus or results used for visualization and tracking + :col_css: col-md-12 + :button_link: ../visualize/loggers.html + :height: 100 + +.. displayitem:: + :header: Metrics + :description: A statistic used to measure performance or other objectives we want to optimize + :col_css: col-md-12 + :button_link: https://torchmetrics.readthedocs.io/en/stable/ + :height: 100 + +.. displayitem:: + :header: Model + :description: The set of parameters and structure for a system to make predictions + :col_css: col-md-12 + :button_link: ../model/build_model.html + :height: 100 + +.. displayitem:: + :header: Model Parallelism + :description: A way to scale training that splits a model between multiple devices. + :col_css: col-md-12 + :button_link: ../advanced/model_parallel.html + :height: 100 + +.. displayitem:: + :header: Plugins + :description: Custom trainer integrations such as custom precision, checkpointing or cluster environment implementation + :col_css: col-md-12 + :button_link: ../extensions/plugins.html + :height: 100 + +.. displayitem:: + :header: Progress bar + :description: Output printed to the terminal to visualize the progression of training + :col_css: col-md-12 + :button_link: ../common/progress_bar.html + :height: 100 + +.. displayitem:: + :header: Production + :description: Using ML models in real world systems + :col_css: col-md-12 + :button_link: ../deploy/production_advanced.html + :height: 100 + +.. displayitem:: + :header: Prediction + :description: Computing a model's output + :col_css: col-md-12 + :button_link: ../deploy/production_basic.html + :height: 100 + +.. displayitem:: + :header: Pretrained models + :description: Models that have already been trained for a particular task + :col_css: col-md-12 + :button_link: ../advanced/pretrained.html + :height: 100 + +.. displayitem:: + :header: Profiler + :description: Tool to identify bottlenecks and performance of different parts of a model + :col_css: col-md-12 + :button_link: ../tuning/profiler.html + :height: 100 + +.. displayitem:: + :header: Pruning + :description: A technique to eliminae some of the model weights to reduce the model size and decrease inference requirements + :col_css: col-md-12 + :button_link: ../advanced/pruning_quantization.html + :height: 100 + +.. displayitem:: + :header: Quantization + :description: A technique to accelerate the model inference speed and decrease the memory load while still maintaining the model accuracy + :col_css: col-md-12 + :button_link: ../advanced/post_training_quantization.html + :height: 100 + +.. displayitem:: + :header: Remote filesystem and FSSPEC + :description: Accessing files from cloud storage providers + :col_css: col-md-12 + :button_link: ../common/remote_fs.html + :height: 100 + +.. displayitem:: + :header: Strategy + :description: Ways the trainer controls the model distribution across training, evaluation, and prediction + :col_css: col-md-12 + :button_link: ../extensions/strategy.html + :height: 100 + +.. displayitem:: + :header: Strategy registry + :description: A class that holds information about training strategies and allows adding new custom strategies + :col_css: col-md-12 + :button_link: ../advanced/strategy_registry.html + :height: 100 + +.. displayitem:: + :header: Style guide + :description: Best practices to improve readability and reproducability + :col_css: col-md-12 + :button_link: ../starter/style_guide.html + :height: 100 + +.. displayitem:: + :header: SWA + :description: Stochastic Weight Averaging (SWA) can make your models generalize better + :col_css: col-md-12 + :button_link: ../advanced/training_tricks.html#stochastic-weight-averaging + :height: 100 + +.. displayitem:: + :header: SLURM + :description: Simple Linux Utility for Resource Management, or simply Slurm, is a free and open-source job scheduler for Linux clusters + :col_css: col-md-12 + :button_link: ../clouds/cluster_advanced.html + :height: 100 + +.. displayitem:: + :header: Transfer learning + :description: Using pre-trained models to improve learning + :col_css: col-md-12 + :button_link: ../advanced/transfer_learning.html + :height: 100 + +.. displayitem:: + :header: Trainer + :description: The class that automates and customizes model training + :col_css: col-md-12 + :button_link: ../common/trainer.html + :height: 100 + +.. displayitem:: + :header: Torch distributed + :description: Setup for running on distributed environments + :col_css: col-md-12 + :button_link: ../clouds/cluster_intermediate_2.html + :height: 100 + +.. raw:: html + +
+
diff --git a/docs/source-pytorch/index.rst b/docs/source-pytorch/index.rst index 9560190717931..0bd78b6fb63a8 100644 --- a/docs/source-pytorch/index.rst +++ b/docs/source-pytorch/index.rst @@ -165,7 +165,7 @@ Current Lightning Users levels/expert .. toctree:: - :maxdepth: 2 + :maxdepth: 1 :name: pl_docs :caption: Core API @@ -173,105 +173,30 @@ Current Lightning Users common/trainer .. toctree:: - :maxdepth: 2 + :maxdepth: 1 :name: api - :caption: API Reference + :caption: Optional API api_references .. toctree:: :maxdepth: 1 - :name: Common Workflows - :caption: Common Workflows - - Avoid overfitting - model/build_model.rst - cli/lightning_cli - common/progress_bar - deploy/production - advanced/training_tricks - tuning/profiler - Manage experiments - Organize existing PyTorch into Lightning - clouds/cluster - Save and load model progress - Save memory with half-precision - advanced/model_parallel - Train on single or multiple GPUs - Train on single or multiple HPUs - Train on single or multiple IPUs - Train on single or multiple TPUs - Train on MPS - Use a pretrained model - data/data - model/own_your_loop - -.. toctree:: - :maxdepth: 1 - :name: Glossary - :caption: Glossary - - Accelerators - Callback - Checkpointing - Cluster - Cloud checkpoint - Console Logging - Debugging - Early stopping - Experiment manager (Logger) - Finetuning - Flash - GPU - Half precision - HPU - Inference - IPU - Lightning CLI - LightningDataModule - LightningModule - Log - TPU - Metrics - Model - Model Parallel - Plugins - Progress bar - Production - Predict - Pretrained models - Profiler - Pruning and Quantization - Remote filesystem and FSSPEC - Strategy - Strategy registry - Style guide - SWA - SLURM - Transfer learning - Trainer - Torch distributed - -.. toctree:: - :maxdepth: 1 - :name: Hands-on Examples - :caption: Hands-on Examples + :name: Examples + :caption: Examples :glob: notebooks/**/* + .. toctree:: :maxdepth: 1 - :name: Community - :caption: Community - - generated/CODE_OF_CONDUCT.md - generated/CONTRIBUTING.md - generated/BECOMING_A_CORE_CONTRIBUTOR.md - governance - versioning - past_versions - generated/CHANGELOG.md + :name: More + :caption: More + + Community + Glossary + How to + .. raw:: html diff --git a/docs/source-pytorch/past_versions.rst b/docs/source-pytorch/past_versions.rst index be8f516824681..374303735ed32 100644 --- a/docs/source-pytorch/past_versions.rst +++ b/docs/source-pytorch/past_versions.rst @@ -3,7 +3,7 @@ Past PyTorch Lightning versions PyTorch Lightning :doc:`evolved over time `. Here's the history of versions with links to their respective docs. -To help you with keeping up to spead, check :doc:`Migration guide <./upgrade/migration_guide>`. +To help you with keeping up to spead, check :doc:`Migration guide <../upgrade/migration_guide>`. .. list-table:: Past versions :widths: 5 50 30 15 @@ -22,7 +22,7 @@ To help you with keeping up to spead, check :doc:`Migration guide <./upgrade/mig `1.9.3 `_, `1.9.4 `_, `1.9.5 `_ - - :doc:`from 1.9 to 2.0 ` + - :doc:`from 1.9 to 2.0 <../upgrade/from_1_9>` * - `1.8 `_ - `Colossal-AI Strategy, Commands and Secrets for Apps, FSDP Improvements and More! `_ @@ -33,7 +33,7 @@ To help you with keeping up to spead, check :doc:`Migration guide <./upgrade/mig `1.8.4 `_, `1.8.5 `_, `1.8.6 `_ - - :doc:`from 1.8 to 2.0 ` + - :doc:`from 1.8 to 2.0 <../upgrade/from_1_8>` * - `1.7 `_ - `Apple Silicon support, Native FSDP, Collaborative training, and multi-GPU support with Jupyter notebooks `_ @@ -45,7 +45,7 @@ To help you with keeping up to spead, check :doc:`Migration guide <./upgrade/mig `1.7.5 `_, `1.7.6 `_, `1.7.7 `_ - - :doc:`from 1.7 to 2.0 ` + - :doc:`from 1.7 to 2.0 <../upgrade/from_1_7>` * - `1.6 `_ - `Support Intel's Habana Accelerator, New efficient DDP strategy (Bagua), Manual Fault-tolerance, Stability and Reliability `_ @@ -55,7 +55,7 @@ To help you with keeping up to spead, check :doc:`Migration guide <./upgrade/mig `1.6.3 `_, `1.6.4 `_, `1.6.5 `_ - - :doc:`from 1.6 to 2.0 ` + - :doc:`from 1.6 to 2.0 <../upgrade/from_1_6>` * - `1.5 `_ - `LightningLite, Fault-Tolerant Training, Loop Customization, Lightning Tutorials, LightningCLI v2, RichProgressBar, CheckpointIO Plugin, and Trainer Strategy Flag `_ @@ -70,7 +70,7 @@ To help you with keeping up to spead, check :doc:`Migration guide <./upgrade/mig `1.5.8 `_, `1.5.9 `_, `1.5.10 `_ - - :doc:`from 1.5 to 2.0 ` + - :doc:`from 1.5 to 2.0 <../upgrade/from_1_5>` * - `1.4 `_ - `TPU Pod Training, IPU Accelerator, DeepSpeed Infinity, Fully Sharded Data Parallel `_ @@ -84,7 +84,7 @@ To help you with keeping up to spead, check :doc:`Migration guide <./upgrade/mig `1.4.7 `_, `1.4.8 `_, `1.4.9 `_ - - :doc:`from 1.4 to 2.0 ` + - :doc:`from 1.4 to 2.0 <../upgrade/from_1_4>` * - `1.3 `_ - `Lightning CLI, PyTorch Profiler, Improved Early Stopping `_ diff --git a/src/lightning/app/CHANGELOG.md b/src/lightning/app/CHANGELOG.md index ac547a4e24c50..23468376248fb 100644 --- a/src/lightning/app/CHANGELOG.md +++ b/src/lightning/app/CHANGELOG.md @@ -7,14 +7,29 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [UnReleased] - 2023-04-DD +- Added the property `LightningWork.public_ip` that exposes the public IP of the `LightningWork` instance ([#17742](https://github.com/Lightning-AI/lightning/pull/17742)) + + +- Add missing python-multipart dependency ([#17244](https://github.com/Lightning-AI/lightning/pull/17244)) + + ### Changed -- +- Made type hints public ([#17100](https://github.com/Lightning-AI/lightning/pull/17100)) ### Fixed -- +- Fixed `LightningWork.internal_ip` that was mistakenly exposing the public IP instead; now exposes the private/internal IP address ([#17742](https://github.com/Lightning-AI/lightning/pull/17742)) + + +- Fixed resolution of latest version in CLI ([#17351](https://github.com/Lightning-AI/lightning/pull/17351)) + + +- Fixed property raised instead of returned ([#17595](https://github.com/Lightning-AI/lightning/pull/17595)) + + +- Fixed get project ([#17617](https://github.com/Lightning-AI/lightning/pull/17617), [#17666](https://github.com/Lightning-AI/lightning/pull/17666)) ## [2.0.2] - 2023-04-24 diff --git a/src/lightning/app/components/database/server.py b/src/lightning/app/components/database/server.py index 05a21894dc6f7..6da7710cfa4f0 100644 --- a/src/lightning/app/components/database/server.py +++ b/src/lightning/app/components/database/server.py @@ -231,9 +231,10 @@ def db_url(self) -> Optional[str]: use_localhost = "LIGHTNING_APP_STATE_URL" not in os.environ if use_localhost: return self.url - if self.internal_ip != "": - return f"http://{self.internal_ip}:{self.port}" - return self.internal_ip + ip_addr = self.public_ip or self.internal_ip + if ip_addr != "": + return f"http://{ip_addr}:{self.port}" + return ip_addr def on_exit(self): self._exit_event.set() diff --git a/src/lightning/app/components/serve/auto_scaler.py b/src/lightning/app/components/serve/auto_scaler.py index 1ce7f45bef318..dde7722f5da43 100644 --- a/src/lightning/app/components/serve/auto_scaler.py +++ b/src/lightning/app/components/serve/auto_scaler.py @@ -180,9 +180,9 @@ def __init__( raise ValueError("cold_start_proxy must be of type ColdStartProxy or str") def get_internal_url(self) -> str: - if not self._internal_ip: - raise ValueError("Internal IP not set") - return f"http://{self._internal_ip}:{self._port}" + if not self._public_ip: + raise ValueError("Public IP not set") + return f"http://{self._public_ip}:{self._port}" async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]], server_url: str): request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch] @@ -386,7 +386,7 @@ def update_servers(self, server_works: List[LightningWork]): """ old_server_urls = set(self.servers) current_server_urls = { - f"http://{server._internal_ip}:{server.port}" for server in server_works if server._internal_ip + f"http://{server._public_ip}:{server.port}" for server in server_works if server._internal_ip } # doing nothing if no server work has been added/removed diff --git a/src/lightning/app/core/work.py b/src/lightning/app/core/work.py index 67b3fb0361dfd..f5416c0cfae3d 100644 --- a/src/lightning/app/core/work.py +++ b/src/lightning/app/core/work.py @@ -60,6 +60,7 @@ class LightningWork: "_url", "_restarting", "_internal_ip", + "_public_ip", ) _run_executor_cls: Type[WorkRunExecutor] = WorkRunExecutor @@ -138,6 +139,7 @@ def __init__( "_url", "_future_url", "_internal_ip", + "_public_ip", "_restarting", "_cloud_compute", "_display_name", @@ -148,6 +150,7 @@ def __init__( self._url: str = "" self._future_url: str = "" # The cache URL is meant to defer resolving the url values. self._internal_ip: str = "" + self._public_ip: str = "" # setattr_replacement is used by the multiprocessing runtime to send the latest changes to the main coordinator self._setattr_replacement: Optional[Callable[[str, Any], None]] = None self._name: str = "" @@ -212,6 +215,15 @@ def internal_ip(self) -> str: """ return self._internal_ip + @property + def public_ip(self) -> str: + """The public ip address of this LightningWork, reachable from the internet. + + By default, this attribute returns the empty string and the ip address will only be returned once the work runs. + Locally, this address is undefined (empty string) and in the cloud it will be determined by the cluster. + """ + return self._public_ip + def _on_init_end(self) -> None: self._local_build_config.on_work_init(self) self._cloud_build_config.on_work_init(self, self._cloud_compute) diff --git a/src/lightning/app/utilities/proxies.py b/src/lightning/app/utilities/proxies.py index 9c5a53f2b98e6..39d33785068ac 100644 --- a/src/lightning/app/utilities/proxies.py +++ b/src/lightning/app/utilities/proxies.py @@ -494,7 +494,8 @@ def run_once(self): # Set this here after the state observer is initialized, since it needs to record it as a change and send # it back to the flow default_internal_ip = "127.0.0.1" if constants.LIGHTNING_CLOUDSPACE_HOST is None else "0.0.0.0" # noqa: S104 - self.work._internal_ip = os.environ.get("LIGHTNING_NODE_IP", default_internal_ip) + self.work._internal_ip = os.environ.get("LIGHTNING_NODE_PRIVATE_IP", default_internal_ip) + self.work._public_ip = os.environ.get("LIGHTNING_NODE_IP", "") # 8. Patch the setattr method of the work. This needs to be done after step 4, so we don't # send delta while calling `set_state`. diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 043ae2aab61e3..04215ccb8289a 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -7,9 +7,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [UnReleased] - 2023-04-DD +- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756)) + + +- Add Fabric internal hooks ([#17759](https://github.com/Lightning-AI/lightning/pull/17759)) + + ### Changed -- +- Made type hints public ([#17100](https://github.com/Lightning-AI/lightning/pull/17100)) + + +- Support compiling a module after it was set up by Fabric ([#17529](https://github.com/Lightning-AI/lightning/pull/17529)) ### Fixed diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index f12dc56ce707e..1167a92358d8c 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -43,6 +43,7 @@ has_iterable_dataset, ) from lightning.fabric.utilities.distributed import DistributedSamplerWrapper +from lightning.fabric.utilities.registry import _load_external_callbacks from lightning.fabric.utilities.seed import seed_everything from lightning.fabric.utilities.types import ReduceOp from lightning.fabric.utilities.warnings import PossibleUserWarning @@ -105,8 +106,7 @@ def __init__( self._strategy: Strategy = self._connector.strategy self._accelerator: Accelerator = self._connector.accelerator self._precision: Precision = self._strategy.precision - callbacks = callbacks if callbacks is not None else [] - self._callbacks = callbacks if isinstance(callbacks, list) else [callbacks] + self._callbacks = self._configure_callbacks(callbacks) loggers = loggers if loggers is not None else [] self._loggers = loggers if isinstance(loggers, list) else [loggers] self._models_setup: int = 0 @@ -212,7 +212,10 @@ def setup( # Update the _DeviceDtypeModuleMixin's device parameter module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device) - optimizers = [_FabricOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] + optimizers = [ + _FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks) + for optimizer in optimizers + ] self._models_setup += 1 @@ -220,6 +223,8 @@ def setup( original_module._fabric = self # type: ignore[assignment] original_module._fabric_optimizers = optimizers # type: ignore[assignment] + self.call("on_after_setup", fabric=self, module=module) + if optimizers: # join both types in a tuple for API convenience return (module, *optimizers) @@ -276,7 +281,10 @@ def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, Tu """ self._validate_setup_optimizers(optimizers) optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers] - optimizers = [_FabricOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] + optimizers = [ + _FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks) + for optimizer in optimizers + ] return optimizers[0] if len(optimizers) == 1 else tuple(optimizers) def setup_dataloaders( @@ -838,6 +846,13 @@ def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None: if any(not isinstance(dl, DataLoader) for dl in dataloaders): raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.") + @staticmethod + def _configure_callbacks(callbacks: Optional[Union[List[Any], Any]]) -> List[Any]: + callbacks = callbacks if callbacks is not None else [] + callbacks = callbacks if isinstance(callbacks, list) else [callbacks] + callbacks.extend(_load_external_callbacks("lightning.fabric.callbacks_factory")) + return callbacks + def _is_using_cli() -> bool: return bool(int(os.environ.get("LT_CLI_USED", "0"))) diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index bdf3d6f3b34c6..6c63ecaa1eeb1 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -30,3 +30,6 @@ _TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0", use_base_version=True) _TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True) _TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1 + +_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8) +_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/fabric/utilities/registry.py b/src/lightning/fabric/utilities/registry.py index 10903d039ec7d..4c3c96dc5803e 100644 --- a/src/lightning/fabric/utilities/registry.py +++ b/src/lightning/fabric/utilities/registry.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import Any +import logging +from typing import Any, List, Union + +from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 + +_log = logging.getLogger(__name__) def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool: @@ -25,3 +30,40 @@ def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> boo return False return mod_attr.__code__ is not super_attr.__code__ + + +def _load_external_callbacks(group: str) -> List[Any]: + """Collect external callbacks registered through entry points. + + The entry points are expected to be functions returning a list of callbacks. + + Args: + group: The entry point group name to load callbacks from. + + Return: + A list of all callbacks collected from external factories. + """ + if _PYTHON_GREATER_EQUAL_3_8_0: + from importlib.metadata import entry_points + + factories = ( + entry_points(group=group) + if _PYTHON_GREATER_EQUAL_3_10_0 + else entry_points().get(group, {}) # type: ignore[arg-type] + ) + else: + from pkg_resources import iter_entry_points + + factories = iter_entry_points(group) # type: ignore[assignment] + + external_callbacks: List[Any] = [] + for factory in factories: + callback_factory = factory.load() + callbacks_list: Union[List[Any], Any] = callback_factory() + callbacks_list = [callbacks_list] if not isinstance(callbacks_list, list) else callbacks_list + _log.info( + f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':" + f" {', '.join(type(cb).__name__ for cb in callbacks_list)}" + ) + external_callbacks.extend(callbacks_list) + return external_callbacks diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index 947510244d0df..8ed10f00c5511 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import Any, Callable, Dict, Generator, Iterator, Mapping, Optional, overload, TypeVar, Union +from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, overload, TypeVar, Union import torch from lightning_utilities import WarningCache @@ -38,7 +38,7 @@ class _FabricOptimizer: - def __init__(self, optimizer: Optimizer, strategy: Strategy) -> None: + def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[List[Callable]] = None) -> None: """FabricOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer step calls to the strategy plugin. @@ -54,6 +54,7 @@ def __init__(self, optimizer: Optimizer, strategy: Strategy) -> None: self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer self._strategy = strategy + self._callbacks = callbacks or [] @property def optimizer(self) -> Optimizer: @@ -69,10 +70,15 @@ def step(self, closure: Optional[Callable] = None) -> Any: optimizer = self._strategy.model else: optimizer = self.optimizer - return self._strategy.optimizer_step( + output = self._strategy.optimizer_step( optimizer, **kwargs, ) + for callback in self._callbacks: + hook = getattr(callback, "on_after_optimizer_step", None) + if callable(hook): + hook(strategy=self._strategy, optimizer=optimizer) + return output class _FabricModule(_DeviceDtypeModuleMixin): diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 0500e336f2268..5f7ed4645d847 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Made type hints public ([#17100](https://github.com/Lightning-AI/lightning/pull/17100)) + + - @@ -41,6 +44,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an edge case causing overlapping samples in DDP when no global seed is set ([#17713](https://github.com/Lightning-AI/lightning/pull/17713)) +- Fallback to module available check for mlflow ([#17467](https://github.com/Lightning-AI/lightning/pull/17467)) + + +- Fixed LR finder max val batches ([#17636](https://github.com/Lightning-AI/lightning/pull/17636)) + + +- Fixed multithreading checkpoint loading ([#17678](https://github.com/Lightning-AI/lightning/pull/17678)) + + ## [2.0.2] - 2023-04-24 ### Fixed diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index 8195a5c4a3b52..d649755172658 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -18,6 +18,7 @@ from typing import Dict, List, Optional, Sequence, Union import lightning.pytorch as pl +from lightning.fabric.utilities.registry import _load_external_callbacks from lightning.pytorch.callbacks import ( Callback, Checkpoint, @@ -33,7 +34,6 @@ from lightning.pytorch.callbacks.timer import Timer from lightning.pytorch.trainer import call from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 from lightning.pytorch.utilities.model_helpers import is_overridden from lightning.pytorch.utilities.rank_zero import rank_zero_info @@ -75,7 +75,7 @@ def on_trainer_init( # configure the ModelSummary callback self._configure_model_summary_callback(enable_model_summary) - self.trainer.callbacks.extend(_configure_external_callbacks()) + self.trainer.callbacks.extend(_load_external_callbacks("lightning.pytorch.callbacks_factory")) _validate_callbacks_list(self.trainer.callbacks) # push all model checkpoint callbacks to the end @@ -213,42 +213,6 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: return tuner_callbacks + other_callbacks + checkpoint_callbacks -def _configure_external_callbacks() -> List[Callback]: - """Collect external callbacks registered through entry points. - - The entry points are expected to be functions returning a list of callbacks. - - Return: - A list of all callbacks collected from external factories. - """ - group = "lightning.pytorch.callbacks_factory" - - if _PYTHON_GREATER_EQUAL_3_8_0: - from importlib.metadata import entry_points - - factories = ( - entry_points(group=group) - if _PYTHON_GREATER_EQUAL_3_10_0 - else entry_points().get(group, {}) # type: ignore[arg-type] - ) - else: - from pkg_resources import iter_entry_points - - factories = iter_entry_points(group) # type: ignore[assignment] - - external_callbacks: List[Callback] = [] - for factory in factories: - callback_factory = factory.load() - callbacks_list: Union[List[Callback], Callback] = callback_factory() - callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list - _log.info( - f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':" - f" {', '.join(type(cb).__name__ for cb in callbacks_list)}" - ) - external_callbacks.extend(callbacks_list) - return external_callbacks - - def _validate_callbacks_list(callbacks: List[Callback]) -> None: stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)] seen_callbacks = set() diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py index 1f48386e3ff86..7e6b7cd0c5e91 100644 --- a/src/lightning/pytorch/trainer/connectors/signal_connector.py +++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py @@ -11,8 +11,7 @@ import lightning.pytorch as pl from lightning.fabric.plugins.environments import SLURMEnvironment -from lightning.fabric.utilities.imports import _IS_WINDOWS -from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0 +from lightning.fabric.utilities.imports import _IS_WINDOWS, _PYTHON_GREATER_EQUAL_3_8_0 from lightning.pytorch.utilities.rank_zero import rank_zero_info # copied from signal.pyi diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py index 259f18070362b..bfb1eeb5c5174 100644 --- a/src/lightning/pytorch/utilities/imports.py +++ b/src/lightning/pytorch/utilities/imports.py @@ -17,8 +17,6 @@ import torch from lightning_utilities.core.imports import package_available, RequirementCache -_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8) -_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) _PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11) _TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1") _TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task diff --git a/src/version.info b/src/version.info index e9307ca5751b2..50ffc5aa7f69f 100644 --- a/src/version.info +++ b/src/version.info @@ -1 +1 @@ -2.0.2 +2.0.3 diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py index f5933747293e6..ed451475b5e80 100644 --- a/tests/tests_app/core/test_lightning_app.py +++ b/tests/tests_app/core/test_lightning_app.py @@ -119,6 +119,7 @@ def test_simple_app(tmpdir): "_url": "", "_future_url": "", "_internal_ip": "", + "_public_ip": "", "_paths": {}, "_port": None, "_restarting": False, @@ -136,6 +137,7 @@ def test_simple_app(tmpdir): "_url": "", "_future_url": "", "_internal_ip": "", + "_public_ip": "", "_paths": {}, "_port": None, "_restarting": False, @@ -982,7 +984,7 @@ def run(self): def test_state_size_constant_growth(): app = LightningApp(SizeFlow()) MultiProcessRuntime(app, start_server=False).dispatch() - assert app.root._state_sizes[0] <= 7965 + assert app.root._state_sizes[0] <= 8304 assert app.root._state_sizes[20] <= 26550 diff --git a/tests/tests_app/core/test_lightning_flow.py b/tests/tests_app/core/test_lightning_flow.py index 9b174f5a3af8a..fe66a9bb10124 100644 --- a/tests/tests_app/core/test_lightning_flow.py +++ b/tests/tests_app/core/test_lightning_flow.py @@ -324,6 +324,7 @@ def run(self): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_public_ip": "", "_display_name": "", "_cloud_compute": { "type": "__cloud_compute__", @@ -349,6 +350,7 @@ def run(self): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_public_ip": "", "_display_name": "", "_cloud_compute": { "type": "__cloud_compute__", @@ -388,6 +390,7 @@ def run(self): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_public_ip": "", "_display_name": "", "_cloud_compute": { "type": "__cloud_compute__", @@ -413,6 +416,7 @@ def run(self): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_public_ip": "", "_display_name": "", "_cloud_compute": { "type": "__cloud_compute__", diff --git a/tests/tests_app/structures/test_structures.py b/tests/tests_app/structures/test_structures.py index f921607a50734..4cbe391027d8a 100644 --- a/tests/tests_app/structures/test_structures.py +++ b/tests/tests_app/structures/test_structures.py @@ -46,6 +46,7 @@ def run(self): "_restarting": False, "_display_name": "", "_internal_ip": "", + "_public_ip": "", "_cloud_compute": { "type": "__cloud_compute__", "name": "cpu-small", @@ -80,6 +81,7 @@ def run(self): "_restarting": False, "_display_name": "", "_internal_ip": "", + "_public_ip": "", "_cloud_compute": { "type": "__cloud_compute__", "name": "cpu-small", @@ -114,6 +116,7 @@ def run(self): "_restarting": False, "_display_name": "", "_internal_ip": "", + "_public_ip": "", "_cloud_compute": { "type": "__cloud_compute__", "name": "cpu-small", @@ -199,6 +202,7 @@ def run(self): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_public_ip": "", "_display_name": "", "_cloud_compute": { "type": "__cloud_compute__", @@ -233,6 +237,7 @@ def run(self): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_public_ip": "", "_display_name": "", "_cloud_compute": { "type": "__cloud_compute__", @@ -262,6 +267,7 @@ def run(self): "_paths": {}, "_restarting": False, "_internal_ip": "", + "_public_ip": "", "_display_name": "", "_cloud_compute": { "type": "__cloud_compute__", diff --git a/tests/tests_app/utilities/test_proxies.py b/tests/tests_app/utilities/test_proxies.py index a4661e2736aaa..8dbfe1b06b871 100644 --- a/tests/tests_app/utilities/test_proxies.py +++ b/tests/tests_app/utilities/test_proxies.py @@ -641,16 +641,21 @@ def test_state_observer(): @pytest.mark.parametrize( - ("patch_constants", "environment", "expected_ip_addr"), + ("patch_constants", "environment", "expected_public_ip", "expected_private_ip"), [ - ({}, {}, "127.0.0.1"), - ({"LIGHTNING_CLOUDSPACE_HOST": "any"}, {}, "0.0.0.0"), # noqa: S104 - ({}, {"LIGHTNING_NODE_IP": "10.10.10.5"}, "10.10.10.5"), + ({}, {}, "", "127.0.0.1"), + ({"LIGHTNING_CLOUDSPACE_HOST": "any"}, {}, "", "0.0.0.0"), # noqa: S104 + ( + {}, + {"LIGHTNING_NODE_IP": "85.44.2.25", "LIGHTNING_NODE_PRIVATE_IP": "10.10.10.5"}, + "85.44.2.25", + "10.10.10.5", + ), ], indirect=["patch_constants"], ) -def test_work_runner_sets_internal_ip(patch_constants, environment, expected_ip_addr): - """Test that the WorkRunner updates the internal ip address as soon as the Work starts running.""" +def test_work_runner_sets_public_and_private_ip(patch_constants, environment, expected_public_ip, expected_private_ip): + """Test that the WorkRunner updates the public and private address as soon as the Work starts running.""" class Work(LightningWork): def run(self): @@ -690,11 +695,13 @@ def run(self): with mock.patch.dict(os.environ, environment, clear=True): work_runner.setup() - # The internal ip address only becomes available once the hardware is up / the work is running. + # The public ip address only becomes available once the hardware is up / the work is running. + assert work.public_ip == "" assert work.internal_ip == "" with contextlib.suppress(Empty): work_runner.run_once() - assert work.internal_ip == expected_ip_addr + assert work.public_ip == expected_public_ip + assert work.internal_ip == expected_private_ip class WorkBi(LightningWork): diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 58ab0e34166ed..015d1aba41359 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -753,6 +753,29 @@ def test_call(): assert not callback1.mock_calls +def test_special_callbacks(): + """Tests special callbacks that have hooks for internal Fabric events.""" + + class SpecialCallback: + def on_after_optimizer_step(self, strategy, optimizer): + pass + + def on_after_setup(self, fabric, module): + pass + + callback = Mock(wraps=SpecialCallback()) + fabric = Fabric(accelerator="cpu", callbacks=[callback]) + + model = torch.nn.Linear(2, 2) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + fabric_model, fabric_optimizer = fabric.setup(model, optimizer) + callback.on_after_setup.assert_called_once_with(fabric=fabric, module=fabric_model) + + model(torch.randn(2, 2)).sum().backward() + fabric_optimizer.step() + callback.on_after_optimizer_step.assert_called_once_with(strategy=fabric._strategy, optimizer=optimizer) + + def test_loggers_input(): """Test the various ways in which loggers can be registered with Fabric.""" logger0 = Mock() diff --git a/tests/tests_fabric/utilities/test_registry.py b/tests/tests_fabric/utilities/test_registry.py new file mode 100644 index 0000000000000..75e6e12f5abff --- /dev/null +++ b/tests/tests_fabric/utilities/test_registry.py @@ -0,0 +1,64 @@ +import contextlib +from unittest import mock +from unittest.mock import Mock + +from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 +from lightning.fabric.utilities.registry import _load_external_callbacks + + +class ExternalCallback: + """A callback in another library that gets registered through entry points.""" + + pass + + +def test_load_external_callbacks(): + """Test that the connector collects Callback instances from factories registered through entry points.""" + + def factory_no_callback(): + return [] + + def factory_one_callback(): + return ExternalCallback() + + def factory_one_callback_list(): + return [ExternalCallback()] + + def factory_multiple_callbacks_list(): + return [ExternalCallback(), ExternalCallback()] + + with _make_entry_point_query_mock(factory_no_callback): + callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory") + assert callbacks == [] + + with _make_entry_point_query_mock(factory_one_callback): + callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory") + assert isinstance(callbacks[0], ExternalCallback) + + with _make_entry_point_query_mock(factory_one_callback_list): + callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory") + assert isinstance(callbacks[0], ExternalCallback) + + with _make_entry_point_query_mock(factory_multiple_callbacks_list): + callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory") + assert isinstance(callbacks[0], ExternalCallback) + assert isinstance(callbacks[1], ExternalCallback) + + +@contextlib.contextmanager +def _make_entry_point_query_mock(callback_factory): + query_mock = Mock() + entry_point = Mock() + entry_point.name = "mocked" + entry_point.load.return_value = callback_factory + if _PYTHON_GREATER_EQUAL_3_10_0: + query_mock.return_value = [entry_point] + import_path = "importlib.metadata.entry_points" + elif _PYTHON_GREATER_EQUAL_3_8_0: + query_mock().get.return_value = [entry_point] + import_path = "importlib.metadata.entry_points" + else: + query_mock.return_value = [entry_point] + import_path = "pkg_resources.iter_entry_points" + with mock.patch(import_path, query_mock): + yield diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py index 71262df9179e8..58f59ad760763 100644 --- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py @@ -19,6 +19,7 @@ import pytest import torch +from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 from lightning.pytorch import Callback, LightningModule, Trainer from lightning.pytorch.callbacks import ( EarlyStopping, @@ -32,7 +33,6 @@ from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector -from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0 def test_checkpoint_callbacks_are_last(tmpdir): diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index 36f6356f995cb..cea40b921e1a5 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -25,6 +25,7 @@ import torch from torch import Tensor +from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0 from lightning.pytorch import callbacks, Trainer from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset @@ -32,7 +33,6 @@ from lightning.pytorch.loops import _EvaluationLoop from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0 from tests_pytorch.helpers.runif import RunIf if _RICH_AVAILABLE: