In [8]:
import polars as pl
import polars.selectors as cs

from typing import Tuple, List
from datetime import date

# def ewm(arr: pl.Series, alpha=0.9):
#     if arr.len() == 0:
#         return None
#     weights = (1 - alpha) ** np.arange(arr.len()-1, -1, -1)
#     return (arr * weights).sum() / weights.sum()

def feature_engineering(snapshot_path: str, between: Tuple[date, date], target:str, agg_level: List[str], fill_nulls: bool=False) -> pl.LazyFrame:
    """
    Feature engineering for sales data.
    
    Args:
        sales_lzdf (pl.LazyFrame): LazyFrame containing sales data.
        
    Returns:
        pl.LazyFrame: LazyFrame with engineered features.
    """
    
    # 1. Load
    sales_lzdf = pl.scan_parquet(
        snapshot_path,
    ).sort(agg_level + ["date"]).with_columns([
        pl.col("units_sold").log1p(),
        cs.boolean().cast(pl.Int8),  # Convert boolean to int for compatibility with LightGBM
        cs.string().cast(pl.Categorical), # Convert string categories to categorical type
    ]).rename({"units_sold": "log_units_sold"})  # Drop original units_sold column to avoid confusion

    # 2. Add calendar features
    sales_lzdf = sales_lzdf.with_columns([
        pl.col("date").dt.weekday().alias("weekday"),
        pl.col("date").dt.month().alias("month"),
        pl.col("date").dt.week().alias("weekofyear"),
        pl.col("date").dt.ordinal_day().alias("dayofyear"),
    ]).with_columns(
        cs.boolean().cast(pl.Int8)  # Convert boolean to int for compatibility with LightGBM
    )

    # # 3. Add lagged features for the previous 1, 3, 7, and 14 days
    # for lag in [1, 3, 7, 14]:
    #     # 3.1 This is a workaround for considering no registered sales on the lag day
    #     # by shifting the date by the lag and joining on the shifted date
    #     tmp = sales_lzdf.select(
    #         *agg_level,
    #         pl.col("date") + pl.duration(days=lag),
    #         pl.col("log_units_sold").alias(f"lag_{lag}d_log_units_sold"),
    #         # pl.col("is_on_promotion").alias(f"lag_{lag}d_is_on_promotion"),
    #     )
        
    #     sales_lzdf = sales_lzdf.join(tmp, on=agg_level+["date"], how="left")

        # 3.2 This is the original way to add lags, but it will not consider no registered sales on the lag day
        # sales_lzdf = sales_lzdf.with_columns(
        #     pl.col("log_units_sold")
        #         .shift_by(lag, fill_value=0)  # Fill with 0 to avoid NaN
        #         .over(["product_id", "store_id"], order_by="date")
        #         .alias(f"lag_{lag}d_log_units_sold")
        # )

    # 4. Add rolling features over the previous 3, 7, 14, and 28 days
    for window in [3, 7, 14, 28]:
        # Only consider window before current date to avoid data leakage. This is done by using closed='left'
        tmp = sales_lzdf.rolling('date', period=f'{window}d', closed='right', group_by=agg_level).agg(
            # 4.1 Calculate rolling mean, median, std, min, max, and ewm_mean to log_units_sold
            pl.col("log_units_sold").mean().alias(f"mean_{window}d_log_units_sold"),
            pl.col("log_units_sold").median().alias(f"median_{window}d_log_units_sold"),
            pl.col("log_units_sold").std().alias(f"std_{window}d_log_units_sold"),
            pl.col("log_units_sold").min().alias(f"min_{window}d_log_units_sold"),
            pl.col("log_units_sold").max().alias(f"max_{window}d_log_units_sold"),
            pl.col("log_units_sold").ewm_mean(alpha=0.9, adjust=True).last().alias(f"ewm_{window}d_log_units_sold"),
            pl.col("log_units_sold").diff().mean().alias(f"diff_mean_{window}d_log_units_sold"),
        ).with_columns(
            # 4.3 Calculate ratio of max to mean. Useful to identify outliers
            (pl.col(f"max_{window}d_log_units_sold") / pl.col(f"mean_{window}d_log_units_sold"))
                .alias(f"max_mean_ratio_{window}d_log_units_sold")
        )
        # sales_lzdf = pl.concat([sales_lzdf, tmp], how="horizontal", parallel=True)
        sales_lzdf = sales_lzdf.join(tmp, on=agg_level + ["date"], how="left")

        # sales_lzdf = sales_lzdf.with_columns(
        #     # Calculate rolling mean
        #     pl.col('log_units_sold') # Only consider window before current date to avoid data leakage
        #         .rolling_mean_by('date', window_size=f"{window}d", closed="left")
        #         .over(["product_id", "store_id"])
        #         .alias(f"mean_{window}d_log_units_sold"),
        #     # Calculate rolling median
        #     pl.col('log_units_sold') # Only consider window before current date to avoid data leakage
        #         .rolling_median_by('date', window_size=f"{window}d", closed="left")
        #         .over(["product_id", "store_id"])
        #         .alias(f"median_{window}d_log_units_sold"),
        #     # Calculate rolling standard deviation
        #     pl.col('log_units_sold') # Only consider window before current date to avoid data leakage
        #         .rolling_std_by('date', window_size=f"{window}d", closed="left")
        #         .over(["product_id", "store_id"])
        #         .alias(f"std_{window}d_log_units_sold"),
        #     # Calculate rolling min
        #     pl.col('log_units_sold') # Only consider window before current date to avoid data leakage
        #         .rolling_min_by('date', window_size=f"{window}d", closed="left")
        #         .over(["product_id", "store_id"])
        #         .alias(f"min_{window}d_log_units_sold"),
        #     # Calculate rolling max
        #     pl.col('log_units_sold') # Only consider window before current date to avoid data leakage
        #         .rolling_max_by('date', window_size=f"{window}d", closed="left")
        #         .over(["product_id", "store_id"])
        #         .alias(f"max_{window}d_log_units_sold"),
        #     # # calculate rolling exponential weighted mean
        #     # pl.col('log_units_sold') # Only consider window before current date to avoid data leakage
        #     # .rolling_map(ewm, window_size=f"{window}d", closed="left"),
                # pl.col("is_on_promotion")
                #         .rolling_sum_by('date', window_size=f"{window}d", closed="left")
                #         .over(agg_level)
                #         .alias(f"sum_{window}d_is_on_promotion")
        # )

    # 5. Add weekday rolling mean. e.i. mean of the same weekday in the past 4 weeks
    for weeks in [1, 2, 3, 4]:
        sales_lzdf = sales_lzdf.with_columns(
            pl.col("log_units_sold")
                .rolling_mean_by('date', window_size=f"{weeks}w", closed="left")
                .over(agg_level)
                .alias(f"mean_{weeks}w_log_units_sold")
        )

    # 5. Add yearly rolling mean. e.i. mean of the same day in the past 4 years
    for years in [1, 2, 3, 4]:
        sales_lzdf = sales_lzdf.with_columns(
            pl.col("log_units_sold")
                .rolling_mean_by('date', window_size=f"{years}w", closed="left")
                .over(agg_level)
                .alias(f"mean_{years}y_log_units_sold")
        )

    # 6. Add is_on_promotion column rolling sum after 16 days
    for window in [3, 7, 14]:
        tmp = sales_lzdf.rolling('date', period=f'{window}d', offset="0d", closed='right', group_by=agg_level).agg(
            pl.col("is_on_promotion").sum().alias(f"sum_next_{window}d_is_on_promotion")
        )
        sales_lzdf = sales_lzdf.join(tmp, on=agg_level + ["date"], how="left")

    # 7. Join item features
    # TODO: Esta concatenacion no deberia de estar aqui, deberia de ser otro paso del pipeline
    products_lzdf = pl.scan_parquet(
        "../../data/favorita_dataset/subset/products.parquet"
    ).with_columns(
        cs.boolean().cast(pl.Int8),  # Convert boolean to int for compatibility with LightGBM
        cs.string().cast(pl.Categorical), # Convert string categories to categorical type
    )

    sales_lzdf = sales_lzdf.join(
        products_lzdf,
        on="product_id",
        how="left"
    )

    # 8. Finally fills null values with 0
    if fill_nulls:
        sales_lzdf = sales_lzdf.fill_null(0)
    
    # 9. Filter by date range
    return sales_lzdf.filter(pl.col("date").is_between(*between))
    

