Skip to content

Commit

Permalink
Refactor: Think positively in providers (#34279)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro committed Sep 12, 2023
1 parent e4d44fc commit 05036e6
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 49 deletions.
81 changes: 40 additions & 41 deletions airflow/providers/amazon/aws/operators/emr.py
Expand Up @@ -749,52 +749,51 @@ def execute(self, context: Context) -> str | None:
job_flow_overrides = self.job_flow_overrides
response = self._emr_hook.create_job_flow(job_flow_overrides)

if not response["ResponseMetadata"]["HTTPStatusCode"] == 200:
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Job flow creation failed: {response}")
else:
self._job_flow_id = response["JobFlowId"]
self.log.info("Job flow with id %s created", self._job_flow_id)
EmrClusterLink.persist(

self._job_flow_id = response["JobFlowId"]
self.log.info("Job flow with id %s created", self._job_flow_id)
EmrClusterLink.persist(
context=context,
operator=self,
region_name=self._emr_hook.conn_region_name,
aws_partition=self._emr_hook.conn_partition,
job_flow_id=self._job_flow_id,
)
if self._job_flow_id:
EmrLogsLink.persist(
context=context,
operator=self,
region_name=self._emr_hook.conn_region_name,
aws_partition=self._emr_hook.conn_partition,
job_flow_id=self._job_flow_id,
log_uri=get_log_uri(emr_client=self._emr_hook.conn, job_flow_id=self._job_flow_id),
)
if self._job_flow_id:
EmrLogsLink.persist(
context=context,
operator=self,
region_name=self._emr_hook.conn_region_name,
aws_partition=self._emr_hook.conn_partition,
if self.deferrable:
self.defer(
trigger=EmrCreateJobFlowTrigger(
job_flow_id=self._job_flow_id,
log_uri=get_log_uri(emr_client=self._emr_hook.conn, job_flow_id=self._job_flow_id),
)
if self.deferrable:
self.defer(
trigger=EmrCreateJobFlowTrigger(
job_flow_id=self._job_flow_id,
aws_conn_id=self.aws_conn_id,
poll_interval=self.waiter_delay,
max_attempts=self.waiter_max_attempts,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
if self.wait_for_completion:
self._emr_hook.get_waiter("job_flow_waiting").wait(
ClusterId=self._job_flow_id,
WaiterConfig=prune_dict(
{
"Delay": self.waiter_delay,
"MaxAttempts": self.waiter_max_attempts,
}
),
)

return self._job_flow_id
aws_conn_id=self.aws_conn_id,
poll_interval=self.waiter_delay,
max_attempts=self.waiter_max_attempts,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
)
if self.wait_for_completion:
self._emr_hook.get_waiter("job_flow_waiting").wait(
ClusterId=self._job_flow_id,
WaiterConfig=prune_dict(
{
"Delay": self.waiter_delay,
"MaxAttempts": self.waiter_max_attempts,
}
),
)
return self._job_flow_id

def execute_complete(self, context, event=None):
if event["status"] != "success":
Expand Down Expand Up @@ -940,10 +939,10 @@ def execute(self, context: Context) -> None:
self.log.info("Terminating JobFlow %s", self.job_flow_id)
response = emr.terminate_job_flows(JobFlowIds=[self.job_flow_id])

if not response["ResponseMetadata"]["HTTPStatusCode"] == 200:
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"JobFlow termination failed: {response}")
else:
self.log.info("Terminating JobFlow with id %s", self.job_flow_id)

self.log.info("Terminating JobFlow with id %s", self.job_flow_id)

if self.deferrable:
self.defer(
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/s3.py
Expand Up @@ -470,7 +470,7 @@ def execute(self, context: Context):
if not exactly_one(self.keys is None, self.prefix is None):
raise AirflowException("Either keys or prefix should be set.")

if isinstance(self.keys, (list, str)) and not bool(self.keys):
if isinstance(self.keys, (list, str)) and not self.keys:
return
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)

Expand Down
10 changes: 5 additions & 5 deletions airflow/providers/google/cloud/operators/compute.py
Expand Up @@ -215,7 +215,7 @@ def execute(self, context: Context) -> dict:
)
except exceptions.NotFound as e:
# We actually expect to get 404 / Not Found here as the should not yet exist
if not e.code == 404:
if e.code != 404:
raise e
else:
self.log.info("The %s Instance already exists", self.resource_id)
Expand Down Expand Up @@ -386,7 +386,7 @@ def execute(self, context: Context) -> dict:
except exceptions.NotFound as e:
# We actually expect to get 404 / Not Found here as the template should
# not yet exist
if not e.code == 404:
if e.code != 404:
raise e
else:
self.log.info("The %s Instance already exists", self.resource_id)
Expand Down Expand Up @@ -960,7 +960,7 @@ def execute(self, context: Context) -> dict:
except exceptions.NotFound as e:
# We actually expect to get 404 / Not Found here as the template should
# not yet exist
if not e.code == 404:
if e.code != 404:
raise e
else:
self.log.info("The %s Template already exists.", existing_template)
Expand Down Expand Up @@ -1222,7 +1222,7 @@ def execute(self, context: Context) -> dict:
except exceptions.NotFound as e:
# We actually expect to get 404 / Not Found here as the template should
# not yet exist
if not e.code == 404:
if e.code != 404:
raise e
else:
self.log.info(
Expand Down Expand Up @@ -1541,7 +1541,7 @@ def execute(self, context: Context) -> dict:
except exceptions.NotFound as e:
# We actually expect to get 404 / Not Found here as the Instance Group Manager should
# not yet exist
if not e.code == 404:
if e.code != 404:
raise e
else:
self.log.info("The %s Instance Group Manager already exists", existing_instance_group_manager)
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/sftp/operators/sftp.py
Expand Up @@ -123,7 +123,7 @@ def execute(self, context: Any) -> str | list[str] | None:
f"!= {len(remote_filepath_array)} paths in remote_filepath"
)

if not (self.operation.lower() == SFTPOperation.GET or self.operation.lower() == SFTPOperation.PUT):
if self.operation.lower() not in (SFTPOperation.GET, SFTPOperation.PUT):
raise TypeError(
f"Unsupported operation value {self.operation}, "
f"expected {SFTPOperation.GET} or {SFTPOperation.PUT}."
Expand Down
Expand Up @@ -34,7 +34,7 @@
class TestAwsBaseAsyncHook:
@staticmethod
def compare_aio_cred(first, second):
if not type(first) == type(second):
if type(first) != type(second):
return False
if first.access_key != second.access_key:
return False
Expand Down

0 comments on commit 05036e6

Please sign in to comment.