Skip to content

Commit

Permalink
Revert "Prevent assignment of non JSON serializable values to DagRun.…
Browse files Browse the repository at this point in the history
…conf dict (#35096)" (#35959)

This reverts commit 84c40a7.

(cherry picked from commit 4a7c746)
  • Loading branch information
ephraimbuddy committed Dec 5, 2023
1 parent 54dc2b9 commit 7e9b6a4
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 63 deletions.
51 changes: 2 additions & 49 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from __future__ import annotations

import itertools
import json
import os
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -98,37 +97,6 @@ class TISchedulingDecision(NamedTuple):
finished_tis: list[TI]


class ConfDict(dict):
"""Custom dictionary for storing only JSON serializable values."""

def __init__(self, val=None):
super().__init__(self.is_jsonable(val))

def __setitem__(self, key, value):
self.is_jsonable({key: value})
super().__setitem__(key, value)

@staticmethod
def is_jsonable(conf: dict) -> dict | None:
"""Prevent setting non-json attributes."""
try:
json.dumps(conf)
except TypeError:
raise AirflowException("Cannot assign non JSON Serializable value")
if isinstance(conf, dict):
return conf
else:
raise AirflowException(f"Object of type {type(conf)} must be a dict")

@staticmethod
def dump_check(conf: str) -> str:
val = json.loads(conf)
if isinstance(val, dict):
return conf
else:
raise TypeError(f"Object of type {type(val)} must be a dict")


def _creator_note(val):
"""Creator the ``note`` association proxy."""
if isinstance(val, str):
Expand Down Expand Up @@ -159,7 +127,7 @@ class DagRun(Base, LoggingMixin):
creating_job_id = Column(Integer)
external_trigger = Column(Boolean, default=True)
run_type = Column(String(50), nullable=False)
_conf = Column("conf", PickleType)
conf = Column(PickleType)
# These two must be either both NULL or both datetime.
data_interval_start = Column(UtcDateTime)
data_interval_end = Column(UtcDateTime)
Expand Down Expand Up @@ -261,12 +229,7 @@ def __init__(
self.execution_date = execution_date
self.start_date = start_date
self.external_trigger = external_trigger

if isinstance(conf, str):
self._conf = ConfDict.dump_check(conf)
else:
self._conf = ConfDict(conf or {})

self.conf = conf or {}
if state is not None:
self.state = state
if queued_at is NOTSET:
Expand Down Expand Up @@ -296,16 +259,6 @@ def validate_run_id(self, key: str, run_id: str) -> str | None:
)
return run_id

def get_conf(self):
return self._conf

def set_conf(self, value):
self._conf = ConfDict(value)

@declared_attr
def conf(self):
return synonym("_conf", descriptor=property(self.get_conf, self.set_conf))

@property
def stats_tags(self) -> dict[str, str]:
return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type})
Expand Down
14 changes: 0 additions & 14 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -2618,17 +2618,3 @@ def test_dag_run_id_config(session, dag_maker, pattern, run_id, result):
else:
with pytest.raises(AirflowException):
dag_maker.create_dagrun(run_id=run_id)


def test_dagrun_conf():
dag_run = DagRun(conf={"test": 1234})
assert dag_run.conf == {"test": 1234}

with pytest.raises(AirflowException) as err:
dag_run.conf["non_json"] = timezone.utcnow()
assert str(err.value) == "Cannot assign non JSON Serializable value"

with pytest.raises(AirflowException) as err:
value = 1
dag_run.conf = value
assert str(err.value) == f"Object of type {type(value)} must be a dict"

0 comments on commit 7e9b6a4

Please sign in to comment.