Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,20 @@ def _resolve_expires_at(expires_at: datetime | None | Literal["default"]) -> dat
"""
Resolve the expires_at value from the request body.

- ``"default"``: apply configured default_retention_days
- ``"default"``: apply configured ``[state_store] default_retention_days``.
``0`` means never expire. Negative values raise HTTP 400.
- ``None``: never expire
- datetime: use as-is
"""
if expires_at == "default":
days = conf.getint("state_store", "default_retention_days")
return datetime.now(tz=timezone.utc) + timedelta(days=days)
if days < 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"[state_store] default_retention_days must be >= 0, got {days}. "
"Set to 0 to disable expiry.",
)
return None if days == 0 else datetime.now(tz=timezone.utc) + timedelta(days=days)
Comment thread
amoghrajesh marked this conversation as resolved.
return expires_at


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,19 @@ def test_new_key_default_retention_applies_config(self, test_client, time_machin
resp = test_client.get(f"{BASE_URL}/job_id").json()
assert resp["expires_at"] == "2026-01-08T00:00:00Z"

def test_new_key_default_retention_zero_never_expires(self, test_client):
"""PUT with expires_at=default and default_retention_days=0 stores a key that never expires."""
with conf_vars({("state_store", "default_retention_days"): "0"}):
test_client.put(f"{BASE_URL}/job_id", json={"value": "v", "expires_at": "default"})
assert test_client.get(f"{BASE_URL}/job_id").json()["expires_at"] is None

def test_new_key_negative_retention_days_returns_400(self, test_client):
"""PUT with expires_at=default and default_retention_days<0 returns HTTP 400."""
with conf_vars({("state_store", "default_retention_days"): "-1"}):
resp = test_client.put(f"{BASE_URL}/job_id", json={"value": "v", "expires_at": "default"})
assert resp.status_code == 400
assert "default_retention_days" in resp.json()["detail"]

def test_new_key_never_expiry(self, test_client):
"""PUT with expires_at=null stores a key that never expires."""
test_client.put(f"{BASE_URL}/job_id", json={"value": "v", "expires_at": None})
Expand Down
7 changes: 6 additions & 1 deletion task-sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,12 @@ def set(self, key: str, value: JsonValue, *, retention: timedelta | None = None)
expires_at = now + retention
else:
days = conf.getint("state_store", "default_retention_days")
expires_at = None if days <= 0 else now + timedelta(days=days)
if days < 0:
raise ValueError(
f"[state_store] default_retention_days must be >= 0, got {days}. "
"Set to 0 to disable expiry."
)
expires_at = None if days == 0 else now + timedelta(days=days)

# if custom backend is configured, store the value on the custom backend, and return the reference
# to the stored value to store in the DB
Expand Down
6 changes: 6 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,12 @@ def test_set_global_default_zero_sends_null_expires_at(self, mock_supervisor_com
SetTaskStore(ti_id=self.TI_ID, key="job_id", value="app_001", expires_at=None)
)

def test_set_raises_on_negative_retention_days(self, mock_supervisor_comms):
"""set() raises ValueError when default_retention_days is negative."""
with conf_vars({("state_store", "default_retention_days"): "-1"}):
with pytest.raises(ValueError, match="default_retention_days must be >= 0"):
TaskStoreAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set("job_id", "app_001")

def test_delete_operation(self, mock_supervisor_comms):
mock_supervisor_comms.send.return_value = OKResponse(ok=True)

Expand Down
Loading