Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent assignment of non JSON serializable values to DagRun.conf dict #35096

Merged
merged 13 commits into from
Nov 14, 2023
51 changes: 49 additions & 2 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import itertools
import json
import os
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -96,6 +97,37 @@ class TISchedulingDecision(NamedTuple):
finished_tis: list[TI]


class ConfDict(dict):
"""Custom dictionary for storing only JSON serializable values."""

def __init__(self, val=None):
super().__init__(self.is_jsonable(val))

def __setitem__(self, key, value):
self.is_jsonable({key: value})
super().__setitem__(key, value)

@staticmethod
def is_jsonable(conf: dict) -> dict | None:
"""Prevent setting non-json attributes."""
try:
json.dumps(conf)
except TypeError:
raise AirflowException("Cannot assign non JSON Serializable value")
if isinstance(conf, dict):
return conf
else:
raise AirflowException(f"Object of type {type(conf)} must be a dict")

@staticmethod
def dump_check(conf: str) -> str:
val = json.loads(conf)
if isinstance(val, dict):
return conf
else:
raise TypeError(f"Object of type {type(val)} must be a dict")


def _creator_note(val):
"""Creator the ``note`` association proxy."""
if isinstance(val, str):
Expand Down Expand Up @@ -126,7 +158,7 @@ class DagRun(Base, LoggingMixin):
creating_job_id = Column(Integer)
external_trigger = Column(Boolean, default=True)
run_type = Column(String(50), nullable=False)
conf = Column(PickleType)
_conf = Column("conf", PickleType)
# These two must be either both NULL or both datetime.
data_interval_start = Column(UtcDateTime)
data_interval_end = Column(UtcDateTime)
Expand Down Expand Up @@ -228,7 +260,12 @@ def __init__(
self.execution_date = execution_date
self.start_date = start_date
self.external_trigger = external_trigger
self.conf = conf or {}

if isinstance(conf, str):
self._conf = ConfDict.dump_check(conf)
else:
self._conf = ConfDict(conf or {})

if state is not None:
self.state = state
if queued_at is NOTSET:
Expand Down Expand Up @@ -258,6 +295,16 @@ def validate_run_id(self, key: str, run_id: str) -> str | None:
)
return run_id

def get_conf(self):
return self._conf

def set_conf(self, value):
self._conf = ConfDict(value)
jscheffl marked this conversation as resolved.
Show resolved Hide resolved

@declared_attr
def conf(self):
return synonym("_conf", descriptor=property(self.get_conf, self.set_conf))

@property
def stats_tags(self) -> dict[str, str]:
return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type})
Expand Down
14 changes: 14 additions & 0 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -2617,3 +2617,17 @@ def test_dag_run_id_config(session, dag_maker, pattern, run_id, result):
else:
with pytest.raises(AirflowException):
dag_maker.create_dagrun(run_id=run_id)


def test_dagrun_conf():
dag_run = DagRun(conf={"test": 1234})
assert dag_run.conf == {"test": 1234}

with pytest.raises(AirflowException) as err:
dag_run.conf["non_json"] = timezone.utcnow()
assert str(err.value) == "Cannot assign non JSON Serializable value"

with pytest.raises(AirflowException) as err:
value = 1
dag_run.conf = value
assert str(err.value) == f"Object of type {type(value)} must be a dict"