Skip to content

Commit

Permalink
Encrypt all trigger attributes (apache#38233)
Browse files Browse the repository at this point in the history
  • Loading branch information
hussein-awala authored and utkarsharma2 committed Apr 22, 2024
1 parent 63dc96c commit 3cdf0e3
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 7 deletions.
7 changes: 6 additions & 1 deletion airflow/api_connexion/schemas/trigger_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

from marshmallow import fields
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field

from airflow.models import Trigger
Expand All @@ -32,6 +33,10 @@ class Meta:

id = auto_field(dump_only=True)
classpath = auto_field(dump_only=True)
kwargs = auto_field(dump_only=True)
kwargs = fields.Method("get_kwars", dump_only=True)
created_date = auto_field(dump_only=True)
triggerer_id = auto_field(dump_only=True)

@staticmethod
def get_kwars(trigger: Trigger) -> str:
return str(trigger.kwargs)
4 changes: 3 additions & 1 deletion airflow/cli/commands/rotate_fernet_key_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from sqlalchemy import select

from airflow.models import Connection, Variable
from airflow.models import Connection, Trigger, Variable
from airflow.utils import cli as cli_utils
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
from airflow.utils.session import create_session
Expand All @@ -36,3 +36,5 @@ def rotate_fernet_key(args):
conn.rotate_fernet_key()
for var in session.scalars(select(Variable).where(Variable.is_encrypted)):
var.rotate_fernet_key()
for trigger in session.scalars(select(Trigger)):
trigger.rotate_fernet_key()
47 changes: 43 additions & 4 deletions airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from traceback import format_exception
from typing import TYPE_CHECKING, Any, Iterable

from sqlalchemy import Column, Integer, String, delete, func, or_, select, update
from sqlalchemy import Column, Integer, String, Text, delete, func, or_, select, update
from sqlalchemy.orm import joinedload, relationship
from sqlalchemy.sql.functions import coalesce

Expand All @@ -30,7 +30,7 @@
from airflow.utils import timezone
from airflow.utils.retries import run_with_db_retries
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, with_row_locks
from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
Expand Down Expand Up @@ -62,7 +62,7 @@ class Trigger(Base):

id = Column(Integer, primary_key=True)
classpath = Column(String(1000), nullable=False)
kwargs = Column(ExtendedJSON, nullable=False)
encrypted_kwargs = Column("kwargs", Text, nullable=False)
created_date = Column(UtcDateTime, nullable=False)
triggerer_id = Column(Integer, nullable=True)

Expand All @@ -83,9 +83,48 @@ def __init__(
) -> None:
super().__init__()
self.classpath = classpath
self.kwargs = kwargs
self.encrypted_kwargs = self._encrypt_kwargs(kwargs)
self.created_date = created_date or timezone.utcnow()

@property
def kwargs(self) -> dict[str, Any]:
"""Return the decrypted kwargs of the trigger."""
return self._decrypt_kwargs(self.encrypted_kwargs)

@kwargs.setter
def kwargs(self, kwargs: dict[str, Any]) -> None:
"""Set the encrypted kwargs of the trigger."""
self.encrypted_kwargs = self._encrypt_kwargs(kwargs)

@staticmethod
def _encrypt_kwargs(kwargs: dict[str, Any]) -> str:
"""Encrypt the kwargs of the trigger."""
import json

from airflow.models.crypto import get_fernet
from airflow.serialization.serialized_objects import BaseSerialization

serialized_kwargs = BaseSerialization.serialize(kwargs)
return get_fernet().encrypt(json.dumps(serialized_kwargs).encode("utf-8")).decode("utf-8")

@staticmethod
def _decrypt_kwargs(encrypted_kwargs: str) -> dict[str, Any]:
"""Decrypt the kwargs of the trigger."""
import json

from airflow.models.crypto import get_fernet
from airflow.serialization.serialized_objects import BaseSerialization

decrypted_kwargs = json.loads(get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8"))

return BaseSerialization.deserialize(decrypted_kwargs)

def rotate_fernet_key(self):
"""Encrypts data with a new key. See: :ref:`security/fernet`."""
from airflow.models.crypto import get_fernet

self.encrypted_kwargs = get_fernet().rotate(self.encrypted_kwargs.encode("utf-8")).decode("utf-8")

@classmethod
@internal_api_call
def from_object(cls, trigger: BaseTrigger) -> Trigger:
Expand Down
5 changes: 5 additions & 0 deletions docs/apache-airflow/authoring-and-scheduling/deferring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ Triggers can be as complex or as simple as you want, provided they meet the desi

If you are new to writing asynchronous Python, be very careful when writing your ``run()`` method. Python's async model means that code can block the entire process if it does not correctly ``await`` when it does a blocking operation. Airflow attempts to detect process blocking code and warn you in the triggerer logs when it happens. You can enable extra checks by Python by setting the variable ``PYTHONASYNCIODEBUG=1`` when you are writing your trigger to make sure you're writing non-blocking code. Be especially careful when doing filesystem calls, because if the underlying filesystem is network-backed, it can be blocking.

Sensitive information in triggers
'''''''''''''''''''''''''''''''''
Since Airflow 2.9.0, triggers kwargs are serialized and encrypted before being stored in the database. This means that any sensitive information you pass to a trigger will be stored in the database in an encrypted form, and decrypted when it is read from the database.


High Availability
-----------------

Expand Down
43 changes: 42 additions & 1 deletion tests/models/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@
from __future__ import annotations

import datetime
from typing import Any, AsyncIterator

import pytest
import pytz
from cryptography.fernet import Fernet

from airflow.jobs.job import Job
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
from airflow.models import TaskInstance, Trigger
from airflow.operators.empty import EmptyOperator
from airflow.triggers.base import TriggerEvent
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State
from tests.test_utils.config import conf_vars

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -337,3 +340,41 @@ def test_get_sorted_triggers_different_priority_weights(session, create_task_ins
trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, alive_triggerer_ids=[], session=session)

assert trigger_ids_query == [(2,), (1,)]


class SensitiveKwargsTrigger(BaseTrigger):
"""
A trigger that has sensitive kwargs.
"""

def __init__(self, param1: str, param2: str):
super().__init__()
self.param1 = param1
self.param2 = param2

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"tests.models.test_trigger.SensitiveKwargsTrigger",
{
"param1": self.param1,
"param2": self.param2,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
yield TriggerEvent({})


@conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()})
def test_serialize_sensitive_kwargs():
"""
Tests that sensitive kwargs are encrypted.
"""
trigger_instance = SensitiveKwargsTrigger(param1="value1", param2="value2")
trigger_row: Trigger = Trigger.from_object(trigger_instance)

assert trigger_row.kwargs["param1"] == "value1"
assert trigger_row.kwargs["param2"] == "value2"
assert isinstance(trigger_row.encrypted_kwargs, str)
assert "value1" not in trigger_row.encrypted_kwargs
assert "value2" not in trigger_row.encrypted_kwargs

0 comments on commit 3cdf0e3

Please sign in to comment.