Skip to content

Commit

Permalink
Refactor: Simplify code in providers/amazon (#33222)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro committed Aug 9, 2023
1 parent 1dcf05f commit 83bd60f
Show file tree
Hide file tree
Showing 17 changed files with 112 additions and 144 deletions.
7 changes: 3 additions & 4 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Expand Up @@ -494,10 +494,9 @@ def _find_class_name(target_function_name: str) -> str:
responsible with catching and handling those.
"""
stack = inspect.stack()
# Find the index of the most recent frame which called the provided function name.
target_frame_index = [frame.function for frame in stack].index(target_function_name)
# Pull that frame off the stack.
target_frame = stack[target_frame_index][0]
# Find the index of the most recent frame which called the provided function name
# and pull that frame off the stack.
target_frame = next(frame for frame in stack if frame.function == target_function_name)[0]
# Get the local variables for that frame.
frame_variables = target_frame.f_locals["self"]
# Get the class object for that frame.
Expand Down
91 changes: 42 additions & 49 deletions airflow/providers/amazon/aws/hooks/batch_client.py
Expand Up @@ -26,6 +26,7 @@
"""
from __future__ import annotations

import itertools as it
from random import uniform
from time import sleep
from typing import Callable
Expand Down Expand Up @@ -343,8 +344,17 @@ def poll_job_status(self, job_id: str, match_status: list[str]) -> bool:
:raises: AirflowException
"""
retries = 0
while True:
for retries in range(1 + self.max_retries):
if retries:
pause = self.exponential_delay(retries)
self.log.info(
"AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds",
job_id,
retries,
self.max_retries,
pause,
)
self.delay(pause)

job = self.get_job_description(job_id)
job_status = job.get("status")
Expand All @@ -354,23 +364,10 @@ def poll_job_status(self, job_id: str, match_status: list[str]) -> bool:
job_status,
match_status,
)

if job_status in match_status:
return True

if retries >= self.max_retries:
raise AirflowException(f"AWS Batch job ({job_id}) status checks exceed max_retries")

retries += 1
pause = self.exponential_delay(retries)
self.log.info(
"AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds",
job_id,
retries,
self.max_retries,
pause,
)
self.delay(pause)
else:
raise AirflowException(f"AWS Batch job ({job_id}) status checks exceed max_retries")

def get_job_description(self, job_id: str) -> dict:
"""
Expand All @@ -382,12 +379,21 @@ def get_job_description(self, job_id: str) -> dict:
:raises: AirflowException
"""
retries = 0
while True:
for retries in range(self.status_retries):
if retries:
pause = self.exponential_delay(retries)
self.log.info(
"AWS Batch job (%s) description retry (%d of %d) in the next %.2f seconds",
job_id,
retries,
self.status_retries,
pause,
)
self.delay(pause)

try:
response = self.get_conn().describe_jobs(jobs=[job_id])
return self.parse_job_description(job_id, response)

except botocore.exceptions.ClientError as err:
# Allow it to retry in case of exceeded quota limit of requests to AWS API
if err.response.get("Error", {}).get("Code") != "TooManyRequestsException":
Expand All @@ -398,23 +404,11 @@ def get_job_description(self, job_id: str) -> dict:
"check Amazon Provider AWS Connection documentation for more details.",
str(err),
)

retries += 1
if retries >= self.status_retries:
raise AirflowException(
f"AWS Batch job ({job_id}) description error: exceeded status_retries "
f"({self.status_retries})"
)

pause = self.exponential_delay(retries)
self.log.info(
"AWS Batch job (%s) description retry (%d of %d) in the next %.2f seconds",
job_id,
retries,
self.status_retries,
pause,
else:
raise AirflowException(
f"AWS Batch job ({job_id}) description error: exceeded status_retries "
f"({self.status_retries})"
)
self.delay(pause)

@staticmethod
def parse_job_description(job_id: str, response: dict) -> dict:
Expand Down Expand Up @@ -476,7 +470,7 @@ def get_job_all_awslogs_info(self, job_id: str) -> list[dict[str, str]]:
)

# If the user selected another logDriver than "awslogs", then CloudWatch logging is disabled.
if any([c.get("logDriver", "awslogs") != "awslogs" for c in log_configs]):
if any(c.get("logDriver", "awslogs") != "awslogs" for c in log_configs):
self.log.warning(
f"AWS Batch job ({job_id}) uses non-aws log drivers. AWS CloudWatch logging disabled."
)
Expand All @@ -494,18 +488,17 @@ def get_job_all_awslogs_info(self, job_id: str) -> list[dict[str, str]]:

# cross stream names with options (i.e. attempts X nodes) to generate all log infos
result = []
for stream in stream_names:
for option in log_options:
result.append(
{
"awslogs_stream_name": stream,
# If the user did not specify anything, the default settings are:
# awslogs-group = /aws/batch/job
# awslogs-region = `same as AWS Batch Job region`
"awslogs_group": option.get("awslogs-group", "/aws/batch/job"),
"awslogs_region": option.get("awslogs-region", self.conn_region_name),
}
)
for stream, option in it.product(stream_names, log_options):
result.append(
{
"awslogs_stream_name": stream,
# If the user did not specify anything, the default settings are:
# awslogs-group = /aws/batch/job
# awslogs-region = `same as AWS Batch Job region`
"awslogs_group": option.get("awslogs-group", "/aws/batch/job"),
"awslogs_region": option.get("awslogs-region", self.conn_region_name),
}
)
return result

@staticmethod
Expand Down
28 changes: 8 additions & 20 deletions airflow/providers/amazon/aws/hooks/datasync.py
Expand Up @@ -125,17 +125,11 @@ def get_location_arns(

def _refresh_locations(self) -> None:
"""Refresh the local list of Locations."""
self.locations = []
next_token = None
while True:
if next_token:
locations = self.get_conn().list_locations(NextToken=next_token)
else:
locations = self.get_conn().list_locations()
locations = self.get_conn().list_locations()
self.locations = locations["Locations"]
while "NextToken" in locations:
locations = self.get_conn().list_locations(NextToken=locations["NextToken"])
self.locations.extend(locations["Locations"])
if "NextToken" not in locations:
break
next_token = locations["NextToken"]

def create_task(
self, source_location_arn: str, destination_location_arn: str, **create_task_kwargs
Expand Down Expand Up @@ -181,17 +175,11 @@ def delete_task(self, task_arn: str) -> None:

def _refresh_tasks(self) -> None:
"""Refreshes the local list of Tasks."""
self.tasks = []
next_token = None
while True:
if next_token:
tasks = self.get_conn().list_tasks(NextToken=next_token)
else:
tasks = self.get_conn().list_tasks()
tasks = self.get_conn().list_tasks()
self.tasks = tasks["Tasks"]
while "NextToken" in tasks:
tasks = self.get_conn().list_tasks(NextToken=tasks["NextToken"])
self.tasks.extend(tasks["Tasks"])
if "NextToken" not in tasks:
break
next_token = tasks["NextToken"]

def get_task_arns_for_location_arns(
self,
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/hooks/redshift_data.py
Expand Up @@ -195,9 +195,9 @@ def get_table_primary_key(
# we only select a single column (that is a string),
# so safe to assume that there is only a single col in the record
pk_columns += [y["stringValue"] for x in response["Records"] for y in x]
if "NextToken" not in response.keys():
break
else:
if "NextToken" in response:
token = response["NextToken"]
else:
break

return pk_columns or None
7 changes: 2 additions & 5 deletions airflow/providers/amazon/aws/hooks/s3.py
Expand Up @@ -1284,18 +1284,15 @@ def delete_bucket(self, bucket_name: str, force_delete: bool = False, max_retrie
bucket and trying to delete the bucket.
:return: None
"""
tries_remaining = max_retries + 1
if force_delete:
while tries_remaining:
for retry in range(max_retries):
bucket_keys = self.list_keys(bucket_name=bucket_name)
if not bucket_keys:
break
if tries_remaining <= max_retries:
# Avoid first loop
if retry: # Avoid first loop
sleep(500)

self.delete_objects(bucket=bucket_name, keys=bucket_keys)
tries_remaining -= 1

self.conn.delete_bucket(Bucket=bucket_name)

Expand Down
31 changes: 15 additions & 16 deletions airflow/providers/amazon/aws/hooks/sagemaker.py
Expand Up @@ -1143,34 +1143,33 @@ def stop_pipeline(
if check_interval is None:
check_interval = 10

retries = 2 # i.e. 3 calls max, 1 initial + 2 retries
while True:
for retries in (2, 1, 0):
try:
self.conn.stop_pipeline_execution(PipelineExecutionArn=pipeline_exec_arn)
break
except ClientError as ce:
# this can happen if the pipeline was transitioning between steps at that moment
if ce.response["Error"]["Code"] == "ConflictException" and retries > 0:
retries = retries - 1
if ce.response["Error"]["Code"] == "ConflictException" and retries:
self.log.warning(
"Got a conflict exception when trying to stop the pipeline, "
"retrying %s more times. Error was: %s",
retries,
ce,
)
time.sleep(0.3) # error is due to a race condition, so it should be very transient
continue
# we have to rely on the message to catch the right error here, because its type
# (ValidationException) is shared with other kinds of errors (e.g. badly formatted ARN)
if (
not fail_if_not_running
and "Only pipelines with 'Executing' status can be stopped"
in ce.response["Error"]["Message"]
):
self.log.warning("Cannot stop pipeline execution, as it was not running: %s", ce)
else:
self.log.error(ce)
raise
# we have to rely on the message to catch the right error here, because its type
# (ValidationException) is shared with other kinds of errors (e.g. badly formatted ARN)
if (
not fail_if_not_running
and "Only pipelines with 'Executing' status can be stopped"
in ce.response["Error"]["Message"]
):
self.log.warning("Cannot stop pipeline execution, as it was not running: %s", ce)
break
else:
self.log.error(ce)
raise
else:
break

res = self.describe_pipeline_exec(pipeline_exec_arn)
Expand Down
7 changes: 4 additions & 3 deletions airflow/providers/amazon/aws/log/s3_task_handler.py
Expand Up @@ -122,9 +122,10 @@ def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[list[str], l
bucket, prefix = self.hook.parse_s3_url(s3url=os.path.join(self.remote_base, worker_log_rel_path))
keys = self.hook.list_keys(bucket_name=bucket, prefix=prefix)
if keys:
keys = [f"s3://{bucket}/{key}" for key in keys]
messages.extend(["Found logs in s3:", *[f" * {x}" for x in sorted(keys)]])
for key in sorted(keys):
keys = sorted(f"s3://{bucket}/{key}" for key in keys)
messages.append("Found logs in s3:")
messages.extend(f" * {key}" for key in keys)
for key in keys:
logs.append(self.s3_read(key, return_error=True))
else:
messages.append(f"No logs found on s3 for ti={ti}")
Expand Down
18 changes: 9 additions & 9 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Expand Up @@ -492,14 +492,14 @@ def __init__(
def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
self.log.info("Starting resume cluster")
while self._remaining_attempts >= 1:
while self._remaining_attempts:
try:
redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
break
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._remaining_attempts = self._remaining_attempts - 1
self._remaining_attempts -= 1

if self._remaining_attempts > 0:
if self._remaining_attempts:
self.log.error(
"Unable to resume cluster. %d attempts remaining.", self._remaining_attempts
)
Expand Down Expand Up @@ -580,14 +580,14 @@ def __init__(

def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
while self._remaining_attempts >= 1:
while self._remaining_attempts:
try:
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
break
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._remaining_attempts = self._remaining_attempts - 1
self._remaining_attempts -= 1

if self._remaining_attempts > 0:
if self._remaining_attempts:
self.log.error(
"Unable to pause cluster. %d attempts remaining.", self._remaining_attempts
)
Expand Down Expand Up @@ -669,7 +669,7 @@ def __init__(
self.max_attempts = max_attempts

def execute(self, context: Context):
while self._attempts >= 1:
while self._attempts:
try:
self.redshift_hook.delete_cluster(
cluster_identifier=self.cluster_identifier,
Expand All @@ -678,9 +678,9 @@ def execute(self, context: Context):
)
break
except self.redshift_hook.get_conn().exceptions.InvalidClusterStateFault:
self._attempts = self._attempts - 1
self._attempts -= 1

if self._attempts > 0:
if self._attempts:
self.log.error("Unable to delete cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
Expand Down
5 changes: 1 addition & 4 deletions airflow/providers/amazon/aws/secrets/secrets_manager.py
Expand Up @@ -217,10 +217,7 @@ def _standardize_secret_keys(self, secret: dict[str, Any]) -> dict[str, Any]:

conn_d: dict[str, Any] = {}
for conn_field, possible_words in possible_words_for_conn_fields.items():
try:
conn_d[conn_field] = [v for k, v in secret.items() if k in possible_words][0]
except IndexError:
conn_d[conn_field] = None
conn_d[conn_field] = next((v for k, v in secret.items() if k in possible_words), None)

return conn_d

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/transfers/redshift_to_s3.py
Expand Up @@ -137,7 +137,7 @@ def __init__(

if self.redshift_data_api_kwargs:
for arg in ["sql", "parameters"]:
if arg in self.redshift_data_api_kwargs.keys():
if arg in self.redshift_data_api_kwargs:
raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs")

def _build_unload_query(
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/transfers/s3_to_redshift.py
Expand Up @@ -116,7 +116,7 @@ def __init__(

if self.redshift_data_api_kwargs:
for arg in ["sql", "parameters"]:
if arg in self.redshift_data_api_kwargs.keys():
if arg in self.redshift_data_api_kwargs:
raise AirflowException(f"Cannot include param '{arg}' in Redshift Data API kwargs")

def _build_copy_query(
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/transfers/sql_to_s3.py
Expand Up @@ -194,7 +194,7 @@ def _partition_dataframe(self, df: DataFrame) -> Iterable[tuple[str, DataFrame]]
yield "", df
else:
grouped_df = df.groupby(**self.groupby_kwargs)
for group_label in grouped_df.groups.keys():
for group_label in grouped_df.groups:
yield group_label, grouped_df.get_group(group_label).reset_index(drop=True)

def _get_hook(self) -> DbApiHook:
Expand Down

0 comments on commit 83bd60f

Please sign in to comment.