# Unit tests in PySpark
Testing is a crucial part of software development that helps in identifying issues with your code early and often, thereby ensuring the reliability and robustness of the application. **Unit tests** are designed to test the independent pieces of logic that comprise your application. In general, tests look to validate that your logic is functioning as intended. **By asserting that the _actual_ output of our logic is identical to the _expected_ output, we can determine if the logic has been implemented correctly.** Ideally each test will cover exactly one piece of functionality, e.g.,  a specific data transformation or helper function.
 
 In the context of PySpark tests are usually centered around **comparing DataFrames** for expected output.  There are several dimensions by which DataFrames can be compared:
 <br><br>
 
 * Schemas
 * Columns
 * Rows 
 * Entire DataFrames

In this notebook, we are going to learn about unit testing PySpark applications using the Python `pytest` framework. Tests will be written using the popular `chispa` library as well as the new PySpark native testing functions available in Spark 3.5.[link] This tutorial is designed to keep it simple, so if you want to learn more about how `pytest` works we recommend taking a closer look at the official documentation [link].

Let's get started!


# Working with `pytest`

## How to structure your project with pytest
Consider the following hypothetical project structure for a simple reporting use case:

```
├── notebooks/
│   ├── analysis.ipynb   # report with visualizations
├── src/
│   ├── load_config.py   # helper functions
│   └── cleaning.py
├── tests/
│   ├── main_test.py     # unit tests
├── requirements.txt     # dependencies
└── test-requirements.txt
```

The `notebooks/` folder contains `analysis.ipynb`, which reports on and visualizes some business data.  This notebook imports custom Python functions from both files in `src/` to help load and clean data like so:

```
from src.load_config import db_loader, config_handler
from src.cleaning import *
```

Since our report depends on these functions to get and prepare the data correctly, we want to write tests that validate our functions are behaving as intended.  To so do, we create a `tests/` folder and include our tests there.  **`pytest` is designed to look for folders and files with "test" as a prefix or suffix.**  In this example our test script is called `main_test.py`, but later on you will see examples like `test_pyspark_column_equality.py`.  Both are supported and will be picked up by `pytest`. 

The last two files in our project are specify any dependencies for all of our code.  It is a best practice to separate testing dependencies, since we only need `pytest` to run the testing scripts.

## Invoking `pytest`

`pytest` is usually run from the system command line, but it can also be executed from the context of a notebook or Python REPL.  We'll be using the latter method to invoke our tests from the Databricks editor.  This lets us make use of Spark and other configuration variables in Databricks Runtime. 

One limitation of this approach is that changes to the test will be cached by Python's import caching mechanism.  If we wanted to iterate on tests during a development scenario, we would need to use `dbutils.library.restartPython()` to clear the cache and pick up changes to our tests.  This tutorial has been structured to render this unnecessary, but it is important to note!

In the following cell, we first make sure that all tests will run relative to our repository root directory.  Then we define `run_pytest`, a helper function to invoke a specific test file in our project.  Importantly, this function also fails the Databricks notebook cell execution if tests fail.  This ensures we surface errors whether we run these unit tests in an interactive session or as part of a Databricks Workflow.

In [0]:
import pytest
import os
import sys

# Run all tests in the repository root.
notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
repo_root = os.path.dirname(os.path.dirname(notebook_path))
os.chdir(f'/Workspace/{repo_root}')
%pwd

def run_pytest(pytest_path):
  # Skip writing pyc files on a readonly filesystem.
  sys.dont_write_bytecode = True

  retcode = pytest.main([pytest_path, "-p", "no:cacheprovider"])

  # Fail the cell execution if we have any test failures.
  assert retcode == 0, 'The pytest invocation failed. See the log above for details.'

# Test scenarios
#### Installing dependencies

Let's ensure we install any dependencies for our tests first. To make use of them we will need to restart the Python interpreter before running any tests.

In [0]:
!cp ../requirements.txt ~/.
%pip install -r ~/requirements.txt

In [0]:
dbutils.library.restartPython()

## Using `chispa`

### Schema equality

A common data transformation task is to add or remove columns from DataFrames.  In these cases we define a function that takes a DataFrame as input, alters the schema, and returns a DataFrame as output.  In `chispa` we can test the validity of our function by using the `assert_df_equality()` function.  

For example, let's say we want to compare the following two DataFrames with `assert_df_equality()`.

