In [0]:
%pip install pytest


Python interpreter will be restarted.
Collecting pytest
  Downloading pytest-8.4.0-py3-none-any.whl (363 kB)
Collecting exceptiongroup>=1
  Downloading exceptiongroup-1.3.0-py3-none-any.whl (16 kB)
Collecting iniconfig>=1
  Downloading iniconfig-2.1.0-py3-none-any.whl (6.0 kB)
Collecting pluggy<2,>=1.5
  Downloading pluggy-1.6.0-py3-none-any.whl (20 kB)
Collecting typing-extensions>=4.6.0
  Downloading typing_extensions-4.14.0-py3-none-any.whl (43 kB)
Installing collected packages: typing-extensions, pluggy, iniconfig, exceptiongroup, pytest
  Attempting uninstall: typing-extensions
    Found existing installation: typing-extensions 4.1.1
    Not uninstalling typing-extensions at /databricks/python3/lib/python3.9/site-packages, outside environment /local_disk0/.ephemeral_nfs/envs/pythonEnv-b22b6341-1f4d-4a58-8c13-65464120b8e3
    Can't uninstall 'typing-extensions'. No files were found to uninstall.
  Attempting uninstall: pluggy
    Found existing installation: pluggy 1.0.0
    Not un

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import round, col
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType
import pytest


In [0]:
@pytest.fixture(scope="module")
def spark():
    spark = SparkSession.builder.master("local[*]").appName("test").getOrCreate()
    yield spark
    spark.stop()


In [0]:
#sample schema
orders_schema = StructType([
    StructField("order_id", IntegerType()),
    StructField("customer_id", IntegerType()),
    StructField("product_id", IntegerType()),
    StructField("profit", DoubleType())
])

customers_schema = StructType([
    StructField("customer_id", IntegerType()),
    StructField("customer_name", StringType()),
    StructField("country", StringType())
])

products_schema = StructType([
    StructField("product_id", IntegerType()),
    StructField("category", StringType()),
    StructField("sub_category", StringType())
])


In [0]:
def test_enriched_table(spark):
    #sample data
    orders_data = [
        (1, 101, 1001, 123.4567),
        (2, 102, 1002, 78.9),
    ]
    customers_data = [
        (101, "Alice", "USA"),
        (102, "Bob", "UK"),
    ]
    products_data = [
        (1001, "Electronics", "Phones"),
        (1002, "Furniture", "Chairs"),
    ]

    # sample DataFrame
    orders_df = spark.createDataFrame(orders_data, orders_schema)
    customers_df = spark.createDataFrame(customers_data, customers_schema)
    products_df = spark.createDataFrame(products_data, products_schema)

    # orders with customers join
    df = orders_df.join(customers_df, "customer_id", "left")
    # Join with products
    df = df.join(products_df, "product_id", "left")
    df = df.withColumn("profit_rounded", round(col("profit"), 2))

    # result for assertion
    result = df.select(
        "order_id", "customer_name", "country", "category", "sub_category", "profit_rounded"
    ).collect()

    expected = [
        (1, "Alice", "USA", "Electronics", "Phones", 123.46),
        (2, "Bob", "UK", "Furniture", "Chairs", 78.90)
    ]

    for row, exp in zip(result, expected):
        assert row.order_id == exp[0]
        assert row.customer_name == exp[1]
        assert row.country == exp[2]
        assert row.category == exp[3]
        assert row.sub_category == exp[4]
        assert abs(row.profit_rounded - exp[5]) < 1e-2
