In [2]:
import pyspark.sql.functions as F
from pyspark.sql import *
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, FloatType
import pandas as pd

from feature_engineering.engineering import engineerFeatures
from modelling.model_utils import train_model
from utils.utils import gapfilling, serialize


In [7]:
spark = SparkSession\
            .builder\
            .appName("test-app")\
            .getOrCreate()

spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

/usr/local/lib/python3.10/site-packages/pyspark/bin/load-spark-env.sh: line 68: ps: command not found
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/12/28 18:02:43 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [8]:
# [TODO]: move config dictionnaries to a json config file
features_config = {

    "discount_rate":{},
    "promoted_percent":{"promoted_hierarchy": "sku", "group_key":"subclass"},
    "week_of_year":{},
    "promo_category":{},
}

model_config={
    "model": "xgboost",
    "params": {},
    "hierarchy_columns": ["sku", "subclass", "store_id", "region_id"],
    "target": "units",
    "train_startDate": "2018-01-01",
    "train_endDate": "2020-01-01",
    "inference_startDate": "2019-11-01",
    "inference_endDate": "2020-12-21",
}

path = "data/"

In [6]:
# Union transcational data

schema = StructType(
    [StructField("customer_id", StringType(), True),
    StructField("week_index", StringType(), True),
    StructField("sku", StringType(), True),
    StructField("promo_cat", StringType(), True),
    StructField("discount", FloatType(), True),
    StructField("store_id", StringType(), True)],
)

transactions = spark.read.csv(
    "data/transactions_*.csv", 
    schema=schema,
    header=False
)

                                                                                

In [8]:
# Reading data
customers = spark.read.csv(
    "data/customers.csv", 
    header="true", 
    inferSchema="true")

calendar = spark.read.csv(
    "data/calendar.csv", 
    header="true", 
    inferSchema="true")

products = spark.read.csv(
    "data/products.csv", 
    header="true", 
    inferSchema="true")

stores = spark.read.csv(
    "data/stores.csv", 
    header="true", 
    inferSchema="true")

In [9]:
# Adding location hierarchy for customers

customers = customers.select(
    F.col("customer_id").cast("string"), 
    F.col("store_pref").cast("string").alias("store_id")
)

stores = stores.select(
    F.col("store_id").cast("string"), 
    F.col("store_region").cast("string").alias("region_id")
).dropDuplicates()

In [10]:
products = products.select(
    F.col("prod_id").cast("string").alias("sku"),
    F.col("prod_subclass").cast("string").alias("subclass"),
    F.col("prod_class").cast("string").alias("class"),
    F.col("prod_dept").cast("string").alias("dept"),
    F.col("prod_base_price").cast("float").alias("base_price"),
).dropDuplicates()

products.show(5)

+---+--------+-----+----+----------+
|sku|subclass|class|dept|base_price|
+---+--------+-----+----+----------+
|271|       0|    0|   7|    142.86|
|289|       0|    0|   3|      20.0|
|369|       0|    0|   2|     16.67|
|410|       1|    1|  14|    525.33|
|559|       0|    0|   4|     48.57|
+---+--------+-----+----+----------+
only showing top 5 rows



In [12]:
# daily calendar -> weekly calendar
weekly_calendar = calendar.where(
    F.col("day_of_week")=="0"
).select(
    F.to_date(F.col("calendar_day"),"MM-dd-yy").alias("date")
).distinct(
).sort(
    F.col("date").asc()
).withColumn(
    "week_index", F.monotonically_increasing_id()
).select(
    F.col("week_index").cast("string"),
    F.col("date")
)
weekly_calendar.show(5)

+----------+----------+
|week_index|      date|
+----------+----------+
|         0|2018-01-01|
|         1|2018-01-08|
|         2|2018-01-15|
|         3|2018-01-22|
|         4|2018-01-29|
+----------+----------+
only showing top 5 rows