# Get columns for horizons
def apply_horizon_shifting(train_dataset: pl.DataFrame, horizons: int, agg_level: List[str]):
    # Add predictions columns for horizons
    for horizon in range(1, horizons + 1):
        tmp = train_dataset.select(
            *agg_level,
            pl.col("date") - pl.duration(days=horizon),
            pl.col("log_units_sold").alias(f"h{horizon}_log_units_sold"),
        )

        train_dataset = train_dataset.join(tmp, on=agg_level+["date"], how="left")
    return train_dataset

    # # 7. Flatten for single model
    # # let's expand for horizon 1–7
    # horizons = []
    # for h in range(1,2):
    #     tmp = sales_lzdf.filter(
    #         pl.col("date") <= pl.date("2017-08-15")
    #     ).with_columns([
    #         pl.lit(h).alias("horizon")
    #     ])
    #     # align y label
    #     tmp = tmp.with_columns(
    #         pl.col("units_sold")
    #         .shift(-h, by=["store_nbr", "item_nbr"])
    #         .alias("target")
    #     )
    #     horizons.append(tmp)

    # train = pl.concat(horizons)

In [18]:
dataset = feature_engineering(
    snapshot_path="../../data/favorita_dataset/subset/sales_train.parquet",
    between=(date(2013, 1, 1), date(2017, 8, 15)),
    target="log_units_sold",
    agg_level=["product_id", "store_id"],
).collect().lazy()

