Skip to content

Commit

Permalink
Add pyspark decorator (#35247)
Browse files Browse the repository at this point in the history
This add the pyspark decorator so that spark can be
run inline so that results, like dataframes, can be
shared.
  • Loading branch information
bolkedebruin committed Nov 1, 2023
1 parent 3724a02 commit 0a4ed7d
Show file tree
Hide file tree
Showing 10 changed files with 387 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Expand Up @@ -514,6 +514,7 @@ repos:
^airflow/providers/apache/cassandra/hooks/cassandra.py$|
^airflow/providers/apache/hive/operators/hive_stats.py$|
^airflow/providers/apache/hive/transfers/vertica_to_hive.py$|
^airflow/providers/apache/spark/decorators/|
^airflow/providers/apache/spark/hooks/|
^airflow/providers/apache/spark/operators/|
^airflow/providers/exasol/hooks/exasol.py$|
Expand Down Expand Up @@ -542,6 +543,7 @@ repos:
^docs/apache-airflow-providers-amazon/secrets-backends/aws-ssm-parameter-store.rst$|
^docs/apache-airflow-providers-apache-hdfs/connections.rst$|
^docs/apache-airflow-providers-apache-kafka/connections/kafka.rst$|
^docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst$|
^docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst$|
^docs/apache-airflow-providers-microsoft-azure/connections/azure_cosmos.rst$|
^docs/conf.py$|
Expand Down
22 changes: 22 additions & 0 deletions airflow/decorators/__init__.pyi
Expand Up @@ -566,6 +566,28 @@ class TaskDecoratorCollection:
"""
@overload
def sensor(self, python_callable: Callable[FParams, FReturn] | None = None) -> Task[FParams, FReturn]: ...
@overload
def pyspark(
self,
*,
multiple_outputs: bool | None = None,
conn_id: str | None = None,
config_kwargs: dict[str, str] | None = None,
**kwargs,
) -> TaskDecorator:
"""
Wraps a Python function that is to be injected with a SparkSession.
:param multiple_outputs: If set, function return value will be unrolled to multiple XCom values.
Dict will unroll to XCom values with keys as XCom keys. Defaults to False.
:param conn_id: The connection ID to use for the SparkSession.
:param config_kwargs: Additional kwargs to pass to the SparkSession builder. This overrides
the config from the connection.
"""
@overload
def pyspark(
self, python_callable: Callable[FParams, FReturn] | None = None
) -> Task[FParams, FReturn]: ...

task: TaskDecoratorCollection
setup: Callable
Expand Down
17 changes: 17 additions & 0 deletions airflow/providers/apache/spark/decorators/__init__.py
@@ -0,0 +1,17 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
119 changes: 119 additions & 0 deletions airflow/providers/apache/spark/decorators/pyspark.py
@@ -0,0 +1,119 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, Callable, Sequence

from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
from airflow.hooks.base import BaseHook
from airflow.operators.python import PythonOperator

if TYPE_CHECKING:
from airflow.utils.context import Context

SPARK_CONTEXT_KEYS = ["spark", "sc"]


class _PySparkDecoratedOperator(DecoratedOperator, PythonOperator):
custom_operator_name = "@task.pyspark"

template_fields: Sequence[str] = ("op_args", "op_kwargs")

def __init__(
self,
python_callable: Callable,
op_args: Sequence | None = None,
op_kwargs: dict | None = None,
conn_id: str | None = None,
config_kwargs: dict | None = None,
**kwargs,
):
self.conn_id = conn_id
self.config_kwargs = config_kwargs or {}

signature = inspect.signature(python_callable)
parameters = [
param.replace(default=None) if param.name in SPARK_CONTEXT_KEYS else param
for param in signature.parameters.values()
]
# mypy does not understand __signature__ attribute
# see https://github.com/python/mypy/issues/12472
python_callable.__signature__ = signature.replace(parameters=parameters) # type: ignore[attr-defined]

kwargs_to_upstream = {
"python_callable": python_callable,
"op_args": op_args,
"op_kwargs": op_kwargs,
}
super().__init__(
kwargs_to_upstream=kwargs_to_upstream,
python_callable=python_callable,
op_args=op_args,
op_kwargs=op_kwargs,
**kwargs,
)

def execute(self, context: Context):
from pyspark import SparkConf
from pyspark.sql import SparkSession

conf = SparkConf().setAppName(f"{self.dag_id}-{self.task_id}")

master = "local[*]"
if self.conn_id:
conn = BaseHook.get_connection(self.conn_id)
if conn.port:
master = f"{conn.host}:{conn.port}"
elif conn.host:
master = conn.host

for key, value in conn.extra_dejson.items():
conf.set(key, value)

conf.setMaster(master)

# task can override connection config
for key, value in self.config_kwargs.items():
conf.set(key, value)

spark = SparkSession.builder.config(conf=conf).getOrCreate()
sc = spark.sparkContext

if not self.op_kwargs:
self.op_kwargs = {}

op_kwargs: dict[str, Any] = dict(self.op_kwargs)
op_kwargs["spark"] = spark
op_kwargs["sc"] = sc

self.op_kwargs = op_kwargs
return super().execute(context)


def pyspark_task(
python_callable: Callable | None = None,
multiple_outputs: bool | None = None,
**kwargs,
) -> TaskDecorator:
return task_decorator_factory(
python_callable=python_callable,
multiple_outputs=multiple_outputs,
decorated_operator_class=_PySparkDecoratedOperator,
**kwargs,
)
4 changes: 4 additions & 0 deletions airflow/providers/apache/spark/provider.yaml
Expand Up @@ -83,6 +83,10 @@ connection-types:
- hook-class-name: airflow.providers.apache.spark.hooks.spark_submit.SparkSubmitHook
connection-type: spark

task-decorators:
- class-name: airflow.providers.apache.spark.decorators.pyspark.pyspark_task
name: pyspark

additional-extras:
- name: cncf.kubernetes
dependencies:
Expand Down
51 changes: 51 additions & 0 deletions docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst
@@ -0,0 +1,51 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
.. _howto/decorator:pyspark:

PySpark Decorator
=================

Python callable wrapped within the ``@task.pyspark`` decorator
is injected with a SparkContext object.

Parameters
----------

The following parameters can be passed to the decorator:

conn_id: str
The connection ID to use for connecting to the Spark cluster. If not
specified, the spark master is set to ``local[*]``.
config_kwargs: dict
The kwargs used for initializing the SparkConf object. This overrides
the spark configuration options set in the connection.


Example
-------

The following example shows how to use the ``@task.pyspark`` decorator. Note
that the ``spark`` and ``sc`` objects are injected into the function.

.. exampleinclude:: /../../tests/system/providers/apache/spark/example_pyspark.py
:language: python
:dedent: 4
:start-after: [START task_pyspark]
:end-before: [END task_pyspark]
1 change: 1 addition & 0 deletions docs/apache-airflow-providers-apache-spark/index.rst
Expand Up @@ -34,6 +34,7 @@
:caption: Guides

Connection types <connections/spark>
Decorators <decorators/pyspark>
Operators <operators>

.. toctree::
Expand Down
16 changes: 16 additions & 0 deletions tests/providers/apache/spark/decorators/__init__.py
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
80 changes: 80 additions & 0 deletions tests/providers/apache/spark/decorators/test_pyspark.py
@@ -0,0 +1,80 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from unittest import mock

import pytest

from airflow.decorators import task
from airflow.models import Connection
from airflow.utils import db, timezone

DEFAULT_DATE = timezone.datetime(2021, 9, 1)


class TestPysparkDecorator:
def setup_method(self):
db.merge_conn(
Connection(
conn_id="pyspark_local",
conn_type="spark",
host="spark://none",
extra="",
)
)

@pytest.mark.db_test
@mock.patch("pyspark.SparkConf.setAppName")
@mock.patch("pyspark.sql.SparkSession")
def test_pyspark_decorator_with_connection(self, spark_mock, conf_mock, dag_maker):
@task.pyspark(conn_id="pyspark_local", config_kwargs={"spark.executor.memory": "2g"})
def f(spark, sc):
import random

return [random.random() for _ in range(100)]

with dag_maker():
ret = f()

dr = dag_maker.create_dagrun()
ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date)
ti = dr.get_task_instances()[0]
assert len(ti.xcom_pull()) == 100
conf_mock().set.assert_called_with("spark.executor.memory", "2g")
conf_mock().setMaster.assert_called_once_with("spark://none")
spark_mock.builder.config.assert_called_once_with(conf=conf_mock())

@pytest.mark.db_test
@mock.patch("pyspark.SparkConf.setAppName")
@mock.patch("pyspark.sql.SparkSession")
def test_simple_pyspark_decorator(self, spark_mock, conf_mock, dag_maker):
e = 2

@task.pyspark
def f():
return e

with dag_maker():
ret = f()

dr = dag_maker.create_dagrun()
ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date)
ti = dr.get_task_instances()[0]
assert ti.xcom_pull() == e
conf_mock().setMaster.assert_called_once_with("local[*]")
spark_mock.builder.config.assert_called_once_with(conf=conf_mock())

0 comments on commit 0a4ed7d

Please sign in to comment.