Skip to content

Commit f94b919

Browse files
authored
deprecated: epoch indexing from 1 (Lightning-AI#2206)
* epoch indexing from 1 * chlog * fix tests * fix tests * self.min_epochs
1 parent 8870a84 commit f94b919

File tree

9 files changed

+16
-19
lines changed

9 files changed

+16
-19
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4848
- Raise an error when lightning replaces an existing sampler ([#2020](https://github.com/PyTorchLightning/pytorch-lightning/pull/2020))
4949
- Enabled prepare_data from correct processes - clarify local vs global rank ([#2166](https://github.com/PyTorchLightning/pytorch-lightning/pull/2166))
5050
- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))
51+
- Changed epoch/step indexing from 1 instead of 0 ([#2206](https://github.com/PyTorchLightning/pytorch-lightning/pull/2206))
5152

5253
### Deprecated
5354

pytorch_lightning/callbacks/gradient_accumulation_scheduler.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,9 @@ def __init__(self, scheduling: dict):
4242

4343
for key in scheduling:
4444
if not isinstance(key, int) or not isinstance(scheduling[key], int):
45-
raise TypeError("All epoches and accumulation factor must be integers")
45+
raise TypeError("All epochs and accumulation factor must be integers")
4646

4747
minimal_epoch = min(scheduling.keys())
48-
# rank_zero_warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,'
49-
# ' but will start from "0" in v0.8.0.', DeprecationWarning)
5048
if minimal_epoch < 1:
5149
raise IndexError(f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct")
5250
if minimal_epoch != 1: # if user didnt define first epoch accumulation factor
@@ -56,9 +54,7 @@ def __init__(self, scheduling: dict):
5654
self.epochs = sorted(scheduling.keys())
5755

5856
def on_epoch_start(self, trainer, pl_module):
59-
# indexing epochs from 1 (until v0.6.x)
60-
# In v0.8.0, ` + 1` should be removed.
61-
epoch = trainer.current_epoch + 1
57+
epoch = trainer.current_epoch
6258
for i in reversed(range(len(self.epochs))):
6359
if epoch >= self.epochs[i]:
6460
trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i])

pytorch_lightning/callbacks/progress.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def total_val_batches(self) -> int:
9696
if trainer.fast_dev_run and trainer.val_dataloaders is not None:
9797
total_val_batches = len(trainer.val_dataloaders)
9898
elif not self.trainer.disable_validation:
99-
is_val_epoch = (trainer.current_epoch + 1) % trainer.check_val_every_n_epoch == 0
99+
is_val_epoch = trainer.current_epoch % trainer.check_val_every_n_epoch == 0
100100
total_val_batches = trainer.num_val_batches if is_val_epoch else 0
101101
return total_val_batches
102102

@@ -317,7 +317,7 @@ def on_epoch_start(self, trainer, pl_module):
317317
total_batches = total_train_batches + total_val_batches
318318
if not self.main_progress_bar.disable:
319319
self.main_progress_bar.reset(convert_inf(total_batches))
320-
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}')
320+
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}')
321321

322322
def on_batch_end(self, trainer, pl_module):
323323
super().on_batch_end(trainer, pl_module)

pytorch_lightning/trainer/training_io.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
323323
structured dictionary
324324
"""
325325
checkpoint = {
326-
'epoch': self.current_epoch + 1,
326+
'epoch': self.current_epoch,
327327
'global_step': self.global_step + 1,
328328
'pytorch-ligthning_version': pytorch_lightning.__version__,
329329
}

pytorch_lightning/trainer/training_loop.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,8 @@ def train(self):
336336
model.on_train_start()
337337

338338
try:
339-
# run all epochs
340-
for epoch in range(self.current_epoch, self.max_epochs):
339+
# run all epochs from actual + 1 till the maximal
340+
for epoch in range(self.current_epoch + 1, self.max_epochs + 1):
341341
# reset train dataloader
342342
if self.reload_dataloaders_every_epoch:
343343
self.reset_train_dataloader(model)
@@ -372,7 +372,7 @@ def train(self):
372372
self.update_learning_rates(interval='epoch')
373373

374374
# early stopping
375-
met_min_epochs = epoch >= self.min_epochs - 1
375+
met_min_epochs = epoch >= self.min_epochs
376376
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
377377

378378
# TODO wrap this logic into the callback
@@ -466,7 +466,7 @@ def run_training_epoch(self):
466466
# RUN VAL STEP
467467
# ---------------
468468
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
469-
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
469+
can_check_epoch = self.current_epoch % self.check_val_every_n_epoch == 0
470470
can_check_val = not self.disable_validation and can_check_epoch
471471
should_check_val = is_val_check_batch or early_stop_epoch
472472
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))

tests/callbacks/test_callbacks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def training_step(self, *args, **kwargs):
248248
result = trainer.fit(model)
249249

250250
assert result == 1, 'training failed to complete'
251-
assert trainer.current_epoch < trainer.max_epochs
251+
assert trainer.current_epoch <= trainer.max_epochs
252252

253253

254254
def test_pickling(tmpdir):

tests/models/test_hooks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def training_epoch_end(self, outputs):
6868
# a metric shared in both methods gets overwritten by epoch_end
6969
assert metrics['shared_metric'] == 111
7070
# metrics are kept after each epoch
71-
for i in range(num_epochs):
71+
for i in range(1, num_epochs + 1):
7272
assert metrics[f'epoch_metric_{i}'] == i
7373

7474

tests/models/test_restore.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def test_dp_resume(tmpdir):
172172
result = trainer.fit(model)
173173

174174
# track epoch before saving. Increment since we finished the current epoch, don't want to rerun
175-
real_global_epoch = trainer.current_epoch + 1
175+
real_global_epoch = trainer.current_epoch
176176

177177
# correct result and ok accuracy
178178
assert result == 1, 'amp + dp model failed to complete'

tests/trainer/test_trainer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def test_trainer_max_steps_and_epochs(tmpdir):
451451

452452
# check training stopped at max_epochs
453453
assert trainer.global_step == num_train_samples * trainer.max_epochs
454-
assert trainer.current_epoch == trainer.max_epochs - 1, "Model did not stop at max_epochs"
454+
assert trainer.current_epoch == trainer.max_epochs, "Model did not stop at max_epochs"
455455

456456

457457
def test_trainer_min_steps_and_epochs(tmpdir):
@@ -619,7 +619,7 @@ def validation_epoch_end(self, *args, **kwargs):
619619

620620
# check that val_percent_check=0 turns off validation
621621
assert result == 1, 'training failed to complete'
622-
assert trainer.current_epoch == 1
622+
assert trainer.current_epoch == 2
623623
assert not model.validation_step_invoked, \
624624
'`validation_step` should not run when `val_percent_check=0`'
625625
assert not model.validation_epoch_end_invoked, \
@@ -632,7 +632,7 @@ def validation_epoch_end(self, *args, **kwargs):
632632
result = trainer.fit(model)
633633

634634
assert result == 1, 'training failed to complete'
635-
assert trainer.current_epoch == 0
635+
assert trainer.current_epoch == 1
636636
assert model.validation_step_invoked, \
637637
'did not run `validation_step` with `fast_dev_run=True`'
638638
assert model.validation_epoch_end_invoked, \

0 commit comments

Comments
 (0)