From adeb7f7cba2ab2b16be2e006c17e140fe91fdf77 Mon Sep 17 00:00:00 2001
From: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com>
Date: Thu, 25 Apr 2024 15:20:40 -0400
Subject: [PATCH] Fix trigger kwarg encryption migration (#39246)
Do the encryption in the migration itself, and fix support for offline
migrations as well.
The offline up migration won't actually encrypt the trigger kwargs as there
isn't a safe way to accomplish that, so the decryption processes checks
and short circuits if it isn't encrypted.
The offline down migration will now print out a warning that the offline
migration will fail if there are any running triggers. I think this is
the best we can do for that scenario (and folks willing to do offline
migrations will hopefully be able to understand the situation).
This also solves the "encrypting the already encrypted kwargs" bug in
2.9.0.
---
.../0140_2_9_0_update_trigger_kwargs_type.py | 46 ++++++++++++++++---
airflow/models/trigger.py | 10 +++-
airflow/utils/db.py | 39 ----------------
docs/apache-airflow/img/airflow_erd.sha256 | 2 +-
docs/apache-airflow/img/airflow_erd.svg | 8 ++--
docs/apache-airflow/migrations-ref.rst | 2 +-
tests/models/test_trigger.py | 17 +++++++
7 files changed, 72 insertions(+), 52 deletions(-)
diff --git a/airflow/migrations/versions/0140_2_9_0_update_trigger_kwargs_type.py b/airflow/migrations/versions/0140_2_9_0_update_trigger_kwargs_type.py
index dbde1201e4cd0..2d57686e43f46 100644
--- a/airflow/migrations/versions/0140_2_9_0_update_trigger_kwargs_type.py
+++ b/airflow/migrations/versions/0140_2_9_0_update_trigger_kwargs_type.py
@@ -16,18 +16,22 @@
# specific language governing permissions and limitations
# under the License.
-"""update trigger kwargs type
+"""update trigger kwargs type and encrypt
Revision ID: 1949afb29106
Revises: ee1467d4aa35
Create Date: 2024-03-17 22:09:09.406395
"""
+import json
+from textwrap import dedent
+
+from alembic import context, op
import sqlalchemy as sa
+from sqlalchemy.orm import lazyload
+from airflow.serialization.serialized_objects import BaseSerialization
from airflow.models.trigger import Trigger
-from alembic import op
-
from airflow.utils.sqlalchemy import ExtendedJSON
# revision identifiers, used by Alembic.
@@ -38,13 +42,43 @@
airflow_version = "2.9.0"
+def get_session() -> sa.orm.Session:
+ conn = op.get_bind()
+ sessionmaker = sa.orm.sessionmaker()
+ return sessionmaker(bind=conn)
+
def upgrade():
- """Update trigger kwargs type to string"""
+ """Update trigger kwargs type to string and encrypt"""
with op.batch_alter_table("trigger") as batch_op:
batch_op.alter_column("kwargs", type_=sa.Text(), )
+ if not context.is_offline_mode():
+ session = get_session()
+ try:
+ for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)):
+ trigger.kwargs = trigger.kwargs
+ session.commit()
+ finally:
+ session.close()
+
def downgrade():
- """Unapply update trigger kwargs type to string"""
+ """Unapply update trigger kwargs type to string and encrypt"""
+ if context.is_offline_mode():
+ print(dedent("""
+ ------------
+ -- WARNING: Unable to decrypt trigger kwargs automatically in offline mode!
+ -- If any trigger rows exist when you do an offline downgrade, the migration will fail.
+ ------------
+ """))
+ else:
+ session = get_session()
+ try:
+ for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)):
+ trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs))
+ session.commit()
+ finally:
+ session.close()
+
with op.batch_alter_table("trigger") as batch_op:
- batch_op.alter_column("kwargs", type_=ExtendedJSON(), postgresql_using="kwargs::json")
+ batch_op.alter_column("kwargs", type_=ExtendedJSON(), postgresql_using='kwargs::json')
diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py
index d3509b377d532..670d88d6142bb 100644
--- a/airflow/models/trigger.py
+++ b/airflow/models/trigger.py
@@ -116,7 +116,15 @@ def _decrypt_kwargs(encrypted_kwargs: str) -> dict[str, Any]:
from airflow.models.crypto import get_fernet
from airflow.serialization.serialized_objects import BaseSerialization
- decrypted_kwargs = json.loads(get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8"))
+ # We weren't able to encrypt the kwargs in all migration paths,
+ # so we need to handle the case where they are not encrypted.
+ # Triggers aren't long lasting, so we can skip encrypting them now.
+ if encrypted_kwargs.startswith("{"):
+ decrypted_kwargs = json.loads(encrypted_kwargs)
+ else:
+ decrypted_kwargs = json.loads(
+ get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8")
+ )
return BaseSerialization.deserialize(decrypted_kwargs)
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index c0d282a587343..b7997498bb47d 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -972,33 +972,6 @@ def synchronize_log_template(*, session: Session = NEW_SESSION) -> None:
session.add(LogTemplate(filename=filename, elasticsearch_id=elasticsearch_id))
-def encrypt_trigger_kwargs(*, session: Session) -> None:
- """Encrypt trigger kwargs."""
- from airflow.models.trigger import Trigger
- from airflow.serialization.serialized_objects import BaseSerialization
-
- for trigger in session.query(Trigger):
- # convert serialized dict to string and encrypt it
- trigger.kwargs = BaseSerialization.deserialize(json.loads(trigger.encrypted_kwargs))
- session.commit()
-
-
-def decrypt_trigger_kwargs(*, session: Session) -> None:
- """Decrypt trigger kwargs."""
- from airflow.models.trigger import Trigger
- from airflow.serialization.serialized_objects import BaseSerialization
-
- if not inspect(session.bind).has_table(Trigger.__tablename__):
- # table does not exist, nothing to do
- # this can happen when we downgrade to an old version before the Trigger table was added
- return
-
- for trigger in session.scalars(select(Trigger.encrypted_kwargs)):
- # decrypt the string and convert it to serialized dict
- trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs))
- session.commit()
-
-
def check_conn_id_duplicates(session: Session) -> Iterable[str]:
"""
Check unique conn_id in connection table.
@@ -1666,12 +1639,6 @@ def upgradedb(
_reserialize_dags(session=session)
add_default_pool_if_not_exists(session=session)
synchronize_log_template(session=session)
- if _revision_greater(
- config,
- _REVISION_HEADS_MAP["2.9.0"],
- _get_current_revision(session=session),
- ):
- encrypt_trigger_kwargs(session=session)
@provide_session
@@ -1744,12 +1711,6 @@ def downgrade(*, to_revision, from_revision=None, show_sql_only=False, session:
else:
log.info("Applying downgrade migrations.")
command.downgrade(config, revision=to_revision, sql=show_sql_only)
- if _revision_greater(
- config,
- _REVISION_HEADS_MAP["2.9.0"],
- to_revision,
- ):
- decrypt_trigger_kwargs(session=session)
def drop_airflow_models(connection):
diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256
index 8947b7e631598..cdcf039446dd5 100644
--- a/docs/apache-airflow/img/airflow_erd.sha256
+++ b/docs/apache-airflow/img/airflow_erd.sha256
@@ -1 +1 @@
-072fb4b43a86ccb57765ec3f163350519773be83ab38b7ac747d25e1197233e8
\ No newline at end of file
+77757e21aee500cb7fe7fd75e0f158633a0037d4d74e6f45eb14238f901ebacd
\ No newline at end of file
diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg
index bf4c6c94906a0..fb280ee0ea7fc 100644
--- a/docs/apache-airflow/img/airflow_erd.svg
+++ b/docs/apache-airflow/img/airflow_erd.svg
@@ -1421,28 +1421,28 @@
task_instance--xcom
-0..N
+1
1
task_instance--xcom
-1
+0..N
1
task_instance--xcom
-0..N
+1
1
task_instance--xcom
-1
+0..N
1
diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst
index d989564d91670..d858879d545fc 100644
--- a/docs/apache-airflow/migrations-ref.rst
+++ b/docs/apache-airflow/migrations-ref.rst
@@ -41,7 +41,7 @@ Here's the list of all the Database Migrations that are executed via when you ru
+=================================+===================+===================+==============================================================+
| ``677fdbb7fc54`` (head) | ``1949afb29106`` | ``2.10.0`` | add new executor field to db |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
-| ``1949afb29106`` | ``ee1467d4aa35`` | ``2.9.0`` | update trigger kwargs type |
+| ``1949afb29106`` | ``ee1467d4aa35`` | ``2.9.0`` | update trigger kwargs type and encrypt |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
| ``ee1467d4aa35`` | ``b4078ac230a1`` | ``2.9.0`` | add display name for dag and task instance |
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py
index a3dd6ce35afbf..6be2086f34112 100644
--- a/tests/models/test_trigger.py
+++ b/tests/models/test_trigger.py
@@ -17,6 +17,7 @@
from __future__ import annotations
import datetime
+import json
from typing import Any, AsyncIterator
import pytest
@@ -27,6 +28,7 @@
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
from airflow.models import TaskInstance, Trigger
from airflow.operators.empty import EmptyOperator
+from airflow.serialization.serialized_objects import BaseSerialization
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils import timezone
from airflow.utils.session import create_session
@@ -378,3 +380,18 @@ def test_serialize_sensitive_kwargs():
assert isinstance(trigger_row.encrypted_kwargs, str)
assert "value1" not in trigger_row.encrypted_kwargs
assert "value2" not in trigger_row.encrypted_kwargs
+
+
+def test_kwargs_not_encrypted():
+ """
+ Tests that we don't decrypt kwargs if they aren't encrypted.
+ We weren't able to encrypt the kwargs in all migration paths.
+ """
+ trigger = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={})
+ # force the `encrypted_kwargs` to be unencrypted, like they would be after an offline upgrade
+ trigger.encrypted_kwargs = json.dumps(
+ BaseSerialization.serialize({"param1": "value1", "param2": "value2"})
+ )
+
+ assert trigger.kwargs["param1"] == "value1"
+ assert trigger.kwargs["param2"] == "value2"