Skip to content

Commit

Permalink
Implement validation of missing template fields
Browse files Browse the repository at this point in the history
  • Loading branch information
shahar1 committed Dec 29, 2023
1 parent ed9080a commit fc2bad9
Show file tree
Hide file tree
Showing 5 changed files with 332 additions and 0 deletions.
11 changes: 11 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,17 @@ repos:
# changes quickly - especially when we want the early modifications from the first local group
# to be applied before the non-local pre-commits are run
hooks:
- id: validate-operators-init
name: Prevent templated field logic checks in operators' __init__
language: python
entry: ./scripts/ci/pre_commit/pre_commit_validate_operators_init.py
pass_filenames: true
files: ^airflow/providers/.*/(operators|transfers|sensors)/.*\.py$
additional_dependencies: [ 'rich>=12.4.4' ]
exclude: |
(?x)^(
^.*__init__.py$
)$
- id: replace-bad-characters
name: Replace bad characters
entry: ./scripts/ci/pre_commit/pre_commit_replace_bad_characters.py
Expand Down
2 changes: 2 additions & 0 deletions STATIC_CODE_CHECKS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ require Breeze Docker image to be built locally.
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| update-version | Update version to the latest version in the documentation | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| validate-operators-init | Prevent templated field logic checks in operators' __init__ | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| validate-pyproject | Validate pyproject.toml | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| yamllint | Check YAML files with yamllint | |
Expand Down
1 change: 1 addition & 0 deletions dev/breeze/src/airflow_breeze/pre_commit_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
"update-supported-versions",
"update-vendored-in-k8s-json-schema",
"update-version",
"validate-operators-init",
"validate-pyproject",
"yamllint",
]
82 changes: 82 additions & 0 deletions docs/apache-airflow/howto/custom-operator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,88 @@ Currently available lexers:

If you use a non-existing lexer then the value of the template field will be rendered as a pretty-printed object.

Limitations
^^^^^^^^^^^
To prevent misuse, the following limitations must be observed when defining and assigning templated fields in the
operator's constructor (when such exists, otherwise - see below):

1. Templated fields' corresponding parameters passed into the constructor must be named exactly
as the fields. The following example is invalid, as the parameter passed into the constructor is not the same as the
templated field:

.. code-block:: python
class HelloOperator(BaseOperator):
template_fields = "field_a"
def __init__(field_a_id) -> None: # <- should be def __init__(field_a)-> None
self.field_a = field_a_id # <- should be self.field_a = field_a
2. Templated fields' instance members must be assigned with their corresponding parameter from the constructor,
either by a direct assignment or by calling the parent's constructor (in which these fields are
defined as ``template_fields``) with explicit an assignment of the parameter.
The following example is invalid, as the instance member ``self.field_a`` is not assigned at all, despite being a
templated field:

.. code-block:: python
class HelloOperator(BaseOperator):
template_fields = ("field_a", "field_b")
def __init__(field_a, field_b) -> None:
self.field_b = field_b
The following example is also invalid, as the instance member ``self.field_a`` of ``MyHelloOperator`` is initialized
implicitly as part of the ``kwargs`` passed to its parent constructor:

.. code-block:: python
class HelloOperator(BaseOperator):
template_fields = "field_a"
def __init__(field_a) -> None:
self.field_a = field_a
class MyHelloOperator(HelloOperator):
template_fields = ("field_a", "field_b")
def __init__(field_b, **kwargs) -> None: # <- should be def __init__(field_a, field_b, **kwargs)
super().__init__(**kwargs) # <- should be super().__init__(field_a=field_a, **kwargs)
self.field_b = field_b
3. Applying actions on the parameter during the assignment in the constructor is not allowed.
Any action on the value should be applied in the `execute()` method.
Therefore, the following example is invalid:

.. code-block:: python
class HelloOperator(BaseOperator):
template_fields = "field_a"
def __init__(field_a) -> None:
self.field_a = field_a.lower() # <- assignment should be only self.field_a = field_a
When an operator inherits from a base operator and does not have a constructor defined on its own, the limitations above
do not apply. However, the templated fields must be set properly in the parent according to those limitations.

Thus, the following example is valid:

