Skip to content

Commit

Permalink
YandexCloud provider: Support new Yandex SDK features: log_group_id, …
Browse files Browse the repository at this point in the history
…user-agent, maven packages (#20103)
  • Loading branch information
Piatachock committed Dec 14, 2021
1 parent 1a2a249 commit 41c49c7
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 5 deletions.
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import uuid
from datetime import datetime

from airflow import DAG
Expand Down Expand Up @@ -81,7 +81,7 @@
'-input',
's3a://data-proc-public/jobs/sources/data/cities500.txt.bz2',
'-output',
f's3a://{S3_BUCKET_NAME_FOR_JOB_LOGS}/dataproc/job/results',
f's3a://{S3_BUCKET_NAME_FOR_JOB_LOGS}/dataproc/job/results/{uuid.uuid4()}',
],
properties={
'yarn.app.mapreduce.am.resource.mb': '2048',
Expand Down Expand Up @@ -113,6 +113,9 @@
properties={
'spark.submit.deployMode': 'cluster',
},
packages=['org.slf4j:slf4j-simple:1.7.30'],
repositories=['https://repo1.maven.org/maven2'],
exclude_packages=['com.amazonaws:amazon-kinesis-client'],
)

create_pyspark_job = DataprocCreatePysparkJobOperator(
Expand All @@ -129,7 +132,7 @@
],
args=[
's3a://data-proc-public/jobs/sources/data/cities500.txt.bz2',
f's3a://{S3_BUCKET_NAME_FOR_JOB_LOGS}/jobs/results/${{JOB_ID}}',
f's3a://{S3_BUCKET_NAME_FOR_JOB_LOGS}/dataproc/job/results/${{JOB_ID}}',
],
jar_file_uris=[
's3a://data-proc-public/jobs/sources/java/dataproc-examples-1.0.jar',
Expand All @@ -139,6 +142,9 @@
properties={
'spark.submit.deployMode': 'cluster',
},
packages=['org.slf4j:slf4j-simple:1.7.30'],
repositories=['https://repo1.maven.org/maven2'],
exclude_packages=['com.amazonaws:amazon-kinesis-client'],
)

delete_cluster = DataprocDeleteClusterOperator(
Expand Down
16 changes: 15 additions & 1 deletion airflow/providers/yandex/hooks/yandex.py
Expand Up @@ -80,6 +80,20 @@ def get_connection_form_widgets() -> Dict[str, Any]:
),
}

@classmethod
def provider_user_agent(cls) -> Optional[str]:
"""Construct User-Agent from Airflow core & provider package versions"""
import airflow
from airflow.providers_manager import ProvidersManager

try:
manager = ProvidersManager()
provider_name = manager.hooks[cls.conn_type].package_name
provider = manager.providers[provider_name]
return f'apache-airflow/{airflow.__version__} {provider_name}/{provider.version}'
except KeyError:
warnings.warn(f"Hook '{cls.hook_name}' info is not initialized in airflow.ProviderManager")

