In [1]:
import gc
import glob
import os
from datetime import date

import polars as pl

In [ ]:
COLUMN_TO_DTYPE = {
    "store_nbr": pl.UInt8,
    "cluster": pl.UInt8,
    "perishable": pl.UInt8,
    "class": pl.UInt16,
    "transactions": pl.UInt16,
    "oil_price": pl.Float32,
    "onpromotion": pl.UInt8,
    "item_nbr": pl.UInt32,
    "unit_sales": pl.Float32,
}
START_DATE = date(2015, 1, 1)
END_DATE = date(2016, 6, 1)

In [2]:
def convert_to_parquet(download_dir: str):
    files = glob.glob(f"{download_dir}/*.csv")
    for file in files:
        print(f"Converting {file}")
        target_file = file.replace("csv", "parquet")
        pl.scan_csv(file, try_parse_dates=True).sink_parquet(target_file)
        os.remove(file)


convert_to_parquet("../data/favorita")

Converting ../data/favorita/train.csv
Converting ../data/favorita/transactions.csv
Converting ../data/favorita/items.csv
Converting ../data/favorita/oil.csv
Converting ../data/favorita/holidays_events.csv
Converting ../data/favorita/stores.csv


In [3]:
temporal_df_head = pl.scan_parquet("../data/favorita/train.parquet").fetch(10)
pl.DataFrame({
    "col": temporal_df_head.columns,
    "dtype": temporal_df_head.dtypes
})

col,dtype
str,object
"""id""",Int64
"""date""",Date
"""store_nbr""",Int64
"""item_nbr""",Int64
"""unit_sales""",Float64
"""onpromotion""",String


In [ ]:
def downcast_dataframe(dataframe: pl.DataFrame | pl.LazyFrame, streaming: bool = True) -> pl.DataFrame | pl.LazyFrame:
    cols = []
    for i in dataframe.columns:
        if i not in COLUMN_TO_DTYPE:
            print(f"{i} not in COLUMN_TO_DTYPE")
            continue
        
        cols.append(pl.col(i).cast(COLUMN_TO_DTYPE[i]))
    
    dataframe = dataframe.with_columns(cols)
    is_lazy = isinstance(dataframe, pl.LazyFrame)
    if is_lazy:
        dataframe = dataframe.collect(streaming=streaming)
    dataframe = dataframe.shrink_to_fit(in_place=True).rechunk()
    #if is_lazy:
    #    dataframe = dataframe.lazy()
    return dataframe


In [None]:
temporal_df: pl.LazyFrame = (
    pl.scan_parquet("../data/favorita/train.parquet")
    .drop("id")
    # cutoff dataset to reduce memory requirements.
    #.filter(pl.col("date") >= START_DATE)
    #.filter(pl.col("date") <= END_DATE)
    .with_columns(pl.col("onpromotion").map_batches(lambda x: None if x is None else x == "True"))
    .with_columns(pl.format("{}_{}", "store_nbr", "item_nbr").alias("traj_id"))
    # remove_returns_data
    .filter(pl.col("unit_sales").min().over("traj_id") >= 0)
    .with_columns(open=pl.lit(1).cast(pl.Int8))
    .sort("date", "traj_id")
    .collect(streaming=True)
    .upsample("date", every="1d", by="traj_id")
    .lazy()
    .with_columns(
        [
            pl.col(i).fill_null(strategy="forward")
            for i in ["store_nbr", "item_nbr", "onpromotion"]
        ]
    )
    .with_columns(pl.col("open").fill_null(0))
    .with_columns(pl.col("unit_sales").log())
    .rename({"unit_sales": "log_sales"})
    .with_columns(pl.col("log_sales").fill_null(strategy="forward"))
    .pipe(downcast_dataframe, streaming=True)
)
gc.collect()
temporal_df.head(10)

  .with_columns(pl.col("onpromotion").map(lambda x: None if x is None else x == "True"))


In [ ]:
store_info = pl.scan_parquet("../data/favorita/stores.parquet").pipe(downcast_dataframe)
items = pl.scan_parquet("../data/favorita/items.parquet").pipe(downcast_dataframe)
transactions = pl.scan_parquet("../data/favorita/transactions.parquet").pipe(downcast_dataframe)
oil = (
    pl.scan_parquet("../data/favorita/oil.parquet")
    .rename({"dcoilwtico": "oil_price"})
    .pipe(downcast_dataframe)
)
holidays = pl.scan_parquet("../data/favorita/holidays_events.parquet")

national_holidays = (
    holidays.filter(pl.col("locale") == "National")
    .select(["description", "date"])
    .rename({"description": "national_hol"})
    .pipe(downcast_dataframe)
)
regional_holidays = (
    holidays.filter(pl.col("locale") == "Regional")
    .select(["description", "locale_name", "date"])
    .rename({"locale_name": "state", "description": "regional_hol"})
    .pipe(downcast_dataframe)
)
local_holidays = (
    holidays.filter(pl.col("locale") == "Local")
    .select(["description", "locale_name", "date"])
    .rename({"locale_name": "city", "description": "local_hol"})
    .pipe(downcast_dataframe)
)

In [ ]:
joined_df = (
    temporal_df.join(oil, on="date", how="left")
    .with_columns(pl.col("oil_price").fill_null(strategy="forward"))
    .join(store_info, on="store_nbr")
    .join(items, on="item_nbr")
    .join(transactions, on=["store_nbr", "date"])
    .with_columns(pl.col("transactions").fill_null(strategy="forward"))
    .join(national_holidays, on="date", how="left")
    .join(regional_holidays, on=["date", "state"], how="left")
    .join(local_holidays, on=["date", "city"], how="left")
    .with_columns(
        [
            pl.col("national_hol").fill_null(""),
            pl.col("regional_hol").fill_null(""),
            pl.col("local_hol").fill_null(""),
            pl.col("date").dt.year().alias("year"),
            pl.col("date").dt.month().alias("month"),
            pl.col("date").dt.day().alias("day_of_month"),
            pl.col("date").dt.weekday().alias("day_of_week"),
        ]
    )
    .filter(pl.col("oil_price").is_not_null())
    .sort("traj_id", "date")
    .collect(streaming=True)
    .shrink_to_fit(in_place=True)
    .rechunk()
)