In [0]:
%run ../jobs/transformer

In [0]:
# tests/test_transformer
import unittest
import pytest
from pyspark.sql import SparkSession
from pyspark.sql.types import StringType, StructField, IntegerType, StringType,StructType
import pytest
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, DateType
from pyspark.sql import functions as F


class TestTransformer(unittest.TestCase):
    # Schemas for later use
    orders_schema = StructType([
        StructField("order_id", IntegerType(), True),
        StructField("order_date", StringType(), True),
        StructField("ship_date", StringType(), True),
        StructField("customer_id", IntegerType(), True),
        StructField("product_id", IntegerType(), True),
        StructField("profit", DoubleType(), True),
    ])

    customers_schema = StructType([
        StructField("customer_id", IntegerType(), True),
        StructField("customer_name", StringType(), True),
        StructField("country", StringType(), True),
    ])

    products_schema = StructType([
        StructField("product_id", IntegerType(), True),
        StructField("category", StringType(), True),
        StructField("sub_category", StringType(), True),
])
    @classmethod
    def setUpClass(cls):
        cls.spark = SparkSession.builder.appName("TransformerTest").getOrCreate()
        cls.transformer = Transformer(spark)

    def test_transform_orders_date_parse(self):
        data = [("1", "1/1/2020", "2/1/2020", 100.5)]
        df = self.spark.createDataFrame(data, ["order_date", "ship_date", "profit"])
        result = self.transformer.transform_orders(df)
        self.assertIn("order_date", result.columns)
        self.assertIn("ship_date", result.columns)


    def test_valid_transformation(self):
        data = [("1/2/2023", "5/2/2023", "100.50")]
        df = self.spark.createDataFrame(data, ["order_date", "ship_date", "profit"])
        result = self.transformer.transform_orders(df).collect()[0]

        assert str(result["order_date"]) == "2023-02-01"
        assert str(result["ship_date"]) == "2023-02-05"
        assert abs(result["profit"] - 100.50) < 0.001

    def test_multiple_rows(self):
        data = [
            ("1/1/2023", "2/1/2023", "50"),
            ("10/12/2022", "15/12/2022", "75.25")
        ]
        df = self.spark.createDataFrame(data, ["order_date", "ship_date", "profit"])
        result = self.transformer.transform_orders(df).collect()

        assert len(result) == 2
        assert str(result[0]["order_date"]) == "2023-01-01"
        assert result[1]["profit"] == 75.25


    # negative test case
    def test_missing_column(self):
        data = [("1/1/2023", "100")]
        df = self.spark.createDataFrame(data, ["order_date", "profit"])  # Missing Ship_Date

        with pytest.raises(Exception):  
            self.transformer.transform_orders(df).collect()
    
    def test_invalid_date_format(self):
        data = [("2023-01-01", "2023-01-02", "200")]
        df = self.spark.createDataFrame(data, ["order_date", "ship_date", "profit"])
        result = self.transformer.transform_orders(df).collect()[0]

        # Invalid format → Spark can't parse → returns None
        assert result["order_date"] is None
        assert result["ship_date"] is None

    def test_non_numeric_profit(self):
        data = [("1/1/2023", "2/1/2023", "abc")]
        df = self.spark.createDataFrame(data, ["order_date", "ship_date", "profit"])
        result = self.transformer.transform_orders(df).collect()[0]

        # Casting fails → becomes null
        assert result["profit"] is None

    #-------------Edge cases--------------
    def test_empty_dataframe(self):
        schema = StructType([
        StructField("order_date", StringType(), True),
        StructField("ship_date", StringType(), True),
        StructField("profit", StringType(), True)])
        df = self.spark.createDataFrame([], schema=schema)
        result = self.transformer.transform_orders(df)
        assert result.count() == 0

    def test_null_values(self):
        schema = StructType([
        StructField("order_date", StringType(), True),
        StructField("ship_date", StringType(), True),
        StructField("profit", StringType(), True)])

        data = [(None, "2/1/2023", "100")]
        df = self.spark.createDataFrame(data, schema=schema)
        result = self.transformer.transform_orders(df).collect()[0]

        assert result["order_date"] is None
        assert str(result["ship_date"]) == "2023-01-02"
        assert result["profit"] == 100.0

    def test_enrich_orders_left_join(self):
        orders = self.spark.createDataFrame([("O1", "C1", "P1", 50.0,"1/1/2020","1/1/2020")], ["order_id", "customer_id", "product_id", "profit","order_date","ship_date"])
        customers = self.spark.createDataFrame([("C1", "Alice", "India")], ["customer_id", "customer_name", "country"])
        products = self.spark.createDataFrame([("P1", "Furniture", "Chair")], ["product_id", "category", "sub_category"])

        result = self.transformer.enrich_orders(orders, customers, products)
        self.assertEqual(result.count(), 1)
        # result.show(truncate=False)
        self.assertEqual(result.first()["customer_name"], "Alice")


    #--------------normal standard case--------------------
    def test_standard_enrichment(self):
        orders = spark.createDataFrame(
            [(1, "2023-01-01", "2023-01-05", 100, 200, 123.456)],
            schema=self.orders_schema,
        )
        customers = spark.createDataFrame(
            [(100, "Alice", "USA")],
            schema=self.customers_schema,
        )
        products = spark.createDataFrame(
            [(200, "Electronics", "Mobile")],
            schema=self.products_schema,
        )

        result = self.transformer.enrich_orders(orders, customers, products).collect()[0]

        assert result["order_id"] == 1
        assert result["customer_name"] == "Alice"
        assert result["country"] == "USA"
        assert result["category"] == "Electronics"
        assert result["sub_category"] == "Mobile"
        assert result["profit"] == 123.46   # rounded
        assert result["year"] == 2023

    #----------multiple orders case ------------------- 
    def test_multiple_orders(self):
        orders = self.spark.createDataFrame(
            [
                (1, "2023-01-01", "2023-01-05", 100, 200, 10.123),
                (2, "2022-02-01", "2022-02-03", 101, 201, 20.456),
            ],
            schema=self.orders_schema,
        )
        customers = self.spark.createDataFrame(
            [(100, "Alice", "USA"), (101, "Bob", "Canada")],
            schema=self.customers_schema,
        )
        products = self.spark.createDataFrame(
            [(200, "Electronics", "Mobile"), (201, "Furniture", "Chair")],
            schema=self.products_schema,
        )

        result = self.transformer.enrich_orders( orders, customers, products).collect()

        assert len(result) == 2
        assert set([row["customer_name"] for row in result]) == {"Alice", "Bob"}
        assert set([row["category"] for row in result]) == {"Electronics", "Furniture"}

    #--------------negative test case-------------------
    def test_missing_customer(self):
        orders = self.spark.createDataFrame(
            [(1, "2023-01-01", "2023-01-05", 999, 200, 50.0)],
            schema=self.orders_schema,
        )
        customers = self.spark.createDataFrame([], schema=self.customers_schema)  # empty
        products = self.spark.createDataFrame([(200, "Electronics", "Laptop")], schema=self.products_schema)

        result = self.transformer.enrich_orders( orders, customers, products).collect()[0]

        assert result["customer_name"] is None
        assert result["country"] is None
        assert result["category"] == "Electronics"

    def test_missing_product(self):
        orders = self.spark.createDataFrame(
            [(1, "2023-01-01", "2023-01-05", 100, 999, 75.0)],
            schema=self.orders_schema,
        )
        customers = spark.createDataFrame([(100, "Charlie", "UK")], schema=self.customers_schema)
        products = spark.createDataFrame([], schema=self.products_schema)  # empty

        result = self.transformer.enrich_orders(orders, customers, products).collect()[0]

        assert result["customer_name"] == "Charlie"
        assert result["category"] is None
        assert result["sub_category"] is None

    #----------------empty order test--------------------
    def test_empty_orders(self):
        orders = self.spark.createDataFrame([], schema=self.orders_schema)
        customers = self.spark.createDataFrame([(100, "Alice", "USA")], schema=self.customers_schema)
        products = self.spark.createDataFrame([(200, "Electronics", "TV")], schema=self.products_schema)

        result = self.transformer.enrich_orders(orders, customers, products)

        assert result.count() == 0

    #-------------without profit data--------------------
    def test_null_profit(self):
        orders = self.spark.createDataFrame(
            [(1, "2023-01-01", "2023-01-05", 100, 200, None)],
            schema=self.orders_schema,
        )
        customers = self.spark.createDataFrame([(100, "Alice", "USA")], schema=self.customers_schema)
        products = self.spark.createDataFrame([(200, "Electronics", "Mobile")], schema=self.products_schema)

        result = self.transformer.enrich_orders( orders, customers, products).collect()[0]

        assert result["profit"] is None

    #------------with null date----------------------
    def test_null_dates(self):
        orders = self.spark.createDataFrame(
            [(1, None, None, 100, 200, 50.0)],
            schema=self.orders_schema,
        )
        customers = self.spark.createDataFrame([(100, "Alice", "USA")], schema=self.customers_schema)
        products = self.spark.createDataFrame([(200, "Electronics", "Mobile")], schema=self.products_schema)

        result = self.transformer.enrich_orders(orders, customers, products).collect()[0]

        assert result["year"] is None

if __name__ == "__main__":
    unittest.main(argv=['first-arg-is-ignored'], exit=False)


  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
  self._sock = None
.
----------------------------------------------------------------------
Ran 16 tests in 16.749s

OK
