Skip to content

Commit

Permalink
Replace OrderedDict with plain dict (#33508)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro committed Aug 20, 2023
1 parent c9452c8 commit 63e6eab
Show file tree
Hide file tree
Showing 19 changed files with 101 additions and 137 deletions.
15 changes: 7 additions & 8 deletions airflow/configuration.py
Expand Up @@ -31,7 +31,6 @@
import sys
import warnings
from base64 import b64encode
from collections import OrderedDict
from configparser import ConfigParser, NoOptionError, NoSectionError
from contextlib import contextmanager
from copy import deepcopy
Expand Down Expand Up @@ -1331,12 +1330,12 @@ def getsection(self, section: str) -> ConfigOptionsDictType | None:
if not self.has_section(section) and not self._default_values.has_section(section):
return None
if self._default_values.has_section(section):
_section: ConfigOptionsDictType = OrderedDict(self._default_values.items(section))
_section: ConfigOptionsDictType = dict(self._default_values.items(section))
else:
_section = OrderedDict()
_section = {}

if self.has_section(section):
_section.update(OrderedDict(self.items(section)))
_section.update(self.items(section))

section_prefix = self._env_var_name(section, "")
for env_var in sorted(os.environ.keys()):
Expand Down Expand Up @@ -1487,7 +1486,7 @@ def _include_secrets(
opt = value.replace("%", "%%")
else:
opt = value
config_sources.setdefault(section, OrderedDict()).update({key: opt})
config_sources.setdefault(section, {}).update({key: opt})
del config_sources[section][key + "_secret"]

def _include_commands(
Expand All @@ -1510,7 +1509,7 @@ def _include_commands(
opt_to_set = str(opt_to_set).replace("%", "%%")
if opt_to_set is not None:
dict_to_update: dict[str, str | tuple[str, str]] = {key: opt_to_set}
config_sources.setdefault(section, OrderedDict()).update(dict_to_update)
config_sources.setdefault(section, {}).update(dict_to_update)
del config_sources[section][key + "_cmd"]

def _include_envs(
Expand Down Expand Up @@ -1548,7 +1547,7 @@ def _include_envs(
# with AIRFLOW_. Therefore, we need to make it a special case.
if section != "kubernetes_environment_variables":
key = key.lower()
config_sources.setdefault(section, OrderedDict()).update({key: opt})
config_sources.setdefault(section, {}).update({key: opt})

def _filter_by_source(
self,
Expand Down Expand Up @@ -1709,7 +1708,7 @@ def _replace_section_config_with_display_sources(
include_cmds: bool,
include_secret: bool,
):
sect = config_sources.setdefault(section, OrderedDict())
sect = config_sources.setdefault(section, {})
if isinstance(config, AirflowConfigParser):
with config.suppress_future_warnings():
items: Iterable[tuple[str, Any]] = config.items(section=section, raw=raw)
Expand Down
4 changes: 2 additions & 2 deletions airflow/executors/base_executor.py
Expand Up @@ -21,7 +21,7 @@
import logging
import sys
import warnings
from collections import OrderedDict, defaultdict
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple
Expand Down Expand Up @@ -121,7 +121,7 @@ class BaseExecutor(LoggingMixin):
def __init__(self, parallelism: int = PARALLELISM):
super().__init__()
self.parallelism: int = parallelism
self.queued_tasks: OrderedDict[TaskInstanceKey, QueuedTaskInstanceType] = OrderedDict()
self.queued_tasks: dict[TaskInstanceKey, QueuedTaskInstanceType] = {}
self.running: set[TaskInstanceKey] = set()
self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {}
self.attempts: dict[TaskInstanceKey, RunningRetryAttemptType] = defaultdict(RunningRetryAttemptType)
Expand Down
7 changes: 3 additions & 4 deletions airflow/providers/apache/hive/hooks/hive.py
Expand Up @@ -24,7 +24,6 @@
import subprocess
import time
import warnings
from collections import OrderedDict
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import TYPE_CHECKING, Any, Iterable, Mapping

Expand Down Expand Up @@ -350,7 +349,7 @@ def load_df(
:param table: target Hive table, use dot notation to target a
specific database
:param field_dict: mapping from column name to hive data type.
Note that it must be OrderedDict so as to keep columns' order.
Note that Python dict is ordered so it keeps columns' order.
:param delimiter: field delimiter in the file
:param encoding: str encoding to use when writing DataFrame to file
:param pandas_kwargs: passed to DataFrame.to_csv
Expand All @@ -371,7 +370,7 @@ def _infer_field_types_from_df(df: pd.DataFrame) -> dict[Any, Any]:
"V": "STRING", # void
}

order_type = OrderedDict()
order_type = {}
for col, dtype in df.dtypes.items():
order_type[col] = dtype_kind_hive_type[dtype.kind]
return order_type
Expand Down Expand Up @@ -427,7 +426,7 @@ def load_file(
:param delimiter: field delimiter in the file
:param field_dict: A dictionary of the fields name in the file
as keys and their Hive types as values.
Note that it must be OrderedDict so as to keep columns' order.
Note that Python dict is ordered so it keeps columns' order.
:param create: whether to create the table if it doesn't exist
:param overwrite: whether to overwrite the data in table or partition
:param partition: target partition as a dict of partition columns
Expand Down
2 changes: 0 additions & 2 deletions airflow/providers/apache/hive/operators/hive_stats.py
Expand Up @@ -19,7 +19,6 @@

import json
import warnings
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Callable, Sequence

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -134,7 +133,6 @@ def execute(self, context: Context) -> None:
assign_exprs = self.get_default_exprs(col, col_type)
exprs.update(assign_exprs)
exprs.update(self.extra_exprs)
exprs = OrderedDict(exprs)
exprs_str = ",\n ".join(f"{v} AS {k[0]}__{k[1]}" for k, v in exprs.items())

where_clause_ = [f"{k} = '{v}'" for k, v in self.partition.items()]
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/apache/hive/transfers/mssql_to_hive.py
Expand Up @@ -19,7 +19,6 @@
from __future__ import annotations

import csv
from collections import OrderedDict
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Sequence

Expand Down Expand Up @@ -117,9 +116,10 @@ def execute(self, context: Context):
cursor.execute(self.sql)
with NamedTemporaryFile(mode="w", encoding="utf-8") as tmp_file:
csv_writer = csv.writer(tmp_file, delimiter=self.delimiter)
field_dict = OrderedDict()
for col_count, (key, val) in enumerate(cursor.description, start=1):
field_dict[key or f"Column{col_count}"] = self.type_map(val)
field_dict = {}
for col_count, field in enumerate(cursor.description, start=1):
col_position = f"Column{col_count}"
field_dict[col_position if field[0] == "" else field[0]] = self.type_map(field[1])
csv_writer.writerows(cursor)
tmp_file.flush()

Expand Down
3 changes: 1 addition & 2 deletions airflow/providers/apache/hive/transfers/mysql_to_hive.py
Expand Up @@ -19,7 +19,6 @@
from __future__ import annotations

import csv
from collections import OrderedDict
from contextlib import closing
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Sequence
Expand Down Expand Up @@ -147,7 +146,7 @@ def execute(self, context: Context):
quotechar=self.quotechar if self.quoting != csv.QUOTE_NONE else None,
escapechar=self.escapechar,
)
field_dict = OrderedDict()
field_dict = {}
if cursor.description is not None:
for field in cursor.description:
field_dict[field[0]] = self.type_map(field[1])
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/apache/hive/transfers/vertica_to_hive.py
Expand Up @@ -19,7 +19,6 @@
from __future__ import annotations

import csv
from collections import OrderedDict
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Sequence

Expand Down Expand Up @@ -121,9 +120,10 @@ def execute(self, context: Context):
cursor.execute(self.sql)
with NamedTemporaryFile(mode="w", encoding="utf-8") as f:
csv_writer = csv.writer(f, delimiter=self.delimiter)
field_dict = OrderedDict()
for col_count, (key, val) in enumerate(cursor.description, start=1):
field_dict[key or f"Column{col_count}"] = self.type_map(val)
field_dict = {}
for col_count, field in enumerate(cursor.description, start=1):
col_position = f"Column{col_count}"
field_dict[col_position if field[0] == "" else field[0]] = self.type_map(field[1])
csv_writer.writerows(cursor.iterate())
f.flush()
cursor.close()
Expand Down
11 changes: 5 additions & 6 deletions airflow/providers_manager.py
Expand Up @@ -27,7 +27,6 @@
import sys
import traceback
import warnings
from collections import OrderedDict
from dataclasses import dataclass
from functools import wraps
from time import perf_counter
Expand Down Expand Up @@ -484,7 +483,7 @@ def initialize_providers_list(self):
self._discover_all_airflow_builtin_providers_from_local_sources()
self._discover_all_providers_from_packages()
self._verify_all_providers_all_compatible()
self._provider_dict = OrderedDict(sorted(self._provider_dict.items()))
self._provider_dict = dict(sorted(self._provider_dict.items()))

def _verify_all_providers_all_compatible(self):
from packaging import version as packaging_version
Expand All @@ -507,7 +506,7 @@ def initialize_providers_hooks(self):
"""Lazy initialization of providers hooks."""
self.initialize_providers_list()
self._discover_hooks()
self._hook_provider_dict = OrderedDict(sorted(self._hook_provider_dict.items()))
self._hook_provider_dict = dict(sorted(self._hook_provider_dict.items()))

@provider_info_cache("taskflow_decorators")
def initialize_providers_taskflow_decorator(self):
Expand Down Expand Up @@ -829,21 +828,21 @@ def _discover_hooks(self) -> None:
provider,
provider_uses_connection_types,
)
self._hook_provider_dict = OrderedDict(sorted(self._hook_provider_dict.items()))
self._hook_provider_dict = dict(sorted(self._hook_provider_dict.items()))

@provider_info_cache("import_all_hooks")
def _import_info_from_all_hooks(self):
"""Force-import all hooks and initialize the connections/fields."""
# Retrieve all hooks to make sure that all of them are imported
_ = list(self._hooks_lazy_dict.values())
self._field_behaviours = OrderedDict(sorted(self._field_behaviours.items()))
self._field_behaviours = dict(sorted(self._field_behaviours.items()))

# Widgets for connection forms are currently used in two places:
# 1. In the UI Connections, expected same order that it defined in Hook.
# 2. cli command - `airflow providers widgets` and expected that it in alphabetical order.
# It is not possible to recover original ordering after sorting,
# that the main reason why original sorting moved to cli part:
# self._connection_form_widgets = OrderedDict(sorted(self._connection_form_widgets.items()))
# self._connection_form_widgets = dict(sorted(self._connection_form_widgets.items()))

def _discover_taskflow_decorators(self) -> None:
for name, info in self._provider_dict.items():
Expand Down
53 changes: 22 additions & 31 deletions tests/api_connexion/endpoints/test_provider_endpoint.py
Expand Up @@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations

from collections import OrderedDict
from unittest import mock

import pytest
Expand All @@ -25,36 +24,28 @@
from airflow.security import permissions
from tests.test_utils.api_connexion_utils import create_user, delete_user

MOCK_PROVIDERS = OrderedDict(
[
(
"apache-airflow-providers-amazon",
ProviderInfo(
"1.0.0",
{
"package-name": "apache-airflow-providers-amazon",
"name": "Amazon",
"description": "`Amazon Web Services (AWS) <https://aws.amazon.com/>`__.\n",
"versions": ["1.0.0"],
},
"package",
),
),
(
"apache-airflow-providers-apache-cassandra",
ProviderInfo(
"1.0.0",
{
"package-name": "apache-airflow-providers-apache-cassandra",
"name": "Apache Cassandra",
"description": "`Apache Cassandra <http://cassandra.apache.org/>`__.\n",
"versions": ["1.0.0"],
},
"package",
),
),
]
)
MOCK_PROVIDERS = {
"apache-airflow-providers-amazon": ProviderInfo(
"1.0.0",
{
"package-name": "apache-airflow-providers-amazon",
"name": "Amazon",
"description": "`Amazon Web Services (AWS) <https://aws.amazon.com/>`__.\n",
"versions": ["1.0.0"],
},
"package",
),
"apache-airflow-providers-apache-cassandra": ProviderInfo(
"1.0.0",
{
"package-name": "apache-airflow-providers-apache-cassandra",
"name": "Apache Cassandra",
"description": "`Apache Cassandra <http://cassandra.apache.org/>`__.\n",
"versions": ["1.0.0"],
},
"package",
),
}


@pytest.fixture(scope="module")
Expand Down
15 changes: 8 additions & 7 deletions tests/core/test_configuration.py
Expand Up @@ -25,7 +25,6 @@
import tempfile
import textwrap
import warnings
from collections import OrderedDict
from unittest import mock
from unittest.mock import patch

Expand Down Expand Up @@ -540,12 +539,14 @@ def test_getsection(self):
test_conf = AirflowConfigParser(default_config=parameterized_config(test_config_default))
test_conf.read_string(test_config)

assert OrderedDict([("key1", "hello"), ("key2", "airflow")]) == test_conf.getsection("test")
assert OrderedDict(
[("key3", "value3"), ("testkey", "testvalue"), ("testpercent", "with%percent")]
) == test_conf.getsection("testsection")
assert {"key1": "hello", "key2": "airflow"} == test_conf.getsection("test")
assert {
"key3": "value3",
"testkey": "testvalue",
"testpercent": "with%percent",
} == test_conf.getsection("testsection")

assert OrderedDict([("key", "value")]) == test_conf.getsection("new_section")
assert {"key": "value"} == test_conf.getsection("new_section")

assert test_conf.getsection("non_existent_section") is None

Expand Down Expand Up @@ -574,7 +575,7 @@ def test_kubernetes_environment_variables_section(self):
test_conf = AirflowConfigParser(default_config=parameterized_config(test_config_default))
test_conf.read_string(test_config)

assert OrderedDict([("key1", "hello"), ("AIRFLOW_HOME", "/root/airflow")]) == test_conf.getsection(
assert {"key1": "hello", "AIRFLOW_HOME": "/root/airflow"} == test_conf.getsection(
"kubernetes_environment_variables"
)

Expand Down
21 changes: 8 additions & 13 deletions tests/providers/amazon/aws/transfers/test_salesforce_to_s3.py
Expand Up @@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations

from collections import OrderedDict
from unittest import mock

from airflow.providers.amazon.aws.hooks.s3 import S3Hook
Expand All @@ -31,18 +30,14 @@
AWS_CONNECTION_ID = "aws_default"
SALESFORCE_RESPONSE = {
"records": [
OrderedDict(
[
(
"attributes",
OrderedDict(
[("type", "Lead"), ("url", "/services/data/v42.0/sobjects/Lead/00Q3t00001eJ7AnEAK")]
),
),
("Id", "00Q3t00001eJ7AnEAK"),
("Company", "Hello World Inc"),
]
)
{
"attributes": {
"type": "Lead",
"url": "/services/data/v42.0/sobjects/Lead/00Q3t00001eJ7AnEAK",
},
"Id": "00Q3t00001eJ7AnEAK",
"Company": "Hello World Inc",
}
],
"totalSize": 1,
"done": True,
Expand Down

0 comments on commit 63e6eab

Please sign in to comment.