Skip to content

[功能建議] Make CustomTQDMProgressBar greater #1

@kunkunlin1221

Description

@kunkunlin1221

您的功能請求是否與問題相關?請描述。
CustomTQDMProgressBar 的unit_scale應該每個模式設定都分開

描述您想要的解決方案
藉由Callback的interface,動態設定unit_scale。

下面是我的實作,這個寫法相依於lightning==2.5.3的TQDMProgressBar,要注意的是之後是否有官方更新到這裡的實作。

from typing import Any

from lightning.pytorch.callbacks.progress.tqdm_progress import TQDMProgressBar, _update_n, convert_inf


class CustomTQDMProgressBar(TQDMProgressBar):
    def on_train_epoch_start(self, trainer, *_):
        super().on_train_epoch_start(trainer, *_)
        self.train_progress_bar.unit_scale = trainer.train_dataloader.batch_size

    def on_validation_batch_start(
        self,
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        if not self.has_dataloader_changed(dataloader_idx):
            return

        dataset_size = len(self.trainer.val_dataloaders[dataloader_idx].dataset)
        self.val_progress_bar.reset(dataset_size)
        self.val_progress_bar.initial = 0
        desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
        self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")
        self.val_batch_size = self.trainer.val_dataloaders[dataloader_idx].batch_size

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        key = list(batch.keys())[0]  # Assuming batch is a dict-like object
        n = self.val_batch_size * batch_idx + len(batch[key])
        if self._should_update(n, self.val_progress_bar.total):
            _update_n(self.val_progress_bar, n)

    def on_predict_batch_start(
        self,
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        if not self.has_dataloader_changed(dataloader_idx):
            return

        self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader))
        self.predict_progress_bar.initial = 0
        self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")
        self.predict_batch_size = self.trainer.predict_dataloaders[dataloader_idx].batch_size

    def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        key = list(batch.keys())[0]  # Assuming batch is a dict-like object
        n = self.predict_batch_size * batch_idx + len(batch[key])
        if self._should_update(n, self.predict_progress_bar.total):
            _update_n(self.predict_progress_bar, n)

    def on_test_batch_start(
        self,
        trainer: "pl.Trainer",
        pl_module: "pl.LightningModule",
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ):
        if not self.has_dataloader_changed(dataloader_idx):
            return
        self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader))
        self.test_progress_bar.initial = 0
        self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")
        self.test_batch_size = self.trainer.test_dataloaders[dataloader_idx].batch_size

    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        key = list(batch.keys())[0]  # Assuming batch is a dict-like object
        n = self.test_batch_size * batch_idx + len(batch[key])
        if self._should_update(n, self.test_progress_bar.total):
            _update_n(self.test_progress_bar, n)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions