In [None]:
# ============================================================================
# Hierarchical LightGBM 
# ============================================================================

import pandas as pd
import numpy as np
import lightgbm as lgb
from pathlib import Path
from tqdm import tqdm
import warnings, gc, pickle
warnings.filterwarnings("ignore")

from m5_wrmsse import wrmsse 

# ----------------------------------------------------------------------------
# 1) SETUP
# ----------------------------------------------------------------------------

RAW_DIR = Path("../data/raw")
OUT_DIR = Path("../data/hier_fast_results")
OUT_DIR.mkdir(parents=True, exist_ok=True)

FAST_HISTORY_DAYS = 600
LAGS = [7, 14, 28]
ROLLS = [7, 14, 28]
N_ROUNDS = 200

print("Config FAST:")
print(f"- History days: {FAST_HISTORY_DAYS}")
print(f"- Lags: {LAGS}, Rolling: {ROLLS}")
print(f"- Rounds: {N_ROUNDS}")

# ----------------------------------------------------------------------------
# 2) CARICAMENTO DATI
# ----------------------------------------------------------------------------
calendar = pd.read_csv(RAW_DIR / "calendar.csv",
                       usecols=["date","d","wm_yr_wk","wday","month","year",
                                "snap_CA","snap_TX","snap_WI"])
prices = pd.read_csv(RAW_DIR / "sell_prices.csv",
                     usecols=["store_id","item_id","wm_yr_wk","sell_price"])
sales = pd.read_csv(RAW_DIR / "sales_train_evaluation.csv")

# Riduci sales alle ultime FAST_HISTORY_DAYS + 28
d_cols = [c for c in sales.columns if c.startswith("d_")]
d2date = dict(zip(calendar["d"], calendar["date"]))
dates = pd.to_datetime([d2date[d] for d in d_cols])
last_date = dates.max()
keep_mask = dates >= (last_date - pd.Timedelta(days=FAST_HISTORY_DAYS + 28))
keep_cols = [c for c, m in zip(d_cols, keep_mask) if m]
sales_small = pd.concat([sales[sales.columns[:6]], sales[keep_cols]], axis=1)

# Melt 
long = sales_small.melt(id_vars=["id","item_id","dept_id","cat_id","store_id","state_id"],
                        var_name="d", value_name="sales")
long = long.merge(calendar[["d","date","wm_yr_wk","wday","month","year",
                            "snap_CA","snap_TX","snap_WI"]],
                  on="d", how="left")
long["date"] = pd.to_datetime(long["date"])
long = long.merge(prices, on=["store_id","item_id","wm_yr_wk"], how="left")

# SNAP unico
long["snap"] = 0
long.loc[(long["state_id"]=="CA") & (long["snap_CA"]==1),"snap"]=1
long.loc[(long["state_id"]=="TX") & (long["snap_TX"]==1),"snap"]=1
long.loc[(long["state_id"]=="WI") & (long["snap_WI"]==1),"snap"]=1

# Tenere solo colonne utili immediate
long = long[["id","item_id","dept_id","cat_id","store_id","state_id",
             "date","sales","sell_price","wday","month","year","snap"]].copy()

# Cut finale: storico FAST_HISTORY_DAYS + 28 test
max_date = long["date"].max()
cut_date = max_date - pd.Timedelta(days=28)
hist_start = cut_date - pd.Timedelta(days=FAST_HISTORY_DAYS)
long = long[(long["date"]>=hist_start) & (long["date"]<=max_date)].copy()

# ----------------------------------------------------------------------------
# 3) FEATURE ENGINEERING 
# ----------------------------------------------------------------------------
assert "sales" in long.columns, "Colonna 'sales' mancante prima del FE"

# Encodings categorici
for c in ["state_id","store_id","dept_id","item_id"]:
    long[c+"_enc"] = pd.factorize(long[c])[0]

# Date features leggere
long["is_weekend"] = long["wday"].isin([1,7]).astype(int)

def add_group_lag_roll(df, key_cols, val_col, prefix):
    df = df.sort_values(key_cols + ["date"])
    grp = df.groupby(key_cols, sort=False)[val_col]
    for lag in LAGS:
        df[f"{prefix}_lag_{lag}"] = grp.shift(lag).values
    for w in ROLLS:
        df[f"{prefix}_roll_{w}"] = grp.shift(1).rolling(w).mean().values
    return df

# 3.1) Bottom level lags/rolls su 'id' usando 'sales' (resta nel main DF)
long = add_group_lag_roll(long, ["id"], "sales", "b")

# 3.2) Item level: medie per (item_id, date) -> DF separato, prefissi 'it_'
item_agg = (
    long.groupby(["item_id","date"], as_index=False)["sales"]
        .mean()
        .rename(columns={"sales":"sales_item"})
)
item_agg = add_group_lag_roll(item_agg, ["item_id"], "sales_item", "it")
item_agg_keep = ["item_id","date"] + [c for c in item_agg.columns if c.startswith("it_")]
long = long.merge(item_agg[item_agg_keep], on=["item_id","date"], how="left")

