Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sequential CombinedLoader to flatten the eval and predict loops #16726

Merged
merged 54 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
e816ddd
Refactor CombinedLoader using pytrees
carmocca Feb 10, 2023
ad69804
Refactor CombinedLoader using pytrees
carmocca Feb 10, 2023
98385ba
CHANGELOG
carmocca Feb 10, 2023
177ddc6
Merge branch 'master' into refactor/combined-loader
carmocca Feb 10, 2023
c5fad2c
Minor changes
carmocca Feb 10, 2023
ee451a2
TODO
carmocca Feb 10, 2023
20d3961
Remove unused name
carmocca Feb 10, 2023
c52f3b0
Class structure for modeiterators
carmocca Feb 10, 2023
55b43a1
The great dataloader refactor
carmocca Feb 11, 2023
03c7283
WIP: flatten
carmocca Feb 11, 2023
1e481ee
Merge branch 'master' into refactor/dataloader-everything
carmocca Feb 12, 2023
5c2b8a1
Predict cleanup
carmocca Feb 13, 2023
da8fd06
Merge branch 'master' into refactor/dataloader-everything
carmocca Feb 14, 2023
289e1ab
Hello darkness my old friend
carmocca Feb 14, 2023
81c7189
Merge branch 'master' into refactor/dataloader-everything
carmocca Feb 14, 2023
945a44b
Cleaning up
carmocca Feb 14, 2023
c7e0ff9
Flattening eval loop
carmocca Feb 14, 2023
151661d
Minor fixes
carmocca Feb 14, 2023
445ef53
Fix step kwargs
carmocca Feb 14, 2023
8a85c95
Merge branch 'master' into refactor/dataloader-everything
carmocca Feb 14, 2023
5e36e53
Limits
carmocca Feb 14, 2023
c251daa
Fix tqdm
carmocca Feb 14, 2023
14b5c78
Fixes
carmocca Feb 15, 2023
448598c
Mypy fun times
carmocca Feb 15, 2023
e62b0b1
Return dataloader_idx from SequentialMode
carmocca Feb 15, 2023
1af92bf
Forgot +1
carmocca Feb 15, 2023
6dc04d9
carmocca
carmocca Feb 15, 2023
44ab979
Fix
carmocca Feb 15, 2023
dd705d0
Merge branch 'master' into refactor/dataloader-everything
carmocca Feb 15, 2023
ca93edd
Shine
carmocca Feb 15, 2023
2e4514c
Fixing tests...
carmocca Feb 15, 2023
61fb921
Merge branch 'master' into refactor/dataloader-everything
carmocca Feb 15, 2023
7c8cbca
Fix tpu spawn tests
carmocca Feb 15, 2023
19accde
Merge branch 'master' into refactor/dataloader-everything
carmocca Feb 15, 2023
7886e44
Merge branch 'master' into refactor/dataloader-everything
carmocca Feb 15, 2023
67588bc
WIP: setup_data
carmocca Feb 16, 2023
796d3b7
Merge branch 'master' into refactor/dataloader-everything
carmocca Feb 16, 2023
b6d81a0
Fix logging
carmocca Feb 16, 2023
84f3e02
mypy
carmocca Feb 16, 2023
fbe1def
Fix tests
carmocca Feb 16, 2023
781120c
Fix batch size finder
carmocca Feb 16, 2023
d6ffb36
TODOs and FIXMEs
carmocca Feb 16, 2023
7f98fa1
Merge branch 'master' into refactor/dataloader-everything
carmocca Feb 16, 2023
f5038cb
Move single device test
carmocca Feb 16, 2023
0a1d8e0
Update CHANGELOG
carmocca Feb 16, 2023
b6d7ee4
TODO
carmocca Feb 16, 2023
18f2133
Tests
carmocca Feb 16, 2023
8c970b5
mypy
carmocca Feb 16, 2023
4ffba92
Typo
carmocca Feb 16, 2023
646b266
Merge branch 'master' into refactor/dataloader-everything
carmocca Feb 16, 2023
227f221
Merge branch 'master' into refactor/dataloader-everything
carmocca Feb 17, 2023
f92cd41
Jirka's suggestions
carmocca Feb 17, 2023
8120ab0
Remove is_fresh_start_epoch
carmocca Feb 17, 2023
94c7c3f
fit not necessary
carmocca Feb 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a new method `Strategy.on_exception` to the strategy base interface ([#16646](https://github.com/Lightning-AI/lightning/pull/16646))


- Added support for `predict_step(dataloader_iter, batch_index)` ([#16726](https://github.com/Lightning-AI/lightning/pull/16726))


- Added support for arbitrary iterables as dataloaders ([#16726](https://github.com/Lightning-AI/lightning/pull/16726))


- Added "sequential" mode support to `CombinedLoader` to consume multiple iterables in sequence ([#16743](https://github.com/Lightning-AI/lightning/pull/16743), [#16784](https://github.com/Lightning-AI/lightning/pull/16784))

### Changed
Expand Down Expand Up @@ -87,6 +93,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Renamed `CombinedLoader.loaders` to `CombinedLoader.iterables` ([#16743](https://github.com/Lightning-AI/lightning/pull/16743))


- The top-level loops now own the data sources and combined dataloaders ([#16726](https://github.com/Lightning-AI/lightning/pull/16726))


- The `trainer.*_dataloader` properties now return what the user returned in their `LightningModule.*_dataloader()` hook ([#16726](https://github.com/Lightning-AI/lightning/pull/16726))


- The `dataloader_idx` argument is now optional for the `on_{validation,test,predict}_batch_{start,end}` hooks. Remove it or default it to 0 if you don't use multiple dataloaders ([#16753](https://github.com/Lightning-AI/lightning/pull/16753))


Expand Down Expand Up @@ -210,6 +222,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* The fetching classes are now marked as protected ([#16664](https://github.com/Lightning-AI/lightning/pull/16664))


- Removed the `DataLoaderLoop`, `EvaluationEpochLoop`, and `PredictionEpochLoop` classes ([#16726](https://github.com/Lightning-AI/lightning/pull/16726))


- Removed `trainer.reset_*_dataloader()` methods in favor of `Loop.setup_data()` for the top-level loops ([#16726](https://github.com/Lightning-AI/lightning/pull/16726))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved


- Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172))
* Removed the `LightningModule.truncated_bptt_steps` attribute
* Removed the `LightningModule.tbptt_split_batch` hook
Expand Down
20 changes: 10 additions & 10 deletions src/lightning/pytorch/callbacks/batch_size_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,24 +128,25 @@ def __init__(
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if trainer._accelerator_connector.is_distributed:
raise MisconfigurationException("The Batch size finder is not supported with distributed strategies.")

running_stage = trainer.state.stage
assert running_stage is not None
dl_source = getattr(trainer._data_connector, f"_{running_stage.dataloader_prefix}_dataloader_source")

# TODO: check if this can be enabled (#4040)
if not trainer._data_connector._train_dataloader_source.is_module():
if not trainer.fit_loop._data_source.is_module():
raise MisconfigurationException(
"The Batch size finder cannot be used with dataloaders passed directly to `.fit()`. Please disable"
" the feature or incorporate the dataloader into your LightningModule or LightningDataModule."
)

# TODO: Add support for multiple eval dataloader
if stage != "fit":
dataloaders = dl_source.dataloader()
if isinstance(dataloaders, list) and len(dataloaders) > 1:
loop = trainer._active_loop
assert loop is not None
justusschock marked this conversation as resolved.
Show resolved Hide resolved
loop.setup_data()
combined_loader = loop._combined_loader
assert combined_loader is not None
justusschock marked this conversation as resolved.
Show resolved Hide resolved
if len(combined_loader._flattened) > 1:
stage = trainer.state.stage
assert stage is not None
carmocca marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException(
f"The Batch size finder cannot be used with multiple {running_stage.dataloader_prefix} dataloaders."
f"The Batch size finder cannot be used with multiple {stage.dataloader_prefix} dataloaders."
)

if not lightning_hasattr(pl_module, self._batch_arg_name):
Expand All @@ -167,7 +168,6 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
new_size = _scale_batch_size(
trainer,
pl_module,
self._mode,
self._steps_per_trial,
self._init_val,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/prediction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def on_predict_batch_end(
) -> None:
if not self.interval.on_batch:
return
batch_indices = trainer.predict_loop.epoch_loop.current_batch_indices
batch_indices = trainer.predict_loop.current_batch_indices
self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx)

def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/pytorch/loops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lightning.pytorch.loops.loop import _Loop # noqa: F401 isort: skip (avoids circular imports)
from lightning.pytorch.loops.dataloader import _DataLoaderLoop, _EvaluationLoop, _PredictionLoop # noqa: F401
from lightning.pytorch.loops.epoch import _EvaluationEpochLoop, _PredictionEpochLoop, _TrainingEpochLoop # noqa: F401
from lightning.pytorch.loops.epoch import _TrainingEpochLoop # noqa: F401
from lightning.pytorch.loops.evaluation_loop import _EvaluationLoop # noqa: F401
from lightning.pytorch.loops.fit_loop import _FitLoop # noqa: F401
from lightning.pytorch.loops.optimization import _AutomaticOptimization, _ManualOptimization # noqa: F401
from lightning.pytorch.loops.prediction_loop import _PredictionLoop # noqa: F401
17 changes: 0 additions & 17 deletions src/lightning/pytorch/loops/dataloader/__init__.py

This file was deleted.

68 changes: 0 additions & 68 deletions src/lightning/pytorch/loops/dataloader/dataloader_loop.py

This file was deleted.

178 changes: 0 additions & 178 deletions src/lightning/pytorch/loops/dataloader/prediction_loop.py

This file was deleted.

2 changes: 0 additions & 2 deletions src/lightning/pytorch/loops/epoch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from lightning.pytorch.loops.epoch.evaluation_epoch_loop import _EvaluationEpochLoop # noqa: F401
from lightning.pytorch.loops.epoch.prediction_epoch_loop import _PredictionEpochLoop # noqa: F401
from lightning.pytorch.loops.epoch.training_epoch_loop import _TrainingEpochLoop # noqa: F401
Loading