Skip to content

Commit

Permalink
Replace sequence concatination by unpacking in Airflow providers (#33933
Browse files Browse the repository at this point in the history
)
  • Loading branch information
hussein-awala committed Aug 31, 2023
1 parent f63a94d commit 55976af
Show file tree
Hide file tree
Showing 13 changed files with 19 additions and 29 deletions.
4 changes: 1 addition & 3 deletions airflow/providers/amazon/aws/transfers/redshift_to_s3.py
Expand Up @@ -131,9 +131,7 @@ def __init__(
)

if self.include_header and "HEADER" not in [uo.upper().strip() for uo in self.unload_options]:
self.unload_options = list(self.unload_options) + [
"HEADER",
]
self.unload_options = [*self.unload_options, "HEADER"]

if self.redshift_data_api_kwargs:
for arg in ["sql", "parameters"]:
Expand Down
10 changes: 3 additions & 7 deletions airflow/providers/apache/beam/hooks/beam.py
Expand Up @@ -188,9 +188,7 @@ def _start_pipeline(
process_line_callback: Callable[[str], None] | None = None,
working_directory: str | None = None,
) -> None:
cmd = command_prefix + [
f"--runner={self.runner}",
]
cmd = [*command_prefix, f"--runner={self.runner}"]
if variables:
cmd.extend(beam_options_to_args(variables))
run_beam_command(
Expand Down Expand Up @@ -261,7 +259,7 @@ def start_python_pipeline(
requirements=py_requirements,
)

command_prefix = [py_interpreter] + py_options + [py_file]
command_prefix = [py_interpreter, *py_options, py_file]

beam_version = (
subprocess.check_output(
Expand Down Expand Up @@ -506,9 +504,7 @@ async def start_pipeline_async(
command_prefix: list[str],
working_directory: str | None = None,
) -> int:
cmd = command_prefix + [
f"--runner={self.runner}",
]
cmd = [*command_prefix, f"--runner={self.runner}"]
if variables:
cmd.extend(beam_options_to_args(variables))
return await self.run_beam_command_async(
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/apache/hive/hooks/hive.py
Expand Up @@ -163,7 +163,7 @@ def _prepare_cli_cmd(self) -> list[Any]:

hive_params_list = self.hive_cli_params.split()

return [hive_bin] + cmd_extra + hive_params_list
return [hive_bin, *cmd_extra, *hive_params_list]

def _validate_beeline_parameters(self, conn):
if ":" in conn.host or "/" in conn.host or ";" in conn.host:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/apache/pig/hooks/pig.py
Expand Up @@ -79,7 +79,7 @@ def run_cli(self, pig: str, pig_opts: str | None = None, verbose: bool = True) -
pig_opts_list = pig_opts.split()
pig_cmd.extend(pig_opts_list)

pig_cmd.extend(["-f", fname] + cmd_extra)
pig_cmd.extend(["-f", fname, *cmd_extra])

if verbose:
self.log.info("%s", " ".join(pig_cmd))
Expand Down
7 changes: 4 additions & 3 deletions airflow/providers/cncf/kubernetes/pod_generator.py
Expand Up @@ -345,9 +345,10 @@ def reconcile_containers(
client_container = extend_object_field(base_container, client_container, "volume_devices")
client_container = merge_objects(base_container, client_container)

return [client_container] + PodGenerator.reconcile_containers(
base_containers[1:], client_containers[1:]
)
return [
client_container,
*PodGenerator.reconcile_containers(base_containers[1:], client_containers[1:]),
]

@classmethod
def construct_pod(
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/databricks/hooks/databricks_sql.py
Expand Up @@ -83,7 +83,7 @@ def __init__(

def _get_extra_config(self) -> dict[str, Any | None]:
extra_params = copy(self.databricks_conn.extra_dejson)
for arg in ["http_path", "session_configuration"] + self.extra_parameters:
for arg in ["http_path", "session_configuration", *self.extra_parameters]:
if arg in extra_params:
del extra_params[arg]

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/docker/operators/docker.py
Expand Up @@ -335,7 +335,7 @@ def _run_image(self) -> list[str] | str | None:
with TemporaryDirectory(prefix="airflowtmp", dir=self.host_tmp_dir) as host_tmp_dir_generated:
tmp_mount = Mount(self.tmp_dir, host_tmp_dir_generated, "bind")
try:
return self._run_image_with_mounts(self.mounts + [tmp_mount], add_tmp_variable=True)
return self._run_image_with_mounts([*self.mounts, tmp_mount], add_tmp_variable=True)
except APIError as e:
if host_tmp_dir_generated in str(e):
self.log.warning(
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/elasticsearch/log/es_task_handler.py
Expand Up @@ -364,7 +364,7 @@ def set_context(self, ti: TaskInstance) -> None:
if self.json_format:
self.formatter = ElasticsearchJSONFormatter(
fmt=self.formatter._fmt,
json_fields=self.json_fields + [self.offset_field],
json_fields=[*self.json_fields, self.offset_field],
extras={
"dag_id": str(ti.dag_id),
"task_id": str(ti.task_id),
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/operators/datafusion.py
Expand Up @@ -783,7 +783,7 @@ def __init__(
if success_states:
self.success_states = success_states
else:
self.success_states = SUCCESS_STATES + [PipelineStates.RUNNING]
self.success_states = [*SUCCESS_STATES, PipelineStates.RUNNING]

def execute(self, context: Context) -> str:
hook = DataFusionHook(
Expand Down
Expand Up @@ -45,7 +45,8 @@ class BigQueryToMsSqlOperator(BigQueryToSqlBaseOperator):
:param mssql_conn_id: reference to a specific mssql hook
"""

template_fields: Sequence[str] = tuple(BigQueryToSqlBaseOperator.template_fields) + (
template_fields: Sequence[str] = (
*BigQueryToSqlBaseOperator.template_fields,
"source_project_dataset_table",
)
operator_extra_links = (BigQueryTableLink(),)
Expand Down
Expand Up @@ -40,10 +40,7 @@ class BigQueryToMySqlOperator(BigQueryToSqlBaseOperator):
:param mysql_conn_id: Reference to :ref:`mysql connection id <howto/connection:mysql>`.
"""

template_fields: Sequence[str] = tuple(BigQueryToSqlBaseOperator.template_fields) + (
"dataset_id",
"table_id",
)
template_fields: Sequence[str] = (*BigQueryToSqlBaseOperator.template_fields, "dataset_id", "table_id")

def __init__(
self,
Expand Down
Expand Up @@ -36,10 +36,7 @@ class BigQueryToPostgresOperator(BigQueryToSqlBaseOperator):
:param postgres_conn_id: Reference to :ref:`postgres connection id <howto/connection:postgres>`.
"""

template_fields: Sequence[str] = tuple(BigQueryToSqlBaseOperator.template_fields) + (
"dataset_id",
"table_id",
)
template_fields: Sequence[str] = (*BigQueryToSqlBaseOperator.template_fields, "dataset_id", "table_id")

def __init__(
self,
Expand Down
Expand Up @@ -139,7 +139,7 @@ def MakeSummary(pcoll, metric_fn, metric_keys):
return (
pcoll
| "ApplyMetricFnPerInstance" >> beam.Map(metric_fn)
| "PairWith1" >> beam.Map(lambda tup: tup + (1,))
| "PairWith1" >> beam.Map(lambda tup: (*tup, 1))
| "SumTuple" >> beam.CombineGlobally(beam.combiners.TupleCombineFn(*([sum] * (len(metric_keys) + 1))))
| "AverageAndMakeDict"
>> beam.Map(
Expand Down

0 comments on commit 55976af

Please sign in to comment.