Skip to content

Commit

Permalink
Adds more robust timer duration parsing (#19513)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebgorman committed Feb 24, 2024
1 parent 0f4522c commit 63188f9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `load_from_checkpoint` support for `LightningCLI` when using dependency injection ([#18105](https://github.com/Lightning-AI/lightning/pull/18105))

-
- Added robust timer duration parsing with an informative error message when parsing fails ([#19513](https://github.com/Lightning-AI/pytorch-lightning/pull/19513))

-

Expand Down
20 changes: 16 additions & 4 deletions src/lightning/pytorch/callbacks/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import logging
import re
import time
from datetime import timedelta
from typing import Any, Dict, Optional, Union
Expand Down Expand Up @@ -50,6 +51,8 @@ class Timer(Callback):
verbose: Set this to ``False`` to suppress logging messages.
Raises:
MisconfigurationException:
If ``duration`` is not in the expected format.
MisconfigurationException:
If ``interval`` is not one of the supported choices.
Expand Down Expand Up @@ -86,10 +89,19 @@ def __init__(
) -> None:
super().__init__()
if isinstance(duration, str):
dhms = duration.strip().split(":")
dhms = [int(i) for i in dhms]
duration = timedelta(days=dhms[0], hours=dhms[1], minutes=dhms[2], seconds=dhms[3])
if isinstance(duration, dict):
duration_match = re.fullmatch(r"(\d+):(\d\d):(\d\d):(\d\d)", duration.strip())
if not duration_match:
raise MisconfigurationException(
f"`Timer(duration={duration!r})` is not a valid duration. "
"Expected a string in the format DD:HH:MM:SS."
)
duration = timedelta(
days=int(duration_match.group(1)),
hours=int(duration_match.group(2)),
minutes=int(duration_match.group(3)),
seconds=int(duration_match.group(4)),
)
elif isinstance(duration, dict):
duration = timedelta(**duration)
if interval not in set(Interval):
raise MisconfigurationException(
Expand Down
6 changes: 6 additions & 0 deletions tests/tests_pytorch/callbacks/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def test_timer_parse_duration(duration, expected):
assert (timer.time_remaining() == expected is None) or (timer.time_remaining() == expected.total_seconds())


@pytest.mark.parametrize("duration", ["6:00:00", "60 minutes"])
def test_timer_parse_duration_misconfiguration(duration):
with pytest.raises(MisconfigurationException, match="format DD:HH:MM:SS"):
Timer(duration=duration)


def test_timer_interval_choice():
Timer(duration=timedelta(), interval="step")
Timer(duration=timedelta(), interval="epoch")
Expand Down

0 comments on commit 63188f9

Please sign in to comment.