Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add is_last_batch to progress tracking #9657

Merged
merged 10 commits into from Sep 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Progress tracking
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598))
* Add `BatchProgress` and integrate `TrainingEpochLoop.is_last_batch` ([#9657](https://github.com/PyTorchLightning/pytorch-lightning/pull/9657))
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320))
* Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371))
* Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561))
Expand Down
23 changes: 10 additions & 13 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.loops.optimization.closure import OutputResult
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand All @@ -43,9 +43,7 @@ def __init__(self, min_steps: int, max_steps: int):
self.max_steps: int = max_steps

self.global_step: int = 0
# manually tracking which is the last batch is necessary for iterable dataset support
self.is_last_batch: Optional[bool] = None
self.batch_progress = Progress()
self.batch_progress = BatchProgress()
self.scheduler_progress = SchedulerProgress()

self.batch_loop: Optional[TrainingBatchLoop] = None
Expand Down Expand Up @@ -94,18 +92,16 @@ def reset(self) -> None:
assert self.batch_loop is not None
assert self.batch_loop.optimizer_loop is not None
if self.restarting:
self.batch_progress.current.reset_on_restart()
self.batch_progress.reset_on_restart()
self.scheduler_progress.current.reset_on_restart()
self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()

self.is_last_batch = False

# track epoch output
self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]

if not self.restarting or self._num_training_batches_reached():
self.batch_progress.current.reset()
self.scheduler_progress.current.reset()
self.batch_progress.reset_on_epoch()
self.scheduler_progress.reset_on_epoch()
self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch()

def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
Expand All @@ -127,6 +123,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
StopIteration: When the epoch is canceled by the user returning -1
"""
batch_idx, (batch, is_last) = next(self.dataloader_iter)
self.batch_progress.is_last_batch = is_last

if not self.trainer.data_connector.train_data_fetcher.store_on_device:
with self.trainer.profiler.profile("training_batch_to_device"):
Expand All @@ -139,8 +136,6 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

self.batch_progress.increment_processed()

self.is_last_batch = is_last

# when returning -1 from train_step, we end epoch early
if batch_output.signal == -1:
raise StopIteration
Expand Down Expand Up @@ -178,7 +173,7 @@ def on_advance_end(self):
# -----------------------------------------
# VALIDATE IF NEEDED + CHECKPOINT CALLBACK
# -----------------------------------------
should_check_val = self._should_check_val_fx(self.batch_idx, self.is_last_batch)
should_check_val = self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch)
if should_check_val:
self.trainer.validating = True
self._run_validation()
Expand Down Expand Up @@ -259,7 +254,9 @@ def _accumulated_batches_reached(self) -> bool:

def _num_training_batches_reached(self) -> bool:
"""Checks if we are in the last batch or if there are more batches to follow."""
return self.batch_progress.current.ready == self.trainer.num_training_batches or self.is_last_batch
return (
self.batch_progress.current.ready == self.trainer.num_training_batches or self.batch_progress.is_last_batch
)

def _should_accumulate(self) -> bool:
"""Checks if the optimizer step should be performed or gradients should be accumulated for the current
Expand Down
36 changes: 30 additions & 6 deletions pytorch_lightning/trainer/progress.py
Expand Up @@ -148,6 +148,9 @@ def from_defaults(cls, tracker_cls: Type[ReadyCompletedTracker], **kwargs: int)
"""Utility function to easily create an instance from keyword arguments to both ``Tracker``s."""
return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs))

def reset_on_epoch(self) -> None:
self.current.reset()

def reset_on_restart(self) -> None:
self.current.reset_on_restart()

Expand All @@ -158,8 +161,9 @@ def load_state_dict(self, state_dict: dict) -> None:

@dataclass
class DataLoaderProgress(Progress):
"""Tracks the dataloader progress These counters are local to a trainer rank. By default, they are not globally
synced across all ranks.
"""Tracks dataloader progress.

These counters are local to a trainer rank. By default, they are not globally synced across all ranks.

Args:
total: Tracks the total dataloader progress.
Expand All @@ -170,10 +174,30 @@ class DataLoaderProgress(Progress):
current: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)


@dataclass
class BatchProgress(Progress):
"""Tracks batch progress.

These counters are local to a trainer rank. By default, they are not globally synced across all ranks.

Args:
total: Tracks the total dataloader progress.
current: Tracks the current dataloader progress.
is_last_batch: Whether the batch is the last one. This is useful for iterable datasets.
"""

is_last_batch: bool = False

def reset_on_epoch(self) -> None:
super().reset_on_epoch()
self.is_last_batch = False


@dataclass
class SchedulerProgress(Progress):
"""Tracks the scheduler progress. These counters are local to a trainer rank. By default, they are not globally
synced across all ranks.
"""Tracks scheduler progress.

These counters are local to a trainer rank. By default, they are not globally synced across all ranks.

Args:
total: Tracks the total scheduler progress.
Expand All @@ -197,8 +221,8 @@ class OptimizerProgress(BaseProgress):
zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(StartedTracker))

def reset_on_epoch(self) -> None:
self.step.current.reset()
self.zero_grad.current.reset()
self.step.reset_on_epoch()
self.zero_grad.reset_on_epoch()

def reset_on_restart(self) -> None:
self.step.reset_on_restart()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Expand Up @@ -1892,7 +1892,7 @@ def min_steps(self) -> Optional[int]:

@property
def is_last_batch(self) -> bool:
return self.fit_loop.epoch_loop.is_last_batch
return self.fit_loop.epoch_loop.batch_progress.is_last_batch

@property
def fit_loop(self) -> FitLoop:
Expand Down
1 change: 1 addition & 0 deletions tests/loops/test_loop_state_dict.py
Expand Up @@ -51,6 +51,7 @@ def test_loops_state_dict_structure():
"epoch_loop.batch_progress": {
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
"current": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
"is_last_batch": False,
},
"epoch_loop.scheduler_progress": {
"total": {"ready": 0, "completed": 0},
Expand Down
11 changes: 10 additions & 1 deletion tests/loops/test_loops.py
Expand Up @@ -20,7 +20,9 @@

import pytest
import torch
from torch.utils.data import DataLoader

from pl_examples.bug_report_model import RandomDataset
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loops import Loop, TrainingBatchLoop
Expand Down Expand Up @@ -443,6 +445,7 @@ def configure_optimizers_multiple(self):
"processed": stop_batch,
"completed": stop_batch,
},
"is_last_batch": False,
},
"epoch_loop.scheduler_progress": {
"total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps},
Expand Down Expand Up @@ -548,13 +551,16 @@ def configure_optimizers_multiple(self):

return optimizers, lr_schedulers

def train_dataloader(self):
# override to test the `is_last_batch` value
return DataLoader(RandomDataset(32, n_batches))

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
limit_val_batches=0,
accumulate_grad_batches=accumulate_grad_batches,
progress_bar_refresh_rate=0,
Expand All @@ -563,6 +569,8 @@ def configure_optimizers_multiple(self):
)
trainer.fit(model)

assert trainer.num_training_batches == n_batches

ckpt_path = trainer.checkpoint_callback.best_model_path
assert os.path.exists(ckpt_path)
checkpoint = torch.load(ckpt_path)
Expand Down Expand Up @@ -607,6 +615,7 @@ def configure_optimizers_multiple(self):
"processed": n_batches,
"completed": n_batches,
},
"is_last_batch": True,
},
"epoch_loop.scheduler_progress": {
"total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total},
Expand Down