# 3.3) Dept-Store level: medie per (dept_id, store_id, date) -> 'ds_'
ds_agg = (
    long.groupby(["dept_id","store_id","date"], as_index=False)["sales"]
        .mean()
        .rename(columns={"sales":"sales_dept_store"})
)
ds_agg = add_group_lag_roll(ds_agg, ["dept_id","store_id"], "sales_dept_store", "ds")
ds_agg_keep = ["dept_id","store_id","date"] + [c for c in ds_agg.columns if c.startswith("ds_")]
long = long.merge(ds_agg[ds_agg_keep], on=["dept_id","store_id","date"], how="left")

# 3.4) State-Store level: medie per (state_id, store_id, date) -> 'ss_'
ss_agg = (
    long.groupby(["state_id","store_id","date"], as_index=False)["sales"]
        .mean()
        .rename(columns={"sales":"sales_state_store"})
)
ss_agg = add_group_lag_roll(ss_agg, ["state_id","store_id"], "sales_state_store", "ss")
ss_agg_keep = ["state_id","store_id","date"] + [c for c in ss_agg.columns if c.startswith("ss_")]
long = long.merge(ss_agg[ss_agg_keep], on=["state_id","store_id","date"], how="left")

assert "sales" in long.columns, "'sales' è stato perso dopo i merge"

# Cleanup
del item_agg, ds_agg, ss_agg
gc.collect()

# ----------------------------------------------------------------------------
# 4) SPLIT TRAIN/TEST E FEATURE SET
# ----------------------------------------------------------------------------
train_df = long[long["date"]<=cut_date].copy()
test_df  = long[long["date"]> cut_date].copy()

base_feats = ["wday","month","year","is_weekend","snap",
              "sell_price","state_id_enc","store_id_enc","dept_id_enc","item_id_enc"]

lag_feats  = [f"b_lag_{l}"  for l in LAGS] + [f"it_lag_{l}" for l in LAGS] + \
             [f"ds_lag_{l}" for l in LAGS] + [f"ss_lag_{l}" for l in LAGS]

roll_feats = [f"b_roll_{w}"  for w in ROLLS] + [f"it_roll_{w}" for w in ROLLS] + \
             [f"ds_roll_{w}" for w in ROLLS] + [f"ss_roll_{w}" for w in ROLLS]

FEATS = base_feats + lag_feats + roll_feats

# Drop righe con NaN feature (solo training)
train_df = train_df.dropna(subset=FEATS)

# ----------------------------------------------------------------------------
# 5) TRAIN: 28 MODELLI (NON-RECURSIVE) - CALLBACKS LGB 4.x
# ----------------------------------------------------------------------------
params = {
    "objective":"poisson",
    "metric":"rmse",
    "learning_rate":0.08,
    "num_leaves":31,
    "max_depth":8,
    "feature_fraction":0.8,
    "bagging_fraction":0.8,
    "bagging_freq":1,
    "verbose":-1,
    "n_jobs":-1
}

models = {}
print("Training 28 modelli")
for h in tqdm(range(1, 29)):
    Xh = train_df.copy()
    grp = Xh.groupby("id", sort=False)
    for col in lag_feats + roll_feats:
        Xh[col] = grp[col].shift(h).values

    Xh = Xh.dropna(subset=FEATS)
    y = Xh["sales"].values
    X = Xh[FEATS].values

    dtrain = lgb.Dataset(X, y)

    models[h] = lgb.train(
        params=params,
        train_set=dtrain,
        num_boost_round=N_ROUNDS,
        valid_sets=[dtrain],
        valid_names=["train"],
        callbacks=[lgb.log_evaluation(period=0)]  # sostituisce verbose_eval
    )

    del Xh, X, y, dtrain
    gc.collect()

# ----------------------------------------------------------------------------
# 6) PREDICT
# ----------------------------------------------------------------------------
print("Predict 28 giorni...")
test_days = sorted(test_df["date"].unique())
assert len(test_days) >= 28
test_days = test_days[:28]

pred_list = []
for h, day in tqdm(list(zip(range(1,29), test_days))):
    Xtest = test_df[test_df["date"]==day].copy()
    Xmat = Xtest[FEATS].fillna(0).values
    pred = models[h].predict(Xmat)
    out = Xtest[["id"]].copy()
    out["h"] = h
    out["forecast"] = np.clip(pred, 0, None)
    pred_list.append(out)

pred_all = pd.concat(pred_list, axis=0)
pivot = pred_all.pivot(index="id", columns="h", values="forecast").reindex(
    sales["id"].tolist()).fillna(0)
pivot.columns = [f"F{i}" for i in range(1, 29)]

# ----------------------------------------------------------------------------
# 7) WRMSSE
# ----------------------------------------------------------------------------
forecast_array = pivot.values
score = wrmsse(forecast_array)
print(f"\n✅ Hierarchical WRMSSE: {score:.4f}")

# ----------------------------------------------------------------------------
# 8) SALVATAGGIO
# ----------------------------------------------------------------------------
pivot.to_pickle(OUT_DIR / "hier_fast_forecasts.pkl")
with open(OUT_DIR / "hier_fast_summary.pkl","wb") as f:
    pickle.dump({"wrmsse": float(score),
                 "history_days": FAST_HISTORY_DAYS,
                 "lags": LAGS,
                 "rolls": ROLLS,
                 "rounds": N_ROUNDS}, f)
print("Salvato in hier_fast_results/")