In [None]:
class TestSalesFeatures(unittest.Testcase):
    def test_sales_features(self):
        df_input = spark.createDataFrame(
            pd.DataFrame(
                {
                   "week": ["2022-02-25", "2022-01-28, "2022-02-04","2022-02-11","2022-02-18","2022-02-25"],
                    "stg_item_category_desc_txt": [" Ambient Instant N", "Ambient Instant N", "Ambient Instant N","Ambient Instant N"],
                    "stg_outlet_cd": ["A", "A", "A","A"],
                    "weekly_sales_qty":[" 588.0","  885.0","765.0","914.0"],
                    
                }
            )
        )

        df_expected = spark.createDataFrame(
            pd.DataFrame(
                {
                   "week": ["2022-02-25", "2022-01-28, "2022-02-04","2022-02-11","2022-02-18","2022-02-25"],
                    "stg_item_category_desc_txt": [" Ambient Instant N", "Ambient Instant N", "Ambient Instant N","Ambient Instant N"],
                    "stg_outlet_cd": ["A", "A", "A","A"],
                    "weekly_sales_qty":[" 588.0","  885.0","765.0","914.0"],
                    "avg_last_2_weeks":[" 598.5","653.0","653.0","638.3333333333334"],
                    "avg_last_3_weeks":["635.6666666666666","801.5","730.3333333333334","588.0"]",
                    "avg_last_4_weeks":[" 652.5","653.0"," 588.0","765.0"]
                }
            )
        )

        df_output = calculate_sales_features(
        primary_key_with_weekly_sales_df
        )

        df_expected_pd = (
            df_expected.toPandas()
            .sort_values(["stg_outlet_cd"])
            .reset_index(drop=True)
        )

        df_output_pd = (
            df_output.toPandas()
            .sort_values(["stg_outlet_cd"])
            .reset_index(drop=True)
        )

        # ensuring the columns have the same order
        df_output_pd = df_output_pd[df_expected_pd.columns]
        pd.testing.assert_frame_equal(df_output_pd, df_expected_pd, check_dtype=False)

In [None]:
def calculate_sales_features(primary_key_with_weekly_sales_df):
    """
    Calculates sales features (averages) for the specified window size.

    Args:
        df (DataFrame): Input DataFrame containing columns 'stg_outlet_cd', 'stg_item_category_desc_txt', 'week', and 'weekly_sales_qty'.
        window_size (int): Size of the rolling window for calculating averages (default is 4).

    Returns:
        DataFrame: A new DataFrame with additional columns for average sales over the specified window size.
    """
    window_spec = Window.partitionBy("stg_outlet_cd", "stg_item_category_desc_txt").orderBy("week")
   
    # Calculate averages for last 2, 3, and 4 weeks
    df = (
    primary_key_with_weekly_sales_df 
    .withColumn(
        "avg_last_2_weeks",
        f.avg(f.col("weekly_sales_qty")).over(window_spec.rowsBetween(-1, 0))
        )
    .withColumn(
        "avg_last_3_weeks", 
        f.avg(f.col("weekly_sales_qty")).over(window_spec.rowsBetween(-2, 0))
        ) 
    .withColumn(
        "avg_last_4_weeks",
        f.avg(f.col("weekly_sales_qty")).over(window_spec.rowsBetween(-3, 0))
        )
    )
    
    return df