.. code-block:: python
class HelloOperator(BaseOperator):
template_fields = "field_a"
def __init__(field_a) -> None:
self.field_a = field_a
class MyHelloOperator(HelloOperator):
template_fields = "field_a"
The limitations above are enforced by a pre-commit named 'validate-operators-init'.

Add template fields with subclassing
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
A common use case for creating a custom operator is for simply augmenting existing ``template_fields``.
Expand Down
236 changes: 236 additions & 0 deletions scripts/ci/pre_commit/pre_commit_validate_operators_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import ast
import sys
from typing import Any

from rich.console import Console

console = Console(color_system="standard", width=200)
BASE_OPERATOR_CLASS_NAME = "BaseOperator"


def _is_operator(class_node: ast.ClassDef) -> bool:
"""
Check if a given class node is an operator, based of the string suffix of the base IDs
(ends with "BaseOperator").
TODO: Enhance this function to work with nested inheritance trees through dynamic imports.
:param class_node: The class node to check.
:return: True if the class definition is of an operator, False otherwise.
"""
for base in class_node.bases:
if isinstance(base, ast.Name) and base.id.endswith(BASE_OPERATOR_CLASS_NAME):
return True
return False


def _extract_template_fields(class_node: ast.ClassDef) -> list[str]:
"""
This method takes a class node as input and extracts the template fields from it.
Template fields are identified by an assignment statement where the target is a variable
named "template_fields" and the value is a tuple of constants.
:param class_node: The class node representing the class for which template fields need to be extracted.
:return: A list of template fields extracted from the class node.
"""
for class_item in class_node.body:
if isinstance(class_item, ast.Assign):
for target in class_item.targets:
if (
isinstance(target, ast.Name)
and target.id == "template_fields"
and isinstance(class_item.value, ast.Tuple)
):
return [elt.value for elt in class_item.value.elts if isinstance(elt, ast.Constant)]
elif isinstance(class_item, ast.AnnAssign):
if (
isinstance(class_item.target, ast.Name)
and class_item.target.id == "template_fields"
and isinstance(class_item.value, ast.Tuple)
):
return [elt.value for elt in class_item.value.elts if isinstance(elt, ast.Constant)]
return []


def _handle_parent_constructor_kwargs(
template_fields: list[str],
ctor_stmt: ast.stmt,
missing_assignments: list[str],
invalid_assignments: list[str],
) -> list[str]:
"""
This method checks if template fields are correctly assigned in a call to class parent's
constructor call.
It handles both the detection of missing assignments and invalid assignments.
It assumes that if the call is valid - the parent class will correctly assign the template
field.
TODO: Enhance this function to work with nested inheritance trees through dynamic imports.
:param missing_assignments: List[str] - List of template fields that have not been assigned a value.
:param ctor_stmt: ast.Expr - AST node representing the constructor statement.
:param invalid_assignments: List[str] - List of template fields that have been assigned incorrectly.
:param template_fields: List[str] - List of template fields to be assigned.
:return: List[str] - List of template fields that are still missing assignments.
"""
if isinstance(ctor_stmt, ast.Expr):
if (
isinstance(ctor_stmt.value, ast.Call)
and isinstance(ctor_stmt.value.func, ast.Attribute)
and isinstance(ctor_stmt.value.func.value, ast.Call)
and isinstance(ctor_stmt.value.func.value.func, ast.Name)
and ctor_stmt.value.func.value.func.id == "super"
):
for arg in ctor_stmt.value.keywords:
if arg.arg is not None and arg.arg in template_fields:
if not isinstance(arg.value, ast.Name) or arg.arg != arg.value.id:
invalid_assignments.append(arg.arg)
assigned_targets = [arg.arg for arg in ctor_stmt.value.keywords if arg.arg is not None]
return list(set(missing_assignments) - set(assigned_targets))
return missing_assignments


