Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements and changes to progress tracking dataclasses #8140

Merged
merged 4 commits into from Jun 29, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Expand Up @@ -36,6 +36,7 @@

# Specifics
/pytorch_lightning/trainer/connectors/logger_connector @tchaton @carmocca
/pytorch_lightning/trainer/progress.py @tchaton @awaelchli @carmocca

# Metrics
/pytorch_lightning/metrics/ @SkafteNicki @ananyahjha93 @justusschock
Expand Down
9 changes: 5 additions & 4 deletions CHANGELOG.md
Expand Up @@ -30,9 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for checkpointing based on a provided time interval during training ([#7515](https://github.com/PyTorchLightning/pytorch-lightning/pull/7515))


- Added dataclasses for progress tracking (
[#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603),
[#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574))
- Progress tracking
* Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574))


- Added support for passing a `LightningDataModule` positionally as the second argument to `trainer.{validate,test,predict}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431))
Expand Down Expand Up @@ -85,12 +84,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fault-tolerant training
* Add `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948))
* Checkpoint the loop results ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))
carmocca marked this conversation as resolved.
Show resolved Hide resolved


- Add `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))


- Add `metric_attribute` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966))
carmocca marked this conversation as resolved.
Show resolved Hide resolved


- Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734))


Expand Down
164 changes: 126 additions & 38 deletions pytorch_lightning/trainer/progress.py
Expand Up @@ -11,12 +11,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from typing import Optional


@dataclass
class Tracker:
class _DataclassStateDictMixin:

def __getstate__(self) -> dict:
return asdict(self)

def __setstate__(self, state: dict) -> None:
self.__dict__.update(state)

def state_dict(self) -> dict:
return self.__getstate__()

@classmethod
def load_state_dict(cls, state_dict: dict) -> "_DataclassStateDictMixin":
carmocca marked this conversation as resolved.
Show resolved Hide resolved
obj = cls()
obj.__setstate__(state_dict)
return obj


@dataclass
class Tracker(_DataclassStateDictMixin):
"""
Track an event's progress.

Expand All @@ -28,6 +47,7 @@ class Tracker:

Attributes set to ``None`` are treated as unused and are restricted.
"""

ready: Optional[int] = 0
started: Optional[int] = 0
processed: Optional[int] = 0
Expand Down Expand Up @@ -55,14 +75,15 @@ def __repr__(self):


@dataclass
class Progress:
class Progress(_DataclassStateDictMixin):
"""
Track aggregated and current progress.

Args:
total: Intended to track the total progress of an event
current: Intended to track the current progress of an event
"""

total: Tracker = field(default_factory=Tracker)
current: Tracker = field(default_factory=Tracker)

Expand Down Expand Up @@ -91,90 +112,157 @@ def increment_completed(self) -> None:
self.current.completed += 1

@classmethod
def from_defaults(cls, **kwargs: Optional[int]) -> 'Progress':
def from_defaults(cls, **kwargs: Optional[int]) -> "Progress":
return cls(total=Tracker(**kwargs), current=Tracker(**kwargs))

def __setstate__(self, state: dict) -> None:
self.total.__setstate__(state["total"])
self.current.__setstate__(state["current"])

@dataclass
class LoopProgress:

class BatchProgress(Progress):
"""
Tracks the batch progress

Args:
total: Tracks the total epoch progress
current: Tracks the current epoch progress
"""
Track loop progress during execution.


@dataclass
class EpochProgress(Progress):
"""
Tracks the epoch progress
These counters are local to a trainer rank. By default, they are not globally synced across all ranks.

Args:
epoch: Tracks epochs progress.
total: Tracks the total epoch progress
current: Tracks the current epoch progress
batch: Tracks batch progress.
"""
epoch: Progress = field(default_factory=Progress)
batch: Progress = field(default_factory=Progress)

def increment_epoch_completed(self) -> None:
self.epoch.increment_completed()
self.reset_on_epoch()
batch: BatchProgress = field(default_factory=BatchProgress)

def reset_on_epoch(self) -> None:
self.batch.current.reset()
self.epoch.current.reset()

def __setstate__(self, state: dict) -> None:
super().__setstate__(state)
self.batch.__setstate__(state["batch"])


@dataclass
class OptimizationProgress:
class OptimizerProgress(_DataclassStateDictMixin):
"""
Track optimizer progress.

Args:
step: Tracks ``optimizer.step`` calls.
zero_grad: Tracks ``optimizer.zero_grad`` calls.
"""

step: Progress = field(default_factory=lambda: Progress.from_defaults(processed=None))
zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(processed=None))

def reset_on_epoch(self) -> None:
self.step.current.reset()
self.zero_grad.current.reset()

def __setstate__(self, state: dict) -> None:
self.step.__setstate__(state["step"])
self.zero_grad.__setstate__(state["zero_grad"])


@dataclass
class OptimizationProgress(_DataclassStateDictMixin):
"""
Track optimization progress.

Args:
optimizer: Tracks optimizer progress.
scheduler: Tracks scheduler progress.
"""
optimizer: Progress = Progress.from_defaults(processed=None)
scheduler: Progress = Progress.from_defaults(started=None, processed=None)
zero_grad: Progress = Progress.from_defaults(processed=None)

# TODO: support for multiple optimizers
optimizer: OptimizerProgress = field(default_factory=OptimizerProgress)
scheduler: Progress = field(default_factory=lambda: Progress.from_defaults(started=None, processed=None))

@property
def optimizer_steps(self) -> int:
return self.optimizer.total.completed
return self.optimizer.step.total.completed

@property
def scheduler_steps(self) -> int:
return self.scheduler.total.completed

def reset_on_epoch(self) -> None:
self.optimizer.reset_on_epoch()
self.scheduler.current.reset()

def __setstate__(self, state: dict) -> None:
self.optimizer.__setstate__(state["optimizer"])
self.scheduler.__setstate__(state["scheduler"])


@dataclass
class TrainingProgress(Progress):
class EpochLoopProgress(_DataclassStateDictMixin):
"""
Extends ``Progress`` with training specific attributes
Tracks epoch loop progress.
These counters are local to a trainer rank. By default, they are not globally synced across all ranks.

Args:
optimization: Tracks optimization progress
epoch: Tracks epochs progress.
"""
optimization: OptimizationProgress = field(default_factory=OptimizationProgress)

epoch: EpochProgress = field(default_factory=EpochProgress)

@dataclass
class TrainingLoopProgress(LoopProgress):
epoch: TrainingProgress = field(default_factory=TrainingProgress)
def increment_epoch_completed(self) -> None:
self.epoch.increment_completed()
self.reset_on_epoch()

def reset_on_epoch(self) -> None:
# override to avoid resetting `epoch.current`
self.batch.current.reset()
self.epoch.reset_on_epoch()
self.epoch.current.reset()

def __setstate__(self, state: dict) -> None:
self.epoch.__setstate__(state["epoch"])


@dataclass
class FitLoopProgress:
train: TrainingLoopProgress = field(default_factory=TrainingLoopProgress)
val: LoopProgress = field(default_factory=LoopProgress)
class TrainingEpochProgress(EpochProgress):
"""
Extends ``EpochProgress`` with training specific attributes

Args:
total: Tracks the total epoch progress.
current: Tracks the current epoch progress.
batch: Tracks batch progress.
optim: Tracks optimization progress.
val: Tracks validation_loop progress.
"""

optim: OptimizationProgress = field(default_factory=OptimizationProgress)
val: EpochLoopProgress = field(default_factory=EpochLoopProgress)

def __setstate__(self, state: dict) -> None:
super().__setstate__(state)
self.optim.__setstate__(state["optim"])
self.val.__setstate__(state["val"])


@dataclass
class LoopState:
class FitLoopProgress(EpochLoopProgress):
"""
Basic dataclass to track loop progress across trainer functions during trainer execution.
Extends ``EpochLoopProgress`` with fit specific attributes

This class will be removed and these attributes will live in each loop.
Args:
epoch: Tracks epochs progress.
"""

fit: FitLoopProgress = field(default_factory=FitLoopProgress)
val: LoopProgress = field(default_factory=LoopProgress)
test: LoopProgress = field(default_factory=LoopProgress)
predict: LoopProgress = field(default_factory=LoopProgress)
epoch: TrainingEpochProgress = field(default_factory=TrainingEpochProgress)

def reset_on_epoch(self) -> None:
# do not reset `epoch.current` as it should track the number of epochs this `fit` call
self.epoch.reset_on_epoch()
self.epoch.optim.reset_on_epoch()