Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix resuming the tqdm progress bar #13962

Merged
merged 13 commits into from
Aug 2, 2022
77 changes: 77 additions & 0 deletions repro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
import shutil
from time import sleep

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint


class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)

def __getitem__(self, index):
return self.data[index]

def __len__(self):
return self.len


class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)

def forward(self, x):
sleep(1)
return self.layer(x)

def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
print()
return {"loss": loss}

def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("val_loss", loss)
print()

def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
train_data = DataLoader(RandomDataset(32, 10), batch_size=2)

if os.path.exists("lightning_logs"):
shutil.rmtree("lightning_logs")

model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
max_epochs=1,
enable_model_summary=False,
enable_progress_bar=False,
callbacks=ModelCheckpoint(monitor="train_loss", save_top_k=-1, every_n_train_steps=1),
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=train_data)

trainer = Trainer(
default_root_dir=os.getcwd(),
max_epochs=3,
enable_model_summary=False,
enable_progress_bar=True,
callbacks=ModelCheckpoint(monitor="train_loss", save_top_k=-1, every_n_train_steps=1),
)
trainer.fit(
model, train_dataloaders=train_data, ckpt_path="lightning_logs/version_0/checkpoints/epoch=0-step=3.ckpt"
)


if __name__ == "__main__":
run()
5 changes: 4 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed default `amp_level` for `DeepSpeedPrecisionPlugin` to `O2` ([#13897](https://github.com/PyTorchLightning/pytorch-lightning/pull/13897))


- Fixed `TQDMProgressBar` reset and update to show correct time estimation (2/2) ([#13962](https://github.com/Lightning-AI/lightning/pull/13962))



## [1.6.5] - 2022-07-13

Expand Down Expand Up @@ -454,7 +457,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `fuse_modules` to be qat-aware for `torch>=1.11` ([#12891](https://github.com/Lightning-AI/lightning/pull/12891))
- Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/Lightning-AI/lightning/pull/12653))
- Enable mixed precision in `DDPFullyShardedStrategy` when `precision=16` ([#12965](https://github.com/Lightning-AI/lightning/pull/12965))
- Fixed `TQDMProgressBar` reset and update to show correct time estimation ([#12889](https://github.com/Lightning-AI/lightning/pull/12889))
- Fixed `TQDMProgressBar` reset and update to show correct time estimation (1/2) ([#12889](https://github.com/Lightning-AI/lightning/pull/12889))
- Fixed fit loop restart logic to enable resume using the checkpoint ([#12821](https://github.com/Lightning-AI/lightning/pull/12821))


Expand Down
22 changes: 12 additions & 10 deletions src/pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,13 @@ def on_train_start(self, *_: Any) -> None:
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
total_batches = self.total_batches_current_epoch
self.main_progress_bar.reset(convert_inf(total_batches))
self.main_progress_bar.initial = 0
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
current = self.train_batch_idx + self._val_processed
if self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current, self.refresh_rate)
_update_n(self.main_progress_bar, current)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand All @@ -280,16 +281,17 @@ def on_validation_batch_start(
return

self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader))
self.val_progress_bar.initial = 0
desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")

def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._should_update(self.val_batch_idx, self.val_progress_bar.total):
_update_n(self.val_progress_bar, self.val_batch_idx, self.refresh_rate)
_update_n(self.val_progress_bar, self.val_batch_idx)

current = self.train_batch_idx + self._val_processed
if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current, self.refresh_rate)
_update_n(self.main_progress_bar, current)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._main_progress_bar is not None and trainer.state.fn == "fit":
Expand All @@ -307,11 +309,12 @@ def on_test_batch_start(
return

self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader))
self.test_progress_bar.initial = 0
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")

def on_test_batch_end(self, *_: Any) -> None:
if self._should_update(self.test_batch_idx, self.test_progress_bar.total):
_update_n(self.test_progress_bar, self.test_batch_idx, self.refresh_rate)
_update_n(self.test_progress_bar, self.test_batch_idx)

def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.test_progress_bar.close()
Expand All @@ -327,11 +330,12 @@ def on_predict_batch_start(
return

self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader))
self.predict_progress_bar.initial = 0
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")

def on_predict_batch_end(self, *_: Any) -> None:
if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total):
_update_n(self.predict_progress_bar, self.predict_batch_idx, self.refresh_rate)
_update_n(self.predict_progress_bar, self.predict_batch_idx)

def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.predict_progress_bar.close()
Expand Down Expand Up @@ -375,9 +379,7 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
return x


def _update_n(bar: _tqdm, current: int, refresh_rate: int) -> None:
def _update_n(bar: _tqdm, value: int) -> None:
if not bar.disable:
total = bar.total
leftover = current % refresh_rate
advance = leftover if (current == total and leftover != 0) else refresh_rate
bar.update(advance)
bar.n = value
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
bar.refresh()