In [0]:
# COMMAND ----------
# Import pytest
%pip install pytest

import pytest

# COMMAND ----------
# Run your ETL notebook so we can reuse functions

# COMMAND ----------
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType
from datetime import datetime

def test_log_audit_creates_record():
    """
    Test that log_audit writes to the audit log table.
    """
    log_audit("pytest_table", "Started", 5, 10, "pytest run")

    df = spark.table(audit_log_table_full_name).filter(F.col("table_name") == "pytest_table").orderBy(F.desc("timestamp"))
    row = df.first()
    assert row["status"] == "Started"
    assert row["initial_count"] == 5
    assert row["final_count"] == 10

def test_fillna_and_trim():
    """
    Test that nulls are filled and strings are trimmed.
    """
    schema = StructType([
        StructField("id", IntegerType(), True),
        StructField("name", StringType(), True),
        StructField("category", StringType(), True),
        StructField("cost", DoubleType(), True)
    ])
    data = [
        (1, "  Alice  ", None, None),
        (2, None, " Electronics  ", 100.0),
        (1, " Alice ", "Unknown", 0.0),  # duplicate
    ]
    df = spark.createDataFrame(data, schema=schema)

    defaults = {"id": 0, "name": "Unknown", "category": "Unknown", "cost": 0.0}
    df_filled = df.fillna(defaults)

    string_columns = [f.name for f in df_filled.schema.fields if isinstance(f.dataType, StringType)]
    for col in string_columns:
        df_filled = df_filled.withColumn(col, F.trim(F.col(col)))

    df_final = df_filled.dropDuplicates(["id"])

    result = {row["id"]: (row["name"], row["category"], row["cost"]) for row in df_final.collect()}
    assert result[1] == ("Alice", "Unknown", 0.0)
    assert result[2] == ("Unknown", "Electronics", 100.0)

def test_timestamp_fill():
    """
    Test timestamp null replacement.
    """
    schema = StructType([
        StructField("id", IntegerType(), True),
        StructField("created_at", TimestampType(), True)
    ])
    df = spark.createDataFrame([(1, None)], schema=schema)

    df_filled = df.withColumn(
        "created_at",
        F.when(F.col("created_at").isNull(), F.lit(datetime(2023, 1, 1))).otherwise(F.col("created_at"))
    )

    row = df_filled.first()
    assert row["created_at"] == datetime(2023, 1, 1)

# COMMAND ----------
# Run pytest programmatically inside the notebook
pytest.main(["-q"])
