Import Libraries

In [0]:
import unittest
from pyspark.sql import functions as F
from pyspark.sql.functions import *
from pyspark.sql import SparkSession
import warnings


Create sample dataframe

In [0]:
# Creating the sample DataFrame
data = [
    ("001", "Michael Jackson", 30, "New York"),
    ("002", "Jane Smith", 40, "Los Angeles"),
    ("003", "Sam Brown", None, "Chicago"),
    ("004", "BBohemian Rhapsody", 25, "Houston")
]
schema = ["CustomerID", "Name", "Age", "City"]
df = spark.createDataFrame(data, schema)
print("original dataset")
df.show()

# Applying the business logic transformations
# Transformation_1: Replace null values in 'Age' with the average 'Age'
average_age = df.select(avg(col("Age"))).first()[0]
df = df.withColumn("Age", when(col("Age").isNull(), lit(average_age)).otherwise(col("Age")))
df = df.withColumn("Age", round(col("Age"), 2))
print("Transformation_1 dataset")
df.show()

# Transformation_2: Count the distinct total number of customers
df_count = df.select('CustomerID').distinct().count()
print("Transformation_2 result")
print(f"Distinct Customer Count: {df_count}")

original dataset
+----------+------------------+----+-----------+
|CustomerID|              Name| Age|       City|
+----------+------------------+----+-----------+
|       001|   Michael Jackson|  30|   New York|
|       002|        Jane Smith|  40|Los Angeles|
|       003|         Sam Brown|null|    Chicago|
|       004|BBohemian Rhapsody|  25|    Houston|
+----------+------------------+----+-----------+

Transformation_1 dataset
+----------+------------------+-----+-----------+
|CustomerID|              Name|  Age|       City|
+----------+------------------+-----+-----------+
|       001|   Michael Jackson| 30.0|   New York|
|       002|        Jane Smith| 40.0|Los Angeles|
|       003|         Sam Brown|31.67|    Chicago|
|       004|BBohemian Rhapsody| 25.0|    Houston|
+----------+------------------+-----+-----------+

Transformation_2 result
Distinct Customer Count: 4


In [0]:
# Unit Test class
class TestCustomerTransformation(unittest.TestCase):

    def setUp(self):
        """
        Set up the test by reading the data into a DataFrame.
        """
        warnings.filterwarnings("ignore", category=ResourceWarning)
        
        # Use the existing Spark session instead of creating a new one
        self.spark = spark

    def test_null_age_values(self):
        """
        Test the Age transformation applied on the dataframe.
        """
        # Count null values in Age column
        age_null_count = df.filter(col('Age').isNull()).count()
        self.assertEqual(age_null_count, 0, 
                         f"Error: The Age_Null_count obtained after transformation is {age_null_count} (should be 0).")

    def test_distinct_customer_count(self):
        """
        Test to check the number of distinct customers.
        """
        distinct_records = df.select('CustomerID').distinct().count()
        expected_records = 5  
        self.assertEqual(distinct_records, expected_records, 
                         f"The total records in the register is {distinct_records}, which is not the expected count of {expected_records}.")


In [0]:
# Custom Test Result class to track passing tests
class CustomTestResult(unittest.TextTestResult):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.passing_tests = []

    def addSuccess(self, test):
        super().addSuccess(test)
        self.passing_tests.append(test)  # Track passing tests

def run_tests_1():
    """
    Function to run the unittest test cases manually, displaying only passing and failing tests.
    """
    suite = unittest.TestLoader().loadTestsFromTestCase(TestCustomerTransformation)
    combined_suite = unittest.TestSuite(suite)
    
    # Run the combined test suite with a custom runner
    runner = unittest.TextTestRunner(resultclass=CustomTestResult, verbosity=0)
    result = runner.run(combined_suite)
    
    # Display summary of passing and failing tests
    print("\nSummary:")

    print("\nFailing tests:")
    if result.failures or result.errors:
        for failed_test, _ in result.failures + result.errors:  # Only print the test names, not tracebacks
            print(f" - {failed_test}")
    else:
        print("No failed tests found")

# Run the tests manually
run_tests_1()

FAIL: test_distinct_customer_count (__main__.TestCustomerTransformation)
Test to check the number of distinct customers.
----------------------------------------------------------------------
Traceback (most recent call last):
  File "<command-4325684114186129>", line 28, in test_distinct_customer_count
    self.assertEqual(distinct_records, expected_records,
AssertionError: 4 != 5 : The total records in the register is 4, which is not the expected count of 5.

----------------------------------------------------------------------
Ran 2 tests in 0.826s

FAILED (failures=1)



Summary:

Failing tests:
 - test_distinct_customer_count (__main__.TestCustomerTransformation)
