In [None]:
import ipywidgets as widgets
from IPython.display import display

import re
import pandas as pd

In [None]:
widget_user = widgets.Text(
    value='testuser',
    placeholder='Type something',
    description='user: ',
    disabled=False,
    style={'description_width': '100px'}    
)

widget_git_org = widgets.Text(
    value='Nike-Inc',
    placeholder='Type something',
    description='git_org ',
    disabled=False,
    style={'description_width': '100px'}    
)

widget_catalog = widgets.Text(
    value='spark_catalog',
    placeholder='Type something',
    description='catalog:',
    disabled=False,
    style={'description_width': '100px'}    
)

widget_schema = widgets.Text(
    value='default',
    placeholder='Type something',
    description='schema:',
    disabled=False,
    style={'description_width': '100px'}
)

widget_library_source = widgets.Combobox(
    placeholder='Choose source',
    options=['pypi', 'git'],
    description='library_source:',
    ensure_option=True,
    value='git',
    disabled=False,
    style={'description_width': '100px'}
)

widget_git_branch_or_commit = widgets.Text(
    value='main',
    placeholder='Type branch name or commit hash',
    description='git_branch_or_commit:',
    disabled=False,
    style={'description_width': '150px'}
)

widget_override_version = widgets.Checkbox(
    value=False,
    description='Override SE version',
    disabled=False,
    style={'description_width': '30px'}
    
)

hbox = widgets.HBox([
    widget_user,
    widget_catalog, 
    widget_schema,
    widget_override_version, 
    widget_library_source, 
    widget_git_org,
    widget_git_branch_or_commit
])

In [None]:
# Display widgets
display(hbox)

In [None]:
user = re.sub(r'[^a-zA-Z]', '', widget_user.value).lower()
catalog = widget_catalog.value
schema = widget_schema.value
override_se_version = widget_override_version.value
library = widget_library_source.value
org = widget_git_org.value
branch_or_commit = widget_git_branch_or_commit.value

print(user)
print(catalog)
print(schema)
print(override_se_version)
print(library)
print(org)
print(branch_or_commit)

In [None]:
CONFIG = {
    "owner": user,
    "catalog": catalog,
    "schema": schema,
    "user": user,
    "product_id": f"se_{user}_product",
    "in_memory_source": f"se_{user}_source",
    "rules_table": f"{catalog}.{schema}.se_{user}_rules",
    "stats_table": f"{catalog}.{schema}.se_{user}_stats",
    "target_table": f"{catalog}.{schema}.se_{user}_target",
    "override_se_version" : override_se_version,
    "library": library,
    "org": org,
    "branch_or_commit": branch_or_commit
}

config_df = pd.DataFrame(list(CONFIG.items()), columns=['Key', 'Value'])
display(config_df)

In [None]:
# Display Spark Expectations installed version
from importlib.metadata import version
print(f"---- Current SparkExpectation Version: {version('spark-expectations')}")

### Setting up spark expectations

In [None]:
# CREATE SPARK SESSION AND DATABASE
from pyspark.sql import SparkSession

# Create or get a Spark session
spark = SparkSession.builder \
    .appName("Spark Aggregation Rules") \
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
    .config("spark.jars.packages", "io.delta:delta-spark_2.12:3.0.0") \
    .getOrCreate()

In [None]:
databases_df = spark.sql("SHOW DATABASES")
databases_df.show(truncate=False)

tables_df = spark.sql("SHOW TABLES")
tables_df.show(truncate=False)


In [None]:
db_name = f"{CONFIG['catalog']}.{CONFIG['schema']}"
pattern = f"se_{CONFIG['user']}*"

# Set the current catalog
spark.sql(f"USE {CONFIG['catalog']}")

# Drop tables matching pattern
tables_df = spark.sql(f"SHOW TABLES IN {db_name} LIKE '{pattern}'")
tables_to_drop = [row for row in tables_df.collect() if not row["isTemporary"] ]

if tables_to_drop:
    print(f"Found {len(tables_to_drop)} tables to drop.")
    for row in tables_to_drop:
        table_name = row["tableName"]
        spark.sql(f"DROP TABLE IF EXISTS {db_name}.{table_name}")
        print(f"Dropped table: {db_name}.{table_name}")
else:
    print("----- No tables to drop")

In [None]:
views_df = spark.sql(f"SHOW VIEWS in {db_name} LIKE '{pattern}'")
views_to_drop = views_df.collect()

if views_to_drop:
    print(f"Found {len(views_to_drop)} views to drop.")
    for row in views_to_drop:
        view_name = row["viewName"]
        spark.sql(f"DROP VIEW IF EXISTS {view_name}")
        print(f"Dropped view: {view_name}")
else:
    print("----- No views to drop")

### Spark expectation execution

In [None]:
import pandas as pd
from pyspark.sql import SparkSession

