Skip to content

Commit

Permalink
Fix self.log(on_epoch=True) on_batch_start (Lightning-AI#9780)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored and rohitgr7 committed Oct 18, 2021
1 parent da126ba commit db8470a
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 86 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Expand Up @@ -530,11 +530,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `broadcast` in `DDPPlugin` and ``DDPSpawnPlugin` to respect the `src` input ([#9691](https://github.com/PyTorchLightning/pytorch-lightning/pull/9691))


- Fixed `self.log(on_epoch=True)` for the `on_batch_start` and `on_train_batch_start` hooks ([#9780](https://github.com/PyTorchLightning/pytorch-lightning/pull/9780))


- Fixed restoring training state during `trainer.fit` only ([#9413](https://github.com/PyTorchLightning/pytorch-lightning/pull/9413))


- Fixed DeepSpeed and Lightning both calling the scheduler ([#9788](https://github.com/PyTorchLightning/pytorch-lightning/pull/9788))


- Fixed missing arguments when saving hyperparameters from the parent class but not from the child class ([#9800](https://github.com/PyTorchLightning/pytorch-lightning/pull/9800))


Expand Down
49 changes: 5 additions & 44 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Expand Up @@ -23,9 +23,6 @@
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop
from pytorch_lightning.loops.utilities import _get_active_optimizers
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import WarningCache

_OUTPUTS_TYPE = List[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]]

Expand All @@ -43,7 +40,6 @@ def __init__(self) -> None:
self.manual_loop = ManualOptimization()

self._outputs: _OUTPUTS_TYPE = []
self._warning_cache: WarningCache = WarningCache()
self._remaining_splits: Optional[List[Any]] = None

@property
Expand All @@ -59,42 +55,6 @@ def connect(
if manual_loop is not None:
self.manual_loop = manual_loop

def run(self, batch: Any, batch_idx: int) -> AttributeDict:
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks.
Args:
batch: the current batch to run the train step on
batch_idx: the index of the current batch
"""
if batch is None:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
return AttributeDict(signal=0, outputs=[])

# hook
self.trainer.logger_connector.on_batch_start()
response = self.trainer.call_hook("on_batch_start")
if response == -1:
return AttributeDict(signal=-1)

# hook
# TODO: Update this in v1.7 (deprecation: #9816)
model_fx = self.trainer.lightning_module.on_train_batch_start
extra_kwargs = (
{"dataloader_idx": 0}
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
else {}
)
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
if response == -1:
return AttributeDict(signal=-1)

self.trainer.fit_loop.epoch_loop.batch_progress.increment_started()

super().run(batch, batch_idx)

output, self._outputs = AttributeDict(signal=0, outputs=self._outputs), None # free memory
return output

def reset(self) -> None:
"""Resets the loop state."""
self._outputs = []
Expand All @@ -117,11 +77,10 @@ def advance(self, batch, batch_idx):
batch_idx: the index of the current batch
"""
void(batch)
split_idx, split_batch = self._remaining_splits.pop(0)
self.split_idx = split_idx
self.split_idx, split_batch = self._remaining_splits.pop(0)

# let logger connector extract current batch size
self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch)
self.trainer.logger_connector.on_train_split_start(self.split_idx, split_batch)

# choose which loop will run the optimization
if self.trainer.lightning_module.automatic_optimization:
Expand All @@ -135,10 +94,12 @@ def advance(self, batch, batch_idx):
# then `advance` doesn't finish and an empty dict is returned
self._outputs.append(outputs)

def on_run_end(self) -> None:
def on_run_end(self) -> _OUTPUTS_TYPE:
self.optimizer_loop._hiddens = None
# this is not necessary as the manual loop runs for only 1 iteration, but just in case
self.manual_loop._hiddens = None
output, self._outputs = self._outputs, None # free memory
return output

def teardown(self) -> None:
# release memory
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Expand Up @@ -233,10 +233,10 @@ def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
Raises:
AssertionError: If the number of dataloaders is None (has not yet been set).
"""
self.trainer.logger_connector.on_batch_start()
self.trainer.logger_connector.on_batch_start(batch_idx)

assert self._num_dataloaders is not None
self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self._num_dataloaders)
self.trainer.logger_connector.on_evaluation_batch_start(batch, dataloader_idx, self._num_dataloaders)

if self.trainer.testing:
self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx)
Expand Down
41 changes: 33 additions & 8 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Expand Up @@ -28,6 +28,7 @@
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import WarningCache

_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]

Expand Down Expand Up @@ -57,6 +58,7 @@ def __init__(self, min_steps: int, max_steps: int):

self._results = ResultCollection(training=True)
self._outputs: _OUTPUTS_TYPE = []
self._warning_cache = WarningCache()
self._dataloader_iter: Optional[Iterator] = None
# caches the loaded dataloader state until dataloader objects are available
self._dataloader_state_dict: Dict[str, Any] = {}
Expand Down Expand Up @@ -151,14 +153,37 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

self.batch_progress.increment_ready()

with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, batch_idx)
if batch is None:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
batch_output = []
else:
# hook
self.trainer.logger_connector.on_batch_start(batch_idx)
response = self.trainer.call_hook("on_batch_start")
if response == -1:
self.batch_progress.increment_processed()
raise StopIteration

# TODO: Update this in v1.7 (deprecation: #9816)
model_fx = self.trainer.lightning_module.on_train_batch_start
extra_kwargs = (
{"dataloader_idx": 0}
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
else {}
)

self.batch_progress.increment_processed()
# hook
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
if response == -1:
self.batch_progress.increment_processed()
raise StopIteration

# when returning -1 from train_step, we end epoch early
if batch_output.signal == -1:
raise StopIteration
self.batch_progress.increment_started()

with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(batch, batch_idx)

self.batch_progress.increment_processed()

# update non-plateau LR schedulers
# update epoch-interval ones only when we are at the end of training epoch
Expand All @@ -167,7 +192,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
self.update_lr_schedulers("epoch", update_plateau_schedulers=False)

batch_end_outputs = self._prepare_outputs_training_batch_end(
batch_output.outputs,
batch_output,
automatic=self.trainer.lightning_module.trainer.lightning_module.automatic_optimization,
num_optimizers=len(self.trainer.optimizers),
)
Expand All @@ -186,7 +211,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
self.batch_progress.increment_completed()

if is_overridden("training_epoch_end", self.trainer.lightning_module):
self._outputs.append(batch_output.outputs)
self._outputs.append(batch_output)

# -----------------------------------------
# SAVE METRICS TO LOGGERS AND PROGRESS_BAR
Expand Down
Expand Up @@ -138,15 +138,14 @@ def _increment_eval_log_step(self) -> None:
elif self.trainer.state.stage is RunningStage.TESTING:
self._test_log_step += 1

def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None:
def on_evaluation_batch_start(self, batch: Any, dataloader_idx: int, num_dataloaders: int) -> None:
model = self.trainer.lightning_module
# set dataloader_idx only if multiple ones
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None

# track batch_size
assert self.trainer._results is not None
self.trainer._results.extract_batch_size(batch)
self._batch_idx = batch_idx

def update_eval_step_metrics(self) -> None:
if self.trainer.sanity_checking:
Expand Down Expand Up @@ -213,14 +212,12 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:
Train metric updates
"""

def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None:
def on_train_split_start(self, split_idx: int, split_batch: Any) -> None:
assert self.trainer._results is not None
# when the user requests `dataloader_iter`, we can't track the batch_size
# and this is left to user responsibility.
if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher):
self.trainer._results.extract_batch_size(split_batch)

self._batch_idx = batch_idx
self._split_idx = split_idx

def update_train_step_metrics(self) -> None:
Expand Down Expand Up @@ -267,7 +264,8 @@ def _log_gpus_metrics(self) -> None:
def on_epoch_start(self) -> None:
self._epoch_end_reached = False

def on_batch_start(self) -> None:
def on_batch_start(self, batch_idx: int) -> None:
self._batch_idx = batch_idx
self._epoch_end_reached = False

def epoch_end_reached(self) -> None:
Expand Down
8 changes: 2 additions & 6 deletions tests/loops/test_evaluation_loop_flow.py
Expand Up @@ -64,10 +64,8 @@ def backward(self, loss, optimizer, optimizer_idx):
# simulate training manually
trainer.state.stage = RunningStage.TRAINING
batch_idx, batch = 0, next(iter(model.train_dataloader()))
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
assert out.signal == 0
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)

train_step_out = out.outputs
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out["loss"], torch.Tensor)
Expand Down Expand Up @@ -129,10 +127,8 @@ def backward(self, loss, optimizer, optimizer_idx):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
assert out.signal == 0
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)

train_step_out = out.outputs
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out["loss"], torch.Tensor)
Expand Down
30 changes: 10 additions & 20 deletions tests/loops/test_training_loop_flow_scalar.py
Expand Up @@ -147,10 +147,8 @@ def backward(self, loss, optimizer, optimizer_idx):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
assert out.signal == 0
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)

train_step_out = out.outputs
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out["loss"], torch.Tensor)
Expand Down Expand Up @@ -221,10 +219,8 @@ def backward(self, loss, optimizer, optimizer_idx):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
assert out.signal == 0
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)

train_step_out = out.outputs
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
assert isinstance(train_step_out["loss"], torch.Tensor)
Expand Down Expand Up @@ -311,8 +307,7 @@ def training_step(self, batch, batch_idx):
for batch_idx, batch in enumerate(model.train_dataloader()):
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
if not batch_idx % 2:
assert out.outputs == []
assert out.signal == 0
assert out == []


def test_training_step_none_batches(tmpdir):
Expand All @@ -321,7 +316,6 @@ def test_training_step_none_batches(tmpdir):
class TestModel(BoringModel):
def __init__(self):
super().__init__()

self.counter = 0

def collate_none_when_even(self, batch):
Expand All @@ -333,12 +327,17 @@ def collate_none_when_even(self, batch):
return result

def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), collate_fn=self.collate_none_when_even)
return DataLoader(RandomDataset(32, 4), collate_fn=self.collate_none_when_even)

def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
if batch_idx % 2 == 0:
assert outputs == []
else:
assert outputs

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=4,
limit_val_batches=1,
max_epochs=4,
enable_model_summary=False,
Expand All @@ -348,12 +347,3 @@ def train_dataloader(self):

with pytest.warns(UserWarning, match=r".*train_dataloader yielded None.*"):
trainer.fit(model)

trainer.state.stage = RunningStage.TRAINING

# manually check a few batches
for batch_idx, batch in enumerate(model.train_dataloader()):
out = trainer.fit_loop.epoch_loop.batch_loop.run(batch, batch_idx)
if not batch_idx % 2:
assert out.outputs == []
assert out.signal == 0
12 changes: 12 additions & 0 deletions tests/trainer/logging_/test_train_loop_logging.py
Expand Up @@ -276,11 +276,21 @@ def on_train_epoch_start(self, _, pl_module):
pl_module, "on_train_epoch_start", on_steps=self.choices, on_epochs=[True], prob_bars=self.choices
)

def on_batch_start(self, _, pl_module, *__):
self.make_logging(
pl_module, "on_batch_start", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
)

def on_batch_end(self, _, pl_module):
self.make_logging(
pl_module, "on_batch_end", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
)

def on_train_batch_start(self, _, pl_module, *__):
self.make_logging(
pl_module, "on_train_batch_start", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
)

def on_train_batch_end(self, _, pl_module, *__):
self.make_logging(
pl_module, "on_train_batch_end", on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices
Expand Down Expand Up @@ -323,7 +333,9 @@ def training_step(self, batch, batch_idx):
"on_train_start": 1,
"on_epoch_start": 1,
"on_train_epoch_start": 1,
"on_train_batch_start": 2,
"on_train_batch_end": 2,
"on_batch_start": 2,
"on_batch_end": 2,
"on_train_epoch_end": 1,
"on_epoch_end": 1,
Expand Down

0 comments on commit db8470a

Please sign in to comment.