<a href="https://www.kaggle.com/code/khoatran311/baseline-model-not-yet-ready?scriptVersionId=215734880" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
import numpy as np
import pandas as pd
import polars as pl
import polars.selectors as pls
import seaborn as sea
import matplotlib.pyplot as plt 
import datetime as dt
import warnings

import optuna
from lightgbm import LGBMRegressor
from xgboost import XGBRegressor
from catboost import CatBoostRegressor

warnings.filterwarnings("ignore")

In [2]:
def format_pl():
    """FLOAT DISPLAY FORMATTING"""
    pl.Config.set_fmt_float("mixed")
    """STRING FORMATTING"""
    pl.Config.set_fmt_str_lengths(50)
    """TABLE FORMATTING"""
    pl.Config.set_tbl_rows(8)
    pl.Config.set_tbl_cols(15)
    pl.Config.set_tbl_width_chars(200)
    pl.Config.set_tbl_cell_alignment("RIGHT")
    pl.Config.set_tbl_hide_dtype_separator(True)
    pl.Config.set_tbl_hide_column_data_types(True)

format_pl()

# Reading Data

In [3]:
train = pl.read_csv("/kaggle/input/playground-series-s5e1/train.csv")
test = pl.read_csv("/kaggle/input/playground-series-s5e1/test.csv")
sample_submission = pl.read_csv("/kaggle/input/playground-series-s5e1/sample_submission.csv")

train = train.drop(pl.col("id"))
test  = test.drop(pl.col("id"))

In [4]:
train

date,country,store,product,num_sold
"""2010-01-01""","""Canada""","""Discount Stickers""","""Holographic Goose""",
"""2010-01-01""","""Canada""","""Discount Stickers""","""Kaggle""",973.0
"""2010-01-01""","""Canada""","""Discount Stickers""","""Kaggle Tiers""",906.0
"""2010-01-01""","""Canada""","""Discount Stickers""","""Kerneler""",423.0
…,…,…,…,…
"""2016-12-31""","""Singapore""","""Premium Sticker Mart""","""Kaggle""",2907.0
"""2016-12-31""","""Singapore""","""Premium Sticker Mart""","""Kaggle Tiers""",2299.0
"""2016-12-31""","""Singapore""","""Premium Sticker Mart""","""Kerneler""",1242.0
"""2016-12-31""","""Singapore""","""Premium Sticker Mart""","""Kerneler Dark Mode""",1622.0


In [5]:
test

date,country,store,product
"""2017-01-01""","""Canada""","""Discount Stickers""","""Holographic Goose"""
"""2017-01-01""","""Canada""","""Discount Stickers""","""Kaggle"""
"""2017-01-01""","""Canada""","""Discount Stickers""","""Kaggle Tiers"""
"""2017-01-01""","""Canada""","""Discount Stickers""","""Kerneler"""
…,…,…,…
"""2019-12-31""","""Singapore""","""Premium Sticker Mart""","""Kaggle"""
"""2019-12-31""","""Singapore""","""Premium Sticker Mart""","""Kaggle Tiers"""
"""2019-12-31""","""Singapore""","""Premium Sticker Mart""","""Kerneler"""
"""2019-12-31""","""Singapore""","""Premium Sticker Mart""","""Kerneler Dark Mode"""


# Data Processing

In [6]:
NA_features = train.select(
    pl.all().is_null().sum()
).unpivot(
    value_name="train_NA_count"
).filter(
    pl.col("train_NA_count") > 0
)

NA_features = NA_features.with_columns(
    (100*pl.col("train_NA_count")/train.shape[0]).alias("train_NA_pct")
)

print(NA_features)

shape: (1, 3)
┌──────────┬────────────────┬──────────────┐
│ variable ┆ train_NA_count ┆ train_NA_pct │
╞══════════╪════════════════╪══════════════╡
│ num_sold ┆           8871 ┆     3.854778 │
└──────────┴────────────────┴──────────────┘


In [7]:
unique_count = train.select(
    pl.all().n_unique()
).unpivot(
    value_name="train_unique_count"
)

unique_count = unique_count.sort(by="train_unique_count")
print(unique_count)

