Skip to content

Commit

Permalink
Refactor: Simplify code in Apache/Alibaba providers (#33227)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro committed Aug 22, 2023
1 parent 770228e commit 32feab4
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 58 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py
Expand Up @@ -321,7 +321,7 @@ def _validate_extra_conf(conf: dict[Any, Any]) -> bool:
if conf:
if not isinstance(conf, dict):
raise ValueError("'conf' argument must be a dict")
if not all((v and isinstance(v, str)) or isinstance(v, int) for v in conf.values()):
if not all(isinstance(v, (str, int)) and v != "" for v in conf.values()):
raise ValueError("'conf' values must be either strings or ints")
return True

Expand Down
6 changes: 2 additions & 4 deletions airflow/providers/apache/beam/hooks/beam.py
Expand Up @@ -104,10 +104,8 @@ def process_fd(
fd_to_log = {proc.stderr: log.warning, proc.stdout: log.info}
func_log = fd_to_log[fd]

while True:
line = fd.readline().decode()
if not line:
return
for line in iter(fd.readline, b""):
line = line.decode()
if process_line_callback:
process_line_callback(line)
func_log(line.rstrip("\n"))
Expand Down
47 changes: 22 additions & 25 deletions airflow/providers/apache/beam/triggers/beam.py
Expand Up @@ -85,32 +85,29 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Get current pipeline status and yields a TriggerEvent."""
hook = self._get_async_hook()
while True:
try:
return_code = await hook.start_python_pipeline_async(
variables=self.variables,
py_file=self.py_file,
py_options=self.py_options,
py_interpreter=self.py_interpreter,
py_requirements=self.py_requirements,
py_system_site_packages=self.py_system_site_packages,
try:
return_code = await hook.start_python_pipeline_async(
variables=self.variables,
py_file=self.py_file,
py_options=self.py_options,
py_interpreter=self.py_interpreter,
py_requirements=self.py_requirements,
py_system_site_packages=self.py_system_site_packages,
)
except Exception as e:
self.log.exception("Exception occurred while checking for pipeline state")
yield TriggerEvent({"status": "error", "message": str(e)})
else:
if return_code == 0:
yield TriggerEvent(
{
"status": "success",
"message": "Pipeline has finished SUCCESSFULLY",
}
)
if return_code == 0:
yield TriggerEvent(
{
"status": "success",
"message": "Pipeline has finished SUCCESSFULLY",
}
)
return
else:
yield TriggerEvent({"status": "error", "message": "Operation failed"})
return

except Exception as e:
self.log.exception("Exception occurred while checking for pipeline state")
yield TriggerEvent({"status": "error", "message": str(e)})
return
else:
yield TriggerEvent({"status": "error", "message": "Operation failed"})
return

def _get_async_hook(self) -> BeamAsyncHook:
return BeamAsyncHook(runner=self.runner)
33 changes: 13 additions & 20 deletions airflow/providers/apache/hive/hooks/hive.py
Expand Up @@ -277,13 +277,11 @@ def run_cli(
)
self.sub_process = sub_process
stdout = ""
while True:
line = sub_process.stdout.readline()
if not line:
break
stdout += line.decode("UTF-8")
for line in iter(sub_process.stdout.readline, b""):
line = line.decode()
stdout += line
if verbose:
self.log.info(line.decode("UTF-8").strip())
self.log.info(line.strip())
sub_process.wait()

if sub_process.returncode:
Expand Down Expand Up @@ -704,25 +702,20 @@ def _get_max_partition_from_part_specs(
# Assuming all specs have the same keys.
if partition_key not in part_specs[0].keys():
raise AirflowException(f"Provided partition_key {partition_key} is not in part_specs.")
is_subset = None
if filter_map:
is_subset = set(filter_map.keys()).issubset(set(part_specs[0].keys()))
if filter_map and not is_subset:
if filter_map and not set(filter_map).issubset(part_specs[0]):
raise AirflowException(
f"Keys in provided filter_map {', '.join(filter_map.keys())} "
f"are not subset of part_spec keys: {', '.join(part_specs[0].keys())}"
)

candidates = [
p_dict[partition_key]
for p_dict in part_specs
if filter_map is None or all(item in p_dict.items() for item in filter_map.items())
]

if not candidates:
return None
else:
return max(candidates)
return max(
(
p_dict[partition_key]
for p_dict in part_specs
if filter_map is None or all(item in p_dict.items() for item in filter_map.items())
),
default=None,
)

def max_partition(
self,
Expand Down
12 changes: 5 additions & 7 deletions airflow/providers/apache/livy/hooks/livy.py
Expand Up @@ -432,7 +432,7 @@ def _validate_list_of_stringables(vals: Sequence[str | int | float]) -> bool:
if (
vals is None
or not isinstance(vals, (tuple, list))
or any(1 for val in vals if not isinstance(val, (str, int, float)))
or not all(isinstance(val, (str, int, float)) for val in vals)
):
raise ValueError("List of strings expected")
return True
Expand All @@ -448,7 +448,7 @@ def _validate_extra_conf(conf: dict[Any, Any]) -> bool:
if conf:
if not isinstance(conf, dict):
raise ValueError("'conf' argument must be a dict")
if not all((v and isinstance(v, str)) or isinstance(v, int) for v in conf.values()):
if not all(isinstance(v, (str, int)) and v != "" for v in conf.values()):
raise ValueError("'conf' values must be either strings or ints")
return True

Expand Down Expand Up @@ -542,8 +542,7 @@ async def _do_api_call_async(
else:
return {"Response": f"Unexpected HTTP Method: {self.method}", "status": "error"}

attempt_num = 1
while True:
for attempt_num in range(1, 1 + self.retry_limit):
response = await request_func(
url,
json=data if self.method in ("POST", "PATCH") else None,
Expand All @@ -568,7 +567,6 @@ async def _do_api_call_async(
# Don't retry.
return {"Response": {e.message}, "Status Code": {e.status}, "status": "error"}

attempt_num += 1
await asyncio.sleep(self.retry_delay)

def _generate_base_url(self, conn: Connection) -> str:
Expand Down Expand Up @@ -815,7 +813,7 @@ def _validate_list_of_stringables(vals: Sequence[str | int | float]) -> bool:
if (
vals is None
or not isinstance(vals, (tuple, list))
or any(1 for val in vals if not isinstance(val, (str, int, float)))
or not all(isinstance(val, (str, int, float)) for val in vals)
):
raise ValueError("List of strings expected")
return True
Expand All @@ -831,6 +829,6 @@ def _validate_extra_conf(conf: dict[Any, Any]) -> bool:
if conf:
if not isinstance(conf, dict):
raise ValueError("'conf' argument must be a dict")
if not all((v and isinstance(v, str)) or isinstance(v, int) for v in conf.values()):
if not all(isinstance(v, (str, int)) and v != "" for v in conf.values()):
raise ValueError("'conf' values must be either strings or ints")
return True
2 changes: 1 addition & 1 deletion airflow/providers/apache/spark/hooks/spark_sql.py
Expand Up @@ -134,7 +134,7 @@ def _prepare_command(self, cmd: str | list[str]) -> list[str]:
connection_cmd += ["--num-executors", str(self._num_executors)]
if self._sql:
sql = self._sql.strip()
if sql.endswith(".sql") or sql.endswith(".hql"):
if sql.endswith((".sql", ".hql")):
connection_cmd += ["-f", sql]
else:
connection_cmd += ["-e", sql]
Expand Down

0 comments on commit 32feab4

Please sign in to comment.