In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import warnings
import itertools
warnings.filterwarnings("ignore")

sys.path.append('../src')

import pandas as pd
import seaborn as sns
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from sales_project.utils import reduce_size, compare_disributions, get_bins
from sales_project.plotters import linear_plot

pd.set_option('display.max_rows', 100)

sns.set_theme(context='talk', style="darkgrid", palette='dark', font='sans-serif')
tqdm.pandas()

In [3]:
df = pd.read_csv('../data/artifacts/data_after_eda.csv', parse_dates=['date'])
df.sort_values(by=['store_nbr', 'family', 'date'], inplace=True)

In [None]:
for col in ['store_nbr', 'cluster', 'year', 'month', 'weekday']:
    df[col] = df[col].astype('int')

In [7]:
reduce_size(df)

  0%|          | 0/20 [00:00<?, ?it/s]

In [8]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 3036528 entries, 0 to 3036527
Data columns (total 20 columns):
 #   Column                    Dtype         
---  ------                    -----         
 0   id                        uint32        
 1   date                      datetime64[ns]
 2   store_nbr                 uint8         
 3   family                    object        
 4   onpromotion               float32       
 5   is_submission             bool          
 6   city                      object        
 7   state                     object        
 8   type                      object        
 9   cluster                   uint8         
 10  dcoilwtico                float32       
 11  transactions              float32       
 12  median_sales_over_family  float32       
 13  relative_sales            float32       
 14  subset                    object        
 15  scaled_dcoilwtico         float32       
 16  is_promoted               bool          
 17  year         

# 1. Rolling and expanding features

In [9]:
def add_rolling_and_expanding_features(df: pd.DataFrame, feat_to_roll: str, kind: str, window: int) -> pd.DataFrame:
    
    if kind == 'rolling':
        new_feat = f"{feat_to_roll}.{kind}.mean.window.{window}"
        df2 = (
            df
            .groupby(["store_nbr", "family"])[feat_to_roll]
            .rolling(window=window, min_periods=1)
            .mean()
            .reset_index()
        )
    elif kind == 'expanding':
        new_feat = f"{feat_to_roll}.{kind}.mean"
        df2 = (
            df
            .groupby(["store_nbr", "family"])[feat_to_roll]
            .expanding(min_periods=1)
            .mean()
            .reset_index()
        )
    df[new_feat] = df2[feat_to_roll]
    return df

In [10]:
combinations = [
    ('rolling', 7),
    ('expanding', None,),
]

for args in tqdm(combinations):
    df = add_rolling_and_expanding_features(df=df, feat_to_roll='relative_sales', kind=args[0], window=args[1])

  0%|          | 0/2 [00:00<?, ?it/s]

# 2. Lag features

In [11]:
feats_to_lag = [
    "dcoilwtico",
    "transactions",
    "onpromotion",
    "relative_sales",
    "relative_sales.rolling.mean.window.7",
    "relative_sales.expanding.mean",
]
for lag in tqdm(range(1, 8)):
    df = pd.concat(
        [
            df,
            df.groupby(["store_nbr", "family"])[feats_to_lag]
            .shift(lag)
            .rename(columns=lambda x: f"{x}.lag.{lag}"),
        ],
        axis=1,
    )


  0%|          | 0/7 [00:00<?, ?it/s]

In [12]:
reduce_size(df)

  0%|          | 0/64 [00:00<?, ?it/s]

In [13]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 3036528 entries, 0 to 3036527
Data columns (total 64 columns):
 #   Column                                      Dtype         
---  ------                                      -----         
 0   id                                          uint32        
 1   date                                        datetime64[ns]
 2   store_nbr                                   uint8         
 3   family                                      object        
 4   onpromotion                                 float32       
 5   is_submission                               bool          
 6   city                                        object        
 7   state                                       object        
 8   type                                        object        
 9   cluster                                     uint8         
 10  dcoilwtico                                  float32       
 11  transactions                                float32    

In [14]:
df.to_csv("../data/artifacts/df_with_fe.csv", index=False)