diff --git a/airflow/api/client/__init__.py b/airflow/api/client/__init__.py index b7f9c78d30a49..49224b5336a32 100644 --- a/airflow/api/client/__init__.py +++ b/airflow/api/client/__init__.py @@ -26,7 +26,7 @@ def get_current_api_client() -> Client: """Return current API Client based on current Airflow configuration""" - api_module = import_module(conf.get('cli', 'api_client')) # type: Any + api_module = import_module(conf.get_mandatory_value('cli', 'api_client')) # type: Any auth_backends = api.load_auth() session = None for backend in auth_backends: diff --git a/airflow/config_templates/airflow_local_settings.py b/airflow/config_templates/airflow_local_settings.py index 14fa529991e9b..b2752c2be7c25 100644 --- a/airflow/config_templates/airflow_local_settings.py +++ b/airflow/config_templates/airflow_local_settings.py @@ -19,7 +19,7 @@ import os from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union from urllib.parse import urlparse from airflow.configuration import conf @@ -29,30 +29,32 @@ # in this file instead of from airflow.cfg. Currently # there are other log format and level configurations in # settings.py and cli.py. Please see AIRFLOW-1455. -LOG_LEVEL: str = conf.get('logging', 'LOGGING_LEVEL').upper() +LOG_LEVEL: str = conf.get_mandatory_value('logging', 'LOGGING_LEVEL').upper() # Flask appbuilder's info level log is very verbose, # so it's set to 'WARN' by default. -FAB_LOG_LEVEL: str = conf.get('logging', 'FAB_LOGGING_LEVEL').upper() +FAB_LOG_LEVEL: str = conf.get_mandatory_value('logging', 'FAB_LOGGING_LEVEL').upper() -LOG_FORMAT: str = conf.get('logging', 'LOG_FORMAT') +LOG_FORMAT: str = conf.get_mandatory_value('logging', 'LOG_FORMAT') -COLORED_LOG_FORMAT: str = conf.get('logging', 'COLORED_LOG_FORMAT') +COLORED_LOG_FORMAT: str = conf.get_mandatory_value('logging', 'COLORED_LOG_FORMAT') COLORED_LOG: bool = conf.getboolean('logging', 'COLORED_CONSOLE_LOG') -COLORED_FORMATTER_CLASS: str = conf.get('logging', 'COLORED_FORMATTER_CLASS') +COLORED_FORMATTER_CLASS: str = conf.get_mandatory_value('logging', 'COLORED_FORMATTER_CLASS') -BASE_LOG_FOLDER: str = conf.get('logging', 'BASE_LOG_FOLDER') +BASE_LOG_FOLDER: str = conf.get_mandatory_value('logging', 'BASE_LOG_FOLDER') -PROCESSOR_LOG_FOLDER: str = conf.get('scheduler', 'CHILD_PROCESS_LOG_DIRECTORY') +PROCESSOR_LOG_FOLDER: str = conf.get_mandatory_value('scheduler', 'CHILD_PROCESS_LOG_DIRECTORY') -DAG_PROCESSOR_MANAGER_LOG_LOCATION: str = conf.get('logging', 'DAG_PROCESSOR_MANAGER_LOG_LOCATION') +DAG_PROCESSOR_MANAGER_LOG_LOCATION: str = conf.get_mandatory_value( + 'logging', 'DAG_PROCESSOR_MANAGER_LOG_LOCATION' +) -FILENAME_TEMPLATE: str = conf.get('logging', 'LOG_FILENAME_TEMPLATE') +FILENAME_TEMPLATE: str = conf.get_mandatory_value('logging', 'LOG_FILENAME_TEMPLATE') -PROCESSOR_FILENAME_TEMPLATE: str = conf.get('logging', 'LOG_PROCESSOR_FILENAME_TEMPLATE') +PROCESSOR_FILENAME_TEMPLATE: str = conf.get_mandatory_value('logging', 'LOG_PROCESSOR_FILENAME_TEMPLATE') DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { 'version': 1, @@ -116,7 +118,7 @@ }, } -EXTRA_LOGGER_NAMES: str = conf.get('logging', 'EXTRA_LOGGER_NAMES', fallback=None) +EXTRA_LOGGER_NAMES: Optional[str] = conf.get('logging', 'EXTRA_LOGGER_NAMES', fallback=None) if EXTRA_LOGGER_NAMES: new_loggers = { logger_name.strip(): { @@ -171,7 +173,7 @@ if REMOTE_LOGGING: - ELASTICSEARCH_HOST: str = conf.get('elasticsearch', 'HOST') + ELASTICSEARCH_HOST: Optional[str] = conf.get('elasticsearch', 'HOST') # Storage bucket URL for remote logging # S3 buckets should start with "s3://" @@ -179,7 +181,7 @@ # GCS buckets should start with "gs://" # WASB buckets should start with "wasb" # just to help Airflow select correct handler - REMOTE_BASE_LOG_FOLDER: str = conf.get('logging', 'REMOTE_BASE_LOG_FOLDER') + REMOTE_BASE_LOG_FOLDER: str = conf.get_mandatory_value('logging', 'REMOTE_BASE_LOG_FOLDER') if REMOTE_BASE_LOG_FOLDER.startswith('s3://'): S3_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = { @@ -207,7 +209,7 @@ DEFAULT_LOGGING_CONFIG['handlers'].update(CLOUDWATCH_REMOTE_HANDLERS) elif REMOTE_BASE_LOG_FOLDER.startswith('gs://'): - key_path = conf.get('logging', 'GOOGLE_KEY_PATH', fallback=None) + key_path = conf.get_mandatory_value('logging', 'GOOGLE_KEY_PATH', fallback=None) GCS_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = { 'task': { 'class': 'airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler', @@ -235,7 +237,7 @@ DEFAULT_LOGGING_CONFIG['handlers'].update(WASB_REMOTE_HANDLERS) elif REMOTE_BASE_LOG_FOLDER.startswith('stackdriver://'): - key_path = conf.get('logging', 'GOOGLE_KEY_PATH', fallback=None) + key_path = conf.get_mandatory_value('logging', 'GOOGLE_KEY_PATH', fallback=None) # stackdriver:///airflow-tasks => airflow-tasks log_name = urlparse(REMOTE_BASE_LOG_FOLDER).path[1:] STACKDRIVER_REMOTE_HANDLERS = { @@ -260,14 +262,14 @@ } DEFAULT_LOGGING_CONFIG['handlers'].update(OSS_REMOTE_HANDLERS) elif ELASTICSEARCH_HOST: - ELASTICSEARCH_LOG_ID_TEMPLATE: str = conf.get('elasticsearch', 'LOG_ID_TEMPLATE') - ELASTICSEARCH_END_OF_LOG_MARK: str = conf.get('elasticsearch', 'END_OF_LOG_MARK') - ELASTICSEARCH_FRONTEND: str = conf.get('elasticsearch', 'frontend') + ELASTICSEARCH_LOG_ID_TEMPLATE: str = conf.get_mandatory_value('elasticsearch', 'LOG_ID_TEMPLATE') + ELASTICSEARCH_END_OF_LOG_MARK: str = conf.get_mandatory_value('elasticsearch', 'END_OF_LOG_MARK') + ELASTICSEARCH_FRONTEND: str = conf.get_mandatory_value('elasticsearch', 'frontend') ELASTICSEARCH_WRITE_STDOUT: bool = conf.getboolean('elasticsearch', 'WRITE_STDOUT') ELASTICSEARCH_JSON_FORMAT: bool = conf.getboolean('elasticsearch', 'JSON_FORMAT') - ELASTICSEARCH_JSON_FIELDS: str = conf.get('elasticsearch', 'JSON_FIELDS') - ELASTICSEARCH_HOST_FIELD: str = conf.get('elasticsearch', 'HOST_FIELD') - ELASTICSEARCH_OFFSET_FIELD: str = conf.get('elasticsearch', 'OFFSET_FIELD') + ELASTICSEARCH_JSON_FIELDS: str = conf.get_mandatory_value('elasticsearch', 'JSON_FIELDS') + ELASTICSEARCH_HOST_FIELD: str = conf.get_mandatory_value('elasticsearch', 'HOST_FIELD') + ELASTICSEARCH_OFFSET_FIELD: str = conf.get_mandatory_value('elasticsearch', 'OFFSET_FIELD') ELASTIC_REMOTE_HANDLERS: Dict[str, Dict[str, Union[str, bool]]] = { 'task': { diff --git a/airflow/config_templates/default_celery.py b/airflow/config_templates/default_celery.py index bbc8aa6a7faee..9d81c6353fba2 100644 --- a/airflow/config_templates/default_celery.py +++ b/airflow/config_templates/default_celery.py @@ -59,14 +59,14 @@ def _broker_supports_visibility_timeout(url): try: if celery_ssl_active: - if 'amqp://' in broker_url: + if broker_url and 'amqp://' in broker_url: broker_use_ssl = { 'keyfile': conf.get('celery', 'SSL_KEY'), 'certfile': conf.get('celery', 'SSL_CERT'), 'ca_certs': conf.get('celery', 'SSL_CACERT'), 'cert_reqs': ssl.CERT_REQUIRED, } - elif 'redis://' in broker_url: + elif broker_url and 'redis://' in broker_url: broker_use_ssl = { 'ssl_keyfile': conf.get('celery', 'SSL_KEY'), 'ssl_certfile': conf.get('celery', 'SSL_CERT'), @@ -92,7 +92,7 @@ def _broker_supports_visibility_timeout(url): f'all necessary certs and key ({e}).' ) -result_backend = DEFAULT_CELERY_CONFIG['result_backend'] +result_backend = str(DEFAULT_CELERY_CONFIG['result_backend']) if 'amqp://' in result_backend or 'redis://' in result_backend or 'rpc://' in result_backend: log.warning( "You have configured a result_backend of %s, it is highly recommended " diff --git a/airflow/configuration.py b/airflow/configuration.py index 2b9bf6a6d23e5..91b8a4e94e512 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import datetime import functools import json @@ -34,9 +33,12 @@ from configparser import _UNSET, ConfigParser, NoOptionError, NoSectionError # type: ignore from contextlib import suppress from json.decoder import JSONDecodeError -from typing import Any, Dict, List, Optional, Tuple, Union +from re import Pattern +from typing import IO, Any, Dict, Iterable, List, Optional, Set, Tuple, Union from urllib.parse import urlparse +from typing_extensions import overload + from airflow.exceptions import AirflowConfigException from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH, BaseSecretsBackend from airflow.utils import yaml @@ -52,6 +54,11 @@ _SQLITE3_VERSION_PATTERN = re.compile(r"(?P^\d+(?:\.\d+)*)\D?.*$") +ConfigType = Union[str, int, float, bool] +ConfigOptionsDictType = Dict[str, ConfigType] +ConfigSectionSourcesType = Dict[str, Union[str, Tuple[str, str]]] +ConfigSourcesType = Dict[str, ConfigSectionSourcesType] + def _parse_sqlite_version(s: str) -> Tuple[int, ...]: match = _SQLITE3_VERSION_PATTERN.match(s) @@ -60,7 +67,17 @@ def _parse_sqlite_version(s: str) -> Tuple[int, ...]: return tuple(int(p) for p in match.group("version").split(".")) -def expand_env_var(env_var): +@overload +def expand_env_var(env_var: None) -> None: + ... + + +@overload +def expand_env_var(env_var: str) -> str: + ... + + +def expand_env_var(env_var: Union[str, None]) -> Optional[Union[str, None]]: """ Expands (potentially nested) env vars by repeatedly applying `expandvars` and `expanduser` until interpolation stops having @@ -76,7 +93,7 @@ def expand_env_var(env_var): env_var = interpolated -def run_command(command): +def run_command(command: str) -> str: """Runs command and returns stdout""" process = subprocess.Popen( shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True @@ -92,7 +109,7 @@ def run_command(command): return output -def _get_config_value_from_secret_backend(config_key): +def _get_config_value_from_secret_backend(config_key: str) -> Optional[str]: """Get Config option values from Secret Backend""" try: secrets_client = get_custom_secret_backend() @@ -108,7 +125,7 @@ def _get_config_value_from_secret_backend(config_key): ) -def _default_config_file_path(file_name: str): +def _default_config_file_path(file_name: str) -> str: templates_dir = os.path.join(os.path.dirname(__file__), 'config_templates') return os.path.join(templates_dir, file_name) @@ -131,7 +148,7 @@ class AirflowConfigParser(ConfigParser): # is to not store password on boxes in text files. # These configs can also be fetched from Secrets backend # following the "{section}__{name}__secret" pattern - sensitive_config_values = { + sensitive_config_values: Set[Tuple[str, str]] = { ('database', 'sql_alchemy_conn'), ('core', 'fernet_key'), ('celery', 'broker_url'), @@ -147,7 +164,7 @@ class AirflowConfigParser(ConfigParser): # A mapping of (new section, new option) -> (old section, old option, since_version). # When reading new option, the old option will be checked to see if it exists. If it does a # DeprecationWarning will be issued and the old option will be used instead - deprecated_options = { + deprecated_options: Dict[Tuple[str, str], Tuple[str, str, str]] = { ('celery', 'worker_precheck'): ('core', 'worker_precheck', '2.0.0'), ('logging', 'base_log_folder'): ('core', 'base_log_folder', '2.0.0'), ('logging', 'remote_logging'): ('core', 'remote_logging', '2.0.0'), @@ -206,7 +223,7 @@ class AirflowConfigParser(ConfigParser): # A mapping of old default values that we want to change and warn the user # about. Mapping of section -> setting -> { old, replace, by_version } - deprecated_values = { + deprecated_values: Dict[str, Dict[str, Tuple[Pattern, str, str]]] = { 'core': { 'hostname_callable': (re.compile(r':'), r'.', '2.1'), }, @@ -225,7 +242,7 @@ class AirflowConfigParser(ConfigParser): 'log_filename_template': ( re.compile(re.escape("{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log")), "XX-set-after-default-config-loaded-XX", - 3.0, + '3.0', ), }, 'api': { @@ -239,7 +256,7 @@ class AirflowConfigParser(ConfigParser): 'log_id_template': ( re.compile('^' + re.escape('{dag_id}-{task_id}-{run_id}-{try_number}') + '$'), '{dag_id}-{task_id}-{run_id}-{map_index}-{try_number}', - 3.0, + '3.0', ) }, } @@ -260,12 +277,12 @@ class AirflowConfigParser(ConfigParser): """Mapping of (section,option) to the old value that was upgraded""" # This method transforms option names on every read, get, or set operation. - # This changes from the default behaviour of ConfigParser from lowercasing + # This changes from the default behaviour of ConfigParser from lower-casing # to instead be case-preserving def optionxform(self, optionstr: str) -> str: return optionstr - def __init__(self, default_config=None, *args, **kwargs): + def __init__(self, default_config: Optional[str] = None, *args, **kwargs): super().__init__(*args, **kwargs) self.upgraded_values = {} @@ -392,7 +409,7 @@ def _validate_config_dependencies(self): from airflow.utils.docs import get_docs_url - # Some of the features in storing rendered fields require sqlite version >= 3.15.0 + # Some features in storing rendered fields require sqlite version >= 3.15.0 min_sqlite_version = (3, 15, 0) if _parse_sqlite_version(sqlite3.sqlite_version) < min_sqlite_version: min_sqlite_version_str = ".".join(str(s) for s in min_sqlite_version) @@ -401,16 +418,16 @@ def _validate_config_dependencies(self): f"See {get_docs_url('howto/set-up-database.html#setting-up-a-sqlite-database')}" ) - def _using_old_value(self, old, current_value): + def _using_old_value(self, old: Pattern, current_value: str) -> bool: return old.search(current_value) is not None - def _update_env_var(self, section, name, new_value): + def _update_env_var(self, section: str, name: str, new_value: Union[str]): env_var = self._env_var_name(section, name) # Set it as an env var so that any subprocesses keep the same override! os.environ[env_var] = new_value @staticmethod - def _create_future_warning(name, section, current_value, new_value, version): + def _create_future_warning(name: str, section: str, current_value: Any, new_value: Any, version: str): warnings.warn( f'The {name!r} setting in [{section}] has the old default value of {current_value!r}. ' f'This value has been changed to {new_value!r} in the running config, but ' @@ -423,7 +440,7 @@ def _create_future_warning(name, section, current_value, new_value, version): def _env_var_name(self, section: str, key: str) -> str: return f'{self.ENV_VAR_PREFIX}{section.upper()}__{key.upper()}' - def _get_env_var_option(self, section, key): + def _get_env_var_option(self, section: str, key: str): # must have format AIRFLOW__{SECTION}__{KEY} (note double underscore) env_var = self._env_var_name(section, key) if env_var in os.environ: @@ -442,7 +459,7 @@ def _get_env_var_option(self, section, key): return _get_config_value_from_secret_backend(os.environ[env_var_secret_path]) return None - def _get_cmd_option(self, section, key): + def _get_cmd_option(self, section: str, key: str): fallback_key = key + '_cmd' # if this is a valid command key... if (section, key) in self.sensitive_config_values: @@ -451,7 +468,7 @@ def _get_cmd_option(self, section, key): return run_command(command) return None - def _get_secret_option(self, section, key): + def _get_secret_option(self, section: str, key: str) -> Optional[str]: """Get Config option values from Secret Backend""" fallback_key = key + '_secret' # if this is a valid secret key... @@ -461,7 +478,13 @@ def _get_secret_option(self, section, key): return _get_config_value_from_secret_backend(secrets_path) return None - def get(self, section, key, **kwargs): + def get_mandatory_value(self, section: str, key: str, **kwargs) -> str: + value = self.get(section, key, **kwargs) + if value is None: + raise ValueError(f"The value {section}/{key} should be set!") + return value + + def get(self, section: str, key: str, **kwargs) -> Optional[str]: # type: ignore[override] section = str(section).lower() key = str(key).lower() @@ -487,7 +510,7 @@ def get(self, section, key, **kwargs): return self._get_option_from_default_config(section, key, **kwargs) - def _get_option_from_default_config(self, section, key, **kwargs): + def _get_option_from_default_config(self, section: str, key: str, **kwargs) -> Optional[str]: # ...then the default config if self.airflow_defaults.has_option(section, key) or 'fallback' in kwargs: return expand_env_var(self.airflow_defaults.get(section, key, **kwargs)) @@ -497,55 +520,68 @@ def _get_option_from_default_config(self, section, key, **kwargs): raise AirflowConfigException(f"section/key [{section}/{key}] not found in config") - def _get_option_from_secrets(self, deprecated_key, deprecated_section, key, section): + def _get_option_from_secrets( + self, deprecated_key: Optional[str], deprecated_section: Optional[str], key: str, section: str + ) -> Optional[str]: # ...then from secret backends option = self._get_secret_option(section, key) if option: return option - if deprecated_section: + if deprecated_section and deprecated_key: option = self._get_secret_option(deprecated_section, deprecated_key) if option: self._warn_deprecate(section, key, deprecated_section, deprecated_key) return option return None - def _get_option_from_commands(self, deprecated_key, deprecated_section, key, section): + def _get_option_from_commands( + self, deprecated_key: Optional[str], deprecated_section: Optional[str], key: str, section: str + ) -> Optional[str]: # ...then commands option = self._get_cmd_option(section, key) if option: return option - if deprecated_section: + if deprecated_section and deprecated_key: option = self._get_cmd_option(deprecated_section, deprecated_key) if option: self._warn_deprecate(section, key, deprecated_section, deprecated_key) return option return None - def _get_option_from_config_file(self, deprecated_key, deprecated_section, key, kwargs, section): + def _get_option_from_config_file( + self, + deprecated_key: Optional[str], + deprecated_section: Optional[str], + key: str, + kwargs: Dict[str, Any], + section: str, + ) -> Optional[str]: # ...then the config file if super().has_option(section, key): # Use the parent's methods to get the actual config here to be able to # separate the config from default config. return expand_env_var(super().get(section, key, **kwargs)) - if deprecated_section: + if deprecated_section and deprecated_key: if super().has_option(deprecated_section, deprecated_key): self._warn_deprecate(section, key, deprecated_section, deprecated_key) return expand_env_var(super().get(deprecated_section, deprecated_key, **kwargs)) return None - def _get_environment_variables(self, deprecated_key, deprecated_section, key, section): + def _get_environment_variables( + self, deprecated_key: Optional[str], deprecated_section: Optional[str], key: str, section: str + ) -> Optional[str]: # first check environment variables option = self._get_env_var_option(section, key) if option is not None: return option - if deprecated_section: + if deprecated_section and deprecated_key: option = self._get_env_var_option(deprecated_section, deprecated_key) if option is not None: self._warn_deprecate(section, key, deprecated_section, deprecated_key) return option return None - def getboolean(self, section, key, **kwargs): + def getboolean(self, section: str, key: str, **kwargs) -> bool: # type: ignore[override] val = str(self.get(section, key, **kwargs)).lower().strip() if '#' in val: val = val.split('#')[0].strip() @@ -559,9 +595,13 @@ def getboolean(self, section, key, **kwargs): f'Current value: "{val}".' ) - def getint(self, section, key, **kwargs): + def getint(self, section: str, key: str, **kwargs) -> int: # type: ignore[override] val = self.get(section, key, **kwargs) - + if val is None: + raise AirflowConfigException( + f'Failed to convert value None to int. ' + f'Please check "{key}" key in "{section}" section is set.' + ) try: return int(val) except ValueError: @@ -570,9 +610,13 @@ def getint(self, section, key, **kwargs): f'Current value: "{val}".' ) - def getfloat(self, section, key, **kwargs): + def getfloat(self, section: str, key: str, **kwargs) -> float: # type: ignore[override] val = self.get(section, key, **kwargs) - + if val is None: + raise AirflowConfigException( + f'Failed to convert value None to float. ' + f'Please check "{key}" key in "{section}" section is set.' + ) try: return float(val) except ValueError: @@ -581,7 +625,7 @@ def getfloat(self, section, key, **kwargs): f'Current value: "{val}".' ) - def getimport(self, section, key, **kwargs): + def getimport(self, section: str, key: str, **kwargs) -> Any: """ Reads options, imports the full qualified name, and returns the object. @@ -602,7 +646,9 @@ def getimport(self, section, key, **kwargs): f'Current value: "{full_qualified_path}".' ) - def getjson(self, section, key, fallback=_UNSET, **kwargs) -> Union[dict, list, str, int, float, None]: + def getjson( + self, section: str, key: str, fallback=_UNSET, **kwargs + ) -> Union[dict, list, str, int, float, None]: """ Return a config value parsed from a JSON string. @@ -620,7 +666,7 @@ def getjson(self, section, key, fallback=_UNSET, **kwargs) -> Union[dict, list, except (NoSectionError, NoOptionError): return default - if len(data) == 0: + if not data: return default if default is not _UNSET else None try: @@ -628,7 +674,9 @@ def getjson(self, section, key, fallback=_UNSET, **kwargs) -> Union[dict, list, except JSONDecodeError as e: raise AirflowConfigException(f'Unable to parse [{section}] {key!r} as valid json') from e - def gettimedelta(self, section, key, fallback=None, **kwargs) -> Optional[datetime.timedelta]: + def gettimedelta( + self, section: str, key: str, fallback: Any = None, **kwargs + ) -> Optional[datetime.timedelta]: """ Gets the config value for the given section and key, and converts it into datetime.timedelta object. If the key is missing, then it is considered as `None`. @@ -662,13 +710,26 @@ def gettimedelta(self, section, key, fallback=None, **kwargs) -> Optional[dateti return fallback - def read(self, filenames, encoding=None): + def read( + self, + filenames: Union[ + str, + bytes, + os.PathLike, + Iterable[Union[str, bytes, os.PathLike]], + ], + encoding=None, + ): super().read(filenames=filenames, encoding=encoding) - def read_dict(self, dictionary, source=''): + # The RawConfigParser defines "Mapping" from abc.collections is not subscriptable - so we have + # to use Dict here. + def read_dict( # type: ignore[override] + self, dictionary: Dict[str, Dict[str, Any]], source: str = '' + ): super().read_dict(dictionary=dictionary, source=source) - def has_option(self, section, option): + def has_option(self, section: str, option: str) -> bool: try: # Using self.get() to avoid reimplementing the priority order # of config variables (env, config, cmd, defaults) @@ -678,7 +739,7 @@ def has_option(self, section, option): except (NoOptionError, NoSectionError): return False - def remove_option(self, section, option, remove_default=True): + def remove_option(self, section: str, option: str, remove_default: bool = True): """ Remove an option if it exists in config from a file or default config. If both of config have the same option, this removes @@ -690,7 +751,7 @@ def remove_option(self, section, option, remove_default=True): if self.airflow_defaults.has_option(section, option) and remove_default: self.airflow_defaults.remove_option(section, option) - def getsection(self, section: str) -> Optional[Dict[str, Union[str, int, float, bool]]]: + def getsection(self, section: str) -> Optional[ConfigOptionsDictType]: """ Returns the section as a dict. Values are converted to int, float, bool as required. @@ -700,9 +761,8 @@ def getsection(self, section: str) -> Optional[Dict[str, Union[str, int, float, """ if not self.has_section(section) and not self.airflow_defaults.has_section(section): return None - if self.airflow_defaults.has_section(section): - _section = OrderedDict(self.airflow_defaults.items(section)) + _section: ConfigOptionsDictType = OrderedDict(self.airflow_defaults.items(section)) else: _section = OrderedDict() @@ -719,40 +779,48 @@ def getsection(self, section: str) -> Optional[Dict[str, Union[str, int, float, _section[key] = self._get_env_var_option(section, key) for key, val in _section.items(): + if val is None: + raise AirflowConfigException( + f'Failed to convert value automatically. ' + f'Please check "{key}" key in "{section}" section is set.' + ) try: - val = int(val) + _section[key] = int(val) except ValueError: try: - val = float(val) + _section[key] = float(val) except ValueError: - if val.lower() in ('t', 'true'): - val = True - elif val.lower() in ('f', 'false'): - val = False - _section[key] = val + if isinstance(val, str) and val.lower() in ('t', 'true'): + _section[key] = True + elif isinstance(val, str) and val.lower() in ('f', 'false'): + _section[key] = False return _section - def write(self, fp, space_around_delimiters=True): + def write(self, fp: IO, space_around_delimiters: bool = True): # type: ignore[override] # This is based on the configparser.RawConfigParser.write method code to add support for # reading options from environment variables. + # Various type ignores below deal with less-than-perfect RawConfigParser superclass typing if space_around_delimiters: - delimiter = f" {self._delimiters[0]} " + delimiter = f" {self._delimiters[0]} " # type: ignore[attr-defined] else: - delimiter = self._delimiters[0] - if self._defaults: - self._write_section(fp, self.default_section, self._defaults.items(), delimiter) - for section in self._sections: - self._write_section(fp, section, self.getsection(section).items(), delimiter) + delimiter = self._delimiters[0] # type: ignore[attr-defined] + if self._defaults: # type: ignore + self._write_section( # type: ignore[attr-defined] + fp, self.default_section, self._defaults.items(), delimiter # type: ignore[attr-defined] + ) + for section in self._sections: # type: ignore[attr-defined] + item_section: ConfigOptionsDictType = self.getsection(section) # type: ignore[assignment] + self._write_section(fp, section, item_section.items(), delimiter) # type: ignore[attr-defined] def as_dict( self, - display_source=False, - display_sensitive=False, - raw=False, - include_env=True, - include_cmds=True, - include_secret=True, - ) -> Dict[str, Dict[str, str]]: + display_source: bool = False, + display_sensitive: bool = False, + raw: bool = False, + include_env: bool = True, + include_cmds: bool = True, + include_secret: bool = True, + ) -> ConfigSourcesType: """ Returns the current configuration as an OrderedDict of OrderedDicts. @@ -785,7 +853,7 @@ def as_dict( :return: Dictionary, where the key is the name of the section and the content is the dictionary with the name of the parameter and its value. """ - config_sources: Dict[str, Dict[str, str]] = {} + config_sources: ConfigSourcesType = {} configs = [ ('default', self.airflow_defaults), ('airflow.cfg', self), @@ -813,20 +881,34 @@ def as_dict( return config_sources - def _include_secrets(self, config_sources, display_sensitive, display_source, raw): + def _include_secrets( + self, + config_sources: ConfigSourcesType, + display_sensitive: bool, + display_source: bool, + raw: bool, + ): for (section, key) in self.sensitive_config_values: - opt = self._get_secret_option(section, key) - if opt: + value: Optional[str] = self._get_secret_option(section, key) + if value: if not display_sensitive: - opt = '< hidden >' + value = '< hidden >' if display_source: - opt = (opt, 'secret') + opt: Union[str, Tuple[str, str]] = (value, 'secret') elif raw: - opt = opt.replace('%', '%%') + opt = value.replace('%', '%%') + else: + opt = value config_sources.setdefault(section, OrderedDict()).update({key: opt}) del config_sources[section][key + '_secret'] - def _include_commands(self, config_sources, display_sensitive, display_source, raw): + def _include_commands( + self, + config_sources: ConfigSourcesType, + display_sensitive: bool, + display_source: bool, + raw: bool, + ): for (section, key) in self.sensitive_config_values: opt = self._get_cmd_option(section, key) if not opt: @@ -840,7 +922,13 @@ def _include_commands(self, config_sources, display_sensitive, display_source, r config_sources.setdefault(section, OrderedDict()).update({key: opt}) del config_sources[section][key + '_cmd'] - def _include_envs(self, config_sources, display_sensitive, display_source, raw): + def _include_envs( + self, + config_sources: ConfigSourcesType, + display_sensitive: bool, + display_source: bool, + raw: bool, + ): for env_var in [ os_environment for os_environment in os.environ if os_environment.startswith(self.ENV_VAR_PREFIX) ]: @@ -868,7 +956,12 @@ def _include_envs(self, config_sources, display_sensitive, display_source, raw): key = key.lower() config_sources.setdefault(section, OrderedDict()).update({key: opt}) - def _filter_by_source(self, config_sources, display_source, getter_func): + def _filter_by_source( + self, + config_sources: ConfigSourcesType, + display_source: bool, + getter_func, + ): """ Deletes default configs from current configuration (an OrderedDict of OrderedDicts) if it would conflict with special sensitive_config_values. @@ -904,14 +997,20 @@ def _filter_by_source(self, config_sources, display_source, getter_func): continue # Check to see if bare setting is the same as defaults if display_source: - opt, source = config_sources[section][key] + # when display_source = true, we know that the config_sources contains tuple + opt, source = config_sources[section][key] # type: ignore else: opt = config_sources[section][key] if opt == self.airflow_defaults.get(section, key): del config_sources[section][key] @staticmethod - def _replace_config_with_display_sources(config_sources, configs, display_source, raw): + def _replace_config_with_display_sources( + config_sources: ConfigSourcesType, + configs: Iterable[Tuple[str, ConfigParser]], + display_source: bool, + raw: bool, + ): for (source_name, config) in configs: for section in config.sections(): AirflowConfigParser._replace_section_config_with_display_sources( @@ -920,13 +1019,19 @@ def _replace_config_with_display_sources(config_sources, configs, display_source @staticmethod def _replace_section_config_with_display_sources( - config, config_sources, display_source, raw, section, source_name + config: ConfigParser, + config_sources: ConfigSourcesType, + display_source: bool, + raw: bool, + section: str, + source_name: str, ): sect = config_sources.setdefault(section, OrderedDict()) for (k, val) in config.items(section=section, raw=raw): if display_source: - val = (val, source_name) - sect[k] = val + sect[k] = (val, source_name) + else: + sect[k] = val def load_test_config(self): """ @@ -948,7 +1053,7 @@ def load_test_config(self): self.read(TEST_CONFIG_FILE) @staticmethod - def _warn_deprecate(section, key, deprecated_section, deprecated_name): + def _warn_deprecate(section: str, key: str, deprecated_section: str, deprecated_name: str): if section == deprecated_section: warnings.warn( f'The {deprecated_name} option in [{section}] has been renamed to {key} - ' @@ -981,16 +1086,17 @@ def __setstate__(self, state): self.__dict__.update(state) -def get_airflow_home(): +def get_airflow_home() -> str: """Get path to Airflow Home""" return expand_env_var(os.environ.get('AIRFLOW_HOME', '~/airflow')) -def get_airflow_config(airflow_home): +def get_airflow_config(airflow_home) -> str: """Get Path to airflow.cfg path""" - if 'AIRFLOW_CONFIG' not in os.environ: + airflow_config_var = os.environ.get('AIRFLOW_CONFIG') + if airflow_config_var is None: return os.path.join(airflow_home, 'airflow.cfg') - return expand_env_var(os.environ['AIRFLOW_CONFIG']) + return expand_env_var(airflow_config_var) def _parameterized_config_from_template(filename) -> str: @@ -1005,7 +1111,7 @@ def _parameterized_config_from_template(filename) -> str: raise RuntimeError(f"Template marker not found in {path!r}") -def parameterized_config(template): +def parameterized_config(template) -> str: """ Generates a configuration from the provided template + variables defined in current scope @@ -1016,20 +1122,21 @@ def parameterized_config(template): return template.format(**all_vars) -def get_airflow_test_config(airflow_home): +def get_airflow_test_config(airflow_home) -> str: """Get path to unittests.cfg""" if 'AIRFLOW_TEST_CONFIG' not in os.environ: return os.path.join(airflow_home, 'unittests.cfg') - return expand_env_var(os.environ['AIRFLOW_TEST_CONFIG']) + # It will never return None + return expand_env_var(os.environ['AIRFLOW_TEST_CONFIG']) # type: ignore[return-value] -def _generate_fernet_key(): +def _generate_fernet_key() -> str: from cryptography.fernet import Fernet return Fernet.generate_key().decode() -def initialize_config(): +def initialize_config() -> AirflowConfigParser: """ Load the Airflow config files. @@ -1039,9 +1146,9 @@ def initialize_config(): default_config = _parameterized_config_from_template('default_airflow.cfg') - conf = AirflowConfigParser(default_config=default_config) + local_conf = AirflowConfigParser(default_config=default_config) - if conf.getboolean('core', 'unit_test_mode'): + if local_conf.getboolean('core', 'unit_test_mode'): # Load test config only if not os.path.isfile(TEST_CONFIG_FILE): from cryptography.fernet import Fernet @@ -1055,7 +1162,7 @@ def initialize_config(): cfg = _parameterized_config_from_template('default_test.cfg') file.write(cfg) - conf.load_test_config() + local_conf.load_test_config() else: # Load normal config if not os.path.isfile(AIRFLOW_CONFIG): @@ -1071,9 +1178,9 @@ def initialize_config(): log.info("Reading the config from %s", AIRFLOW_CONFIG) - conf.read(AIRFLOW_CONFIG) + local_conf.read(AIRFLOW_CONFIG) - if conf.has_option('core', 'AIRFLOW_HOME'): + if local_conf.has_option('core', 'AIRFLOW_HOME'): msg = ( 'Specifying both AIRFLOW_HOME environment variable and airflow_home ' 'in the config file is deprecated. Please use only the AIRFLOW_HOME ' @@ -1081,7 +1188,7 @@ def initialize_config(): ) if 'AIRFLOW_HOME' in os.environ: warnings.warn(msg, category=DeprecationWarning) - elif conf.get('core', 'airflow_home') == AIRFLOW_HOME: + elif local_conf.get('core', 'airflow_home') == AIRFLOW_HOME: warnings.warn( 'Specifying airflow_home in the config file is deprecated. As you ' 'have left it at the default value you should remove the setting ' @@ -1089,13 +1196,14 @@ def initialize_config(): category=DeprecationWarning, ) else: - AIRFLOW_HOME = conf.get('core', 'airflow_home') + # there + AIRFLOW_HOME = local_conf.get('core', 'airflow_home') # type: ignore[assignment] warnings.warn(msg, category=DeprecationWarning) # They _might_ have set unit_test_mode in the airflow.cfg, we still # want to respect that and then load the unittests.cfg - if conf.getboolean('core', 'unit_test_mode'): - conf.load_test_config() + if local_conf.getboolean('core', 'unit_test_mode'): + local_conf.load_test_config() # Make it no longer a proxy variable, just set it to an actual string global WEBSERVER_CONFIG @@ -1106,7 +1214,7 @@ def initialize_config(): log.info('Creating new FAB webserver config file in: %s', WEBSERVER_CONFIG) shutil.copy(_default_config_file_path('default_webserver_config.py'), WEBSERVER_CONFIG) - return conf + return local_conf # Historical convenience functions to access config entries @@ -1122,7 +1230,7 @@ def load_test_config(): conf.load_test_config() -def get(*args, **kwargs): +def get(*args, **kwargs) -> Optional[ConfigType]: """Historical get""" warnings.warn( "Accessing configuration method 'get' directly from the configuration module is " @@ -1134,7 +1242,7 @@ def get(*args, **kwargs): return conf.get(*args, **kwargs) -def getboolean(*args, **kwargs): +def getboolean(*args, **kwargs) -> bool: """Historical getboolean""" warnings.warn( "Accessing configuration method 'getboolean' directly from the configuration module is " @@ -1146,7 +1254,7 @@ def getboolean(*args, **kwargs): return conf.getboolean(*args, **kwargs) -def getfloat(*args, **kwargs): +def getfloat(*args, **kwargs) -> float: """Historical getfloat""" warnings.warn( "Accessing configuration method 'getfloat' directly from the configuration module is " @@ -1158,7 +1266,7 @@ def getfloat(*args, **kwargs): return conf.getfloat(*args, **kwargs) -def getint(*args, **kwargs): +def getint(*args, **kwargs) -> int: """Historical getint""" warnings.warn( "Accessing configuration method 'getint' directly from the configuration module is " @@ -1170,7 +1278,7 @@ def getint(*args, **kwargs): return conf.getint(*args, **kwargs) -def getsection(*args, **kwargs): +def getsection(*args, **kwargs) -> Optional[ConfigOptionsDictType]: """Historical getsection""" warnings.warn( "Accessing configuration method 'getsection' directly from the configuration module is " @@ -1182,7 +1290,7 @@ def getsection(*args, **kwargs): return conf.getsection(*args, **kwargs) -def has_option(*args, **kwargs): +def has_option(*args, **kwargs) -> bool: """Historical has_option""" warnings.warn( "Accessing configuration method 'has_option' directly from the configuration module is " @@ -1194,7 +1302,7 @@ def has_option(*args, **kwargs): return conf.has_option(*args, **kwargs) -def remove_option(*args, **kwargs): +def remove_option(*args, **kwargs) -> bool: """Historical remove_option""" warnings.warn( "Accessing configuration method 'remove_option' directly from the configuration module is " @@ -1206,7 +1314,7 @@ def remove_option(*args, **kwargs): return conf.remove_option(*args, **kwargs) -def as_dict(*args, **kwargs): +def as_dict(*args, **kwargs) -> ConfigSourcesType: """Historical as_dict""" warnings.warn( "Accessing configuration method 'as_dict' directly from the configuration module is " @@ -1218,7 +1326,7 @@ def as_dict(*args, **kwargs): return conf.as_dict(*args, **kwargs) -def set(*args, **kwargs): +def set(*args, **kwargs) -> None: """Historical set""" warnings.warn( "Accessing configuration method 'set' directly from the configuration module is " @@ -1227,7 +1335,7 @@ def set(*args, **kwargs): DeprecationWarning, stacklevel=2, ) - return conf.set(*args, **kwargs) + conf.set(*args, **kwargs) def ensure_secrets_loaded() -> List[BaseSecretsBackend]: @@ -1247,9 +1355,8 @@ def get_custom_secret_backend() -> Optional[BaseSecretsBackend]: if secrets_backend_cls: try: - alternative_secrets_config_dict = json.loads( - conf.get(section='secrets', key='backend_kwargs', fallback='{}') - ) + backends: Any = conf.get(section='secrets', key='backend_kwargs', fallback='{}') + alternative_secrets_config_dict = json.loads(backends) except JSONDecodeError: alternative_secrets_config_dict = {} @@ -1277,14 +1384,14 @@ def initialize_secrets_backends() -> List[BaseSecretsBackend]: @functools.lru_cache(maxsize=None) -def _DEFAULT_CONFIG(): +def _DEFAULT_CONFIG() -> str: path = _default_config_file_path('default_airflow.cfg') with open(path) as fh: return fh.read() @functools.lru_cache(maxsize=None) -def _TEST_CONFIG(): +def _TEST_CONFIG() -> str: path = _default_config_file_path('default_test.cfg') with open(path) as fh: return fh.read() @@ -1311,7 +1418,6 @@ def __getattr__(name): # Setting AIRFLOW_HOME and AIRFLOW_CONFIG from environment variables, using # "~/airflow" and "$AIRFLOW_HOME/airflow.cfg" respectively as defaults. - AIRFLOW_HOME = get_airflow_home() AIRFLOW_CONFIG = get_airflow_config(AIRFLOW_HOME) diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index f16a913725401..dcfb6bedafb90 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -395,7 +395,10 @@ def __init__( os.set_blocking(self._direct_scheduler_conn.fileno(), False) self._parallelism = conf.getint('scheduler', 'parsing_processes') - if conf.get('database', 'sql_alchemy_conn').startswith('sqlite') and self._parallelism > 1: + if ( + conf.get_mandatory_value('database', 'sql_alchemy_conn').startswith('sqlite') + and self._parallelism > 1 + ): self.log.warning( "Because we cannot use more than 1 thread (parsing_processes = " "%d) when using sqlite. So we set parallelism to 1.", diff --git a/airflow/executors/executor_loader.py b/airflow/executors/executor_loader.py index b98d2a80005dc..723060db0f179 100644 --- a/airflow/executors/executor_loader.py +++ b/airflow/executors/executor_loader.py @@ -71,8 +71,7 @@ def get_default_executor(cls) -> "BaseExecutor": from airflow.configuration import conf - executor_name = conf.get('core', 'EXECUTOR') - + executor_name = conf.get_mandatory_value('core', 'EXECUTOR') cls._default_executor = cls.load_executor(executor_name) return cls._default_executor diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 034baa179bb5e..0a93d07ffb992 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -142,7 +142,7 @@ def __init__( self._log = log # Check what SQL backend we use - sql_conn: str = conf.get('database', 'sql_alchemy_conn').lower() + sql_conn: str = conf.get_mandatory_value('database', 'sql_alchemy_conn').lower() self.using_sqlite = sql_conn.startswith('sqlite') self.using_mysql = sql_conn.startswith('mysql') # Dag Processor agent - not used in Dag Processor standalone mode. @@ -153,7 +153,10 @@ def __init__( if conf.getboolean('smart_sensor', 'use_smart_sensor'): compatible_sensors = set( - map(lambda l: l.strip(), conf.get('smart_sensor', 'sensors_enabled').split(',')) + map( + lambda l: l.strip(), + conf.get_mandatory_value('smart_sensor', 'sensors_enabled').split(','), + ) ) docs_url = get_docs_url('concepts/smart-sensors.html#migrating-to-deferrable-operators') warnings.warn( diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index 2c53a8b75d8cc..8d2e06442a2e5 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -57,17 +57,17 @@ from airflow.models.operator import Operator from airflow.models.taskinstance import TaskInstance -DEFAULT_OWNER: str = conf.get("operators", "default_owner") +DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") DEFAULT_POOL_SLOTS: int = 1 DEFAULT_PRIORITY_WEIGHT: int = 1 -DEFAULT_QUEUE: str = conf.get("operators", "default_queue") +DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue") DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0) DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(seconds=300) DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) ) DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS -DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta = conf.gettimedelta( +DEFAULT_TASK_EXECUTION_TIMEOUT: Optional[datetime.timedelta] = conf.gettimedelta( "core", "default_task_execution_timeout" ) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 0fc8250ae6638..f48883ec6bc98 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -326,8 +326,8 @@ def __init__( sla_miss_callback: Optional[ Callable[["DAG", str, str, List["SlaMiss"], List[TaskInstance]], None] ] = None, - default_view: str = conf.get('webserver', 'dag_default_view').lower(), - orientation: str = conf.get('webserver', 'dag_orientation'), + default_view: str = conf.get_mandatory_value('webserver', 'dag_default_view').lower(), + orientation: str = conf.get_mandatory_value('webserver', 'dag_orientation'), catchup: bool = conf.getboolean('scheduler', 'catchup_by_default'), on_success_callback: Optional[DagStateChangeCallback] = None, on_failure_callback: Optional[DagStateChangeCallback] = None, @@ -2815,7 +2815,7 @@ def get_default_view(self) -> str: have a value """ # This is for backwards-compatibility with old dags that don't have None as default_view - return self.default_view or conf.get('webserver', 'dag_default_view').lower() + return self.default_view or conf.get_mandatory_value('webserver', 'dag_default_view').lower() @property def safe_dag_id(self): diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index c49867c107377..0d30727615c5c 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2228,7 +2228,7 @@ def get_email_subject_content( def render(key: str, content: str) -> str: if conf.has_option('email', key): - path = conf.get('email', key) + path = conf.get_mandatory_value('email', key) with open(path) as f: content = f.read() return render_template_to_string(jinja_env.from_string(content), jinja_context) diff --git a/airflow/providers/apache/hive/operators/hive.py b/airflow/providers/apache/hive/operators/hive.py index b4581aaea3fe3..45cae0fa4e31f 100644 --- a/airflow/providers/apache/hive/operators/hive.py +++ b/airflow/providers/apache/hive/operators/hive.py @@ -100,11 +100,15 @@ def __init__( self.mapred_queue = mapred_queue self.mapred_queue_priority = mapred_queue_priority self.mapred_job_name = mapred_job_name - self.mapred_job_name_template = conf.get( + + job_name_template = conf.get( 'hive', 'mapred_job_name_template', fallback="Airflow HiveOperator task for {hostname}.{dag_id}.{task_id}.{execution_date}", ) + if job_name_template is None: + raise ValueError("Job name template should be set !") + self.mapred_job_name_template: str = job_name_template # assigned lazily - just for consistency we can create the attribute with a # `None` initial value, later it will be populated by the execute method. diff --git a/airflow/providers/apache/spark/hooks/spark_submit.py b/airflow/providers/apache/spark/hooks/spark_submit.py index bcbeb4b99c032..0f5dc2f7307cc 100644 --- a/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/airflow/providers/apache/spark/hooks/spark_submit.py @@ -632,7 +632,7 @@ def on_kill(self) -> None: # we still attempt to kill the yarn application renew_from_kt(self._principal, self._keytab, exit_on_fail=False) env = os.environ.copy() - env["KRB5CCNAME"] = airflow_conf.get('kerberos', 'ccache') + env["KRB5CCNAME"] = airflow_conf.get_mandatory_value('kerberos', 'ccache') with subprocess.Popen( kill_cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE diff --git a/airflow/providers/qubole/hooks/qubole.py b/airflow/providers/qubole/hooks/qubole.py index a62243e9a037d..7896fbd352807 100644 --- a/airflow/providers/qubole/hooks/qubole.py +++ b/airflow/providers/qubole/hooks/qubole.py @@ -227,7 +227,7 @@ def get_results( """ if fp is None: iso = datetime.datetime.utcnow().isoformat() - logpath = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER')) + logpath = os.path.expanduser(conf.get_mandatory_value('logging', 'BASE_LOG_FOLDER')) resultpath = logpath + '/' + self.dag_id + '/' + self.task_id + '/results' pathlib.Path(resultpath).mkdir(parents=True, exist_ok=True) fp = open(resultpath + '/' + iso, 'wb') diff --git a/airflow/security/kerberos.py b/airflow/security/kerberos.py index ce8043b2eb35e..e8fc86af7259c 100644 --- a/airflow/security/kerberos.py +++ b/airflow/security/kerberos.py @@ -38,7 +38,7 @@ import subprocess import sys import time -from typing import Optional +from typing import List, Optional from airflow.configuration import conf @@ -59,7 +59,9 @@ def renew_from_kt(principal: Optional[str], keytab: str, exit_on_fail: bool = Tr # minutes to give ourselves a large renewal buffer. renewal_lifetime = f"{conf.getint('kerberos', 'reinit_frequency')}m" - cmd_principal = principal or conf.get('kerberos', 'principal').replace("_HOST", socket.getfqdn()) + cmd_principal = principal or conf.get_mandatory_value('kerberos', 'principal').replace( + "_HOST", socket.getfqdn() + ) if conf.getboolean('kerberos', 'forwardable'): forwardable = '-f' @@ -71,8 +73,8 @@ def renew_from_kt(principal: Optional[str], keytab: str, exit_on_fail: bool = Tr else: include_ip = '-A' - cmdv = [ - conf.get('kerberos', 'kinit_path'), + cmdv: List[str] = [ + conf.get_mandatory_value('kerberos', 'kinit_path'), forwardable, include_ip, "-r", @@ -81,7 +83,7 @@ def renew_from_kt(principal: Optional[str], keytab: str, exit_on_fail: bool = Tr "-t", keytab, # specify keytab "-c", - conf.get('kerberos', 'ccache'), # specify credentials cache + conf.get_mandatory_value('kerberos', 'ccache'), # specify credentials cache cmd_principal, ] log.info("Re-initialising kerberos from keytab: %s", " ".join(shlex.quote(f) for f in cmdv)) @@ -129,10 +131,10 @@ def perform_krb181_workaround(principal: str): :param principal: principal name :return: None """ - cmdv = [ - conf.get('kerberos', 'kinit_path'), + cmdv: List[str] = [ + conf.get_mandatory_value('kerberos', 'kinit_path'), "-c", - conf.get('kerberos', 'ccache'), + conf.get_mandatory_value('kerberos', 'ccache'), "-R", ] # Renew ticket_cache @@ -162,7 +164,7 @@ def detect_conf_var() -> bool: Sun Java Krb5LoginModule in Java6, so we need to take an action to work around it. """ - ticket_cache = conf.get('kerberos', 'ccache') + ticket_cache = conf.get_mandatory_value('kerberos', 'ccache') with open(ticket_cache, 'rb') as file: # Note: this file is binary, so we check against a bytearray. diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index 5cbe009c82b95..7f1cd87c3dd65 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -143,7 +143,7 @@ def __init__( self._validate_input_values() self.sensor_service_enabled = conf.getboolean('smart_sensor', 'use_smart_sensor') self.sensors_support_sensor_service = set( - map(lambda l: l.strip(), conf.get('smart_sensor', 'sensors_enabled').split(',')) + map(lambda l: l.strip(), conf.get_mandatory_value('smart_sensor', 'sensors_enabled').split(',')) ) def _validate_input_values(self) -> None: diff --git a/airflow/settings.py b/airflow/settings.py index 5693575ed9b93..8b50bde0a179e 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -45,7 +45,7 @@ TIMEZONE = pendulum.tz.timezone('UTC') try: - tz = conf.get("core", "default_timezone") + tz = conf.get_mandatory_value("core", "default_timezone") if tz == "system": TIMEZONE = pendulum.tz.local_timezone() else: @@ -77,7 +77,7 @@ PLUGINS_FOLDER: Optional[str] = None LOGGING_CLASS_PATH: Optional[str] = None DONOT_MODIFY_HANDLERS: Optional[bool] = None -DAGS_FOLDER: str = os.path.expanduser(conf.get('core', 'DAGS_FOLDER')) +DAGS_FOLDER: str = os.path.expanduser(conf.get_mandatory_value('core', 'DAGS_FOLDER')) engine: Engine Session: Callable[..., SASession] diff --git a/airflow/utils/email.py b/airflow/utils/email.py index f9b8fe23efe2f..ec0095e983056 100644 --- a/airflow/utils/email.py +++ b/airflow/utils/email.py @@ -94,9 +94,13 @@ def send_email_smtp( """ smtp_mail_from = conf.get('smtp', 'SMTP_MAIL_FROM') - if smtp_mail_from: + if smtp_mail_from is not None: mail_from = smtp_mail_from else: + if from_email is None: + raise Exception( + "You should set from email - either by smtp/smtp_mail_from config or " "`from_email parameter" + ) mail_from = from_email msg, recipients = build_mime_message( @@ -188,7 +192,7 @@ def send_mime_email( dryrun: bool = False, ) -> None: """Send MIME email.""" - smtp_host = conf.get('smtp', 'SMTP_HOST') + smtp_host = conf.get_mandatory_value('smtp', 'SMTP_HOST') smtp_port = conf.getint('smtp', 'SMTP_PORT') smtp_starttls = conf.getboolean('smtp', 'SMTP_STARTTLS') smtp_ssl = conf.getboolean('smtp', 'SMTP_SSL') diff --git a/airflow/utils/file.py b/airflow/utils/file.py index 90be4ab172842..89d0b7d9fc97a 100644 --- a/airflow/utils/file.py +++ b/airflow/utils/file.py @@ -247,7 +247,7 @@ def _find_path_from_directory( def find_path_from_directory( base_dir_path: str, ignore_file_name: str, - ignore_file_syntax: str = conf.get('core', 'DAG_IGNORE_FILE_SYNTAX', fallback="regexp"), + ignore_file_syntax: str = conf.get_mandatory_value('core', 'DAG_IGNORE_FILE_SYNTAX', fallback="regexp"), ) -> Generator[str, None, None]: """ Recursively search the base path and return the list of file paths that should not be ignored. diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index f4a1107bc37cd..e8366805b4ee8 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -256,7 +256,7 @@ def build_airflow_url_with_query(query: Dict[str, Any]) -> str: """ import flask - view = conf.get('webserver', 'dag_default_view').lower() + view = conf.get_mandatory_value('webserver', 'dag_default_view').lower() return flask.url_for(f"Airflow.{view}", **query) diff --git a/airflow/utils/mixins.py b/airflow/utils/mixins.py index fbe0b24d561bc..4d0165e4dfc40 100644 --- a/airflow/utils/mixins.py +++ b/airflow/utils/mixins.py @@ -30,7 +30,7 @@ def _get_multiprocessing_start_method(self) -> str: mp_start_method is set in configs, else, it uses the OS default. """ if conf.has_option('core', 'mp_start_method'): - return conf.get('core', 'mp_start_method') + return conf.get_mandatory_value('core', 'mp_start_method') method = multiprocessing.get_start_method() if not method: diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py index de4ad01e6901e..6d751d106ab89 100644 --- a/airflow/utils/sqlalchemy.py +++ b/airflow/utils/sqlalchemy.py @@ -37,7 +37,7 @@ utc = pendulum.tz.timezone('UTC') -using_mysql = conf.get('database', 'sql_alchemy_conn').lower().startswith('mysql') +using_mysql = conf.get_mandatory_value('database', 'sql_alchemy_conn').lower().startswith('mysql') class UtcDateTime(TypeDecorator): diff --git a/tests/cli/commands/test_dag_processor_command.py b/tests/cli/commands/test_dag_processor_command.py index 8129c31047223..21cec4cee2f8b 100644 --- a/tests/cli/commands/test_dag_processor_command.py +++ b/tests/cli/commands/test_dag_processor_command.py @@ -43,7 +43,7 @@ def setUpClass(cls): ) @mock.patch("airflow.cli.commands.dag_processor_command.DagFileProcessorManager") @pytest.mark.skipif( - conf.get('database', 'sql_alchemy_conn').lower().startswith('sqlite'), + conf.get_mandatory_value('database', 'sql_alchemy_conn').lower().startswith('sqlite'), reason="Standalone Dag Processor doesn't support sqlite.", ) def test_start_manager( diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index fac04fd50802d..a998b69392693 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -426,7 +426,7 @@ def setUp(self) -> None: self.execution_date = timezone.make_aware(datetime(2017, 1, 1)) self.execution_date_str = self.execution_date.isoformat() self.task_args = ['tasks', 'run', self.dag_id, self.task_id, '--local', self.execution_date_str] - self.log_dir = conf.get('logging', 'base_log_folder') + self.log_dir = conf.get_mandatory_value('logging', 'base_log_folder') self.log_filename = f"dag_id={self.dag_id}/run_id={self.run_id}/task_id={self.task_id}/attempt=1.log" self.ti_log_file_path = os.path.join(self.log_dir, self.log_filename) self.parser = cli_parser.get_parser() diff --git a/tests/test_utils/system_tests_class.py b/tests/test_utils/system_tests_class.py index be9cf73486dae..846c34c3e02dd 100644 --- a/tests/test_utils/system_tests_class.py +++ b/tests/test_utils/system_tests_class.py @@ -19,6 +19,7 @@ import shutil from datetime import datetime from pathlib import Path +from typing import Optional from unittest import TestCase from airflow.configuration import AIRFLOW_HOME, AirflowConfigParser, get_airflow_config @@ -31,6 +32,12 @@ DEFAULT_DAG_FOLDER = os.path.join(AIRFLOW_MAIN_FOLDER, "airflow", "example_dags") +def get_default_logs_if_none(logs: Optional[str]) -> str: + if logs is None: + return os.path.join(AIRFLOW_HOME, 'logs') + return logs + + def resolve_logs_folder() -> str: """ Returns LOGS folder specified in current Airflow config. @@ -39,13 +46,13 @@ def resolve_logs_folder() -> str: conf = AirflowConfigParser() conf.read(config_file) try: - logs = conf.get("logging", "base_log_folder") + return get_default_logs_if_none(conf.get("logging", "base_log_folder")) except AirflowException: try: - logs = conf.get("core", "base_log_folder") + return get_default_logs_if_none(conf.get("core", "base_log_folder")) except AirflowException: - logs = os.path.join(AIRFLOW_HOME, 'logs') - return logs + pass + return get_default_logs_if_none(None) class SystemTest(TestCase, LoggingMixin):