### Testing PySpark Application
This guide is a reference for writing robust tests for PySpark code.

To view the docs for PySpark test utils, see here. To see the code for PySpark built-in test utils, check out the Spark repository here. To see the JIRA board tickets for the PySpark test framework, see here

#### Build a PySpark Application

Here is an example for how to start a PySpark application. Feel free to skip to the next section, “Testing your PySpark Application,” if you already have an application you’re ready to test.

First, start your Spark Session.

In [1]:
import findspark
findspark.init()
from pyspark.sql import SparkSession

#create spark session
spark = SparkSession.builder.appName("Testing_PySpark")\
.getOrCreate()

25/01/03 05:59:56 WARN Utils: Your hostname, codespaces-e1276e resolves to a loopback address: 127.0.0.1; using 10.0.0.128 instead (on interface eth0)
25/01/03 05:59:56 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/01/03 05:59:56 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
sample_data = [{"name": "John    D.", "age": 30},
  {"name": "Alice   G.", "age": 25},
  {"name": "Bob  T.", "age": 35},
  {"name": "Eve   A.", "age": 28}]

df = spark.createDataFrame(sample_data)

In [3]:
from pyspark.sql.functions import col, regexp_replace

#remove additional space in name
def remove_extra_space(df, column_name):
    df_transformed = df.withColumn(column_name, regexp_replace(col(column_name), "\\s+", " "))

    return df_transformed

transformed_df = remove_extra_space(df, "name")

transformed_df.show()

                                                                                

+---+--------+
|age|    name|
+---+--------+
| 30| John D.|
| 25|Alice G.|
| 35|  Bob T.|
| 28|  Eve A.|
+---+--------+



### Testing Time!!

Now, let's test our PySpark transformation function.

One approach is to visually inspect the resulting DataFrame. However, this can be impractical for large DataFrames or input sizes.

A better approach is to write tests. Below are some examples of how we can test our code, applicable for Spark 3.5 and above versions.

Note that these examples are not exhaustive, as there are many other test frameworks you can use instead of unittest or pytest. The built-in PySpark testing utility functions are standalone, meaning they are compatible with any test framework or CI test pipeline.

##### Option 1: Using Only PySpark Built-in Test Utility Functions

For simple ad-hoc validation cases, you can use PySpark testing utilities like assertDataFrameEqual and assertSchemaEqual in a standalone context. This allows you to easily test PySpark code in a notebook session. For instance, if you want to assert equality between two DataFrames:

In [4]:
import pyspark.testing
from pyspark.testing.utils import assertDataFrameEqual

# Example 1
df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
df2 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
assertDataFrameEqual(df1, df2)  # pass, DataFrames are identical



In [5]:
# Example 2
df1 = spark.createDataFrame(data=[("1", 0.1), ("2", 3.23)], schema=["id", "amount"])
df2 = spark.createDataFrame(data=[("1", 0.109), ("2", 3.23)], schema=["id", "amount"])
assertDataFrameEqual(df1, df2, rtol=1e-1)  # pass, DataFrames are approx equal by rtol

In [6]:
from pyspark.testing.utils import assertSchemaEqual
from pyspark.sql.types import StructType, StructField, ArrayType, DoubleType

s1 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])
s2 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])

assertSchemaEqual(s1, s2)  # pass, schemas are identical

#### Option 2: Using Unit Test
For more complex testing scenarios, you may want to use a testing framework.

One of the most popular testing framework options is unit tests. Let’s walk through how you can use the built-in Python unittest library to write PySpark tests. For more information about the unittest library, see here: https://docs.python.org/3/library/unittest.html.

First, you will need a Spark session. You can use the @classmethod decorator from the unittest package to take care of setting up and tearing down a Spark session.

In [7]:
import unittest

class PySparkTestCase(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate()


    @classmethod
    def tearDownClass(cls):
        cls.spark.stop()

In [8]:
from pyspark.testing.utils import assertDataFrameEqual

class TestTranformation(PySparkTestCase):
    def test_single_space(self):
        sample_data = [{"name": "John    D.", "age": 30},
                       {"name": "Alice   G.", "age": 25},
                       {"name": "Bob  T.", "age": 35},
                       {"name": "Eve   A.", "age": 28}]

        # Create a Spark DataFrame
        original_df = spark.createDataFrame(sample_data)

        # Apply the transformation function from before
        transformed_df = remove_extra_spaces(original_df, "name")

        expected_data = [{"name": "John D.", "age": 30},
        {"name": "Alice G.", "age": 25},
        {"name": "Bob T.", "age": 35},
        {"name": "Eve A.", "age": 28}]

        expected_df = spark.createDataFrame(expected_data)

        assertDataFrameEqual(transformed_df, expected_df)


# When you run your test file with the pytest command, it will pick up all functions that have their name beginning with “test.”

In [9]:
# pkg/etl.py
import unittest

from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.functions import regexp_replace
from pyspark.testing.utils import assertDataFrameEqual

# Create a SparkSession
spark = SparkSession.builder.appName("Sample PySpark ETL").getOrCreate()

sample_data = [{"name": "John    D.", "age": 30},
  {"name": "Alice   G.", "age": 25},
  {"name": "Bob  T.", "age": 35},
  {"name": "Eve   A.", "age": 28}]

df = spark.createDataFrame(sample_data)

# Define DataFrame transformation function
def remove_extra_spaces(df, column_name):
    # Remove extra spaces from the specified column using regexp_replace
    df_transformed = df.withColumn(column_name, regexp_replace(col(column_name), "\\s+", " "))

    return df_transformed

25/01/03 06:00:04 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [10]:
# pkg/test_etl.py
import unittest

from pyspark.sql import SparkSession

# Define unit test base class
class PySparkTestCase(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.spark = SparkSession.builder.appName("Sample PySpark ETL").getOrCreate()

    @classmethod
    def tearDownClass(cls):
        cls.spark.stop()

# Define unit test
class TestTranformation(PySparkTestCase):
    def test_single_space(self):
        sample_data = [{"name": "John    D.", "age": 30},
                        {"name": "Alice   G.", "age": 25},
                        {"name": "Bob  T.", "age": 35},
                        {"name": "Eve   A.", "age": 28}]

        # Create a Spark DataFrame
        original_df = spark.createDataFrame(sample_data)

        # Apply the transformation function from before
        transformed_df = remove_extra_spaces(original_df, "name")

        expected_data = [{"name": "John D.", "age": 30},
        {"name": "Alice G.", "age": 25},
        {"name": "Bob T.", "age": 35},
        {"name": "Eve A.", "age": 28}]

        expected_df = spark.createDataFrame(expected_data)

        assertDataFrameEqual(transformed_df, expected_df)

In [11]:
unittest.main(argv=[''], verbosity=0, exit=False)

  self._sock = None
  self._sock = None
----------------------------------------------------------------------
Ran 1 test in 1.551s

OK


<unittest.main.TestProgram at 0x704b0dae3380>