Skip to content

Commit

Permalink
Avoid use of assert outside of the tests (#37718)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Mar 11, 2024
1 parent ccc9bb5 commit c0b849a
Show file tree
Hide file tree
Showing 18 changed files with 66 additions and 46 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Expand Up @@ -1117,7 +1117,7 @@ repos:
airflow/example_dags/.*
args:
- "--skip"
- "B301,B324,B403,B404,B603"
- "B101,B301,B324,B403,B404,B603"
- "--severity-level"
- "high" # TODO: remove this line when we fix all the issues
- id: pylint
Expand Down
3 changes: 2 additions & 1 deletion airflow/cli/commands/task_command.py
Expand Up @@ -212,7 +212,8 @@ def _run_task_by_selected_method(
- as raw task
- by executor
"""
assert not isinstance(ti, TaskInstancePydantic), "Wait for AIP-44 implementation to complete"
if TYPE_CHECKING:
assert not isinstance(ti, TaskInstancePydantic) # Wait for AIP-44 implementation to complete
if args.local:
return _run_task_by_local_task_job(args, ti)
if args.raw:
Expand Down
5 changes: 3 additions & 2 deletions airflow/models/skipmixin.py
Expand Up @@ -23,7 +23,6 @@
from sqlalchemy import select, update

from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
Expand All @@ -35,6 +34,7 @@
from pendulum import DateTime
from sqlalchemy import Session

from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.models.taskmixin import DAGNode
from airflow.serialization.pydantic.dag_run import DagRunPydantic
Expand Down Expand Up @@ -197,7 +197,8 @@ def skip_all_except(
)

dag_run = ti.get_dagrun()
assert isinstance(dag_run, DagRun)
if TYPE_CHECKING:
assert isinstance(dag_run, DagRun)

# TODO(potiuk): Handle TaskInstancePydantic case differently - we need to figure out the way to
# pass task that has been set in LocalTaskJob but in the way that TaskInstancePydantic definition
Expand Down
3 changes: 2 additions & 1 deletion airflow/models/taskinstance.py
Expand Up @@ -1025,7 +1025,8 @@ def _email_alert(
:meta private:
"""
subject, html_content, html_content_err = task_instance.get_email_subject_content(exception, task=task)
assert task.email
if TYPE_CHECKING:
assert task.email
try:
send_email(task.email, subject, html_content)
except Exception:
Expand Down
3 changes: 2 additions & 1 deletion airflow/models/xcom_arg.py
Expand Up @@ -406,7 +406,8 @@ def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
from airflow.models.taskinstance import TaskInstance

ti = context["ti"]
assert isinstance(ti, TaskInstance), "Wait for AIP-44 implementation to complete"
if not isinstance(ti, TaskInstance):
raise NotImplementedError("Wait for AIP-44 implementation to complete")

task_id = self.operator.task_id
map_indexes = ti.get_relevant_upstream_map_indexes(
Expand Down
6 changes: 2 additions & 4 deletions airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py
Expand Up @@ -80,7 +80,7 @@ def __init__(
) -> None:
self.adb_spark_conn_id = adb_spark_conn_id
self.adb_spark_conn = self.get_connection(adb_spark_conn_id)
self.region = self.get_default_region() if region is None else region
self.region = region or self.get_default_region()
super().__init__(*args, **kwargs)

def submit_spark_app(
Expand Down Expand Up @@ -327,8 +327,6 @@ def _validate_extra_conf(conf: dict[Any, Any]) -> bool:

def get_adb_spark_client(self) -> Client:
"""Get valid AnalyticDB MySQL Spark client."""
assert self.region is not None

extra_config = self.adb_spark_conn.extra_dejson
auth_type = extra_config.get("auth_type", None)
if not auth_type:
Expand All @@ -352,7 +350,7 @@ def get_adb_spark_client(self) -> Client:
)
)

def get_default_region(self) -> str | None:
def get_default_region(self) -> str:
"""Get default region from connection."""
extra_config = self.adb_spark_conn.extra_dejson
auth_type = extra_config.get("auth_type", None)
Expand Down
5 changes: 2 additions & 3 deletions airflow/providers/alibaba/cloud/hooks/oss.py
Expand Up @@ -87,7 +87,7 @@ class OSSHook(BaseHook):
def __init__(self, region: str | None = None, oss_conn_id="oss_default", *args, **kwargs) -> None:
self.oss_conn_id = oss_conn_id
self.oss_conn = self.get_connection(oss_conn_id)
self.region = self.get_default_region() if region is None else region
self.region = region or self.get_default_region()
super().__init__(*args, **kwargs)

def get_conn(self) -> Connection:
Expand Down Expand Up @@ -137,7 +137,6 @@ def get_bucket(self, bucket_name: str | None = None) -> oss2.api.Bucket:
:return: the bucket object to the bucket name.
"""
auth = self.get_credential()
assert self.region is not None
return oss2.Bucket(auth, f"https://oss-{self.region}.aliyuncs.com", bucket_name)

@provide_bucket_name
Expand Down Expand Up @@ -353,7 +352,7 @@ def get_credential(self) -> oss2.auth.Auth:

return oss2.Auth(oss_access_key_id, oss_access_key_secret)

def get_default_region(self) -> str | None:
def get_default_region(self) -> str:
extra_config = self.oss_conn.extra_dejson
auth_type = extra_config.get("auth_type", None)
if not auth_type:
Expand Down
Expand Up @@ -31,6 +31,7 @@

import json
from json import JSONDecodeError
from typing import TYPE_CHECKING

from airflow.configuration import conf
from airflow.providers.amazon.aws.executors.batch.utils import (
Expand Down Expand Up @@ -68,8 +69,9 @@ def build_submit_kwargs() -> dict:
if "eksPropertiesOverride" in job_kwargs:
raise KeyError("Eks jobs are not currently supported.")

if TYPE_CHECKING:
assert isinstance(job_kwargs, dict)
# some checks with some helpful errors
assert isinstance(job_kwargs, dict)
if "containerOverrides" not in job_kwargs or "command" not in job_kwargs["containerOverrides"]:
raise KeyError(
'SubmitJob API needs kwargs["containerOverrides"]["command"] field,'
Expand Down
14 changes: 8 additions & 6 deletions airflow/providers/fab/auth_manager/security_manager/override.py
Expand Up @@ -727,8 +727,8 @@ def create_admin_standalone(self) -> tuple[str | None, str | None]:
# If the user does not exist, make a random password and make it
if not user_exists:
print(f"FlaskAppBuilder Authentication Manager: Creating {user_name} user")
role = self.find_role("Admin")
assert role is not None
if (role := self.find_role("Admin")) is None:
raise AirflowException("Unable to find role 'Admin'")
# password does not contain visually similar characters: ijlIJL1oO0
password = "".join(random.choices("abcdefghkmnpqrstuvwxyzABCDEFGHKMNPQRSTUVWXYZ23456789", k=16))
with open(password_path, "w") as file:
Expand Down Expand Up @@ -2445,8 +2445,9 @@ def _ldap_bind_indirect(self, ldap, con) -> None:
:param ldap: The ldap module reference
:param con: The ldap connection
"""
# always check AUTH_LDAP_BIND_USER is set before calling this method
assert self.auth_ldap_bind_user, "AUTH_LDAP_BIND_USER must be set"
if not self.auth_ldap_bind_user:
# always check AUTH_LDAP_BIND_USER is set before calling this method
raise ValueError("AUTH_LDAP_BIND_USER must be set")

try:
log.debug("LDAP bind indirect TRY with username: %r", self.auth_ldap_bind_user)
Expand All @@ -2465,8 +2466,9 @@ def _search_ldap(self, ldap, con, username):
:param username: username to match with AUTH_LDAP_UID_FIELD
:return: ldap object array
"""
# always check AUTH_LDAP_SEARCH is set before calling this method
assert self.auth_ldap_search, "AUTH_LDAP_SEARCH must be set"
if not self.auth_ldap_search:
# always check AUTH_LDAP_SEARCH is set before calling this method
raise ValueError("AUTH_LDAP_SEARCH must be set")

# build the filter string for the LDAP search
if self.auth_ldap_search_filter:
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/microsoft/psrp/hooks/psrp.py
Expand Up @@ -20,7 +20,7 @@
from contextlib import contextmanager
from copy import copy
from logging import DEBUG, ERROR, INFO, WARNING
from typing import Any, Callable, Generator
from typing import TYPE_CHECKING, Any, Callable, Generator
from warnings import warn
from weakref import WeakKeyDictionary

Expand Down Expand Up @@ -173,7 +173,8 @@ def invoke(self) -> Generator[PowerShell, None, None]:
if local_context:
self.__enter__()
try:
assert self._conn is not None
if TYPE_CHECKING:
assert self._conn is not None
ps = PowerShell(self._conn)
yield ps
ps.begin_invoke()
Expand Down
3 changes: 0 additions & 3 deletions airflow/providers/openlineage/extractors/base.py
Expand Up @@ -78,9 +78,6 @@ def _is_operator_disabled(self) -> bool:
)
return fully_qualified_class_name in self.disabled_operators

def validate(self):
assert self.operator.task_type in self.get_operator_classnames()

@abstractmethod
def _execute_extraction(self) -> OperatorLineage | None:
...
Expand Down
6 changes: 2 additions & 4 deletions airflow/providers/yandex/secrets/lockbox.py
Expand Up @@ -128,10 +128,8 @@ def __init__(
self.yc_connection_id = None
if not any([yc_oauth_token, yc_sa_key_json, yc_sa_key_json_path]):
self.yc_connection_id = yc_connection_id or default_conn_name
else:
assert (
yc_connection_id is None
), "yc_connection_id should not be used if other credentials are specified"
elif yc_connection_id is not None:
raise ValueError("`yc_connection_id` should not be used if other credentials are specified")

self.folder_id = folder_id
self.connections_prefix = connections_prefix.rstrip(sep) if connections_prefix is not None else None
Expand Down
5 changes: 3 additions & 2 deletions airflow/serialization/pydantic/job.py
Expand Up @@ -16,7 +16,7 @@
# under the License.
import datetime
from functools import cached_property
from typing import Optional
from typing import TYPE_CHECKING, Optional

from airflow.executors.executor_loader import ExecutorLoader
from airflow.jobs.base_job_runner import BaseJobRunner
Expand Down Expand Up @@ -53,7 +53,8 @@ def executor(self):
def heartrate(self) -> float:
from airflow.jobs.job import Job

assert self.job_type is not None
if TYPE_CHECKING:
assert self.job_type is not None
return Job._heartrate(self.job_type)

def is_alive(self, grace_multiplier=2.1) -> bool:
Expand Down
14 changes: 13 additions & 1 deletion dev/breeze/src/airflow_breeze/commands/minor_release_command.py
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import os
import sys

import click

Expand Down Expand Up @@ -158,7 +159,18 @@ def create_constraints(version_branch):
@option_answer
def create_minor_version_branch(version_branch):
for obj in version_branch.split("-"):
assert isinstance(int(obj), int)
if not obj.isdigit():
console_print(f"[error]Failed `version_branch` part {obj!r} not a digit.")
sys.exit(1)
elif len(obj) > 1 and obj.startswith("0"):
# `01` is a valid digit string, as well as it could be converted to the integer,
# however, it might be considered as typo (e.g. 10) so better stop here
console_print(
f"[error]Found leading zero into the `version_branch` part {obj!r} ",
f"if it is not a typo consider to use {int(obj)} instead.",
)
sys.exit(1)

os.chdir(AIRFLOW_SOURCES_ROOT)
repo_root = os.getcwd()
console_print()
Expand Down
12 changes: 6 additions & 6 deletions dev/perf/dags/elastic_dag.py
Expand Up @@ -39,12 +39,12 @@ def parse_time_delta(time_str: str):
:param time_str: A string identifying a duration. (eg. 2h13m)
:return datetime.timedelta: A datetime.timedelta object or "@once"
"""
parts = RE_TIME_DELTA.match(time_str)

assert parts is not None, (
f"Could not parse any time information from '{time_str}'. "
f"Examples of valid strings: '8h', '2d8h5m20s', '2m4s'"
)
if (parts := RE_TIME_DELTA.match(time_str)) is None:
msg = (
f"Could not parse any time information from '{time_str}'. "
f"Examples of valid strings: '8h', '2d8h5m20s', '2m4s'"
)
raise ValueError(msg)

time_params = {name: float(param) for name, param in parts.groupdict().items() if param}
return timedelta(**time_params) # type: ignore
Expand Down
14 changes: 8 additions & 6 deletions pyproject.toml
Expand Up @@ -1330,6 +1330,7 @@ extend-select = [
"PGH004", # Use specific rule codes when using noqa
"PGH005", # Invalid unittest.mock.Mock methods/attributes/properties
"B006", # Checks for uses of mutable objects as function argument defaults.
"S101", # Checks use `assert` outside the test cases, test cases should be added into the exclusions
]
ignore = [
"D203",
Expand Down Expand Up @@ -1384,13 +1385,14 @@ combine-as-imports = true
"*/example_dags/*" = ["D"]
"chart/*" = ["D"]
"dev/*" = ["D"]
# In addition ignore top level imports, e.g. pandas, numpy in tests
# In addition, ignore top level imports, e.g. pandas, numpy (TID253) and use of assert (S101) in tests
"dev/perf/*" = ["TID253"]
"dev/breeze/tests/*" = ["TID253"]
"tests/*" = ["D", "TID253"]
"docker_tests/*" = ["D", "TID253"]
"kubernetes_tests/*" = ["D", "TID253"]
"helm_tests/*" = ["D", "TID253"]
"dev/check_files.py" = ["S101"]
"dev/breeze/tests/*" = ["TID253", "S101"]
"tests/*" = ["D", "TID253", "S101"]
"docker_tests/*" = ["D", "TID253", "S101"]
"kubernetes_tests/*" = ["D", "TID253", "S101"]
"helm_tests/*" = ["D", "TID253", "S101"]

# All of the modules which have an extra license header (i.e. that we copy from another project) need to
# ignore E402 -- module level import not at top level
Expand Down
5 changes: 4 additions & 1 deletion scripts/ci/pre_commit/pre_commit_sync_init_decorator.py
Expand Up @@ -24,6 +24,7 @@
import itertools
import pathlib
import sys
from typing import TYPE_CHECKING

PACKAGE_ROOT = pathlib.Path(__file__).resolve().parents[3].joinpath("airflow")
DAG_PY = PACKAGE_ROOT.joinpath("models", "dag.py")
Expand Down Expand Up @@ -93,7 +94,9 @@ def _match_arguments(
if dec is None and ini is not None:
yield f"Argument present in {init_name} but missing from @{deco_name}: {ini.arg}"
return
assert ini is not None and dec is not None # Because None is only possible as fillvalue.

if TYPE_CHECKING:
assert ini is not None and dec is not None # Because None is only possible as fillvalue.

if ini.arg != dec.arg:
yield f"Argument {i + 1} mismatch: {init_name} has {ini.arg} but @{deco_name} has {dec.arg}"
Expand Down
3 changes: 2 additions & 1 deletion scripts/in_container/run_migration_reference.py
Expand Up @@ -107,7 +107,8 @@ def revision_suffix(rev: Script):

def ensure_airflow_version(revisions: Iterable[Script]):
for rev in revisions:
assert rev.module.__file__ is not None # For Mypy.
if TYPE_CHECKING: # For mypy
assert rev.module.__file__ is not None
file = Path(rev.module.__file__)
content = file.read_text()
if not has_version(content):
Expand Down

0 comments on commit c0b849a

Please sign in to comment.