In [4]:
from modeling_module.training.forecaster import make_calendar_exo
import os
import time
from datetime import date

import numpy as np
import polars as pl
import torch
import sys


dir = "C:/Users/USER/PycharmProjects/ts_forecaster_lib/raw_data/"         # default project directory\
save_dir = os.path.join(dir, 'fit')
os.makedirs(save_dir, exist_ok = True)
save_root = os.path.join(save_dir, 'Xpatchtst', '20260119')

plan_week = 201103
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lookback = 52
horizon = 27
batch_size = 256
freq = 'weekly'
split_mode = 'multi'
shuffle = True
id_col = 'unique_id'
date_col = 'date'
y_col = 'y'

# add past exogenous continuous variable columns
past_exo_cont_cols = (
    # "exo_p_y_lag_1w",
    "exo_p_y_lag_2w",
    # "exo_p_y_lag_52w",
    "exo_p_y_rollmean_4w","exo_p_y_rollmean_12w","exo_p_y_rollstd_4w",
    # "exo_p_weeks_since_holiday",
    # "exo_p_temperature",
    # "exo_p_fuel_price",
    # "exo_p_cpi",
    # "exo_p_unemployment",
    # "exo_p_markdown_sum",
    # "exo_p_markdown1",
    # "exo_p_markdown2",
    # "exo_p_markdown3",
    # "exo_p_markdown4",
    # "exo_p_markdown5",
    # "exo_markdown1_isnull",
    # "exo_markdown2_isnull",
    # "exo_markdown3_isnull",
    # "exo_markdown4_isnull",
    # "exo_markdown5_isnull",
)

# add past exogenous categorical variable columns
past_exo_cat_cols = (
    # "exo_c_woy_bucket",
)

def my_exo_cb(start_idx: int, Hm: int, device="cuda" if torch.cuda.is_available() else "cpu"):
    # exo_dim = 2 (sin, cos)
    return make_calendar_exo(start_idx, Hm, period=52, device=device)

future_exo_cb = my_exo_cb

# real dataframe
df = pl.read_parquet(dir + 'train_data/walmart_best_feature_train.parquet')

In [5]:
from modeling_module.training.forecater import DMSForecaster
import numpy as np
import polars as pl
import torch
from datetime import date, timedelta



# -----------------------------
# YYYYWW(ISO week) 유틸
# -----------------------------
def yyyyww_to_monday(yyyyww: int) -> date:
    y = int(yyyyww) // 100
    w = int(yyyyww) % 100
    return date.fromisocalendar(y, w, 1)

def monday_to_yyyyww(d: date) -> int:
    iso_y, iso_w, _ = d.isocalendar()
    return int(iso_y) * 100 + int(iso_w)

def add_week(yyyyww: int, add: int) -> int:
    return monday_to_yyyyww(yyyyww_to_monday(yyyyww) + timedelta(weeks=int(add)))


# -----------------------------
# 결과 테이블 생성 함수
# -----------------------------
@torch.no_grad()
def make_forecast_result_table(
    *,
    inference_loader,
    base_model: torch.nn.Module,
    quantile_model: torch.nn.Module,
    plan_week: int,
    horizon: int,
    device: str = "cuda",
    max_parts: int = 10_000,          # 필요시 제한
) -> pl.DataFrame:
    """
    inference_loader 배치에서 (x, y, uid, fe, pe_cont, pe_cat)을 받아
    - base_model: point forecast
    - quantile_model: q50 forecast
    를 만든 후,
    (oper_part_no, forecast_week) 단위의 long table로 반환
    """

    base_fc = DMSForecaster(base_model, target_channel=0, fill_mode="copy_last")
    q_fc    = DMSForecaster(quantile_model, target_channel=0, fill_mode="copy_last")

    rows = []
    n_parts = 0

    for batch in inference_loader:
        # exogenous_test의 inspect()와 동일한 배치 구조 가정
        x, uid, fe, pe_cont, pe_cat = batch

        B = x.size(0)
        for i in range(B):
            if n_parts >= max_parts:
                break

            x1 = x[i:i+1]
            pid = uid[i]  # oper_part_no로 사용(필요하면 mapping)

            fe1 = fe[i:i+1] if fe is not None else None
            pec1 = pe_cont[i:i+1] if pe_cont is not None else None
            pek1 = pe_cat[i:i+1] if pe_cat is not None else None

            # Base(Point) 예측
            pred_base = base_fc.predict(
                x1,
                horizon=int(horizon),
                device=device,
                mode="eval",
                part_ids=[pid] if pid is not None else None,
                past_exo_cont=pec1,
                past_exo_cat=pek1,
                future_exo_batch=fe1,
            )
            y_base = np.asarray(pred_base["point"]).reshape(-1)  # (H,)

            # Quantile(q50) 예측
            pred_q = q_fc.predict(
                x1,
                horizon=int(horizon),
                device=device,
                mode="eval",
                part_ids=[pid] if pid is not None else None,
                past_exo_cont=pec1,
                past_exo_cat=pek1,
                future_exo_batch=fe1,
            )
            # forecater_v2는 quantile이면 q50을 point로도 넣어주지만, 명시적으로 q50 우선 사용
            y_q50 = np.asarray(pred_q.get("q50", pred_q["point"])).reshape(-1)  # (H,)

            # 주차 생성: plan_week 포함하여 horizon개
            weeks = [add_week(plan_week, h) for h in range(int(horizon))]

            # long rows 적재
            pid_str = str(pid.item()) if torch.is_tensor(pid) and pid.numel() == 1 else str(pid)
            for w, bval, qval in zip(weeks, y_base.tolist(), y_q50.tolist()):
                rows.append({
                    "plan_week": int(plan_week),
                    "oper_part_no": pid_str,
                    "forecast_week": int(w),
                    "base_forecast": float(bval),
                    "quantile_forecast": float(qval),
                })

            n_parts += 1

        if n_parts >= max_parts:
            break

    return pl.DataFrame(rows)

