From adeb7f7cba2ab2b16be2e006c17e140fe91fdf77 Mon Sep 17 00:00:00 2001 From: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> Date: Thu, 25 Apr 2024 15:20:40 -0400 Subject: [PATCH] Fix trigger kwarg encryption migration (#39246) Do the encryption in the migration itself, and fix support for offline migrations as well. The offline up migration won't actually encrypt the trigger kwargs as there isn't a safe way to accomplish that, so the decryption processes checks and short circuits if it isn't encrypted. The offline down migration will now print out a warning that the offline migration will fail if there are any running triggers. I think this is the best we can do for that scenario (and folks willing to do offline migrations will hopefully be able to understand the situation). This also solves the "encrypting the already encrypted kwargs" bug in 2.9.0. --- .../0140_2_9_0_update_trigger_kwargs_type.py | 46 ++++++++++++++++--- airflow/models/trigger.py | 10 +++- airflow/utils/db.py | 39 ---------------- docs/apache-airflow/img/airflow_erd.sha256 | 2 +- docs/apache-airflow/img/airflow_erd.svg | 8 ++-- docs/apache-airflow/migrations-ref.rst | 2 +- tests/models/test_trigger.py | 17 +++++++ 7 files changed, 72 insertions(+), 52 deletions(-) diff --git a/airflow/migrations/versions/0140_2_9_0_update_trigger_kwargs_type.py b/airflow/migrations/versions/0140_2_9_0_update_trigger_kwargs_type.py index dbde1201e4cd0..2d57686e43f46 100644 --- a/airflow/migrations/versions/0140_2_9_0_update_trigger_kwargs_type.py +++ b/airflow/migrations/versions/0140_2_9_0_update_trigger_kwargs_type.py @@ -16,18 +16,22 @@ # specific language governing permissions and limitations # under the License. -"""update trigger kwargs type +"""update trigger kwargs type and encrypt Revision ID: 1949afb29106 Revises: ee1467d4aa35 Create Date: 2024-03-17 22:09:09.406395 """ +import json +from textwrap import dedent + +from alembic import context, op import sqlalchemy as sa +from sqlalchemy.orm import lazyload +from airflow.serialization.serialized_objects import BaseSerialization from airflow.models.trigger import Trigger -from alembic import op - from airflow.utils.sqlalchemy import ExtendedJSON # revision identifiers, used by Alembic. @@ -38,13 +42,43 @@ airflow_version = "2.9.0" +def get_session() -> sa.orm.Session: + conn = op.get_bind() + sessionmaker = sa.orm.sessionmaker() + return sessionmaker(bind=conn) + def upgrade(): - """Update trigger kwargs type to string""" + """Update trigger kwargs type to string and encrypt""" with op.batch_alter_table("trigger") as batch_op: batch_op.alter_column("kwargs", type_=sa.Text(), ) + if not context.is_offline_mode(): + session = get_session() + try: + for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)): + trigger.kwargs = trigger.kwargs + session.commit() + finally: + session.close() + def downgrade(): - """Unapply update trigger kwargs type to string""" + """Unapply update trigger kwargs type to string and encrypt""" + if context.is_offline_mode(): + print(dedent(""" + ------------ + -- WARNING: Unable to decrypt trigger kwargs automatically in offline mode! + -- If any trigger rows exist when you do an offline downgrade, the migration will fail. + ------------ + """)) + else: + session = get_session() + try: + for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)): + trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs)) + session.commit() + finally: + session.close() + with op.batch_alter_table("trigger") as batch_op: - batch_op.alter_column("kwargs", type_=ExtendedJSON(), postgresql_using="kwargs::json") + batch_op.alter_column("kwargs", type_=ExtendedJSON(), postgresql_using='kwargs::json') diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py index d3509b377d532..670d88d6142bb 100644 --- a/airflow/models/trigger.py +++ b/airflow/models/trigger.py @@ -116,7 +116,15 @@ def _decrypt_kwargs(encrypted_kwargs: str) -> dict[str, Any]: from airflow.models.crypto import get_fernet from airflow.serialization.serialized_objects import BaseSerialization - decrypted_kwargs = json.loads(get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8")) + # We weren't able to encrypt the kwargs in all migration paths, + # so we need to handle the case where they are not encrypted. + # Triggers aren't long lasting, so we can skip encrypting them now. + if encrypted_kwargs.startswith("{"): + decrypted_kwargs = json.loads(encrypted_kwargs) + else: + decrypted_kwargs = json.loads( + get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8") + ) return BaseSerialization.deserialize(decrypted_kwargs) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index c0d282a587343..b7997498bb47d 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -972,33 +972,6 @@ def synchronize_log_template(*, session: Session = NEW_SESSION) -> None: session.add(LogTemplate(filename=filename, elasticsearch_id=elasticsearch_id)) -def encrypt_trigger_kwargs(*, session: Session) -> None: - """Encrypt trigger kwargs.""" - from airflow.models.trigger import Trigger - from airflow.serialization.serialized_objects import BaseSerialization - - for trigger in session.query(Trigger): - # convert serialized dict to string and encrypt it - trigger.kwargs = BaseSerialization.deserialize(json.loads(trigger.encrypted_kwargs)) - session.commit() - - -def decrypt_trigger_kwargs(*, session: Session) -> None: - """Decrypt trigger kwargs.""" - from airflow.models.trigger import Trigger - from airflow.serialization.serialized_objects import BaseSerialization - - if not inspect(session.bind).has_table(Trigger.__tablename__): - # table does not exist, nothing to do - # this can happen when we downgrade to an old version before the Trigger table was added - return - - for trigger in session.scalars(select(Trigger.encrypted_kwargs)): - # decrypt the string and convert it to serialized dict - trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs)) - session.commit() - - def check_conn_id_duplicates(session: Session) -> Iterable[str]: """ Check unique conn_id in connection table. @@ -1666,12 +1639,6 @@ def upgradedb( _reserialize_dags(session=session) add_default_pool_if_not_exists(session=session) synchronize_log_template(session=session) - if _revision_greater( - config, - _REVISION_HEADS_MAP["2.9.0"], - _get_current_revision(session=session), - ): - encrypt_trigger_kwargs(session=session) @provide_session @@ -1744,12 +1711,6 @@ def downgrade(*, to_revision, from_revision=None, show_sql_only=False, session: else: log.info("Applying downgrade migrations.") command.downgrade(config, revision=to_revision, sql=show_sql_only) - if _revision_greater( - config, - _REVISION_HEADS_MAP["2.9.0"], - to_revision, - ): - decrypt_trigger_kwargs(session=session) def drop_airflow_models(connection): diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256 index 8947b7e631598..cdcf039446dd5 100644 --- a/docs/apache-airflow/img/airflow_erd.sha256 +++ b/docs/apache-airflow/img/airflow_erd.sha256 @@ -1 +1 @@ -072fb4b43a86ccb57765ec3f163350519773be83ab38b7ac747d25e1197233e8 \ No newline at end of file +77757e21aee500cb7fe7fd75e0f158633a0037d4d74e6f45eb14238f901ebacd \ No newline at end of file diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg index bf4c6c94906a0..fb280ee0ea7fc 100644 --- a/docs/apache-airflow/img/airflow_erd.svg +++ b/docs/apache-airflow/img/airflow_erd.svg @@ -1421,28 +1421,28 @@ task_instance--xcom -0..N +1 1 task_instance--xcom -1 +0..N 1 task_instance--xcom -0..N +1 1 task_instance--xcom -1 +0..N 1 diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index d989564d91670..d858879d545fc 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -41,7 +41,7 @@ Here's the list of all the Database Migrations that are executed via when you ru +=================================+===================+===================+==============================================================+ | ``677fdbb7fc54`` (head) | ``1949afb29106`` | ``2.10.0`` | add new executor field to db | +---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ -| ``1949afb29106`` | ``ee1467d4aa35`` | ``2.9.0`` | update trigger kwargs type | +| ``1949afb29106`` | ``ee1467d4aa35`` | ``2.9.0`` | update trigger kwargs type and encrypt | +---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ | ``ee1467d4aa35`` | ``b4078ac230a1`` | ``2.9.0`` | add display name for dag and task instance | +---------------------------------+-------------------+-------------------+--------------------------------------------------------------+ diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py index a3dd6ce35afbf..6be2086f34112 100644 --- a/tests/models/test_trigger.py +++ b/tests/models/test_trigger.py @@ -17,6 +17,7 @@ from __future__ import annotations import datetime +import json from typing import Any, AsyncIterator import pytest @@ -27,6 +28,7 @@ from airflow.jobs.triggerer_job_runner import TriggererJobRunner from airflow.models import TaskInstance, Trigger from airflow.operators.empty import EmptyOperator +from airflow.serialization.serialized_objects import BaseSerialization from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils import timezone from airflow.utils.session import create_session @@ -378,3 +380,18 @@ def test_serialize_sensitive_kwargs(): assert isinstance(trigger_row.encrypted_kwargs, str) assert "value1" not in trigger_row.encrypted_kwargs assert "value2" not in trigger_row.encrypted_kwargs + + +def test_kwargs_not_encrypted(): + """ + Tests that we don't decrypt kwargs if they aren't encrypted. + We weren't able to encrypt the kwargs in all migration paths. + """ + trigger = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}) + # force the `encrypted_kwargs` to be unencrypted, like they would be after an offline upgrade + trigger.encrypted_kwargs = json.dumps( + BaseSerialization.serialize({"param1": "value1", "param2": "value2"}) + ) + + assert trigger.kwargs["param1"] == "value1" + assert trigger.kwargs["param2"] == "value2"