In [1]:
import unittest
import pandas as pd
from common.features.filteration.common_filter import get_common_date_filter
from common.utils.utilities import call_conf
from pyspark.dbutils import DBUtils
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
dbutils = DBUtils(spark)
conf = call_conf()

ModuleNotFoundError: No module named 'common'

In [None]:
class TestPreprocessItem(unittest.Testcase):
    def test_preprocess_item(self):
        df_input = spark.createDataFrame(
            pd.DataFrame(
                {
                    "stg_item_cd": ["1016782", "94111", "839302", "908458"],
                    "stg_item_sub_segment_desc_txt": ["Regular", "Cheddar", "Prepared Meals", "Shower Gel"],
                    "stg_item_segment_desc_txt": ["Instant Powdered Chocolate Drinks", "Cheddar Cheese", "Prepared Meals", "Body Cleansing"],
                    "stg_item_category_desc_txt":["Instant Chocolate Drinks","Chilled","Cheese Blocks","Hand & Body Care"],
                    "stg_item_dept_desc_txt":["Beverages","Chilled","Grocery","Household"]
                    
                }
            )
        )

        df_expected = spark.createDataFrame(
            pd.DataFrame(
                {
                    "stg_item_cd": ["1016782", "1016782", "839302"],
                    "stg_item_sub_segment_desc_txt": ["Regular", "Cheddar", "Prepared Meals"],
                    "stg_item_segment_desc_txt": ["Instant Powdered Chocolate Drinks", "Cheddar Cheese", "Prepared Meals"],
                    "stg_item_dept_desc_txt":["Beverages","Chilled","Grocery"]
                }
            )
        )

        df_output =  preprocess_item_info(
            conf, df_input
        )

        df_expected_pd = (
            df_expected.toPandas()
            .sort_values(["stg_item_cd"])
            .reset_index(drop=True)
        )

        df_output_pd = (
            df_output.toPandas()
            .sort_values(["stg_item_cd"])
            .reset_index(drop=True)
        )

        # ensuring the columns have the same order
        df_output_pd = df_output_pd[df_expected_pd.columns]
        pd.testing.assert_frame_equal(df_output_pd, df_expected_pd, check_dtype=False)

In [5]:
def preprocess_item_info(conf,item_info_df):
    """
    preprocess the item info by filtering out Grocery, Beverages, Chilled department when item_info table is given
    This function filters item information based on department names given on configuration
    
    Parameters:
    conf (dict): Configuration dictionary containing the list of outlet names to filter.
    item_info_df (DataFrame): The input DataFrame containing item infomation data.

    Returns:
    preprocess_item_info: The preprocessed DataFrame with filtered  data.
    """
    # Filter the DataFrame by department
    preprocess_item_info = (
        item_info_df
        .filter(
            f.col("stg_item_dept_desc_txt").isin(conf["department_name"])
        )
    )
    return preprocess_item_info