@staticmethod
def get_ui_field_behaviour() -> Dict:
"""Returns custom field behaviour"""
Expand Down Expand Up @@ -107,7 +121,7 @@ def __init__(
self.connection = self.get_connection(self.connection_id)
self.extras = self.connection.extra_dejson
credentials = self._get_credentials()
self.sdk = yandexcloud.SDK(**credentials)
self.sdk = yandexcloud.SDK(user_agent=self.provider_user_agent(), **credentials)
self.default_folder_id = default_folder_id or self._get_field('folder_id', False)
self.default_public_ssh_key = default_public_ssh_key or self._get_field('public_ssh_key', False)
self.client = self.sdk.client
Expand Down
40 changes: 40 additions & 0 deletions airflow/providers/yandex/operators/yandexcloud_dataproc.py
Expand Up @@ -93,6 +93,9 @@ class DataprocCreateClusterOperator(BaseOperator):
:param computenode_decommission_timeout: Timeout to gracefully decommission nodes during downscaling.
In seconds.
:type computenode_decommission_timeout: int
:param log_group_id: Id of log group to write logs. By default logs will be sent to default log group.
To disable cloud log sending set cluster property dataproc:disable_cloud_logging = true
:type log_group_id: str
"""

def __init__(
Expand Down Expand Up @@ -127,6 +130,7 @@ def __init__(
computenode_cpu_utilization_target: Optional[int] = None,
computenode_decommission_timeout: Optional[int] = None,
connection_id: Optional[str] = None,
log_group_id: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -159,6 +163,7 @@ def __init__(
self.computenode_preemptible = computenode_preemptible
self.computenode_cpu_utilization_target = computenode_cpu_utilization_target
self.computenode_decommission_timeout = computenode_decommission_timeout
self.log_group_id = log_group_id

self.hook: Optional[DataprocHook] = None

Expand Down Expand Up @@ -195,6 +200,7 @@ def execute(self, context) -> None:
computenode_preemptible=self.computenode_preemptible,
computenode_cpu_utilization_target=self.computenode_cpu_utilization_target,
computenode_decommission_timeout=self.computenode_decommission_timeout,
log_group_id=self.log_group_id,
)
context['task_instance'].xcom_push(key='cluster_id', value=operation_result.response.id)
context['task_instance'].xcom_push(key='yandexcloud_connection_id', value=self.yandex_conn_id)
Expand Down Expand Up @@ -399,6 +405,14 @@ class DataprocCreateSparkJobOperator(BaseOperator):
:type cluster_id: Optional[str]
:param connection_id: ID of the Yandex.Cloud Airflow connection.
:type connection_id: Optional[str]
:param packages: List of maven coordinates of jars to include on the driver and executor classpaths.
:type packages: Optional[Iterable[str]]
:param repositories: List of additional remote repositories to search for the maven coordinates
given with --packages.
:type repositories: Optional[Iterable[str]]
:param exclude_packages: List of groupId:artifactId, to exclude while resolving the dependencies
provided in --packages to avoid dependency conflicts.
:type exclude_packages: Optional[Iterable[str]]
"""

template_fields = ['cluster_id']
Expand All @@ -416,6 +430,9 @@ def __init__(
name: str = 'Spark job',
cluster_id: Optional[str] = None,
connection_id: Optional[str] = None,
packages: Optional[Iterable[str]] = None,
repositories: Optional[Iterable[str]] = None,
exclude_packages: Optional[Iterable[str]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -429,6 +446,9 @@ def __init__(
self.name = name
self.cluster_id = cluster_id
self.connection_id = connection_id
self.packages = packages
self.repositories = repositories
self.exclude_packages = exclude_packages
self.hook: Optional[DataprocHook] = None

def execute(self, context) -> None:
Expand All @@ -447,6 +467,9 @@ def execute(self, context) -> None:
file_uris=self.file_uris,
args=self.args,
properties=self.properties,
packages=self.packages,
repositories=self.repositories,
exclude_packages=self.exclude_packages,
name=self.name,
cluster_id=cluster_id,
)
Expand Down Expand Up @@ -476,6 +499,14 @@ class DataprocCreatePysparkJobOperator(BaseOperator):
:type cluster_id: Optional[str]
:param connection_id: ID of the Yandex.Cloud Airflow connection.
:type connection_id: Optional[str]
:param packages: List of maven coordinates of jars to include on the driver and executor classpaths.
:type packages: Optional[Iterable[str]]
:param repositories: List of additional remote repositories to search for the maven coordinates
given with --packages.
:type repositories: Optional[Iterable[str]]
:param exclude_packages: List of groupId:artifactId, to exclude while resolving the dependencies
provided in --packages to avoid dependency conflicts.
:type exclude_packages: Optional[Iterable[str]]
"""

template_fields = ['cluster_id']
Expand All @@ -493,6 +524,9 @@ def __init__(
name: str = 'Pyspark job',
cluster_id: Optional[str] = None,
connection_id: Optional[str] = None,
packages: Optional[Iterable[str]] = None,
repositories: Optional[Iterable[str]] = None,
exclude_packages: Optional[Iterable[str]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -506,6 +540,9 @@ def __init__(
self.name = name
self.cluster_id = cluster_id
self.connection_id = connection_id
self.packages = packages
self.repositories = repositories
self.exclude_packages = exclude_packages
self.hook: Optional[DataprocHook] = None

def execute(self, context) -> None:
Expand All @@ -524,6 +561,9 @@ def execute(self, context) -> None:
file_uris=self.file_uris,
args=self.args,
properties=self.properties,
packages=self.packages,
repositories=self.repositories,
exclude_packages=self.exclude_packages,
name=self.name,
cluster_id=cluster_id,
)
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Expand Up @@ -452,6 +452,7 @@ args
argv
arn
arraysize
artifactId
asana
asc
ascii
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -508,7 +508,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
'pywinrm~=0.4',
]
yandex = [
'yandexcloud>=0.97.0',
'yandexcloud>=0.122.0',
]
zendesk = [
'zdesk',
Expand Down
11 changes: 11 additions & 0 deletions tests/providers/yandex/operators/test_yandexcloud_dataproc.py
Expand Up @@ -60,6 +60,9 @@
'cFDe6faKCxH6iDRteo4D8L8BxwzN42uZSB0nfmjkIxFTcEU3mFSXEbWByg78aoddMrAAjatyrhH1pON6P0='
]

# https://cloud.yandex.com/en-ru/docs/logging/concepts/log-group
LOG_GROUP_ID = 'my_log_group_id'


class DataprocClusterCreateOperatorTest(TestCase):
def setUp(self):
Expand Down Expand Up @@ -87,6 +90,7 @@ def test_create_cluster(self, create_cluster_mock, *_):
connection_id=CONNECTION_ID,
s3_bucket=S3_BUCKET_NAME_FOR_LOGS,
cluster_image_version=CLUSTER_IMAGE_VERSION,
log_group_id=LOG_GROUP_ID,
)
context = {'task_instance': MagicMock()}
operator.execute(context)
Expand Down Expand Up @@ -122,6 +126,7 @@ def test_create_cluster(self, create_cluster_mock, *_):
],
subnet_id='my_subnet_id',
zone='ru-central1-c',
log_group_id=LOG_GROUP_ID,
)
context['task_instance'].xcom_push.assert_has_calls(
[
Expand Down Expand Up @@ -300,6 +305,9 @@ def test_create_spark_job_operator(self, create_spark_job_mock, *_):
main_jar_file_uri='s3a://data-proc-public/jobs/sources/java/dataproc-examples-1.0.jar',
name='Spark job',
properties={'spark.submit.deployMode': 'cluster'},
packages=None,
repositories=None,
exclude_packages=None,
)

@patch('airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials')
Expand Down Expand Up @@ -359,4 +367,7 @@ def test_create_pyspark_job_operator(self, create_pyspark_job_mock, *_):
name='Pyspark job',
properties={'spark.submit.deployMode': 'cluster'},
python_file_uris=['s3a://some-in-bucket/jobs/sources/pyspark-001/geonames.py'],
packages=None,
repositories=None,
exclude_packages=None,
)

0 comments on commit 41c49c7

Please sign in to comment.