Skip to content

Commit

Permalink
Refactor usage of str() in providers (#34320)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro committed Sep 28, 2023
1 parent ca3ce78 commit 7ebf422
Show file tree
Hide file tree
Showing 30 changed files with 70 additions and 70 deletions.
6 changes: 3 additions & 3 deletions airflow/providers/alibaba/cloud/log/oss_task_handler.py
Expand Up @@ -189,16 +189,16 @@ def oss_write(self, log, remote_log_location, append=True) -> bool:
if append and self.oss_log_exists(oss_remote_log_location):
head = self.hook.head_key(self.bucket_name, oss_remote_log_location)
pos = head.content_length
self.log.info("log write pos is: %s", str(pos))
self.log.info("log write pos is: %s", pos)
try:
self.log.info("writing remote log: %s", oss_remote_log_location)
self.hook.append_string(self.bucket_name, log, oss_remote_log_location, pos)
except Exception:
self.log.exception(
"Could not write logs to %s, log write pos is: %s, Append is %s",
oss_remote_log_location,
str(pos),
str(append),
pos,
append,
)
return False
return True
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/hooks/base_aws.py
Expand Up @@ -839,7 +839,7 @@ def test_connection(self):
return True, ", ".join(f"{k}={v!r}" for k, v in conn_info.items())

except Exception as e:
return False, str(f"{type(e).__name__!r} error occurred while testing connection: {e}")
return False, f"{type(e).__name__!r} error occurred while testing connection: {e}"

@cached_property
def waiter_path(self) -> os.PathLike[str] | None:
Expand Down
9 changes: 5 additions & 4 deletions airflow/providers/amazon/aws/hooks/s3.py
Expand Up @@ -350,7 +350,8 @@ def check_for_prefix(self, prefix: str, delimiter: str, bucket_name: str | None
:param delimiter: the delimiter marks key hierarchy.
:return: False if the prefix does not exist in the bucket and True if it does.
"""
prefix = prefix + delimiter if prefix[-1] != delimiter else prefix
if not prefix.endswith(delimiter):
prefix += delimiter
prefix_split = re.split(rf"(\w+[{delimiter}])$", prefix, 1)
previous_level = prefix_split[0]
plist = self.list_prefixes(bucket_name, previous_level, delimiter)
Expand Down Expand Up @@ -544,7 +545,8 @@ async def check_for_prefix_async(
:param delimiter: the delimiter marks key hierarchy.
:return: False if the prefix does not exist in the bucket and True if it does.
"""
prefix = prefix + delimiter if prefix[-1] != delimiter else prefix
if not prefix.endswith(delimiter):
prefix += delimiter
prefix_split = re.split(rf"(\w+[{delimiter}])$", prefix, 1)
previous_level = prefix_split[0]
plist = await self.list_prefixes_async(client, bucket_name, previous_level, delimiter)
Expand Down Expand Up @@ -576,8 +578,7 @@ async def get_files_async(
response = paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter=delimiter)
async for page in response:
if "Contents" in page:
_temp = [k for k in page["Contents"] if isinstance(k.get("Size", None), (int, float))]
keys = keys + _temp
keys.extend(k for k in page["Contents"] if isinstance(k.get("Size"), (int, float)))
return keys

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/apache/spark/hooks/spark_submit.py
Expand Up @@ -268,7 +268,7 @@ def _build_spark_submit_command(self, application: str) -> list[str]:
connection_cmd += ["--master", self._connection["master"]]

for key in self._conf:
connection_cmd += ["--conf", f"{key}={str(self._conf[key])}"]
connection_cmd += ["--conf", f"{key}={self._conf[key]}"]
if self._env_vars and (self._is_kubernetes or self._is_yarn):
if self._is_yarn:
tmpl = "spark.yarn.appMasterEnv.{}={}"
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/celery/executors/celery_executor_utils.py
Expand Up @@ -154,8 +154,8 @@ def _execute_in_fork(command_to_exec: CommandType, celery_task_id: str | None =
setproctitle(f"airflow task supervisor: {command_to_exec}")
args.func(args)
ret = 0
except Exception as e:
log.exception("[%s] Failed to execute task %s.", celery_task_id, str(e))
except Exception:
log.exception("[%s] Failed to execute task.", celery_task_id)
ret = 1
finally:
Sentry.flush()
Expand Down
Expand Up @@ -438,10 +438,10 @@ def _change_state(
if self.kube_config.delete_worker_pods:
if state != TaskInstanceState.FAILED or self.kube_config.delete_worker_pods_on_failure:
self.kube_scheduler.delete_pod(pod_name=pod_name, namespace=namespace)
self.log.info("Deleted pod: %s in namespace %s", str(key), str(namespace))
self.log.info("Deleted pod: %s in namespace %s", key, namespace)
else:
self.kube_scheduler.patch_pod_executor_done(pod_name=pod_name, namespace=namespace)
self.log.info("Patched pod %s in namespace %s to mark it as done", str(key), str(namespace))
self.log.info("Patched pod %s in namespace %s to mark it as done", key, namespace)

try:
self.running.remove(key)
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/cncf/kubernetes/python_kubernetes_script.py
Expand Up @@ -31,9 +31,9 @@ def _balance_parens(after_decorator):
while num_paren:
current = after_decorator.popleft()
if current == "(":
num_paren = num_paren + 1
num_paren += 1
elif current == ")":
num_paren = num_paren - 1
num_paren -= 1
return "".join(after_decorator)


Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/cncf/kubernetes/utils/delete_from.py
Expand Up @@ -138,7 +138,7 @@ def _delete_from_yaml_single_item(
else:
resp = getattr(k8s_api, f"delete_{kind}")(name=name, body=body, **kwargs)
if verbose:
print(f"{kind} deleted. status='{str(resp.status)}'")
print(f"{kind} deleted. status='{resp.status}'")
return resp


Expand Down
35 changes: 18 additions & 17 deletions airflow/providers/cncf/kubernetes/utils/pod_manager.py
Expand Up @@ -729,23 +729,24 @@ def extract_xcom_kill(self, pod: V1Pod):
self._exec_pod_command(resp, "kill -s SIGINT 1")

def _exec_pod_command(self, resp, command: str) -> str | None:
res = None
if resp.is_open():
self.log.info("Running command... %s\n", command)
resp.write_stdin(command + "\n")
while resp.is_open():
resp.update(timeout=1)
while resp.peek_stdout():
res = res + resp.read_stdout() if res else resp.read_stdout()
error_res = None
while resp.peek_stderr():
error_res = error_res + resp.read_stderr() if error_res else resp.read_stderr()
if error_res:
self.log.info("stderr from command: %s", error_res)
break
if res:
return res
return res
res = ""
if not resp.is_open():
return None
self.log.info("Running command... %s", command)
resp.write_stdin(f"{command}\n")
while resp.is_open():
resp.update(timeout=1)
while resp.peek_stdout():
res += resp.read_stdout()
error_res = ""
while resp.peek_stderr():
error_res += resp.read_stderr()
if error_res:
self.log.info("stderr from command: %s", error_res)
break
if res:
return res
return None


class OnFinishAction(str, enum.Enum):
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/common/sql/operators/sql.py
Expand Up @@ -945,8 +945,8 @@ def __init__(
sqlexp = ", ".join(self.metrics_sorted)
sqlt = f"SELECT {sqlexp} FROM {table} WHERE {date_filter_column}="

self.sql1 = sqlt + "'{{ ds }}'"
self.sql2 = sqlt + "'{{ macros.ds_add(ds, " + str(self.days_back) + ") }}'"
self.sql1 = f"{sqlt}'{{{{ ds }}}}'"
self.sql2 = f"{sqlt}'{{{{ macros.ds_add(ds, {self.days_back}) }}}}'"

def execute(self, context: Context):
hook = self.get_db_hook()
Expand Down
6 changes: 2 additions & 4 deletions airflow/providers/databricks/operators/databricks_sql.py
Expand Up @@ -339,13 +339,11 @@ def _create_sql_query(self) -> str:
elif isinstance(self._validate, int):
if self._validate < 0:
raise AirflowException(
"Number of rows for validation should be positive, got: " + str(self._validate)
f"Number of rows for validation should be positive, got: {self._validate}"
)
validation = f"VALIDATE {self._validate} ROWS\n"
else:
raise AirflowException(
"Incorrect data type for validate parameter: " + str(type(self._validate))
)
raise AirflowException(f"Incorrect data type for validate parameter: {type(self._validate)}")
# TODO: think on how to make sure that table_name and expression_list aren't used for SQL injection
sql = f"""COPY INTO {self._table_name}{storage_cred}
FROM {location}
Expand Down
Expand Up @@ -140,7 +140,7 @@ def _check_table_partitions(self) -> list:
if self.table_name.split(".")[0] == "delta":
_fully_qualified_table_name = self.table_name
else:
_fully_qualified_table_name = str(self.catalog + "." + self.schema + "." + self.table_name)
_fully_qualified_table_name = f"{self.catalog}.{self.schema}.{self.table_name}"
self.log.debug("Table name generated from arguments: %s", _fully_qualified_table_name)
_joiner_val = " AND "
_prefix = f"SELECT 1 FROM {_fully_qualified_table_name} WHERE"
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/dbt/cloud/operators/dbt.py
Expand Up @@ -142,7 +142,7 @@ def execute(self, context: Context):

if self.wait_for_termination and isinstance(self.run_id, int):
if self.deferrable is False:
self.log.info("Waiting for job run %s to terminate.", str(self.run_id))
self.log.info("Waiting for job run %s to terminate.", self.run_id)

if self.hook.wait_for_job_run_status(
run_id=self.run_id,
Expand All @@ -151,7 +151,7 @@ def execute(self, context: Context):
check_interval=self.check_interval,
timeout=self.timeout,
):
self.log.info("Job run %s has completed successfully.", str(self.run_id))
self.log.info("Job run %s has completed successfully.", self.run_id)
else:
raise DbtCloudJobRunException(f"Job run {self.run_id} has failed or has been cancelled.")

Expand All @@ -173,7 +173,7 @@ def execute(self, context: Context):
method_name="execute_complete",
)
elif job_run_status == DbtCloudJobRunStatus.SUCCESS.value:
self.log.info("Job run %s has completed successfully.", str(self.run_id))
self.log.info("Job run %s has completed successfully.", self.run_id)
return self.run_id
elif job_run_status in (
DbtCloudJobRunStatus.CANCELLED.value,
Expand Down Expand Up @@ -211,7 +211,7 @@ def on_kill(self) -> None:
check_interval=self.check_interval,
timeout=self.timeout,
):
self.log.info("Job run %s has been cancelled successfully.", str(self.run_id))
self.log.info("Job run %s has been cancelled successfully.", self.run_id)

@cached_property
def hook(self):
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/docker/operators/docker_swarm.py
Expand Up @@ -143,7 +143,7 @@ def _run_service(self) -> None:
)
if self.service is None:
raise Exception("Service should be set here")
self.log.info("Service started: %s", str(self.service))
self.log.info("Service started: %s", self.service)

# wait for the service to start the task
while not self.cli.tasks(filters={"service": self.service["ID"]}):
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/cloud_sql.py
Expand Up @@ -954,7 +954,7 @@ def _get_instance_socket_name(self) -> str:
def _get_sqlproxy_instance_specification(self) -> str:
instance_specification = self._get_instance_socket_name()
if self.sql_proxy_use_tcp:
instance_specification += "=tcp:" + str(self.sql_proxy_tcp_port)
instance_specification += f"=tcp:{self.sql_proxy_tcp_port}"
return instance_specification

def create_connection(self) -> Connection:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/hooks/dataflow.py
Expand Up @@ -426,7 +426,7 @@ def _check_dataflow_job_state(self, job) -> bool:
if current_state in DataflowJobStatus.AWAITING_STATES:
return self._wait_until_finished is False

self.log.debug("Current job: %s", str(job))
self.log.debug("Current job: %s", job)
raise Exception(
f"Google Cloud Dataflow job {job['name']} is in an unexpected terminal state: {current_state}, "
f"expected terminal state: {self._expected_terminal_state}"
Expand Down Expand Up @@ -896,7 +896,7 @@ def build_dataflow_job_name(job_name: str, append_job_name: bool = True) -> str:
)

if append_job_name:
safe_job_name = base_job_name + "-" + str(uuid.uuid4())[:8]
safe_job_name = f"{base_job_name}-{uuid.uuid4()!s:.8}"
else:
safe_job_name = base_job_name

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/operators/dataflow.py
Expand Up @@ -899,7 +899,7 @@ def _append_uuid_to_job_name(self):
job_body = self.body.get("launch_parameter") or self.body.get("launchParameter")
job_name = job_body.get("jobName")
if job_name:
job_name += f"-{str(uuid.uuid4())[:8]}"
job_name += f"-{uuid.uuid4()!s:.8}"
job_body["jobName"] = job_name
self.log.info("Job name was changed to %s", job_name)

Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/google/cloud/operators/dataproc.py
Expand Up @@ -615,7 +615,7 @@ def _wait_for_cluster_in_deleting_state(self, hook: DataprocHook) -> None:
if time_left < 0:
raise AirflowException(f"Cluster {self.cluster_name} is still DELETING state, aborting")
time.sleep(time_to_sleep)
time_left = time_left - time_to_sleep
time_left -= time_to_sleep
try:
self._get_cluster(hook)
except NotFound:
Expand All @@ -630,7 +630,7 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster:
if time_left < 0:
raise AirflowException(f"Cluster {self.cluster_name} is still CREATING state, aborting")
time.sleep(time_to_sleep)
time_left = time_left - time_to_sleep
time_left -= time_to_sleep
cluster = self._get_cluster(hook)
return cluster

Expand Down Expand Up @@ -1599,7 +1599,7 @@ class DataprocSubmitPySparkJobOperator(DataprocJobBaseOperator):

@staticmethod
def _generate_temp_filename(filename):
return f"{time:%Y%m%d%H%M%S}_{str(uuid.uuid4())[:8]}_{ntpath.basename(filename)}"
return f"{time:%Y%m%d%H%M%S}_{uuid.uuid4()!s:.8}_{ntpath.basename(filename)}"

def _upload_file_temp(self, bucket, local_file):
"""Upload a local file to a Google Cloud Storage bucket."""
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/sensors/bigquery_dts.py
Expand Up @@ -137,7 +137,7 @@ def poke(self, context: Context) -> bool:
timeout=self.request_timeout,
metadata=self.metadata,
)
self.log.info("Status of %s run: %s", self.run_id, str(run.state))
self.log.info("Status of %s run: %s", self.run_id, run.state)

if run.state in (TransferState.FAILED, TransferState.CANCELLED):
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
Expand Down
Expand Up @@ -199,7 +199,7 @@ def _decide_and_flush(self, converted_rows_with_action: dict[FlushAction, list])
else:
message = (
"FlushAction not found in the data. Please check the FlushAction in "
"the operator. Converted Rows with Action: " + str(converted_rows_with_action)
f"the operator. Converted Rows with Action: {converted_rows_with_action}"
)
raise AirflowException(message)
return total_data_count
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/leveldb/operators/leveldb.py
Expand Up @@ -88,7 +88,7 @@ def execute(self, context: Context) -> str | None:
keys=self.keys,
values=self.values,
)
self.log.info("Done. Returned value was: %s", str(value))
self.log.info("Done. Returned value was: %s", value)
leveldb_hook.close_conn()
str_value = value if value is None else value.decode()
return str_value
2 changes: 1 addition & 1 deletion airflow/providers/grpc/hooks/grpc.py
Expand Up @@ -83,7 +83,7 @@ def get_conn(self) -> grpc.Channel:
base_url = self.conn.host

if self.conn.port:
base_url = base_url + ":" + str(self.conn.port)
base_url += f":{self.conn.port}"

auth_type = self._get_field("auth_type")

Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/http/hooks/http.py
Expand Up @@ -103,10 +103,10 @@ def get_conn(self, headers: dict[Any, Any] | None = None) -> requests.Session:
# schema defaults to HTTP
schema = conn.schema if conn.schema else "http"
host = conn.host if conn.host else ""
self.base_url = schema + "://" + host
self.base_url = f"{schema}://{host}"

if conn.port:
self.base_url = self.base_url + ":" + str(conn.port)
self.base_url += f":{conn.port}"
if conn.login:
session.auth = self.auth_type(conn.login, conn.password)
elif self._auth_type:
Expand Down Expand Up @@ -329,7 +329,7 @@ async def run(
self.base_url = schema + "://" + host

if conn.port:
self.base_url = self.base_url + ":" + str(conn.port)
self.base_url += f":{conn.port}"
if conn.login:
auth = self.auth_type(conn.login, conn.password)
if conn.extra:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/microsoft/azure/log/wasb_task_handler.py
Expand Up @@ -147,7 +147,7 @@ def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[list[str], l
blob_names = self.hook.get_blobs_list(container_name=self.wasb_container, prefix=prefix)
except HttpResponseError as e:
messages.append(f"tried listing blobs with prefix={prefix} and container={self.wasb_container}")
messages.append("could not list blobs " + str(e))
messages.append(f"could not list blobs {e}")
self.log.exception("can't list blobs")

if blob_names:
Expand Down
Expand Up @@ -125,7 +125,7 @@ def get_sftp_files_map(self) -> list[SftpFile]:
sftp_complete_path, prefix=prefix, delimiter=delimiter
)

self.log.info("Found %s files at sftp source path: %s", str(len(found_files)), self.sftp_source_path)
self.log.info("Found %d files at sftp source path: %s", len(found_files), self.sftp_source_path)

for file in found_files:
future_blob_name = self.get_full_path_blob(file)
Expand Down

0 comments on commit 7ebf422

Please sign in to comment.