In [0]:
# jobs/transformer
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, to_date, year, round as spark_round,sum, min,max,countDistinct,when,avg
from pyspark.sql.types import DoubleType
from pyspark.sql import DataFrame
from pyspark.sql.utils import AnalysisException
from pyspark.sql.window import Window


class Transformer:
    def __init__(self, spark: SparkSession):
        self.spark = spark

    def transform_orders(self, orders_df: DataFrame) -> DataFrame:
        """
        Transforms the orders DataFrame by:
        - Converting 'Order_Date' and 'Ship_Date' to date type
        - Casting 'Profit' to double type

        Parameters:
            orders_df (DataFrame): Input DataFrame with raw order data

        Returns:
            DataFrame: Transformed DataFrame with updated schema

        Raises:
            ValueError: If DataFrame is invalid or required columns are missing
            Exception: For unexpected transformation errors
        """
        try:
            # Validate input DataFrame
            if orders_df is None or not isinstance(orders_df, DataFrame):
                raise ValueError("Invalid DataFrame provided.")

            required_cols = ["order_date", "ship_date", "profit"]
            missing_cols = [c for c in required_cols if c not in orders_df.columns]
            if missing_cols:
                raise ValueError(f"Missing required columns: {', '.join(missing_cols)}")

            # Apply transformations
            transformed_df = (
                orders_df
                .withColumn("order_date", to_date(col("order_date"), "d/M/yyyy"))
                .withColumn("ship_date", to_date(col("ship_date"), "d/M/yyyy"))
                .withColumn("profit", col("profit").cast(DoubleType()))
            )

            return transformed_df

        except ValueError as ve:
            raise Exception(f"ValueError error during order enrichment: {ve}") from ve
        except Exception as e:
            raise Exception(f"Unexpected error during order transformation: {e}") from e

    def enrich_orders(self, orders: DataFrame, customers: DataFrame, products: DataFrame) -> DataFrame:
        """
        Enriches orders by joining with customers and products datasets.

        Parameters:
            orders (DataFrame): Orders DataFrame (must contain order_id, order_date, ship_date, customer_id, product_id, profit)
            customers (DataFrame): Customers DataFrame (must contain customer_id, customer_name, country)
            products (DataFrame): Products DataFrame (must contain product_id, category, sub_category)

        Returns:
            DataFrame: Enriched DataFrame with customer and product information.

        Raises:
            ValueError: If any input DataFrame is invalid or required columns are missing.
            Exception: For unexpected join or transformation errors.
        """
        try:
            # Validate input DataFrames
            for df, name in [(orders, "orders"), (customers, "customers"), (products, "products")]:
                if df is None or not isinstance(df, DataFrame):
                    raise ValueError(f"Invalid DataFrame provided for {name}.")

            # Required columns for each DataFrame
            required_cols = {
                "orders": ["order_id", "order_date", "ship_date", "customer_id", "product_id", "profit"],
                "customers": ["customer_id", "customer_name", "country"],
                "products": ["product_id", "category", "sub_category"],
            }

            # Check missing columns
            for df, name in [(orders, "orders"), (customers, "customers"), (products, "products")]:
                missing_cols = [c for c in required_cols[name] if c not in df.columns]
                if missing_cols:
                    raise ValueError(f"Missing required columns in {name} DataFrame: {', '.join(missing_cols)}")

            # Perform joins and select required fields
            enriched_df = (
                orders.alias("o")
                .join(customers.alias("c"), col("o.customer_id") == col("c.customer_id"), "inner")
                .join(products.alias("p"), col("o.product_id") == col("p.product_id"), "inner")
                .select(
                    col("o.order_id"),
                    col("order_date"),
                    col("ship_date"),
                    col("c.customer_id"),
                    col("c.customer_name"),
                    col("c.country"),
                    col("o.product_id"),
                    col("p.category"),
                    col("p.sub_category"),
                    spark_round(col("profit"), 2).alias("profit"),
                    year(col("order_date")).alias("year")
                )
            )

            return enriched_df

        except ValueError as ve:
            # Validation errors (bad schema, missing columns, invalid DataFrame)  
            # should be raised as-is so caller can handle them explicitly.
            raise Exception(f"ValueError error during order enrichment: {ve}") from ve
        except Exception as e:
            # Any unexpected Spark or join errors are wrapped with context for debugging.
            raise Exception(f"Unexpected error during order enrichment: {e}") from e




    def validate_dataframe(self,df: DataFrame, expected_cols: list, df_name: str):
        """Ensure DataFrame has required columns."""
        if not isinstance(df, DataFrame):
            raise TypeError(f"{df_name} must be a Spark DataFrame, got {type(df)}")
        
        missing = [col for col in expected_cols if col not in df.columns]
        if missing:
            raise ValueError(f"{df_name} is missing required columns: {missing}")


    def create_enriched_customer_table(self,customers_df: DataFrame, 
                                    orders_df: DataFrame, 
                                    products_df: DataFrame) -> DataFrame:
        """
        Create an enriched customer table with aggregated metrics.
        """
        try:
            # Validate schemas
            self.validate_dataframe(customers_df, 
                ["customer_id", "customer_name", "email", "phone", "address", 
                "segment", "country", "city", "state", "postal_code", "region"], 
                "customers_df")
            
            self.validate_dataframe(orders_df, 
                ["order_id", "order_date", "ship_date", "ship_mode", "customer_id", 
                "product_id", "quantity", "price", "discount", "profit"], 
                "orders_df")
            
            self.validate_dataframe(products_df, 
                ["product_id", "category", "sub_category", "product_name", "state", "price_per_product"], 
                "products_df")

            # Join orders with products
            orders_products = orders_df.join(products_df, "product_id", "left")

            # Aggregate per customer
            customer_metrics = (
                orders_products.groupBy("customer_id")
                .agg(
                    countDistinct("order_id").alias("total_orders"),
                    sum("quantity").alias("total_quantity"),
                    spark_round(sum((F.col("price") * F.col("quantity")) - F.col("discount")),2).alias("total_spent"),
                    spark_round(sum("discount"),2).alias("total_discount"),
                    spark_round(sum("profit"),2).alias("total_profit"),
                    spark_round((sum((F.col("price") * F.col("quantity")) - F.col("discount")) /
                    when(F.countDistinct("order_id") != 0, F.countDistinct("order_id"))
                    ),2).alias("avg_order_value"),
                    min("order_date").alias("first_order_date"),
                    max("order_date").alias("last_order_date"),
                )
            )

            # Add favorite category/sub_category (works on Spark <3.4 with workaround)
            fav_cat = (
                orders_products.groupBy("customer_id", "category")
                .count()
                .withColumn("rank", F.rank().over(Window.partitionBy("customer_id").orderBy(F.desc("count"))))
                .filter("rank = 1")
                .drop("rank", "count")
                .withColumnRenamed("category", "fav_category")
            )

            fav_subcat = (
                orders_products.groupBy("customer_id", "sub_category")
                .count()
                .withColumn("rank", F.rank().over(Window.partitionBy("customer_id").orderBy(F.desc("count"))))
                .filter("rank = 1")
                .drop("rank", "count")
                .withColumnRenamed("sub_category", "fav_sub_category")
            )

            # Final enriched customers
            enriched_customers = (
                customers_df.join(customer_metrics, "customer_id", "left")
                .join(fav_cat, "customer_id", "left")
                .join(fav_subcat, "customer_id", "left")
            )

            return enriched_customers

        except (AnalysisException, ValueError, TypeError) as e:
            raise RuntimeError(f"Error creating enriched customer table: {str(e)}")

    def create_enriched_product_table(self,customers_df: DataFrame, 
                                    orders_df: DataFrame, 
                                    products_df: DataFrame) -> DataFrame:
        """
        Create an enriched product table with aggregated metrics.
        """
        try:
            # Validate schemas
            self.validate_dataframe(customers_df, ["customer_id", "region"], "customers_df")
            self.validate_dataframe(orders_df, 
                ["order_id", "order_date", "customer_id", "product_id", "quantity", "price", "discount", "profit"], 
                "orders_df")
            self.validate_dataframe(products_df, 
                ["product_id", "category", "sub_category", "product_name", "state", "price_per_product"], 
                "products_df")

            # Aggregate per product
            product_metrics = (
                orders_df.groupBy("product_id")
                .agg(
                    countDistinct("order_id").alias("total_orders"),
                    sum("quantity").alias("total_quantity_sold"),
                    spark_round(sum((F.col("price") * F.col("quantity")) - F.col("discount")),2).alias("total_revenue"),
                    spark_round(sum("profit"),2).alias("total_profit"),
                    spark_round(avg("discount"),2).alias("avg_discount"),
                    countDistinct("customer_id").alias("distinct_customers")
                )
            )

            # Best region per product
            product_region = (
                orders_df.join(customers_df, "customer_id", "left")
                .groupBy("product_id", "region")
                .agg(F.sum("profit").alias("region_profit"))
            )

            best_region = (
                product_region
                .withColumn("rank", F.rank().over(Window.partitionBy("product_id").orderBy(F.desc("region_profit"))))
                .filter("rank = 1")
                .drop("rank", "region_profit")
                .withColumnRenamed("region", "best_region")
            )

            # Final enriched products
            enriched_products = (
                products_df
                .join(product_metrics, "product_id", "left")
                .join(best_region, "product_id", "left")
            )

            return enriched_products
        except (AnalysisException, ValueError, TypeError) as e:
            raise RuntimeError(f"Error creating enriched product table: {str(e)}")
