Skip to content

Commit

Permalink
Optimise Airflow DB backend usage in Azure Provider (#33750)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Aug 26, 2023
1 parent c077d19 commit faf3253
Show file tree
Hide file tree
Showing 30 changed files with 872 additions and 857 deletions.
15 changes: 10 additions & 5 deletions airflow/providers/microsoft/azure/hooks/adx.py
Expand Up @@ -26,6 +26,7 @@
from __future__ import annotations

import warnings
from functools import cached_property
from typing import Any

from azure.identity import DefaultAzureCredential
Expand Down Expand Up @@ -76,8 +77,8 @@ class AzureDataExplorerHook(BaseHook):
conn_type = "azure_data_explorer"
hook_name = "Azure Data Explorer"

@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget
from flask_babel import lazy_gettext
Expand All @@ -94,8 +95,8 @@ def get_connection_form_widgets() -> dict[str, Any]:
),
}

@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "port", "extra"],
Expand All @@ -116,7 +117,11 @@ def get_ui_field_behaviour() -> dict[str, Any]:
def __init__(self, azure_data_explorer_conn_id: str = default_conn_name) -> None:
super().__init__()
self.conn_id = azure_data_explorer_conn_id
self.connection = self.get_conn() # todo: make this a property, or just delete

@cached_property
def connection(self) -> KustoClient:
"""Return a KustoClient object (cached)."""
return self.get_conn()

def get_conn(self) -> KustoClient:
"""Return a KustoClient object."""
Expand Down
23 changes: 11 additions & 12 deletions airflow/providers/microsoft/azure/hooks/batch.py
Expand Up @@ -19,14 +19,14 @@

import time
from datetime import timedelta
from functools import cached_property
from typing import Any

from azure.batch import BatchServiceClient, batch_auth, models as batch_models
from azure.batch.models import JobAddParameter, PoolAddParameter, TaskAddParameter

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from airflow.providers.microsoft.azure.utils import AzureIdentityCredentialAdapter, get_field
from airflow.utils import timezone

Expand All @@ -52,8 +52,8 @@ def _get_field(self, extras, name):
field_name=name,
)

@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
Expand All @@ -63,8 +63,8 @@ def get_connection_form_widgets() -> dict[str, Any]:
"account_url": StringField(lazy_gettext("Batch Account URL"), widget=BS3TextFieldWidget()),
}

@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "port", "host", "extra"],
Expand All @@ -77,20 +77,19 @@ def get_ui_field_behaviour() -> dict[str, Any]:
def __init__(self, azure_batch_conn_id: str = default_conn_name) -> None:
super().__init__()
self.conn_id = azure_batch_conn_id
self.connection = self.get_conn()

def _connection(self) -> Connection:
"""Get connected to Azure Batch service."""
conn = self.get_connection(self.conn_id)
return conn
@cached_property
def connection(self) -> BatchServiceClient:
"""Get the Batch client connection (cached)."""
return self.get_conn()

def get_conn(self):
def get_conn(self) -> BatchServiceClient:
"""
Get the Batch client connection.
:return: Azure Batch client
"""
conn = self._connection()
conn = self.get_connection(self.conn_id)

batch_account_url = self._get_field(conn.extra_dejson, "account_url")
if not batch_account_url:
Expand Down
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import warnings
from functools import cached_property

from azure.mgmt.containerinstance import ContainerInstanceManagementClient
from azure.mgmt.containerinstance.models import ContainerGroup
Expand Down Expand Up @@ -47,7 +48,10 @@ class AzureContainerInstanceHook(AzureBaseHook):

def __init__(self, azure_conn_id: str = default_conn_name) -> None:
super().__init__(sdk_client=ContainerInstanceManagementClient, conn_id=azure_conn_id)
self.connection = self.get_conn()

@cached_property
def connection(self):
return self.get_conn()