shape: (5, 2)
┌──────────┬────────────────────┐
│ variable ┆ train_unique_count │
╞══════════╪════════════════════╡
│    store ┆                  3 │
│  product ┆                  5 │
│  country ┆                  6 │
│     date ┆               2557 │
│ num_sold ┆               4038 │
└──────────┴────────────────────┘


In [8]:
train.filter(
    ~pl.col("num_sold").is_null()
).with_columns(
    np.log(pl.col("num_sold")).alias("log_num_sold")
).group_by(
    pl.col("country")
).agg(
    pl.col("log_num_sold").mean().alias("mean_log_num_sold")
).filter(
    pl.col("country").is_in({"Canada", "Kenya"})
)

country,mean_log_num_sold
"""Canada""",6.559667
"""Kenya""",2.831507


In [9]:
### Fill NA values in num_sold
train = train.with_columns(
    ### Fill Canada NA num_sold with e**6.56
    pl.when((pl.col("country") == "Canada") & pl.col("num_sold").is_null())
      .then(pl.col("num_sold").fill_null(np.exp(6.5596)))  
    ### Fill Kenya NA num_sold with e**2.83
      .when((pl.col("country") == "Kenya") & pl.col("num_sold").is_null())
      .then(pl.col("num_sold").fill_null(np.exp(2.8315))) 
    ## Retain other values
      .otherwise(pl.col("num_sold"))
      .alias("num_sold")    
)

In [10]:
### Count of categorical values per categorical variable
string_df = train.select(pls.string())

for feature in string_df.columns:
    count_df = string_df.group_by(
        pl.col(feature)
    ).agg(
        pl.count()
    ).sort(feature,  descending=True)
    print(count_df)

shape: (2_557, 2)
┌────────────┬───────┐
│       date ┆ count │
╞════════════╪═══════╡
│ 2016-12-31 ┆    90 │
│ 2016-12-30 ┆    90 │
│ 2016-12-29 ┆    90 │
│ 2016-12-28 ┆    90 │
│          … ┆     … │
│ 2010-01-04 ┆    90 │
│ 2010-01-03 ┆    90 │
│ 2010-01-02 ┆    90 │
│ 2010-01-01 ┆    90 │
└────────────┴───────┘
shape: (6, 2)
┌───────────┬───────┐
│   country ┆ count │
╞═══════════╪═══════╡
│ Singapore ┆ 38355 │
│    Norway ┆ 38355 │
│     Kenya ┆ 38355 │
│     Italy ┆ 38355 │
│   Finland ┆ 38355 │
│    Canada ┆ 38355 │
└───────────┴───────┘
shape: (3, 2)
┌──────────────────────┬───────┐
│                store ┆ count │
╞══════════════════════╪═══════╡
│    Stickers for Less ┆ 76710 │
│ Premium Sticker Mart ┆ 76710 │
│    Discount Stickers ┆ 76710 │
└──────────────────────┴───────┘
shape: (5, 2)
┌────────────────────┬───────┐
│            product ┆ count │
╞════════════════════╪═══════╡
│ Kerneler Dark Mode ┆ 46026 │
│           Kerneler ┆ 46026 │
│       Kaggle Tiers ┆ 46026 │
│   