In [None]:
rules_data = [
    {
        "product_id": CONFIG["product_id"],
        "table_name": CONFIG["target_table"],
        "rule_type": "agg_dq",
        "rule": "data_existing",
        "column_name": "sales",
        "expectation": "count(*) > 0",
        "action_if_failed": "fail",
        "tag": "completeness",
        "description": "Data should be present",
        "enable_for_source_dq_validation": True,
        "enable_for_target_dq_validation": True,
        "is_active": True,
        "enable_error_drop_alert": False,
        "error_drop_threshold": 0,
    },
    {
        "product_id": CONFIG["product_id"],
        "table_name": CONFIG["target_table"],
        "rule_type": "agg_dq",
        "rule": "min_sales",
        "column_name": "sales",
        "expectation": "min(sales)>1000",
        "action_if_failed": "warn",
        "tag": "validity",
        "description": "Minimum sales should be greater than 1000",
        "enable_for_source_dq_validation": True,
        "enable_for_target_dq_validation": True,
        "is_active": True,
        "enable_error_drop_alert": False,
        "error_drop_threshold": 0,
    },
    {
        "product_id": CONFIG["product_id"],
        "table_name": CONFIG["target_table"],
        "rule_type": "agg_dq",
        "rule": "no_duplicates",
        "column_name": "id",
        "expectation": "count(distinct id) == count(*)",
        "action_if_failed": "fail",
        "tag": "uniqueness",
        "description": "Each data point should have a unique id",
        "enable_for_source_dq_validation": True,
        "enable_for_target_dq_validation": True,
        "is_active": True,
        "enable_error_drop_alert": False,
        "error_drop_threshold": 0,
    }
]

In [None]:
rules_df = spark.createDataFrame(pd.DataFrame(rules_data))

In [None]:
rules_df.write.format("delta").mode("overwrite").saveAsTable(CONFIG['rules_table'])

In [None]:
rules_df.toPandas()

### Running spark expectations

In [None]:
from pyspark.sql import DataFrame

from spark_expectations.core import load_configurations
from spark_expectations.config.user_config import Constants as user_config

from spark_expectations.core.expectations import (
    SparkExpectations,
    WrappedDataFrameWriter,
)

In [None]:
writer = WrappedDataFrameWriter().mode("overwrite").format("delta")

In [None]:
load_configurations(spark) 

In [None]:
# Custom config (example enable slack/email notifications)
stats_streaming_config_dict = {user_config.se_enable_streaming: False}
notification_conf = {}

In [None]:
se = SparkExpectations(
    product_id=CONFIG["product_id"],
    rules_df=rules_df,
    stats_table=CONFIG["stats_table"],
    stats_table_writer=writer,
    target_and_error_table_writer=writer,
    stats_streaming_options=stats_streaming_config_dict,
)

In [None]:
#  Initialize input data
data = [
    {"id": 1, "name": "Alice",  "sales": 3500},
    {"id": 2, "name": "Bob",   "sales": 2800},
    {"id": 3, "name": "Charlie", "sales": 4200},
    {"id": 4, "name": "Mike",   "sales": 2300},
    {"id": 5, "name": "Ron", "sales": None},
    {"id": 6, "name": "Zach",   "sales": 3900},
    {"id": 7, "name": "Alex",   "sales": 4100},
    {"id": 8, "name": "Steve",   "sales": 4200},
    {"id": 9, "name": "James",   "sales": 900},
    {"id": 10, "name": "Dan",   "sales": 2500},
    {"id": 4, "name": "Bryan",   "sales": 1600}
]
input_df = spark.createDataFrame(pd.DataFrame(data))
input_df.show(truncate=False)

In [None]:
@se.with_expectations(
    target_table=CONFIG["target_table"],
    write_to_table=True,
    write_to_temp_table=True,
    user_conf=notification_conf,
)
def get_dataset():
    _df_source: DataFrame = input_df
    _df_source.createOrReplaceTempView(CONFIG["in_memory_source"])
    return _df_source

In [None]:
get_dataset()

In [None]:
query_stats_table = f"SELECT * FROM {CONFIG['stats_table']}"
query_stats_table_df = spark.sql(query_stats_table).toPandas()
query_stats_table_df

In [None]:
# Failure of rules captured here
query_stats_table_df.loc[0,'source_agg_dq_results']

In [None]:
databases_df = spark.sql("SHOW DATABASES")
databases_df.show(truncate=False)

tables_df = spark.sql("SHOW TABLES")
tables_df.show(truncate=False)


In [None]:
query_target_table = f"""
SELECT *
FROM {CONFIG['target_table']} 
ORDER BY meta_dq_run_id, id
"""

final_data_set_df = spark.sql(query_target_table)

if final_data_set_df is not None:
    final_data_set_df.show(truncate=False)

In [None]:
input_count = spark.sql(f"SELECT COUNT(*) AS count FROM {CONFIG['in_memory_source']}").collect()[0]['count']
output_count = spark.sql(f"SELECT COUNT(*) AS count FROM {CONFIG['target_table']}").collect()[0]['count']

# Find missing rows in target_table that are present in in_memory_source
removed_rows_df = spark.sql(f"""
SELECT s.*
FROM {CONFIG['in_memory_source']} s
LEFT ANTI JOIN {CONFIG['target_table']} t
ON s.id = t.id
""")

removed_rows_count = removed_rows_df.count()

comparison_df = spark.createDataFrame(
    [
        ("input", input_count),
        ("output", output_count),
        ("removed_records", removed_rows_count)
    ],
    ["table", "record_count"]
)

comparison_df.show()

if removed_rows_count > 0:
    removed_rows_df.show(truncate=False)