dataset = apply_horizon_shifting(dataset, horizons=7, agg_level=["product_id", "store_id"])

In [20]:
dataset.collect()

date,store_id,product_id,log_units_sold,is_on_promotion,weekday,month,weekofyear,dayofyear,mean_3d_log_units_sold,median_3d_log_units_sold,std_3d_log_units_sold,min_3d_log_units_sold,max_3d_log_units_sold,ewm_3d_log_units_sold,diff_mean_3d_log_units_sold,max_mean_ratio_3d_log_units_sold,mean_7d_log_units_sold,median_7d_log_units_sold,std_7d_log_units_sold,min_7d_log_units_sold,max_7d_log_units_sold,ewm_7d_log_units_sold,diff_mean_7d_log_units_sold,max_mean_ratio_7d_log_units_sold,mean_14d_log_units_sold,median_14d_log_units_sold,std_14d_log_units_sold,min_14d_log_units_sold,max_14d_log_units_sold,ewm_14d_log_units_sold,diff_mean_14d_log_units_sold,max_mean_ratio_14d_log_units_sold,mean_28d_log_units_sold,median_28d_log_units_sold,std_28d_log_units_sold,min_28d_log_units_sold,max_28d_log_units_sold,ewm_28d_log_units_sold,diff_mean_28d_log_units_sold,max_mean_ratio_28d_log_units_sold,mean_1w_log_units_sold,mean_2w_log_units_sold,mean_3w_log_units_sold,mean_4w_log_units_sold,mean_1y_log_units_sold,mean_2y_log_units_sold,mean_3y_log_units_sold,mean_4y_log_units_sold,sum_next_3d_is_on_promotion,sum_next_7d_is_on_promotion,sum_next_14d_is_on_promotion,family,class,perishable,h1_log_units_sold,h2_log_units_sold,h3_log_units_sold,h4_log_units_sold,h5_log_units_sold,h6_log_units_sold,h7_log_units_sold
date,i64,i64,f64,i8,i8,i8,i8,i16,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64,i64,i64,cat,cat,i8,f64,f64,f64,f64,f64,f64,f64
2013-01-02,3,213652,3.555348,0,3,1,1,2,3.555348,3.555348,,3.555348,3.555348,3.555348,,1.0,3.555348,3.555348,,3.555348,3.555348,3.555348,,1.0,3.555348,3.555348,,3.555348,3.555348,3.555348,,1.0,3.555348,3.555348,,3.555348,3.555348,3.555348,,1.0,,,,,,,,,0,0,0,"""GROCERY I""","""1048""",0,3.912023,3.135494,3.401197,3.332205,3.258097,2.944439,3.526361
2013-01-03,3,213652,3.912023,0,4,1,1,3,3.733686,3.733686,0.252207,3.555348,3.912023,3.879598,0.356675,1.047764,3.733686,3.733686,0.252207,3.555348,3.912023,3.879598,0.356675,1.047764,3.733686,3.733686,0.252207,3.555348,3.912023,3.879598,0.356675,1.047764,3.733686,3.733686,0.252207,3.555348,3.912023,3.879598,0.356675,1.047764,3.555348,3.555348,3.555348,3.555348,3.555348,3.555348,3.555348,3.555348,0,0,0,"""GROCERY I""","""1048""",0,3.135494,3.401197,3.332205,3.258097,2.944439,3.526361,3.433987
2013-01-04,3,213652,3.135494,0,5,1,1,4,3.534288,3.555348,0.388693,3.135494,3.912023,3.209234,-0.209927,1.106877,3.534288,3.555348,0.388693,3.135494,3.912023,3.209234,-0.209927,1.106877,3.534288,3.555348,0.388693,3.135494,3.912023,3.209234,-0.209927,1.106877,3.534288,3.555348,0.388693,3.135494,3.912023,3.209234,-0.209927,1.106877,3.733686,3.733686,3.733686,3.733686,3.733686,3.733686,3.733686,3.733686,0,0,0,"""GROCERY I""","""1048""",0,3.401197,3.332205,3.258097,2.944439,3.526361,3.433987,2.639057
2013-01-05,3,213652,3.401197,0,6,1,1,5,3.482905,3.401197,0.39466,3.135494,3.912023,3.381862,-0.255413,1.123207,3.501016,3.478273,0.324268,3.135494,3.912023,3.382018,-0.051384,1.117397,3.501016,3.478273,0.324268,3.135494,3.912023,3.382018,-0.051384,1.117397,3.501016,3.478273,0.324268,3.135494,3.912023,3.382018,-0.051384,1.117397,3.534288,3.534288,3.534288,3.534288,3.534288,3.534288,3.534288,3.534288,0,0,0,"""GROCERY I""","""1048""",0,3.332205,3.258097,2.944439,3.526361,3.433987,2.639057,2.944439
2013-01-06,3,213652,3.332205,0,7,1,1,6,3.289632,3.332205,0.137873,3.135494,3.401197,3.336648,0.098355,1.033914,3.467253,3.401197,0.290795,3.135494,3.912023,3.337185,-0.055786,1.128277,3.467253,3.401197,0.290795,3.135494,3.912023,3.337185,-0.055786,1.128277,3.467253,3.401197,0.290795,3.135494,3.912023,3.337185,-0.055786,1.128277,3.501016,3.501016,3.501016,3.501016,3.501016,3.501016,3.501016,3.501016,0,0,0,"""GROCERY I""","""1048""",0,3.258097,2.944439,3.526361,3.433987,2.639057,2.944439,2.833213
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
2017-08-11,47,507478,3.332205,1,5,8,32,223,3.112031,3.295837,0.350329,2.70805,3.332205,3.275647,0.018184,1.070749,3.202917,3.295837,0.311544,2.70805,3.496508,3.275268,-0.027384,1.091664,3.223743,3.276967,0.296953,2.70805,3.713572,3.275268,-0.012639,1.151944,3.168747,3.238486,0.340994,2.397895,3.713572,3.275268,-0.0013,1.171937,3.139798,3.215648,3.218316,3.152967,3.139798,3.215648,3.218316,3.152967,0,0,0,"""EGGS""","""2502""",1,2.833213,3.332205,3.496508,3.091042,,,
2017-08-12,47,507478,2.833213,0,6,8,32,224,2.957823,2.833213,0.330209,2.70805,3.332205,2.87704,0.062582,1.126573,3.10816,3.258097,0.308219,2.70805,3.496508,2.877419,-0.070814,1.124944,3.176365,3.238486,0.30294,2.70805,3.713572,2.877419,-0.06772,1.169126,3.149672,3.218876,0.344397,2.397895,3.713572,2.877419,-0.024566,1.179034,3.202917,3.223743,3.244964,3.168747,3.202917,3.223743,3.244964,3.168747,0,0,0,"""EGGS""","""2502""",1,3.332205,3.496508,3.091042,,,,
2017-08-13,47,507478,3.332205,0,7,8,32,225,3.165874,3.332205,0.288093,2.833213,3.332205,3.28725,0.0,1.052539,3.118747,3.295837,0.315416,2.70805,3.496508,3.286726,-0.027384,1.121126,3.149125,3.238486,0.265786,2.70805,3.496508,3.286726,0.025882,1.110311,3.143804,3.218876,0.339635,2.397895,3.713572,3.286726,0.008932,1.181235,3.10816,3.176365,3.219532,3.149672,3.10816,3.176365,3.219532,3.149672,0,0,0,"""EGGS""","""2502""",1,3.496508,3.091042,,,,,
2017-08-14,47,507478,3.496508,0,1,8,33,226,3.220642,3.332205,0.345434,2.833213,3.496508,3.47573,0.331647,1.085656,3.118747,3.295837,0.315416,2.70805,3.496508,3.475529,0.110549,1.121126,3.184894,3.276967,0.277015,2.70805,3.496508,3.475529,0.007332,1.097841,3.158285,3.238486,0.345888,2.397895,3.713572,3.475529,0.034502,1.175819,3.118747,3.149125,3.202547,3.143804,3.118747,3.149125,3.202547,3.143804,0,0,0,"""EGGS""","""2502""",1,3.091042,,,,,,