In [11]:
train = train.with_columns(
    pl.col("date").str.to_date(format="%Y-%m-%d")
).with_columns(
    ### DATETIME EXTRACTION
    pl.col("date").dt.year().alias("year"),
    pl.col("date").dt.month().alias("month"),
    pl.col("date").dt.day().alias("day"),
    pl.col("date").dt.weekday().alias("weekday"),
    pl.col("date").dt.quarter().alias("quarter")
).with_columns(
    ### TRIG TRANSFORMS ON DATETIME FEATURES
    # Yearly pattern
    (np.cos(2*np.pi*pl.col("month")/12)).alias("cos_12months"),
    (np.sin(2*np.pi*pl.col("month")/12)).alias("sin_12months"),
    # Weekly pattern
    (np.cos(2*np.pi*pl.col("weekday")/7)).alias("cos_7days"),
    (np.sin(2*np.pi*pl.col("weekday")/7)).alias("sin_7days"),
    # Quarterly pattern if it exists??
    (np.cos(2*np.pi*pl.col("quarter")/4)).alias("cos_4quarters"),
    (np.sin(2*np.pi*pl.col("quarter")/4)).alias("sin_4quarters"),
).with_columns(
    ### APPROXIMATE YEARLY SEASONALITY OBTAINED FROM ANALYZING COUNTRIES
    pl.when( ((pl.col("month")==12) & (pl.col("day")==31)) | 
             ((pl.col("month")==1) & (pl.col("day")==1))
           ).then(1)
      .otherwise(0)
      .alias("global_yearly_seasonality")
).with_columns(
    ### APPROXIMATE WEEKLY SEASONALITY OBTAINED FROM ANALYZING COUNTRIES VERSION 1
    pl.when( ((pl.col("month")==1) & (pl.col("day")%7==3)) | 
             ((pl.col("month")!=1) & (pl.col("day")%7==0))
           ).then(1)
      .otherwise(0)
      .alias("global_weekly_seasonality_v1")
).with_columns(
    ### APPROXIMATE WEEKLY SEASONALITY VERSION 2 (Intuitively obtained)
    pl.when(pl.col("weekday") == 7)
      .then(1)
      .otherwise(0)
      .alias("global_weekly_seasonality_v2")
).with_columns(
    ### GLUEING VALUES OF CATEGORICAL VARIABLES FOR COMPLEX RELATIONSHIPS
    pl.concat_str([pl.col("store"), pl.lit("_"), pl.col("product")]).alias("store_product"),
    pl.concat_str([pl.col("country"), pl.lit("_"), pl.col("store"), pl.lit("_"), pl.col("product")]).alias("country_store_product")
)

In [12]:
train

date,country,store,product,num_sold,year,month,…,cos_4quarters,sin_4quarters,global_yearly_seasonality,global_weekly_seasonality_v1,global_weekly_seasonality_v2,store_product,country_store_product
2010-01-01,"""Canada""","""Discount Stickers""","""Holographic Goose""",705.989242,2010,1,…,6.1232e-17,1.0,1,0,0,"""Discount Stickers_Holographic Goose""","""Canada_Discount Stickers_Holographic Goose"""
2010-01-01,"""Canada""","""Discount Stickers""","""Kaggle""",973.0,2010,1,…,6.1232e-17,1.0,1,0,0,"""Discount Stickers_Kaggle""","""Canada_Discount Stickers_Kaggle"""
2010-01-01,"""Canada""","""Discount Stickers""","""Kaggle Tiers""",906.0,2010,1,…,6.1232e-17,1.0,1,0,0,"""Discount Stickers_Kaggle Tiers""","""Canada_Discount Stickers_Kaggle Tiers"""
2010-01-01,"""Canada""","""Discount Stickers""","""Kerneler""",423.0,2010,1,…,6.1232e-17,1.0,1,0,0,"""Discount Stickers_Kerneler""","""Canada_Discount Stickers_Kerneler"""
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
2016-12-31,"""Singapore""","""Premium Sticker Mart""","""Kaggle""",2907.0,2016,12,…,1.0,-2.4493e-16,1,0,0,"""Premium Sticker Mart_Kaggle""","""Singapore_Premium Sticker Mart_Kaggle"""
2016-12-31,"""Singapore""","""Premium Sticker Mart""","""Kaggle Tiers""",2299.0,2016,12,…,1.0,-2.4493e-16,1,0,0,"""Premium Sticker Mart_Kaggle Tiers""","""Singapore_Premium Sticker Mart_Kaggle Tiers"""
2016-12-31,"""Singapore""","""Premium Sticker Mart""","""Kerneler""",1242.0,2016,12,…,1.0,-2.4493e-16,1,0,0,"""Premium Sticker Mart_Kerneler""","""Singapore_Premium Sticker Mart_Kerneler"""
2016-12-31,"""Singapore""","""Premium Sticker Mart""","""Kerneler Dark Mode""",1622.0,2016,12,…,1.0,-2.4493e-16,1,0,0,"""Premium Sticker Mart_Kerneler Dark Mode""","""Singapore_Premium Sticker Mart_Kerneler Dark Mode"""


# Baseline Models

In [13]:
def lgbm_objective(trial):
    pass

In [14]:
def catboost_objective(trial):
    pass

In [15]:
def xgb_objective(trial):
    pass