<a href="https://colab.research.google.com/github/MAY2704/ML_usecases/blob/main/spark_examples/ETL_spark_tests_3_examples.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Import necessary functions
!pip install pyspark
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, cast, udf
from pyspark.sql.types import StringType, IntegerType

# Create a SparkSession
spark = SparkSession.builder.appName("DataFrameExample").getOrCreate()

# Raw Data ingestion

# Create the first source DataFrame with party data
pandas_df_src_party = pd.DataFrame({"id": [1, 2, 3, 4, 5], "name": ["Alpha", "Beta", "Charlie", "Delta", "Echo"], "age": [45, 76, 30, 70, 26]})
spark_df_src_data_party = spark.createDataFrame(pandas_df_src_party)
spark_df_src_data_party.show()

# Create the second source DataFrame with address data
pandas_df_src_address = pd.DataFrame({"id": [1, 2, 3, 4, 5], "street": ["Barbastraat", "Michalstraat", "Parijstraat", "Tiensesstraat", "Dieststraat"]})
spark_df_src_data_address = spark.createDataFrame(pandas_df_src_address)
spark_df_src_data_address.show()

# Sample ETL stage 1

# Join the DataFrames on the "id" column
joined_df = spark_df_src_data_party.join(spark_df_src_data_address, on="id", how="inner")
joined_df.show()

# Sample ETL stage 2 (UDF for age category)
@udf(returnType=StringType())
def get_age_category(age):
    if age >= 60:
        return "Senior"
    elif age <= 18:
        return "Junior"
    else:
        return "Medior"

# Add the "age_category" column using the UDF
df_with_age_category = joined_df.withColumn("age_category", get_age_category(col("age")))

# Display the DataFrame with the new column
df_with_age_category.show()

# Now, let us test the UDF of ETL

# Test 1 = Given a fixed set of input data, the real output must match expected output

def test_age_category_logic():
    """
    Unit test for the logic of creating the "age_category" column based on age.
    """

    # Given INPUT TEST DATA
    data = [
        (1, "Test1", 35),
        (2, "Test2", 72),
        (3, "Test3", 16),
        (4, "Test4", 28),
        (5, "Test5", 60),
    ]
    df = spark.createDataFrame(data, ["id", "name", "age"])

    # AND GIVEN Expected results (modify based on your logic)
    expected_data = [
        (1, "Test1", 35, "Medior"),
        (2, "Test2", 72, "Senior"),
        (3, "Test3", 16, "Junior"),
        (4, "Test4", 28, "Medior"),
        (5, "Test5", 60, "Senior"),
    ]


    # WHEN applying the UDF (from above function) to create the "age_category" column
    df_with_category = df.withColumn("age_category", get_age_category(col("age")))
    df_with_category.show()

    # THEN Assert the results match expectations
    expected_df = spark.createDataFrame(expected_data, ["id", "name", "age", "age_category"])
    expected_df.show()
    assert df_with_category.collect() == expected_df.collect(), "Processed test data does not match with expected results, Test 1 is failed"
    print("Processed test data match with expected results!, Test 1 is pass")

test_age_category_logic()

# Test 2 = Checking the source and target count
def test_source_target_count_match():

  # Get source and target DataFrame row counts
  source_df_count = spark_df_src_data_party.count()  # Source data is in df_src_data_party
  target_df_count = df_with_age_category.count() # Target data is in df_with_age_category

  # Assert that the counts match
  assert source_df_count == target_df_count, "Source and target DataFrame counts do not match, Test 2 is failed"
  print("Source and target DataFrame counts match!, Test 2 is pass")

test_source_target_count_match()

# Test 3 = Checking data quality in target

def test_age_category_not_null():

  # Filter rows with null values in "age_category"
  df_with_nulls = df_with_age_category.filter(col("age_category").isNull())

  # Assert that there are no rows with null values
  assert df_with_nulls.count() == 0, "Target DataFrame contains null values in 'age_category' column, Test 3 is failed"

  print("'age_category' column in target DataFrame has no null values!, Test 3 is pass")

test_age_category_not_null()

# Stop SparkSession
spark.stop()


Collecting pyspark
  Downloading pyspark-3.5.2.tar.gz (317.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.3/317.3 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.5.2-py2.py3-none-any.whl size=317812365 sha256=9fca88ef29116b6169d58f4c469b60f862ce7ff8a8a9334bf9d6d985536aac0b
  Stored in directory: /root/.cache/pip/wheels/34/34/bd/03944534c44b677cd5859f248090daa9fb27b3c8f8e5f49574
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.5.2
+---+-------+---+
| id|   name|age|
+---+-------+---+
|  1|  Alpha| 45|
|  2|   Beta| 76|
|  3|Charlie| 30|
|  4|  Delta| 70|
|  5|   Echo| 26|
+---+-------+---+

+---+-------------+
| id|       street|
+---+-------------+
|  1|  Barbastraat|
|  2| Michalstraat|
|  3|  Parij