Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add Loops Restart #8131

Closed
wants to merge 668 commits into from
Closed
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
668 commits
Select commit Hold shift + click to select a range
1763d8f
test
awaelchli Jun 7, 2021
d718498
update trainer
awaelchli Jun 7, 2021
6d98a07
integrate latest changes from logger connector refactor poc
awaelchli Jun 7, 2021
7ca1049
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2021
515ad9f
Minor changes
carmocca Jun 7, 2021
b03591c
update changelog
awaelchli Jun 7, 2021
0aa8428
Remove unused argument
carmocca Jun 7, 2021
24b41e3
Update CHANGELOG
carmocca Jun 7, 2021
6d71e6a
Copy call_hook changes
carmocca Jun 7, 2021
44ad4ac
Docs
carmocca Jun 7, 2021
2c74018
Fix ref
carmocca Jun 7, 2021
b15984b
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 7, 2021
e8021bb
merge
tchaton Jun 8, 2021
9747023
move to cpu
tchaton Jun 8, 2021
d9ae37a
Bad merge
carmocca Jun 8, 2021
bad51c6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
273bc92
remove pdb
tchaton Jun 8, 2021
f214632
remove pdb
tchaton Jun 8, 2021
5fdf3c5
merge
tchaton Jun 8, 2021
99543a7
Refactor to
carmocca Jun 8, 2021
738c810
Avoid partial
carmocca Jun 8, 2021
6a7637d
trigger ci
carmocca Jun 8, 2021
8077cf9
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 8, 2021
aff9e3d
Bad merge
carmocca Jun 8, 2021
464f581
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 8, 2021
461332b
integrate latest logger connector changes
awaelchli Jun 8, 2021
417ad31
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
9321b11
remove grad norm dicts list
awaelchli Jun 8, 2021
e75a958
Diff
carmocca Jun 8, 2021
007dcac
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 8, 2021
2e4bb24
Bad merge
carmocca Jun 8, 2021
f5154ae
Reuse metrics_to_scalars
carmocca Jun 8, 2021
558cdf4
Use active loop
carmocca Jun 8, 2021
90d71bf
Move to device
carmocca Jun 8, 2021
d7f1761
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 8, 2021
6ce6762
resolve test
tchaton Jun 8, 2021
fba9a87
properties first
awaelchli Jun 8, 2021
fd967af
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 8, 2021
79c73b9
define union
awaelchli Jun 8, 2021
37a0b9d
Update logger connector
carmocca Jun 8, 2021
aaea387
Update result
carmocca Jun 8, 2021
e2f69ce
Update imports
carmocca Jun 8, 2021
6037833
Update after rename
carmocca Jun 8, 2021
3804963
Merge branch 'refactor/logger-connector-poc' of https://github.com/Py…
carmocca Jun 8, 2021
499da76
Refactor reduce_fx and op
carmocca Jun 8, 2021
6eb448a
Fix test after rename
carmocca Jun 8, 2021
f871cbd
mypy
carmocca Jun 8, 2021
5631b53
manual merge poc changes
awaelchli Jun 9, 2021
d10d5c7
integrate latest changes from logger connector poc
awaelchli Jun 9, 2021
7b6803a
Fix test
carmocca Jun 9, 2021
9bfedc9
Refactor test
carmocca Jun 9, 2021
c9c7829
Deprecate `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)`
carmocca Jun 9, 2021
e3dde0b
Undo field
carmocca Jun 9, 2021
bae2139
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 9, 2021
2c167cc
rename
awaelchli Jun 9, 2021
99db497
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2021
832dfb9
rename
awaelchli Jun 9, 2021
f92e01d
imports
awaelchli Jun 9, 2021
b15fc34
loop hygiene
awaelchli Jun 9, 2021
7175a50
yapf on loops
awaelchli Jun 9, 2021
59d6227
protected new loop trigger
awaelchli Jun 9, 2021
e1d4fd2
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 9, 2021
a7c3555
Replace code
carmocca Jun 9, 2021
501224d
Fix names and imports
carmocca Jun 9, 2021
dee7e5f
Remove metric_attribute
carmocca Jun 9, 2021
4eb9757
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 9, 2021
d4bb357
integrate latest logger connector changes
awaelchli Jun 9, 2021
c9b4e9e
resolve todo dataloading reset
awaelchli Jun 10, 2021
a3ef0aa
re-add notebooks
awaelchli Jun 10, 2021
b071532
Merge branch 'master' into refactor/logger-connector-poc
awaelchli Jun 10, 2021
53deef8
add missing init
awaelchli Jun 10, 2021
93fd682
bad merge
awaelchli Jun 10, 2021
80c406e
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 10, 2021
a041b6f
remove iteration count method
awaelchli Jun 10, 2021
e080be8
todo for a fix in #5007
awaelchli Jun 10, 2021
4950821
Merge branch 'master' into refactor/logger-connector-poc
carmocca Jun 10, 2021
c56adc1
remove NEW_LOOP guard
awaelchli Jun 10, 2021
5e72d1d
Merge branch 'refactor/logger-connector-poc' into refactor/loops/loop…
awaelchli Jun 10, 2021
bace4a2
flake8
awaelchli Jun 10, 2021
71bfb6f
exclude coverage
awaelchli Jun 10, 2021
acc6d4f
Merge branch 'master' into refactor/loops/loops_everywhere_train
awaelchli Jun 10, 2021
41e0e64
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 10, 2021
643bef0
flake8 vs yapf wars
awaelchli Jun 10, 2021
4b6bd18
Merge branch 'master' into refactor/loops/loops_everywhere_train
awaelchli Jun 10, 2021
536574a
integrate #7917, remove teardown from training loop
awaelchli Jun 10, 2021
b28fb09
update "accumulated_batches_reached" condition
awaelchli Jun 11, 2021
6f17688
remove public loop properties
awaelchli Jun 11, 2021
6dd4e1d
make skip backward protected again
awaelchli Jun 11, 2021
c394267
typing base loop
awaelchli Jun 11, 2021
4adae06
typing fit loop
awaelchli Jun 11, 2021
c49875d
typing training_batch_loop
awaelchli Jun 11, 2021
80edb75
typing training epoch loop
awaelchli Jun 11, 2021
8b54505
fix merge error
justusschock Jun 11, 2021
9fd8ed1
Merge branch 'master' into refactor/loops/loops_everywhere_train
awaelchli Jun 11, 2021
e4ffa6c
integrate train loop changes from master
awaelchli Jun 11, 2021
69ed0e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 11, 2021
eeebc9a
fix tpipes moving model to cpu and leaving it there.
awaelchli Jun 12, 2021
ce9dd2a
don't reset fit loop
awaelchli Jun 12, 2021
80e225a
fix test iteration count <-> batch_idx reset
awaelchli Jun 14, 2021
4880b26
replace torch.Tensor -> Tensor
awaelchli Jun 14, 2021
5461f73
fix attribute error to block_ddp_sync_behaviour
awaelchli Jun 14, 2021
a2d3f0d
Merge branch 'master' into refactor/loops/loops_everywhere_train
awaelchli Jun 14, 2021
0fe6d9f
ignore mypy errors
awaelchli Jun 14, 2021
5497fc0
fix flake8 and yapf conflict
awaelchli Jun 14, 2021
4c51c45
remove redundant override
awaelchli Jun 14, 2021
8f68b61
Apply suggestions from code review
awaelchli Jun 14, 2021
0150f6c
Apply suggestions from code review
awaelchli Jun 14, 2021
fd90c10
Apply suggestions from code review
awaelchli Jun 14, 2021
153d264
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2021
4eb0eb1
remove all empty space between atoms
awaelchli Jun 14, 2021
70cdb14
carlos
awaelchli Jun 14, 2021
bf26aa3
Apply suggestions from code review
justusschock Jun 14, 2021
ffc4f45
Apply suggestions from code review
justusschock Jun 14, 2021
79f8c18
Merge remote-tracking branch 'origin/refactor/loops/loops_everywhere_…
awaelchli Jun 14, 2021
3373cc8
resolve a todo integrating on_train_batch_end with on_advance_end
awaelchli Jun 14, 2021
e1a40c0
clarify what is todo and what is fixme
awaelchli Jun 14, 2021
b5bb08a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2021
5d98009
shorten a docstring
awaelchli Jun 14, 2021
03bce7a
Merge remote-tracking branch 'origin/refactor/loops/loops_everywhere_…
awaelchli Jun 14, 2021
42c9ad6
wip
tchaton Jun 14, 2021
f001f81
move on_epoch_start to on_run_start of training loop
awaelchli Jun 14, 2021
24fa859
Merge branch 'master' into refactor/loops/loops_everywhere_train
awaelchli Jun 14, 2021
12086c5
add tracking
tchaton Jun 14, 2021
d191fe1
Update pytorch_lightning/loops/base.py
awaelchli Jun 15, 2021
1d21065
update class names in changelog
awaelchli Jun 15, 2021
d8377d5
wip
tchaton Jun 15, 2021
f249351
update
tchaton Jun 15, 2021
7d5b3f3
add zero_grad
tchaton Jun 15, 2021
1b2b251
Merge branch 'refactor/loops/loops_everywhere_train' into progress_tr…
tchaton Jun 15, 2021
743d262
add decription
tchaton Jun 15, 2021
2d8c441
add empty teardown method
awaelchli Jun 15, 2021
1ae88a4
update on comments
tchaton Jun 15, 2021
7763afd
update on comments
tchaton Jun 15, 2021
f874182
added skip property
awaelchli Jun 15, 2021
2ef0fe0
Merge branch 'refactor/loops/loops_everywhere_train' into progress_tr…
tchaton Jun 15, 2021
e2bb1d2
update on comments
tchaton Jun 15, 2021
ec25ab6
Merge branch 'progress_tracking' of https://github.com/PyTorchLightni…
tchaton Jun 15, 2021
27927c4
Merge branch 'master' into progress_tracking
tchaton Jun 15, 2021
8806b00
update
tchaton Jun 15, 2021
4011d85
update changelog
tchaton Jun 15, 2021
d08cb38
update
tchaton Jun 15, 2021
fcdfd39
resolve failing tests
tchaton Jun 15, 2021
5557beb
remove typing
tchaton Jun 15, 2021
f35c7a1
Merge branch 'master' into progress_tracking
kaushikb11 Jun 15, 2021
8c74a3b
update on comments
tchaton Jun 16, 2021
e80230c
Merge branch 'progress_tracking' of https://github.com/PyTorchLightni…
tchaton Jun 16, 2021
f036f3b
update on comments
tchaton Jun 16, 2021
48d720a
move optimizer_idx to batchProgress
tchaton Jun 16, 2021
6d653c2
update
tchaton Jun 16, 2021
741ea4d
Merge branch 'master' into progress_tracking
tchaton Jun 16, 2021
f4b91af
remove useless code
tchaton Jun 16, 2021
6f9a3b8
Merge branch 'progress_tracking' of https://github.com/PyTorchLightni…
tchaton Jun 16, 2021
769364b
add a space on docstring
tchaton Jun 17, 2021
39de7e9
Merge branch 'master' into progress_tracking
carmocca Jun 18, 2021
fda5688
Merge branch 'master' into progress_tracking
carmocca Jun 19, 2021
e2db5fe
Minor changes
carmocca Jun 19, 2021
b947a33
Unused code
carmocca Jun 19, 2021
281847a
Update CHANGELOG
carmocca Jun 19, 2021
0038219
Merge branch 'master' into progress_tracking
tchaton Jun 25, 2021
d39d524
Merge branch 'progress_tracking' of https://github.com/PyTorchLightni…
tchaton Jun 25, 2021
2a45166
update
tchaton Jun 25, 2021
bfd8ac2
update
tchaton Jun 25, 2021
66eb9e0
wip
tchaton Jun 25, 2021
f2d2f32
add FastForwardSampler
tchaton Jun 25, 2021
9d8ffdc
update
tchaton Jun 25, 2021
4d6feca
resolve a bug
tchaton Jun 25, 2021
3e0af69
resolve bug
tchaton Jun 25, 2021
4de9581
Merge branch 'master' into training_restart
tchaton Jun 27, 2021
3d183bc
add support for validation
tchaton Jun 27, 2021
94f052e
update
tchaton Jun 27, 2021
a446639
update
tchaton Jun 27, 2021
be49acc
kill processes
tchaton Jun 27, 2021
6690dcb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 27, 2021
b05617a
add mechanism to kill on deadlock detection
tchaton Jun 27, 2021
56e2763
Merge branch 'training_restart' of https://github.com/PyTorchLightnin…
tchaton Jun 27, 2021
3da64e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 27, 2021
5c1a639
wip
tchaton Jun 28, 2021
1df225d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 28, 2021
6df9968
update
tchaton Jun 28, 2021
8ecd834
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 28, 2021
e62798d
add support for accumulate_grad_batches
tchaton Jun 28, 2021
e3b80ca
Merge branch 'training_restart' of https://github.com/PyTorchLightnin…
tchaton Jun 28, 2021
6c38083
Merge branch 'master' into training_restart
tchaton Jun 29, 2021
11fa777
Merge branch 'training_restart' of https://github.com/PyTorchLightnin…
tchaton Jun 29, 2021
b9be984
resolve bugs
tchaton Jun 30, 2021
a630b2e
resolve tracking
tchaton Jun 30, 2021
51671d0
Rename
carmocca Jul 1, 2021
cd28db2
Merge branch 'master' into training_restart
tchaton Jul 1, 2021
25da045
Merge branch 'training_restart' of https://github.com/PyTorchLightnin…
carmocca Jul 1, 2021
97cbd9a
wip
tchaton Jul 1, 2021
ba276b9
Comments after call
carmocca Jul 1, 2021
c08fe50
Merge branch 'master' into training_restart
carmocca Jul 1, 2021
ac7cff7
update
tchaton Jul 1, 2021
d4917e5
update
tchaton Jul 1, 2021
2ecab37
add partial support for iterative dataset
tchaton Jul 2, 2021
25b1ac2
update
tchaton Jul 2, 2021
7c9058f
Merge branch 'fast_forward_samplers' into training_restart
tchaton Jul 2, 2021
588d0a9
added some logic for samplers restart
tchaton Jul 4, 2021
7dd42c0
update
tchaton Jul 5, 2021
16a58c6
resolve bug
tchaton Jul 5, 2021
756baca
fix attribute error
awaelchli Jul 6, 2021
2f54117
add simple test for random dataset (wip)
awaelchli Jul 6, 2021
c6a774c
wip
tchaton Jul 6, 2021
a946159
update
tchaton Jul 6, 2021
a2b74f0
resolve bug
tchaton Jul 6, 2021
668b02e
wip
tchaton Jul 6, 2021
3827208
wip
tchaton Jul 6, 2021
1f4ef8c
wip
tchaton Jul 6, 2021
770a78b
resolved tests
tchaton Jul 6, 2021
76f1f53
update on comments
tchaton Jul 7, 2021
3aaf0ea
update
tchaton Jul 7, 2021
ed056aa
update
tchaton Jul 7, 2021
f1cdcdc
Merge branch 'master' into add_fast_forward_sampler
tchaton Jul 7, 2021
81bf954
Update pytorch_lightning/utilities/auto_restart.py
tchaton Jul 7, 2021
3cd6859
Merge branch 'master' into training_restart
tchaton Jul 7, 2021
54e0a24
Merge branch 'training_restart' of https://github.com/PyTorchLightnin…
tchaton Jul 7, 2021
76f0503
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2021
7a05094
update on comments
tchaton Jul 7, 2021
bff288c
Merge branch 'add_fast_forward_sampler' of https://github.com/PyTorch…
tchaton Jul 7, 2021
98ec265
Update pytorch_lightning/utilities/auto_restart.py
tchaton Jul 7, 2021
82b1cf1
resolve bug
tchaton Jul 7, 2021
8972d82
update
tchaton Jul 7, 2021
1fb8c02
move properties to top
awaelchli Jul 7, 2021
f086edb
update docs for fast forward sampler
awaelchli Jul 7, 2021
7450388
move public attribute to top
awaelchli Jul 7, 2021
5e43757
add missing super call
awaelchli Jul 7, 2021
eae11c3
update docs for state_dict
awaelchli Jul 7, 2021
efcb882
fix merge conflict
awaelchli Jul 7, 2021
c068704
add missing super() call
awaelchli Jul 7, 2021
79ff550
move property to top
awaelchli Jul 7, 2021
d433bb4
update
tchaton Jul 7, 2021
dfbb8eb
Merge branch 'training_restart' of https://github.com/PyTorchLightnin…
tchaton Jul 7, 2021
50ac617
update on comments
tchaton Jul 7, 2021
733e329
Merge branch 'add_fast_forward_sampler' of https://github.com/PyTorch…
tchaton Jul 7, 2021
f111826
resolve bug
tchaton Jul 7, 2021
67a3691
update
tchaton Jul 7, 2021
14bea6b
wip
tchaton Jul 7, 2021
4eee70a
resolve bug
tchaton Jul 7, 2021
8b93505
update
tchaton Jul 7, 2021
322600c
wip
tchaton Jul 7, 2021
028d773
update on comments
tchaton Jul 7, 2021
0be0b5d
some refactor
tchaton Jul 7, 2021
5c3e328
activate coverage for CaptureIterableDataset
tchaton Jul 7, 2021
461bee9
update on comments
tchaton Jul 7, 2021
2de5290
update
tchaton Jul 7, 2021
bee5536
Merge branch 'master' into training_restart
tchaton Jul 7, 2021
d98e0fd
Merge branch 'add_fast_forward_sampler' into training_restart
tchaton Jul 7, 2021
cb3c2f9
wip
tchaton Jul 7, 2021
07331ab
resolve training loop
tchaton Jul 7, 2021
7e2bf72
wip
tchaton Jul 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for checkpointing based on a provided time interval during training ([#7515](https://github.com/PyTorchLightning/pytorch-lightning/pull/7515))


- Added dataclasses for progress tracking (
[#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603),
[#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574))
- Progress tracking
* Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574))
* Integrate progress tracking with the training loops ([#7976](https://github.com/PyTorchLightning/pytorch-lightning/pull/7976))


- Added support for passing a `LightningDataModule` positionally as the second argument to `trainer.{validate,test,predict}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431))
Expand Down
70 changes: 69 additions & 1 deletion pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
from torch import Tensor
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import TrainBatchLoopProgress, TrainingProgress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -53,6 +55,24 @@ def __init__(self) -> None:
self._optimizer_freq_cumsum: Optional[int] = None
self._remaining_splits: Optional[List[Any]] = None
self._skip_backward: bool = False
self._progress: Optional[TrainBatchLoopProgress] = None
self._progress_optimization: Optional[TrainingProgress] = None

@property
def progress(self) -> Optional[TrainBatchLoopProgress]:
return self._progress

@progress.setter
def progress(self, progress: TrainBatchLoopProgress):
self._progress = progress

@property
def progress_optimization(self) -> Optional[TrainingProgress]:
return self._progress_optimization

@progress_optimization.setter
def progress_optimization(self, progress_optimization: TrainingProgress):
self._progress_optimization = progress_optimization

@property
def done(self) -> bool:
Expand All @@ -66,6 +86,11 @@ def optimizer_freq_cumsum(self) -> int:
self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
return self._optimizer_freq_cumsum

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
# TODO(@justusschock): can we make this a weakref/proxy?
void(*args, **kwargs)
self.trainer = trainer

def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict:
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks

Expand All @@ -78,6 +103,8 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict:
self.warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
return AttributeDict(signal=0, training_step_output=[[]])

self.progress.increment_ready()

# hook
self.trainer.logger_connector.on_batch_start()
response = self.trainer.call_hook("on_batch_start")
Expand All @@ -100,6 +127,10 @@ def reset(self) -> None:
self.batch_idx = 0
self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))]

if not self.trainer.is_restarting:
# reset tracking
self.progress_optimization.optimization.reset_on_epoch()

def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int):
"""Splits the data into tbptt splits

Expand All @@ -111,6 +142,14 @@ def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int):
void(batch_idx, dataloader_idx)
self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch)))

def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
self.progress.increment_started()
return super().on_advance_start(*args, **kwargs)

def on_advance_end(self) -> None:
self.progress.increment_completed()
return super().on_advance_end()

def advance(self, batch, batch_idx, dataloader_idx):
"""Runs the train step together with optimization (if necessary) on the current batch split

Expand All @@ -128,7 +167,19 @@ def advance(self, batch, batch_idx, dataloader_idx):
self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch)

if self.trainer.lightning_module.automatic_optimization:
for opt_idx, optimizer in self.get_active_optimizers(batch_idx):
active_optimizers = self.get_active_optimizers(batch_idx)
for opt_idx, optimizer in active_optimizers:

# handle optimization restart
if self.trainer.is_restarting:
if len(active_optimizers) > 1 and opt_idx < self.progress.current.completed:
continue
else:
self.trainer.is_restarting = False

# track optimizer_idx
self.progress.optimizer_idx = opt_idx

result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer)
if result:
self.batch_outputs[opt_idx].append(result.training_step_output)
Expand All @@ -138,6 +189,8 @@ def advance(self, batch, batch_idx, dataloader_idx):
if result:
self.batch_outputs[0].append(result.training_step_output)

self.progress.increment_processed()

def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int:
"""Gets the number of active optimizers based on their frequency"""
return len(self.get_active_optimizers(batch_idx))
Expand Down Expand Up @@ -217,8 +270,13 @@ def _training_step_and_backward_closure(
"""

result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens)

if result is not None:
return_result.update(result)

# this should be done only if result.loss exists
self.progress_optimization.optimization.optimizer.increment_started()

return return_result.loss

def _make_closure(self, *closure_args: Any, **closure_kwargs: Any) -> Callable:
Expand Down Expand Up @@ -250,6 +308,8 @@ def _on_after_backward(self, batch_idx: int, untouched_loss: Tensor) -> None:
# insert after step hook
self.trainer.call_hook("on_after_backward")

self.progress_optimization.optimization.optimizer.increment_ready()

# when in dev debugging track the losses
self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach())

Expand Down Expand Up @@ -400,14 +460,20 @@ def _optimizer_step(
using_lbfgs=is_lbfgs,
)

self.progress_optimization.optimization.optimizer.increment_completed()

def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
"""Calls the ``on_before_zero_grad`` hook.

Args:
optimizer: the current optimizer
"""
self.progress_optimization.optimization.zero_grad.increment_ready()

self.trainer.call_hook('on_before_zero_grad', optimizer)

self.progress_optimization.optimization.zero_grad.increment_started()

def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
"""Zeroes out all gradients of parameters optimized by the current optimizer.

Expand All @@ -418,6 +484,8 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer,
"""
self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)

self.progress_optimization.optimization.zero_grad.increment_completed()

def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]:
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.

Expand Down
44 changes: 42 additions & 2 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import TrainingLoopProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
Expand Down Expand Up @@ -48,12 +49,28 @@ def __init__(self, min_steps: int, max_steps: int):
self.is_last_batch: Optional[bool] = None

self.batch_loop: Optional[TrainingBatchLoop] = None
self._progress: Optional[TrainingLoopProgress] = None

self._dataloader_idx: Optional[int] = None
self._warning_cache: WarningCache = WarningCache()
self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
self._results = ResultCollection(training=True)

@property
def progress(self) -> TrainingLoopProgress:
if not self._progress:
self._progress = TrainingLoopProgress()
self.batch_loop.progress = self._progress.batch
self.batch_loop.progress_optimization = self._progress.epoch
return self._progress

@progress.setter
def progress(self, progress: TrainingLoopProgress) -> None:
if progress:
self.batch_loop.progress = progress.batch
self.batch_loop.progress_optimization = progress.epoch
self._progress = progress

@property
def results(self) -> ResultCollection:
return self._results
Expand All @@ -63,13 +80,21 @@ def batch_idx(self) -> int:
"""Returns the current batch index (within this epoch)"""
return self.iteration_count

@property
def total_optimizer_step(self) -> int:
return self.progress.epoch.optimization.optimizer.total.completed

@property
def current_batch_seen(self) -> int:
return self.progress.batch.current.completed

@property
def done(self) -> bool:
"""Returns whether the training should be stopped.
The criteria are that the number of steps reached the max steps,
the last batch is reached or the trainer signals to stop (e.g. by early stopping).
"""
max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps
max_steps_reached = self.max_steps is not None and (self.total_optimizer_step) >= self.max_steps
return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch)

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
Expand All @@ -88,12 +113,22 @@ def reset(self) -> None:
# track epoch output
self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]

if not self.trainer.is_restarting:
# reset tracking
self.progress.reset_on_epoch()
else:
self.batches_seen = self.current_batch_seen

def on_run_start(self, *args: Any, **kwargs: Any) -> None:
self.progress.epoch.increment_ready()

# hook
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start")
self.trainer.call_hook("on_train_epoch_start")

self.progress.epoch.increment_started()

def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
"""Runs a single training batch.

Expand Down Expand Up @@ -216,15 +251,20 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:
'HINT: remove the return statement in training_epoch_end'
)

self.progress.epoch.increment_processed()

# call train epoch end hooks
self._on_train_epoch_end_hook(processed_outputs)
self.trainer.call_hook('on_epoch_end')
self.trainer.logger_connector.on_epoch_end()

self.progress.epoch.increment_completed()

return self._epoch_output

def teardown(self) -> None:
"""Frees memory of tracked epoch outputs."""
self.epoch_output = None
self._epoch_output = None

def _on_train_epoch_end_hook(self, processed_epoch_output: List[List[STEP_OUTPUT]]) -> None:
"""Runs ``on_train_epoch_end hook``."""
Expand Down
36 changes: 32 additions & 4 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.epoch.training_epoch_loop import TrainingEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import FitLoopProgress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_info

Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(
self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs
self.epoch_loop = TrainingEpochLoop(min_steps, max_steps)
self.val_loop = EvaluationLoop()
self._progress: Optional[FitLoopProgress] = None

@property
def results(self) -> ResultCollection:
Expand Down Expand Up @@ -114,6 +116,16 @@ def max_steps(self, value: int) -> None:
# TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
self.epoch_loop.max_steps = value

@property
def total_epoch_completed(self) -> int:
"""Returns the total number of epoch completed"""
return self.progress.train.epoch.total.completed

@property
def total_optimizer_step_completed(self) -> int:
"""Returns the total number of optimizer step completed"""
return self.progress.train.epoch.optimization.optimizer.total.completed

@property
def running_loss(self) -> TensorRunningAccum:
"""Returns the running loss"""
Expand All @@ -137,14 +149,14 @@ def done(self) -> bool:
or if the maximum number of steps or epochs is reached.
"""
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
stop_steps = self.max_steps is not None and self.global_step >= self.max_steps
stop_epochs = self.max_epochs is not None and self.current_epoch >= self.max_epochs
stop_steps = self.max_steps is not None and self.total_optimizer_step_completed >= self.max_steps
stop_epochs = self.max_epochs is not None and self.total_epoch_completed >= self.max_epochs

should_stop = False
if self.trainer.should_stop:
# early stopping
met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
met_min_epochs = self.total_epoch_completed >= self.min_epochs if self.min_epochs else True
met_min_steps = self.total_optimizer_step_completed >= self.min_steps if self.min_steps else True
if met_min_epochs and met_min_steps:
should_stop = True
else:
Expand All @@ -171,8 +183,24 @@ def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
def reset(self) -> None:
"""Resets the internal state of this loop"""

@property
def progress(self) -> FitLoopProgress:
if not self._progress:
self._progress = FitLoopProgress(train=self.epoch_loop.progress)
return self._progress

@progress.setter
def progress(self, progress: FitLoopProgress) -> None:
if progress:
self._progress = progress
self.epoch_loop.progress = progress.train

def on_run_start(self) -> None:
"""Calls the ``on_train_start`` hook."""

# reset current epoch counter to 0
self.progress.train.epoch.current.reset()

self.results.to(device=self.trainer.lightning_module.device)
self.trainer.call_hook("on_train_start")

Expand Down
Loading