In [17]:
train_dataset["date"].value_counts().sort("date", descending=True)

date,count
date,u32
2017-08-15,20
2017-08-14,20
2017-08-13,20
2017-08-12,20
2017-08-11,20
…,…
2013-01-06,20
2013-01-05,20
2013-01-04,20
2013-01-03,20


In [36]:
#transform bit string to integer
def bitstring_to_int(bitstring: str) -> int:
    """Convert a bit string to an integer."""
    return int(bitstring, 2)
# Example usage
bitstring = "1" * 31
print(bitstring_to_int(bitstring))  # Output: 13

2147483647


In [66]:
events_lzdf = pl.scan_parquet(
    "./data/favorita_dataset/subset/events.parquet")

stores_lzdf = pl.scan_parquet(
    "./data/favorita_dataset/subset/stores.parquet"
).with_columns(
    cs.boolean().cast(pl.Int8),  # Convert boolean to int for compatibility with LightGBM
    cs.string().cast(pl.Categorical), # Convert string categories to categorical type
)

events_lzdf.collect().head(10)

date,type,locale,locale_name,description
date,cat,cat,cat,str
2012-08-10,"""Holiday""","""National""","""Ecuador""","""Primer Grito de Independencia"""
2012-10-12,"""Transfer""","""National""","""Ecuador""","""Independencia de Guayaquil"""
2012-11-02,"""Holiday""","""National""","""Ecuador""","""Dia de Difuntos"""
2012-11-03,"""Holiday""","""National""","""Ecuador""","""Independencia de Cuenca"""
2012-12-05,"""Additional""","""Local""","""Quito""","""Fundacion de Quito-1"""
2012-12-06,"""Holiday""","""Local""","""Quito""","""Fundacion de Quito"""
2012-12-21,"""Additional""","""National""","""Ecuador""","""Navidad-4"""
2012-12-22,"""Additional""","""National""","""Ecuador""","""Navidad-3"""
2012-12-23,"""Additional""","""National""","""Ecuador""","""Navidad-2"""
2012-12-24,"""Bridge""","""National""","""Ecuador""","""Puente Navidad"""