In [14]:
# add hierarchies
demand_data = transactions.groupby(
    "sku", "store_id", "week_index"
).agg(
    F.count("*").alias("units"),
    F.first("promo_cat").alias("promo_cat"),
    F.max("discount").alias("discount"),
).join(
    weekly_calendar, on=["week_index"], how="inner"
).drop("week_index")

serialize(spark, demand_data, path + "demand_data.parquet").show(5)


[Stage 37:>                                                         (0 + 8) / 8]

22/12/28 17:44:29 WARN MemoryManager: Total allocation exceeds 95.00% (906,992,014 bytes) of heap memory
Scaling row group sizes to 96.54% for 7 writers
22/12/28 17:44:29 WARN MemoryManager: Total allocation exceeds 95.00% (906,992,014 bytes) of heap memory
Scaling row group sizes to 84.47% for 8 writers
22/12/28 17:44:31 WARN MemoryManager: Total allocation exceeds 95.00% (906,992,014 bytes) of heap memory
Scaling row group sizes to 96.54% for 7 writers


                                                                                

+---+--------+-----+---------+--------+----------+
|sku|store_id|units|promo_cat|discount|      date|
+---+--------+-----+---------+--------+----------+
|  0|      87|    1|      nan|    null|2020-02-03|
|100|      83|    1|      nan|    null|2020-02-03|
|101|      24|    1|      nan|    null|2020-02-03|
|101|      55|    1|      nan|    null|2020-02-03|
|102|      40|    1|      nan|    null|2020-02-03|
+---+--------+-----+---------+--------+----------+
only showing top 5 rows



In [16]:
demand_data = spark.read.parquet(path + "demand_data.parquet", header="true", inferSchema="true")

sales_filled_data = gapfilling(demand_data, date_column="date", product_column="sku", location_column="store_id")

serialize(spark, sales_filled_data, path + "sales_filled_data.parquet").show(5)


[Stage 52:>                                                         (0 + 8) / 9]

22/12/28 17:44:59 WARN MemoryManager: Total allocation exceeds 95.00% (906,992,014 bytes) of heap memory
Scaling row group sizes to 96.54% for 7 writers
22/12/28 17:44:59 WARN MemoryManager: Total allocation exceeds 95.00% (906,992,014 bytes) of heap memory
Scaling row group sizes to 84.47% for 8 writers
22/12/28 17:45:02 WARN MemoryManager: Total allocation exceeds 95.00% (906,992,014 bytes) of heap memory
Scaling row group sizes to 96.54% for 7 writers


                                                                                

+----------+---+--------+-----+---------+--------+
|      date|sku|store_id|units|promo_cat|discount|
+----------+---+--------+-----+---------+--------+
|2018-01-01|  0|      10|    0|     null|    null|
|2018-01-01|  0|      20|    0|     null|    null|
|2018-01-01|  0|      29|    0|     null|    null|
|2018-01-01|  0|      34|    0|     null|    null|
|2018-01-01|  0|      41|    0|     null|    null|
+----------+---+--------+-----+---------+--------+
only showing top 5 rows



In [18]:
# Adding product and location hierarchies to demand data
demand_data = spark.read.parquet(path + "sales_filled_data.parquet", header="true", inferSchema="true")

sales_data = demand_data.join(
    stores, on="store_id", how="inner"
).join(
    products, on="sku", how="inner"
)

serialize(spark, sales_data, path + "sales_data.parquet").show(5)

[Stage 74:>                                                         (0 + 8) / 9]

22/12/28 17:45:25 WARN MemoryManager: Total allocation exceeds 95.00% (906,992,014 bytes) of heap memory
Scaling row group sizes to 96.54% for 7 writers
22/12/28 17:45:25 WARN MemoryManager: Total allocation exceeds 95.00% (906,992,014 bytes) of heap memory
Scaling row group sizes to 84.47% for 8 writers
22/12/28 17:45:31 WARN MemoryManager: Total allocation exceeds 95.00% (906,992,014 bytes) of heap memory
Scaling row group sizes to 96.54% for 7 writers




22/12/28 17:45:31 WARN MemoryManager: Total allocation exceeds 95.00% (906,992,014 bytes) of heap memory
Scaling row group sizes to 96.54% for 7 writers


                                                                                

