Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,7 @@ def __hash__(self):
"pyspark.sql.tests.streaming.test_streaming_foreach_batch",
"pyspark.sql.tests.streaming.test_streaming_kafka_rtm",
"pyspark.sql.tests.streaming.test_streaming_listener",
"pyspark.sql.tests.streaming.test_state",
"pyspark.sql.tests.streaming.test_streaming_offline_state_repartition",
"pyspark.sql.tests.pandas.test_pandas_grouped_map_with_state",
"pyspark.sql.tests.pandas.streaming.test_pandas_transform_with_state",
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,13 @@
"Invalid streaming source name '<source_name>'. Source names must contain only ASCII letters, digits, and underscores."
]
},
"INVALID_TIMEOUT_DURATION_STRING": {
"message": [
"Provided timeout duration string '<duration>' is invalid.",
" Use a Spark interval string such as '5 minutes' or '1 hour 30 minutes'.",
" Only seconds may have a fractional quantity, e.g. '1.5 seconds'."
]
},
"INVALID_TIMEOUT_TIMESTAMP": {
"message": [
"Timeout timestamp (<timestamp>) cannot be earlier than the current watermark (<watermark>)."
Expand Down
96 changes: 85 additions & 11 deletions python/pyspark/sql/streaming/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,91 @@
#
import datetime
import json
from typing import Tuple, Optional
import re
from typing import Tuple, Optional, Union

from pyspark.sql.types import Row, StructType, TimestampType
from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkRuntimeError

__all__ = ["GroupState", "GroupStateTimeout"]

# Microseconds per unit, accumulating at us precision before converting to ms.
# Months and years use 31 days/month - Spark's structured streaming watermark
# convention (IntervalUtils.getDuration default daysPerMonth=31).
_TIMEOUT_DURATION_UNIT_TO_MICROS = {
"year": 12 * 31 * 24 * 60 * 60 * 1_000_000,
"years": 12 * 31 * 24 * 60 * 60 * 1_000_000,
"month": 31 * 24 * 60 * 60 * 1_000_000,
"months": 31 * 24 * 60 * 60 * 1_000_000,
"week": 7 * 24 * 60 * 60 * 1_000_000,
"weeks": 7 * 24 * 60 * 60 * 1_000_000,
"day": 24 * 60 * 60 * 1_000_000,
"days": 24 * 60 * 60 * 1_000_000,
"hour": 60 * 60 * 1_000_000,
"hours": 60 * 60 * 1_000_000,
"minute": 60 * 1_000_000,
"minutes": 60 * 1_000_000,
"second": 1_000_000,
"seconds": 1_000_000,
"millisecond": 1_000,
"milliseconds": 1_000,
"microsecond": 1,
"microseconds": 1,
}

_TIMEOUT_DURATION_SECOND_UNITS = frozenset({"second", "seconds"})

# Quantity is either a normal decimal (5, 1.5) or a leading-dot decimal (.5).
# Sign and quantity may be separated by optional whitespace (Scala's TRIM_BEFORE_VALUE).
# Quantity and unit MUST be separated by at least one whitespace (Scala's VALUE state
# requires whitespace to transition to TRIM_BEFORE_UNIT; no space is INVALID_VALUE).
_TIMEOUT_DURATION_COMPONENT = re.compile(
r"([+-]?)\s*((?:\d+(?:\.\d+)?|\.\d+))\s+"
r"(years?|months?|weeks?|days?|hours?|minutes?|seconds?|milliseconds?|microseconds?)",
re.IGNORECASE,
)

# Full-string validator: optional 'interval' keyword, then one or more components.
_TIMEOUT_DURATION_VALID = re.compile(
r"^(interval\s+)?"
r"(\s*[+-]?\s*(?:\d+(?:\.\d+)?|\.\d+)\s+"
r"(years?|months?|weeks?|days?|hours?|minutes?|seconds?|milliseconds?|microseconds?)\s*)+$",
re.IGNORECASE,
)


def _parse_timeout_duration(duration: str) -> int:
"""Convert a Spark interval string to milliseconds.

Supported format: [interval] [sign]quantity unit [[sign]quantity unit ...]
Supported units: years, months, weeks, days, hours, minutes, seconds,
milliseconds, microseconds. Months and years are converted using 31 days/month,
matching Spark's IntervalUtils.getDuration default (daysPerMonth=31).
Only seconds may carry a fractional quantity (e.g. '1.5 seconds').
Results are truncated to millisecond precision via integer division.
"""
if not _TIMEOUT_DURATION_VALID.match(duration.strip()):
raise PySparkValueError(
errorClass="INVALID_TIMEOUT_DURATION_STRING",
messageParameters={"duration": duration},
)
# Strip optional 'interval' keyword before component parsing.
s = re.sub(r"^interval\s+", "", duration.strip(), flags=re.IGNORECASE)
total_micros = 0
for sign, quantity_str, unit in _TIMEOUT_DURATION_COMPONENT.findall(s):
unit_lower = unit.lower()
is_fractional = "." in quantity_str
if is_fractional and unit_lower not in _TIMEOUT_DURATION_SECOND_UNITS:
raise PySparkValueError(
errorClass="INVALID_TIMEOUT_DURATION_STRING",
messageParameters={"duration": duration},
)
quantity = float(quantity_str) if is_fractional else int(quantity_str)
component_micros = round(quantity * _TIMEOUT_DURATION_UNIT_TO_MICROS[unit_lower])
total_micros += -component_micros if sign == "-" else component_micros
# Integer division to ms matches Spark's TimeUnit.convert behaviour.
return total_micros // 1_000


class GroupStateTimeout:
"""
Expand Down Expand Up @@ -170,21 +248,17 @@ def remove(self) -> None:
self._updated = False
self._removed = True

def setTimeoutDuration(self, durationMs: int) -> None:
def setTimeoutDuration(self, durationMs: Union[int, str]) -> None:
"""
Set the timeout duration in ms for this key.
Processing time timeout must be enabled.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add a versionchanged to doc that str is supported?

Copy link
Copy Markdown
Contributor Author

@brijrajk brijrajk Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, added! A versionchanged note is useful here because this is a behavioral change to an existing method — users upgrading from an older Spark version would not know that string durations are now accepted unless the API docs call it out explicitly.


.. versionchanged:: 5.0.0
`durationMs` now also accepts a Spark interval string such as
``'5 minutes'`` or ``'1 hour 30 minutes'``.
"""
if isinstance(durationMs, str):
# TODO(SPARK-40437): Support string representation of durationMs.
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"expected_type": "int",
"arg_name": "durationMs",
"arg_type": type(durationMs).__name__,
},
)
durationMs = _parse_timeout_duration(durationMs)

if self._timeout_conf != GroupStateTimeout.ProcessingTimeTimeout:
raise PySparkRuntimeError(
Expand Down
Loading