In [67]:
stores_lzdf.collect().head(10)

store_id,city,state,type,cluster,country
i32,cat,cat,cat,cat,cat
3,"""Quito""","""Pichincha""","""D""","""8""","""Ecuador"""
44,"""Quito""","""Pichincha""","""A""","""5""","""Ecuador"""
45,"""Quito""","""Pichincha""","""A""","""11""","""Ecuador"""
47,"""Quito""","""Pichincha""","""A""","""14""","""Ecuador"""


In [77]:
events_lzdf.join(
    stores_lzdf.select("store_id", "city"),
    left_on="locale_name",
    right_on="city",
    how="left",
).join(
    stores_lzdf.select("store_id", "country"),
    left_on="locale_name",
    right_on="country",
    how="left",
    # coalesce="store_id"
).collect().head(10)

  ).collect().head(10)


date,type,locale,locale_name,description,store_id,store_id_right
date,cat,cat,cat,str,i32,i32
2012-08-10,"""Holiday""","""National""","""Ecuador""","""Primer Grito de Independencia""",,3
2012-08-10,"""Holiday""","""National""","""Ecuador""","""Primer Grito de Independencia""",,44
2012-08-10,"""Holiday""","""National""","""Ecuador""","""Primer Grito de Independencia""",,45
2012-08-10,"""Holiday""","""National""","""Ecuador""","""Primer Grito de Independencia""",,47
2012-10-12,"""Transfer""","""National""","""Ecuador""","""Independencia de Guayaquil""",,3
2012-10-12,"""Transfer""","""National""","""Ecuador""","""Independencia de Guayaquil""",,44
2012-10-12,"""Transfer""","""National""","""Ecuador""","""Independencia de Guayaquil""",,45
2012-10-12,"""Transfer""","""National""","""Ecuador""","""Independencia de Guayaquil""",,47
2012-11-02,"""Holiday""","""National""","""Ecuador""","""Dia de Difuntos""",,3
2012-11-02,"""Holiday""","""National""","""Ecuador""","""Dia de Difuntos""",,44


