-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
您的功能請求是否與問題相關?請描述。
CustomTQDMProgressBar 的unit_scale應該每個模式設定都分開
描述您想要的解決方案
藉由Callback的interface,動態設定unit_scale。
下面是我的實作,這個寫法相依於lightning==2.5.3的TQDMProgressBar,要注意的是之後是否有官方更新到這裡的實作。
from typing import Any
from lightning.pytorch.callbacks.progress.tqdm_progress import TQDMProgressBar, _update_n, convert_inf
class CustomTQDMProgressBar(TQDMProgressBar):
def on_train_epoch_start(self, trainer, *_):
super().on_train_epoch_start(trainer, *_)
self.train_progress_bar.unit_scale = trainer.train_dataloader.batch_size
def on_validation_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
):
if not self.has_dataloader_changed(dataloader_idx):
return
dataset_size = len(self.trainer.val_dataloaders[dataloader_idx].dataset)
self.val_progress_bar.reset(dataset_size)
self.val_progress_bar.initial = 0
desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")
self.val_batch_size = self.trainer.val_dataloaders[dataloader_idx].batch_size
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
key = list(batch.keys())[0] # Assuming batch is a dict-like object
n = self.val_batch_size * batch_idx + len(batch[key])
if self._should_update(n, self.val_progress_bar.total):
_update_n(self.val_progress_bar, n)
def on_predict_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
):
if not self.has_dataloader_changed(dataloader_idx):
return
self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader))
self.predict_progress_bar.initial = 0
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")
self.predict_batch_size = self.trainer.predict_dataloaders[dataloader_idx].batch_size
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
key = list(batch.keys())[0] # Assuming batch is a dict-like object
n = self.predict_batch_size * batch_idx + len(batch[key])
if self._should_update(n, self.predict_progress_bar.total):
_update_n(self.predict_progress_bar, n)
def on_test_batch_start(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
):
if not self.has_dataloader_changed(dataloader_idx):
return
self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader))
self.test_progress_bar.initial = 0
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")
self.test_batch_size = self.trainer.test_dataloaders[dataloader_idx].batch_size
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
key = list(batch.keys())[0] # Assuming batch is a dict-like object
n = self.test_batch_size * batch_idx + len(batch[key])
if self._should_update(n, self.test_progress_bar.total):
_update_n(self.test_progress_bar, n)Metadata
Metadata
Assignees
Labels
No labels