diff --git a/docs/source-pytorch/governance.rst b/docs/source-pytorch/community/governance.rst
similarity index 100%
rename from docs/source-pytorch/governance.rst
rename to docs/source-pytorch/community/governance.rst
diff --git a/docs/source-pytorch/community/index.rst b/docs/source-pytorch/community/index.rst
new file mode 100644
index 0000000000000..bd500a914146e
--- /dev/null
+++ b/docs/source-pytorch/community/index.rst
@@ -0,0 +1,68 @@
+
+.. toctree::
+ :maxdepth: 1
+ :hidden:
+
+ ../generated/CODE_OF_CONDUCT.md
+ ../generated/CONTRIBUTING.md
+ ../generated/BECOMING_A_CORE_CONTRIBUTOR.md
+ governance
+ ../versioning
+ ../past_versions
+ ../generated/CHANGELOG.md
+
+#########
+Community
+#########
+
+.. raw:: html
+
+
+
+
+.. displayitem::
+ :header: Code of conduct
+ :description: Contributor Covenant Code of Conduct
+ :col_css: col-md-12
+ :button_link: ../generated/CODE_OF_CONDUCT.html
+ :height: 100
+
+.. displayitem::
+ :header: Contribution guide
+ :description: How to contribute to PyTorch Lightning
+ :col_css: col-md-12
+ :button_link: ../generated/CONTRIBUTING.html
+ :height: 100
+
+.. displayitem::
+ :header: How to Become a core contributor
+ :description: Steps to be a core contributor
+ :col_css: col-md-12
+ :button_link: ../generated/BECOMING_A_CORE_CONTRIBUTOR.html
+ :height: 100
+
+.. displayitem::
+ :header: Lightning Governance
+ :description: The governance processes we follow
+ :col_css: col-md-12
+ :button_link: governance.html
+ :height: 100
+
+.. displayitem::
+ :header: Versioning
+ :description: PyTorch Lightning's versioning policy
+ :col_css: col-md-12
+ :button_link: ../versioning.html
+ :height: 100
+
+.. displayitem::
+ :header: Changelog
+ :description: All notable changes to PyTorch Lightning
+ :col_css: col-md-12
+ :button_link: ../generated/CHANGELOG.html
+ :height: 100
+
+.. raw:: html
+
+
+
diff --git a/docs/source-pytorch/glossary/index.rst b/docs/source-pytorch/glossary/index.rst
new file mode 100644
index 0000000000000..2c91a7c2d0b80
--- /dev/null
+++ b/docs/source-pytorch/glossary/index.rst
@@ -0,0 +1,331 @@
+
+.. toctree::
+ :maxdepth: 1
+ :hidden:
+
+ Accelerators <../extensions/accelerator>
+ Callback <../extensions/callbacks>
+ Checkpointing <../common/checkpointing>
+ Cluster <../clouds/cluster>
+ Cloud checkpoint <../common/checkpointing_advanced>
+ Console Logging <../common/console_logs>
+ Debugging <../debug/debugging>
+ Early stopping <../common/early_stopping>
+ Experiment manager (Logger) <../visualize/experiment_managers>
+ Finetuning <../advanced/finetuning>
+ GPU <../accelerators/gpu>
+ Half precision <../common/precision>
+ HPU <../accelerators/hpu>
+ Inference <../deploy/production_intermediate>
+ IPU <../accelerators/ipu>
+ Lightning CLI <../cli/lightning_cli>
+ LightningDataModule <../data/datamodule>
+ LightningModule <../common/lightning_module>
+ Log <../visualize/loggers>
+ TPU <../accelerators/tpu>
+ Metrics
+ Model <../model/build_model.rst>
+ Model Parallel <../advanced/model_parallel>
+ Plugins <../extensions/plugins>
+ Progress bar <../common/progress_bar>
+ Production <../deploy/production_advanced>
+ Predict <../deploy/production_basic>
+ Pretrained models <../advanced/pretrained>
+ Profiler <../tuning/profiler>
+ Pruning and Quantization <../advanced/pruning_quantization>
+ Remote filesystem and FSSPEC <../common/remote_fs>
+ Strategy <../extensions/strategy>
+ Strategy registry <../advanced/strategy_registry>
+ Style guide <../starter/style_guide>
+ SWA <../advanced/training_tricks>
+ SLURM <../clouds/cluster_advanced>
+ Transfer learning <../advanced/transfer_learning>
+ Trainer <../common/trainer>
+ Torch distributed <../clouds/cluster_intermediate_2>
+
+########
+Glossary
+########
+
+.. raw:: html
+
+
+
+
+.. displayitem::
+ :header: Accelerators
+ :description: Accelerators connect the Trainer to hardware to train faster
+ :col_css: col-md-12
+ :button_link: ../extensions/accelerator.html
+ :height: 100
+
+.. displayitem::
+ :header: Callback
+ :description: Add self-contained extra functionality during training execution
+ :col_css: col-md-12
+ :button_link: ../extensions/callbacks.html
+ :height: 100
+
+.. displayitem::
+ :header: Checkpointing
+ :description: Save and load progress with checkpoints
+ :col_css: col-md-12
+ :button_link: ../common/checkpointing.html
+ :height: 100
+
+.. displayitem::
+ :header: Cluster
+ :description: Run on your own group of servers
+ :col_css: col-md-12
+ :button_link: ../clouds/cluster.html
+ :height: 100
+
+.. displayitem::
+ :header: Cloud checkpoint
+ :description: Save your models to cloud filesystems
+ :col_css: col-md-12
+ :button_link: ../common/checkpointing_advanced.html
+ :height: 100
+
+.. displayitem::
+ :header: Console Logging
+ :description: Capture more visible logs
+ :col_css: col-md-12
+ :button_link: ../common/console_logs.html
+ :height: 100
+
+.. displayitem::
+ :header: Debugging
+ :description: Fix errors in your code
+ :col_css: col-md-12
+ :button_link: ../debug/debugging.html
+ :height: 100
+
+.. displayitem::
+ :header: Early stopping
+ :description: Stop the training when no improvement is observed
+ :col_css: col-md-12
+ :button_link: ../common/early_stopping.html
+ :height: 100
+
+.. displayitem::
+ :header: Experiment manager (Logger)
+ :description: Tools for tracking and visualizing artifacts and logs
+ :col_css: col-md-12
+ :button_link: ../visualize/experiment_managers.html
+ :height: 100
+
+.. displayitem::
+ :header: Finetuning
+ :description: Technique for training pretrained models
+ :col_css: col-md-12
+ :button_link: ../advanced/finetuning.html
+ :height: 100
+
+.. displayitem::
+ :header: GPU
+ :description: Graphics Processing Unit for faster training
+ :col_css: col-md-12
+ :button_link: ../accelerators/gpu.html
+ :height: 100
+
+.. displayitem::
+ :header: Half precision
+ :description: Using different numerical formats to save memory and run fatser
+ :col_css: col-md-12
+ :button_link: ../common/precision.html
+ :height: 100
+
+.. displayitem::
+ :header: HPU
+ :description: Habana Gaudi AI Processor Unit for faster training
+ :col_css: col-md-12
+ :button_link: ../accelerators/hpu.html
+ :height: 100
+
+.. displayitem::
+ :header: Inference
+ :description: Making predictions by applying a trained model to unlabeled examples
+ :col_css: col-md-12
+ :button_link: ../deploy/production_intermediate.html
+ :height: 100
+
+.. displayitem::
+ :header: IPU
+ :description: Graphcore Intelligence Processing Unit for faster training
+ :col_css: col-md-12
+ :button_link: ../accelerators/ipu.html
+ :height: 100
+
+.. displayitem::
+ :header: Lightning CLI
+ :description: A Command-line Interface (CLI) to interact with Lightning code via a terminal
+ :col_css: col-md-12
+ :button_link: ../cli/lightning_cli.html
+ :height: 100
+
+.. displayitem::
+ :header: LightningDataModule
+ :description: A shareable, reusable class that encapsulates all the steps needed to process data
+ :col_css: col-md-12
+ :button_link: ../data/datamodule.html
+ :height: 100
+
+.. displayitem::
+ :header: LightningModule
+ :description: A base class organizug your neural network module
+ :col_css: col-md-12
+ :button_link: ../common/lightning_module.html
+ :height: 100
+
+.. displayitem::
+ :header: Log
+ :description: Outpus or results used for visualization and tracking
+ :col_css: col-md-12
+ :button_link: ../visualize/loggers.html
+ :height: 100
+
+.. displayitem::
+ :header: Metrics
+ :description: A statistic used to measure performance or other objectives we want to optimize
+ :col_css: col-md-12
+ :button_link: https://torchmetrics.readthedocs.io/en/stable/
+ :height: 100
+
+.. displayitem::
+ :header: Model
+ :description: The set of parameters and structure for a system to make predictions
+ :col_css: col-md-12
+ :button_link: ../model/build_model.html
+ :height: 100
+
+.. displayitem::
+ :header: Model Parallelism
+ :description: A way to scale training that splits a model between multiple devices.
+ :col_css: col-md-12
+ :button_link: ../advanced/model_parallel.html
+ :height: 100
+
+.. displayitem::
+ :header: Plugins
+ :description: Custom trainer integrations such as custom precision, checkpointing or cluster environment implementation
+ :col_css: col-md-12
+ :button_link: ../extensions/plugins.html
+ :height: 100
+
+.. displayitem::
+ :header: Progress bar
+ :description: Output printed to the terminal to visualize the progression of training
+ :col_css: col-md-12
+ :button_link: ../common/progress_bar.html
+ :height: 100
+
+.. displayitem::
+ :header: Production
+ :description: Using ML models in real world systems
+ :col_css: col-md-12
+ :button_link: ../deploy/production_advanced.html
+ :height: 100
+
+.. displayitem::
+ :header: Prediction
+ :description: Computing a model's output
+ :col_css: col-md-12
+ :button_link: ../deploy/production_basic.html
+ :height: 100
+
+.. displayitem::
+ :header: Pretrained models
+ :description: Models that have already been trained for a particular task
+ :col_css: col-md-12
+ :button_link: ../advanced/pretrained.html
+ :height: 100
+
+.. displayitem::
+ :header: Profiler
+ :description: Tool to identify bottlenecks and performance of different parts of a model
+ :col_css: col-md-12
+ :button_link: ../tuning/profiler.html
+ :height: 100
+
+.. displayitem::
+ :header: Pruning
+ :description: A technique to eliminae some of the model weights to reduce the model size and decrease inference requirements
+ :col_css: col-md-12
+ :button_link: ../advanced/pruning_quantization.html
+ :height: 100
+
+.. displayitem::
+ :header: Quantization
+ :description: A technique to accelerate the model inference speed and decrease the memory load while still maintaining the model accuracy
+ :col_css: col-md-12
+ :button_link: ../advanced/post_training_quantization.html
+ :height: 100
+
+.. displayitem::
+ :header: Remote filesystem and FSSPEC
+ :description: Accessing files from cloud storage providers
+ :col_css: col-md-12
+ :button_link: ../common/remote_fs.html
+ :height: 100
+
+.. displayitem::
+ :header: Strategy
+ :description: Ways the trainer controls the model distribution across training, evaluation, and prediction
+ :col_css: col-md-12
+ :button_link: ../extensions/strategy.html
+ :height: 100
+
+.. displayitem::
+ :header: Strategy registry
+ :description: A class that holds information about training strategies and allows adding new custom strategies
+ :col_css: col-md-12
+ :button_link: ../advanced/strategy_registry.html
+ :height: 100
+
+.. displayitem::
+ :header: Style guide
+ :description: Best practices to improve readability and reproducability
+ :col_css: col-md-12
+ :button_link: ../starter/style_guide.html
+ :height: 100
+
+.. displayitem::
+ :header: SWA
+ :description: Stochastic Weight Averaging (SWA) can make your models generalize better
+ :col_css: col-md-12
+ :button_link: ../advanced/training_tricks.html#stochastic-weight-averaging
+ :height: 100
+
+.. displayitem::
+ :header: SLURM
+ :description: Simple Linux Utility for Resource Management, or simply Slurm, is a free and open-source job scheduler for Linux clusters
+ :col_css: col-md-12
+ :button_link: ../clouds/cluster_advanced.html
+ :height: 100
+
+.. displayitem::
+ :header: Transfer learning
+ :description: Using pre-trained models to improve learning
+ :col_css: col-md-12
+ :button_link: ../advanced/transfer_learning.html
+ :height: 100
+
+.. displayitem::
+ :header: Trainer
+ :description: The class that automates and customizes model training
+ :col_css: col-md-12
+ :button_link: ../common/trainer.html
+ :height: 100
+
+.. displayitem::
+ :header: Torch distributed
+ :description: Setup for running on distributed environments
+ :col_css: col-md-12
+ :button_link: ../clouds/cluster_intermediate_2.html
+ :height: 100
+
+.. raw:: html
+
+
+
diff --git a/docs/source-pytorch/index.rst b/docs/source-pytorch/index.rst
index 9560190717931..0bd78b6fb63a8 100644
--- a/docs/source-pytorch/index.rst
+++ b/docs/source-pytorch/index.rst
@@ -165,7 +165,7 @@ Current Lightning Users
levels/expert
.. toctree::
- :maxdepth: 2
+ :maxdepth: 1
:name: pl_docs
:caption: Core API
@@ -173,105 +173,30 @@ Current Lightning Users
common/trainer
.. toctree::
- :maxdepth: 2
+ :maxdepth: 1
:name: api
- :caption: API Reference
+ :caption: Optional API
api_references
.. toctree::
:maxdepth: 1
- :name: Common Workflows
- :caption: Common Workflows
-
- Avoid overfitting
- model/build_model.rst
- cli/lightning_cli
- common/progress_bar
- deploy/production
- advanced/training_tricks
- tuning/profiler
- Manage experiments
- Organize existing PyTorch into Lightning
- clouds/cluster
- Save and load model progress
- Save memory with half-precision
- advanced/model_parallel
- Train on single or multiple GPUs
- Train on single or multiple HPUs
- Train on single or multiple IPUs
- Train on single or multiple TPUs
- Train on MPS
- Use a pretrained model
- data/data
- model/own_your_loop
-
-.. toctree::
- :maxdepth: 1
- :name: Glossary
- :caption: Glossary
-
- Accelerators
- Callback
- Checkpointing
- Cluster
- Cloud checkpoint
- Console Logging
- Debugging
- Early stopping
- Experiment manager (Logger)
- Finetuning
- Flash
- GPU
- Half precision
- HPU
- Inference
- IPU
- Lightning CLI
- LightningDataModule
- LightningModule
- Log
- TPU
- Metrics
- Model
- Model Parallel
- Plugins
- Progress bar
- Production
- Predict
- Pretrained models
- Profiler
- Pruning and Quantization
- Remote filesystem and FSSPEC
- Strategy
- Strategy registry
- Style guide
- SWA
- SLURM
- Transfer learning
- Trainer
- Torch distributed
-
-.. toctree::
- :maxdepth: 1
- :name: Hands-on Examples
- :caption: Hands-on Examples
+ :name: Examples
+ :caption: Examples
:glob:
notebooks/**/*
+
.. toctree::
:maxdepth: 1
- :name: Community
- :caption: Community
-
- generated/CODE_OF_CONDUCT.md
- generated/CONTRIBUTING.md
- generated/BECOMING_A_CORE_CONTRIBUTOR.md
- governance
- versioning
- past_versions
- generated/CHANGELOG.md
+ :name: More
+ :caption: More
+
+ Community
+ Glossary
+ How to
+
.. raw:: html
diff --git a/docs/source-pytorch/past_versions.rst b/docs/source-pytorch/past_versions.rst
index be8f516824681..374303735ed32 100644
--- a/docs/source-pytorch/past_versions.rst
+++ b/docs/source-pytorch/past_versions.rst
@@ -3,7 +3,7 @@ Past PyTorch Lightning versions
PyTorch Lightning :doc:`evolved over time `. Here's the history of versions with links to their respective docs.
-To help you with keeping up to spead, check :doc:`Migration guide <./upgrade/migration_guide>`.
+To help you with keeping up to spead, check :doc:`Migration guide <../upgrade/migration_guide>`.
.. list-table:: Past versions
:widths: 5 50 30 15
@@ -22,7 +22,7 @@ To help you with keeping up to spead, check :doc:`Migration guide <./upgrade/mig
`1.9.3 `_,
`1.9.4 `_,
`1.9.5 `_
- - :doc:`from 1.9 to 2.0 `
+ - :doc:`from 1.9 to 2.0 <../upgrade/from_1_9>`
* - `1.8 `_
- `Colossal-AI Strategy, Commands and Secrets for Apps, FSDP Improvements and More! `_
@@ -33,7 +33,7 @@ To help you with keeping up to spead, check :doc:`Migration guide <./upgrade/mig
`1.8.4 `_,
`1.8.5 `_,
`1.8.6 `_
- - :doc:`from 1.8 to 2.0 `
+ - :doc:`from 1.8 to 2.0 <../upgrade/from_1_8>`
* - `1.7 `_
- `Apple Silicon support, Native FSDP, Collaborative training, and multi-GPU support with Jupyter notebooks `_
@@ -45,7 +45,7 @@ To help you with keeping up to spead, check :doc:`Migration guide <./upgrade/mig
`1.7.5 `_,
`1.7.6 `_,
`1.7.7 `_
- - :doc:`from 1.7 to 2.0 `
+ - :doc:`from 1.7 to 2.0 <../upgrade/from_1_7>`
* - `1.6 `_
- `Support Intel's Habana Accelerator, New efficient DDP strategy (Bagua), Manual Fault-tolerance, Stability and Reliability `_
@@ -55,7 +55,7 @@ To help you with keeping up to spead, check :doc:`Migration guide <./upgrade/mig
`1.6.3 `_,
`1.6.4 `_,
`1.6.5 `_
- - :doc:`from 1.6 to 2.0 `
+ - :doc:`from 1.6 to 2.0 <../upgrade/from_1_6>`
* - `1.5 `_
- `LightningLite, Fault-Tolerant Training, Loop Customization, Lightning Tutorials, LightningCLI v2, RichProgressBar, CheckpointIO Plugin, and Trainer Strategy Flag `_
@@ -70,7 +70,7 @@ To help you with keeping up to spead, check :doc:`Migration guide <./upgrade/mig
`1.5.8 `_,
`1.5.9 `_,
`1.5.10 `_
- - :doc:`from 1.5 to 2.0 `
+ - :doc:`from 1.5 to 2.0 <../upgrade/from_1_5>`
* - `1.4 `_
- `TPU Pod Training, IPU Accelerator, DeepSpeed Infinity, Fully Sharded Data Parallel `_
@@ -84,7 +84,7 @@ To help you with keeping up to spead, check :doc:`Migration guide <./upgrade/mig
`1.4.7 `_,
`1.4.8 `_,
`1.4.9 `_
- - :doc:`from 1.4 to 2.0 `
+ - :doc:`from 1.4 to 2.0 <../upgrade/from_1_4>`
* - `1.3 `_
- `Lightning CLI, PyTorch Profiler, Improved Early Stopping `_
diff --git a/src/lightning/app/CHANGELOG.md b/src/lightning/app/CHANGELOG.md
index ac547a4e24c50..23468376248fb 100644
--- a/src/lightning/app/CHANGELOG.md
+++ b/src/lightning/app/CHANGELOG.md
@@ -7,14 +7,29 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [UnReleased] - 2023-04-DD
+- Added the property `LightningWork.public_ip` that exposes the public IP of the `LightningWork` instance ([#17742](https://github.com/Lightning-AI/lightning/pull/17742))
+
+
+- Add missing python-multipart dependency ([#17244](https://github.com/Lightning-AI/lightning/pull/17244))
+
+
### Changed
--
+- Made type hints public ([#17100](https://github.com/Lightning-AI/lightning/pull/17100))
### Fixed
--
+- Fixed `LightningWork.internal_ip` that was mistakenly exposing the public IP instead; now exposes the private/internal IP address ([#17742](https://github.com/Lightning-AI/lightning/pull/17742))
+
+
+- Fixed resolution of latest version in CLI ([#17351](https://github.com/Lightning-AI/lightning/pull/17351))
+
+
+- Fixed property raised instead of returned ([#17595](https://github.com/Lightning-AI/lightning/pull/17595))
+
+
+- Fixed get project ([#17617](https://github.com/Lightning-AI/lightning/pull/17617), [#17666](https://github.com/Lightning-AI/lightning/pull/17666))
## [2.0.2] - 2023-04-24
diff --git a/src/lightning/app/components/database/server.py b/src/lightning/app/components/database/server.py
index 05a21894dc6f7..6da7710cfa4f0 100644
--- a/src/lightning/app/components/database/server.py
+++ b/src/lightning/app/components/database/server.py
@@ -231,9 +231,10 @@ def db_url(self) -> Optional[str]:
use_localhost = "LIGHTNING_APP_STATE_URL" not in os.environ
if use_localhost:
return self.url
- if self.internal_ip != "":
- return f"http://{self.internal_ip}:{self.port}"
- return self.internal_ip
+ ip_addr = self.public_ip or self.internal_ip
+ if ip_addr != "":
+ return f"http://{ip_addr}:{self.port}"
+ return ip_addr
def on_exit(self):
self._exit_event.set()
diff --git a/src/lightning/app/components/serve/auto_scaler.py b/src/lightning/app/components/serve/auto_scaler.py
index 1ce7f45bef318..dde7722f5da43 100644
--- a/src/lightning/app/components/serve/auto_scaler.py
+++ b/src/lightning/app/components/serve/auto_scaler.py
@@ -180,9 +180,9 @@ def __init__(
raise ValueError("cold_start_proxy must be of type ColdStartProxy or str")
def get_internal_url(self) -> str:
- if not self._internal_ip:
- raise ValueError("Internal IP not set")
- return f"http://{self._internal_ip}:{self._port}"
+ if not self._public_ip:
+ raise ValueError("Public IP not set")
+ return f"http://{self._public_ip}:{self._port}"
async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]], server_url: str):
request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch]
@@ -386,7 +386,7 @@ def update_servers(self, server_works: List[LightningWork]):
"""
old_server_urls = set(self.servers)
current_server_urls = {
- f"http://{server._internal_ip}:{server.port}" for server in server_works if server._internal_ip
+ f"http://{server._public_ip}:{server.port}" for server in server_works if server._internal_ip
}
# doing nothing if no server work has been added/removed
diff --git a/src/lightning/app/core/work.py b/src/lightning/app/core/work.py
index 67b3fb0361dfd..f5416c0cfae3d 100644
--- a/src/lightning/app/core/work.py
+++ b/src/lightning/app/core/work.py
@@ -60,6 +60,7 @@ class LightningWork:
"_url",
"_restarting",
"_internal_ip",
+ "_public_ip",
)
_run_executor_cls: Type[WorkRunExecutor] = WorkRunExecutor
@@ -138,6 +139,7 @@ def __init__(
"_url",
"_future_url",
"_internal_ip",
+ "_public_ip",
"_restarting",
"_cloud_compute",
"_display_name",
@@ -148,6 +150,7 @@ def __init__(
self._url: str = ""
self._future_url: str = "" # The cache URL is meant to defer resolving the url values.
self._internal_ip: str = ""
+ self._public_ip: str = ""
# setattr_replacement is used by the multiprocessing runtime to send the latest changes to the main coordinator
self._setattr_replacement: Optional[Callable[[str, Any], None]] = None
self._name: str = ""
@@ -212,6 +215,15 @@ def internal_ip(self) -> str:
"""
return self._internal_ip
+ @property
+ def public_ip(self) -> str:
+ """The public ip address of this LightningWork, reachable from the internet.
+
+ By default, this attribute returns the empty string and the ip address will only be returned once the work runs.
+ Locally, this address is undefined (empty string) and in the cloud it will be determined by the cluster.
+ """
+ return self._public_ip
+
def _on_init_end(self) -> None:
self._local_build_config.on_work_init(self)
self._cloud_build_config.on_work_init(self, self._cloud_compute)
diff --git a/src/lightning/app/utilities/proxies.py b/src/lightning/app/utilities/proxies.py
index 9c5a53f2b98e6..39d33785068ac 100644
--- a/src/lightning/app/utilities/proxies.py
+++ b/src/lightning/app/utilities/proxies.py
@@ -494,7 +494,8 @@ def run_once(self):
# Set this here after the state observer is initialized, since it needs to record it as a change and send
# it back to the flow
default_internal_ip = "127.0.0.1" if constants.LIGHTNING_CLOUDSPACE_HOST is None else "0.0.0.0" # noqa: S104
- self.work._internal_ip = os.environ.get("LIGHTNING_NODE_IP", default_internal_ip)
+ self.work._internal_ip = os.environ.get("LIGHTNING_NODE_PRIVATE_IP", default_internal_ip)
+ self.work._public_ip = os.environ.get("LIGHTNING_NODE_IP", "")
# 8. Patch the setattr method of the work. This needs to be done after step 4, so we don't
# send delta while calling `set_state`.
diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md
index 043ae2aab61e3..04215ccb8289a 100644
--- a/src/lightning/fabric/CHANGELOG.md
+++ b/src/lightning/fabric/CHANGELOG.md
@@ -7,9 +7,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [UnReleased] - 2023-04-DD
+- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756))
+
+
+- Add Fabric internal hooks ([#17759](https://github.com/Lightning-AI/lightning/pull/17759))
+
+
### Changed
--
+- Made type hints public ([#17100](https://github.com/Lightning-AI/lightning/pull/17100))
+
+
+- Support compiling a module after it was set up by Fabric ([#17529](https://github.com/Lightning-AI/lightning/pull/17529))
### Fixed
diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py
index f12dc56ce707e..1167a92358d8c 100644
--- a/src/lightning/fabric/fabric.py
+++ b/src/lightning/fabric/fabric.py
@@ -43,6 +43,7 @@
has_iterable_dataset,
)
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
+from lightning.fabric.utilities.registry import _load_external_callbacks
from lightning.fabric.utilities.seed import seed_everything
from lightning.fabric.utilities.types import ReduceOp
from lightning.fabric.utilities.warnings import PossibleUserWarning
@@ -105,8 +106,7 @@ def __init__(
self._strategy: Strategy = self._connector.strategy
self._accelerator: Accelerator = self._connector.accelerator
self._precision: Precision = self._strategy.precision
- callbacks = callbacks if callbacks is not None else []
- self._callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
+ self._callbacks = self._configure_callbacks(callbacks)
loggers = loggers if loggers is not None else []
self._loggers = loggers if isinstance(loggers, list) else [loggers]
self._models_setup: int = 0
@@ -212,7 +212,10 @@ def setup(
# Update the _DeviceDtypeModuleMixin's device parameter
module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device)
- optimizers = [_FabricOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
+ optimizers = [
+ _FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks)
+ for optimizer in optimizers
+ ]
self._models_setup += 1
@@ -220,6 +223,8 @@ def setup(
original_module._fabric = self # type: ignore[assignment]
original_module._fabric_optimizers = optimizers # type: ignore[assignment]
+ self.call("on_after_setup", fabric=self, module=module)
+
if optimizers:
# join both types in a tuple for API convenience
return (module, *optimizers)
@@ -276,7 +281,10 @@ def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, Tu
"""
self._validate_setup_optimizers(optimizers)
optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers]
- optimizers = [_FabricOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
+ optimizers = [
+ _FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks)
+ for optimizer in optimizers
+ ]
return optimizers[0] if len(optimizers) == 1 else tuple(optimizers)
def setup_dataloaders(
@@ -838,6 +846,13 @@ def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None:
if any(not isinstance(dl, DataLoader) for dl in dataloaders):
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
+ @staticmethod
+ def _configure_callbacks(callbacks: Optional[Union[List[Any], Any]]) -> List[Any]:
+ callbacks = callbacks if callbacks is not None else []
+ callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
+ callbacks.extend(_load_external_callbacks("lightning.fabric.callbacks_factory"))
+ return callbacks
+
def _is_using_cli() -> bool:
return bool(int(os.environ.get("LT_CLI_USED", "0")))
diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py
index bdf3d6f3b34c6..6c63ecaa1eeb1 100644
--- a/src/lightning/fabric/utilities/imports.py
+++ b/src/lightning/fabric/utilities/imports.py
@@ -30,3 +30,6 @@
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0", use_base_version=True)
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
_TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1
+
+_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
+_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
diff --git a/src/lightning/fabric/utilities/registry.py b/src/lightning/fabric/utilities/registry.py
index 10903d039ec7d..4c3c96dc5803e 100644
--- a/src/lightning/fabric/utilities/registry.py
+++ b/src/lightning/fabric/utilities/registry.py
@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
-from typing import Any
+import logging
+from typing import Any, List, Union
+
+from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
+
+_log = logging.getLogger(__name__)
def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool:
@@ -25,3 +30,40 @@ def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> boo
return False
return mod_attr.__code__ is not super_attr.__code__
+
+
+def _load_external_callbacks(group: str) -> List[Any]:
+ """Collect external callbacks registered through entry points.
+
+ The entry points are expected to be functions returning a list of callbacks.
+
+ Args:
+ group: The entry point group name to load callbacks from.
+
+ Return:
+ A list of all callbacks collected from external factories.
+ """
+ if _PYTHON_GREATER_EQUAL_3_8_0:
+ from importlib.metadata import entry_points
+
+ factories = (
+ entry_points(group=group)
+ if _PYTHON_GREATER_EQUAL_3_10_0
+ else entry_points().get(group, {}) # type: ignore[arg-type]
+ )
+ else:
+ from pkg_resources import iter_entry_points
+
+ factories = iter_entry_points(group) # type: ignore[assignment]
+
+ external_callbacks: List[Any] = []
+ for factory in factories:
+ callback_factory = factory.load()
+ callbacks_list: Union[List[Any], Any] = callback_factory()
+ callbacks_list = [callbacks_list] if not isinstance(callbacks_list, list) else callbacks_list
+ _log.info(
+ f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
+ f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
+ )
+ external_callbacks.extend(callbacks_list)
+ return external_callbacks
diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py
index 947510244d0df..8ed10f00c5511 100644
--- a/src/lightning/fabric/wrappers.py
+++ b/src/lightning/fabric/wrappers.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
-from typing import Any, Callable, Dict, Generator, Iterator, Mapping, Optional, overload, TypeVar, Union
+from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, overload, TypeVar, Union
import torch
from lightning_utilities import WarningCache
@@ -38,7 +38,7 @@
class _FabricOptimizer:
- def __init__(self, optimizer: Optimizer, strategy: Strategy) -> None:
+ def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[List[Callable]] = None) -> None:
"""FabricOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the
optimizer step calls to the strategy plugin.
@@ -54,6 +54,7 @@ def __init__(self, optimizer: Optimizer, strategy: Strategy) -> None:
self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
self._optimizer = optimizer
self._strategy = strategy
+ self._callbacks = callbacks or []
@property
def optimizer(self) -> Optimizer:
@@ -69,10 +70,15 @@ def step(self, closure: Optional[Callable] = None) -> Any:
optimizer = self._strategy.model
else:
optimizer = self.optimizer
- return self._strategy.optimizer_step(
+ output = self._strategy.optimizer_step(
optimizer,
**kwargs,
)
+ for callback in self._callbacks:
+ hook = getattr(callback, "on_after_optimizer_step", None)
+ if callable(hook):
+ hook(strategy=self._strategy, optimizer=optimizer)
+ return output
class _FabricModule(_DeviceDtypeModuleMixin):
diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md
index 0500e336f2268..5f7ed4645d847 100644
--- a/src/lightning/pytorch/CHANGELOG.md
+++ b/src/lightning/pytorch/CHANGELOG.md
@@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Changed
+- Made type hints public ([#17100](https://github.com/Lightning-AI/lightning/pull/17100))
+
+
-
@@ -41,6 +44,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an edge case causing overlapping samples in DDP when no global seed is set ([#17713](https://github.com/Lightning-AI/lightning/pull/17713))
+- Fallback to module available check for mlflow ([#17467](https://github.com/Lightning-AI/lightning/pull/17467))
+
+
+- Fixed LR finder max val batches ([#17636](https://github.com/Lightning-AI/lightning/pull/17636))
+
+
+- Fixed multithreading checkpoint loading ([#17678](https://github.com/Lightning-AI/lightning/pull/17678))
+
+
## [2.0.2] - 2023-04-24
### Fixed
diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py
index 8195a5c4a3b52..d649755172658 100644
--- a/src/lightning/pytorch/trainer/connectors/callback_connector.py
+++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py
@@ -18,6 +18,7 @@
from typing import Dict, List, Optional, Sequence, Union
import lightning.pytorch as pl
+from lightning.fabric.utilities.registry import _load_external_callbacks
from lightning.pytorch.callbacks import (
Callback,
Checkpoint,
@@ -33,7 +34,6 @@
from lightning.pytorch.callbacks.timer import Timer
from lightning.pytorch.trainer import call
from lightning.pytorch.utilities.exceptions import MisconfigurationException
-from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_info
@@ -75,7 +75,7 @@ def on_trainer_init(
# configure the ModelSummary callback
self._configure_model_summary_callback(enable_model_summary)
- self.trainer.callbacks.extend(_configure_external_callbacks())
+ self.trainer.callbacks.extend(_load_external_callbacks("lightning.pytorch.callbacks_factory"))
_validate_callbacks_list(self.trainer.callbacks)
# push all model checkpoint callbacks to the end
@@ -213,42 +213,6 @@ def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
return tuner_callbacks + other_callbacks + checkpoint_callbacks
-def _configure_external_callbacks() -> List[Callback]:
- """Collect external callbacks registered through entry points.
-
- The entry points are expected to be functions returning a list of callbacks.
-
- Return:
- A list of all callbacks collected from external factories.
- """
- group = "lightning.pytorch.callbacks_factory"
-
- if _PYTHON_GREATER_EQUAL_3_8_0:
- from importlib.metadata import entry_points
-
- factories = (
- entry_points(group=group)
- if _PYTHON_GREATER_EQUAL_3_10_0
- else entry_points().get(group, {}) # type: ignore[arg-type]
- )
- else:
- from pkg_resources import iter_entry_points
-
- factories = iter_entry_points(group) # type: ignore[assignment]
-
- external_callbacks: List[Callback] = []
- for factory in factories:
- callback_factory = factory.load()
- callbacks_list: Union[List[Callback], Callback] = callback_factory()
- callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list
- _log.info(
- f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
- f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
- )
- external_callbacks.extend(callbacks_list)
- return external_callbacks
-
-
def _validate_callbacks_list(callbacks: List[Callback]) -> None:
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
seen_callbacks = set()
diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py
index 1f48386e3ff86..7e6b7cd0c5e91 100644
--- a/src/lightning/pytorch/trainer/connectors/signal_connector.py
+++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py
@@ -11,8 +11,7 @@
import lightning.pytorch as pl
from lightning.fabric.plugins.environments import SLURMEnvironment
-from lightning.fabric.utilities.imports import _IS_WINDOWS
-from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
+from lightning.fabric.utilities.imports import _IS_WINDOWS, _PYTHON_GREATER_EQUAL_3_8_0
from lightning.pytorch.utilities.rank_zero import rank_zero_info
# copied from signal.pyi
diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py
index 259f18070362b..bfb1eeb5c5174 100644
--- a/src/lightning/pytorch/utilities/imports.py
+++ b/src/lightning/pytorch/utilities/imports.py
@@ -17,8 +17,6 @@
import torch
from lightning_utilities.core.imports import package_available, RequirementCache
-_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
-_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task
diff --git a/src/version.info b/src/version.info
index e9307ca5751b2..50ffc5aa7f69f 100644
--- a/src/version.info
+++ b/src/version.info
@@ -1 +1 @@
-2.0.2
+2.0.3
diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py
index f5933747293e6..ed451475b5e80 100644
--- a/tests/tests_app/core/test_lightning_app.py
+++ b/tests/tests_app/core/test_lightning_app.py
@@ -119,6 +119,7 @@ def test_simple_app(tmpdir):
"_url": "",
"_future_url": "",
"_internal_ip": "",
+ "_public_ip": "",
"_paths": {},
"_port": None,
"_restarting": False,
@@ -136,6 +137,7 @@ def test_simple_app(tmpdir):
"_url": "",
"_future_url": "",
"_internal_ip": "",
+ "_public_ip": "",
"_paths": {},
"_port": None,
"_restarting": False,
@@ -982,7 +984,7 @@ def run(self):
def test_state_size_constant_growth():
app = LightningApp(SizeFlow())
MultiProcessRuntime(app, start_server=False).dispatch()
- assert app.root._state_sizes[0] <= 7965
+ assert app.root._state_sizes[0] <= 8304
assert app.root._state_sizes[20] <= 26550
diff --git a/tests/tests_app/core/test_lightning_flow.py b/tests/tests_app/core/test_lightning_flow.py
index 9b174f5a3af8a..fe66a9bb10124 100644
--- a/tests/tests_app/core/test_lightning_flow.py
+++ b/tests/tests_app/core/test_lightning_flow.py
@@ -324,6 +324,7 @@ def run(self):
"_paths": {},
"_restarting": False,
"_internal_ip": "",
+ "_public_ip": "",
"_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
@@ -349,6 +350,7 @@ def run(self):
"_paths": {},
"_restarting": False,
"_internal_ip": "",
+ "_public_ip": "",
"_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
@@ -388,6 +390,7 @@ def run(self):
"_paths": {},
"_restarting": False,
"_internal_ip": "",
+ "_public_ip": "",
"_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
@@ -413,6 +416,7 @@ def run(self):
"_paths": {},
"_restarting": False,
"_internal_ip": "",
+ "_public_ip": "",
"_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
diff --git a/tests/tests_app/structures/test_structures.py b/tests/tests_app/structures/test_structures.py
index f921607a50734..4cbe391027d8a 100644
--- a/tests/tests_app/structures/test_structures.py
+++ b/tests/tests_app/structures/test_structures.py
@@ -46,6 +46,7 @@ def run(self):
"_restarting": False,
"_display_name": "",
"_internal_ip": "",
+ "_public_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "cpu-small",
@@ -80,6 +81,7 @@ def run(self):
"_restarting": False,
"_display_name": "",
"_internal_ip": "",
+ "_public_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "cpu-small",
@@ -114,6 +116,7 @@ def run(self):
"_restarting": False,
"_display_name": "",
"_internal_ip": "",
+ "_public_ip": "",
"_cloud_compute": {
"type": "__cloud_compute__",
"name": "cpu-small",
@@ -199,6 +202,7 @@ def run(self):
"_paths": {},
"_restarting": False,
"_internal_ip": "",
+ "_public_ip": "",
"_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
@@ -233,6 +237,7 @@ def run(self):
"_paths": {},
"_restarting": False,
"_internal_ip": "",
+ "_public_ip": "",
"_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
@@ -262,6 +267,7 @@ def run(self):
"_paths": {},
"_restarting": False,
"_internal_ip": "",
+ "_public_ip": "",
"_display_name": "",
"_cloud_compute": {
"type": "__cloud_compute__",
diff --git a/tests/tests_app/utilities/test_proxies.py b/tests/tests_app/utilities/test_proxies.py
index a4661e2736aaa..8dbfe1b06b871 100644
--- a/tests/tests_app/utilities/test_proxies.py
+++ b/tests/tests_app/utilities/test_proxies.py
@@ -641,16 +641,21 @@ def test_state_observer():
@pytest.mark.parametrize(
- ("patch_constants", "environment", "expected_ip_addr"),
+ ("patch_constants", "environment", "expected_public_ip", "expected_private_ip"),
[
- ({}, {}, "127.0.0.1"),
- ({"LIGHTNING_CLOUDSPACE_HOST": "any"}, {}, "0.0.0.0"), # noqa: S104
- ({}, {"LIGHTNING_NODE_IP": "10.10.10.5"}, "10.10.10.5"),
+ ({}, {}, "", "127.0.0.1"),
+ ({"LIGHTNING_CLOUDSPACE_HOST": "any"}, {}, "", "0.0.0.0"), # noqa: S104
+ (
+ {},
+ {"LIGHTNING_NODE_IP": "85.44.2.25", "LIGHTNING_NODE_PRIVATE_IP": "10.10.10.5"},
+ "85.44.2.25",
+ "10.10.10.5",
+ ),
],
indirect=["patch_constants"],
)
-def test_work_runner_sets_internal_ip(patch_constants, environment, expected_ip_addr):
- """Test that the WorkRunner updates the internal ip address as soon as the Work starts running."""
+def test_work_runner_sets_public_and_private_ip(patch_constants, environment, expected_public_ip, expected_private_ip):
+ """Test that the WorkRunner updates the public and private address as soon as the Work starts running."""
class Work(LightningWork):
def run(self):
@@ -690,11 +695,13 @@ def run(self):
with mock.patch.dict(os.environ, environment, clear=True):
work_runner.setup()
- # The internal ip address only becomes available once the hardware is up / the work is running.
+ # The public ip address only becomes available once the hardware is up / the work is running.
+ assert work.public_ip == ""
assert work.internal_ip == ""
with contextlib.suppress(Empty):
work_runner.run_once()
- assert work.internal_ip == expected_ip_addr
+ assert work.public_ip == expected_public_ip
+ assert work.internal_ip == expected_private_ip
class WorkBi(LightningWork):
diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py
index 58ab0e34166ed..015d1aba41359 100644
--- a/tests/tests_fabric/test_fabric.py
+++ b/tests/tests_fabric/test_fabric.py
@@ -753,6 +753,29 @@ def test_call():
assert not callback1.mock_calls
+def test_special_callbacks():
+ """Tests special callbacks that have hooks for internal Fabric events."""
+
+ class SpecialCallback:
+ def on_after_optimizer_step(self, strategy, optimizer):
+ pass
+
+ def on_after_setup(self, fabric, module):
+ pass
+
+ callback = Mock(wraps=SpecialCallback())
+ fabric = Fabric(accelerator="cpu", callbacks=[callback])
+
+ model = torch.nn.Linear(2, 2)
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
+ fabric_model, fabric_optimizer = fabric.setup(model, optimizer)
+ callback.on_after_setup.assert_called_once_with(fabric=fabric, module=fabric_model)
+
+ model(torch.randn(2, 2)).sum().backward()
+ fabric_optimizer.step()
+ callback.on_after_optimizer_step.assert_called_once_with(strategy=fabric._strategy, optimizer=optimizer)
+
+
def test_loggers_input():
"""Test the various ways in which loggers can be registered with Fabric."""
logger0 = Mock()
diff --git a/tests/tests_fabric/utilities/test_registry.py b/tests/tests_fabric/utilities/test_registry.py
new file mode 100644
index 0000000000000..75e6e12f5abff
--- /dev/null
+++ b/tests/tests_fabric/utilities/test_registry.py
@@ -0,0 +1,64 @@
+import contextlib
+from unittest import mock
+from unittest.mock import Mock
+
+from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
+from lightning.fabric.utilities.registry import _load_external_callbacks
+
+
+class ExternalCallback:
+ """A callback in another library that gets registered through entry points."""
+
+ pass
+
+
+def test_load_external_callbacks():
+ """Test that the connector collects Callback instances from factories registered through entry points."""
+
+ def factory_no_callback():
+ return []
+
+ def factory_one_callback():
+ return ExternalCallback()
+
+ def factory_one_callback_list():
+ return [ExternalCallback()]
+
+ def factory_multiple_callbacks_list():
+ return [ExternalCallback(), ExternalCallback()]
+
+ with _make_entry_point_query_mock(factory_no_callback):
+ callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
+ assert callbacks == []
+
+ with _make_entry_point_query_mock(factory_one_callback):
+ callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
+ assert isinstance(callbacks[0], ExternalCallback)
+
+ with _make_entry_point_query_mock(factory_one_callback_list):
+ callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
+ assert isinstance(callbacks[0], ExternalCallback)
+
+ with _make_entry_point_query_mock(factory_multiple_callbacks_list):
+ callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
+ assert isinstance(callbacks[0], ExternalCallback)
+ assert isinstance(callbacks[1], ExternalCallback)
+
+
+@contextlib.contextmanager
+def _make_entry_point_query_mock(callback_factory):
+ query_mock = Mock()
+ entry_point = Mock()
+ entry_point.name = "mocked"
+ entry_point.load.return_value = callback_factory
+ if _PYTHON_GREATER_EQUAL_3_10_0:
+ query_mock.return_value = [entry_point]
+ import_path = "importlib.metadata.entry_points"
+ elif _PYTHON_GREATER_EQUAL_3_8_0:
+ query_mock().get.return_value = [entry_point]
+ import_path = "importlib.metadata.entry_points"
+ else:
+ query_mock.return_value = [entry_point]
+ import_path = "pkg_resources.iter_entry_points"
+ with mock.patch(import_path, query_mock):
+ yield
diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py
index 71262df9179e8..58f59ad760763 100644
--- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py
+++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py
@@ -19,6 +19,7 @@
import pytest
import torch
+from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from lightning.pytorch import Callback, LightningModule, Trainer
from lightning.pytorch.callbacks import (
EarlyStopping,
@@ -32,7 +33,6 @@
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
-from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
def test_checkpoint_callbacks_are_last(tmpdir):
diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py
index 36f6356f995cb..cea40b921e1a5 100644
--- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py
+++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py
@@ -25,6 +25,7 @@
import torch
from torch import Tensor
+from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from lightning.pytorch import callbacks, Trainer
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
@@ -32,7 +33,6 @@
from lightning.pytorch.loops import _EvaluationLoop
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.exceptions import MisconfigurationException
-from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from tests_pytorch.helpers.runif import RunIf
if _RICH_AVAILABLE: