Skip to content

Commit

Permalink
Fixing MyPy issues inside providers/microsoft (#20409)
Browse files Browse the repository at this point in the history
  • Loading branch information
khalidmammadov committed Dec 23, 2021
1 parent f0cf15c commit e63e23c
Show file tree
Hide file tree
Showing 24 changed files with 320 additions and 35 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/microsoft/azure/log/wasb_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def close(self) -> None:
# Mark closed so we don't double write if close is called twice
self.closed = True

def _read(self, ti, try_number: str, metadata: Optional[str] = None) -> Tuple[str, Dict[str, bool]]:
def _read(self, ti, try_number: int, metadata: Optional[str] = None) -> Tuple[str, Dict[str, bool]]:
"""
Read logs of given task instance and try_number from Wasb remote storage.
If failed, read the log from task instance host machine.
Expand Down
9 changes: 6 additions & 3 deletions airflow/providers/microsoft/azure/operators/adls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any, Sequence
from typing import TYPE_CHECKING, Any, Sequence

from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook

if TYPE_CHECKING:
from airflow.utils.context import Context


class ADLSDeleteOperator(BaseOperator):
"""
Expand Down Expand Up @@ -57,7 +60,7 @@ def __init__(
self.ignore_not_found = ignore_not_found
self.azure_data_lake_conn_id = azure_data_lake_conn_id

def execute(self, context: dict) -> Any:
def execute(self, context: "Context") -> Any:
hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id)
return hook.remove(path=self.path, recursive=self.recursive, ignore_not_found=self.ignore_not_found)

Expand Down Expand Up @@ -96,7 +99,7 @@ def __init__(
self.path = path
self.azure_data_lake_conn_id = azure_data_lake_conn_id

def execute(self, context: dict) -> list:
def execute(self, context: "Context") -> list:
hook = AzureDataLakeHook(azure_data_lake_conn_id=self.azure_data_lake_conn_id)
self.log.info('Getting list of ADLS files in path: %s', self.path)
return hook.list(path=self.path)
7 changes: 5 additions & 2 deletions airflow/providers/microsoft/azure/operators/adx.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
#

"""This module contains Azure Data Explorer operators"""
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union

from azure.kusto.data._models import KustoResultTable

from airflow.configuration import conf
from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.adx import AzureDataExplorerHook

if TYPE_CHECKING:
from airflow.utils.context import Context


class AzureDataExplorerQueryOperator(BaseOperator):
"""
Expand Down Expand Up @@ -66,7 +69,7 @@ def get_hook(self) -> AzureDataExplorerHook:
"""Returns new instance of AzureDataExplorerHook"""
return AzureDataExplorerHook(self.azure_data_explorer_conn_id)

def execute(self, context: dict) -> Union[KustoResultTable, str]:
def execute(self, context: "Context") -> Union[KustoResultTable, str]:
"""
Run KQL Query on Azure Data Explorer (Kusto).
Returns `PrimaryResult` of Query v2 HTTP response contents
Expand Down
7 changes: 5 additions & 2 deletions airflow/providers/microsoft/azure/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
# specific language governing permissions and limitations
# under the License.
#
from typing import Any, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional

from azure.batch import models as batch_models

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook

if TYPE_CHECKING:
from airflow.utils.context import Context


class AzureBatchOperator(BaseOperator):
"""
Expand Down Expand Up @@ -266,7 +269,7 @@ def _check_inputs(self) -> Any:
"Some required parameters are missing.Please you must set all the required parameters. "
)

def execute(self, context: dict) -> None:
def execute(self, context: "Context") -> None:
self._check_inputs()
self.hook.connection.config.retry_policy = self.batch_max_retries

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import re
from collections import namedtuple
from time import sleep
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union

from azure.mgmt.containerinstance.models import (
Container,
Expand All @@ -39,6 +39,10 @@
from airflow.providers.microsoft.azure.hooks.container_registry import AzureContainerRegistryHook
from airflow.providers.microsoft.azure.hooks.container_volume import AzureContainerVolumeHook

if TYPE_CHECKING:
from airflow.utils.context import Context


Volume = namedtuple(
'Volume',
['conn_id', 'account_name', 'share_name', 'mount_path', 'read_only'],
Expand Down Expand Up @@ -195,7 +199,7 @@ def __init__(
self.ip_address = ip_address
self.ports = ports

def execute(self, context: dict) -> int:
def execute(self, context: "Context") -> int:
# Check name again in case it was templated.
self._check_name(self.name)

Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/microsoft/azure/operators/cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import TYPE_CHECKING

from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook

if TYPE_CHECKING:
from airflow.utils.context import Context


class AzureCosmosInsertDocumentOperator(BaseOperator):
"""
Expand Down Expand Up @@ -54,7 +58,7 @@ def __init__(
self.document = document
self.azure_cosmos_conn_id = azure_cosmos_conn_id

def execute(self, context: dict) -> None:
def execute(self, context: "Context") -> None:
# Create the hook
hook = AzureCosmosDBHook(azure_cosmos_conn_id=self.azure_cosmos_conn_id)

Expand Down
37 changes: 37 additions & 0 deletions airflow/providers/microsoft/azure/operators/cosmos.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from typing import Optional

from airflow.models import BaseOperator

class AzureCosmosInsertDocumentOperator(BaseOperator):
"""
A stub file to suppress MyPy issues due to not supplying
mandatory parameters to the operator
"""

def __init__(
self,
*,
database_name: Optional[str] = None,
collection_name: Optional[str] = None,
document: Optional[dict] = None,
azure_cosmos_conn_id: str = 'azure_cosmos_default',
**kwargs,
) -> None: ...
7 changes: 5 additions & 2 deletions airflow/providers/microsoft/azure/operators/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional

from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator, BaseOperatorLink, TaskInstance
Expand All @@ -25,6 +25,9 @@
AzureDataFactoryPipelineRunStatus,
)

if TYPE_CHECKING:
from airflow.utils.context import Context


class AzureDataFactoryPipelineRunLink(BaseOperatorLink):
"""Constructs a link to monitor a pipeline run in Azure Data Factory."""
Expand Down Expand Up @@ -148,7 +151,7 @@ def __init__(
self.timeout = timeout
self.check_interval = check_interval

def execute(self, context: Dict) -> None:
def execute(self, context: "Context") -> None:
self.hook = AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
self.log.info(f"Executing the {self.pipeline_name} pipeline.")
response = self.hook.run_pipeline(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
# specific language governing permissions and limitations
# under the License.
#
from typing import Any
from typing import TYPE_CHECKING, Any

from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook

if TYPE_CHECKING:
from airflow.utils.context import Context


class WasbDeleteBlobOperator(BaseOperator):
"""
Expand Down Expand Up @@ -64,7 +67,7 @@ def __init__(
self.is_prefix = is_prefix
self.ignore_if_missing = ignore_if_missing

def execute(self, context: dict) -> None:
def execute(self, context: "Context") -> None:
self.log.info('Deleting blob: %s\n in wasb://%s', self.blob_name, self.container_name)
hook = WasbHook(wasb_conn_id=self.wasb_conn_id)

Expand Down
39 changes: 39 additions & 0 deletions airflow/providers/microsoft/azure/operators/wasb_delete_blob.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from typing import Optional, Any

from airflow.models import BaseOperator

class WasbDeleteBlobOperator(BaseOperator):
"""
A stub file to suppress MyPy issues due to not supplying
mandatory parameters to the operator
"""

def __init__(
self,
*,
container_name: Optional[str] = None,
blob_name: Optional[str] = None,
wasb_conn_id: str = 'wasb_default',
check_options: Any = None,
is_prefix: bool = False,
ignore_if_missing: bool = False,
**kwargs,
) -> None: ...
8 changes: 6 additions & 2 deletions airflow/providers/microsoft/azure/sensors/cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import TYPE_CHECKING

from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


class AzureCosmosDocumentSensor(BaseSensorOperator):
"""
Expand Down Expand Up @@ -62,7 +66,7 @@ def __init__(
self.collection_name = collection_name
self.document_id = document_id

def poke(self, context: dict) -> bool:
self.log.info("*** Intering poke")
def poke(self, context: "Context") -> bool:
self.log.info("*** Entering poke")
hook = AzureCosmosDBHook(self.azure_cosmos_conn_id)
return hook.get_document(self.document_id, self.database_name, self.collection_name) is not None
37 changes: 37 additions & 0 deletions airflow/providers/microsoft/azure/sensors/cosmos.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from typing import Optional

from airflow.sensors.base import BaseSensorOperator

class AzureCosmosDocumentSensor(BaseSensorOperator):
"""
A stub file to suppress MyPy issues due to not supplying
mandatory parameters to the operator
"""

def __init__(
self,
*,
database_name: Optional[str] = None,
collection_name: Optional[str] = None,
document_id: Optional[str] = None,
azure_cosmos_conn_id: str = "azure_cosmos_default",
**kwargs,
) -> None: ...
7 changes: 5 additions & 2 deletions airflow/providers/microsoft/azure/sensors/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from typing import Dict, Optional
from typing import TYPE_CHECKING, Optional

from airflow.providers.microsoft.azure.hooks.data_factory import (
AzureDataFactoryHook,
Expand All @@ -24,6 +24,9 @@
)
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


class AzureDataFactoryPipelineRunStatusSensor(BaseSensorOperator):
"""
Expand Down Expand Up @@ -58,7 +61,7 @@ def __init__(
self.resource_group_name = resource_group_name
self.factory_name = factory_name

def poke(self, context: Dict) -> bool:
def poke(self, context: "Context") -> bool:
self.hook = AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
pipeline_run_status = self.hook.get_pipeline_run_status(
run_id=self.run_id,
Expand Down
Loading

0 comments on commit e63e23c

Please sign in to comment.