In [None]:
import logging
from pyspark.sql import functions as F

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

try:
    # Node 1: Load customer_data from Unity Catalog
    logger.info("Loading customer_data from Unity Catalog")
    try:
        # Correct the catalog and schema names if necessary
        customer_data = spark.table("correct_catalog.correct_schema.customer_data")
        logger.info("Successfully loaded customer_data")
        customer_data.printSchema()
        logger.info(f"customer_data record count: {customer_data.count()}")
    except Exception as e:
        logger.error(f"Failed to load customer_data: {str(e)}")
        # Fallback to loading from CSV if table is not found
        logger.info("Attempting to load customer_data from CSV as fallback")
        customer_data = spark.read.csv("/mnt/data/customer_data.csv", header=True, inferSchema=True)
        logger.info("Successfully loaded customer_data from CSV")
        customer_data.printSchema()
        logger.info(f"customer_data record count: {customer_data.count()}")

    # Node 2: Load transaction_data from Unity Catalog
    logger.info("Loading transaction_data from Unity Catalog")
    try:
        # Correct the catalog and schema names if necessary
        transaction_data = spark.table("correct_catalog.correct_schema.transaction_data")
        logger.info("Successfully loaded transaction_data")
        transaction_data.printSchema()
        logger.info(f"transaction_data record count: {transaction_data.count()}")
    except Exception as e:
        logger.error(f"Failed to load transaction_data: {str(e)}")
        # Fallback to loading from CSV if table is not found
        logger.info("Attempting to load transaction_data from CSV as fallback")
        transaction_data = spark.read.csv("/mnt/data/transaction_data.csv", header=True, inferSchema=True)
        logger.info("Successfully loaded transaction_data from CSV")
        transaction_data.printSchema()
        logger.info(f"transaction_data record count: {transaction_data.count()}")

    # Node 3: Filter transaction_data
    logger.info("Filtering transaction_data where transaction_amount > 100")
    filtered_transaction_data = transaction_data.filter(F.col("transaction_amount") > 100)
    filtered_transaction_data.printSchema()
    logger.info(f"filtered_transaction_data record count: {filtered_transaction_data.count()}")

    # Node 4: Join customer_data with filtered transaction_data
    logger.info("Joining customer_data with filtered_transaction_data on customer_id")
    joined_data = customer_data.alias("cust").join(
        filtered_transaction_data.alias("trans"), F.col("cust.customer_id") == F.col("trans.customer_id"), "inner"
    ).select("cust.*", "trans.transaction_amount")
    joined_data.printSchema()
    logger.info(f"joined_data record count: {joined_data.count()}")

    # Node 5: Aggregate data
    logger.info("Aggregating joined_data to calculate total_spent per customer_id")
    aggregated_data = joined_data.groupBy("customer_id").agg(F.sum("transaction_amount").alias("total_spent"))
    aggregated_data.printSchema()
    logger.info(f"aggregated_data record count: {aggregated_data.count()}")

    # Node 6: Write to customer_spending_summary in Unity Catalog
    target_catalog = "correct_catalog"
    target_schema = "correct_schema"
    target_table = "customer_spending_summary"

    # Ensure schema exists before creating table
    logger.info(f"Ensuring schema {target_catalog}.{target_schema} exists")
    spark.sql(f"CREATE SCHEMA IF NOT EXISTS {target_catalog}.{target_schema}")
    logger.info(f"Schema {target_catalog}.{target_schema} ensured")

    # Write to Unity Catalog target table (overwrite mode handles table replacement)
    logger.info(f"Writing aggregated data to {target_catalog}.{target_schema}.{target_table}")
    try:
        aggregated_data.write.format("delta").mode("overwrite").saveAsTable(f"{target_catalog}.{target_schema}.{target_table}")
        logger.info("Write operation completed successfully")
    except Exception as e:
        logger.error(f"Failed to write to Unity Catalog: {str(e)}")
        # Fallback to writing to CSV if table write fails
        logger.info("Attempting to write aggregated data to CSV as fallback")
        aggregated_data.write.csv("/mnt/output/customer_spending_summary.csv", header=True)
        logger.info("Successfully wrote aggregated data to CSV")

except Exception as e:
    logger.error(f"An error occurred: {str(e)}")
    raise