+---+--------+----------+-----+---------+--------+---------+--------+-----+----+----------+
|sku|store_id|      date|units|promo_cat|discount|region_id|subclass|class|dept|base_price|
+---+--------+----------+-----+---------+--------+---------+--------+-----+----+----------+
|102|      11|2018-01-01|    0|     null|    null|        1|       0|    0|   9|     164.0|
|102|       1|2018-01-22|    0|     null|    null|        2|       0|    0|   9|     164.0|
|102|      11|2018-01-15|    0|     null|    null|        1|       0|    0|   9|     164.0|
|102|       1|2018-02-26|    0|     null|    null|        2|       0|    0|   9|     164.0|
|102|      11|2018-03-12|    0|     null|    null|        1|       0|    0|   9|     164.0|
+---+--------+----------+-----+---------+--------+---------+--------+-----+----+----------+
only showing top 5 rows



## Feature Engineering

In [12]:
sales_data = spark.read.parquet(path + "sales_data.parquet", header="true", inferSchema="true")

engineered_data = engineerFeatures(
    data=sales_data,
    config=features_config
)

serialize(spark, engineered_data, path + "engineered_data.parquet").show(5)

                                                                                

+--------+---+--------+----------+-----+---------+--------+---------+-----+----+----------+--------------+----------------------+-------------+---------------+
|subclass|sku|store_id|      date|units|promo_cat|discount|region_id|class|dept|base_price|_discount_rate|_promoted_sku_subclass|_week_of_year|_promo_category|
+--------+---+--------+----------+-----+---------+--------+---------+-----+----+----------+--------------+----------------------+-------------+---------------+
|       0|102|      11|2018-01-01|    0|     null|    null|        1|    0|   9|     164.0|           1.0|                   296|            1|           null|
|       0|358|      41|2018-03-26|    0|     null|    null|        1|    0|   2|      15.0|           1.0|                   296|           13|           null|
|       0|102|       1|2018-01-22|    0|     null|    null|        2|    0|   9|     164.0|           1.0|                   296|            4|           null|
|       0|358|      37|2018-06-18|    0|

In [9]:
engineered_data = spark.read.parquet(path + "engineered_data.parquet", header="true", inferSchema="true")
engineered_data = engineered_data.where(F.col("sku")<400)

train_model(engineered_data, model_config=model_config, features_config=features_config)

                                                                                

[2022-12-28 18:02:50.181724] Saving the train SPARK dataframe


                                                                                

[2022-12-28 18:02:57.631199] Reading the train dataframe using Pandas
[2022-12-28 18:02:58.346298] One-hot encoding the train dataframe
[2022-12-28 18:03:28.064621] Transforming the train one-hot encoded data into a CSR matrix
[2022-12-28 18:04:41.285533] Saving the test SPARK dataframe


                                                                                

[2022-12-28 18:04:47.225439] Reading the test dataframe using Pandas
[2022-12-28 18:04:47.658583] One-hot encoding the test dataframe
[2022-12-28 18:04:57.403324] Transforming the test one-hot encoded data into a CSR matrix
[2022-12-28 18:05:08.000476] mse=0.11511720835514748


Unnamed: 0,sku,subclass,store_id,region_id,_week_of_year,_promo_category,_discount_rate,_promoted_sku_subclass,units,forecast
0,259,0,85,4,45,,1.0,296,4,1.331852
1,259,0,39,1,39,,1.0,296,0,-0.001409
2,102,0,1,2,47,,1.0,296,0,-0.001042
3,102,0,1,2,51,,1.0,296,1,1.203483
4,102,0,1,2,3,,1.0,296,0,-0.001042
...,...,...,...,...,...,...,...,...,...,...
2399995,7,1,28,1,41,,1.0,290,0,-0.001409
2399996,7,1,57,2,40,,1.0,290,2,1.331852
2399997,7,1,28,1,42,,1.0,290,0,-0.001409
2399998,7,1,57,2,50,,1.0,290,1,1.328338