In [6]:
from modeling_module.utils.checkpoint import load_model_dict
from modeling_module.models import build_patchTST_quantile, build_patchTST_base
from modeling_module.data_loader import MultiPartExoDataModule
from modeling_module.training.model_trainers.total_train import run_total_train_weekly

def inspect(loader, name):
    b = next(iter(loader))
    x, uid, fe, pe_cont, pe_cat = b
    print(f"[{name}] x:", x.shape, x.device, x.dtype)
    print(f"[{name}] fe:", fe.shape, fe.device, fe.dtype)
    print(f"[{name}] pe:", pe_cont.shape, pe_cont.device, pe_cont.dtype)
    # print(f"[{name}] future_exo_cb is None?", loader.collate_fn.future_exo_cb is None)
    if fe.shape[-1] > 0:
        print(f"[{name}] fe sample:", fe[0, :3, :])

data_module = MultiPartExoDataModule(
    df = df,
    id_col = id_col,
    date_col = date_col,
    y_col = y_col,
    lookback = lookback,
    horizon = horizon,
    batch_size = batch_size,
    past_exo_cont_cols = past_exo_cont_cols,
    past_exo_cat_cols = past_exo_cat_cols,
    future_exo_cb = future_exo_cb,
    freq = freq,
    shuffle = shuffle,
    split_mode = split_mode,
)

inference_loader = data_module.get_inference_loader_at_plan(plan_dt = plan_week)

# for batch in inference_loader:
#     print(batch)
#     break

inspect(inference_loader, 'inference_loader')

builders = {
    'patchtst_quantile': build_patchTST_quantile,
    'patchtst': build_patchTST_base,
}

model_loader = load_model_dict(save_root, builders, device = device)
base_model = model_loader['patchtst']
quantile_model = model_loader['patchtst_quantile']

result_df = make_forecast_result_table(
    inference_loader=inference_loader,
    base_model=base_model,
    quantile_model=quantile_model,
    plan_week=plan_week,
    horizon=horizon,
    device=device,
    max_parts=10_000,
)



[inference_loader] x: torch.Size([45, 52, 1]) cpu torch.float32
[inference_loader] fe: torch.Size([45, 27, 2]) cpu torch.float32
[inference_loader] pe: torch.Size([45, 52, 4]) cpu torch.float32
[inference_loader] fe sample: tensor([[ 0.7474, -0.6643],
        [ 0.6617, -0.7498],
        [ 0.5679, -0.8231]])
[load] patchtst_quantile ← C:/Users/USER/PycharmProjects/ts_forecaster_lib/raw_data/fit\Xpatchtst\20260119\weekly_PatchTSTQuantile_L52_H27.pt
[DBG-backbone-init] d_past_cont=4 cont_input_dim=48 target_input_dim=12 total_input_dim=60
[load][patchtst_quantile] missing=2 unexpected=0
  missing sample: ['revin_layer.mean', 'revin_layer.std']
  unexpected sample: []
[load] patchtst ← C:/Users/USER/PycharmProjects/ts_forecaster_lib/raw_data/fit\Xpatchtst\20260119\weekly_PatchTSTBase_L52_H27.pt
[DBG-backbone-init] d_past_cont=4 cont_input_dim=48 target_input_dim=12 total_input_dim=60
[load][patchtst] missing=2 unexpected=0
  missing sample: ['revin_layer.mean', 'revin_layer.std']
  unexpec

In [7]:
result_df

plan_week,oper_part_no,forecast_week,base_forecast,quantile_forecast
i64,str,i64,f64,f64
201103,"""1""",201103,890844.25,1.3086e6
201103,"""1""",201104,858298.5,1.3666e6
201103,"""1""",201105,1.2527e6,1.4038e6
201103,"""1""",201106,1.2568e6,1.4219e6
201103,"""1""",201107,1.4789e6,1343859.5
…,…,…,…,…
201103,"""9""",201125,501744.6875,519144.125
201103,"""9""",201126,547227.5625,500923.84375
201103,"""9""",201127,520996.375,496747.4375
201103,"""9""",201128,495675.6875,500610.875
