diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 540598d089ce0..5a1d22e71f177 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1117,7 +1117,7 @@ repos: airflow/example_dags/.* args: - "--skip" - - "B301,B324,B403,B404,B603" + - "B101,B301,B324,B403,B404,B603" - "--severity-level" - "high" # TODO: remove this line when we fix all the issues - id: pylint diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 8416789f5332b..0309fa1bf56f1 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -212,7 +212,8 @@ def _run_task_by_selected_method( - as raw task - by executor """ - assert not isinstance(ti, TaskInstancePydantic), "Wait for AIP-44 implementation to complete" + if TYPE_CHECKING: + assert not isinstance(ti, TaskInstancePydantic) # Wait for AIP-44 implementation to complete if args.local: return _run_task_by_local_task_job(args, ti) if args.raw: diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index 08d69694fd05e..13800d4dbb4aa 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -23,7 +23,6 @@ from sqlalchemy import select, update from airflow.exceptions import AirflowException, RemovedInAirflow3Warning -from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance from airflow.utils import timezone from airflow.utils.log.logging_mixin import LoggingMixin @@ -35,6 +34,7 @@ from pendulum import DateTime from sqlalchemy import Session + from airflow.models.dagrun import DagRun from airflow.models.operator import Operator from airflow.models.taskmixin import DAGNode from airflow.serialization.pydantic.dag_run import DagRunPydantic @@ -197,7 +197,8 @@ def skip_all_except( ) dag_run = ti.get_dagrun() - assert isinstance(dag_run, DagRun) + if TYPE_CHECKING: + assert isinstance(dag_run, DagRun) # TODO(potiuk): Handle TaskInstancePydantic case differently - we need to figure out the way to # pass task that has been set in LocalTaskJob but in the way that TaskInstancePydantic definition diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 57c9483cd4ee7..066c87b42f379 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1025,7 +1025,8 @@ def _email_alert( :meta private: """ subject, html_content, html_content_err = task_instance.get_email_subject_content(exception, task=task) - assert task.email + if TYPE_CHECKING: + assert task.email try: send_email(task.email, subject, html_content) except Exception: diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 60315617e40cd..fba0259454f7f 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -406,7 +406,8 @@ def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: from airflow.models.taskinstance import TaskInstance ti = context["ti"] - assert isinstance(ti, TaskInstance), "Wait for AIP-44 implementation to complete" + if not isinstance(ti, TaskInstance): + raise NotImplementedError("Wait for AIP-44 implementation to complete") task_id = self.operator.task_id map_indexes = ti.get_relevant_upstream_map_indexes( diff --git a/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py b/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py index e06ee912281bb..e40b892aea8e7 100644 --- a/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py +++ b/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py @@ -80,7 +80,7 @@ def __init__( ) -> None: self.adb_spark_conn_id = adb_spark_conn_id self.adb_spark_conn = self.get_connection(adb_spark_conn_id) - self.region = self.get_default_region() if region is None else region + self.region = region or self.get_default_region() super().__init__(*args, **kwargs) def submit_spark_app( @@ -327,8 +327,6 @@ def _validate_extra_conf(conf: dict[Any, Any]) -> bool: def get_adb_spark_client(self) -> Client: """Get valid AnalyticDB MySQL Spark client.""" - assert self.region is not None - extra_config = self.adb_spark_conn.extra_dejson auth_type = extra_config.get("auth_type", None) if not auth_type: @@ -352,7 +350,7 @@ def get_adb_spark_client(self) -> Client: ) ) - def get_default_region(self) -> str | None: + def get_default_region(self) -> str: """Get default region from connection.""" extra_config = self.adb_spark_conn.extra_dejson auth_type = extra_config.get("auth_type", None) diff --git a/airflow/providers/alibaba/cloud/hooks/oss.py b/airflow/providers/alibaba/cloud/hooks/oss.py index a675eed7062f5..cf0e20127ddae 100644 --- a/airflow/providers/alibaba/cloud/hooks/oss.py +++ b/airflow/providers/alibaba/cloud/hooks/oss.py @@ -87,7 +87,7 @@ class OSSHook(BaseHook): def __init__(self, region: str | None = None, oss_conn_id="oss_default", *args, **kwargs) -> None: self.oss_conn_id = oss_conn_id self.oss_conn = self.get_connection(oss_conn_id) - self.region = self.get_default_region() if region is None else region + self.region = region or self.get_default_region() super().__init__(*args, **kwargs) def get_conn(self) -> Connection: @@ -137,7 +137,6 @@ def get_bucket(self, bucket_name: str | None = None) -> oss2.api.Bucket: :return: the bucket object to the bucket name. """ auth = self.get_credential() - assert self.region is not None return oss2.Bucket(auth, f"https://oss-{self.region}.aliyuncs.com", bucket_name) @provide_bucket_name @@ -353,7 +352,7 @@ def get_credential(self) -> oss2.auth.Auth: return oss2.Auth(oss_access_key_id, oss_access_key_secret) - def get_default_region(self) -> str | None: + def get_default_region(self) -> str: extra_config = self.oss_conn.extra_dejson auth_type = extra_config.get("auth_type", None) if not auth_type: diff --git a/airflow/providers/amazon/aws/executors/batch/batch_executor_config.py b/airflow/providers/amazon/aws/executors/batch/batch_executor_config.py index 5f72b021632fc..9c0b0ebce93fd 100644 --- a/airflow/providers/amazon/aws/executors/batch/batch_executor_config.py +++ b/airflow/providers/amazon/aws/executors/batch/batch_executor_config.py @@ -31,6 +31,7 @@ import json from json import JSONDecodeError +from typing import TYPE_CHECKING from airflow.configuration import conf from airflow.providers.amazon.aws.executors.batch.utils import ( @@ -68,8 +69,9 @@ def build_submit_kwargs() -> dict: if "eksPropertiesOverride" in job_kwargs: raise KeyError("Eks jobs are not currently supported.") + if TYPE_CHECKING: + assert isinstance(job_kwargs, dict) # some checks with some helpful errors - assert isinstance(job_kwargs, dict) if "containerOverrides" not in job_kwargs or "command" not in job_kwargs["containerOverrides"]: raise KeyError( 'SubmitJob API needs kwargs["containerOverrides"]["command"] field,' diff --git a/airflow/providers/fab/auth_manager/security_manager/override.py b/airflow/providers/fab/auth_manager/security_manager/override.py index 1859646458efd..21daa647348e2 100644 --- a/airflow/providers/fab/auth_manager/security_manager/override.py +++ b/airflow/providers/fab/auth_manager/security_manager/override.py @@ -727,8 +727,8 @@ def create_admin_standalone(self) -> tuple[str | None, str | None]: # If the user does not exist, make a random password and make it if not user_exists: print(f"FlaskAppBuilder Authentication Manager: Creating {user_name} user") - role = self.find_role("Admin") - assert role is not None + if (role := self.find_role("Admin")) is None: + raise AirflowException("Unable to find role 'Admin'") # password does not contain visually similar characters: ijlIJL1oO0 password = "".join(random.choices("abcdefghkmnpqrstuvwxyzABCDEFGHKMNPQRSTUVWXYZ23456789", k=16)) with open(password_path, "w") as file: @@ -2445,8 +2445,9 @@ def _ldap_bind_indirect(self, ldap, con) -> None: :param ldap: The ldap module reference :param con: The ldap connection """ - # always check AUTH_LDAP_BIND_USER is set before calling this method - assert self.auth_ldap_bind_user, "AUTH_LDAP_BIND_USER must be set" + if not self.auth_ldap_bind_user: + # always check AUTH_LDAP_BIND_USER is set before calling this method + raise ValueError("AUTH_LDAP_BIND_USER must be set") try: log.debug("LDAP bind indirect TRY with username: %r", self.auth_ldap_bind_user) @@ -2465,8 +2466,9 @@ def _search_ldap(self, ldap, con, username): :param username: username to match with AUTH_LDAP_UID_FIELD :return: ldap object array """ - # always check AUTH_LDAP_SEARCH is set before calling this method - assert self.auth_ldap_search, "AUTH_LDAP_SEARCH must be set" + if not self.auth_ldap_search: + # always check AUTH_LDAP_SEARCH is set before calling this method + raise ValueError("AUTH_LDAP_SEARCH must be set") # build the filter string for the LDAP search if self.auth_ldap_search_filter: diff --git a/airflow/providers/microsoft/psrp/hooks/psrp.py b/airflow/providers/microsoft/psrp/hooks/psrp.py index aa292a3325e5d..8aff72adfb941 100644 --- a/airflow/providers/microsoft/psrp/hooks/psrp.py +++ b/airflow/providers/microsoft/psrp/hooks/psrp.py @@ -20,7 +20,7 @@ from contextlib import contextmanager from copy import copy from logging import DEBUG, ERROR, INFO, WARNING -from typing import Any, Callable, Generator +from typing import TYPE_CHECKING, Any, Callable, Generator from warnings import warn from weakref import WeakKeyDictionary @@ -173,7 +173,8 @@ def invoke(self) -> Generator[PowerShell, None, None]: if local_context: self.__enter__() try: - assert self._conn is not None + if TYPE_CHECKING: + assert self._conn is not None ps = PowerShell(self._conn) yield ps ps.begin_invoke() diff --git a/airflow/providers/openlineage/extractors/base.py b/airflow/providers/openlineage/extractors/base.py index c653638a564bc..5165c1c910132 100644 --- a/airflow/providers/openlineage/extractors/base.py +++ b/airflow/providers/openlineage/extractors/base.py @@ -78,9 +78,6 @@ def _is_operator_disabled(self) -> bool: ) return fully_qualified_class_name in self.disabled_operators - def validate(self): - assert self.operator.task_type in self.get_operator_classnames() - @abstractmethod def _execute_extraction(self) -> OperatorLineage | None: ... diff --git a/airflow/providers/yandex/secrets/lockbox.py b/airflow/providers/yandex/secrets/lockbox.py index d9ade29646e48..2bfef152c743c 100644 --- a/airflow/providers/yandex/secrets/lockbox.py +++ b/airflow/providers/yandex/secrets/lockbox.py @@ -128,10 +128,8 @@ def __init__( self.yc_connection_id = None if not any([yc_oauth_token, yc_sa_key_json, yc_sa_key_json_path]): self.yc_connection_id = yc_connection_id or default_conn_name - else: - assert ( - yc_connection_id is None - ), "yc_connection_id should not be used if other credentials are specified" + elif yc_connection_id is not None: + raise ValueError("`yc_connection_id` should not be used if other credentials are specified") self.folder_id = folder_id self.connections_prefix = connections_prefix.rstrip(sep) if connections_prefix is not None else None diff --git a/airflow/serialization/pydantic/job.py b/airflow/serialization/pydantic/job.py index fd805284253eb..7aec389ba9aa4 100644 --- a/airflow/serialization/pydantic/job.py +++ b/airflow/serialization/pydantic/job.py @@ -16,7 +16,7 @@ # under the License. import datetime from functools import cached_property -from typing import Optional +from typing import TYPE_CHECKING, Optional from airflow.executors.executor_loader import ExecutorLoader from airflow.jobs.base_job_runner import BaseJobRunner @@ -53,7 +53,8 @@ def executor(self): def heartrate(self) -> float: from airflow.jobs.job import Job - assert self.job_type is not None + if TYPE_CHECKING: + assert self.job_type is not None return Job._heartrate(self.job_type) def is_alive(self, grace_multiplier=2.1) -> bool: diff --git a/dev/breeze/src/airflow_breeze/commands/minor_release_command.py b/dev/breeze/src/airflow_breeze/commands/minor_release_command.py index fe43ecc7e6de8..c31901252fa24 100644 --- a/dev/breeze/src/airflow_breeze/commands/minor_release_command.py +++ b/dev/breeze/src/airflow_breeze/commands/minor_release_command.py @@ -17,6 +17,7 @@ from __future__ import annotations import os +import sys import click @@ -158,7 +159,18 @@ def create_constraints(version_branch): @option_answer def create_minor_version_branch(version_branch): for obj in version_branch.split("-"): - assert isinstance(int(obj), int) + if not obj.isdigit(): + console_print(f"[error]Failed `version_branch` part {obj!r} not a digit.") + sys.exit(1) + elif len(obj) > 1 and obj.startswith("0"): + # `01` is a valid digit string, as well as it could be converted to the integer, + # however, it might be considered as typo (e.g. 10) so better stop here + console_print( + f"[error]Found leading zero into the `version_branch` part {obj!r} ", + f"if it is not a typo consider to use {int(obj)} instead.", + ) + sys.exit(1) + os.chdir(AIRFLOW_SOURCES_ROOT) repo_root = os.getcwd() console_print() diff --git a/dev/perf/dags/elastic_dag.py b/dev/perf/dags/elastic_dag.py index 218ccb8fc1753..e0adcdf5caf11 100644 --- a/dev/perf/dags/elastic_dag.py +++ b/dev/perf/dags/elastic_dag.py @@ -39,12 +39,12 @@ def parse_time_delta(time_str: str): :param time_str: A string identifying a duration. (eg. 2h13m) :return datetime.timedelta: A datetime.timedelta object or "@once" """ - parts = RE_TIME_DELTA.match(time_str) - - assert parts is not None, ( - f"Could not parse any time information from '{time_str}'. " - f"Examples of valid strings: '8h', '2d8h5m20s', '2m4s'" - ) + if (parts := RE_TIME_DELTA.match(time_str)) is None: + msg = ( + f"Could not parse any time information from '{time_str}'. " + f"Examples of valid strings: '8h', '2d8h5m20s', '2m4s'" + ) + raise ValueError(msg) time_params = {name: float(param) for name, param in parts.groupdict().items() if param} return timedelta(**time_params) # type: ignore diff --git a/pyproject.toml b/pyproject.toml index 58cbcfa0e72eb..81bde956fde00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1330,6 +1330,7 @@ extend-select = [ "PGH004", # Use specific rule codes when using noqa "PGH005", # Invalid unittest.mock.Mock methods/attributes/properties "B006", # Checks for uses of mutable objects as function argument defaults. + "S101", # Checks use `assert` outside the test cases, test cases should be added into the exclusions ] ignore = [ "D203", @@ -1384,13 +1385,14 @@ combine-as-imports = true "*/example_dags/*" = ["D"] "chart/*" = ["D"] "dev/*" = ["D"] -# In addition ignore top level imports, e.g. pandas, numpy in tests +# In addition, ignore top level imports, e.g. pandas, numpy (TID253) and use of assert (S101) in tests "dev/perf/*" = ["TID253"] -"dev/breeze/tests/*" = ["TID253"] -"tests/*" = ["D", "TID253"] -"docker_tests/*" = ["D", "TID253"] -"kubernetes_tests/*" = ["D", "TID253"] -"helm_tests/*" = ["D", "TID253"] +"dev/check_files.py" = ["S101"] +"dev/breeze/tests/*" = ["TID253", "S101"] +"tests/*" = ["D", "TID253", "S101"] +"docker_tests/*" = ["D", "TID253", "S101"] +"kubernetes_tests/*" = ["D", "TID253", "S101"] +"helm_tests/*" = ["D", "TID253", "S101"] # All of the modules which have an extra license header (i.e. that we copy from another project) need to # ignore E402 -- module level import not at top level diff --git a/scripts/ci/pre_commit/pre_commit_sync_init_decorator.py b/scripts/ci/pre_commit/pre_commit_sync_init_decorator.py index 4e0e59bdaab49..963e9b9222537 100755 --- a/scripts/ci/pre_commit/pre_commit_sync_init_decorator.py +++ b/scripts/ci/pre_commit/pre_commit_sync_init_decorator.py @@ -24,6 +24,7 @@ import itertools import pathlib import sys +from typing import TYPE_CHECKING PACKAGE_ROOT = pathlib.Path(__file__).resolve().parents[3].joinpath("airflow") DAG_PY = PACKAGE_ROOT.joinpath("models", "dag.py") @@ -93,7 +94,9 @@ def _match_arguments( if dec is None and ini is not None: yield f"Argument present in {init_name} but missing from @{deco_name}: {ini.arg}" return - assert ini is not None and dec is not None # Because None is only possible as fillvalue. + + if TYPE_CHECKING: + assert ini is not None and dec is not None # Because None is only possible as fillvalue. if ini.arg != dec.arg: yield f"Argument {i + 1} mismatch: {init_name} has {ini.arg} but @{deco_name} has {dec.arg}" diff --git a/scripts/in_container/run_migration_reference.py b/scripts/in_container/run_migration_reference.py index 83436cb205822..c96768e616221 100755 --- a/scripts/in_container/run_migration_reference.py +++ b/scripts/in_container/run_migration_reference.py @@ -107,7 +107,8 @@ def revision_suffix(rev: Script): def ensure_airflow_version(revisions: Iterable[Script]): for rev in revisions: - assert rev.module.__file__ is not None # For Mypy. + if TYPE_CHECKING: # For mypy + assert rev.module.__file__ is not None file = Path(rev.module.__file__) content = file.read_text() if not has_version(content):