In [None]:
import logging
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType, StringType, DateType
from datetime import datetime

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

try:
    # Load data from Unity Catalog tables
    transactions_df = spark.table("genai_demo.sas.transactions")
    customers_df = spark.table("genai_demo.sas.customers")

    # Step 1: Filter valid transactions
    valid_txns_df = transactions_df.filter((F.col("Sales") > 0) & (F.col("Product").isNotNull()))

    # Step 2: Calculate Effective Price
    trans_step2_df = valid_txns_df.withColumn("EffectivePrice", F.col("Sales") * (1 - F.col("Discount") / 100))

    # Step 3: Calculate Total Value
    trans_step3_df = trans_step2_df.withColumn("TotalValue", F.col("EffectivePrice") * F.col("Quantity"))

    # Step 4: Join with Customers
    full_data_df = trans_step3_df.join(customers_df, on="CustomerID", how="left")

    # Step 5: Calculate Tenure Days
    full_data_df = full_data_df.withColumn("TenureDays", F.datediff(F.col("TransDate"), F.col("JoinDate")))

    # Step 6: Assign Tenure Category
    trans_step6_df = full_data_df.withColumn("TenureCategory", 
                                             F.when(F.col("TenureDays") < 180, "New")
                                             .when(F.col("TenureDays") < 365, "Medium")
                                             .otherwise("Loyal"))

    # Step 7: Flag High Value Transactions
    trans_step7_df = trans_step6_df.withColumn("HighValueFlag", F.col("TotalValue") > 2000)

    # Step 8: Assign Product Group
    trans_step8_df = trans_step7_df.withColumn("ProductGroup", 
                                               F.when(F.col("Product").isin(["A", "C"]), "Core")
                                               .otherwise("Non-Core"))

    # Step 9: Sort by Product Group
    sorted_final_data_df = trans_step8_df.orderBy("ProductGroup")

    # Step 10: Standardize Total Value
    standardized_df = sorted_final_data_df.withColumn("StandardizedTotalValue", 
                                                      (F.col("TotalValue") - F.mean("TotalValue").over(Window.partitionBy("ProductGroup"))) / 
                                                      F.stddev("TotalValue").over(Window.partitionBy("ProductGroup")))

    # Step 11: Flag Outliers
    enhanced_final_data_df = standardized_df.withColumn("OutlierFlag", F.abs(F.col("StandardizedTotalValue")) > 2)

    # Write final data to Unity Catalog table
    enhanced_final_data_df.write.format("delta").mode("overwrite").saveAsTable("genai_demo.sas.enhanced_final_data")

    # Generate reports
    # Frequency Analysis
    freq_df = enhanced_final_data_df.groupBy("TenureCategory", "Region").count()
    freq_df.show()

    # Mean and Sum Statistics
    means_df = enhanced_final_data_df.groupBy("Region", "ProductGroup").agg(
        F.mean("TotalValue").alias("MeanTotalValue"),
        F.sum("TotalValue").alias("SumTotalValue"),
        F.mean("Quantity").alias("MeanQuantity"),
        F.sum("Quantity").alias("SumQuantity"),
        F.mean("Sales").alias("MeanSales"),
        F.sum("Sales").alias("SumSales")
    )
    means_df.show()

    # Correlation Analysis
    corr_df = enhanced_final_data_df.select("Sales", "Discount", "Quantity", "TotalValue").corr()
    logger.info(f"Correlation Matrix: {corr_df}")

    # Outlier Summary
    outlier_summary_df = enhanced_final_data_df.groupBy("OutlierFlag").agg(
        F.mean("Sales").alias("MeanSales"),
        F.stddev("Sales").alias("StdDevSales"),
        F.mean("TotalValue").alias("MeanTotalValue"),
        F.stddev("TotalValue").alias("StdDevTotalValue"),
        F.mean("Quantity").alias("MeanQuantity"),
        F.stddev("Quantity").alias("StdDevQuantity")
    )
    outlier_summary_df.show()

except Exception as e:
    logger.error(f"An error occurred: {e}")
