Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent assignment of non JSON serializable values to DagRun.conf dict #35096

Merged
merged 13 commits into from
Nov 14, 2023
38 changes: 35 additions & 3 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import itertools
import json
import os
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -95,6 +96,29 @@ class TISchedulingDecision(NamedTuple):
unfinished_tis: list[TI]
finished_tis: list[TI]

class ConfDict(dict):
"""Custom dictionary for storing only JSON serializable values"""

def __init__(self, val=None):
super().__init__(self.is_jsonable(val))

def __setitem__(self, key, value):
self.is_jsonable({key: value})
super().__setitem__(key, value)

@staticmethod
def is_jsonable(conf: dict) -> dict | None:
"""Prevent setting non-json attributes"""
if conf is None:
return {}
try:
json.dumps(conf)
except TypeError:
raise AirflowException("Cannot assign non JSON Serializable value")
if isinstance(conf, dict):
return conf
else:
raise AirflowException(f"{conf} must be a dict")

def _creator_note(val):
"""Creator the ``note`` association proxy."""
Expand Down Expand Up @@ -126,7 +150,7 @@ class DagRun(Base, LoggingMixin):
creating_job_id = Column(Integer)
external_trigger = Column(Boolean, default=True)
run_type = Column(String(50), nullable=False)
conf = Column(PickleType)
_conf = Column(PickleType)
jscheffl marked this conversation as resolved.
Show resolved Hide resolved
# These two must be either both NULL or both datetime.
data_interval_start = Column(UtcDateTime)
data_interval_end = Column(UtcDateTime)
Expand Down Expand Up @@ -210,7 +234,7 @@ def __init__(
execution_date: datetime | None = None,
start_date: datetime | None = None,
external_trigger: bool | None = None,
conf: Any | None = None,
_conf: Any | None = None,
jscheffl marked this conversation as resolved.
Show resolved Hide resolved
state: DagRunState | None = None,
run_type: str | None = None,
dag_hash: str | None = None,
Expand All @@ -228,7 +252,7 @@ def __init__(
self.execution_date = execution_date
self.start_date = start_date
self.external_trigger = external_trigger
self.conf = conf or {}
self._conf = ConfDict(_conf)
jscheffl marked this conversation as resolved.
Show resolved Hide resolved
if state is not None:
self.state = state
if queued_at is NOTSET:
Expand Down Expand Up @@ -258,6 +282,14 @@ def validate_run_id(self, key: str, run_id: str) -> str | None:
)
return run_id

@property
def conf(self):
return self._conf

@conf.setter
def conf(self, value):
self._conf = ConfDict(value)
jscheffl marked this conversation as resolved.
Show resolved Hide resolved

@property
def stats_tags(self) -> dict[str, str]:
return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type})
Expand Down