In [2]:
import polars.selectors as cs
sales_lzdf.filter(
    (pl.col("store_id") == 3) 
    & (pl.col("product_id") == 213652)
    # & (pl.col("lag_1d_log_units_sold").is_null())
    # & (pl.col("date").is_between(pl.date(2013,12,23), pl.date(2013,12,31)) )
).select(
    # pl.all()
    'date', "lag_1d_log_units_sold", "log_units_sold", "perishable", "is_on_promotion", "pct_next_14d_is_on_promotion",
    # "date", pl.col.log_units_sold, cs.contains("diff"), #cs.contains("promo")
).filter(
    pl.col("date") >= pl.date(2016,11,10)
    # pl.col("sum_3d_after_is_on_promotion") > 0
).collect().head(10)
# sales_lzdf.collect()

date,lag_1d_log_units_sold,log_units_sold,perishable,is_on_promotion,pct_next_14d_is_on_promotion
date,f64,f64,i8,i8,f64
2016-11-10,3.091042,3.135494,0,0,0.642857
2016-11-11,3.135494,3.663562,0,0,0.714286
2016-11-12,3.663562,2.833213,0,0,0.785714
2016-11-13,2.833213,3.332205,0,0,0.857143
2016-11-14,3.332205,3.135494,0,0,0.928571
2016-11-15,3.135494,3.218876,0,0,1.0
2016-11-16,3.218876,3.951244,0,1,1.0
2016-11-17,3.951244,3.258097,0,1,1.0
2016-11-18,3.258097,3.465736,0,1,1.0
2016-11-19,3.465736,3.526361,0,1,1.0


In [28]:


# get target set for train set and validation set
def pop_columns(df: pl.DataFrame, col_names: List[str]) -> pl.DataFrame:
    return pl.DataFrame(
        [
            df.drop_in_place(col_name)
            for col_name in col_names
        ]
    )

def split_stage(
    dataset: pl.LazyFrame,
    interval: Tuple[date, date],
    target_cols: List[str],
) -> Tuple[pl.DataFrame, pl.DataFrame]:

    x_df = dataset.filter( pl.col("date").is_between(*interval)).collect()
    y_df = pop_columns(x_df, target_cols)
    date_df =pop_columns(x_df, ["date"])

    return x_df, y_df, date_df

def save_df(df: pl.DataFrame, path: str):
    df.write_parquet(
        path,
        compression="snappy",
        # row_group_size=1000000,
        # partition_by=["store_id", "product_id"],
    )

def save_split(
    dataset: pl.LazyFrame,
    interval: Tuple[date, date],
    tag: str = "train",
    horizon: int = 7
):
    target_cols = [f"h{h}_log_units_sold" for h in range(1, horizon+1)]

    x_train_df, y_train_df, dates_train_df = split_stage(
        dataset=dataset,
        interval=interval,
        target_cols=target_cols,
    )

    save_df(x_train_df, f"../../data/favorita_dataset/subset/x_{tag}.parquet")
    save_df(y_train_df, f"../../data/favorita_dataset/subset/y_{tag}.parquet")
    save_df(dates_train_df, f"../../data/favorita_dataset/subset/dates_{tag}.parquet")
    

In [29]:
train_interval = (date(2013, 1, 1), date(2016, 12, 31))
save_split(dataset, train_interval, tag="train", horizon=7)

valid_interval = (date(2016, 12, 1), date(2017, 8, 15))
save_split(dataset, valid_interval, tag="valid", horizon=7)

In [None]:
import numpy as np
import polars as pl

alpha = 0.9
window = 2

dates = [
    "2020-01-01",
    "2020-01-02",
    "2020-01-03",
    "2020-01-04",
    "2020-01-06",
    "2020-01-07",
    "2020-01-01",
    "2020-01-02",
]