def _handle_constructor_statement(
template_fields: list[str],
ctor_stmt: ast.stmt,
missing_assignments: list[str],
invalid_assignments: list[str],
) -> list[str]:
"""
This method handles a single constructor statement by doing the following actions:
1. Removing assigned fields of template_fields from missing_assignments.
2. Detecting invalid assignments of template fields and adding them to invalid_assignments.
:param template_fields: Tuple of template fields.
:param ctor_stmt: Constructor statement (for example, self.field_name = param_name)
:param missing_assignments: List of missing assignments.
:param invalid_assignments: List of invalid assignments.
:return: List of missing assignments after handling the assigned targets.
"""
assigned_template_fields: list[str] = []
if isinstance(ctor_stmt, ast.Assign):
if isinstance(ctor_stmt.targets[0], ast.Attribute):
for target in ctor_stmt.targets:
if isinstance(target, ast.Attribute) and target.attr in template_fields:
if isinstance(ctor_stmt.value, ast.BoolOp) and isinstance(ctor_stmt.value.op, ast.Or):
_handle_assigned_field(
assigned_template_fields, invalid_assignments, target, ctor_stmt.value.values[0]
)
else:
_handle_assigned_field(
assigned_template_fields, invalid_assignments, target, ctor_stmt.value
)
elif isinstance(ctor_stmt.targets[0], ast.Tuple) and isinstance(ctor_stmt.value, ast.Tuple):
for target, value in zip(ctor_stmt.targets[0].elts, ctor_stmt.value.elts):
if isinstance(target, ast.Attribute):
_handle_assigned_field(assigned_template_fields, invalid_assignments, target, value)
elif isinstance(ctor_stmt, ast.AnnAssign):
if isinstance(ctor_stmt.target, ast.Attribute) and ctor_stmt.target.attr in template_fields:
_handle_assigned_field(
assigned_template_fields, invalid_assignments, ctor_stmt.target, ctor_stmt.value
)
return list(set(missing_assignments) - set(assigned_template_fields))


def _handle_assigned_field(
assigned_template_fields: list[str], invalid_assignments: list[str], target: ast.Attribute, value: Any
) -> None:
"""
Handle an assigned field by its value.
:param assigned_template_fields: A list to store the valid assigned fields.
:param invalid_assignments: A list to store the invalid assignments.
:param target: The target field.
:param value: The value of the field.
"""
if not isinstance(value, ast.Name):
invalid_assignments.append(target.attr)
else:
assigned_template_fields.append(target.attr)


def _check_constructor_template_fields(class_node: ast.ClassDef, template_fields: list[str]) -> int:
"""
This method checks a class's constructor for missing or invalid assignments of template fields.
When there isn't a constructor - it assumes that the template fields are defined in the parent's
constructor correctly.
TODO: Enhance this function to work with nested inheritance trees through dynamic imports.
:param class_node: the AST node representing the class definition
:param template_fields: a tuple of template fields
:return: the number of invalid template fields found
"""
count = 0
class_name = class_node.name
missing_assignments = template_fields.copy()
invalid_assignments: list[str] = []
init_flag: bool = False
for class_item in class_node.body:
if isinstance(class_item, ast.FunctionDef) and class_item.name == "__init__":
init_flag = True
for ctor_stmt in class_item.body:
missing_assignments = _handle_parent_constructor_kwargs(
template_fields, ctor_stmt, missing_assignments, invalid_assignments
)
missing_assignments = _handle_constructor_statement(
template_fields, ctor_stmt, missing_assignments, invalid_assignments
)

if init_flag and missing_assignments:
count += len(missing_assignments)
console.print(
f"{class_name}'s constructor lacks direct assignments for "
f"instance members corresponding to the following template fields "
f"(i.e., self.field_name = field_name or super.__init__(field_name=field_name, ...) ):"
)
console.print(f"[red]{missing_assignments}[/red]")

if invalid_assignments:
count += len(invalid_assignments)
console.print(
f"{class_name}'s constructor contains invalid assignments to the following instance "
f"members that should be corresponding to template fields "
f"(i.e., self.field_name = field_name):"
)
console.print(f"[red]{[f'self.{entry}' for entry in invalid_assignments]}[/red]")
return count


def main():
"""
Check missing or invalid template fields in constructors of providers' operators.
:return: The total number of errors found.
"""
err = 0
for path in sys.argv[1:]:
console.print(f"[yellow]{path}[/yellow]")
tree = ast.parse(open(path).read())
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and _is_operator(class_node=node):
template_fields = _extract_template_fields(node) or []
err += _check_constructor_template_fields(node, template_fields)
return err


if __name__ == "__main__":
sys.exit(main())

0 comments on commit fc2bad9

Please sign in to comment.