In [0]:
%run ../jobs/aggregator

In [0]:
import unittest
from pyspark.sql import SparkSession
import pytest
from pyspark.sql.types import StructType, StructField, IntegerType, DoubleType,StringType
from datetime import date


class TestAggregator(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.spark = SparkSession.builder.master("local[1]").appName("AggregatorTest").getOrCreate()
        cls.aggregator = Aggregator(spark)
        data = [("2020", "Furniture", "Chair", "C1", "Alice", 100.0),
                ("2020", "Furniture", "Chair", "C1", "Alice", 200.0)]
        cls.enriched_orders = cls.spark.createDataFrame(
            data, ["year", "category", "sub_category", "customer_id", "customer_name", "profit"]
        )
    
    def test_aggregate_profit(self):
        result = self.aggregator.aggregate_profit(self.enriched_orders)
        self.assertEqual(result.collect()[0]["profit_sum"], 300.0)
    
    def test_single_year_aggregation(self):
        schema = StructType([
            StructField("year", IntegerType(), True),
            StructField("profit", DoubleType(), True),
        ])
        data = [(2023, 100.55), (2023, 200.45)]
        df = self.spark.createDataFrame(data, schema=schema)

        result = self.aggregator.aggregate_profit_by_year(df).collect()

        assert len(result) == 1
        assert result[0]["year"] == 2023
        assert result[0]["profit_sum"] == 301.0  # rounded to 2 decimals
    
    def test_multiple_years_aggregation(self):
        schema = StructType([
            StructField("year", IntegerType(), True),
            StructField("profit", DoubleType(), True),
        ])
        data = [(2022, 100.123), (2022, 200.456), (2023, 50.789)]
        df = self.spark.createDataFrame(data, schema=schema)

        result = {row["year"]: row["profit_sum"] for row in self.aggregator.aggregate_profit_by_year(df).collect()}

        assert result[2022] == 300.58  # (100.123 + 200.456) rounded
        assert result[2023] == 50.79


    def test_negative_and_zero_profits(self):
        schema = StructType([
            StructField("year", IntegerType(), True),
            StructField("profit", DoubleType(), True),
        ])
        data = [(2023, -50.25), (2023, 0.0), (2023, 100.25)]
        df = self.spark.createDataFrame(data, schema=schema)

        result = self.aggregator.aggregate_profit_by_year(df).collect()[0]

        assert result["profit_sum"] == 50.0  # (-50.25 + 0 + 100.25)
    
    # -------------------
    # Edge Cases
    # -------------------

    def test_empty_dataframe(self):
        schema = StructType([
            StructField("year", IntegerType(), True),
            StructField("profit", DoubleType(), True),
        ])
        df = self.spark.createDataFrame([], schema=schema)

        result = self.aggregator.aggregate_profit_by_year(df).collect()

        assert result == []


    def test_all_null_profits(self):
        schema = StructType([
            StructField("year", IntegerType(), True),
            StructField("profit", DoubleType(), True),
        ])
        data = [(2023, None), (2023, None)]
        df = self.spark.createDataFrame(data, schema=schema)

        result = self.aggregator.aggregate_profit_by_year(df).collect()

        # sum of nulls should be null
        assert result[0]["profit_sum"] is None


    def test_all_null_years(self):
        schema = StructType([
            StructField("year", IntegerType(), True),
            StructField("profit", DoubleType(), True),
        ])
        data = [(None, 100.0), (None, 200.0)]
        df = self.spark.createDataFrame(data, schema=schema)

        result = self.aggregator.aggregate_profit_by_year(df).collect()

        # GroupBy(None) still creates a single group with None as key
        assert result[0]["year"] is None
        assert result[0]["profit_sum"] == 300.0

    #--------negative test cases
    def test_missing_year_column(self):
        schema = StructType([
            StructField("profit", DoubleType(), True),
        ])
        df = self.spark.createDataFrame([(100.0,), (200.0,)], schema=schema)

        with pytest.raises(Exception):
            self.aggregator.aggregate_profit_by_year(None, df).collect()


    def test_invalid_profit_type(self):
        schema = StructType([
            StructField("year", IntegerType(), True),
            StructField("profit", IntegerType(), True),  # not DoubleType
        ])
        df = self.spark.createDataFrame([(2023, 100), (2023, 200)], schema=schema)

        # should still work, Spark can sum integers
        result = self.aggregator.aggregate_profit_by_year( df).collect()[0]

        assert result["profit_sum"] == 300.0


    def test_single_customer(self):
        schema = StructType([
            StructField("customer_id", StringType(), True),
            StructField("customer_name", StringType(), True),
            StructField("profit", DoubleType(), True),
        ])
        data = [("C1", "Alice", 100.55), ("C1", "Alice", 200.45)]
        df = self.spark.createDataFrame(data, schema=schema)

        result = self.aggregator.aggregate_profit_by_customer(df).collect()

        assert len(result) == 1
        assert result[0]["customer_id"] == "C1"
        assert result[0]["profit_sum"] == 301.0


    def test_multiple_customers(self):
        schema = StructType([
            StructField("customer_id", StringType(), True),
            StructField("customer_name", StringType(), True),
            StructField("profit", DoubleType(), True),
        ])
        data = [("C1", "Alice", 100.0), ("C2", "Bob", 200.0)]
        df = self.spark.createDataFrame(data, schema=schema)

        result = {row["customer_id"]: row["profit_sum"] for row in self.aggregator.aggregate_profit_by_customer(df).collect()}

        assert result["C1"] == 100.0
        assert result["C2"] == 200.0


    def test_rounding(self):
        schema = StructType([
            StructField("customer_id", StringType(), True),
            StructField("customer_name", StringType(), True),
            StructField("profit", DoubleType(), True),
        ])
        data = [("C1", "Alice", 10.126), ("C1", "Alice", 20.125)]
        df = self.spark.createDataFrame(data, schema=schema)

        result = self.aggregator.aggregate_profit_by_customer(df).collect()
        assert result[0]["profit_sum"] == 30.25   # rounded to 2 decimals


    def test_empty_dataframe(self):
        schema = StructType([
            StructField("customer_id", StringType(), True),
            StructField("customer_name", StringType(), True),
            StructField("profit", DoubleType(), True),
        ])
        df = self.spark.createDataFrame([], schema=schema)

        result = self.aggregator.aggregate_profit_by_customer(df).collect()
        assert result == []


    def test_zero_profits(self):
        schema = StructType([
            StructField("customer_id", StringType(), True),
            StructField("customer_name", StringType(), True),
            StructField("profit", DoubleType(), True),
        ])
        data = [("C1", "Alice", 0.0), ("C1", "Alice", 0.0)]
        df = self.spark.createDataFrame(data, schema=schema)

        result = self.aggregator.aggregate_profit_by_customer(df).collect()
        assert result[0]["profit_sum"] == 0.0


    def test_missing_columns(self):
        schema = StructType([
            StructField("id", StringType(), True),
            StructField("name", StringType(), True),
            StructField("value", DoubleType(), True),
        ])
        df = self.spark.createDataFrame([("1", "X", 10.0)], schema=schema)

        with pytest.raises(Exception):  # should fail due to missing customer_id
            self.aggregator.aggregate_profit_by_customer(df).collect()


    def test_large_profits(self):
        schema = StructType([
            StructField("customer_id", StringType(), True),
            StructField("customer_name", StringType(), True),
            StructField("profit", DoubleType(), True),
        ])
        data = [("C1", "Alice", 1e12), ("C1", "Alice", 1e12)]
        df = self.spark.createDataFrame(data, schema=schema)

        result = self.aggregator.aggregate_profit_by_customer(df).collect()
        assert result[0]["profit_sum"] == 2e12

    # --- Positive Cases ---

    def test_valid_aggregation(self):
        schema = StructType([
            StructField("year", IntegerType(), True),
            StructField("product_category", StringType(), True),
            StructField("profit", DoubleType(), True),
        ])
        data = [
            (2023, "Furniture", 100.55),
            (2023, "Furniture", 200.45),
            (2023, "Technology", 50.0),
            (2024, "Furniture", 300.0),
        ]
        df = self.spark.createDataFrame(data, schema=schema)
        result = self.aggregator.aggregate_profit_by_year_and_category(df).collect()

        expected = {
            (2023, "Furniture"): 301.0,   # rounded 100.55 + 200.45
            (2023, "Technology"): 50.0,
            (2024, "Furniture"): 300.0,
        }

        for row in result:
            assert expected[(row["year"], row["product_category"])] == row["profit_sum"]


    def test_single_group(self):
        schema = StructType([
            StructField("year", IntegerType(), True),
            StructField("product_category", StringType(), True),
            StructField("profit", DoubleType(), True),
        ])
        df = self.spark.createDataFrame([(2025, "Office Supplies", 999.99)], schema=schema)

        result = self.aggregator.aggregate_profit_by_year_and_category(df).collect()

        assert result[0]["year"] == 2025
        assert result[0]["product_category"] == "Office Supplies"
        assert result[0]["profit_sum"] == 999.99


    # --- Edge Cases ---

    def test_empty_dataframe(self):
        schema = StructType([
            StructField("year", IntegerType(), True),
            StructField("product_category", StringType(), True),
            StructField("profit", DoubleType(), True),
        ])
        df = self.spark.createDataFrame([], schema=schema)
        result = self.aggregator.aggregate_profit_by_year_and_category(df).collect()

        assert result == []


    def test_null_profit_values(self):
        schema = StructType([
            StructField("year", IntegerType(), True),
            StructField("product_category", StringType(), True),
            StructField("profit", DoubleType(), True),
        ])
        data = [(2023, "Furniture", None), (2023, "Furniture", 100.0)]
        df = self.spark.createDataFrame(data, schema=schema)
        result = self.aggregator.aggregate_profit_by_year_and_category(df).collect()

        # Should ignore null and just sum valid values
        assert result[0]["profit_sum"] == 100.0


    # --- Negative Cases ---

    def test_missing_columns(self):
        schema = StructType([
            StructField("year", IntegerType(), True),
            StructField("profit", DoubleType(), True),
        ])
        df = self.spark.createDataFrame([(2023, 100.0)], schema=schema)

        with pytest.raises(Exception):
            result = self.aggregator.aggregate_profit_by_year_and_category(df)
            result.collect()  # force evaluation


    def test_invalid_profit_type(self):
        schema = StructType([
            StructField("year", IntegerType(), True),
            StructField("product_category", StringType(), True),
            StructField("profit", StringType(), True),  # invalid type
        ])
        data = [(2023, "Furniture", "abc"), (2023, "Furniture", "xyz")]
        df = self.spark.createDataFrame(data, schema=schema)

        with pytest.raises(Exception):
            result = self.aggregator.aggregate_profit_by_year_and_category(df)
            result.collect()  # trigger Spark evaluation

if __name__ == "__main__":
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
.
----------------------------------------------------------------------
Ran 18 tests in 9.330s

OK
