From 0a4ed7d557809ad81ecc50d197c33c8d178c42ce Mon Sep 17 00:00:00 2001 From: bolkedebruin Date: Wed, 1 Nov 2023 11:42:58 +0100 Subject: [PATCH] Add pyspark decorator (#35247) This add the pyspark decorator so that spark can be run inline so that results, like dataframes, can be shared. --- .pre-commit-config.yaml | 2 + airflow/decorators/__init__.pyi | 22 ++++ .../apache/spark/decorators/__init__.py | 17 +++ .../apache/spark/decorators/pyspark.py | 119 ++++++++++++++++++ airflow/providers/apache/spark/provider.yaml | 4 + .../decorators/pyspark.rst | 51 ++++++++ .../index.rst | 1 + .../apache/spark/decorators/__init__.py | 16 +++ .../apache/spark/decorators/test_pyspark.py | 80 ++++++++++++ .../providers/apache/spark/example_pyspark.py | 75 +++++++++++ 10 files changed, 387 insertions(+) create mode 100644 airflow/providers/apache/spark/decorators/__init__.py create mode 100644 airflow/providers/apache/spark/decorators/pyspark.py create mode 100644 docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst create mode 100644 tests/providers/apache/spark/decorators/__init__.py create mode 100644 tests/providers/apache/spark/decorators/test_pyspark.py create mode 100644 tests/system/providers/apache/spark/example_pyspark.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b7c80a810678a..e6b59100bf298 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -514,6 +514,7 @@ repos: ^airflow/providers/apache/cassandra/hooks/cassandra.py$| ^airflow/providers/apache/hive/operators/hive_stats.py$| ^airflow/providers/apache/hive/transfers/vertica_to_hive.py$| + ^airflow/providers/apache/spark/decorators/| ^airflow/providers/apache/spark/hooks/| ^airflow/providers/apache/spark/operators/| ^airflow/providers/exasol/hooks/exasol.py$| @@ -542,6 +543,7 @@ repos: ^docs/apache-airflow-providers-amazon/secrets-backends/aws-ssm-parameter-store.rst$| ^docs/apache-airflow-providers-apache-hdfs/connections.rst$| ^docs/apache-airflow-providers-apache-kafka/connections/kafka.rst$| + ^docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst$| ^docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst$| ^docs/apache-airflow-providers-microsoft-azure/connections/azure_cosmos.rst$| ^docs/conf.py$| diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi index 0c3e94bf5c38c..f718e35777d9b 100644 --- a/airflow/decorators/__init__.pyi +++ b/airflow/decorators/__init__.pyi @@ -566,6 +566,28 @@ class TaskDecoratorCollection: """ @overload def sensor(self, python_callable: Callable[FParams, FReturn] | None = None) -> Task[FParams, FReturn]: ... + @overload + def pyspark( + self, + *, + multiple_outputs: bool | None = None, + conn_id: str | None = None, + config_kwargs: dict[str, str] | None = None, + **kwargs, + ) -> TaskDecorator: + """ + Wraps a Python function that is to be injected with a SparkSession. + + :param multiple_outputs: If set, function return value will be unrolled to multiple XCom values. + Dict will unroll to XCom values with keys as XCom keys. Defaults to False. + :param conn_id: The connection ID to use for the SparkSession. + :param config_kwargs: Additional kwargs to pass to the SparkSession builder. This overrides + the config from the connection. + """ + @overload + def pyspark( + self, python_callable: Callable[FParams, FReturn] | None = None + ) -> Task[FParams, FReturn]: ... task: TaskDecoratorCollection setup: Callable diff --git a/airflow/providers/apache/spark/decorators/__init__.py b/airflow/providers/apache/spark/decorators/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/airflow/providers/apache/spark/decorators/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/apache/spark/decorators/pyspark.py b/airflow/providers/apache/spark/decorators/pyspark.py new file mode 100644 index 0000000000000..6f576b03a2009 --- /dev/null +++ b/airflow/providers/apache/spark/decorators/pyspark.py @@ -0,0 +1,119 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING, Any, Callable, Sequence + +from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory +from airflow.hooks.base import BaseHook +from airflow.operators.python import PythonOperator + +if TYPE_CHECKING: + from airflow.utils.context import Context + +SPARK_CONTEXT_KEYS = ["spark", "sc"] + + +class _PySparkDecoratedOperator(DecoratedOperator, PythonOperator): + custom_operator_name = "@task.pyspark" + + template_fields: Sequence[str] = ("op_args", "op_kwargs") + + def __init__( + self, + python_callable: Callable, + op_args: Sequence | None = None, + op_kwargs: dict | None = None, + conn_id: str | None = None, + config_kwargs: dict | None = None, + **kwargs, + ): + self.conn_id = conn_id + self.config_kwargs = config_kwargs or {} + + signature = inspect.signature(python_callable) + parameters = [ + param.replace(default=None) if param.name in SPARK_CONTEXT_KEYS else param + for param in signature.parameters.values() + ] + # mypy does not understand __signature__ attribute + # see https://github.com/python/mypy/issues/12472 + python_callable.__signature__ = signature.replace(parameters=parameters) # type: ignore[attr-defined] + + kwargs_to_upstream = { + "python_callable": python_callable, + "op_args": op_args, + "op_kwargs": op_kwargs, + } + super().__init__( + kwargs_to_upstream=kwargs_to_upstream, + python_callable=python_callable, + op_args=op_args, + op_kwargs=op_kwargs, + **kwargs, + ) + + def execute(self, context: Context): + from pyspark import SparkConf + from pyspark.sql import SparkSession + + conf = SparkConf().setAppName(f"{self.dag_id}-{self.task_id}") + + master = "local[*]" + if self.conn_id: + conn = BaseHook.get_connection(self.conn_id) + if conn.port: + master = f"{conn.host}:{conn.port}" + elif conn.host: + master = conn.host + + for key, value in conn.extra_dejson.items(): + conf.set(key, value) + + conf.setMaster(master) + + # task can override connection config + for key, value in self.config_kwargs.items(): + conf.set(key, value) + + spark = SparkSession.builder.config(conf=conf).getOrCreate() + sc = spark.sparkContext + + if not self.op_kwargs: + self.op_kwargs = {} + + op_kwargs: dict[str, Any] = dict(self.op_kwargs) + op_kwargs["spark"] = spark + op_kwargs["sc"] = sc + + self.op_kwargs = op_kwargs + return super().execute(context) + + +def pyspark_task( + python_callable: Callable | None = None, + multiple_outputs: bool | None = None, + **kwargs, +) -> TaskDecorator: + return task_decorator_factory( + python_callable=python_callable, + multiple_outputs=multiple_outputs, + decorated_operator_class=_PySparkDecoratedOperator, + **kwargs, + ) diff --git a/airflow/providers/apache/spark/provider.yaml b/airflow/providers/apache/spark/provider.yaml index be3e791a9259c..9316f80fa051d 100644 --- a/airflow/providers/apache/spark/provider.yaml +++ b/airflow/providers/apache/spark/provider.yaml @@ -83,6 +83,10 @@ connection-types: - hook-class-name: airflow.providers.apache.spark.hooks.spark_submit.SparkSubmitHook connection-type: spark +task-decorators: + - class-name: airflow.providers.apache.spark.decorators.pyspark.pyspark_task + name: pyspark + additional-extras: - name: cncf.kubernetes dependencies: diff --git a/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst b/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst new file mode 100644 index 0000000000000..28b51ec848a90 --- /dev/null +++ b/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst @@ -0,0 +1,51 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + + +.. _howto/decorator:pyspark: + +PySpark Decorator +================= + +Python callable wrapped within the ``@task.pyspark`` decorator +is injected with a SparkContext object. + +Parameters +---------- + +The following parameters can be passed to the decorator: + +conn_id: str + The connection ID to use for connecting to the Spark cluster. If not + specified, the spark master is set to ``local[*]``. +config_kwargs: dict + The kwargs used for initializing the SparkConf object. This overrides + the spark configuration options set in the connection. + + +Example +------- + +The following example shows how to use the ``@task.pyspark`` decorator. Note +that the ``spark`` and ``sc`` objects are injected into the function. + +.. exampleinclude:: /../../tests/system/providers/apache/spark/example_pyspark.py + :language: python + :dedent: 4 + :start-after: [START task_pyspark] + :end-before: [END task_pyspark] diff --git a/docs/apache-airflow-providers-apache-spark/index.rst b/docs/apache-airflow-providers-apache-spark/index.rst index fc3438f0d9b4b..af2d24ba48a26 100644 --- a/docs/apache-airflow-providers-apache-spark/index.rst +++ b/docs/apache-airflow-providers-apache-spark/index.rst @@ -34,6 +34,7 @@ :caption: Guides Connection types + Decorators Operators .. toctree:: diff --git a/tests/providers/apache/spark/decorators/__init__.py b/tests/providers/apache/spark/decorators/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/apache/spark/decorators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/apache/spark/decorators/test_pyspark.py b/tests/providers/apache/spark/decorators/test_pyspark.py new file mode 100644 index 0000000000000..ea307c9c3c2e7 --- /dev/null +++ b/tests/providers/apache/spark/decorators/test_pyspark.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.decorators import task +from airflow.models import Connection +from airflow.utils import db, timezone + +DEFAULT_DATE = timezone.datetime(2021, 9, 1) + + +class TestPysparkDecorator: + def setup_method(self): + db.merge_conn( + Connection( + conn_id="pyspark_local", + conn_type="spark", + host="spark://none", + extra="", + ) + ) + + @pytest.mark.db_test + @mock.patch("pyspark.SparkConf.setAppName") + @mock.patch("pyspark.sql.SparkSession") + def test_pyspark_decorator_with_connection(self, spark_mock, conf_mock, dag_maker): + @task.pyspark(conn_id="pyspark_local", config_kwargs={"spark.executor.memory": "2g"}) + def f(spark, sc): + import random + + return [random.random() for _ in range(100)] + + with dag_maker(): + ret = f() + + dr = dag_maker.create_dagrun() + ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date) + ti = dr.get_task_instances()[0] + assert len(ti.xcom_pull()) == 100 + conf_mock().set.assert_called_with("spark.executor.memory", "2g") + conf_mock().setMaster.assert_called_once_with("spark://none") + spark_mock.builder.config.assert_called_once_with(conf=conf_mock()) + + @pytest.mark.db_test + @mock.patch("pyspark.SparkConf.setAppName") + @mock.patch("pyspark.sql.SparkSession") + def test_simple_pyspark_decorator(self, spark_mock, conf_mock, dag_maker): + e = 2 + + @task.pyspark + def f(): + return e + + with dag_maker(): + ret = f() + + dr = dag_maker.create_dagrun() + ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date) + ti = dr.get_task_instances()[0] + assert ti.xcom_pull() == e + conf_mock().setMaster.assert_called_once_with("local[*]") + spark_mock.builder.config.assert_called_once_with(conf=conf_mock()) diff --git a/tests/system/providers/apache/spark/example_pyspark.py b/tests/system/providers/apache/spark/example_pyspark.py new file mode 100644 index 0000000000000..c671cb40e3a63 --- /dev/null +++ b/tests/system/providers/apache/spark/example_pyspark.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import typing + +import pendulum + +if typing.TYPE_CHECKING: + import pandas as pd + from pyspark import SparkContext + from pyspark.sql import SparkSession + +from airflow.decorators import dag, task + + +@dag( + schedule=None, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + catchup=False, + tags=["example"], +) +def example_pyspark(): + """ + ### Example Pyspark DAG + This is an example DAG which uses pyspark + """ + + # [START task_pyspark] + @task.pyspark(conn_id="spark-local") + def spark_task(spark: SparkSession, sc: SparkContext) -> pd.DataFrame: + df = spark.createDataFrame( + [ + (1, "John Doe", 21), + (2, "Jane Doe", 22), + (3, "Joe Bloggs", 23), + ], + ["id", "name", "age"], + ) + df.show() + + return df.toPandas() + + # [END task_pyspark] + + @task + def print_df(df: pd.DataFrame): + print(df) + + df = spark_task() + print_df(df) + + +# work around pre-commit +dag = example_pyspark() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)