Skip to content

Commit

Permalink
Fix MyPy errors in Apache Providers (#20422)
Browse files Browse the repository at this point in the history
Part of #19891

The .pyi additions are to handle "default_args" passed in
examples. Currently some of the obligatory parameters are
(correctly) passed as default_args. We have no good
mechanism yet to handle it properly for MyPy (it would
require to add a custom MyPy plugin to handle it)

We have no better way to handle it for now.
  • Loading branch information
potiuk committed Dec 29, 2021
1 parent da88ed1 commit 485ff6c
Show file tree
Hide file tree
Showing 22 changed files with 164 additions and 46 deletions.
9 changes: 6 additions & 3 deletions airflow/providers/apache/cassandra/sensors/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
of a record in a Cassandra cluster.
"""

from typing import Any, Dict
from typing import TYPE_CHECKING, Any, Dict

from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


class CassandraRecordSensor(BaseSensorOperator):
"""
Expand Down Expand Up @@ -58,8 +61,8 @@ class CassandraRecordSensor(BaseSensorOperator):
def __init__(
self,
*,
table: str,
keys: Dict[str, str],
table: str,
cassandra_conn_id: str = CassandraHook.default_conn_name,
**kwargs: Any,
) -> None:
Expand All @@ -68,7 +71,7 @@ def __init__(
self.table = table
self.keys = keys

def poke(self, context: Dict[str, str]) -> bool:
def poke(self, context: "Context") -> bool:
self.log.info('Sensor check existence of record: %s', self.keys)
hook = CassandraHook(self.cassandra_conn_id)
return hook.record_exists(self.table, self.keys)
30 changes: 30 additions & 0 deletions airflow/providers/apache/cassandra/sensors/record.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Optional, Dict, Any

from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook

class CassandraRecordSensor:
def __init__(
self,
*,
keys: Optional[Dict[str, str]] = None,
table: Optional[str] = None,
cassandra_conn_id: str = CassandraHook.default_conn_name,
**kwargs: Any,
) -> None: ...
13 changes: 10 additions & 3 deletions airflow/providers/apache/cassandra/sensors/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
of a table in a Cassandra cluster.
"""

from typing import Any, Dict
from typing import TYPE_CHECKING, Any

from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


class CassandraTableSensor(BaseSensorOperator):
"""
Expand Down Expand Up @@ -54,13 +57,17 @@ class CassandraTableSensor(BaseSensorOperator):
template_fields = ('table',)

def __init__(
self, *, table: str, cassandra_conn_id: str = CassandraHook.default_conn_name, **kwargs: Any
self,
*,
table: str,
cassandra_conn_id: str = CassandraHook.default_conn_name,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.cassandra_conn_id = cassandra_conn_id
self.table = table

def poke(self, context: Dict[Any, Any]) -> bool:
def poke(self, context: "Context") -> bool:
self.log.info('Sensor check existence of table: %s', self.table)
hook = CassandraHook(self.cassandra_conn_id)
return hook.table_exists(self.table)
29 changes: 29 additions & 0 deletions airflow/providers/apache/cassandra/sensors/table.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Optional, Any

from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook

class CassandraTableSensor:
def __init__(
self,
*,
table: Optional[str] = None,
cassandra_conn_id: str = CassandraHook.default_conn_name,
**kwargs: Any,
) -> None: ...
7 changes: 5 additions & 2 deletions airflow/providers/apache/druid/operators/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Optional

from airflow.models import BaseOperator
from airflow.providers.apache.druid.hooks.druid import DruidHook

if TYPE_CHECKING:
from airflow.utils.context import Context


class DruidOperator(BaseOperator):
"""
Expand Down Expand Up @@ -57,7 +60,7 @@ def __init__(
self.timeout = timeout
self.max_ingestion_time = max_ingestion_time

def execute(self, context: Dict[Any, Any]) -> None:
def execute(self, context: "Context") -> None:
hook = DruidHook(
druid_ingest_conn_id=self.conn_id,
timeout=self.timeout,
Expand Down
7 changes: 5 additions & 2 deletions airflow/providers/apache/druid/transfers/hive_to_druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@

"""This module contains operator to move data from Hive to Druid."""

from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from airflow.models import BaseOperator
from airflow.providers.apache.druid.hooks.druid import DruidHook
from airflow.providers.apache.hive.hooks.hive import HiveCliHook, HiveMetastoreHook

if TYPE_CHECKING:
from airflow.utils.context import Context

LOAD_CHECK_INTERVAL = 5
DEFAULT_TARGET_PARTITION_SIZE = 5000000

Expand Down Expand Up @@ -116,7 +119,7 @@ def __init__(
self.hive_tblproperties = hive_tblproperties or {}
self.job_properties = job_properties

def execute(self, context: Dict[str, Any]) -> None:
def execute(self, context: "Context") -> None:
hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
self.log.info("Extracting data from Hive")
hive_table = 'druid.' + context['task_instance_key_str'].replace('.', '_')
Expand Down
11 changes: 7 additions & 4 deletions airflow/providers/apache/hdfs/sensors/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
import logging
import re
import sys
from typing import Any, Dict, List, Optional, Pattern, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Type

from airflow import settings
from airflow.providers.apache.hdfs.hooks.hdfs import HDFSHook
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.utils.context import Context

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -115,7 +118,7 @@ def filter_for_ignored_ext(
log.debug('HdfsSensor.poke: after ext filter result is %s', result)
return result

def poke(self, context: Dict[Any, Any]) -> bool:
def poke(self, context: "Context") -> bool:
"""Get a snakebite client connection and check for file."""
sb_client = self.hook(self.hdfs_conn_id).get_conn()
self.log.info('Poking for file %s', self.filepath)
Expand Down Expand Up @@ -149,7 +152,7 @@ def __init__(self, regex: Pattern[str], *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.regex = regex

def poke(self, context: Dict[Any, Any]) -> bool:
def poke(self, context: "Context") -> bool:
"""
Poke matching files in a directory with self.regex
Expand Down Expand Up @@ -182,7 +185,7 @@ def __init__(self, be_empty: bool = False, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.be_empty = be_empty

def poke(self, context: Dict[str, Any]) -> bool:
def poke(self, context: "Context") -> bool:
"""
Poke for a non empty directory
Expand Down
7 changes: 5 additions & 2 deletions airflow/providers/apache/hdfs/sensors/web_hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict
from typing import TYPE_CHECKING, Any

from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


class WebHdfsSensor(BaseSensorOperator):
"""Waits for a file or folder to land in HDFS"""
Expand All @@ -30,7 +33,7 @@ def __init__(self, *, filepath: str, webhdfs_conn_id: str = 'webhdfs_default', *
self.filepath = filepath
self.webhdfs_conn_id = webhdfs_conn_id

def poke(self, context: Dict[Any, Any]) -> bool:
def poke(self, context: "Context") -> bool:
from airflow.providers.apache.hdfs.hooks.webhdfs import WebHDFSHook

hook = WebHDFSHook(self.webhdfs_conn_id)
Expand Down
7 changes: 5 additions & 2 deletions airflow/providers/apache/hive/operators/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
# under the License.
import os
import re
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

from airflow.configuration import conf
from airflow.models import BaseOperator
from airflow.providers.apache.hive.hooks.hive import HiveCliHook
from airflow.utils import operator_helpers
from airflow.utils.operator_helpers import context_to_airflow_vars

if TYPE_CHECKING:
from airflow.utils.context import Context


class HiveOperator(BaseOperator):
"""
Expand Down Expand Up @@ -133,7 +136,7 @@ def prepare_template(self) -> None:
if self.script_begin_tag and self.script_begin_tag in self.hql:
self.hql = "\n".join(self.hql.split(self.script_begin_tag)[1:])

def execute(self, context: Dict[str, Any]) -> None:
def execute(self, context: "Context") -> None:
self.log.info('Executing: %s', self.hql)
self.hook = self.get_hook()

Expand Down
7 changes: 5 additions & 2 deletions airflow/providers/apache/hive/operators/hive_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
import json
import warnings
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.presto.hooks.presto import PrestoHook

if TYPE_CHECKING:
from airflow.utils.context import Context


class HiveStatsCollectionOperator(BaseOperator):
"""
Expand Down Expand Up @@ -116,7 +119,7 @@ def get_default_exprs(self, col: str, col_type: str) -> Dict[Any, Any]:

return exp

def execute(self, context: Optional[Dict[str, Any]] = None) -> None:
def execute(self, context: "Context") -> None:
metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id)
table = metastore.get_table(table_name=self.table)
field_types = {col.name: col.type for col in table.sd.cols}
Expand Down
7 changes: 5 additions & 2 deletions airflow/providers/apache/hive/sensors/hive_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Optional

from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


class HivePartitionSensor(BaseSensorOperator):
"""
Expand Down Expand Up @@ -67,7 +70,7 @@ def __init__(
self.partition = partition
self.schema = schema

def poke(self, context: Dict[str, Any]) -> bool:
def poke(self, context: "Context") -> bool:
if '.' in self.table:
self.schema, self.table = self.table.split('.')
self.log.info('Poking for table %s.%s, partition %s', self.schema, self.table, self.partition)
Expand Down
7 changes: 5 additions & 2 deletions airflow/providers/apache/hive/sensors/metastore_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict
from typing import TYPE_CHECKING, Any

from airflow.sensors.sql import SqlSensor

if TYPE_CHECKING:
from airflow.utils.context import Context


class MetastorePartitionSensor(SqlSensor):
"""
Expand Down Expand Up @@ -67,7 +70,7 @@ def __init__(
# constructor below and apply_defaults will no longer throw an exception.
super().__init__(**kwargs)

def poke(self, context: Dict[str, Any]) -> Any:
def poke(self, context: "Context") -> Any:
if self.first_poke:
self.first_poke = False
if '.' in self.table:
Expand Down
7 changes: 5 additions & 2 deletions airflow/providers/apache/hive/sensors/named_hive_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Tuple
from typing import TYPE_CHECKING, Any, List, Tuple

from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


class NamedHivePartitionSensor(BaseSensorOperator):
"""
Expand Down Expand Up @@ -92,7 +95,7 @@ def poke_partition(self, partition: str) -> Any:
self.log.info('Poking for %s.%s/%s', schema, table, partition)
return self.hook.check_for_named_partition(schema, table, partition)

def poke(self, context: Dict[str, Any]) -> bool:
def poke(self, context: "Context") -> bool:

number_of_partitions = len(self.partition_names)
poke_index_start = self.next_index_to_poke
Expand Down
Loading

0 comments on commit 485ff6c

Please sign in to comment.