# Testing PySpark

This guide is a reference for writing robust tests for PySpark code.

To view the docs for PySpark test utils, see here. To see the code for PySpark built-in test utils, check out the Spark repository here. To see the JIRA board tickets for the PySpark test framework, see here.

## Build a PySpark Application
Here is an example for how to start a PySpark application. Feel free to skip to the next section, “Testing your PySpark Application,” if you already have an application you’re ready to test.

First, start your Spark Session.

In [1]:
from pyspark.sql import SparkSession 
from pyspark.sql.functions import col 

# Create a SparkSession 
spark = SparkSession.builder.appName("Sample PySpark ETL").getOrCreate() 

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/08/01 09:37:43 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Next, create a DataFrame.

In [2]:
sample_data = [{"name": "John    D.", "age": 30}, 
  {"name": "Alice   G.", "age": 25}, 
  {"name": "Bob  T.", "age": 35}, 
  {"name": "Eve   A.", "age": 28}] 

df = spark.createDataFrame(sample_data)

Now, let’s define and apply a transformation function to our DataFrame.

In [3]:
from pyspark.sql.functions import col, regexp_replace

# Remove additional spaces in name
def remove_extra_spaces(df, column_name):
    # Remove extra spaces from the specified column
    df_transformed = df.withColumn(column_name, regexp_replace(col(column_name), "\\s+", " "))
    
    return df_transformed

transformed_df = remove_extra_spaces(df, "name")

transformed_df.show()

[Stage 0:>                                                          (0 + 1) / 1]

+---+--------+
|age|    name|
+---+--------+
| 30| John D.|
| 25|Alice G.|
| 35|  Bob T.|
| 28|  Eve A.|
+---+--------+



                                                                                

You can also do this using Spark Connect. The only difference is using remote() when you create your SparkSession, for example:

spark = SparkSession.builder.remote("sc://localhost").appName("Sample PySpark ETL").getOrCreate()

For more information on how to use Spark Connect and its benefits, see: https://spark.apache.org/docs/latest/spark-connect-overview.html

## Testing your PySpark Application
Now let’s test our PySpark transformation function. 

One option is to simply eyeball the resulting DataFrame. However, this can be impractical for large DataFrame or input sizes.

A better way is to write tests. Here are some examples of how we can test our code.


### Option 1: No test framework

For simple ad-hoc validation cases, PySpark testing utils like assertDataFrameEqual and assertSchemaEqual can also be used in a standalone context.
You could easily test PySpark code in a notebook session. For example, say you want to assert equality between two DataFrames:



In [8]:
from python.pyspark.testing.utils import assertDataFrameEqual

# Example 1
df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
df2 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
assertDataFrameEqual(df1, df2)  # pass, DataFrames are identical

# Example 2
df1 = spark.createDataFrame(data=[("1", 0.1), ("2", 3.23)], schema=["id", "amount"])
df2 = spark.createDataFrame(data=[("1", 0.109), ("2", 3.23)], schema=["id", "amount"])
assertDataFrameEqual(df1, df2, rtol=1e-1)  # pass, DataFrames are approx equal by rtol

ModuleNotFoundError: No module named 'python'

You can also simply compare two DataFrame schemas:

In [10]:
from pyspark.testing.utils import assertDataFrameEqual
from pyspark.sql.types import StructType, StructField, ArrayType, DoubleType

s1 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])
s2 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])

assertSchemaEqual(s1, s2)  # pass, schemas are identical

ModuleNotFoundError: No module named 'pyspark.testing'

### Option 2: Unit Test
For more complex testing scenarios, you may want to use a testing framework.

One of the most popular testing framework options is unit tests. Let’s walk through how you can use the built-in Python unittest library to write PySpark tests. 
First, you will need a Spark session. You can use the @classmethod decorator from the unittest package to take care of setting up and tearing down a Spark session.

In [9]:
import unittest

