In [1]:
import unittest
import pandas as pd
from common.features.filteration.common_filter import get_common_date_filter
from common.utils.utilities import call_conf
from pyspark.dbutils import DBUtils
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
dbutils = DBUtils(spark)
conf = call_conf()

ModuleNotFoundError: No module named 'common'

In [None]:
class TestSeasonalColumn(unittest.Testcase):
    def test_seasonal_column(self):
        df_input = spark.createDataFrame(
            pd.DataFrame(
                {
                    "week": ["2022-02-25", "2022-02-25", "2022-02-25"],
                    "stg_item_category_desc_txt": ["Cheese Slices", "Cheese Slices", "Cheese Slices"],
                    "stg_outlet_cd": ["B", "A", "C"],
                    "weekly_sales_qty":["11.0","14.0",".0"]
                    
                }
            )
        )

        df_expected = spark.createDataFrame(
            pd.DataFrame(
                {
                   "week": ["2022-02-25", "2022-02-25", "2022-04-25"],
                    "stg_item_category_desc_txt": ["Cheese Slices", "Cheese Slices", "Cheese Slices"],
                    "stg_outlet_cd": ["B", "A", "C"],
                    "weekly_sales_qty":["11.0","14.0","0"],
                    "month":["2","2","4"],
                    "year":["2022","2022","2022"],
                    "week_of_year":["8","8","11"],
                    "is_seasonal":["0","0","1"]
                }
            )
        )

        df_output = create_seasonal_column(
            final_df
        )

        df_expected_pd = (
            df_expected.toPandas()
            .sort_values(["stg_outlet_cd"])
            .reset_index(drop=True)
        )

        df_output_pd = (
            df_output.toPandas()
            .sort_values(["stg_outlet_cd"])
            .reset_index(drop=True)
        )

        # ensuring the columns have the same order
        df_output_pd = df_output_pd[df_expected_pd.columns]
        pd.testing.assert_frame_equal(df_output_pd, df_expected_pd, check_dtype=False)

In [None]:
def create_seasonal_column(final_df):
    """
    Creates a binary column indicating if the month is April (seasonal).

    Args:
        final_df: Input DataFrame containing a "week" column.

    Returns:
        seasonal_col_in_final_df: DataFrame with an additional "is_seasonal" column.
    """
    transaction_with_month_week = extract_month_year_week(final_df)
    
    # Create a binary column indicating if the month is April
    seasonal_col_in_final_df = (
        transaction_with_month_week
        .withColumn(
            "is_seasonal",
            f.when(
                f.col("month") == 4, 1).otherwise(0)
        )
    )
                               
                              
    
    return seasonal_col_in_final_df