df = pl.DataFrame({
    "store_id": [3,3,3,3,3,3,2,2],
    "product_id": [10,10,10,10,10,10,10,10],
    "date": dates,
    "b": [1,2,3,4,5,6,1,2],
}).with_columns(
    pl.col("date").str.strptime(pl.Date).set_sorted()
)

# [
#     # pl.col("b"). #.last()#.shift(-1).first().fill_null(0)
# ]
df.with_columns(
    pl.col("b")
        # .rolling("date", period="2d", closed="left", offset="0d")
        # .list.mean()
        .rolling_min(2)
        .over(["product_id", "store_id"], order_by="date")
        .alias("new")
        # .rolling("date", period="2d", closed="left", offset="0d")
        # .mean()
        #     pl.col("b").shift(-1)  # Shift to get the next value in the group
        # )
        # .shift(-1)  # Shift to get the next value in the group
        
        # .over(["product_id", "store_id"])
        # .ewm_mean(alpha=0.9, adjust=True)
        # .alias("new")
)

store_id,product_id,date,b,c,new
i64,i64,date,i64,i64,i64
3,10,2020-01-01,1,1,
3,10,2020-01-02,2,1,1.0
3,10,2020-01-03,3,1,2.0
3,10,2020-01-04,4,1,3.0
3,10,2020-01-06,5,1,4.0
3,10,2020-01-07,6,1,5.0
2,10,2020-01-01,1,1,
2,10,2020-01-02,2,1,1.0


In [11]:
df.rolling('date', period='2d', closed='right', group_by=["product_id", "store_id"]).agg(
    pl.exclude("date"),
    pl.col("b").mean().alias("b_rolling_window"),
)

product_id,store_id,date,b,c,b_rolling_window
i64,i64,date,list[i64],list[i64],f64
10,2,2020-01-01,[1],[1],1.0
10,2,2020-01-02,"[1, 2]","[1, 1]",1.5
10,3,2020-01-01,[1],[1],1.0
10,3,2020-01-02,"[1, 2]","[1, 1]",1.5
10,3,2020-01-03,"[2, 3]","[1, 1]",2.5
10,3,2020-01-04,"[3, 4]","[1, 1]",3.5
10,3,2020-01-06,[5],[1],5.0
10,3,2020-01-07,"[5, 6]","[1, 1]",5.5


In [8]:
df2 = df.__copy__()

for horizon in range(1, 8):
    tmp = df2.select(
        *["product_id", "store_id"],
        pl.col("date") - pl.duration(days=horizon),
        pl.col("b").alias(f"h{horizon}"),
    )
    df2 = df2.join(tmp, on=["product_id", "store_id", "date"], how="left")
df2

store_id,product_id,date,b,c,h1,h2,h3,h4,h5,h6,h7
i64,i64,date,i64,i64,i64,i64,i64,i64,i64,i64,i64
3,10,2020-01-01,1,1,2.0,3.0,4.0,,5.0,6.0,
3,10,2020-01-02,2,1,3.0,4.0,,5.0,6.0,,
3,10,2020-01-03,3,1,4.0,,5.0,6.0,,,
3,10,2020-01-04,4,1,,5.0,6.0,,,,
3,10,2020-01-06,5,1,6.0,,,,,,
3,10,2020-01-07,6,1,,,,,,,
2,10,2020-01-01,1,1,2.0,,,,,,
2,10,2020-01-02,2,1,,,,,,,


In [6]:
df

store_id,product_id,date,b,c,h1_log_units_sold,h2_log_units_sold,h3_log_units_sold,h4_log_units_sold,h5_log_units_sold,h6_log_units_sold,h7_log_units_sold
i64,i64,date,i64,i64,i64,i64,i64,i64,i64,i64,i64
3,10,2020-01-01,1,1,2.0,3.0,4.0,,5.0,6.0,
3,10,2020-01-02,2,1,3.0,4.0,,5.0,6.0,,
3,10,2020-01-03,3,1,4.0,,5.0,6.0,,,
3,10,2020-01-04,4,1,,5.0,6.0,,,,
3,10,2020-01-06,5,1,6.0,,,,,,
3,10,2020-01-07,6,1,,,,,,,
2,10,2020-01-01,1,1,2.0,,,,,,
2,10,2020-01-02,2,1,,,,,,,


In [None]:
df.with_columns(
    pl.col("b")
        .rolling_mean_by('date', window_size=f"{window}d", closed="right")
        .over(["product_id", "store_id"])
        .alias("b_rolling_mean")
).collect()

