In [0]:
%run ../jobs/reader

In [0]:
%run ../jobs/data_quality

In [0]:
%run ../jobs/transformer

In [0]:
%run ../jobs/aggregator

In [0]:
%run ../jobs/utility

In [0]:
import unittest
from pyspark.sql import SparkSession
from datetime import date
from delta.tables import DeltaTable

class TestUtility(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.spark = (
            SparkSession.builder
            .appName("test_utility")
            .getOrCreate()
        )
        cls.utility = Utility(cls.spark)
        cls.aggregator = Aggregator(cls.spark)
        cls.transformer = Transformer(cls.spark)

        # create a reusable test dataframe
        data = [("1", "Abhi"), ("2", "Shashi")]
        columns = ["id", "name"]
        cls.df = cls.spark.createDataFrame(data, columns)

        test_data_path = "dbfs:/FileStore/Mock/customer_sample_mock.csv"
        cls.enrich_df = cls.spark.read.option("inferSchema", True).csv(test_data_path, header=True)

        data = [
            (2020, "Electronics", 101, "Alice", 100.125),
            (2020, "Clothing",   102, "Bob",   200.456),
            (2021, "Electronics", 101, "Alice", 300.789),
            (2021, "Clothing",   102, "Bob",   400.111),
        ]
        cls.test_df = cls.spark.createDataFrame(
            data, ["year", "category", "customer_id", "customer_name", "profit"]
        )
        cls.test_df.createOrReplaceTempView("enriched_test")


    def sample_data_enrich(self):
        orders = [
            # Alice's orders
            (1001, date(2022, 5, 1), date(2022, 5, 5), 1, 101, 123.456),
            (1002, date(2023, 3, 15), date(2023, 3, 20), 1, 101, 300.111),
            (1003, date(2024, 4, 1), date(2024, 4, 7), 1, 104, 850.777),

            # Bob's orders
            (1004, date(2022, 6, 2), date(2022, 6, 10), 2, 102, 456.789),
            (1005, date(2024, 1, 10), date(2024, 1, 15), 2, 102, 250.499),

            # Charlie's order
            (1006, date(2023, 7, 18), date(2023, 7, 25), 3, 103, 789.555),

            # Diana's order
            (1007, date(2024, 2, 5), date(2024, 2, 12), 4, 104, 999.999),
        ]

        customers = [
            (1, "Alice", "USA"),
            (2, "Bob", "India"),
            (3, "Charlie", "UK"),
            (4, "Diana", "Germany"),
        ]

        products = [
            (101, "Electronics", "Mobile"),
            (102, "Clothing", "Shirts"),
            (103, "Furniture", "Chairs"),
            (104, "Electronics", "Laptop"),
        ]

        orders_df = self.spark.createDataFrame(
            orders, ["order_id", "order_date", "ship_date", "customer_id", "product_id", "profit"]
        )
        customers_df = self.spark.createDataFrame(
            customers, ["customer_id", "customer_name", "country"]
        )
        products_df = self.spark.createDataFrame(
            products, ["product_id", "category", "sub_category"]
        )

        return orders_df, customers_df, products_df
    
    #============ Task 1 test cases ==========================================
    
    def test_create_raw_table_exist(self):
        """
        Validates table existence after creation
        """
        table_name = self.utility.create_delta_table(self.df, "test_table1")
        tables = [t.name for t in spark.catalog.listTables()]
        assert table_name in tables

    def test_create_raw_table(self):
        """
        Tests if table is a delta table
        """
        table_name = self.utility.create_delta_table(self.df, "test_table2")
        delta_tbl = DeltaTable.forName(self.spark, table_name)
        self.assertIsNotNone(delta_tbl)

    def test_table_schema(self):
        """
        Tests schema of the table created
        """
        table_name = self.utility.create_delta_table(self.df, "test_table3")
        schema_fields = [f.name for f in self.spark.table(table_name).schema.fields]
        self.assertListEqual(schema_fields, ["id", "name"])

    def test_table_data_count(self):
        """
        Tests record count of table created
        """
        table_name = self.utility.create_delta_table(self.df, "test_table3")
        row_count = self.spark.table(table_name).count()
        self.assertEqual(row_count, 2)

    # ================Task:2 test cases =================================== 
    def sample_data_enrich2(self):
        # Customers
        customers = [
            (1, "Alice", "alice@mail.com", "111", "Addr1", "Consumer", "USA", "New York", "NY", "10001", "East"),
            (2, "Bob", "bob@mail.com", "222", "Addr2", "Corporate", "USA", "Los Angeles", "CA", "90001", "West"),
            (3, "Charlie", "charlie@mail.com", "333", "Addr3", "Home Office", "USA", "Chicago", "IL", "60007", "Central")
        ]
        customers_df = self.spark.createDataFrame(customers, 
            ["customer_id", "customer_name", "email", "phone", "address", 
            "segment", "country", "city", "state", "postal_code", "region"])

        # Products
        products = [
            (101, "Furniture", "Chairs", "Office Chair", "NY", 120.0),
            (102, "Technology", "Phones", "iPhone", "CA", 999.0),
            (103, "Office Supplies", "Paper", "A4 Paper", "IL", 10.0)
        ]
        products_df = self.spark.createDataFrame(products, 
            ["product_id", "category", "sub_category", "product_name", "state", "price_per_product"])

        # Orders (with REPEATED purchases)
        orders = [
            # Alice orders multiple times
            (1001, date(2023,1,1), date(2023,1,5), "Second Class", 1, 101, 2, 120.0, 10.0, 30.0),   # Office Chair
            (1002, date(2023,1,2), date(2023,1,6), "First Class", 1, 102, 1, 999.0, 50.0, 200.0),   # iPhone
            (1005, date(2023,3,1), date(2023,3,5), "Standard", 1, 103, 5, 10.0, 2.0, 5.0),          # A4 Paper
            (1006, date(2023,3,10), date(2023,3,15), "Standard", 1, 101, 1, 120.0, 0.0, 15.0),      # Office Chair again

            # Bob orders the same product twice
            (1003, date(2023,2,1), date(2023,2,3), "Standard", 2, 103, 10, 10.0, 0.0, 20.0),        # A4 Paper
            (1007, date(2023,2,15), date(2023,2,18), "Second Class", 2, 103, 15, 10.0, 1.0, 30.0), # A4 Paper again

            # Charlie orders once
            (1004, date(2023,2,5), date(2023,2,8), "Second Class", 3, 101, 1, 120.0, 0.0, 15.0),   # Office Chair
        ]
        orders_df = self.spark.createDataFrame(orders, 
            ["order_id", "order_date", "ship_date", "ship_mode", "customer_id", 
            "product_id", "quantity", "price", "discount", "profit"])

        return customers_df, orders_df, products_df
        
    def test_enriched_customer_table(self):
        customers_df, orders_df, products_df = self.sample_data_enrich2()
        
        enriched_customers = self.transformer.create_enriched_customer_table(customers_df, orders_df, products_df)

        # Ensure enrichment columns exist
        assert "total_orders" in enriched_customers.columns
        assert "fav_category" in enriched_customers.columns
        
        # Fetch Alice’s metrics (customer_id=1)
        alice = enriched_customers.filter("customer_id = 1").collect()[0]
        assert alice.total_orders == 4        # Alice has 4 orders total
        assert alice.total_quantity == 9      # Quantities: 2 + 1 + 5 + 1
        assert alice.fav_category == "Furniture"  # Alice bought Office Chair twice

        # Fetch Bob’s metrics (customer_id=2)
        bob = enriched_customers.filter("customer_id = 2").collect()[0]
        assert bob.total_orders == 2          # Bob has 2 orders
        assert bob.total_quantity == 25       # Quantities: 10 + 15
        assert bob.fav_sub_category == "Paper"  # Both orders are A4 Paper

    def test_enriched_product_table(self):
        customers_df, orders_df, products_df = self.sample_data_enrich2()
        
        enriched_products = self.transformer.create_enriched_product_table(customers_df, orders_df, products_df)

        # Ensure enrichment columns exist
        assert "total_orders" in enriched_products.columns
        assert "best_region" in enriched_products.columns
        
        # Fetch Office Chair metrics (product_id=101)
        chair = enriched_products.filter("product_id = 101").collect()[0]
        assert chair.total_orders == 3            # Alice(2) + Charlie(1)
        assert chair.total_quantity_sold == 4     # Quantities: 2 + 1 + 1
        assert chair.best_region == "East"        # Alice from East bought most

        # Fetch Paper metrics (product_id=103)
        paper = enriched_products.filter("product_id = 103").collect()[0]
        assert paper.total_orders == 3            # Bob(2) +  Alice(1)
        assert paper.total_quantity_sold == 30    # Quantities: 5 (Alice) + 10 + 15 (Bob)
        assert paper.distinct_customers ==  2     # Alice & Bob both bought Paper


    # ================Task:3 test cases =================================== 
    def test_enriched_orders_rowcount(self):
        orders, customers, products = self.sample_data_enrich()
        enriched = self.transformer.enrich_orders(orders, customers, products)
        self.assertEqual(enriched.count(), 7)  # 7 total orders

    def test_profit_rounding(self):
        orders, customers, products = self.sample_data_enrich()
        enriched = self.transformer.enrich_orders(orders, customers, products)
        profits = {row.order_id: row.profit for row in enriched.collect()}
        self.assertEqual(profits[1001], 123.46)
        self.assertEqual(profits[1004], 456.79)
        self.assertEqual(profits[1007], 1000.00)

    def test_groupby_year_total_profit(self):
        orders, customers, products = self.sample_data_enrich()
        enriched = self.transformer.enrich_orders(orders, customers, products)
        yearly = (enriched.groupBy("year")
                  .agg(F.round(F.sum("profit"), 2).alias("total_profit"))
                  .orderBy("year"))
        result = [(r["year"], r["total_profit"]) for r in yearly.collect()]
        expected = [
            (2022, 580.25),   # 123.46 (Alice) + 456.79 (Bob)
            (2023, 1089.67),  # 300.11 (Alice) + 789.56 (Charlie)
            (2024, 2101.28),  # 850.78 (Alice) + 250.50 (Bob) + 1000.00 (Diana)
        ]
        self.assertEqual(result, expected)

    def test_customer_order_counts(self):
        orders, customers, products = self.sample_data_enrich()
        enriched = self.transformer.enrich_orders(orders, customers, products)
        counts = (enriched.groupBy("customer_name")
                           .count()
                           .orderBy("customer_name"))
        result = {r["customer_name"]: r["count"] for r in counts.collect()}
        expected = {
            "Alice": 3,   # 3 orders across years
            "Bob": 2,
            "Charlie": 1,
            "Diana": 1,
        }
        self.assertEqual(result, expected)

    def test_distinct_categories(self):
        orders, customers, products = self.sample_data_enrich()
        enriched = self.transformer.enrich_orders(orders, customers, products)
        categories = {row.category for row in enriched.select("category").distinct().collect()}
        expected = {"Electronics", "Clothing", "Furniture"}
        self.assertEqual(categories, expected)

    # ================Task:4 test cases ===================================
    def test_aggregate_with_query_year(self):
        agg_data = self.aggregator.aggregate_profit(self.enrich_df)
        cust1 = agg_data.filter( (col("customer_id")==1) & (col("year")==2025))
        cust2 = agg_data.filter( (col("customer_id")==2) & (col("year")==2024))
        row = cust1.collect()[0].asDict()
        row2 = cust2.collect()[0].asDict()
        self.assertEqual(cust1.count(),1)
        self.assertEqual(row["year"], 2025)
        self.assertEqual(row["category"], "furniture")
        self.assertEqual(row["Sub_Category"],"chair")
        self.assertEqual(row["customer_name"], "Abhi")
        self.assertEqual(row["profit_sum"],60)
        self.assertEqual(row2["profit_sum"],50)

    #=============== Task:5 test cases ================================== 
    def test_year_query(self):
        query = """
            SELECT year, ROUND(SUM(profit),2) AS total_profit
            FROM enriched_test
            GROUP BY year
            ORDER BY year
        """
        result = self.aggregator.aggregate_with_query(query).collect()
        result_dict = {row["year"]: row["total_profit"] for row in result}

        assert result_dict[2020] == 300.58  # 100.125 + 200.456
        assert result_dict[2021] == 700.90  # 300.789 + 400.111
        assert len(result_dict) == 2

    def test_year_category_query(self):
        query = """
            SELECT year, category, ROUND(SUM(profit),2) AS total_profit
            FROM enriched_test
            GROUP BY year, category
            ORDER BY year
        """
        result = self.aggregator.aggregate_with_query(query).collect()
        result_dict = {(row["year"], row["category"]): row["total_profit"] for row in result}
        print("test_year_category_query")
        assert result_dict[(2020, "Electronics")] == 100.13
        assert result_dict[(2020, "Clothing")] == 200.46
        assert result_dict[(2021, "Electronics")] == 300.79
        assert result_dict[(2021, "Clothing")] == 400.11
        assert len(result_dict) == 4

    def test_customer_query(self):
        query = """
            SELECT customer_id, customer_name, ROUND(SUM(profit),2) AS total_profit
            FROM enriched_test
            GROUP BY customer_id, customer_name
        """
        result = self.aggregator.aggregate_with_query(query).collect()
        result_dict = {(row["customer_id"], row["customer_name"]): row["total_profit"] for row in result}

        assert result_dict[(101, "Alice")] == 400.91  # 100.125 + 300.789
        assert result_dict[(102, "Bob")] == 600.57    # 200.456 + 400.111
        assert len(result_dict) == 2

    def test_year_customer_query(self):
        query = """
            SELECT year, customer_id, customer_name, ROUND(SUM(profit),2) AS total_profit
            FROM enriched_test
            GROUP BY year, customer_id, customer_name
            ORDER BY year
        """
        result = self.aggregator.aggregate_with_query(query).collect()
        result_dict = {(row["year"], row["customer_id"], row["customer_name"]): row["total_profit"] for row in result}

        assert result_dict[(2020, 101, "Alice")] == 100.13
        assert result_dict[(2020, 102, "Bob")] == 200.46
        assert result_dict[(2021, 101, "Alice")] == 300.79
        assert result_dict[(2021, 102, "Bob")] == 400.11
        assert len(result_dict) == 4

if __name__ == "__main__":
    unittest.main(argv=[''], exit=False)


  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
.

test_year_category_query


  self._sock = None
  self._sock = None
.
----------------------------------------------------------------------
Ran 16 tests in 63.368s

OK
