Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broken regex for allowed_deserialization_classes #36147

Merged
merged 12 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 12 additions & 3 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,20 @@ core:
allowed_deserialization_classes:
description: |
What classes can be imported during deserialization. This is a multi line value.
The individual items will be parsed as regexp. Python built-in classes (like dict)
are always allowed. Bare "." will be replaced so you can set airflow.* .
The individual items will be parsed as a pattern to a glob function.
Python built-in classes (like dict) are always allowed.
version_added: 2.5.0
type: string
default: 'airflow\..*'
default: 'airflow.*'
example: ~
allowed_deserialization_classes_regexp:
description: |
What classes can be imported during deserialization. This is a multi line value.
The individual items will be parsed as regexp patterns.
This is a secondary option to ``allowed_deserialization_classes``.
version_added: 2.8.1
type: string
default: ''
example: ~
killed_task_cleanup_time:
description: |
Expand Down
2 changes: 1 addition & 1 deletion airflow/config_templates/unit_tests.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ unit_test_mode = True
# We want to use a shorter timeout for task cleanup
killed_task_cleanup_time = 5
# We only allow our own classes to be deserialized in tests
allowed_deserialization_classes = airflow\..* tests\..*
allowed_deserialization_classes = airflow.* tests.*

[database]

Expand Down
27 changes: 23 additions & 4 deletions airflow/serialization/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import functools
import logging
import sys
from fnmatch import fnmatch
from importlib import import_module
from typing import TYPE_CHECKING, Any, Pattern, TypeVar, Union, cast

Expand Down Expand Up @@ -241,7 +242,6 @@ def deserialize(o: T | None, full=True, type_hint: Any = None) -> object:
# only return string representation
if not full:
return _stringify(classname, version, value)

if not _match(classname) and classname not in _extra_allowed:
raise ImportError(
f"{classname} was not found in allow list for deserialization imports. "
Expand Down Expand Up @@ -288,7 +288,22 @@ def _convert(old: dict) -> dict:


def _match(classname: str) -> bool:
return any(p.match(classname) is not None for p in _get_patterns())
"""Checks if the given classname matches a path pattern either using glob format or regexp format."""
return _match_glob(classname) or _match_regexp(classname)


@functools.lru_cache(maxsize=None)
def _match_glob(classname: str):
"""Checks if the given classname matches a pattern from allowed_deserialization_classes using glob syntax."""
patterns = _get_patterns()
return any(fnmatch(classname, p.pattern) for p in patterns)


@functools.lru_cache(maxsize=None)
def _match_regexp(classname: str):
"""Checks if the given classname matches a pattern from allowed_deserialization_classes_regexp using regexp."""
patterns = _get_regexp_patterns()
return any(p.match(classname) is not None for p in patterns)


def _stringify(classname: str, version: int, value: T | None) -> str:
Expand Down Expand Up @@ -359,8 +374,12 @@ def _register():

@functools.lru_cache(maxsize=None)
def _get_patterns() -> list[Pattern]:
patterns = conf.get("core", "allowed_deserialization_classes").split()
return [re2.compile(re2.sub(r"(\w)\.", r"\1\..", p)) for p in patterns]
return [re2.compile(p) for p in conf.get("core", "allowed_deserialization_classes").split()]


@functools.lru_cache(maxsize=None)
def _get_regexp_patterns() -> list[Pattern]:
return [re2.compile(p) for p in conf.get("core", "allowed_deserialization_classes_regexp").split()]


_register()
11 changes: 11 additions & 0 deletions newsfragments/36147.significant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
The ``allowed_deserialization_classes`` flag now follows a glob pattern.

For example if one wants to add the class ``airflow.tests.custom_class`` to the
``allowed_deserialization_classes`` list, it can be done by writing the full class
name (``airflow.tests.custom_class``) or a pattern such as the ones used in glob
search (e.g., ``airflow.*``, ``airflow.tests.*``).

If you currently use a custom regexp path make sure to rewrite it as a glob pattern.

Alternatively, if you still wish to match it as a regexp pattern, add it under the new
list ``allowed_deserialization_classes_regexp`` instead.
58 changes: 54 additions & 4 deletions tests/serialization/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
SCHEMA_ID,
VERSION,
_get_patterns,
_get_regexp_patterns,
_match,
_match_glob,
_match_regexp,
deserialize,
serialize,
)
Expand All @@ -44,10 +47,16 @@
@pytest.fixture()
def recalculate_patterns():
_get_patterns.cache_clear()
_get_regexp_patterns.cache_clear()
_match_glob.cache_clear()
_match_regexp.cache_clear()
try:
yield
finally:
_get_patterns.cache_clear()
_get_regexp_patterns.cache_clear()
_match_glob.cache_clear()
_match_regexp.cache_clear()


class Z:
Expand Down Expand Up @@ -218,7 +227,7 @@ def test_serder_dataclass(self):

@conf_vars(
{
("core", "allowed_deserialization_classes"): "airflow[.].*",
("core", "allowed_deserialization_classes"): "airflow.*",
}
)
@pytest.mark.usefixtures("recalculate_patterns")
Expand All @@ -232,13 +241,54 @@ def test_allow_list_for_imports(self):

@conf_vars(
{
("core", "allowed_deserialization_classes"): "tests.*",
("core", "allowed_deserialization_classes"): "tests.airflow.*",
}
)
@pytest.mark.usefixtures("recalculate_patterns")
def test_allow_list_replace(self):
def test_allow_list_match(self):
assert _match("tests.airflow.deep")
assert _match("testsfault") is False
assert _match("tests.wrongpath") is False

@conf_vars(
{
("core", "allowed_deserialization_classes"): "tests.airflow.deep",
}
)
@pytest.mark.usefixtures("recalculate_patterns")
def test_allow_list_match_class(self):
"""Test the match function when passing a full classname as
allowed_deserialization_classes
"""
assert _match("tests.airflow.deep")
assert _match("tests.airflow.FALSE") is False

@conf_vars(
{
("core", "allowed_deserialization_classes"): "",
("core", "allowed_deserialization_classes_regexp"): "tests\.airflow\..",
}
)
@pytest.mark.usefixtures("recalculate_patterns")
def test_allow_list_match_regexp(self):
"""Test the match function when passing a path as
allowed_deserialization_classes_regexp with no glob pattern defined
"""
assert _match("tests.airflow.deep")
assert _match("tests.wrongpath") is False

@conf_vars(
{
("core", "allowed_deserialization_classes"): "",
("core", "allowed_deserialization_classes_regexp"): "tests\.airflow\.deep",
}
)
@pytest.mark.usefixtures("recalculate_patterns")
def test_allow_list_match_class_regexp(self):
"""Test the match function when passing a full classname as
allowed_deserialization_classes_regexp with no glob pattern defined
"""
assert _match("tests.airflow.deep")
assert _match("tests.airflow.FALSE") is False

def test_incompatible_version(self):
data = dict(
Expand Down