# Unit tests in PySpark 

---
<img align="right" src="https://media3.giphy.com/media/fRYeEj3DtgrW9VpVWs/giphy.gif?cid=ecf05e47cdpuxkms9stljb9g0axvahcqt1rr5ewhzfrpxvsm&ep=v1_gifs_search&rid=giphy.gif&ct=g" width="250" height="300"> 


Testing is a crucial part of software development that helps in identifying issues with your code early and often, thereby ensuring the reliability and robustness of the application. Unit tests are designed to test the independent pieces of logic that comprise your application. In general, tests look to validate that your logic is functioning as intended. By asserting that the actual output of our logic is identical to the expected output, we can determine if the logic has been implemented correctly. Ideally each test will cover exactly one piece of functionality, e.g., a specific data transformation or helper function.

___

 
 In the context of PySpark tests are usually centered around **comparing DataFrames** for expected output.  There are several dimensions by which DataFrames can be compared:
 <br><br>
 
 * Schemas
 * Columns
 * Rows 
 * Entire DataFrames

In this notebook, we are going to learn about unit testing PySpark applications using the Python `pytest` framework. Tests will be written using the popular [`chispa`](https://github.com/MrPowers/chispa) library as well as the new [PySpark native testing functions available in Spark 3.5](https://issues.apache.org/jira/browse/SPARK-44042). This tutorial is designed to keep it simple, so if you want to learn more about how `pytest` works we recommend taking a closer look at [the official documentation](https://docs.pytest.org/en/7.4.x/).

Let's get started!


# Working with `pytest`

## How to structure your project with pytest
Consider the following hypothetical project structure for a simple reporting use case:

```
├── notebooks/
│   ├── analysis.ipynb   # report with visualizations
├── src/
│   ├── load_config.py   # helper functions
│   └── cleaning.py
├── tests/
│   ├── main_test.py     # unit tests
├── requirements.txt     # dependencies
└── test-requirements.txt
```

The `notebooks/` folder contains `analysis.ipynb`, which reports on and visualizes some business data.  This notebook imports custom Python functions from both files in `src/` to help load and clean data like so:

```
from src.load_config import db_loader, config_handler
from src.cleaning import *
```

Since our report depends on these functions to get and prepare the data correctly, we want to write tests that validate our functions are behaving as intended.  To so do, we create a `tests/` folder and include our tests there.  **`pytest` is designed to look for folders and files with "test" as a prefix or suffix.**  In this example our test script is called `main_test.py`, but later on you will see examples like `test_pyspark_column_equality.py`.  Both are supported and will be picked up by `pytest`. 

The last two files in our project are specify any dependencies for all of our code.  It is a best practice to separate testing dependencies, since we only need `pytest` to run the testing scripts.

#### Installing dependencies
To use `pytest` we will need to install it alongside any other dependencies, then restart the Python interpreter to make them available in our environment.

In [0]:
!cp ../requirements.txt ~/.
%pip install -r ~/requirements.txt
dbutils.library.restartPython()

## Invoking `pytest`

`pytest` is usually run from the system command line, but it can also be executed from the context of a notebook or Python REPL.  We'll be using the latter method to invoke our tests from the Databricks editor.  This lets us make use of Spark and other configuration variables in Databricks Runtime. 

One limitation of this approach is that changes to the test will be cached by Python's import caching mechanism.  If we wanted to iterate on tests during a development scenario, we would need to use `dbutils.library.restartPython()` to clear the cache and pick up changes to our tests.  This tutorial has been structured to render this unnecessary, but it is important to note!

In the following cell, we first make sure that all tests will run relative to our repository root directory.  Then we define `run_pytest`, a helper function to invoke a specific test file in our project.  Importantly, **this function also fails the Databricks notebook cell execution if tests fail.**  This ensures we surface errors whether we run these unit tests in an interactive session or as part of a Databricks Workflow.

In [0]:
import pytest
import os
import sys

# Run all tests in the repository root.
notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
repo_root = os.path.dirname(os.path.dirname(notebook_path))
os.chdir(f'/Workspace/{repo_root}')
%pwd

def run_pytest(pytest_path):
  # Skip writing pyc files on a readonly filesystem.
  sys.dont_write_bytecode = True

  retcode = pytest.main([pytest_path, "-p", "no:cacheprovider"])

  # Fail the cell execution if we have any test failures.
  assert retcode == 0, 'The pytest invocation failed. See the log above for details.'

# Test scenarios
Before we use `run_pytest()`, let's take a closer look at how `chispa` works.

## Using `chispa`

`chispa` is Python library for testing PySpark code.  It is authored by Matthew Powers, who is currently (September 2023) at Databricks.  There are three main functions to compare DataFrames in `chispa`:

* `assert_df_equality()`
* `assert_column_equality()`
* `assert_approx_df_equality()`
* `assert_approx_column_equality()`

**In this section we will focus on various ways to compare DataFrames using `assert_df_equality()`**, and we've included some additional tests using the other functions at the end of this section.  If you want to learn more about the library, be sure to checkout [the official documentation](https://mrpowers.github.io/chispa/).

### DataFrame equality

A common data transformation task is to add or remove columns from DataFrames.  To validate that our transformations are working as intended, we can assert that the actual output DataFrame of a function is equivalent to an expected output DataFrame.  We will use this approach to test DataFrame equality for schemas, columns, as well as row or column orders.


#### Schema equality

Check out the example below where we compare two different DataFrames, one with all numeric columns and one with a mix of numeric and string columns:

In [0]:
from chispa import assert_df_equality
# DF with numeric and string columns
data1 = [
        (1, "a"),
        (2, "b"),
        (3, "c"),
        (None, None)
    ]
df1 = spark.createDataFrame(data1, ["num", "letter"])

# DF with only numeric columns
data2 = [
        (1, 6),
        (2, 7),
        (3, 8),
        (None, None)
    ]
df2 = spark.createDataFrame(data2, ["num", "num2"])

# Compare them
assert_df_equality(df1, df2)

As we expected, this fails.  `chispa` makes it really clear exactly why it failed, too, specifying the precise column differences and highlighting the discrepancies in red. 

**Let's run another comparison, but make sure that the schemas match.**

In [0]:
from chispa import assert_df_equality
# DF with numeric and string columns
data1 = [
        (1, "a"),
        (2, "b"),
        (3, "c"),
        (None, None)
    ]
df1 = spark.createDataFrame(data1, ["num", "letter"])

# DF with only numeric columns
data2 = [
        (1, "d"),
        (2, "e"),
        (3, "f"),
        (None, None)
    ]
df2 = spark.createDataFrame(data2, ["num", "letter"])

# Compare them
assert_df_equality(df1, df2)

Aha! The schemas match but the DataFrames are not 100% equal.  Let's try this one more time and ensure that the DataFrames are identical.

In [0]:
from chispa import assert_df_equality
# DF with numeric and string columns
data1 = [
        (1, "a"),
        (2, "b"),
        (3, "c"),
        (None, None)
    ]
df1 = spark.createDataFrame(data1, ["num", "letter"])

# DF with only numeric columns
data2 = [
        (1, "a"),
        (2, "b"),
        (3, "c"),
        (None, None)
    ]
df2 = spark.createDataFrame(data2, ["num", "letter"])

# Compare them
assert_df_equality(df1, df2)

`chispa` will not return any output to the console when we run tests this way.  However, when we use `chispa` in combination with `pytest`, we will always get a complete report on how our tests fared.  To do so, we have added `tests/test_chispa_101.py` to our project.  This file contains the exact same test that we ran in the cell above, the only exception being that we wrap it in a function:

```{python}
## contents of tests/chispa/test_chispa_101.py

def test_schema_mismatch_message():
    # DF with numeric and string columns
    data1 = [
        (1, "a"),
        (2, "b"),
        (3, "c"),
        (None, None)
    ]
    df1 = spark.createDataFrame(data1, ["num", "letter"])
    
    # DF with only numeric columns
    data2 = [
        (1, "a"),
        (2, "b"),
        (3, "c"),
        (None, None)
    ]
    df2 = spark.createDataFrame(data2, ["num", "letter"])
    
    # Compare them
    assert_df_equality(df1, df2)
```

Let's invoke this test now using `pytest`.

In [0]:
run_pytest("tests/chispa/test_chispa_101.py")

Now that we understand how to combine `chispa` with `pytest`, let's examine some other test scenarios.

#### Column equality

In addition to adding or removing columns, a common data transformation is to update the contents of a column.  Let's take a slightly more realistic example for this case.  

In this project under `src/transforms_spark.py` we have a helper function called `remove_non_word_characters()` which ... you guessed it, strips all non-word characters from the strings in each row:

```
import pyspark.sql.functions as F

# Removing all non-word characters in PySpark
def remove_non_word_characters(col):
    return F.regexp_replace(col, "[^\\w\\s]+", "")
```

To test this unit of logic, we will take three simple steps.

1. Create a DataFrame with rows that have non-word characters
2. Apply `remove_non_word_characters()` to the DataFrame and generate a new column with clean rows
3. Compare this DataFrame to a 2nd DataFrame with the expected output and assert that the two are equal

The following cell does just that by importing `remove_non_word_characters()` from the `src` folder in our project and applying it to some sample data.

In [0]:
import pyspark.sql.functions as F
from src.transforms_spark import remove_non_word_characters
from chispa import assert_column_equality

# Dirty rows
dirty_rows = [
      ("jo&&se",),
      ("**li**",),
      ("#::luisa",),
      (None,)
  ]
source_df = spark.createDataFrame(dirty_rows, ["name"])

# Cleaned rows using function
clean_df = source_df.withColumn(
    "clean_name",
    remove_non_word_characters(F.col("name"))
)

# Expected output, should be identical to clean_df
expected_data = [
      ("jo&&se", "jose"),
      ("**li**", "li"),
      ("#::luisa", "luisa"),
      (None, None)
  ]
expected_df = spark.createDataFrame(expected_data, ["name", "clean_name"])

assert_df_equality(clean_df, expected_df, underline_cells=True)

Looks good.  Before we invoke this test, let's consider another case:

```
data = [
        ("matt7", "matt"),
        ("bill&", "bill"),
        ("isabela*", "isabela"),
        (None, None)
      ]
```

We have included this test case in `tests/chispa/test_chispa_column_equality.py`. In that file we import `remove_non_word_characters()` just like we did in the previous cell. 

Using the definition of `remove_non_word_characters()` from earlier, do you think this test will pass?  Why or why not?  Think about this before invoking `pytest` below.

In [0]:
run_pytest("tests/chispa/test_chispa_column_equality.py")

Our test results reflect the fact that our logic might not have been working exactly as intended!  To get clean data, we would need to modify `remove_non_word_characters()` to include integers.  

This example illustrates the importance of unit testing.  The author of this function used an incomplete regular expression to clean their text data, which could have been tricky to track down after the fact.  By including more test cases - the integer in the string - we ensure the quality and robustness of our code!

#### Row and column order
Sometimes pipelines or analytics depend on row or column order.  `assert_df_equality()` assumes that equality extends to row and column order.  This can be overridden using the `ignore_row_order` or `ignore_column_order` parameters.  

In [0]:
from chispa import assert_df_equality

# DF with numeric and string columns
row_order1 = [
        (1,),
        (2,),
        (3,),
        (None,)
    ]
df1 = spark.createDataFrame(row_order1, ["num"])

# DF with only numeric columns
row_order2 = [
        (3,),
        (2,),
        (1,),
        (None,)
    ]
df2 = spark.createDataFrame(row_order2, ["num"])

# Compare them
assert_df_equality(df1, df2)

As expected, these DataFrames are not identical and our tests fails.  Let's try again, but ignore row order.

In [0]:
assert_df_equality(df1, df2, ignore_row_order=True)

The content of the DataFrames are identical, even if the order is not, and our test passes.
___

The remaining tests with `chispa` 

### Approximate equality
When dealing with floating point precision, it can be useful to test for approximate equality.  The sensitivity of this approximation can be set using the `precision` argument in `assert_approx_df_equality()` or `assert_approx_column_equality()`. 

#### Approximate column equality

For example, assume we have the following DataFrame with two floating point columns.

```
data = [
    (1.1, 1.1),
    (2.2, 2.15),
    (3.3, 3.37),
    (None, None)
    ]
df = spark.createDataFrame(data, ["num1", "num2"])

assert_approx_column_equality(df, "num1", "num2", precision=0.1)
```

A precision of 0.1 will determine these columns to be identical.  If we set precision to 0.01, `chispa` would fail the assertion:

In [0]:
from chispa import assert_approx_column_equality

# Create DataFrame with two columns of floating point numbers
data = [
    (1.1, 1.1),
    (2.2, 2.15),
    (3.3, 3.37),
    (None, None)
    ]
df = spark.createDataFrame(data, ["num1", "num2"])

# Compare them with thousandths-level precision
assert_approx_column_equality(df, "num1", "num2", precision=0.01)

Let's invoke this test via `pytest` and check the output.

In [0]:
run_pytest("tests/chispa/test_chispa_approx_column_equality.py")

Take a moment to consider why this test failed? What would need to change in order for it to pass?

#### Approximate DataFrame quality

In [0]:
run_pytest("tests/chispa/test_chispa_approx_df_equality.py")

## Using PySpark native tests

### `assertDataFrameEqual`

In [0]:
import pyspark.sql.functions as F
from src.transforms_spark import remove_non_word_characters
from pyspark.testing.utils import assertDataFrameEqual

# Dirty rows
dirty_rows = [
      ("jo&&se",),
      ("**li**",),
      ("#::luisa",),
      (None,)
  ]
source_df = spark.createDataFrame(dirty_rows, ["name"])

# Cleaned rows using function
clean_df = source_df.withColumn(
    "clean_name",
    remove_non_word_characters(F.col("name"))
)

# Expected output, should be identical to clean_df
expected_data = [
      ("jo&&se", "jose"),
      ("**li**", "li"),
      ("#::luisa", "luisa"),
      (None, None)
  ]
expected_df = spark.createDataFrame(expected_data, ["name", "clean_name"])

assertDataFrameEqual(clean_df, expected_df)

Let's run the same test as before where we had an integer mixed with the strings, but we'll use the `assertDataFrameEqual` function instead of one from `chispa`.

In [0]:
run_pytest("tests/pyspark_native/test_df_equality.py")

### `assertPandasOnSparkEqual`

In [0]:
import pyspark.pandas as ps


### `assertSchemaEqual`