In [0]:
from pyspark.sql import functions as F
from pyspark.sql import types as T
import json

# ---------- 1. Get pr_id parameter ----------
dbutils.widgets.text("pr_id", "local_dev")  # default for manual testing
pr_id = dbutils.widgets.get("pr_id")

assert pr_id, "pr_id is required"

print(f"Running assertions for pr_id = '{pr_id}'")

# ---------- 2. Build clean DB name ----------
if pr_id == "prod":
    clean_db_name = "clean"
else:
    clean_db_name = f"{pr_id}_clean"

print(f"Using clean DB = {clean_db_name}")

# ---------- 3. Load expected output from JSON file in the repo ----------

expected_path = "/Workspace/Repos/radomir@elfak.rs/de-lab-databricks/tests/expected/orders_expected.json"
print(f"Reading expected data from (Python file read): {expected_path}")

with open(expected_path, "r") as f:
    expected_data = json.load(f)

print("Raw Python expected data loaded from JSON:")
print(expected_data)

expected_schema = T.StructType([
    T.StructField("order_id", T.StringType(), False),
    T.StructField("customer_id", T.StringType(), False),
    T.StructField("amount", T.DoubleType(), False),
    T.StructField("order_year", T.IntegerType(), False),
])

df_expected = spark.createDataFrame(expected_data, expected_schema)

print("Expected data DataFrame:")
display(df_expected)

# ---------- 4. Load actual data from clean table ----------
clean_table = f"{clean_db_name}.orders_enriched"
print(f"Reading actual data from clean table: {clean_table}")

df_actual_raw = spark.table(clean_table)

# We only compare the same columns as in expected
df_actual = (
    df_actual_raw
    .select("order_id", "customer_id", "amount", "order_year")
)

print("Actual data (subset of columns):")
display(df_actual)

# ---------- 5. Compare expected vs actual ----------
# We'll use set-based comparison (subtract both ways)
diff1 = df_actual.subtract(df_expected)
diff2 = df_expected.subtract(df_actual)

count_diff1 = diff1.count()
count_diff2 = diff2.count()

print(f"Rows in actual but not in expected: {count_diff1}")
print(f"Rows in expected but not in actual: {count_diff2}")

if count_diff1 != 0 or count_diff2 != 0:
    print("Differences found between actual and expected ❌")
    print("Rows in actual but not in expected:")
    display(diff1)
    print("Rows in expected but not in actual:")
    display(diff2)
    raise AssertionError("Actual clean table does not match expected table.")
else:
    print("Actual data matches expected data ✅")
