diff --git a/airflow/www/decorators.py b/airflow/www/decorators.py index 2bde2c83fbc12..29094969f7da2 100644 --- a/airflow/www/decorators.py +++ b/airflow/www/decorators.py @@ -19,6 +19,7 @@ import functools import gzip +import json import logging from io import BytesIO as IO from itertools import chain @@ -37,6 +38,43 @@ logger = logging.getLogger(__name__) +def _mask_variable_fields(extra_fields): + """ + The variable requests values and args comes in this form: + [('key', 'key_content'),('val', 'val_content'), ('description', 'description_content')] + So we need to mask the 'val_content' field if 'key_content' is in the mask list. + """ + result = [] + keyname = None + for k, v in extra_fields: + if k == "key": + keyname = v + result.append((k, v)) + elif keyname and k == "val": + x = secrets_masker.redact(v, keyname) + result.append((k, x)) + keyname = None + else: + result.append((k, v)) + return result + + +def _mask_connection_fields(extra_fields): + """Mask connection fields""" + result = [] + for k, v in extra_fields: + if k == "extra": + try: + extra = json.loads(v) + extra = [(k, secrets_masker.redact(v, k)) for k, v in extra.items()] + result.append((k, json.dumps(dict(extra)))) + except json.JSONDecodeError: + result.append((k, "Encountered non-JSON in `extra` field")) + else: + result.append((k, secrets_masker.redact(v, k))) + return result + + def action_logging(func: Callable | None = None, event: str | None = None) -> Callable[[T], T]: """Decorator to log user actions""" @@ -57,6 +95,10 @@ def wrapper(*args, **kwargs): for k, v in chain(request.values.items(multi=True), request.view_args.items()) if k not in fields_skip_logging ] + if event and event.startswith("variable."): + extra_fields = _mask_variable_fields(extra_fields) + if event and event.startswith("connection."): + extra_fields = _mask_connection_fields(extra_fields) params = {k: v for k, v in chain(request.values.items(), request.view_args.items())} diff --git a/tests/test_utils/www.py b/tests/test_utils/www.py index 2f48c304b4a19..8491d54094051 100644 --- a/tests/test_utils/www.py +++ b/tests/test_utils/www.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import ast from unittest import mock from airflow.models import Log @@ -73,3 +74,55 @@ def _check_last_log(session, dag_id, event, execution_date): assert len(logs) >= 1 assert logs[0].extra session.query(Log).delete() + + +def _check_last_log_masked_connection(session, dag_id, event, execution_date): + logs = ( + session.query( + Log.dag_id, + Log.task_id, + Log.event, + Log.execution_date, + Log.owner, + Log.extra, + ) + .filter( + Log.dag_id == dag_id, + Log.event == event, + Log.execution_date == execution_date, + ) + .order_by(Log.dttm.desc()) + .limit(5) + .all() + ) + assert len(logs) >= 1 + extra = ast.literal_eval(logs[0].extra) + for k, v in extra: + if k == "password": + assert v == "***" + if k == "extra": + assert v == '{"x_secret": "***", "y_secret": "***"}' + + +def _check_last_log_masked_variable(session, dag_id, event, execution_date): + logs = ( + session.query( + Log.dag_id, + Log.task_id, + Log.event, + Log.execution_date, + Log.owner, + Log.extra, + ) + .filter( + Log.dag_id == dag_id, + Log.event == event, + Log.execution_date == execution_date, + ) + .order_by(Log.dttm.desc()) + .limit(5) + .all() + ) + assert len(logs) >= 1 + extra_dict = ast.literal_eval(logs[0].extra) + assert extra_dict == [("key", "x_secret"), ("val", "***")] diff --git a/tests/www/views/test_views_connection.py b/tests/www/views/test_views_connection.py index 28d3f9570f44e..a884bbe793b42 100644 --- a/tests/www/views/test_views_connection.py +++ b/tests/www/views/test_views_connection.py @@ -28,7 +28,7 @@ from airflow.utils.session import create_session from airflow.www.extensions import init_views from airflow.www.views import ConnectionFormWidget, ConnectionModelView -from tests.test_utils.www import _check_last_log, check_content_in_response +from tests.test_utils.www import _check_last_log, _check_last_log_masked_connection, check_content_in_response CONNECTION = { "conn_id": "test_conn", @@ -40,6 +40,12 @@ "password": "admin", } +CONNECTION_WITH_EXTRA = CONNECTION.update( + { + "extra": '{"x_secret": "testsecret","y_secret": "test"}', + } +) + @pytest.fixture(autouse=True) def clear_connections(): @@ -54,6 +60,12 @@ def test_create_connection(admin_client, session): _check_last_log(session, dag_id=None, event="connection.create", execution_date=None) +def test_action_logging_connection_masked_secrets(session, admin_client): + init_views.init_connection_form() + admin_client.post("/connection/add", data=CONNECTION_WITH_EXTRA, follow_redirects=True) + _check_last_log_masked_connection(session, dag_id=None, event="connection.create", execution_date=None) + + def test_prefill_form_null_extra(): mock_form = mock.Mock() mock_form.data = {"conn_id": "test", "extra": None, "conn_type": "test"} diff --git a/tests/www/views/test_views_decorators.py b/tests/www/views/test_views_decorators.py index 607249f04e37c..13ede6273d78d 100644 --- a/tests/www/views/test_views_decorators.py +++ b/tests/www/views/test_views_decorators.py @@ -28,8 +28,8 @@ from airflow.utils.types import DagRunType from airflow.www import app from airflow.www.views import action_has_dag_edit_access -from tests.test_utils.db import clear_db_runs -from tests.test_utils.www import _check_last_log, check_content_in_response +from tests.test_utils.db import clear_db_runs, clear_db_variables +from tests.test_utils.www import _check_last_log, _check_last_log_masked_variable, check_content_in_response EXAMPLE_DAG_DEFAULT_DATE = timezone.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) @@ -86,6 +86,13 @@ def dagruns(bash_dag, sub_dag, xcom_dag): clear_db_runs() +@pytest.fixture(autouse=True) +def clean_db(): + clear_db_variables() + yield + clear_db_variables() + + @action_has_dag_edit_access def some_view_action_which_requires_dag_edit_access(*args) -> bool: return True @@ -156,11 +163,17 @@ def delete_variable(session, key): def test_action_logging_variables_post(session, admin_client): - form = dict(key="random", value="random") + form = dict(key="random", val="random") admin_client.post("/variable/add", data=form) session.commit() _check_last_log(session, dag_id=None, event="variable.create", execution_date=None) - delete_variable(session, key="random") + + +def test_action_logging_variables_masked_secrets(session, admin_client): + form = dict(key="x_secret", val="randomval") + admin_client.post("/variable/add", data=form) + session.commit() + _check_last_log_masked_variable(session, dag_id=None, event="variable.create", execution_date=None) def test_calendar(admin_client, dagruns):