In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import polars as pl
from datetime import datetime
import gc
from temporal_fusion_transformer.src.experiments.favorita import read_temporal

start_date = datetime(2016, 1, 1)
end_date = datetime(2016, 6, 1)
data_dir = "../data/favorita"

In [3]:
gc.collect()

0

In [4]:
temporal = read_temporal(data_dir)
temporal.collect(streaming=True)

date,store_nbr,item_nbr,log_sales,onpromotion,traj_id,open
date,u8,u32,f32,u8,str,i8
2016-01-02,54,594045,0.693147,0,"""54_594045""",1
2016-01-03,54,594045,0.0,1,"""54_594045""",1
2016-01-04,54,594045,0.693147,0,"""54_594045""",1
2016-01-05,54,594045,0.693147,0,"""54_594045""",1
2016-01-06,54,594045,0.0,0,"""54_594045""",1
2016-01-07,54,594045,0.693147,0,"""54_594045""",1
2016-01-08,54,594045,0.693147,0,"""54_594045""",1
2016-01-09,54,594045,0.693147,0,"""54_594045""",1
2016-01-10,54,594045,0.693147,0,"""54_594045""",1
2016-01-11,54,594045,0.693147,0,"""54_594045""",1


In [6]:
store_info = (
    pl.read_parquet(f"{data_dir}/stores.parquet")
    .with_columns(pl.col("cluster").cast(pl.UInt8))
    .shrink_to_fit(in_place=True)
    .rechunk()
    .lazy()
)

items = (
    pl.read_parquet(f"{data_dir}/items.parquet")
    .with_columns([pl.col("perishable").cast(pl.UInt8), pl.col("class").cast(pl.UInt16)])
    .shrink_to_fit(in_place=True)
    .rechunk()
    .lazy()
)
transactions = (
    pl.scan_parquet(f"{data_dir}/transactions.parquet")
    .with_columns([pl.col("store_nbr").cast(pl.UInt8), pl.col("transactions").cast(pl.UInt16)])
    .collect()
    .shrink_to_fit(in_place=True)
    .rechunk()
    .lazy()
)
oil = (
    pl.read_parquet(f"{data_dir}/oil.parquet")
    .rename({"dcoilwtico": "oil_price"})
    .with_columns(pl.col("oil_price").cast(pl.Float32))
    .shrink_to_fit(in_place=True)
    .rechunk()
    .lazy()
)
holidays = pl.scan_parquet(f"{data_dir}/holidays_events.parquet")

national_holidays = (
    holidays.filter(pl.col("locale") == "National")
    .select(["description", "date"])
    .rename({"description": "national_hol"})
    .collect()
    .shrink_to_fit(in_place=True)
    .rechunk()
    .lazy()
)
regional_holidays = (
    holidays.filter(pl.col("locale") == "Regional")
    .select(["description", "locale_name", "date"])
    .rename({"locale_name": "state", "description": "regional_hol"})
    .collect()
    .shrink_to_fit(in_place=True)
    .rechunk()
    .lazy()
)
local_holidays = (
    holidays.filter(pl.col("locale") == "Local")
    .select(["description", "locale_name", "date"])
    .rename({"locale_name": "city", "description": "local_hol"})
    .collect()
    .shrink_to_fit(in_place=True)
    .rechunk()
    .lazy()
)

In [30]:
df: pl.DataFrame = (
    temporal.join(oil, on="date", how="left")
    .join(store_info, on="store_nbr", how="left")
    .join(items, on="item_nbr", how="left")
    .join(transactions, on=["store_nbr", "date"])
    .join(national_holidays, on="date", how="left")
    .join(regional_holidays, on=["date", "state"], how="left")
    .join(local_holidays, on=["date", "city"], how="left")
    # .with_columns(
    #    [
    #        pl.col("oil_price").fill_null(strategy="forward"),
    #        pl.col("national_hol").fill_null(""),
    #        pl.col("regional_hol").fill_null(""),
    #        pl.col("local_hol").fill_null(""),
    #        pl.col("date").dt.month().alias("month"),
    #        pl.col("date").dt.day().alias("day_of_month"),
    #        pl.col("date").dt.weekday().alias("day_of_week"),
    #    ]
    # )
    # .filter(pl.col("oil_price").is_not_null())
    # .sort("date", "traj_id")
    .collect(streaming=True)
    .shrink_to_fit(in_place=True)
    .rechunk()
)

