In [19]:
# test data setup

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DateType, DoubleType
import yaml
from datetime import date, timedelta

# Initialize Spark session
spark = SparkSession.builder.appName("DataValidationTest").getOrCreate()

# Define schema for the sales table
schema = StructType([
    StructField("order_id", StringType(), False),
    StructField("order_date", DateType(), False),
    StructField("product_category", StringType(), False),
    StructField("quantity", IntegerType(), False),
    StructField("revenue", DoubleType(), False)
])

# Generate test data
def generate_test_data():
    data = []
    categories = ["Electronics", "Clothing", "Books", "Home"]
    start_date = date(2024, 1, 1)
    
    for i in range(1000):  # Generate 1000 sample records
        order_date = start_date + timedelta(days=i % 365)
        category = categories[i % len(categories)]
        quantity = (i % 10) + 1
        revenue = quantity * (100.00 + (i % 100))
        
        data.append((f"ORD-{i:04d}", order_date, category, quantity, revenue))
    
    return spark.createDataFrame(data, schema)

# Create and register the test table
test_df = generate_test_data()
test_df.createOrReplaceTempView("sales_table")

In [20]:
spark.sql("select * from sales_table").show(10, False)

+--------+----------+----------------+--------+-------+
|order_id|order_date|product_category|quantity|revenue|
+--------+----------+----------------+--------+-------+
|ORD-0000|2024-01-01|Electronics     |1       |100.0  |
|ORD-0001|2024-01-02|Clothing        |2       |202.0  |
|ORD-0002|2024-01-03|Books           |3       |306.0  |
|ORD-0003|2024-01-04|Home            |4       |412.0  |
|ORD-0004|2024-01-05|Electronics     |5       |520.0  |
|ORD-0005|2024-01-06|Clothing        |6       |630.0  |
|ORD-0006|2024-01-07|Books           |7       |742.0  |
|ORD-0007|2024-01-08|Home            |8       |856.0  |
|ORD-0008|2024-01-09|Electronics     |9       |972.0  |
|ORD-0009|2024-01-10|Clothing        |10      |1090.0 |
+--------+----------+----------------+--------+-------+
only showing top 10 rows



                                                                                

In [17]:
spark.stop()

In [9]:
import sys
sys.path.insert(0,'/Users/sthitaprajnasahoo/local_learning/sparkdq')

In [29]:
def load_config(file_path):
    with open(file_path, 'r') as file:
        return yaml.safe_load(file)

def execute_sql_rule(spark, sql_expression, variables):
    formatted_sql = sql_expression.format(**variables)
    print(formatted_sql)
    df = spark.sql(formatted_sql)
    print(df.show())
    return df

def compare_result(result_df, expected_result):
    column = expected_result['column']
    operator = expected_result['operator']
    value = expected_result['value']

    if operator == '>':
        return result_df.select(col(column) > value).first()[0]
    elif operator == '<':
        return result_df.select(col(column) < value).first()[0]
    elif operator == 'between':
        return result_df.select((col(column) >= value[0]) & (col(column) <= value[1])).first()[0]
    # Add more operators as needed

def validate_rules(spark, config, variables):
    results = []
    for table in config['tables']:
        for rule in table['rules']:
            result_df = execute_sql_rule(spark, rule['sql_expression'], variables)
            is_valid = compare_result(result_df, rule['expected_result'])
            
            results.append({
                'table_name': table['name'],
                'rule_id': rule['rule_id'],
                'sql_expression': rule['sql_expression'],
                'expected_result': rule['expected_result'],
                'is_valid': is_valid
            })
    
    return results

def store_results(results):
    results_df = spark.createDataFrame(results)
    results_df.write.format("delta").mode("overwrite").save("/path/to/validation_results")

In [37]:
config = load_config('config/validation_rules.yaml')
variables = {
    "start_date": "2024-01-01",
    "categories": "('Electronics', 'Clothing')"
    }

In [38]:
config

{'tables': [{'name': 'sales_table',
   'rules': [{'rule_id': 'rule1',
     'sql_expression': "SELECT MAX(revenue) as max_revenue FROM sales_table WHERE order_date > '{start_date}'",
     'expected_result': {'column': 'max_revenue',
      'operator': '>',
      'value': 10000}},
    {'rule_id': 'rule2',
     'sql_expression': 'SELECT AVG(quantity) as avg_quantity FROM sales_table WHERE product_category IN {categories}',
     'expected_result': {'column': 'avg_quantity',
      'operator': 'between',
      'value': [5, 100]}}]}]}

In [39]:
validation_results = validate_rules(spark, config, variables)

SELECT MAX(revenue) as max_revenue FROM sales_table WHERE order_date > '2024-01-01'
+-----------+
|max_revenue|
+-----------+
|     1990.0|
+-----------+

None
SELECT AVG(quantity) as avg_quantity FROM sales_table WHERE product_category IN ('Electronics', 'Clothing')
+------------+
|avg_quantity|
+------------+
|         5.5|
+------------+

None


In [40]:
validation_results

[{'table_name': 'sales_table',
  'rule_id': 'rule1',
  'sql_expression': "SELECT MAX(revenue) as max_revenue FROM sales_table WHERE order_date > '{start_date}'",
  'expected_result': {'column': 'max_revenue',
   'operator': '>',
   'value': 10000},
  'is_valid': False},
 {'table_name': 'sales_table',
  'rule_id': 'rule2',
  'sql_expression': 'SELECT AVG(quantity) as avg_quantity FROM sales_table WHERE product_category IN {categories}',
  'expected_result': {'column': 'avg_quantity',
   'operator': 'between',
   'value': [5, 100]},
  'is_valid': True}]