class PySparkTestCase(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.spark = SparkSession.builder.appName("Sample PySpark ETL").getOrCreate() 

    @classmethod
    def tearDownClass(cls):
        cls.spark.stop()

Now let’s write a unittest class.

In [15]:
from pyspark.testing.utils import assertDataFrameEqual

class TestTranformation(PySparkTestCase):
    def test_single_space(self):
        sample_data = [{"name": "John    D.", "age": 30}, 
                       {"name": "Alice   G.", "age": 25}, 
                       {"name": "Bob  T.", "age": 35}, 
                       {"name": "Eve   A.", "age": 28}] 
                        
        # create a Spark DataFrame
        original_df = spark.createDataFrame(sample_data)
        
        # apply the transformation function from before
        transformed_df = remove_extra_spaces(original_df, "name")
        
        expected_data = [{"name": "John D.", "age": 30}, 
        {"name": "Alice G.", "age": 25}, 
        {"name": "Bob T.", "age": 35}, 
        {"name": "Eve A.", "age": 28}]
        
        expected_df = spark.createDataFrame(expected_data)
    
        assertDataFrameEqual(transformed_df, expected_df)


ModuleNotFoundError: No module named 'pyspark.testing'

When run, unittest will pick up all functions with a name beginning with “test.”

### Option 3: pytest

We can also write our tests with pytest, which is one of the most popular Python testing frameworks. 

Using the pytest fixture allows us to share a spark session across tests, tearing it down when the tests are complete.

In [17]:
import pytest

@pytest.fixture(scope="session")
def spark():
    spark = SparkSession.builder.appName("Sample PySpark ETL").getOrCreate()
    yield spark
    spark.stop()

We can then define our tests like this:

In [19]:
from pyspark.testing.utils import assertDataFrameEqual

def test_single_space(self):
        sample_data = [{"name": "John    D.", "age": 30}, 
                       {"name": "Alice   G.", "age": 25}, 
                       {"name": "Bob  T.", "age": 35}, 
                       {"name": "Eve   A.", "age": 28}] 
                        
        # create a Spark DataFrame
        original_df = spark.createDataFrame(sample_data)
        
        # apply the transformation function from before
        transformed_df = remove_extra_spaces(original_df, "name")
        
        expected_data = [{"name": "John D.", "age": 30}, 
        {"name": "Alice G.", "age": 25}, 
        {"name": "Bob T.", "age": 35}, 
        {"name": "Eve A.", "age": 28}]
        
        expected_df = spark.createDataFrame(expected_data)
    
        assertDataFrameEqual(transformed_df, expected_df)

ModuleNotFoundError: No module named 'pyspark.testing'

When you run your test file with the pytest command, it will pick up all functions that have their name beginning with “test.”

pytest /path/to/file

Option 4: Other test framework(s)
There are many other test framework alternatives which you can use instead of unittest or pytest. The PySpark testing util functions are standalone, meaning they can be compatible with any test framework or CI test pipeline.

## Putting it all together!

Let’s see all the steps together, in a Unit Test example.

In [21]:
from pyspark.sql import SparkSession 
from pyspark.sql.functions import col
import pyspark.sql.functions as F 
from pyspark.testing.utils import assertDataFrameEqual

import unittest

# Create a SparkSession 
spark = SparkSession.builder.appName("Sample PySpark ETL").getOrCreate() 

sample_data = [{"name": "John    D.", "age": 30}, 
  {"name": "Alice   G.", "age": 25}, 
  {"name": "Bob  T.", "age": 35}, 
  {"name": "Eve   A.", "age": 28}] 

df = spark.createDataFrame(sample_data)

def remove_extra_spaces(df, column_name):
    # Remove extra spaces from the specified column using regexp_replace
    df_transformed = df.withColumn(column_name, 
regexp_replace(col(column_name), "\\s+", " "))

    return df_transformed

# Define unit test base class
class PySparkTestCase(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.spark = SparkSession.builder.appName("Sample PySpark ETL").getOrCreate() 

    @classmethod
    def tearDownClass(cls):
        cls.spark.stop()


# Define unit test
class TestTranformation(PySparkTestCase):
    def test_single_space(self):
      sample_data = [{"name": "John    D.", "age": 30}, 
    {"name": "Alice   G.", "age": 25}, 
    {"name": "Bob  T.", "age": 35}, 
    {"name": "Eve   A.", "age": 28}] 

    # create a Spark DataFrame
    original_df = spark.createDataFrame(sample_data)

    # apply the transformation function from before
    transformed_df = remove_extra_spaces(original_df, "name")

    expected_data = [{"name": "John D.", "age": 30}, 
    {"name": "Alice G.", "age": 25}, 
    {"name": "Bob T.", "age": 35}, 
    {"name": "Eve A.", "age": 28}]

    expected_df = spark.createDataFrame(expected_data)

    assertDataFrameEqual(transformed_df, expected_df)

ModuleNotFoundError: No module named 'pyspark.testing'