def create_or_update(self, resource_group: str, name: str, container_group: ContainerGroup) -> None:
"""
Expand Down
10 changes: 7 additions & 3 deletions airflow/providers/microsoft/azure/hooks/container_registry.py
Expand Up @@ -18,6 +18,7 @@
"""Hook for Azure Container Registry."""
from __future__ import annotations

from functools import cached_property
from typing import Any

from azure.mgmt.containerinstance.models import ImageRegistryCredential
Expand All @@ -39,8 +40,8 @@ class AzureContainerRegistryHook(BaseHook):
conn_type = "azure_container_registry"
hook_name = "Azure Container Registry"

@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "port", "extra"],
Expand All @@ -59,7 +60,10 @@ def get_ui_field_behaviour() -> dict[str, Any]:
def __init__(self, conn_id: str = "azure_registry") -> None:
super().__init__()
self.conn_id = conn_id
self.connection = self.get_conn()

@cached_property
def connection(self) -> ImageRegistryCredential:
return self.get_conn()

def get_conn(self) -> ImageRegistryCredential:
conn = self.get_connection(self.conn_id)
Expand Down
15 changes: 10 additions & 5 deletions airflow/providers/microsoft/azure/hooks/data_lake.py
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import Any

from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
Expand Down Expand Up @@ -256,8 +257,8 @@ class AzureDataLakeStorageV2Hook(BaseHook):
conn_type = "adls"
hook_name = "Azure Date Lake Storage V2"

@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Returns connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget
from flask_babel import lazy_gettext
Expand All @@ -272,8 +273,8 @@ def get_connection_form_widgets() -> dict[str, Any]:
),
}

@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour."""
return {
"hidden_fields": ["schema", "port"],
Expand All @@ -296,7 +297,11 @@ def __init__(self, adls_conn_id: str, public_read: bool = False) -> None:
super().__init__()
self.conn_id = adls_conn_id
self.public_read = public_read
self.service_client = self.get_conn()

@cached_property
def service_client(self) -> DataLakeServiceClient:
"""Return the DataLakeServiceClient object (cached)."""
return self.get_conn()

def get_conn(self) -> DataLakeServiceClient: # type: ignore[override]
"""Return the DataLakeServiceClient object."""
Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/microsoft/azure/hooks/wasb.py
Expand Up @@ -27,6 +27,7 @@

import logging
import os
from functools import cached_property
from typing import Any, Union
from urllib.parse import urlparse

Expand Down Expand Up @@ -123,7 +124,6 @@ def __init__(
super().__init__()
self.conn_id = wasb_conn_id
self.public_read = public_read
self.blob_service_client: BlobServiceClient = self.get_conn()

logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy")
try:
Expand All @@ -142,6 +142,11 @@ def _get_field(self, extra_dict, field_name):
return extra_dict[field_name] or None
return extra_dict.get(f"{prefix}{field_name}") or None

@cached_property
def blob_service_client(self) -> BlobServiceClient:
"""Return the BlobServiceClient object (cached)."""
return self.get_conn()

def get_conn(self) -> BlobServiceClient:
"""Return the BlobServiceClient object."""
conn = self.get_connection(self.conn_id)
Expand Down
1 change: 0 additions & 1 deletion airflow/providers/microsoft/azure/log/wasb_task_handler.py
Expand Up @@ -67,7 +67,6 @@ def __init__(
self.wasb_container = wasb_container
self.remote_base = wasb_log_folder
self.log_relative_path = ""
self._hook = None
self.closed = False
self.upload_on_close = True
self.delete_local_copy = (
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/microsoft/azure/operators/batch.py
Expand Up @@ -179,7 +179,8 @@ def __init__(
self.should_delete_pool = should_delete_pool

@cached_property
def hook(self):
def hook(self) -> AzureBatchHook:
"""Create and return an AzureBatchHook (cached)."""
return self.get_hook()

def _check_inputs(self) -> Any:
Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/microsoft/azure/operators/data_factory.py
Expand Up @@ -18,6 +18,7 @@

import time
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
Expand Down Expand Up @@ -159,8 +160,12 @@ def __init__(
self.check_interval = check_interval
self.deferrable = deferrable

@cached_property
def hook(self) -> AzureDataFactoryHook:
"""Create and return an AzureDataFactoryHook (cached)."""
return AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)

def execute(self, context: Context) -> None:
self.hook = AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
self.log.info("Executing the %s pipeline.", self.pipeline_name)
response = self.hook.run_pipeline(
pipeline_name=self.pipeline_name,
Expand Down
9 changes: 6 additions & 3 deletions airflow/providers/microsoft/azure/operators/synapse.py
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from azure.synapse.spark.models import SparkBatchJobOptions
Expand Down Expand Up @@ -73,10 +74,12 @@ def __init__(
self.timeout = timeout
self.check_interval = check_interval

@cached_property
def hook(self):
"""Create and return an AzureSynapseHook (cached)."""
return AzureSynapseHook(azure_synapse_conn_id=self.azure_synapse_conn_id, spark_pool=self.spark_pool)

def execute(self, context: Context) -> None:
self.hook = AzureSynapseHook(
azure_synapse_conn_id=self.azure_synapse_conn_id, spark_pool=self.spark_pool
)
self.log.info("Executing the Synapse spark job.")
response = self.hook.run_spark_job(payload=self.payload)
self.log.info(response)
Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/microsoft/azure/sensors/data_factory.py
Expand Up @@ -18,6 +18,7 @@

import warnings
from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
Expand Down Expand Up @@ -72,8 +73,12 @@ def __init__(

self.deferrable = deferrable

@cached_property
def hook(self):
"""Create and return an AzureDataFactoryHook (cached)."""
return AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)

def poke(self, context: Context) -> bool:
self.hook = AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
pipeline_run_status = self.hook.get_pipeline_run_status(
run_id=self.run_id,
resource_group_name=self.resource_group_name,
Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Expand Up @@ -154,6 +154,7 @@ BaseView
BaseXCom
bashrc
batchGet
BatchServiceClient
bc
bcc
bdist
Expand Down

0 comments on commit faf3253

Please sign in to comment.