Skip to content

Commit

Permalink
Simplify 'X for X in Y' to 'Y' where applicable (#33453)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro committed Aug 20, 2023
1 parent 7d352e2 commit 7700fb1
Show file tree
Hide file tree
Showing 21 changed files with 27 additions and 35 deletions.
2 changes: 1 addition & 1 deletion airflow/lineage/__init__.py
Expand Up @@ -142,7 +142,7 @@ def wrapper(self, context, *args, **kwargs):
_inlets = self.xcom_pull(
context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS, session=session
)
self.inlets.extend(i for i in itertools.chain.from_iterable(_inlets))
self.inlets.extend(itertools.chain.from_iterable(_inlets))

elif self.inlets:
raise AttributeError("inlets is not a list, operator, string or attr annotated object")
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/dagrun.py
Expand Up @@ -1238,7 +1238,7 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) ->
TI.run_id == self.run_id,
)
)
existing_indexes = {i for i in query}
existing_indexes = set(query)

removed_indexes = existing_indexes.difference(range(total_length))
if removed_indexes:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/apache/hive/hooks/hive.py
Expand Up @@ -230,7 +230,7 @@ def run_cli(

invalid_chars_list = re.findall(r"[^a-z0-9_]", schema)
if invalid_chars_list:
invalid_chars = "".join(char for char in invalid_chars_list)
invalid_chars = "".join(invalid_chars_list)
raise RuntimeError(f"The schema `{schema}` contains invalid characters: {invalid_chars}")

if schema:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/microsoft/azure/operators/batch.py
Expand Up @@ -189,7 +189,7 @@ def _check_inputs(self) -> Any:
)

if self.use_latest_image:
if not all(elem for elem in [self.vm_publisher, self.vm_offer]):
if not self.vm_publisher or not self.vm_offer:
raise AirflowException(
f"If use_latest_image_and_sku is set to True then the parameters vm_publisher, "
f"vm_offer, must all be set. "
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/smtp/hooks/smtp.py
Expand Up @@ -333,7 +333,7 @@ def _get_email_list_from_str(self, addresses: str) -> list[str]:
:return: A list of email addresses.
"""
pattern = r"\s*[,;]\s*"
return [address for address in re.split(pattern, addresses)]
return re.split(pattern, addresses)

@property
def conn(self) -> Connection:
Expand Down
2 changes: 1 addition & 1 deletion airflow/utils/email.py
Expand Up @@ -340,4 +340,4 @@ def _get_email_list_from_str(addresses: str) -> list[str]:
:return: A list of email addresses.
"""
pattern = r"\s*[,;]\s*"
return [address for address in re2.split(pattern, addresses)]
return re2.split(pattern, addresses)
2 changes: 1 addition & 1 deletion airflow/utils/python_virtualenv.py
Expand Up @@ -58,7 +58,7 @@ def _generate_pip_conf(conf_file: Path, index_urls: list[str]) -> None:
else:
pip_conf_options = f"index-url = {index_urls[0]}"
if len(index_urls) > 1:
pip_conf_options += f"\nextra-index-url = {' '.join(x for x in index_urls[1:])}"
pip_conf_options += f"\nextra-index-url = {' '.join(index_urls[1:])}"
conf_file.write_text(f"[global]\n{pip_conf_options}")


Expand Down
2 changes: 1 addition & 1 deletion airflow/www/views.py
Expand Up @@ -2715,7 +2715,7 @@ def confirm(self):
return redirect_or_json(
origin, msg=f"TaskGroup {group_id} could not be found", status="error", status_code=404
)
tasks = [task for task in task_group.iter_tasks()]
tasks = list(task_group.iter_tasks())
elif task_id:
try:
task = dag.get_task(task_id)
Expand Down
2 changes: 1 addition & 1 deletion dev/airflow-license
Expand Up @@ -79,5 +79,5 @@ if __name__ == "__main__":
license = parse_license_file(notice[1])
print(f"{notice[1]:<30}|{notice[2][:50]:<50}||{notice[0]:<20}||{license:<10}")

file_count = len([name for name in os.listdir("../licenses")])
file_count = len(os.listdir("../licenses"))
print(f"Defined licenses: {len(notices)} Files found: {file_count}")
1 change: 0 additions & 1 deletion docker_tests/test_prod_image.py
Expand Up @@ -85,7 +85,6 @@ def test_required_providers_are_installed(self):
lines = PREINSTALLED_PROVIDERS
else:
lines = (d.strip() for d in INSTALLED_PROVIDER_PATH.read_text().splitlines())
lines = (d for d in lines)
packages_to_install = {f"apache-airflow-providers-{d.replace('.', '-')}" for d in lines}
assert len(packages_to_install) != 0

Expand Down
3 changes: 1 addition & 2 deletions scripts/ci/pre_commit/common_precommit_utils.py
Expand Up @@ -64,8 +64,7 @@ def insert_documentation(file_path: Path, content: list[str], header: str, foote


def get_directory_hash(directory: Path, skip_path_regexp: str | None = None) -> str:
files = [file for file in directory.rglob("*")]
files.sort()
files = sorted(directory.rglob("*"))
if skip_path_regexp:
matcher = re.compile(skip_path_regexp)
files = [file for file in files if not matcher.match(os.fspath(file.resolve()))]
Expand Down
6 changes: 3 additions & 3 deletions tests/always/test_connection.py
Expand Up @@ -347,7 +347,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
),
]

