From 7700fb12cc6c7a97901662e6ac6aa1e4e932d969 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= <6774676+eumiro@users.noreply.github.com> Date: Sun, 20 Aug 2023 18:02:34 +0000 Subject: [PATCH] Simplify 'X for X in Y' to 'Y' where applicable (#33453) --- airflow/lineage/__init__.py | 2 +- airflow/models/dagrun.py | 2 +- airflow/providers/apache/hive/hooks/hive.py | 2 +- airflow/providers/microsoft/azure/operators/batch.py | 2 +- airflow/providers/smtp/hooks/smtp.py | 2 +- airflow/utils/email.py | 2 +- airflow/utils/python_virtualenv.py | 2 +- airflow/www/views.py | 2 +- dev/airflow-license | 2 +- docker_tests/test_prod_image.py | 1 - scripts/ci/pre_commit/common_precommit_utils.py | 3 +-- tests/always/test_connection.py | 6 +++--- tests/conftest.py | 2 +- tests/models/test_mappedoperator.py | 2 +- tests/models/test_skipmixin.py | 2 +- tests/providers/amazon/aws/sensors/test_eks.py | 10 +++------- .../google/cloud/log/test_stackdriver_task_handler.py | 2 +- .../google/cloud/transfers/test_sql_to_gcs.py | 4 +--- .../triggers/test_cloud_storage_transfer_service.py | 2 +- tests/sensors/test_external_task_sensor.py | 8 ++++---- tests/system/providers/amazon/aws/example_s3_to_sql.py | 2 +- 21 files changed, 27 insertions(+), 35 deletions(-) diff --git a/airflow/lineage/__init__.py b/airflow/lineage/__init__.py index e22f264fdb4e1..a2fcdf4ed5cdc 100644 --- a/airflow/lineage/__init__.py +++ b/airflow/lineage/__init__.py @@ -142,7 +142,7 @@ def wrapper(self, context, *args, **kwargs): _inlets = self.xcom_pull( context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS, session=session ) - self.inlets.extend(i for i in itertools.chain.from_iterable(_inlets)) + self.inlets.extend(itertools.chain.from_iterable(_inlets)) elif self.inlets: raise AttributeError("inlets is not a list, operator, string or attr annotated object") diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 923d6f3d8af1a..99e517656a129 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -1238,7 +1238,7 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> TI.run_id == self.run_id, ) ) - existing_indexes = {i for i in query} + existing_indexes = set(query) removed_indexes = existing_indexes.difference(range(total_length)) if removed_indexes: diff --git a/airflow/providers/apache/hive/hooks/hive.py b/airflow/providers/apache/hive/hooks/hive.py index 32b1c8e5c209b..a57429f830ef1 100644 --- a/airflow/providers/apache/hive/hooks/hive.py +++ b/airflow/providers/apache/hive/hooks/hive.py @@ -230,7 +230,7 @@ def run_cli( invalid_chars_list = re.findall(r"[^a-z0-9_]", schema) if invalid_chars_list: - invalid_chars = "".join(char for char in invalid_chars_list) + invalid_chars = "".join(invalid_chars_list) raise RuntimeError(f"The schema `{schema}` contains invalid characters: {invalid_chars}") if schema: diff --git a/airflow/providers/microsoft/azure/operators/batch.py b/airflow/providers/microsoft/azure/operators/batch.py index bb93c3b5ad8ec..63b925a98199c 100644 --- a/airflow/providers/microsoft/azure/operators/batch.py +++ b/airflow/providers/microsoft/azure/operators/batch.py @@ -189,7 +189,7 @@ def _check_inputs(self) -> Any: ) if self.use_latest_image: - if not all(elem for elem in [self.vm_publisher, self.vm_offer]): + if not self.vm_publisher or not self.vm_offer: raise AirflowException( f"If use_latest_image_and_sku is set to True then the parameters vm_publisher, " f"vm_offer, must all be set. " diff --git a/airflow/providers/smtp/hooks/smtp.py b/airflow/providers/smtp/hooks/smtp.py index 62c1196627d88..e9a36076e824d 100644 --- a/airflow/providers/smtp/hooks/smtp.py +++ b/airflow/providers/smtp/hooks/smtp.py @@ -333,7 +333,7 @@ def _get_email_list_from_str(self, addresses: str) -> list[str]: :return: A list of email addresses. """ pattern = r"\s*[,;]\s*" - return [address for address in re.split(pattern, addresses)] + return re.split(pattern, addresses) @property def conn(self) -> Connection: diff --git a/airflow/utils/email.py b/airflow/utils/email.py index 8e139e5b52cdc..2957e5e1d3f64 100644 --- a/airflow/utils/email.py +++ b/airflow/utils/email.py @@ -340,4 +340,4 @@ def _get_email_list_from_str(addresses: str) -> list[str]: :return: A list of email addresses. """ pattern = r"\s*[,;]\s*" - return [address for address in re2.split(pattern, addresses)] + return re2.split(pattern, addresses) diff --git a/airflow/utils/python_virtualenv.py b/airflow/utils/python_virtualenv.py index 1abf6e1f372a1..2ce279cb36b64 100644 --- a/airflow/utils/python_virtualenv.py +++ b/airflow/utils/python_virtualenv.py @@ -58,7 +58,7 @@ def _generate_pip_conf(conf_file: Path, index_urls: list[str]) -> None: else: pip_conf_options = f"index-url = {index_urls[0]}" if len(index_urls) > 1: - pip_conf_options += f"\nextra-index-url = {' '.join(x for x in index_urls[1:])}" + pip_conf_options += f"\nextra-index-url = {' '.join(index_urls[1:])}" conf_file.write_text(f"[global]\n{pip_conf_options}") diff --git a/airflow/www/views.py b/airflow/www/views.py index 2be20fcb2aede..b08158ea08b95 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -2715,7 +2715,7 @@ def confirm(self): return redirect_or_json( origin, msg=f"TaskGroup {group_id} could not be found", status="error", status_code=404 ) - tasks = [task for task in task_group.iter_tasks()] + tasks = list(task_group.iter_tasks()) elif task_id: try: task = dag.get_task(task_id) diff --git a/dev/airflow-license b/dev/airflow-license index 0f9cc7cfbb9db..2144f54b87714 100755 --- a/dev/airflow-license +++ b/dev/airflow-license @@ -79,5 +79,5 @@ if __name__ == "__main__": license = parse_license_file(notice[1]) print(f"{notice[1]:<30}|{notice[2][:50]:<50}||{notice[0]:<20}||{license:<10}") - file_count = len([name for name in os.listdir("../licenses")]) + file_count = len(os.listdir("../licenses")) print(f"Defined licenses: {len(notices)} Files found: {file_count}") diff --git a/docker_tests/test_prod_image.py b/docker_tests/test_prod_image.py index 46cba5fbfaaac..01e271a4071a9 100644 --- a/docker_tests/test_prod_image.py +++ b/docker_tests/test_prod_image.py @@ -85,7 +85,6 @@ def test_required_providers_are_installed(self): lines = PREINSTALLED_PROVIDERS else: lines = (d.strip() for d in INSTALLED_PROVIDER_PATH.read_text().splitlines()) - lines = (d for d in lines) packages_to_install = {f"apache-airflow-providers-{d.replace('.', '-')}" for d in lines} assert len(packages_to_install) != 0 diff --git a/scripts/ci/pre_commit/common_precommit_utils.py b/scripts/ci/pre_commit/common_precommit_utils.py index 29109a4c3433f..3bb4e1c4184c1 100644 --- a/scripts/ci/pre_commit/common_precommit_utils.py +++ b/scripts/ci/pre_commit/common_precommit_utils.py @@ -64,8 +64,7 @@ def insert_documentation(file_path: Path, content: list[str], header: str, foote def get_directory_hash(directory: Path, skip_path_regexp: str | None = None) -> str: - files = [file for file in directory.rglob("*")] - files.sort() + files = sorted(directory.rglob("*")) if skip_path_regexp: matcher = re.compile(skip_path_regexp) files = [file for file in files if not matcher.match(os.fspath(file.resolve()))] diff --git a/tests/always/test_connection.py b/tests/always/test_connection.py index 390f05133a038..60caa1b3a4665 100644 --- a/tests/always/test_connection.py +++ b/tests/always/test_connection.py @@ -347,7 +347,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self): ), ] - @pytest.mark.parametrize("test_config", [x for x in test_from_uri_params]) + @pytest.mark.parametrize("test_config", test_from_uri_params) def test_connection_from_uri(self, test_config: UriTestCaseConfig): connection = Connection(uri=test_config.test_uri) @@ -369,7 +369,7 @@ def test_connection_from_uri(self, test_config: UriTestCaseConfig): self.mask_secret.assert_has_calls(expected_calls) - @pytest.mark.parametrize("test_config", [x for x in test_from_uri_params]) + @pytest.mark.parametrize("test_config", test_from_uri_params) def test_connection_get_uri_from_uri(self, test_config: UriTestCaseConfig): """ This test verifies that when we create a conn_1 from URI, and we generate a URI from that conn, that @@ -390,7 +390,7 @@ def test_connection_get_uri_from_uri(self, test_config: UriTestCaseConfig): assert connection.schema == new_conn.schema assert connection.extra_dejson == new_conn.extra_dejson - @pytest.mark.parametrize("test_config", [x for x in test_from_uri_params]) + @pytest.mark.parametrize("test_config", test_from_uri_params) def test_connection_get_uri_from_conn(self, test_config: UriTestCaseConfig): """ This test verifies that if we create conn_1 from attributes (rather than from URI), and we generate a diff --git a/tests/conftest.py b/tests/conftest.py index 2b431f1772dec..a652cb9cef480 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -134,7 +134,7 @@ def pytest_print(text): # It is very unlikely that the user wants to display only numbers, but probably # the user just wants to count the queries. exit_stack.enter_context(count_queries(print_fn=pytest_print)) - elif any(c for c in ["time", "trace", "sql", "parameters"]): + elif any(c in columns for c in ["time", "trace", "sql", "parameters"]): exit_stack.enter_context( trace_queries( display_num="num" in columns, diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 6d4a2fbca5cf1..8366acf0b7ea7 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -668,7 +668,7 @@ def execute(self, context): class ConsumeXcomOperator(PushXcomOperator): def execute(self, context): - assert {i for i in self.arg1} == {1, 2, 3} + assert set(self.arg1) == {1, 2, 3} with dag_maker("test_all_xcomargs_from_mapped_tasks_are_consumable"): op1 = PushXcomOperator.partial(task_id="op1").expand(arg1=[1, 2, 3]) diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py index 547dbec5b4208..103abc3ef67f6 100644 --- a/tests/models/test_skipmixin.py +++ b/tests/models/test_skipmixin.py @@ -147,7 +147,7 @@ def task_group_op(k): branch_b = EmptyOperator(task_id="branch_b") branch_op(k) >> [branch_a, branch_b] - task_group_op.expand(k=[i for i in range(2)]) + task_group_op.expand(k=[0, 1]) dag_maker.create_dagrun() branch_op_ti_0 = TI(dag.get_task("task_group_op.branch_op"), execution_date=DEFAULT_DATE, map_index=0) diff --git a/tests/providers/amazon/aws/sensors/test_eks.py b/tests/providers/amazon/aws/sensors/test_eks.py index fa5457f88958f..0bb625532d9ff 100644 --- a/tests/providers/amazon/aws/sensors/test_eks.py +++ b/tests/providers/amazon/aws/sensors/test_eks.py @@ -42,13 +42,9 @@ NODEGROUP_NAME = "test_nodegroup" TASK_ID = "test_eks_sensor" -CLUSTER_PENDING_STATES = frozenset(frozenset({state for state in ClusterStates}) - CLUSTER_TERMINAL_STATES) -FARGATE_PENDING_STATES = frozenset( - frozenset({state for state in FargateProfileStates}) - FARGATE_TERMINAL_STATES -) -NODEGROUP_PENDING_STATES = frozenset( - frozenset({state for state in NodegroupStates}) - NODEGROUP_TERMINAL_STATES -) +CLUSTER_PENDING_STATES = frozenset(ClusterStates) - frozenset(CLUSTER_TERMINAL_STATES) +FARGATE_PENDING_STATES = frozenset(FargateProfileStates) - frozenset(FARGATE_TERMINAL_STATES) +NODEGROUP_PENDING_STATES = frozenset(NodegroupStates) - frozenset(NODEGROUP_TERMINAL_STATES) class TestEksClusterStateSensor: diff --git a/tests/providers/google/cloud/log/test_stackdriver_task_handler.py b/tests/providers/google/cloud/log/test_stackdriver_task_handler.py index ca489efb45dbd..e9a79629801f1 100644 --- a/tests/providers/google/cloud/log/test_stackdriver_task_handler.py +++ b/tests/providers/google/cloud/log/test_stackdriver_task_handler.py @@ -311,7 +311,7 @@ def test_should_read_logs_with_custom_resources(self, mock_client, mock_get_cred entry = mock.MagicMock(json_payload={"message": "TEXT"}) page = mock.MagicMock(entries=[entry, entry], next_page_token=None) - mock_client.return_value.list_log_entries.return_value.pages = (n for n in [page]) + mock_client.return_value.list_log_entries.return_value.pages = iter([page]) logs, metadata = stackdriver_task_handler.read(self.ti) mock_client.return_value.list_log_entries.assert_called_once_with( diff --git a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py index dd0e5a42d639e..03d6ca36b879c 100644 --- a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py @@ -555,9 +555,7 @@ def test__write_local_data_files_csv_does_not_write_on_empty_rows(self): files = op._write_local_data_files(cursor) # Raises StopIteration when next is called because generator returns no files with pytest.raises(StopIteration): - next(files)["file_handle"] - - assert len([f for f in files]) == 0 + next(files) def test__write_local_data_files_csv_writes_empty_file_with_write_on_empty(self): op = DummySQLToGCSOperator( diff --git a/tests/providers/google/cloud/triggers/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/triggers/test_cloud_storage_transfer_service.py index a7108d69380fa..ec6ed4b917f74 100644 --- a/tests/providers/google/cloud/triggers/test_cloud_storage_transfer_service.py +++ b/tests/providers/google/cloud/triggers/test_cloud_storage_transfer_service.py @@ -69,7 +69,7 @@ def mock_jobs(names: list[str], latest_operation_names: list[str | None]): for job, name in zip(jobs, names): job.name = name mock_obj = mock.MagicMock() - mock_obj.__aiter__.return_value = (job for job in jobs) + mock_obj.__aiter__.return_value = iter(jobs) return mock_obj diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index e84b3f69f48e0..4422cbe48c87f 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -135,7 +135,7 @@ def dummy_mapped_task(x: int): return x dummy_task() - dummy_mapped_task.expand(x=[i for i in map_indexes]) + dummy_mapped_task.expand(x=list(map_indexes)) SerializedDagModel.write_dag(dag) @@ -1089,7 +1089,7 @@ def run_tasks(dag_bag, execution_date=DEFAULT_DATE, session=None): # this is equivalent to topological sort. It would not work in general case # but it works for our case because we specifically constructed test DAGS # in the way that those two sort methods are equivalent - tasks = sorted((ti for ti in dagrun.task_instances), key=lambda ti: ti.task_id) + tasks = sorted(dagrun.task_instances, key=lambda ti: ti.task_id) for ti in tasks: ti.refresh_from_task(dag.get_task(ti.task_id)) tis[ti.task_id] = ti @@ -1478,7 +1478,7 @@ def dummy_task(x: int): mode="reschedule", ) - body = dummy_task.expand(x=[i for i in range(5)]) + body = dummy_task.expand(x=range(5)) tail = ExternalTaskMarker( task_id="tail", external_dag_id=dag.dag_id, @@ -1524,7 +1524,7 @@ def test_clear_overlapping_external_task_marker_mapped_tasks(dag_bag_head_tail_m include_downstream=True, include_upstream=False, ) - task_ids = [tid for tid in dag.task_dict] + task_ids = list(dag.task_dict) assert ( dag.clear( start_date=DEFAULT_DATE, diff --git a/tests/system/providers/amazon/aws/example_s3_to_sql.py b/tests/system/providers/amazon/aws/example_s3_to_sql.py index ee110dce4fdc6..4910149961e20 100644 --- a/tests/system/providers/amazon/aws/example_s3_to_sql.py +++ b/tests/system/providers/amazon/aws/example_s3_to_sql.py @@ -177,7 +177,7 @@ def parse_csv_to_list(filepath): import csv with open(filepath, newline="") as file: - return [row for row in csv.reader(file)] + return list(csv.reader(file)) transfer_s3_to_sql = S3ToSqlOperator( task_id="transfer_s3_to_sql",