
<div style="text-align: center; line-height: 0; padding-top: 9px;">
  <img
    src="https://databricks.com/wp-content/uploads/2018/03/db-academy-rgb-1200px.png"
    alt="Databricks Learning"
  >
</div>


# 2.4 - Creating and Executing Unit Tests

A unit test is a type of software testing that focuses on verifying the smallest parts of an application, typically individual functions or methods in isolation. The goal is to ensure that each unit of code works as expected, producing the correct output for a given input. Unit tests are typically automated and run frequently during development to catch bugs early and maintain code quality.

## Objectives

- Write simple unit tests within a notebook to verify the functionality of the code.
- Store unit tests in an external .py file, evaluating the advantages of externalizing tests for maintainability.
- Use the `pytest` package to automatically discover and execute test functions.

## REQUIRED - SELECT CLASSIC COMPUTE

Before executing cells in this notebook, please select your classic compute cluster in the lab. Be aware that **Serverless** is enabled by default.

Follow these steps to select the classic compute cluster:


1. Navigate to the top-right of this notebook and click the drop-down menu to select your cluster. By default, the notebook will use **Serverless**.

2. If your cluster is available, select it and continue to the next cell. If the cluster is not shown:

   - Click **More** in the drop-down.

   - In the **Attach to an existing compute resource** window, use the first drop-down to select your unique cluster.

**NOTE:** If your cluster has terminated, you might need to restart it in order to select it. To do this:

1. Right-click on **Compute** in the left navigation pane and select *Open in new tab*.

2. Find the triangle icon to the right of your compute cluster name and click it.

3. Wait a few minutes for the cluster to start.

4. Once the cluster is running, complete the steps above to select your cluster.

## A. Create Simple Unit Tests in a Databricks Notebook

Let's perform simple unit tests on the functions we created in the previous demonstration. To create a unit test for a function:

  **a. Create the Initial Function**  
  Create the initial function you want to use in production. Ensure it has clear inputs and outputs that can be easily validated.

**b. Create One or More Test Functions**  
Write a unit test function (or multiple functions to test different expectations) that calls the initial function with sample inputs and checks if the actual function output matches the defined expected result using assertions. Name the test function with the keyword `test_` followed by a description of what you are testing.  

**c. Execute the Unit Test Functions**  
Run and review the test using a testing framework like `pytest`. Review the results to ensure the function behaves as expected, and fix any issues if the test fails.

**NOTE:** There are many available frameworks you can use. We will use `pytest` in this course. The framework you choose should be discussed with your team or organization.

1. Complete the following to view the custom functions for this project:

   a. Navigate to the main course folder **DevOps Essentials for Data Engineering**.  
   
   b. Select **src**.  

   c. Then select **helpers**.  

   d. Right-click on **[project_functions.py]($../../src/helpers/project_functions.py)**  and select *Open in new tab*.  

   e. Review the **project_functions.py** file. Notice it contains the following functions that we saw in the previous notebook: 

     - `get_health_csv_schema` - Returns the schema for the health data CSVs.  

     - `highcholest_map` - Maps a cholesterol value to a categorical label.  
     
     - `group_ages_map` - Maps an age value to an age group category.  

   f. Close the **project_functions.py** tab.
   
   g. Run the cell below to import the custom functions from the **helpers** module.

**NOTE:** `sys.path.append()` Adds the root folder path to the system path (this allows you to import modules from this folder).

In [0]:
# Adds the parent directory of the current working directory to the Python search path
import sys
import os

# Get the current working directory
current_path = os.getcwd()

# Get the root folder path by navigating two levels up from the current path
root_folder_path = os.path.dirname(os.path.dirname(current_path))

# Add the root folder path to the system path (this allows you to import modules from this folder)
sys.path.append(root_folder_path)

# Print a message confirming the root folder path added to sys.path
print(f'Add the following path: {root_folder_path}')

In [0]:
## Import the custom functions
from src.helpers import project_functions

