Skip to content

Commit

Permalink
Improve speed to run airflow by 6x (#21438)
Browse files Browse the repository at this point in the history
By delaying expensive/slow imports to where they are needed, this gets
`airflow` printing it's usage information in under 0.8s, down from almost
3s which makes it feel much much snappier.

By not loading BaseExecutor we can get down to <0.5s
  • Loading branch information
ashb committed Feb 9, 2022
1 parent 0a3ff43 commit 1a8a897
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 45 deletions.
9 changes: 4 additions & 5 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
import warnings
from typing import Any, Dict, List, NamedTuple, Optional, Sized

from airflow.api_connexion.exceptions import NotFound as ApiConnexionNotFound
from airflow.utils.code_utils import prepare_code_snippet
from airflow.utils.platform import is_tty


class AirflowException(Exception):
"""
Expand All @@ -44,7 +40,7 @@ class AirflowBadRequest(AirflowException):
status_code = 400


class AirflowNotFoundException(AirflowException, ApiConnexionNotFound):
class AirflowNotFoundException(AirflowException):
"""Raise when the requested object/resource is not available in the system."""

status_code = 404
Expand Down Expand Up @@ -249,6 +245,9 @@ def __init__(self, msg: str, file_path: str, parse_errors: List[FileSyntaxError]
self.parse_errors = parse_errors

def __str__(self):
from airflow.utils.code_utils import prepare_code_snippet
from airflow.utils.platform import is_tty

result = f"{self.msg}\nFilename: {self.file_path}\n\n"

for error_no, parse_error in enumerate(self.parse_errors, 1):
Expand Down
16 changes: 9 additions & 7 deletions airflow/executors/executor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
import logging
from contextlib import suppress
from enum import Enum, unique
from typing import Optional, Tuple, Type
from typing import TYPE_CHECKING, Optional, Tuple, Type

from airflow.exceptions import AirflowConfigException
from airflow.executors.base_executor import BaseExecutor
from airflow.executors.executor_constants import (
CELERY_EXECUTOR,
CELERY_KUBERNETES_EXECUTOR,
Expand All @@ -35,6 +34,9 @@

log = logging.getLogger(__name__)

if TYPE_CHECKING:
from airflow.executors.base_executor import BaseExecutor


@unique
class ConnectorSource(Enum):
Expand All @@ -48,7 +50,7 @@ class ConnectorSource(Enum):
class ExecutorLoader:
"""Keeps constants for all the currently available executors."""

_default_executor: Optional[BaseExecutor] = None
_default_executor: Optional["BaseExecutor"] = None
executors = {
LOCAL_EXECUTOR: 'airflow.executors.local_executor.LocalExecutor',
SEQUENTIAL_EXECUTOR: 'airflow.executors.sequential_executor.SequentialExecutor',
Expand All @@ -60,7 +62,7 @@ class ExecutorLoader:
}

@classmethod
def get_default_executor(cls) -> BaseExecutor:
def get_default_executor(cls) -> "BaseExecutor":
"""Creates a new instance of the configured executor if none exists and returns it"""
if cls._default_executor is not None:
return cls._default_executor
Expand All @@ -74,7 +76,7 @@ def get_default_executor(cls) -> BaseExecutor:
return cls._default_executor

@classmethod
def load_executor(cls, executor_name: str) -> BaseExecutor:
def load_executor(cls, executor_name: str) -> "BaseExecutor":
"""
Loads the executor.
Expand All @@ -101,7 +103,7 @@ def load_executor(cls, executor_name: str) -> BaseExecutor:
return executor_cls()

@classmethod
def import_executor_cls(cls, executor_name: str) -> Tuple[Type[BaseExecutor], ConnectorSource]:
def import_executor_cls(cls, executor_name: str) -> Tuple[Type["BaseExecutor"], ConnectorSource]:
"""
Imports the executor class.
Expand All @@ -127,7 +129,7 @@ def import_executor_cls(cls, executor_name: str) -> Tuple[Type[BaseExecutor], Co
return import_string(executor_name), ConnectorSource.CUSTOM_PATH

@classmethod
def __load_celery_kubernetes_executor(cls) -> BaseExecutor:
def __load_celery_kubernetes_executor(cls) -> "BaseExecutor":
""":return: an instance of CeleryKubernetesExecutor"""
celery_executor = import_string(cls.executors[CELERY_EXECUTOR])()
kubernetes_executor = import_string(cls.executors[KUBERNETES_EXECUTOR])()
Expand Down
8 changes: 4 additions & 4 deletions airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@
import warnings
from typing import Any, Dict, ItemsView, MutableMapping, Optional, ValuesView

import jsonschema
from jsonschema import FormatChecker
from jsonschema.exceptions import ValidationError

from airflow.exceptions import AirflowException, ParamValidationError
from airflow.utils.context import Context
from airflow.utils.types import NOTSET, ArgNotSet
Expand Down Expand Up @@ -61,6 +57,10 @@ def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any:
:param suppress_exception: To raise an exception or not when the validations fails.
If true and validations fails, the return value would be None.
"""
import jsonschema
from jsonschema import FormatChecker
from jsonschema.exceptions import ValidationError

try:
json.dumps(value)
except Exception:
Expand Down
12 changes: 4 additions & 8 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,6 @@
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.timeout import timeout

try:
from kubernetes.client.api_client import ApiClient

from airflow.kubernetes.kube_config import KubeConfig
from airflow.kubernetes.pod_generator import PodGenerator
except ImportError:
ApiClient = None

TR = TaskReschedule

_CURRENT_CONTEXT: List[Context] = []
Expand Down Expand Up @@ -2032,7 +2024,11 @@ def render_templates(self, context: Optional[Context] = None) -> None:

def render_k8s_pod_yaml(self) -> Optional[dict]:
"""Render k8s pod yaml"""
from kubernetes.client.api_client import ApiClient

from airflow.kubernetes.kube_config import KubeConfig
from airflow.kubernetes.kubernetes_helper_functions import create_pod_id # Circular import
from airflow.kubernetes.pod_generator import PodGenerator

kube_config = KubeConfig()
pod = PodGenerator.construct_pod(
Expand Down
44 changes: 32 additions & 12 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,17 @@
from airflow.utils.module_loading import as_importable_string, import_string
from airflow.utils.task_group import MappedTaskGroup, TaskGroup

try:
# isort: off
from kubernetes.client import models as k8s
from airflow.kubernetes.pod_generator import PodGenerator

# isort: on
HAS_KUBERNETES = True
except ImportError:
HAS_KUBERNETES = False

if TYPE_CHECKING:
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep

HAS_KUBERNETES: bool
try:
from kubernetes.client import models as k8s

from airflow.kubernetes.pod_generator import PodGenerator
except ImportError:
pass

log = logging.getLogger(__name__)

_OPERATOR_EXTRA_LINKS: Set[str] = {
Expand Down Expand Up @@ -313,7 +311,7 @@ def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for r
return cls._encode({str(k): cls._serialize(v) for k, v in var.items()}, type_=DAT.DICT)
elif isinstance(var, list):
return [cls._serialize(v) for v in var]
elif HAS_KUBERNETES and isinstance(var, k8s.V1Pod):
elif _has_kubernetes() and isinstance(var, k8s.V1Pod):
json_pod = PodGenerator.serialize_pod(var)
return cls._encode(json_pod, type_=DAT.POD)
elif isinstance(var, DAG):
Expand Down Expand Up @@ -374,7 +372,7 @@ def _deserialize(cls, encoded_var: Any) -> Any:
elif type_ == DAT.DATETIME:
return pendulum.from_timestamp(var)
elif type_ == DAT.POD:
if not HAS_KUBERNETES:
if not _has_kubernetes():
raise RuntimeError("Cannot deserialize POD objects without kubernetes libraries installed!")
pod = PodGenerator.deserialize_model_dict(var)
return pod
Expand Down Expand Up @@ -1120,3 +1118,25 @@ class DagDependency:
def node_id(self):
"""Node ID for graph rendering"""
return f"{self.dependency_type}:{self.source}:{self.target}:{self.dependency_id}"


def _has_kubernetes() -> bool:
global HAS_KUBERNETES
if "HAS_KUBERNETES" in globals():
return HAS_KUBERNETES

# Loading kube modules is expensive, so delay it until the last moment

try:
from kubernetes.client import models as k8s

from airflow.kubernetes.pod_generator import PodGenerator

globals()['k8s'] = k8s
globals()['PodGenerator'] = PodGenerator

# isort: on
HAS_KUBERNETES = True
except ImportError:
HAS_KUBERNETES = False
return HAS_KUBERNETES
3 changes: 2 additions & 1 deletion airflow/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from airflow import settings
from airflow.exceptions import AirflowException
from airflow.utils import cli_action_loggers
from airflow.utils.db import check_and_run_migrations, synchronize_log_template
from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler
from airflow.utils.platform import getuser, is_terminal_support_colors
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -93,6 +92,8 @@ def wrapper(*args, **kwargs):
try:
# Check and run migrations if necessary
if check_db:
from airflow.utils.db import check_and_run_migrations, synchronize_log_template

check_and_run_migrations()
synchronize_log_template()
return f(*args, **kwargs)
Expand Down
18 changes: 11 additions & 7 deletions airflow/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,13 @@
)
from urllib import parse

import flask
import jinja2
import jinja2.nativetypes

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.utils.module_loading import import_string

if TYPE_CHECKING:
import jinja2

from airflow.models import TaskInstance

KEY_REGEX = re.compile(r'^[\w.-]+$')
Expand Down Expand Up @@ -171,8 +169,10 @@ def as_flattened_list(iterable: Iterable[Iterable[T]]) -> List[T]:
return [e for i in iterable for e in i]


def parse_template_string(template_string: str) -> Tuple[Optional[str], Optional[jinja2.Template]]:
def parse_template_string(template_string: str) -> Tuple[Optional[str], Optional["jinja2.Template"]]:
"""Parses Jinja template string."""
import jinja2

if "{{" in template_string: # jinja mode
return None, jinja2.Template(template_string)
else:
Expand Down Expand Up @@ -255,6 +255,8 @@ def build_airflow_url_with_query(query: Dict[str, Any]) -> str:
For example:
'http://0.0.0.0:8000/base/graph?dag_id=my-task&root=&execution_date=2020-10-27T10%3A59%3A25.615587
"""
import flask

view = conf.get('webserver', 'dag_default_view').lower()
url = flask.url_for(f"Airflow.{view}")
return f"{url}?{parse.urlencode(query)}"
Expand Down Expand Up @@ -285,16 +287,18 @@ def render_template(template: Any, context: MutableMapping[str, Any], *, native:
except Exception:
env.handle_exception() # Rewrite traceback to point to the template.
if native:
import jinja2.nativetypes

return jinja2.nativetypes.native_concat(nodes)
return "".join(nodes)


def render_template_to_string(template: jinja2.Template, context: MutableMapping[str, Any]) -> str:
def render_template_to_string(template: "jinja2.Template", context: MutableMapping[str, Any]) -> str:
"""Shorthand to ``render_template(native=False)`` with better typing support."""
return render_template(template, context, native=False)


def render_template_as_native(template: jinja2.Template, context: MutableMapping[str, Any]) -> Any:
def render_template_as_native(template: "jinja2.Template", context: MutableMapping[str, Any]) -> Any:
"""Shorthand to ``render_template(native=True)`` with better typing support."""
return render_template(template, context, native=True)

Expand Down
3 changes: 2 additions & 1 deletion airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from pathlib import Path
from typing import TYPE_CHECKING, Optional

import httpx
from itsdangerous import TimedJSONWebSignatureSerializer

from airflow.configuration import AirflowConfigException, conf
Expand Down Expand Up @@ -159,6 +158,8 @@ def _read(self, ti, try_number, metadata=None):
except Exception as f:
log += f'*** Unable to fetch logs from worker pod {ti.hostname} ***\n{str(f)}\n\n'
else:
import httpx

url = os.path.join("http://{ti.hostname}:{worker_log_server_port}/log", log_relative_path).format(
ti=ti, worker_log_server_port=conf.get('logging', 'WORKER_LOG_SERVER_PORT')
)
Expand Down

0 comments on commit 1a8a897

Please sign in to comment.