In [0]:
from chispa import assert_df_equality
# DF with numeric and string columns
data1 = [
        (1, "a"),
        (2, "b"),
        (3, "c"),
        (None, None)
    ]
df1 = spark.createDataFrame(data1, ["num", "letter"])

# DF with only numeric columns
data2 = [
        (1, 6),
        (2, 7),
        (3, 8),
        (None, None)
    ]
df2 = spark.createDataFrame(data2, ["num", "num2"])

# Compare them
assert_df_equality(df1, df2)

[0;31m---------------------------------------------------------------------------[0m
[0;31mSchemasNotEqualError[0m                      Traceback (most recent call last)
File [0;32m<command-2786918639959420>, line 21[0m
[1;32m     18[0m df2 [38;5;241m=[39m spark[38;5;241m.[39mcreateDataFrame(data2, [[38;5;124m"[39m[38;5;124mnum[39m[38;5;124m"[39m, [38;5;124m"[39m[38;5;124mnum2[39m[38;5;124m"[39m])
[1;32m     20[0m [38;5;66;03m# Compare them[39;00m
[0;32m---> 21[0m assert_df_equality(df1, df2)

File [0;32m/local_disk0/.ephemeral_nfs/envs/pythonEnv-79fda558-daa4-4fe6-bb75-9a5c3be82218/lib/python3.10/site-packages/chispa/dataframe_comparer.py:22[0m, in [0;36massert_df_equality[0;34m(df1, df2, ignore_nullable, transforms, allow_nan_equality, ignore_column_order, ignore_row_order, underline_cells)[0m
[1;32m     20[0m df1 [38;5;241m=[39m reduce([38;5;28;01mlambda[39;00m acc, fn: fn(acc), transforms, df1)
[1;32m     21[0m df2 [38;5;241m=[39m reduce

As we expected, this fails.  `chispa` makes it really clear exactly why it failed, too, specifying the precise column differences and highlighting the discrepancies in red. 

Let's look at that same scenario, but make sure that it passes this time.

In [0]:
from chispa import assert_df_equality
# DF with numeric and string columns
data1 = [
        (1, "a"),
        (2, "b"),
        (3, "c"),
        (None, None)
    ]
df1 = spark.createDataFrame(data1, ["num", "letter"])

# DF with only numeric columns
data2 = [
        (1, "d"),
        (2, "e"),
        (3, "f"),
        (None, None)
    ]
df2 = spark.createDataFrame(data2, ["num", "letter"])

# Compare them
assert_df_equality(df1, df2)

[0;31m---------------------------------------------------------------------------[0m
[0;31mDataFramesNotEqualError[0m                   Traceback (most recent call last)
File [0;32m<command-2786918639959422>, line 21[0m
[1;32m     18[0m df2 [38;5;241m=[39m spark[38;5;241m.[39mcreateDataFrame(data2, [[38;5;124m"[39m[38;5;124mnum[39m[38;5;124m"[39m, [38;5;124m"[39m[38;5;124mletter[39m[38;5;124m"[39m])
[1;32m     20[0m [38;5;66;03m# Compare them[39;00m
[0;32m---> 21[0m assert_df_equality(df1, df2)

File [0;32m/local_disk0/.ephemeral_nfs/envs/pythonEnv-79fda558-daa4-4fe6-bb75-9a5c3be82218/lib/python3.10/site-packages/chispa/dataframe_comparer.py:27[0m, in [0;36massert_df_equality[0;34m(df1, df2, ignore_nullable, transforms, allow_nan_equality, ignore_column_order, ignore_row_order, underline_cells)[0m
[1;32m     24[0m     assert_generic_rows_equality(
[1;32m     25[0m         df1[38;5;241m.[39mcollect(), df2[38;5;241m.[39mcollect(), are_rows_equal

Aha! The schemas match but the DataFrames are not 100% equal.  Let's try this one more time and ensure that the DataFrames are identical.

In [0]:
from chispa import assert_df_equality
# DF with numeric and string columns
data1 = [
        (1, "a"),
        (2, "b"),
        (3, "c"),
        (None, None)
    ]
df1 = spark.createDataFrame(data1, ["num", "letter"])

# DF with only numeric columns
data2 = [
        (1, "a"),
        (2, "b"),
        (3, "c"),
        (None, None)
    ]
df2 = spark.createDataFrame(data2, ["num", "letter"])

# Compare them
assert_df_equality(df1, df2)

`chispa` will not return any output to the console when we run tests this way.  However, when we use `chispa` in combination with `pytest`, we will always get a complete report on how our tests fared.  To do so, we have added `tests/test_chispa_101.py` to our project.  This file contains the exact same test that we ran in the cell above, the only exception being that we wrap it in a function:

```{python}
## contents of tests/chispa/test_chispa_101.py

def test_schema_mismatch_message():
    # DF with numeric and string columns
    data1 = [
        (1, "a"),
        (2, "b"),
        (3, "c"),
        (None, None)
    ]
    df1 = spark.createDataFrame(data1, ["num", "letter"])
    
    # DF with only numeric columns
    data2 = [
        (1, "a"),
        (2, "b"),
        (3, "c"),
        (None, None)
    ]
    df2 = spark.createDataFrame(data2, ["num", "letter"])
    
    # Compare them
    assert_df_equality(df1, df2)
```

Let's invoke this test now using `pytest`.

In [0]:
run_pytest("tests/chispa/test_chispa_101.py")

platform linux -- Python 3.10.12, pytest-7.2.0, pluggy-1.0.0
rootdir: /Workspace/Repos/dustin.vannoy@databricks.com/notebook-best-practices, configfile: pytest.ini
plugins: anyio-3.5.0, typeguard-2.13.3
collected 1 item

tests/chispa/test_chispa_101.py [32m.[0m[32m                                        [100%][0m



### Column equality

In addition to adding or removing columns, a common data transformation is to update the contents of a column.  Let's take a slightly more realistic example for this case.  

Say we have a helper function called `remove_non_word_characters()` which ... you guessed it, strips all non-word characters from the strings in each row. 

```
import pyspark.sql.functions as F

# Removing all non-word characters in PySpark
def remove_non_word_characters(col):
    return F.regexp_replace(col, "[^\\w\\s]+", "")
```

To test this unit of logic, we will take three simple steps:

1. Create a DataFrame with rows that have non-word characters
2. Apply `remove_non_word_characters()` to the DataFrame and generate a new column with clean rows
3. Compare this DataFrame to a 2nd DataFrame with the expected output and assert that the two are equal

The following cell does just that.


In [0]:
import pyspark.sql.functions as F
from covid_analysis.transforms_spark import remove_non_word_characters
from chispa import assert_column_equality

# Dirty rows
dirty_rows = [
      ("jo&&se",),
      ("**li**",),
      ("#::luisa",),
      (None,)
  ]
source_df = spark.createDataFrame(dirty_rows, ["name"])

# Cleaned rows using function
clean_df = source_df.withColumn(
    "clean_name",
    remove_non_word_characters(F.col("name"))
)

# Expected output, should be identical to clean_df
expected_data = [
      ("jo&&se", "jose"),
      ("**li**", "li"),
      ("#::luisa", "luisa"),
      (None, None)
  ]
expected_df = spark.createDataFrame(expected_data, ["name", "clean_name"])

assert_df_equality(clean_df, expected_df, underline_cells=True)

Looks good.  Before we invoke this test, let's consider another case:

```
data = [
        ("matt7", "matt"),
        ("bill&", "bill"),
        ("isabela*", "isabela"),
        (None, None)
      ]
```

We have included this test case in `tests/chispa/test_chispa_column_equality.py`. 
 Take a look at the `remove_non_word_characters()` function and think about it before you invoke `pytest` in the next cell.
 
 What do you think will happen? Will this test pass?  



In [0]:
run_pytest("tests/chispa/test_chispa_column_equality.py")

platform linux -- Python 3.10.12, pytest-7.2.0, pluggy-1.0.0
rootdir: /Workspace/Repos/dustin.vannoy@databricks.com/notebook-best-practices, configfile: pytest.ini
plugins: anyio-3.5.0, typeguard-2.13.3
collected 2 items

tests/chispa/test_chispa_column_equality.py [32m.[0m[31mF[0m[31m                           [100%][0m

[31m[1m__________________ test_remove_non_word_characters_nice_error __________________[0m

>   [04m[91m?[39;49;00m[04m[91m?[39;49;00m[04m[91m?[39;49;00m

[1m[31mtests/chispa/test_chispa_column_equality.py[0m:30: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

df = DataFrame[name: string, expected_name: string, clean_name: string]
col_name1 = 'clean_name', col_name2 = 'expected_name'

    [94mdef[39;49;00m [92massert_column_equality[39;49;00m(df, col_name1, col_name2):
        elements = df.select(col_name1, col_name2).collect()
        colName1Elements = [96mlist[39;49;00m([96mmap[39;49;00m([94mlambda[

[0;31m---------------------------------------------------------------------------[0m
[0;31mAssertionError[0m                            Traceback (most recent call last)
File [0;32m<command-2786918639959407>, line 1[0m
[0;32m----> 1[0m [43mrun_pytest[49m[43m([49m[38;5;124;43m"[39;49m[38;5;124;43mtests/chispa/test_chispa_column_equality.py[39;49m[38;5;124;43m"[39;49m[43m)[49m

File [0;32m<command-2770700997875077>, line 28[0m, in [0;36mrun_pytest[0;34m(pytest_path)[0m
[1;32m     25[0m retcode [38;5;241m=[39m pytest[38;5;241m.[39mmain([pytest_path, [38;5;124m"[39m[38;5;124m-p[39m[38;5;124m"[39m, [38;5;124m"[39m[38;5;124mno:cacheprovider[39m[38;5;124m"[39m])
[1;32m     27[0m [38;5;66;03m# Fail the cell execution if we have any test failures.[39;00m
[0;32m---> 28[0m [38;5;28;01massert[39;00m retcode [38;5;241m==[39m [38;5;241m0[39m, [38;5;124m'[39m[38;5;124mThe pytest invocation failed. See the log above for details.[39m[38;5;124m

### Approximate column equality

In [0]:
run_pytest("tests/chispa/test_chispa_approx_column_equality.py")

platform linux -- Python 3.10.12, pytest-7.2.0, pluggy-1.0.0
rootdir: /Workspace/Repos/dustin.vannoy@databricks.com/notebook-best-practices, configfile: pytest.ini
plugins: anyio-3.5.0, typeguard-2.13.3
collected 2 items

tests/chispa/test_chispa_approx_column_equality.py [32m.[0m[31mF[0m[31m                    [100%][0m

[31m[1m______________________ test_approx_col_equality_different ______________________[0m

    [94mdef[39;49;00m [92mtest_approx_col_equality_different[39;49;00m():
        data = [
            ([94m1.1[39;49;00m, [94m1.1[39;49;00m),
            ([94m2.2[39;49;00m, [94m2.15[39;49;00m),
            ([94m3.3[39;49;00m, [94m5.0[39;49;00m),
            ([94mNone[39;49;00m, [94mNone[39;49;00m)
        ]
        df = spark.createDataFrame(data, [[33m"[39;49;00m[33mnum1[39;49;00m[33m"[39;49;00m, [33m"[39;49;00m[33mnum2[39;49;00m[33m"[39;49;00m])
>       assert_approx_column_equality(df, [33m"[39;49;00m[33mnum1[39;49;00m[33m"[39

[0;31m---------------------------------------------------------------------------[0m
[0;31mAssertionError[0m                            Traceback (most recent call last)
File [0;32m<command-2786918639959410>, line 1[0m
[0;32m----> 1[0m [43mrun_pytest[49m[43m([49m[38;5;124;43m"[39;49m[38;5;124;43mtests/chispa/test_chispa_approx_column_equality.py[39;49m[38;5;124;43m"[39;49m[43m)[49m

File [0;32m<command-2770700997875077>, line 28[0m, in [0;36mrun_pytest[0;34m(pytest_path)[0m
[1;32m     25[0m retcode [38;5;241m=[39m pytest[38;5;241m.[39mmain([pytest_path, [38;5;124m"[39m[38;5;124m-p[39m[38;5;124m"[39m, [38;5;124m"[39m[38;5;124mno:cacheprovider[39m[38;5;124m"[39m])
[1;32m     27[0m [38;5;66;03m# Fail the cell execution if we have any test failures.[39;00m
[0;32m---> 28[0m [38;5;28;01massert[39;00m retcode [38;5;241m==[39m [38;5;241m0[39m, [38;5;124m'[39m[38;5;124mThe pytest invocation failed. See the log above for details.[39m[38

### DataFrame equality

In [0]:
run_pytest("tests/chispa/test_chispa_df_equality.py")

### Approximate DataFrame quality

In [0]:
run_pytest("tests/chispa/test_chispa_approx_df_equality.py")

## Using PySpark native tests

### `assertDataFrameEqual`

### `assertPandasOnSparkEqual`

### `assertSchemaEqual`

When should I write unit tests?

When first working on a project we often prototype our code without paying attention to how we will modularize it and test it.  This can be fine during the early phases, but writing modular code and writing good tests go hand in hand. 