Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add Loops Restart #8131

Closed
wants to merge 668 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
668 commits
Select commit Hold shift + click to select a range
1763d8f
test
awaelchli Jun 7, 2021
d718498
update trainer
awaelchli Jun 7, 2021
6d98a07
integrate latest changes from logger connector refactor poc
awaelchli Jun 7, 2021
7ca1049
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2021
515ad9f
Minor changes
carmocca Jun 7, 2021
b03591c
update changelog
awaelchli Jun 7, 2021
0aa8428
Remove unused argument
carmocca Jun 7, 2021
24b41e3
Update CHANGELOG
carmocca Jun 7, 2021
6d71e6a
Copy call_hook changes
carmocca Jun 7, 2021
44ad4ac
Docs
carmocca Jun 7, 2021
2c74018
Fix ref
carmocca Jun 7, 2021
b15984b
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 7, 2021
e8021bb
merge
tchaton Jun 8, 2021
9747023
move to cpu
tchaton Jun 8, 2021
d9ae37a
Bad merge
carmocca Jun 8, 2021
bad51c6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
273bc92
remove pdb
tchaton Jun 8, 2021
f214632
remove pdb
tchaton Jun 8, 2021
5fdf3c5
merge
tchaton Jun 8, 2021
99543a7
Refactor to
carmocca Jun 8, 2021
738c810
Avoid partial
carmocca Jun 8, 2021
6a7637d
trigger ci
carmocca Jun 8, 2021
8077cf9
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 8, 2021
aff9e3d
Bad merge
carmocca Jun 8, 2021
464f581
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 8, 2021
461332b
integrate latest logger connector changes
awaelchli Jun 8, 2021
417ad31
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
9321b11
remove grad norm dicts list
awaelchli Jun 8, 2021
e75a958
Diff
carmocca Jun 8, 2021
007dcac
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 8, 2021
2e4bb24
Bad merge
carmocca Jun 8, 2021
f5154ae
Reuse metrics_to_scalars
carmocca Jun 8, 2021
558cdf4
Use active loop
carmocca Jun 8, 2021
90d71bf
Move to device
carmocca Jun 8, 2021
d7f1761
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 8, 2021
6ce6762
resolve test
tchaton Jun 8, 2021
fba9a87
properties first
awaelchli Jun 8, 2021
fd967af
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 8, 2021
79c73b9
define union
awaelchli Jun 8, 2021
37a0b9d
Update logger connector
carmocca Jun 8, 2021
aaea387
Update result
carmocca Jun 8, 2021
e2f69ce
Update imports
carmocca Jun 8, 2021
6037833
Update after rename
carmocca Jun 8, 2021
3804963
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
carmocca Jun 8, 2021
499da76
Refactor reduce_fx and op
carmocca Jun 8, 2021
6eb448a
Fix test after rename
carmocca Jun 8, 2021
f871cbd
mypy
carmocca Jun 8, 2021
5631b53
manual merge poc changes
awaelchli Jun 9, 2021
d10d5c7
integrate latest changes from logger connector poc
awaelchli Jun 9, 2021
7b6803a
Fix test
carmocca Jun 9, 2021
9bfedc9
Refactor test
carmocca Jun 9, 2021
c9c7829
Deprecate `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)`
carmocca Jun 9, 2021
e3dde0b
Undo field
carmocca Jun 9, 2021
bae2139
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 9, 2021
2c167cc
rename
awaelchli Jun 9, 2021
99db497
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2021
832dfb9
rename
awaelchli Jun 9, 2021
f92e01d
imports
awaelchli Jun 9, 2021
b15fc34
loop hygiene
awaelchli Jun 9, 2021
7175a50
yapf on loops
awaelchli Jun 9, 2021
59d6227
protected new loop trigger
awaelchli Jun 9, 2021
e1d4fd2
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 9, 2021
a7c3555
Replace code
carmocca Jun 9, 2021
501224d
Fix names and imports
carmocca Jun 9, 2021
dee7e5f
Remove metric_attribute
carmocca Jun 9, 2021
4eb9757
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 9, 2021
d4bb357
integrate latest logger connector changes
awaelchli Jun 9, 2021
c9b4e9e
resolve todo dataloading reset
awaelchli Jun 10, 2021
a3ef0aa
re-add notebooks
awaelchli Jun 10, 2021
b071532
Merge branch 'master' into refactor/logger-connector-poc
awaelchli Jun 10, 2021
53deef8
add missing init
awaelchli Jun 10, 2021
93fd682
bad merge
awaelchli Jun 10, 2021
80c406e
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 10, 2021
a041b6f
remove iteration count method
awaelchli Jun 10, 2021
e080be8
todo for a fix in #5007
awaelchli Jun 10, 2021
4950821
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 10, 2021
c56adc1
remove NEW_LOOP guard
awaelchli Jun 10, 2021
5e72d1d
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 10, 2021
bace4a2
flake8
awaelchli Jun 10, 2021
71bfb6f
exclude coverage
awaelchli Jun 10, 2021
acc6d4f
Merge branch 'master' into refactor/loops/loops_everywhere_train
awaelchli Jun 10, 2021
41e0e64
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2021
643bef0
flake8 vs yapf wars
awaelchli Jun 10, 2021
4b6bd18
Merge branch 'master' into refactor/loops/loops_everywhere_train
awaelchli Jun 10, 2021
536574a
integrate #7917, remove teardown from training loop
awaelchli Jun 10, 2021
b28fb09
update "accumulated_batches_reached" condition
awaelchli Jun 11, 2021
6f17688
remove public loop properties
awaelchli Jun 11, 2021
6dd4e1d
make skip backward protected again
awaelchli Jun 11, 2021
c394267
typing base loop
awaelchli Jun 11, 2021
4adae06
typing fit loop
awaelchli Jun 11, 2021
c49875d
typing training_batch_loop
awaelchli Jun 11, 2021
80edb75
typing training epoch loop
awaelchli Jun 11, 2021
8b54505
fix merge error
justusschock Jun 11, 2021
9fd8ed1
Merge branch 'master' into refactor/loops/loops_everywhere_train
awaelchli Jun 11, 2021
e4ffa6c
integrate train loop changes from master
awaelchli Jun 11, 2021
69ed0e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 11, 2021
eeebc9a
fix tpipes moving model to cpu and leaving it there.
awaelchli Jun 12, 2021
ce9dd2a
don't reset fit loop
awaelchli Jun 12, 2021
80e225a
fix test iteration count <-> batch_idx reset
awaelchli Jun 14, 2021
4880b26
replace torch.Tensor -> Tensor
awaelchli Jun 14, 2021
5461f73
fix attribute error to block_ddp_sync_behaviour
awaelchli Jun 14, 2021
a2d3f0d
Merge branch 'master' into refactor/loops/loops_everywhere_train
awaelchli Jun 14, 2021
0fe6d9f
ignore mypy errors
awaelchli Jun 14, 2021
5497fc0
fix flake8 and yapf conflict
awaelchli Jun 14, 2021
4c51c45
remove redundant override
awaelchli Jun 14, 2021
8f68b61
Apply suggestions from code review
awaelchli Jun 14, 2021
0150f6c
Apply suggestions from code review
awaelchli Jun 14, 2021
fd90c10
Apply suggestions from code review
awaelchli Jun 14, 2021
153d264
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2021
4eb0eb1
remove all empty space between atoms
awaelchli Jun 14, 2021
70cdb14
carlos
awaelchli Jun 14, 2021
bf26aa3
Apply suggestions from code review
justusschock Jun 14, 2021
ffc4f45
Apply suggestions from code review
justusschock Jun 14, 2021
79f8c18
Merge remote-tracking branch 'origin/refactor/loops/loops_everywhere_…
awaelchli Jun 14, 2021
3373cc8
resolve a todo integrating on_train_batch_end with on_advance_end
awaelchli Jun 14, 2021
e1a40c0
clarify what is todo and what is fixme
awaelchli Jun 14, 2021
b5bb08a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2021
5d98009
shorten a docstring
awaelchli Jun 14, 2021
03bce7a
Merge remote-tracking branch 'origin/refactor/loops/loops_everywhere_…
awaelchli Jun 14, 2021
42c9ad6
wip
tchaton Jun 14, 2021
f001f81
move on_epoch_start to on_run_start of training loop
awaelchli Jun 14, 2021
24fa859
Merge branch 'master' into refactor/loops/loops_everywhere_train
awaelchli Jun 14, 2021
12086c5
add tracking
tchaton Jun 14, 2021
d191fe1
Update pytorch_lightning/loops/base.py
awaelchli Jun 15, 2021
1d21065
update class names in changelog
awaelchli Jun 15, 2021
d8377d5
wip
tchaton Jun 15, 2021
f249351
update
tchaton Jun 15, 2021
7d5b3f3
add zero_grad
tchaton Jun 15, 2021
1b2b251
Merge branch 'refactor/loops/loops_everywhere_train' into progress_tr…
tchaton Jun 15, 2021
743d262
add decription
tchaton Jun 15, 2021
2d8c441
add empty teardown method
awaelchli Jun 15, 2021
1ae88a4
update on comments
tchaton Jun 15, 2021
7763afd
update on comments
tchaton Jun 15, 2021
f874182
added skip property
awaelchli Jun 15, 2021
2ef0fe0
Merge branch 'refactor/loops/loops_everywhere_train' into progress_tr…
tchaton Jun 15, 2021
e2bb1d2
update on comments
tchaton Jun 15, 2021
ec25ab6
Merge branch 'progress_tracking' of https://github.com/PyTorchLightni…
tchaton Jun 15, 2021
27927c4
Merge branch 'master' into progress_tracking
tchaton Jun 15, 2021
8806b00
update
tchaton Jun 15, 2021
4011d85
update changelog
tchaton Jun 15, 2021
d08cb38
update
tchaton Jun 15, 2021
fcdfd39
resolve failing tests
tchaton Jun 15, 2021
5557beb
remove typing
tchaton Jun 15, 2021
f35c7a1
Merge branch 'master' into progress_tracking
kaushikb11 Jun 15, 2021
8c74a3b
update on comments
tchaton Jun 16, 2021
e80230c
Merge branch 'progress_tracking' of https://github.com/PyTorchLightni…
tchaton Jun 16, 2021
f036f3b
update on comments
tchaton Jun 16, 2021
48d720a
move optimizer_idx to batchProgress
tchaton Jun 16, 2021
6d653c2
update
tchaton Jun 16, 2021
741ea4d
Merge branch 'master' into progress_tracking
tchaton Jun 16, 2021
f4b91af
remove useless code
tchaton Jun 16, 2021
6f9a3b8
Merge branch 'progress_tracking' of https://github.com/PyTorchLightni…
tchaton Jun 16, 2021
769364b
add a space on docstring
tchaton Jun 17, 2021
39de7e9
Merge branch 'master' into progress_tracking
carmocca Jun 18, 2021
fda5688
Merge branch 'master' into progress_tracking
carmocca Jun 19, 2021
e2db5fe
Minor changes
carmocca Jun 19, 2021
b947a33
Unused code
carmocca Jun 19, 2021
281847a
Update CHANGELOG
carmocca Jun 19, 2021
0038219
Merge branch 'master' into progress_tracking
tchaton Jun 25, 2021
d39d524
Merge branch 'progress_tracking' of https://github.com/PyTorchLightni…
tchaton Jun 25, 2021
2a45166
update
tchaton Jun 25, 2021
bfd8ac2
update
tchaton Jun 25, 2021
66eb9e0
wip
tchaton Jun 25, 2021
f2d2f32
add FastForwardSampler
tchaton Jun 25, 2021
9d8ffdc
update
tchaton Jun 25, 2021
4d6feca
resolve a bug
tchaton Jun 25, 2021
3e0af69
resolve bug
tchaton Jun 25, 2021
4de9581
Merge branch 'master' into training_restart
tchaton Jun 27, 2021
3d183bc
add support for validation
tchaton Jun 27, 2021
94f052e
update
tchaton Jun 27, 2021
a446639
update
tchaton Jun 27, 2021
be49acc
kill processes
tchaton Jun 27, 2021
6690dcb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 27, 2021
b05617a
add mechanism to kill on deadlock detection
tchaton Jun 27, 2021
56e2763
Merge branch 'training_restart' of https://github.com/PyTorchLightnin…
tchaton Jun 27, 2021
3da64e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 27, 2021
5c1a639
wip
tchaton Jun 28, 2021
1df225d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 28, 2021
6df9968
update
tchaton Jun 28, 2021
8ecd834
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 28, 2021
e62798d
add support for accumulate_grad_batches
tchaton Jun 28, 2021
e3b80ca
Merge branch 'training_restart' of https://github.com/PyTorchLightnin…
tchaton Jun 28, 2021
6c38083
Merge branch 'master' into training_restart
tchaton Jun 29, 2021
11fa777
Merge branch 'training_restart' of https://github.com/PyTorchLightnin…
tchaton Jun 29, 2021
b9be984
resolve bugs
tchaton Jun 30, 2021
a630b2e
resolve tracking
tchaton Jun 30, 2021
51671d0
Rename
carmocca Jul 1, 2021
cd28db2
Merge branch 'master' into training_restart
tchaton Jul 1, 2021
25da045
Merge branch 'training_restart' of https://github.com/PyTorchLightnin…
carmocca Jul 1, 2021
97cbd9a
wip
tchaton Jul 1, 2021
ba276b9
Comments after call
carmocca Jul 1, 2021
c08fe50
Merge branch 'master' into training_restart
carmocca Jul 1, 2021
ac7cff7
update
tchaton Jul 1, 2021
d4917e5
update
tchaton Jul 1, 2021
2ecab37
add partial support for iterative dataset
tchaton Jul 2, 2021
25b1ac2
update
tchaton Jul 2, 2021
7c9058f
Merge branch 'fast_forward_samplers' into training_restart
tchaton Jul 2, 2021
588d0a9
added some logic for samplers restart
tchaton Jul 4, 2021
7dd42c0
update
tchaton Jul 5, 2021
16a58c6
resolve bug
tchaton Jul 5, 2021
756baca
fix attribute error
awaelchli Jul 6, 2021
2f54117
add simple test for random dataset (wip)
awaelchli Jul 6, 2021
c6a774c
wip
tchaton Jul 6, 2021
a946159
update
tchaton Jul 6, 2021
a2b74f0
resolve bug
tchaton Jul 6, 2021
668b02e
wip
tchaton Jul 6, 2021
3827208
wip
tchaton Jul 6, 2021
1f4ef8c
wip
tchaton Jul 6, 2021
770a78b
resolved tests
tchaton Jul 6, 2021
76f1f53
update on comments
tchaton Jul 7, 2021
3aaf0ea
update
tchaton Jul 7, 2021
ed056aa
update
tchaton Jul 7, 2021
f1cdcdc
Merge branch 'master' into add_fast_forward_sampler
tchaton Jul 7, 2021
81bf954
Update pytorch_lightning/utilities/auto_restart.py
tchaton Jul 7, 2021
3cd6859
Merge branch 'master' into training_restart
tchaton Jul 7, 2021
54e0a24
Merge branch 'training_restart' of https://github.com/PyTorchLightnin…
tchaton Jul 7, 2021
76f0503
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2021
7a05094
update on comments
tchaton Jul 7, 2021
bff288c
Merge branch 'add_fast_forward_sampler' of https://github.com/PyTorch…
tchaton Jul 7, 2021
98ec265
Update pytorch_lightning/utilities/auto_restart.py
tchaton Jul 7, 2021
82b1cf1
resolve bug
tchaton Jul 7, 2021
8972d82
update
tchaton Jul 7, 2021
1fb8c02
move properties to top
awaelchli Jul 7, 2021
f086edb
update docs for fast forward sampler
awaelchli Jul 7, 2021
7450388
move public attribute to top
awaelchli Jul 7, 2021
5e43757
add missing super call
awaelchli Jul 7, 2021
eae11c3
update docs for state_dict
awaelchli Jul 7, 2021
efcb882
fix merge conflict
awaelchli Jul 7, 2021
c068704
add missing super() call
awaelchli Jul 7, 2021
79ff550
move property to top
awaelchli Jul 7, 2021
d433bb4
update
tchaton Jul 7, 2021
dfbb8eb
Merge branch 'training_restart' of https://github.com/PyTorchLightnin…
tchaton Jul 7, 2021
50ac617
update on comments
tchaton Jul 7, 2021
733e329
Merge branch 'add_fast_forward_sampler' of https://github.com/PyTorch…
tchaton Jul 7, 2021
f111826
resolve bug
tchaton Jul 7, 2021
67a3691
update
tchaton Jul 7, 2021
14bea6b
wip
tchaton Jul 7, 2021
4eee70a
resolve bug
tchaton Jul 7, 2021
8b93505
update
tchaton Jul 7, 2021
322600c
wip
tchaton Jul 7, 2021
028d773
update on comments
tchaton Jul 7, 2021
0be0b5d
some refactor
tchaton Jul 7, 2021
5c3e328
activate coverage for CaptureIterableDataset
tchaton Jul 7, 2021
461bee9
update on comments
tchaton Jul 7, 2021
2de5290
update
tchaton Jul 7, 2021
bee5536
Merge branch 'master' into training_restart
tchaton Jul 7, 2021
d98e0fd
Merge branch 'add_fast_forward_sampler' into training_restart
tchaton Jul 7, 2021
cb3c2f9
wip
tchaton Jul 7, 2021
07331ab
resolve training loop
tchaton Jul 7, 2021
7e2bf72
wip
tchaton Jul 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Progress tracking
* Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574), [#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140))
* Added dataclasses for progress tracking([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574))
* Integrate progress tracking with the training loops ([#7976](https://github.com/PyTorchLightning/pytorch-lightning/pull/7976))
* Add `{,load_}state_dict` to the progress tracking dataclasses ([#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140))
* Connect the progress tracking dataclasses to the loops ([#8244](https://github.com/PyTorchLightning/pytorch-lightning/pull/8244))

Expand Down Expand Up @@ -137,6 +138,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `restore` function and `restarting` attribute to base `Loop` ([#8247](https://github.com/PyTorchLightning/pytorch-lightning/pull/8247))


- Added `FastForwardSampler` and `CaptureIterableDataset` ([#8307](https://github.com/PyTorchLightning/pytorch-lightning/pull/8307))


### Changed


Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def closure_dis():
profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"

self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
self._trainer.fit_loop.epoch_loop.batch_loop.optim_progress.optimizer.step.increment_processed()
self._total_optimizer_step_calls += 1

def __repr__(self):
Expand Down
90 changes: 82 additions & 8 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, OrderedDict

from deprecate import void

import pytorch_lightning as pl
from pytorch_lightning.trainer.progress import BaseProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -46,7 +47,44 @@ class Loop(ABC):
def __init__(self) -> None:
self.iteration_count: int = 0
self.trainer: Optional['pl.Trainer'] = None
self._cached_state: Optional[Dict] = None
self.restarting = False
self._loops = OrderedDict()
self._progress = OrderedDict()

def __setattr__(self, name: str, value: Any) -> None:
if isinstance(value, Loop):
self._loops[name] = value
elif isinstance(value, BaseProgress):
self._progress[name] = value
else:
object.__setattr__(self, name, value)

def __getattr__(self, name) -> Any:
loops = self.__dict__.get('_loops')
if loops is None:
raise MisconfigurationException("The Loop wasn't called parent `__init__` function.")

if name in loops:
return loops[name]

progress = self.__dict__.get('_progress')

if name in progress:
return progress[name]

if name not in self.__dict__:
raise AttributeError(f"{self.__class__.__name__} Loop doesn't have attribute {name}.")

return self.__dict__[name]

def __delattr__(self, name) -> None:
if name in self._loops:
del self._loops[name]
elif name in self._progress:
del self._progress[name]
else:
object.__delattr__(self, name)

@property
@abstractmethod
Expand Down Expand Up @@ -89,7 +127,8 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
return self.on_skip()

if self.restarting:
self.restore()
self.restore(self._cached_state)
self._cached_state = None
self.restarting = False
else:
self.reset()
Expand All @@ -108,7 +147,8 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
output = self.on_run_end()
return output

def restore(self) -> None:
@abstractmethod
def restore(self, state: Optional[Dict] = None) -> None:
"""Restore the internal state of the loop the beginning of run if restarting is ``True``."""

@abstractmethod
Expand Down Expand Up @@ -142,9 +182,43 @@ def on_run_end(self) -> Any:
def teardown(self) -> None:
"""Use to release memory etc."""

def load_state_dict(self, state_dict: Dict) -> None:
"""Restore the loop state from the provided state_dict."""

@abstractmethod
def state_dict(self) -> Dict:
"""Return the loop current states."""
return {}
"""Current Loop state"""

def get_state_dict(self, destination: Optional[OrderedDict] = None, prefix: Optional[str] = '') -> OrderedDict:
if destination is None:
destination = OrderedDict()

destination[prefix + "state_dict"] = self.state_dict()

for name, progress in self._progress.items():
destination[prefix + name] = progress.state_dict()

for name, loop in self._loops.items():
loop.get_state_dict(destination, prefix + name + '.')
return destination

def _load_from_state_dict(self, state_dict, prefix, strict, missing_keys, unexpected_keys, error_msgs):
self._cached_state = state_dict[prefix + "state_dict"]

for name, progress in self._progress.items():
progress.load_state_dict(state_dict[prefix + name])

def load_state_dict(self, state_dict: Dict, strict: bool = True):

missing_keys = []
unexpected_keys = []
error_msgs = []

state_dict = state_dict.copy()

def load(loop, prefix=''):
loop._load_from_state_dict(state_dict, prefix, True, missing_keys, unexpected_keys, error_msgs)
loop.restarting = True
for name, loop_children in loop._loops.items():
if loop_children is not None:
load(loop_children, prefix + name + '.')

load(self)
load = None # break load->load reference cycle
67 changes: 61 additions & 6 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,8 @@ def connect(
) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
if progress is not None:
self.progress = progress
if optim_progress is not None:
self.optim_progress = optim_progress
self.progress = progress or self.progress
self.optim_progress = optim_progress or self.optim_progress

@property
def done(self) -> bool:
Expand All @@ -98,6 +96,8 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
return AttributeDict(signal=0, training_step_output=[[]])

self.progress.increment_ready()

# hook
self.trainer.logger_connector.on_batch_start()
response = self.trainer.call_hook("on_batch_start")
Expand All @@ -114,12 +114,23 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict:
self.batch_outputs = None # free memory
return output

def reset(self) -> None:
def _initialize(self):
"""Resets the loop state"""
self._hiddens = None
self.batch_idx = 0
self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))]

def restore(self) -> None:
"""Restore the loop state"""
self._initialize()

def reset(self) -> None:
"""Resets the loop state"""
self._initialize()

# reset tracking
self.optim_progress.optimizer.reset_on_epoch()

def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int):
"""Splits the data into tbptt splits

Expand All @@ -131,6 +142,14 @@ def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int):
void(batch_idx, dataloader_idx)
self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch)))

def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
self.progress.increment_started()
return super().on_advance_start(*args, **kwargs)

def on_advance_end(self) -> None:
self.progress.increment_completed()
return super().on_advance_end()

def advance(self, batch, batch_idx, dataloader_idx):
"""Runs the train step together with optimization (if necessary) on the current batch split

Expand All @@ -148,7 +167,17 @@ def advance(self, batch, batch_idx, dataloader_idx):
self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch)

if self.trainer.lightning_module.automatic_optimization:
for opt_idx, optimizer in self.get_active_optimizers(batch_idx):
active_optimizers = self.get_active_optimizers(batch_idx)
for opt_idx, optimizer in active_optimizers:

# handle optimization restart
if self.trainer.is_restarting:
if len(active_optimizers) > 1 and opt_idx < self.progress.current.completed:
continue

# track optimizer_idx
self.optim_progress.optimizer_idx = opt_idx

result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
if result:
self.batch_outputs[opt_idx].append(result.training_step_output)
Expand All @@ -158,6 +187,8 @@ def advance(self, batch, batch_idx, dataloader_idx):
if result:
self.batch_outputs[0].append(result.training_step_output)

self.progress.increment_processed()

def teardown(self) -> None:
# release memory
self._remaining_splits = None
Expand Down Expand Up @@ -238,8 +269,14 @@ def _training_step_and_backward_closure(
"""

result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens)

if result is not None:
return_result.update(result)

# this should be done only if result.loss exists
if not self.should_accumulate():
self.optim_progress.optimizer.step.increment_started()

return return_result.loss

def _make_closure(self, *closure_args: Any, **closure_kwargs: Any) -> Callable:
Expand Down Expand Up @@ -409,6 +446,8 @@ def _optimizer_step(
# wraps into LightningOptimizer only for running step
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)

self.optim_progress.optimizer.step.increment_ready()

# model hook
model_ref.optimizer_step(
self.trainer.current_epoch,
Expand All @@ -421,13 +460,17 @@ def _optimizer_step(
using_lbfgs=is_lbfgs,
)

self.optim_progress.optimizer.step.increment_completed()

def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
"""Calls the ``on_before_zero_grad`` hook.

Args:
optimizer: the current optimizer
"""
self.optim_progress.optimizer.zero_grad.increment_started()
self.trainer.call_hook('on_before_zero_grad', optimizer)
self.optim_progress.optimizer.zero_grad.increment_ready()

def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
"""Zeroes out all gradients of parameters optimized by the current optimizer.
Expand All @@ -437,8 +480,11 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer,
optimizer: the current optimizer
opt_idx: the index of the current optimizer
"""

self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)

self.optim_progress.optimizer.zero_grad.increment_completed()

def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]:
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.

Expand Down Expand Up @@ -700,3 +746,12 @@ def _truncated_bptt_steps(self) -> int:
if lightning_module.truncated_bptt_steps > 0:
return lightning_module.truncated_bptt_steps
return self.trainer.truncated_bptt_steps or 0

def state_dict(self) -> Dict:
return {"progress": self.progress.state_dict(), "optim_progress": self.optim_progress.state_dict()}

def load_state_dict(self, state_dict: Dict) -> None:
if "progress" in state_dict:
self.progress.load_state_dict(state_dict['progress'])
if "optim_progress" in state_dict:
self.optim_progress.load_state_dict(state_dict['optim_progress'])
Loading