In [213]:
# Example of using shift with group_by
# This will shift the column 'b' by 1 within each group defined by 'product_id' and 'store_id'
df.group_by(["product_id", "store_id"]).agg(
    [
        pl.col("b").shift(-1),
        pl.col("date"),
        pl.col("b").diff().mean().alias("diff_b"),
        # pl.col("b").ro(alpha=0.9).alias("ewm_mean_b"),
    ]
)#.explode(["date", "b", "diff_b"])

product_id,store_id,b,date,diff_b
i64,i64,list[i64],list[date],f64
10,3,"[2, 3, … null]","[2020-01-01, 2020-01-02, … 2020-01-07]",1.0
10,2,"[2, null]","[2020-01-01, 2020-01-02]",1.0


In [285]:
df.select(
    [
        "product_id",
        "store_id",
        "date",
        pl.col("c")
            .rolling_sum_by('date', window_size="2d", offset="0d", closed="left")
    ]
)

TypeError: Expr.rolling_sum_by() got an unexpected keyword argument 'offset'

In [113]:
dff.to_arrow()

pyarrow.Table
store_id: dictionary<values=large_string, indices=uint32, ordered=0>
----
store_id: [  -- dictionary:
["3","2"]  -- indices:
[0,0,0,0,1,1]]

In [53]:
import numpy as np

x = np.array([1, 2, 3])
alpha = 0.9

ewm = [
    x[0],
    (2*1 + 1*0.1) / (1 + 0.1),
    (3*1 + 2*0.1 + 1*0.01) / (1 + 0.1 + 0.01)
]
print(ewm)
# [1.0, 1.9090909090909092, 2.891891891891892]

[1, 1.909090909090909, 2.8918918918918917]


In [95]:
out = sales_lzdf.rolling('date', period='3d', closed='left', group_by=["product_id", "store_id"]).agg(
    [
        pl.sum("log_units_sold").alias("sum_log_units_sold"),
        pl.median("log_units_sold").alias("median_log_units_sold"),
        pl.std("log_units_sold").alias("std_log_units_sold"),
        pl.mean("log_units_sold").alias("mean_log_units_sold"),
        pl.min("log_units_sold").alias("min_log_units_sold"),
        pl.max("log_units_sold").alias("max_log_units_sold"),
    ]
).collect()
out.filter(
    (pl.col("store_id") == 3) 
    & (pl.col("product_id") == 213652)
)
# .select(
#     pl.col("date"),cs.contains("units_sold")
# )

product_id,store_id,date,sum_log_units_sold,median_log_units_sold,std_log_units_sold,mean_log_units_sold,min_log_units_sold,max_log_units_sold
i64,i64,date,f64,f64,f64,f64,f64,f64
213652,3,2013-01-02,0.0,,,,,
213652,3,2013-01-03,3.555348,3.555348,,3.555348,3.555348,3.555348
213652,3,2013-01-04,7.467371,3.733686,0.252207,3.733686,3.555348,3.912023
213652,3,2013-01-05,10.602865,3.555348,0.388693,3.534288,3.135494,3.912023
213652,3,2013-01-06,10.448715,3.401197,0.39466,3.482905,3.135494,3.912023
…,…,…,…,…,…,…,…,…
213652,3,2017-08-11,9.465215,3.135494,0.33208,3.155072,2.833213,3.496508
213652,3,2017-08-12,9.587817,3.258097,0.335987,3.195939,2.833213,3.496508
213652,3,2017-08-13,9.035749,2.944439,0.220332,3.011916,2.833213,3.258097
213652,3,2017-08-14,9.092907,2.944439,0.198547,3.030969,2.890372,3.258097


In [None]:
# # 8. LightGBM train
# # lgb_train = lgb.Dataset(
# #     train.drop(["date","target"]).to_pandas(),
# #     label=train["target"].to_numpy()
# # )

# # params = {
# #     "objective": "regression",
# #     "metric": "rmse",
# #     "learning_rate": 0.05,
# #     "num_leaves": 63,
# #     "feature_fraction": 0.8
# # }

# # bst = lgb.train(
# #     params,
# #     lgb_train,
# #     num_boost_round=1000,
# #     valid_sets=[lgb_train],
# #     early_stopping_rounds=50,
# #     verbose_eval=100
# # )