In [0]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, sum as spark_sum,round
from pyspark.sql.types import NumericType
from pyspark.sql import SparkSession 


class Aggregator:
    def __init__(self, spark: SparkSession):
        self.spark = spark
        
    def aggregate_profit_by_year(self, df: DataFrame) -> DataFrame:
        """
        Aggregate profit by year.

        Args:
            df (DataFrame): Input Spark DataFrame with columns 'year' and 'profit'.

        Returns:
            DataFrame: Aggregated DataFrame with profit sum per year (rounded to 2 decimals).

        Raises:
            TypeError: If df is not a Spark DataFrame or if profit column is not numeric.
            ValueError: If required columns are missing or DataFrame is empty.
        """
        # --- Validate input type ---
        if not isinstance(df, DataFrame):
            raise TypeError("Input must be a Spark DataFrame")

        # --- Validate DataFrame not empty ---
        if not df.columns:
            raise ValueError("Input DataFrame is empty (no columns)")

        # --- Validate required columns exist ---
        required_cols = {"year", "profit"}
        df_cols = set(df.columns)
        missing = required_cols - df_cols
        if missing:
            raise ValueError(f"Missing required columns: {missing}")

        # --- Validate profit column is numeric ---
        profit_field = [f for f in df.schema.fields if f.name == "profit"][0]
        if not isinstance(profit_field.dataType, NumericType):
            raise TypeError("Column 'profit' must be numeric")

        # --- Perform aggregation ---
        try:
            return df.groupBy("year").agg(
                round(spark_sum("profit"), 2).alias("profit_sum")
            )
        except Exception as e:
            raise RuntimeError(f"Failed to aggregate profit by year: {str(e)}")

    def aggregate_profit_by_customer(self, df: DataFrame) -> DataFrame:
        """
        Aggregate profit by customer.

        Args:
            df (DataFrame): Input Spark DataFrame with columns 
                            'customer_id', 'customer_name', and 'profit'.

        Returns:
            DataFrame: Aggregated DataFrame with profit sum per customer (rounded to 2 decimals).

        Raises:
            TypeError: If df is not a Spark DataFrame or if 'profit' column is not numeric.
            ValueError: If required columns are missing or DataFrame is empty.
            RuntimeError: If Spark aggregation fails.
        """
        # --- Validate input type ---
        if not isinstance(df, DataFrame):
            raise TypeError("Input must be a Spark DataFrame")

        # --- Validate DataFrame not empty ---
        if not df.columns:
            raise ValueError("Input DataFrame is empty (no columns)")

        # --- Validate required columns exist ---
        required_cols = {"customer_id", "customer_name", "profit"}
        df_cols = set(df.columns)
        missing = required_cols - df_cols
        if missing:
            raise ValueError(f"Missing required columns: {missing}")

        # --- Validate profit column is numeric ---
        profit_field = [f for f in df.schema.fields if f.name == "profit"][0]
        if not isinstance(profit_field.dataType, NumericType):
            raise TypeError("Column 'profit' must be numeric")

        # --- Perform aggregation ---
        try:
            return df.groupBy("customer_id", "customer_name").agg(
                round(spark_sum("profit"), 2).alias("profit_sum")
            )
        except Exception as e:
            raise RuntimeError(f"Failed to aggregate profit by customer: {str(e)}")
                

    def aggregate_profit_by_year_and_category(self, df: DataFrame) -> DataFrame:
        try:
            # --- Validate input type ---
            if not isinstance(df, DataFrame):
                raise TypeError("Input must be a Spark DataFrame")

            # --- Validate required columns exist ---
            required_cols = {"year", "product_category", "profit"}
            df_cols = set(df.columns)
            missing = required_cols - df_cols
            if missing:
                raise ValueError(f"Missing required columns: {missing}")

            # --- Validate profit is numeric ---
            profit_field = [f for f in df.schema.fields if f.name == "profit"][0]
            if not isinstance(profit_field.dataType, NumericType):
                raise TypeError("Column 'profit' must be numeric")

            # --- Perform aggregation ---
            return df.groupBy("year", "product_category").agg(
                round(spark_sum("profit"), 2).alias("profit_sum")
            )

        except Exception as e:
            # Catch unexpected errors and re-raise with context
            raise RuntimeError(f"Error in aggregate_profit_by_year_and_category: {str(e)}") from e


    def aggregate_profit(self, enriched_orders: DataFrame) -> DataFrame:
        """
        Aggregate profit by year, category, sub-category, and customer.

        Args:
            enriched_orders (DataFrame): Input Spark DataFrame containing
                'year', 'category', 'Sub_Category', 'customer_id',
                'customer_name', and 'profit'.

        Returns:
            DataFrame: Aggregated DataFrame with profit sums per grouping
                    (rounded to 2 decimals).

        Raises:
            TypeError: If enriched_orders is not a Spark DataFrame or 'profit' is not numeric.
            ValueError: If required columns are missing or DataFrame is empty.
            RuntimeError: If Spark aggregation fails.
        """
        # --- Validate input type ---
        if not isinstance(enriched_orders, DataFrame):
            raise TypeError("Input must be a Spark DataFrame")

        # --- Validate DataFrame is not empty ---
        if not enriched_orders.columns:
            raise ValueError("Input DataFrame is empty (no columns)")

        # --- Validate required columns ---
        required_cols = {"year", "category", "sub_category", "customer_id", "profit"}
        df_cols = set(enriched_orders.columns)
        missing = required_cols - df_cols
        if missing:
            raise ValueError(f"Missing required columns: {missing}")

        # --- Validate profit column type ---
        profit_field = [f for f in enriched_orders.schema.fields if f.name == "profit"][0]
        if not isinstance(profit_field.dataType, NumericType):
            raise TypeError("Column 'profit' must be numeric")

        # --- Perform aggregation ---
        try:
            return enriched_orders.groupBy(
                "year", "category", "Sub_Category",
                "customer_id","customer_name"
            ).agg(round(spark_sum(col("profit")), 2).alias("profit_sum"))
        except Exception as e:
            raise RuntimeError(f"Failed to aggregate profit: {str(e)}")


    def aggregate_with_query(self,  query: str) -> DataFrame:
        """
        Run a SQL query on a table and return result.

        Args:
            query (str): SQL query to execute.

        Returns:
            DataFrame: Result of the SQL query.

        Raises:
            TypeError: If query inputs is not strings.
            ValueError: If query is empty.
            RuntimeError: If SQL execution fails.
        """
        # --- Input type validations ---
        if not isinstance(query, str):
            raise TypeError("Query must be a string")

        # --- Validate query ---
        if not query.strip():
            raise ValueError("Query cannot be empty")

        try:
            # Run the SQL query
            return spark.sql(query)

        except Exception as e:
            raise RuntimeError(f"Failed to execute SQL query: {str(e)}")