2. Create the test function `test_get_health_csv_schema_match` to test the `get_health_csv_schema` function from above. 

      Within the test function:

   - The variable **actual_schema** holds the schema result from the `get_health_csv_schema` function.

   - The variable **expected_schema** specifies the schema we expect the function to return if it works as expected.

   - The `assertSchemaEqual` statement from `pyspark.testing.utils` compares the **expected_schema** with **actual_schema**. If the values are not equal, an error will be raised.


   Run the cell and confirm that no errors are returned. This shows that the `get_health_csv_schema` function is working as expected.

   [Pytest Testing Documentation](https://spark.apache.org/docs/latest/api/python/reference/pyspark.testing.html#testing)

In [0]:
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DateType, DoubleType, LongType
from pyspark.sql.functions import when, col

from pyspark.testing.utils import assertSchemaEqual

def test_get_health_csv_schema_match():

    # Get schema from our function
    actual_schema = project_functions.get_health_csv_schema()
    
    # Define the expected schema that the function should return. If that function is changed during development the unit test will pick up the error and the test will fail.
    expected_schema = StructType([
        StructField("ID", IntegerType(), True),
        StructField("PII", StringType(), True),
        StructField("date", DateType(), True),
        StructField("HighCholest", IntegerType(), True),
        StructField("HighBP", DoubleType(), True),
        StructField("BMI", DoubleType(), True),
        StructField("Age", DoubleType(), True),
        StructField("Education", DoubleType(), True),
        StructField("income", IntegerType(), True)
    ])

    # Assert the actual schema matches the expected schema
    assertSchemaEqual(actual_schema, expected_schema)
    print('Test passed!')

test_get_health_csv_schema_match()

3. Create the test function `test_high_cholest_column_valid_map` to test the `highcholest_map` function from above. Within the test function:

   - Import the `assertDataFrameEqual` function from the `pyspark.testing.utils` module to compare two PySpark DataFrames and assert that they are equal, checking both the data and schema.
   
   - The variable **actual_df** holds the new column created from the result of the `highcholest_map` function.
   
   - The variable **expected_df** contains the expected dataframe we anticipate from the function.
   
   - The `assertDataFrameEqual` function from `pyspark.testing.utils` compares the dataframes **actual_df** and **expected_df**. If the data or schema are not equal, an error will be raised.

   Run the cell and confirm that no errors are returned. This shows that the `highcholest_map` function is working as expected.

   [Pytest Testing Documentation](https://spark.apache.org/docs/latest/api/python/reference/pyspark.testing.html#testing)

In [0]:
## Import assertDataFrameEqual to compare your DataFrames
from pyspark.testing.utils import assertDataFrameEqual

def test_high_cholest_column_valid_map():

    # Create the sample DataFrame to test the function
    data = [
        (0,),
        (1,), 
        (2,), 
        (3,), 
        (4,), 
        (None,)
    ]
    sample_df = spark.createDataFrame(data, ["value"])

    # Apply the function on the sample data
    actual_df = sample_df.withColumn("actual", project_functions.high_cholest_map("value"))

    # Create a static DataFrame with the expected results of the highcholest_map function above
    expected_df = spark.createDataFrame(
        [
            (0, "Normal"),
            (1, "Above Average"),
            (2, "High"),
            (3, "Unknown"),
            (4, "Unknown"),
            (None, "Unknown")
        ],
        schema=StructType([
            StructField("value", LongType(), True),
            StructField("actual", StringType(), True)
        ])
    )

    ## Check to make sure the column in the sample dataframe and expected dataframe are the same. If not equal an error will be returned.
    assertDataFrameEqual(actual_df, expected_df)
    print('Test passed!')

test_high_cholest_column_valid_map()

4. In cell below the **expected_df** variable has been modified from `(0, "Normal")` to `(0, "Bad Value Cause Error")`. This will cause the unit test to fail because the expected results were incorrectly specified and will not match the **actual_df**.

    Run the unit test below. Notice that the test function returns an error because the dataframes don't match. 
    
**NOTE:** If your unit tests fail you will need to assess where the failure occurred and fix it.

In [0]:
################################################################################
## Cell causes an error because the actual df does not match the expected df
################################################################################

def test_high_cholest_column_invalid_map():

    # Create the sample DataFrame to test the function on
    data = [
        (0,),
        (1,), 
        (2,), 
        (3,), 
        (4,), 
        (None,)
    ]
    sample_df = spark.createDataFrame(data, ["value"])

    # Apply the function on the sample data
    actual_df = sample_df.withColumn("actual", project_functions.high_cholest_map("value"))

    # Create a static DataFrame with the expected results of the highcholest_map function above
    expected_df = spark.createDataFrame(
        [
            (0, "Bad Value Cause Error"),     ####### <--- Value has been changed to cause an error
            (1, "Above Average"),
            (2, "High"),
            (3, "Unknown"),
            (4, "Unknown"),
            (None, "Unknown")
        ],
        schema=StructType([
            StructField("value", LongType(), True),
            StructField("actual", StringType(), True)
        ])
    )

    ## Check to make sure the column in the sample dataframe and expected dataframe are the same. If not equal an error will be returned.
    assertDataFrameEqual(actual_df, expected_df)

test_high_cholest_column_invalid_map()

Developing good unit tests are important because they help ensure individual components of your code work as expected, catching errors early in the development process. This leads to more reliable, maintainable code and can save time in debugging later on.

## B. Create a File for the Unit Tests

In the previous cells, we implemented our unit tests within the Databricks notebook.

Typically, unit tests are placed in a separate location, such as a **tests/unit_tests/tests_file_name.py** file (or in multiple Python files), to keep them separate from the main code and maintain a clean project structure. A testing framework is then used to execute the tests.

This approach makes it easier to manage, update, and run tests without affecting the application’s core functionality.

1. Let's start by importing our testing framework package `pytest` on the cluster.

    [pytest documentation](https://docs.pytest.org/en/stable/)

In [0]:
!pip install pytest==8.3.4

2. Complete the following steps to view how to define the unit test in your Python file for this project. The unit test functions have already been placed in main course folder in the **tests/unit_tests/test_spark_helper_functions.py** file for you.

    a. Navigate to the main course folder **DevOps Essentials for Data Engineering**.  
   
    b. Open the **tests/unit_tests** folder.

    c. Right-click on the **[tests/unit_tests/test_spark_helper_functions.py]($../../tests/unit_tests/test_spark_helper_functions.py)** file and select *Open in a new tab*. Review the file. 
    
    Notice the following:
  
      - On line 13, our custom `project_functions` is being imported from the `src.helpers` module.

         **NOTE:** We are defining the Python path setting for `pytest` within the **pytest.ini** file in the main course directory. This ensures that the specified directory is added to the Python module search path when `pytest` is run. This is useful for allowing `pytest` to import code from the project root or sibling directories without needing to manually modify `sys.path`. This makes the functions available for use in the current script.
      
      - `@pytest.fixture` is a pytest fixture named **spark** with a session scope, which means it will be set up once per test session and shared across multiple test functions. It creates a SparkSession using `SparkSession.builder.getOrCreate()` (which either retrieves an existing session or creates a new one if none exists) and then yields the spark session to the test functions that use this fixture, allowing them to access the Spark environment for their tests.

      - The remainder of the file creates the unit test functions: 
         - `test_get_health_csv_schema_match`

         - `test_highcholest_column_map`

         - `test_age_group_column_map`

    d. After reviewing the file, close the tab.

**NOTE:** The complexity and thoroughness of your tests depend on your organization's best practices and guidelines. This is a simple demonstration example. This course is simply introducing unit testing and will not go deep into `pytest` or any other testing framework.


3.  In the next cell use `pytest` to execute all the tests (3 tests) within the **tests/unit_tests/test_spark_helper_functions.py** file.

    To execute pytest within Databricks complete the following:

    - `sys.dont_write_bytecode = True` - By default, Python generates .pyc bytecode files in a __pycache__ directory. Setting  it to `True` prevents this, which is useful in read-only environments.

    - `pytest.main()` - This is the function that starts pytest and runs the tests.

    - `./test_simple/test_simple_functions.py` - The path to the test file you want to run.

    - `-v` - This stands for "verbose". It tells pytest to provide more detailed output during test execution, including the names of tests and their results.

    - `-p no:cacheprovider"` - This disables pytest's cache provider. pytest uses a caching mechanism to speed up repeated test runs, but in some cases, you may want to disable it. The no:cacheprovider option tells pytest to avoid using the cache, which may be useful if you’re running tests in an environment where caching could cause issues.

    When executing pytest, it looks for:

    - Test files: By default, files named **test_\*.py** or **\*_test.py**. In this example we are specifically testing the functions within the **tests/unit_tests/test_spark_helper_functions.py** file.

    - Test functions: Functions starting with `test_*()`.

    - For example, if you have a file like **test_example.py** with a function `test_addition()`, pytest will automatically detect and run the tests in that file.
 
    Run the cell below and view the results.


**NOTE:** Depending on your organization's best practices, you can run unit tests within Databricks or locally in the IDE of your choice based on your needs.

In [0]:
## import the `pytest` and `sys` modules.
import pytest
import sys

sys.dont_write_bytecode = True

retcode = pytest.main(["../../tests/unit_tests/test_spark_helper_functions.py", "-v", "-p", "no:cacheprovider"])

# Fail the cell execution if there are any test failures.
assert retcode == 0, "The pytest invocation failed. See the log for details."

4. View the output of the cell above. Notice that:

    - `pytest` automatically discovered all the unit test functions within the **/tests/test_spark_helper_functions.py** file.
    
    - The **test_spark_helper_functions.py** file contained three test functions: 
      - `test_get_health_csv_schema_match()`

      - `test_highcholest_column_map()`

      - `test_age_group_column_map()`
    
    - All three unit tests passed.

    For more information
      
    [Pyspark Testing Documentation](https://spark.apache.org/docs/latest/api/python/reference/pyspark.testing.html#testing)

    [Testing Pyspark Documentation](https://spark.apache.org/docs/latest/api/python/getting_started/testing_pyspark.html#Testing-PySpark)


## Summary
Unit testing is important because it helps ensure that individual parts of your code (like functions or methods) work correctly. It catches bugs early, improves code reliability, and makes it easier to maintain and refactor your code with confidence, knowing that existing functionality is still working as expected. You can perform your unit testing within Databricks or locally using the IDE of your choice.

This demonstration was a simple, quick introduction to unit testing with the `pytest` framework. There are a variety of other frameworks available, and you must decide which one works best for your team. Also, it's important to decide on unit test best practices for your team and organization.

#### Next Steps
You can set up a continuous integration and continuous delivery or deployment (CI/CD) system, such as GitHub Actions, to automatically run your unit tests whenever your code changes. For an example, see the coverage of GitHub Actions in [Software engineering best practices for notebooks](https://docs.databricks.com/en/notebooks/best-practices.html). 


#### Additional Unit Testing Resources
- [pytest](https://docs.pytest.org/en/stable/)
- [chispa](https://github.com/MrPowers/chispa)
- [nutter](https://github.com/microsoft/nutter)
- [unittest](https://docs.python.org/3/library/unittest.html)
- [Best Practices for Unit Testing PySpark](https://www.youtube.com/watch?v=TbWcCyP2MgE)

&copy; 2025 Databricks, Inc. All rights reserved. Apache, Apache Spark, Spark, the Spark Logo, Apache Iceberg, Iceberg, and the Apache Iceberg logo are trademarks of the <a href="https://www.apache.org/" target="_blank">Apache Software Foundation</a>.<br/><br/><a href="https://databricks.com/privacy-policy" target="_blank">Privacy Policy</a> | <a href="https://databricks.com/terms-of-use" target="_blank">Terms of Use</a> | <a href="https://help.databricks.com/" target="_blank">Support</a>