Skip to content

Commit

Permalink
Fix import future annotations in venv jinja template (#40208)
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-friday committed Jun 14, 2024
1 parent 67798b2 commit d5a7544
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
1 change: 1 addition & 0 deletions airflow/utils/python_virtualenv_script.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
specific language governing permissions and limitations
under the License.
-#}
from __future__ import annotations

import {{ pickling_library }}
import sys
Expand Down
30 changes: 30 additions & 0 deletions tests/decorators/test_python_virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
import sys
from importlib.util import find_spec
from subprocess import CalledProcessError
from typing import Any

import pytest

from airflow.decorators import setup, task, teardown
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.utils import timezone
from airflow.utils.state import TaskInstanceState

pytestmark = pytest.mark.db_test

Expand All @@ -37,6 +39,8 @@
CLOUDPICKLE_INSTALLED = find_spec("cloudpickle") is not None
CLOUDPICKLE_MARKER = pytest.mark.skipif(not CLOUDPICKLE_INSTALLED, reason="`cloudpickle` is not installed")

_Invalid = Any


class TestPythonVirtualenvDecorator:
@CLOUDPICKLE_MARKER
Expand Down Expand Up @@ -350,3 +354,29 @@ def f():
assert teardown_task.is_teardown
assert teardown_task.on_failure_fail_dagrun is on_failure_fail_dagrun
ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

def test_invalid_annotation(self, dag_maker):
import uuid

unique_id = uuid.uuid4().hex
value = {"unique_id": unique_id}

# Functions that throw an error
# if `from __future__ import annotations` is missing
@task.virtualenv(multiple_outputs=False, do_xcom_push=True)
def in_venv(value: dict[str, _Invalid]) -> _Invalid:
assert isinstance(value, dict)
return value["unique_id"]

with dag_maker():
ret = in_venv(value)

dr = dag_maker.create_dagrun()
ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date)
ti = dr.get_task_instances()[0]

assert ti.state == TaskInstanceState.SUCCESS

xcom = ti.xcom_pull(task_ids=ti.task_id, key="return_value")
assert isinstance(xcom, str)
assert xcom == unique_id

0 comments on commit d5a7544

Please sign in to comment.