df

date,store_nbr,item_nbr,log_sales,onpromotion,traj_id,open,oil_price,city,state,type,cluster,family,class,perishable,transactions,national_hol,regional_hol,local_hol
date,u8,u32,f32,u8,str,i8,f32,str,str,str,u8,str,u16,u8,u16,str,str,str
2016-01-02,54,594045,0.693147,0,"""54_594045""",1,,,,,,,,,1004,,,
2016-01-05,54,594045,0.693147,0,"""54_594045""",1,35.970001,,,,,,,,866,,,
2016-01-06,54,594045,0.0,0,"""54_594045""",1,33.970001,,,,,,,,808,,,
2016-01-07,54,594045,0.693147,0,"""54_594045""",1,33.290001,,,,,,,,733,,,
2016-01-08,54,594045,0.693147,0,"""54_594045""",1,33.200001,,,,,,,,868,,,
2016-01-09,54,594045,0.693147,0,"""54_594045""",1,,,,,,,,,814,,,
2016-01-10,54,594045,0.693147,0,"""54_594045""",1,,,,,,,,,1134,,,
2016-01-11,54,594045,0.693147,0,"""54_594045""",1,31.42,,,,,,,,827,,,
2016-01-12,54,594045,0.693147,0,"""54_594045""",1,30.42,,,,,,,,784,,,
2016-01-13,54,594045,0.693147,0,"""54_594045""",1,30.42,,,,,,,,674,,,


In [31]:
df.describe()

describe,date,store_nbr,item_nbr,log_sales,onpromotion,traj_id,open,oil_price,city,state,type,cluster,family,class,perishable,transactions,national_hol,regional_hol,local_hol
str,str,f64,f64,f64,f64,str,f64,f64,str,str,str,f64,str,f64,f64,f64,str,str,str
"""count""","""19802907""",19802907.0,19802907.0,19802907.0,19802907.0,"""19802907""",19802907.0,19802907.0,"""19802907""","""19802907""","""19802907""",19802907.0,"""19802907""",19802907.0,19802907.0,19802907.0,"""19802907""","""19802907""","""19802907"""
"""null_count""","""0""",0.0,0.0,0.0,0.0,"""0""",0.0,6430704.0,"""19802907""","""19802907""","""19802907""",19802907.0,"""19802907""",19802907.0,19802907.0,0.0,"""14669476""","""19802907""","""19802907"""
"""mean""",,27.418253,1091500.0,1.209991,0.069158,,1.0,37.697075,,,,,,,,1723.330538,,,
"""std""",,15.868059,534417.495393,1.074361,0.253722,,0.0,6.446909,,,,,,,,952.885533,,,
"""min""","""2016-01-02""",1.0,99197.0,-6.907756,0.0,"""10_1001305""",1.0,26.190001,,,,,,,,6.0,"""Batalla de Pic…",,
"""25%""",,13.0,668752.0,0.0,0.0,,1.0,31.65,,,,,,,,1054.0,,,
"""50%""",,28.0,1143686.0,1.098612,0.0,,1.0,37.299999,,,,,,,,1425.0,,,
"""75%""",,42.0,1463790.0,1.94591,0.0,,1.0,42.759998,,,,,,,,2130.0,,,
"""max""","""2016-06-01""",54.0,2037487.0,10.695168,1.0,"""9_999547""",1.0,49.360001,,,,,,,,6194.0,"""Viernes Santo""",,


In [36]:
store_info.select("store_nbr").collect().dtypes

[Int64]

In [35]:
temporal.select("store_nbr").collect().dtypes

[UInt8]