In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Sample data for the sku table
sku_data = [
    ('sku_001', 'Product A', 10, 5.0),  
    ('sku_002', 'Product B', 20, 3.0),
    ('sku_003', 'Product C', 15, 7.0)
]

# Define the schema for the sku table
sku_columns = ['sku_id', 'product_name', 'quantity', 'price']

# Create DataFrame for the sku table
sku_df = spark.createDataFrame(sku_data, schema=sku_columns)

sku_df.show()

# Sample data for the revenue table
revenue_data = [ 
    ('sku_001', 50.0), 
    ('sku_002', 60.0),
    ('sku_003', 105.0)
]

revenue_columns = ['sku_id', 'revenue']

# Create DataFrame for revenue table
revenue_df = spark.createDataFrame(revenue_data, schema=revenue_columns)

revenue_df.show()

# Perform Regression Test: Compare the total revenue from the revenue table and calculated revenue from sku table

# Calculate the total revenue from the revenue table
total_revenue_from_revenue_table = revenue_df.agg({'revenue': 'sum'}).collect()[0][0]

# Calculate the revenue based on sku table (quantity * price)
calculated_revenue = sku_df.withColumn('calculated_revenue', col('quantity') * col('price')) \
                           .agg({'calculated_revenue': 'sum'}).collect()[0][0]

print(f"Total revenue from revenue table: {total_revenue_from_revenue_table}")
print(f"Calculated revenue based on SKU table: {calculated_revenue}")

# Perform regression test
if total_revenue_from_revenue_table == calculated_revenue:
    print("Regression test passed: Revenue is consistent")
else:
    print(f"Regression test failed: {total_revenue_from_revenue_table} != {calculated_revenue}")
    raise Exception("Regression test failed")
