Skip to content

Commit

Permalink
chore: Refactoring and Cleaning Apache Providers (#24219)
Browse files Browse the repository at this point in the history
  • Loading branch information
koconder committed Jun 6, 2022
1 parent 0685633 commit b4a5783
Show file tree
Hide file tree
Showing 16 changed files with 47 additions and 71 deletions.
6 changes: 1 addition & 5 deletions airflow/providers/apache/beam/operators/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,11 +469,7 @@ def execute(self, context: 'Context'):
process_line_callback=process_line_callback,
)
if dataflow_job_name and self.dataflow_config.location:
multiple_jobs = (
self.dataflow_config.multiple_jobs
if self.dataflow_config.multiple_jobs
else False
)
multiple_jobs = self.dataflow_config.multiple_jobs or False
DataflowJobLink.persist(
self,
context,
Expand Down
7 changes: 3 additions & 4 deletions airflow/providers/apache/cassandra/hooks/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,8 @@ def get_lb_policy(policy_name: str, policy_args: Dict[str, Any]) -> Policy:
child_policy_args = policy_args.get('child_load_balancing_policy_args', {})
if child_policy_name not in allowed_child_policies:
return TokenAwarePolicy(RoundRobinPolicy())
else:
child_policy = CassandraHook.get_lb_policy(child_policy_name, child_policy_args)
return TokenAwarePolicy(child_policy)
child_policy = CassandraHook.get_lb_policy(child_policy_name, child_policy_args)
return TokenAwarePolicy(child_policy)

# Fallback to default RoundRobinPolicy
return RoundRobinPolicy()
Expand Down Expand Up @@ -200,7 +199,7 @@ def record_exists(self, table: str, keys: Dict[str, str]) -> bool:
keyspace = self.keyspace
if '.' in table:
keyspace, table = table.split('.', 1)
ks_str = " AND ".join(f"{key}=%({key})s" for key in keys.keys())
ks_str = " AND ".join(f"{key}=%({key})s" for key in keys)
query = f"SELECT * FROM {keyspace}.{table} WHERE {ks_str}"
try:
result = self.get_conn().execute(query, keys)
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/apache/drill/hooks/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_uri(self) -> str:
host = conn_md.host
if conn_md.port is not None:
host += f':{conn_md.port}'
conn_type = 'drill' if not conn_md.conn_type else conn_md.conn_type
conn_type = conn_md.conn_type or 'drill'
dialect_driver = conn_md.extra_dejson.get('dialect_driver', 'drill+sadrill')
storage_plugin = conn_md.extra_dejson.get('storage_plugin', 'dfs')
return f'{conn_type}://{host}/{storage_plugin}?dialect_driver={dialect_driver}'
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/apache/druid/hooks/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_conn_url(self) -> str:
conn = self.get_connection(self.druid_ingest_conn_id)
host = conn.host
port = conn.port
conn_type = 'http' if not conn.conn_type else conn.conn_type
conn_type = conn.conn_type or 'http'
endpoint = conn.extra_dejson.get('endpoint', '')
return f"{conn_type}://{host}:{port}/{endpoint}"

Expand Down Expand Up @@ -163,7 +163,7 @@ def get_uri(self) -> str:
host = conn.host
if conn.port is not None:
host += f':{conn.port}'
conn_type = 'druid' if not conn.conn_type else conn.conn_type
conn_type = conn.conn_type or 'druid'
endpoint = conn.extra_dejson.get('endpoint', 'druid/v2/sql')
return f'{conn_type}://{host}/{endpoint}'

Expand Down
9 changes: 3 additions & 6 deletions airflow/providers/apache/hdfs/hooks/webhdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,9 @@ def _get_client(self, namenode: str, port: int, login: str, extra_dejson: dict)
session.verify = extra_dejson.get('verify', True)

if _kerberos_security_mode:
client = KerberosClient(connection_str, session=session)
else:
proxy_user = self.proxy_user or login
client = InsecureClient(connection_str, user=proxy_user, session=session)

return client
return KerberosClient(connection_str, session=session)
proxy_user = self.proxy_user or login
return InsecureClient(connection_str, user=proxy_user, session=session)

def check_for_path(self, hdfs_path: str) -> bool:
"""
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/apache/hive/operators/hive_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,15 @@ def get_default_exprs(self, col: str, col_type: str) -> Dict[Any, Any]:
if col in self.excluded_columns:
return {}
exp = {(col, 'non_null'): f"COUNT({col})"}
if col_type in ['double', 'int', 'bigint', 'float']:
if col_type in {'double', 'int', 'bigint', 'float'}:
exp[(col, 'sum')] = f'SUM({col})'
exp[(col, 'min')] = f'MIN({col})'
exp[(col, 'max')] = f'MAX({col})'
exp[(col, 'avg')] = f'AVG({col})'
elif col_type == 'boolean':
exp[(col, 'true')] = f'SUM(CASE WHEN {col} THEN 1 ELSE 0 END)'
exp[(col, 'false')] = f'SUM(CASE WHEN NOT {col} THEN 1 ELSE 0 END)'
elif col_type in ['string']:
elif col_type == 'string':
exp[(col, 'len')] = f'SUM(CAST(LENGTH({col}) AS BIGINT))'
exp[(col, 'approx_distinct')] = f'APPROX_DISTINCT({col})'

Expand All @@ -130,7 +130,7 @@ def execute(self, context: "Context") -> None:
exprs.update(assign_exprs)
exprs.update(self.extra_exprs)
exprs = OrderedDict(exprs)
exprs_str = ",\n ".join(v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items())
exprs_str = ",\n ".join(f"{v} AS {k[0]}__{k[1]}" for k, v in exprs.items())

where_clause_ = [f"{k} = '{v}'" for k, v in self.partition.items()]
where_clause = " AND\n ".join(where_clause_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def parse_partition_name(partition: str) -> Tuple[Any, ...]:
schema, table_partition = first_split
second_split = table_partition.split('/', 1)
if len(second_split) == 1:
raise ValueError('Could not parse ' + partition + 'into table, partition')
raise ValueError(f'Could not parse {partition}into table, partition')
else:
table, partition = second_split
return schema, table, partition
Expand Down
4 changes: 1 addition & 3 deletions airflow/providers/apache/hive/transfers/mssql_to_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ def execute(self, context: "Context"):
with NamedTemporaryFile("w") as tmp_file:
csv_writer = csv.writer(tmp_file, delimiter=self.delimiter, encoding='utf-8')
field_dict = OrderedDict()
col_count = 0
for field in cursor.description:
col_count += 1
for col_count, field in enumerate(cursor.description, start=1):
col_position = f"Column{col_count}"
field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1])
csv_writer.writerows(cursor)
Expand Down
9 changes: 4 additions & 5 deletions airflow/providers/apache/hive/transfers/s3_to_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def execute(self, context: 'Context'):
if not s3_hook.check_for_wildcard_key(self.s3_key):
raise AirflowException(f"No key matches {self.s3_key}")
s3_key_object = s3_hook.get_wildcard_key(self.s3_key)
else:
if not s3_hook.check_for_key(self.s3_key):
raise AirflowException(f"The key {self.s3_key} does not exists")
elif s3_hook.check_for_key(self.s3_key):
s3_key_object = s3_hook.get_key(self.s3_key)

else:
raise AirflowException(f"The key {self.s3_key} does not exists")
_, file_ext = os.path.splitext(s3_key_object.key)
if self.select_expression and self.input_compressed and file_ext.lower() != '.gz':
raise AirflowException("GZIP is the only compression format Amazon S3 Select supports")
Expand Down Expand Up @@ -227,8 +227,7 @@ def execute(self, context: 'Context'):
def _get_top_row_as_list(self, file_name):
with open(file_name) as file:
header_line = file.readline().strip()
header_list = header_line.split(self.delimiter)
return header_list
return header_line.split(self.delimiter)

def _match_headers(self, header_list):
if not header_list:
Expand Down
4 changes: 1 addition & 3 deletions airflow/providers/apache/hive/transfers/vertica_to_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,7 @@ def execute(self, context: 'Context'):
with NamedTemporaryFile("w") as f:
csv_writer = csv.writer(f, delimiter=self.delimiter, encoding='utf-8')
field_dict = OrderedDict()
col_count = 0
for field in cursor.description:
col_count += 1
for col_count, field in enumerate(cursor.description, start=1):
col_position = f"Column{col_count}"
field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1])
csv_writer.writerows(cursor.iterate())
Expand Down
22 changes: 10 additions & 12 deletions airflow/providers/apache/kylin/hooks/kylin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,15 @@ def get_conn(self):
conn = self.get_connection(self.kylin_conn_id)
if self.dsn:
return kylinpy.create_kylin(self.dsn)
else:
self.project = self.project if self.project else conn.schema
return kylinpy.Kylin(
conn.host,
username=conn.login,
password=conn.password,
port=conn.port,
project=self.project,
**conn.extra_dejson,
)
self.project = self.project or conn.schema
return kylinpy.Kylin(
conn.host,
username=conn.login,
password=conn.password,
port=conn.port,
project=self.project,
**conn.extra_dejson,
)

def cube_run(self, datasource_name, op, **op_args):
"""
Expand All @@ -70,8 +69,7 @@ def cube_run(self, datasource_name, op, **op_args):
"""
cube_source = self.get_conn().get_datasource(datasource_name)
try:
response = cube_source.invoke_command(op, **op_args)
return response
return cube_source.invoke_command(op, **op_args)
except exceptions.KylinError as err:
raise AirflowException(f"Cube operation {op} error , Message: {err}")

Expand Down
6 changes: 2 additions & 4 deletions airflow/providers/apache/pinot/hooks/pinot.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ def run_cli(self, cmd: List[str], verbose: bool = True) -> str:
:param cmd: List of command going to be run by pinot-admin.sh script
:param verbose:
"""
command = [self.cmd_path]
command.extend(cmd)

command = [self.cmd_path, *cmd]
env = None
if self.pinot_admin_system_exit:
env = os.environ.copy()
Expand Down Expand Up @@ -273,7 +271,7 @@ def get_uri(self) -> str:
host = conn.host
if conn.port is not None:
host += f':{conn.port}'
conn_type = 'http' if not conn.conn_type else conn.conn_type
conn_type = conn.conn_type or 'http'
endpoint = conn.extra_dejson.get('endpoint', 'query/sql')
return f'{conn_type}://{host}/{endpoint}'

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/apache/spark/hooks/spark_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _build_jdbc_application_arguments(self, jdbc_conn: Dict[str, Any]) -> Any:
def submit_jdbc_job(self) -> None:
"""Submit Spark JDBC job"""
self._application_args = self._build_jdbc_application_arguments(self._jdbc_connection)
self.submit(application=os.path.dirname(os.path.abspath(__file__)) + "/spark_jdbc_script.py")
self.submit(application=f"{os.path.dirname(os.path.abspath(__file__))}/spark_jdbc_script.py")

def get_conn(self) -> Any:
pass
26 changes: 10 additions & 16 deletions airflow/providers/apache/spark/hooks/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.
#
import contextlib
import os
import re
import subprocess
Expand All @@ -28,10 +29,8 @@
from airflow.security.kerberos import renew_from_kt
from airflow.utils.log.logging_mixin import LoggingMixin

try:
with contextlib.suppress(ImportError, NameError):
from airflow.kubernetes import kube_client
except (ImportError, NameError):
pass


class SparkSubmitHook(BaseHook, LoggingMixin):
Expand Down Expand Up @@ -355,9 +354,7 @@ def _build_track_driver_status_command(self) -> List[str]:
self.log.info(connection_cmd)

# The driver id so we can poll for its status
if self._driver_id:
pass
else:
if not self._driver_id:
raise AirflowException(
"Invalid status: attempted to poll driver status but no driver id is known. Giving up."
)
Expand Down Expand Up @@ -607,17 +604,14 @@ def on_kill(self) -> None:
"""Kill Spark submit command"""
self.log.debug("Kill Command is being called")

if self._should_track_driver_status:
if self._driver_id:
self.log.info('Killing driver %s on cluster', self._driver_id)
if self._should_track_driver_status and self._driver_id:
self.log.info('Killing driver %s on cluster', self._driver_id)

kill_cmd = self._build_spark_driver_kill_command()
with subprocess.Popen(
kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
) as driver_kill:
self.log.info(
"Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait()
)
kill_cmd = self._build_spark_driver_kill_command()
with subprocess.Popen(kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as driver_kill:
self.log.info(
"Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait()
)

if self._submit_sp and self._submit_sp.poll() is None:
self.log.info('Sending kill signal to %s', self._connection['spark_binary'])
Expand Down
4 changes: 1 addition & 3 deletions tests/providers/apache/hive/transfers/test_mssql_to_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ def test_execute_empty_description_field(self, mock_hive_hook, mock_mssql_hook,
mssql_to_hive_transfer.execute(context={})

field_dict = OrderedDict()
col_count = 0
for field in mock_mssql_hook_cursor.return_value.description:
col_count += 1
for col_count, field in enumerate(mock_mssql_hook_cursor.return_value.description, start=1):
col_position = f"Column{col_count}"
field_dict[col_position] = mssql_to_hive_transfer.type_map(field[1])
mock_hive_hook.return_value.load_file.assert_called_once_with(
Expand Down
5 changes: 3 additions & 2 deletions tests/providers/apache/spark/hooks/test_spark_jdbc_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def test_spark_write_to_jdbc(self, mock_writer_save):
# Given
arguments = _parse_arguments(self.jdbc_arguments)
spark_session = _create_spark_session(arguments)
spark_session.sql("CREATE TABLE IF NOT EXISTS " + arguments.metastore_table + " (key INT)")
spark_session.sql(f"CREATE TABLE IF NOT EXISTS {arguments.metastore_table} (key INT)")

# When

spark_write_to_jdbc(
Expand All @@ -191,7 +192,7 @@ def test_spark_read_from_jdbc(self, mock_reader_load):
# Given
arguments = _parse_arguments(self.jdbc_arguments)
spark_session = _create_spark_session(arguments)
spark_session.sql("CREATE TABLE IF NOT EXISTS " + arguments.metastore_table + " (key INT)")
spark_session.sql(f"CREATE TABLE IF NOT EXISTS {arguments.metastore_table} (key INT)")

# When
spark_read_from_jdbc(
Expand Down

0 comments on commit b4a5783

Please sign in to comment.