Skip to content

Commit

Permalink
Add a migration script for encrypted trigger kwargs (apache#38358)
Browse files Browse the repository at this point in the history
  • Loading branch information
hussein-awala authored and utkarsharma2 committed Apr 22, 2024
1 parent 7bde025 commit b0aead1
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""update trigger kwargs type
Revision ID: 1949afb29106
Revises: ee1467d4aa35
Create Date: 2024-03-17 22:09:09.406395
"""
import sqlalchemy as sa

from airflow.models.trigger import Trigger
from alembic import op

from airflow.utils.sqlalchemy import ExtendedJSON

# revision identifiers, used by Alembic.
revision = "1949afb29106"
down_revision = "ee1467d4aa35"
branch_labels = None
depends_on = None
airflow_version = "2.9.0"


def upgrade():
"""Update trigger kwargs type to string"""
with op.batch_alter_table("trigger") as batch_op:
batch_op.alter_column("kwargs", type_=sa.Text(), )


def downgrade():
"""Unapply update trigger kwargs type to string"""
with op.batch_alter_table("trigger") as batch_op:
batch_op.alter_column("kwargs", type_=ExtendedJSON(), postgresql_using="kwargs::json")
41 changes: 40 additions & 1 deletion airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
"2.7.0": "405de8318b3a",
"2.8.0": "10b52ebd31f7",
"2.8.1": "88344c1d9134",
"2.9.0": "1fd565369930",
"2.9.0": "1949afb29106",
}


Expand Down Expand Up @@ -972,6 +972,33 @@ def synchronize_log_template(*, session: Session = NEW_SESSION) -> None:
session.add(LogTemplate(filename=filename, elasticsearch_id=elasticsearch_id))


def encrypt_trigger_kwargs(*, session: Session) -> None:
"""Encrypt trigger kwargs."""
from airflow.models.trigger import Trigger
from airflow.serialization.serialized_objects import BaseSerialization

for trigger in session.query(Trigger):
# convert serialized dict to string and encrypt it
trigger.kwargs = BaseSerialization.deserialize(json.loads(trigger.encrypted_kwargs))
session.commit()


def decrypt_trigger_kwargs(*, session: Session) -> None:
"""Decrypt trigger kwargs."""
from airflow.models.trigger import Trigger
from airflow.serialization.serialized_objects import BaseSerialization

if not inspect(session.bind).has_table(Trigger.__tablename__):
# table does not exist, nothing to do
# this can happen when we downgrade to an old version before the Trigger table was added
return

for trigger in session.query(Trigger):
# decrypt the string and convert it to serialized dict
trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs))
session.commit()


def check_conn_id_duplicates(session: Session) -> Iterable[str]:
"""
Check unique conn_id in connection table.
Expand Down Expand Up @@ -1639,6 +1666,12 @@ def upgradedb(
_reserialize_dags(session=session)
add_default_pool_if_not_exists(session=session)
synchronize_log_template(session=session)
if _revision_greater(
config,
_REVISION_HEADS_MAP["2.9.0"],
_get_current_revision(session=session),
):
encrypt_trigger_kwargs(session=session)


@provide_session
Expand Down Expand Up @@ -1711,6 +1744,12 @@ def downgrade(*, to_revision, from_revision=None, show_sql_only=False, session:
else:
log.info("Applying downgrade migrations.")
command.downgrade(config, revision=to_revision, sql=show_sql_only)
if _revision_greater(
config,
_REVISION_HEADS_MAP["2.9.0"],
to_revision,
):
decrypt_trigger_kwargs(session=session)


def drop_airflow_models(connection):
Expand Down

0 comments on commit b0aead1

Please sign in to comment.