@pytest.mark.parametrize("test_config", [x for x in test_from_uri_params])
@pytest.mark.parametrize("test_config", test_from_uri_params)
def test_connection_from_uri(self, test_config: UriTestCaseConfig):

connection = Connection(uri=test_config.test_uri)
Expand All @@ -369,7 +369,7 @@ def test_connection_from_uri(self, test_config: UriTestCaseConfig):

self.mask_secret.assert_has_calls(expected_calls)

@pytest.mark.parametrize("test_config", [x for x in test_from_uri_params])
@pytest.mark.parametrize("test_config", test_from_uri_params)
def test_connection_get_uri_from_uri(self, test_config: UriTestCaseConfig):
"""
This test verifies that when we create a conn_1 from URI, and we generate a URI from that conn, that
Expand All @@ -390,7 +390,7 @@ def test_connection_get_uri_from_uri(self, test_config: UriTestCaseConfig):
assert connection.schema == new_conn.schema
assert connection.extra_dejson == new_conn.extra_dejson

@pytest.mark.parametrize("test_config", [x for x in test_from_uri_params])
@pytest.mark.parametrize("test_config", test_from_uri_params)
def test_connection_get_uri_from_conn(self, test_config: UriTestCaseConfig):
"""
This test verifies that if we create conn_1 from attributes (rather than from URI), and we generate a
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Expand Up @@ -134,7 +134,7 @@ def pytest_print(text):
# It is very unlikely that the user wants to display only numbers, but probably
# the user just wants to count the queries.
exit_stack.enter_context(count_queries(print_fn=pytest_print))
elif any(c for c in ["time", "trace", "sql", "parameters"]):
elif any(c in columns for c in ["time", "trace", "sql", "parameters"]):
exit_stack.enter_context(
trace_queries(
display_num="num" in columns,
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_mappedoperator.py
Expand Up @@ -668,7 +668,7 @@ def execute(self, context):

class ConsumeXcomOperator(PushXcomOperator):
def execute(self, context):
assert {i for i in self.arg1} == {1, 2, 3}
assert set(self.arg1) == {1, 2, 3}

with dag_maker("test_all_xcomargs_from_mapped_tasks_are_consumable"):
op1 = PushXcomOperator.partial(task_id="op1").expand(arg1=[1, 2, 3])
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_skipmixin.py
Expand Up @@ -147,7 +147,7 @@ def task_group_op(k):
branch_b = EmptyOperator(task_id="branch_b")
branch_op(k) >> [branch_a, branch_b]

task_group_op.expand(k=[i for i in range(2)])
task_group_op.expand(k=[0, 1])

dag_maker.create_dagrun()
branch_op_ti_0 = TI(dag.get_task("task_group_op.branch_op"), execution_date=DEFAULT_DATE, map_index=0)
Expand Down
10 changes: 3 additions & 7 deletions tests/providers/amazon/aws/sensors/test_eks.py
Expand Up @@ -42,13 +42,9 @@
NODEGROUP_NAME = "test_nodegroup"
TASK_ID = "test_eks_sensor"

CLUSTER_PENDING_STATES = frozenset(frozenset({state for state in ClusterStates}) - CLUSTER_TERMINAL_STATES)
FARGATE_PENDING_STATES = frozenset(
frozenset({state for state in FargateProfileStates}) - FARGATE_TERMINAL_STATES
)
NODEGROUP_PENDING_STATES = frozenset(
frozenset({state for state in NodegroupStates}) - NODEGROUP_TERMINAL_STATES
)
CLUSTER_PENDING_STATES = frozenset(ClusterStates) - frozenset(CLUSTER_TERMINAL_STATES)
FARGATE_PENDING_STATES = frozenset(FargateProfileStates) - frozenset(FARGATE_TERMINAL_STATES)
NODEGROUP_PENDING_STATES = frozenset(NodegroupStates) - frozenset(NODEGROUP_TERMINAL_STATES)


class TestEksClusterStateSensor:
Expand Down
Expand Up @@ -311,7 +311,7 @@ def test_should_read_logs_with_custom_resources(self, mock_client, mock_get_cred

entry = mock.MagicMock(json_payload={"message": "TEXT"})
page = mock.MagicMock(entries=[entry, entry], next_page_token=None)
mock_client.return_value.list_log_entries.return_value.pages = (n for n in [page])
mock_client.return_value.list_log_entries.return_value.pages = iter([page])

logs, metadata = stackdriver_task_handler.read(self.ti)
mock_client.return_value.list_log_entries.assert_called_once_with(
Expand Down
4 changes: 1 addition & 3 deletions tests/providers/google/cloud/transfers/test_sql_to_gcs.py
Expand Up @@ -555,9 +555,7 @@ def test__write_local_data_files_csv_does_not_write_on_empty_rows(self):
files = op._write_local_data_files(cursor)
# Raises StopIteration when next is called because generator returns no files
with pytest.raises(StopIteration):
next(files)["file_handle"]

assert len([f for f in files]) == 0
next(files)

def test__write_local_data_files_csv_writes_empty_file_with_write_on_empty(self):
op = DummySQLToGCSOperator(
Expand Down
Expand Up @@ -69,7 +69,7 @@ def mock_jobs(names: list[str], latest_operation_names: list[str | None]):
for job, name in zip(jobs, names):
job.name = name
mock_obj = mock.MagicMock()
mock_obj.__aiter__.return_value = (job for job in jobs)
mock_obj.__aiter__.return_value = iter(jobs)
return mock_obj


Expand Down
8 changes: 4 additions & 4 deletions tests/sensors/test_external_task_sensor.py
Expand Up @@ -135,7 +135,7 @@ def dummy_mapped_task(x: int):
return x

dummy_task()
dummy_mapped_task.expand(x=[i for i in map_indexes])
dummy_mapped_task.expand(x=list(map_indexes))

SerializedDagModel.write_dag(dag)

Expand Down Expand Up @@ -1089,7 +1089,7 @@ def run_tasks(dag_bag, execution_date=DEFAULT_DATE, session=None):
# this is equivalent to topological sort. It would not work in general case
# but it works for our case because we specifically constructed test DAGS
# in the way that those two sort methods are equivalent
tasks = sorted((ti for ti in dagrun.task_instances), key=lambda ti: ti.task_id)
tasks = sorted(dagrun.task_instances, key=lambda ti: ti.task_id)
for ti in tasks:
ti.refresh_from_task(dag.get_task(ti.task_id))
tis[ti.task_id] = ti
Expand Down Expand Up @@ -1478,7 +1478,7 @@ def dummy_task(x: int):
mode="reschedule",
)

body = dummy_task.expand(x=[i for i in range(5)])
body = dummy_task.expand(x=range(5))
tail = ExternalTaskMarker(
task_id="tail",
external_dag_id=dag.dag_id,
Expand Down Expand Up @@ -1524,7 +1524,7 @@ def test_clear_overlapping_external_task_marker_mapped_tasks(dag_bag_head_tail_m
include_downstream=True,
include_upstream=False,
)
task_ids = [tid for tid in dag.task_dict]
task_ids = list(dag.task_dict)
assert (
dag.clear(
start_date=DEFAULT_DATE,
Expand Down
2 changes: 1 addition & 1 deletion tests/system/providers/amazon/aws/example_s3_to_sql.py
Expand Up @@ -177,7 +177,7 @@ def parse_csv_to_list(filepath):
import csv

with open(filepath, newline="") as file:
return [row for row in csv.reader(file)]
return list(csv.reader(file))

transfer_s3_to_sql = S3ToSqlOperator(
task_id="transfer_s3_to_sql",
Expand Down

0 comments on commit 7700fb1

Please sign in to comment.