# PatchTST AB Test (Walmart) — Refactored Notebook

Variants:
- **A0**: y-only (no exogenous)
- **A1**: time/calendar exogenous (sin/cos)
- **A2**: time/calendar + holiday (vectorized, batch-safe)

Key fixes:
- `compose_exo_calendar_cb` supports **batched** `start_idx` so the DataLoader collate can call it once per batch.
- `get_train_loader(batch_size=...)` is called explicitly to avoid silent fallback to the default 32.


In [1]:
import os
import time
from datetime import date

import numpy as np
import polars as pl
import torch
import sys

# LTB modules (expected to exist in your repo)
from modeling_module.data_loader import MultiPartExoDataModule
from modeling_module.training.model_trainers.total_train import run_total_train_weekly

# optional
import matplotlib.pyplot as plt

'''
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
https://developer.nvidia.com/cuda-12-8-0-download-archive
'''

MAC_DIR = "/Users/igwanhyeong/PycharmProjects/ts_forecaster_lib/raw_data/"
WINDOW_DIR = "C:/Users/USER/PycharmProjects/ts_forecaster_lib/raw_data/"

DIR = WINDOW_DIR if sys.platform == "win32" else MAC_DIR
device = "cuda" if torch.cuda.is_available() else "cpu"

print("DIR:", DIR)
print("device:", device)
if device == "cuda":
    print("cuda:", torch.version.cuda, "gpu_count:", torch.cuda.device_count())

save_dir = os.path.join(DIR, "fit", "walmart_patchtst_ab")
os.makedirs(save_dir, exist_ok=True)

save_root_A0 = os.path.join(save_dir, "A0_y_only")
save_root_A1 = os.path.join(save_dir, "A1_time_exog")
save_root_A2 = os.path.join(save_dir, "A2_time_holiday")



DIR: C:/Users/USER/PycharmProjects/ts_forecaster_lib/raw_data/
device: cuda
cuda: 12.8 gpu_count: 1


In [2]:
walmart_df = pl.read_parquet(DIR + 'train_data/walmart_train.parquet')
walmart_df.head()

unique_id,dt,y,is_holiday
i64,i64,f64,bool
1,201005,24924.5,False
1,201006,46039.49,True
1,201007,41595.55,False
1,201008,19403.54,False
1,201009,21827.9,False


In [3]:
lookback = 52
horizon = 27
batch_size = 512

freq = "weekly"          # walmart dt is weekly
split_mode = "multi"     # id-disjoint split (leakage-safe)
shuffle = True

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_float32_matmul_precision("high")

print("device:", device)

device: cuda


In [4]:
# -----------------------------
# YYYYWW -> day_idx (Monday of ISO week)
# day_idx: days since 1970-01-01 (int)
# -----------------------------
EPOCH = date(1970, 1, 1)

def _yyyyww_to_monday_dayidx(yyyyww: int) -> int:
    y = int(yyyyww) // 100
    w = int(yyyyww) % 100
    monday = date.fromisocalendar(y, w, 1)
    return (monday - EPOCH).days

from datetime import date, timedelta


def _monday_dayidx_to_yyyyww(dayidx: int) -> int:
    """
    dayidx: days since 1970-01-01 (EPOCH)
    가정: dayidx는 '월요일'에 해당하는 날짜(ISO week Monday)
    반환: YYYYWW (int)
    """
    d = EPOCH + timedelta(days=int(dayidx))
    iso_y, iso_w, iso_d = d.isocalendar()

    # 안전장치: monday로 만든 값이 아니면 여기서 바로 티나게 함(원하면 제거 가능)
    # iso_d: Monday=1 ... Sunday=7
    if iso_d != 1:
        # 월요일이 아니어도 해당 날짜가 속한 ISO week로 변환은 가능하지만,
        # 지금 파이프라인 의도(week anchor)가 깨졌을 수 있어 경고 성격으로 둠
        # raise ValueError(f"dayidx={dayidx} is not Monday (iso_d={iso_d})")
        pass

    return int(iso_y) * 100 + int(iso_w)

df = (
    walmart_df
    .with_columns([
        pl.col("unique_id").cast(pl.Utf8),
        pl.col("dt").cast(pl.Int32).alias("yyyyww"),
        pl.col("y").cast(pl.Float32),
        pl.col("is_holiday").cast(pl.Float32).alias("exo_is_holiday"),
    ])
    .with_columns([
        pl.col("yyyyww").map_elements(_yyyyww_to_monday_dayidx, return_dtype=pl.Int32).alias("date_idx")
    ])
    .with_columns([pl.col("yyyyww").alias("date")])
    .select(["unique_id", "date", "date_idx", "y", "exo_is_holiday"])
)

# # Optional: complete missing weeks per store (step=7 days), fill y=0, holiday=0
# def complete_weekly(g: pl.DataFrame) -> pl.DataFrame:
#     g = g.sort("date_idx")
#     mn = int(g["date_idx"].min())
#     mx = int(g["date_idx"].max())
#     full = pl.DataFrame({"date_idx": pl.int_range(mn, mx + 1, step=7, eager=True).cast(pl.Int32)})
#     out = full.join(g, on="date_idx", how="left").with_columns([
#         pl.col("unique_id").fill_null(g["unique_id"][0]),
#         pl.col("date").fill_null(pl.col("date_idx").map_elements(_monday_dayidx_to_yyyyww, return_dtype=pl.Int32)),
#         pl.col("y").fill_null(0.0),
#         pl.col("exo_is_holiday").fill_null(0.0),
#     ])
#     return out
#
# df = pl.concat([complete_weekly(g) for g in df.partition_by("unique_id")], how="vertical").sort(["unique_id","date_idx"])
#
# print("df:", df.shape)
# df
df

unique_id,date,date_idx,y,exo_is_holiday
str,i32,i32,f32,f32
"""1""",201005,14641,24924.5,0.0
"""1""",201006,14648,46039.488281,1.0
"""1""",201007,14655,41595.550781,0.0
"""1""",201008,14662,19403.539062,0.0
"""1""",201009,14669,21827.900391,0.0
…,…,…,…,…
"""45""",201239,15607,508.369995,0.0
"""45""",201240,15614,628.099976,0.0
"""45""",201241,15621,1061.02002,0.0
"""45""",201242,15628,760.01001,0.0


In [5]:
from modeling_module.utils.exogenous_utils import compose_exo_calendar_cb

future_exo_cb_time = compose_exo_calendar_cb(date_type=freq)
future_exo_cb_time

<function modeling_module.utils.exogenous_utils.compose_exo_calendar_cb.<locals>.cb(start_idx: Union[int, Sequence[int], numpy.ndarray], H: int, device: str = 'cpu')>

In [6]:
# ============================================================
# Holiday lookup (vectorized) + FutureExo callback (time + holiday)
# ============================================================
holiday_map_dayidx = {
    int(row[0]): float(row[1])
    for row in (
        df.select(["date_idx", "exo_is_holiday"])
          .group_by("date_idx")
          .agg(pl.max("exo_is_holiday").alias("exo_is_holiday"))
          .sort("date_idx")
          .iter_rows()
    )
}

def build_holiday_array(holiday_map_dayidx: dict[int, float], *, pad: int = 0) -> np.ndarray:
    if not holiday_map_dayidx:
        return np.zeros((1,), dtype=np.float32)
    max_k = max(int(k) for k in holiday_map_dayidx.keys())
    arr = np.zeros((max_k + 1 + int(pad),), dtype=np.float32)
    for k, v in holiday_map_dayidx.items():
        kk = int(k)
        if kk >= 0:
            arr[kk] = float(v)
    return arr

class FutureExoTimePlusHoliday:
    def __init__(self, holiday_by_dayidx: np.ndarray, *, step_days: int = 7):
        self.holiday = holiday_by_dayidx.astype(np.float32, copy=False)
        self.step_days = int(step_days)

    def __call__(self, start_idx, H: int, device: str = "cpu"):
        # 1) calendar exo (batch-safe)
        cal = future_exo_cb_time(start_idx, H, device=device)  # scalar: (H,E) | batch: (B,H,E)

        # 2) holiday exo (vectorized in numpy)
        is_scalar = isinstance(start_idx, (int, np.integer))
        if is_scalar:
            s = np.asarray([int(start_idx)], dtype=np.int64)
        else:
            s = np.asarray(start_idx, dtype=np.int64).reshape(-1)

        B = s.shape[0]
        H = int(H)

        offsets = (self.step_days * np.arange(H, dtype=np.int64))[None, :]  # (1,H)
        idx = s[:, None] + offsets                                          # (B,H)

        hol = np.zeros((B, H), dtype=np.float32)
        valid = (idx >= 0) & (idx < self.holiday.shape[0])
        hol[valid] = self.holiday[idx[valid]]
        hol_t = torch.from_numpy(hol).unsqueeze(-1)                         # (B,H,1), CPU

        target_device = cal.device  # cal이 이미 cuda일 수 있음
        cal = cal.to(target_device, dtype=torch.float32)
        hol_t = hol_t.to(target_device, dtype=torch.float32)

        # 3) concat
        if is_scalar:
            out = torch.cat([cal.to(torch.float32).unsqueeze(0), hol_t], dim=-1)[0]  # (H,E+1)
        else:
            out = torch.cat([cal.to(torch.float32), hol_t], dim=-1)                  # (B,H,E+1)

        return out

holiday_by_dayidx = build_holiday_array(holiday_map_dayidx, pad=7 * (horizon + 2))
future_exo_cb_time_plus_holiday = FutureExoTimePlusHoliday(holiday_by_dayidx, step_days=7)


In [7]:
# ============================================================
# Build DataModules / Loaders (A0/A1/A2)
# ============================================================
def build_datamodule(variant: str) -> MultiPartExoDataModule:
    if variant == "A0":
        future_exo_cb = None
    elif variant == "A1":
        future_exo_cb = future_exo_cb_time
    elif variant == "A2":
        future_exo_cb = future_exo_cb_time_plus_holiday
    else:
        raise ValueError(variant)

    return MultiPartExoDataModule(
        df=df,
        id_col="unique_id",
        date_col="date",
        y_col="y",
        lookback=lookback,
        horizon=horizon,
        batch_size=batch_size,
        past_exo_cont_cols=None,
        past_exo_cat_cols=None,
        future_exo_cb=future_exo_cb,
        freq=freq,
        shuffle=shuffle,
        split_mode=split_mode,
    )

data_module_A0 = build_datamodule("A0")
data_module_A1 = build_datamodule("A1")
data_module_A2 = build_datamodule("A2")

# IMPORTANT: pass batch_size explicitly to avoid default(32) inside get_train_loader()
train_loader_A0 = data_module_A0.get_train_loader(batch_size=batch_size, shuffle=True, num_workers=0)
val_loader_A0   = data_module_A0.get_val_loader()

train_loader_A1 = data_module_A1.get_train_loader(batch_size=batch_size, shuffle=True, num_workers=0)
val_loader_A1   = data_module_A1.get_val_loader()

train_loader_A2 = data_module_A2.get_train_loader(batch_size=batch_size, shuffle=True, num_workers=0)
val_loader_A2   = data_module_A2.get_val_loader()

def inspect(loader, name: str):
    t0 = time.time()
    x, y, uid, fe, pe_cont, pe_cat = next(iter(loader))
    dt = time.time() - t0

    print(f"[{name}] batch_time: {dt:.3f}s")
    print(f"[{name}] x : {tuple(x.shape)} | {x.device} | {x.dtype}")
    print(f"[{name}] fe: {tuple(fe.shape)} | {fe.device} | {fe.dtype}")
    print(f"[{name}] future_exo_cb is None?: {loader.collate_fn.future_exo_cb is None}")
    if fe.shape[-1] > 0:
        print(f"[{name}] fe sample (first 3 steps):\n{fe[0, :3, :]}")

inspect(train_loader_A0, "A0")
inspect(train_loader_A1, "A1")
inspect(train_loader_A2, "A2")


[A0] batch_time: 0.105s
[A0] x : (512, 52, 1) | cpu | torch.float32
[A0] fe: (512, 27, 0) | cpu | torch.float32
[A0] future_exo_cb is None?: True
[future_exo_cb] batch call OK | miss=512 | time=0.001s
[A1] batch_time: 0.016s
[A1] x : (512, 52, 1) | cpu | torch.float32
[A1] fe: (512, 27, 2) | cpu | torch.float32
[A1] future_exo_cb is None?: False
[A1] fe sample (first 3 steps):
tensor([[-1.2028e-01,  9.9274e-01],
        [ 5.2432e-04,  1.0000e+00],
        [ 1.2132e-01,  9.9261e-01]])
[future_exo_cb] batch call OK | miss=512 | time=0.000s
[A2] batch_time: 0.017s
[A2] x : (512, 52, 1) | cpu | torch.float32
[A2] fe: (512, 27, 3) | cpu | torch.float32
[A2] future_exo_cb is None?: False
[A2] fe sample (first 3 steps):
tensor([[ 0.8849, -0.4657,  0.0000],
        [ 0.8233, -0.5676,  0.0000],
        [ 0.7487, -0.6629,  0.0000]])


In [8]:
def inspect(loader, name):
    b = next(iter(loader))
    x, y, uid, fe, pe_cont, pe_cat = b
    print(f"[{name}] x:", x.shape, x.device, x.dtype)
    print(f"[{name}] fe:", fe.shape, fe.device, fe.dtype)
    print(f"[{name}] future_exo_cb is None?", loader.collate_fn.future_exo_cb is None)
    if fe.shape[-1] > 0:
        print(f"[{name}] fe sample:", fe[0, :3, :])

# inspect(train_loader_A0, "A0")
inspect(train_loader_A1, "A1")
# inspect(train_loader_A2, "A2")

import time

t0 = time.time()
b = next(iter(train_loader_A1))
print("first batch ok, dt=", time.time() - t0)

t0 = time.time()
b = next(iter(val_loader_A1))
print("first val batch ok, dt=", time.time() - t0)

[A1] x: torch.Size([512, 52, 1]) cpu torch.float32
[A1] fe: torch.Size([512, 27, 2]) cpu torch.float32
[A1] future_exo_cb is None? False
[A1] fe sample: tensor([[ 0.6617, -0.7498],
        [ 0.5679, -0.8231],
        [ 0.4643, -0.8857]])
first batch ok, dt= 0.01300191879272461
[future_exo_cb] batch call OK | miss=512 | time=0.001s
first val batch ok, dt= 0.00599980354309082


In [11]:
# ============================================
# CELL 5) 학습 실행 (LTB total_train 포맷 유지)
# - 여기서는 "외생변수 A/B/C"만 비교하므로 use_ssl_pretrain=False로 고정
# ============================================
from modeling_module.training.model_trainers.total_train import run_total_train_weekly
print('run result_A0')
# results_A0 = run_total_train_weekly(
#     train_loader_A0,
#     val_loader_A0,
#     device=device,
#     lookback=lookback,
#     horizon=horizon,
#     warmup_epochs=5,
#     spike_epochs=10,
#     save_dir=save_root_A0,
#     use_exogenous_mode = False,
#     models_to_run=["patchtst"],
#     use_ssl_pretrain=False,
# )

# print('run result_A1')
# results_A1 = run_total_train_weekly(
#     train_loader_A1,
#     val_loader_A1,
#     device=device,
#     lookback=lookback,
#     horizon=horizon,
#     warmup_epochs=5,
#     spike_epochs=10,
#     save_dir=save_root_A1,
#     use_exogenous_mode = True,
#     models_to_run=["patchtst"],
#     use_ssl_pretrain=False,
# )

print('run result_A2')
results_A2 = run_total_train_weekly(
    train_loader_A2,
    val_loader_A2,
    device=device,
    lookback=lookback,
    horizon=horizon,
    warmup_epochs=5,
    spike_epochs=10,
    save_dir=save_root_A2,
    use_exogenous_mode = True,
    models_to_run=["patchtst"],
    use_ssl_pretrain=False,
)


run result_A0
run result_A2

[total_train] === RUN: patchtst (weekly) ===
exogenous dimension:: 2
PatchTST Base (Weekly)
[future_exo_cb] batch call OK | miss=9 | time=0.001s
[train_patchtst] point head rebuilt: d_future 2 -> 3
[train_patchtst] loader provides fe_cont(E=3), so future_exo_cb disabled.

[train_patchtst] ===== Stage 1/2 =====
  - spike: OFF
  - epochs: 5 | lr=0.0003 | horizon_decay=False
[train_patchtst] Effective TrainingConfig:
{
  "device": "cuda",
  "log_every": 100,
  "use_amp": true,
  "lookback": 52,
  "horizon": 27,
  "epochs": 5,
  "lr": 0.0003,
  "weight_decay": 0.001,
  "t_max": 40,
  "patience": 100,
  "max_grad_norm": 30.0,
  "amp_device": "cuda",
  "loss_mode": "point",
  "point_loss": "huber",
  "huber_delta": 0.8,
  "q_star": 0.5,
  "use_cost_q_star": false,
  "Cu": 1.0,
  "Co": 1.0,
  "quantiles": [
    0.1,
    0.5,
    0.9
  ],
  "use_intermittent": true,
  "alpha_zero": 3.0,
  "alpha_pos": 1.0,
  "gamma_run": 0.3,
  "cap": null,
  "use_horizon_decay": f

KeyboardInterrupt: 

In [None]:
from modeling_module.utils.metrics import mae, smape

from modeling_module.utils.metrics import rmse


def load_model_ckpt(model, ckpt_path: str, device: str):
    state = torch.load(ckpt_path, map_location="cpu")
    model.load_state_dict(state["model_state"], strict=True)
    model.to(device)
    model.eval()
    return model

@torch.no_grad()
def eval_on_loader(model, loader, device: str, future_exo_cb=None):
    model.eval()
    ys, yhats = [], []

    for batch in loader:
        x = batch[0].to(device)              # [B, 52, 1]
        y = batch[1].to(device)              # [B, 27]
        start_idx = batch[2]  # <-- 여기! 실제 batch 포맷에 맞게 수정 (현재 batch[2]가 part_ids일 수도 있음)

        if future_exo_cb is not None:
            # start_idx가 tensor면 python int로
            if torch.is_tensor(start_idx):
                start_idx = int(start_idx[0].item()) if start_idx.numel() > 0 else int(start_idx.item())
            fe = future_exo_cb(start_idx, model.horizon, device=device)  # [H, D] 또는 [B,H,D] 규약 확인 필요
            if fe.dim() == 2:
                fe = fe.unsqueeze(0).expand(x.size(0), -1, -1)  # [B,H,D]
            future_exo = fe.to(device)
        else:
            future_exo = None
        past_exo_cont = batch[4].to(device)  # [B, 52, 11]
        past_exo_cat = batch[5].to(device)   # [B, 52, 0]

        # PatchTSTPointModel.forward 시그니처에 맞춰 전달
        yhat = model(
            x,
            future_exo=future_exo,
            past_exo_cont=past_exo_cont,
            past_exo_cat=past_exo_cat,
        )  # [B, 27]

        ys.append(y.detach().cpu().numpy())
        yhats.append(yhat.detach().cpu().numpy())

    y_all = np.concatenate(ys, axis=0)
    yhat_all = np.concatenate(yhats, axis=0)
    return y_all, yhat_all


def _extract_pred_from_output(out, prefer_q=0.5):
    """
    Quantile 모델 출력(out)이 dict/tuple/tensor 등일 수 있으므로
    예측 텐서를 안전하게 꺼낸다.

    반환:
      - yhat_point: [B,H] (예: q=0.5 또는 첫 번째 출력)
      - extra: 원본 out (필요 시 분석용)
    """
    # 1) Tensor면 그대로
    if torch.is_tensor(out):
        return out, out

    # 2) tuple/list면 첫 원소를 예측으로 가정
    if isinstance(out, (tuple, list)):
        # 가장 흔한 케이스: (pred, aux) 또는 (q_pred, ...)
        first = out[0]
        if torch.is_tensor(first):
            return first, out
        # 더 복잡하면 아래 dict 로직으로 넘기기 위해 out를 dict처럼 처리 불가 -> 에러
        raise TypeError(f"Unsupported tuple/list output types: {[type(x) for x in out]}")

    # 3) dict면 키 후보들에서 찾기
    if isinstance(out, dict):
        # 흔한 키 후보들
        key_candidates = [
            "yhat", "pred", "prediction", "y_pred", "output",
            "q_pred", "quantiles", "yq", "y_hat"
        ]
        for k in key_candidates:
            if k in out and torch.is_tensor(out[k]):
                t = out[k]
                # [B,H] 또는 [B,H,Q] 또는 [B,Q,H] 등 가능
                return _select_point_from_quantile_tensor(t, prefer_q=prefer_q), out

        # dict 안에 텐서가 하나뿐이면 그걸 쓰기
        tensor_items = [(k, v) for k, v in out.items() if torch.is_tensor(v)]
        if len(tensor_items) == 1:
            t = tensor_items[0][1]
            return _select_point_from_quantile_tensor(t, prefer_q=prefer_q), out

        raise KeyError(f"Cannot find tensor prediction in dict output. keys={list(out.keys())}")

    raise TypeError(f"Unsupported model output type: {type(out)}")


def _select_point_from_quantile_tensor(t: torch.Tensor, prefer_q=0.5):
    """
    t가
      - [B,H]이면 그대로
      - [B,H,Q]이면 Q축에서 median(0.5)에 해당하는 index 선택
      - [B,Q,H]이면 Q축에서 선택 후 [B,H]로
    """
    if t.ndim == 2:
        return t

    if t.ndim == 3:
        B, d1, d2 = t.shape

        # [B,H,Q] 케이스 가정: d1이 horizon, d2가 quantile
        # [B,Q,H] 케이스 가정: d1이 quantile, d2가 horizon
        # heuristic: horizon은 보통 27 같은 값, quantile은 보통 3/5/9 등 작은 값
        if d1 > d2:
            # [B,H,Q]
            q_dim = 2
            q_len = d2
            h_dim = 1
        else:
            # [B,Q,H]
            q_dim = 1
            q_len = d1
            h_dim = 2

        # prefer_q=0.5에 해당하는 index를 선택 (가능하면 중앙)
        # q 값 리스트를 out에 포함하지 않는 경우가 많으니 중앙 index를 사용
        q_idx = int(round((q_len - 1) * prefer_q))  # median index

        if q_dim == 2:
            # [B,H,Q] -> [B,H]
            return t[:, :, q_idx]
        else:
            # [B,Q,H] -> [B,H]
            return t[:, q_idx, :]

    raise ValueError(f"Unexpected prediction tensor shape: {t.shape}")

@torch.no_grad()
def eval_on_loader_quantile(model, loader, device, prefer_q=0.5, future_exo_cb=None):
    model.eval()
    ys, yhats = [], []

    for batch in loader:
        # 기존 unpack 유지 (당신 eval에서 batch[0], batch[1] ... 쓰는 구조)
        x = batch[0].to(device)
        y = batch[1].to(device)

        # ---- 중요: future_exo 생성
        # dataset이 start_idx를 batch에 포함하지 않으면, 아래 두 줄은 dataset에서 가져오는 방식으로 조정 필요
        # 보통은 batch에 start_idx 또는 date_idx 같은 메타가 들어있거나, part_ids가 들어있습니다.
        start_idx = batch[2]  # <-- 여기! 실제 batch 포맷에 맞게 수정 (현재 batch[2]가 part_ids일 수도 있음)
        if future_exo_cb is not None:
            # start_idx가 tensor면 python int로
            if torch.is_tensor(start_idx):
                start_idx = int(start_idx[0].item()) if start_idx.numel() > 0 else int(start_idx.item())
            fe = future_exo_cb(start_idx, model.horizon, device=device)  # [H, D] 또는 [B,H,D] 규약 확인 필요
            if fe.dim() == 2:
                fe = fe.unsqueeze(0).expand(x.size(0), -1, -1)  # [B,H,D]
            future_exo = fe.to(device)
        else:
            future_exo = None

        past_exo_cont = batch[4].to(device)
        past_exo_cat  = batch[5].to(device)

        out = model(x, future_exo=future_exo, past_exo_cont=past_exo_cont, past_exo_cat=past_exo_cat)

        yhat_point, _ = _extract_pred_from_output(out, prefer_q=prefer_q)
        ys.append(y.detach().cpu().numpy())
        yhats.append(yhat_point.detach().cpu().numpy())

    return np.concatenate(ys), np.concatenate(yhats)

In [None]:


# ============================================
# CELL 6) 체크포인트 로드 + 평가 (첨부 노트북 흐름 유지)
# - quantile 모델이면 eval_on_loader_quantile 사용
# - point 모델이면 eval_on_loader 사용
# ============================================
from modeling_module.utils.checkpoint import load_model_dict
from modeling_module.models import build_patchTST_base, build_patchTST_quantile
from modeling_module.utils.metrics import mae, smape, rmse

# builders는 "당신이 실제 저장한 모델 키"에 맞추세요.
# (첨부 노트북에서는 patchtst_quantile을 사용)
builders = {
    # "patchtst_quantile": build_patchTST_quantile,
    "patchtst": build_patchTST_base,
}

# model_A0 = load_model_dict(save_root_A0, builders, device=device)["patchtst_quantile"]
# model_A1 = load_model_dict(save_root_A1, builders, device=device)["patchtst_quantile"]
# model_A2 = load_model_dict(save_root_A2, builders, device=device)["patchtst_quantile"]
model_A0 = load_model_dict(save_root_A0, builders, device = device)['patchtst']
model_A1 = load_model_dict(save_root_A1, builders, device = device)['patchtst']
model_A2 = load_model_dict(save_root_A2, builders, device = device)['patchtst']


# 첨부 노트북에서 사용하던 eval 유틸을 그대로 쓴다고 가정합니다.
# (이미 앞 셀에 정의돼 있거나, 별도 모듈에 있으면 import 하세요.)
# from your_notebook_utils import eval_on_loader_quantile

# y0, yhat0 = eval_on_loader_quantile(model_A0, val_loader_A0, device, prefer_q=0.5)
# y1, yhat1 = eval_on_loader_quantile(model_A1, val_loader_A1, device, prefer_q=0.5, future_exo_cb = future_exo_cb_time)
# y2, yhat2 = eval_on_loader_quantile(model_A2, val_loader_A2, device, prefer_q=0.5, future_exo_cb = future_exo_cb_time_plus_holiday)

y0, yhat0 = eval_on_loader(model_A0, val_loader_A0, device=device)
y1, yhat1 = eval_on_loader(model_A1, val_loader_A1, device=device, future_exo_cb = future_exo_cb_time)
y2, yhat2 = eval_on_loader(model_A2, val_loader_A2, device=device, future_exo_cb = future_exo_cb_time_plus_holiday)

metric_A0 = {
    "MAE": float(mae(y0.reshape(-1), yhat0.reshape(-1))),
    "RMSE": float(rmse(y0.reshape(-1), yhat0.reshape(-1))),
    "SMAPE": float(smape(y0.reshape(-1), yhat0.reshape(-1))),
}
metric_A1 = {
    "MAE": float(mae(y1.reshape(-1), yhat1.reshape(-1))),
    "RMSE": float(rmse(y1.reshape(-1), yhat1.reshape(-1))),
    "SMAPE": float(smape(y1.reshape(-1), yhat1.reshape(-1))),
}
metric_A2 = {
    "MAE": float(mae(y2.reshape(-1), yhat2.reshape(-1))),
    "RMSE": float(rmse(y2.reshape(-1), yhat2.reshape(-1))),
    "SMAPE": float(smape(y2.reshape(-1), yhat2.reshape(-1))),
}

metric_A0, metric_A1, metric_A2

In [None]:
import os, glob, torch

def peek_ckpt(save_root):
    ckpt_path = sorted(glob.glob(os.path.join(save_root, "*.pt")) +
                       glob.glob(os.path.join(save_root, "*.pth")) +
                       glob.glob(os.path.join(save_root, "*.ckpt")))[-1]
    ckpt = torch.load(ckpt_path, map_location="cpu")
    print("ckpt:", ckpt_path)
    cfg = ckpt.get("cfg") or ckpt.get("config") or ckpt.get("model_cfg")
    print("cfg type:", type(cfg))
    if isinstance(cfg, dict):
        # 흔히 head 또는 d_future가 여기 들어있습니다.
        print("cfg keys:", list(cfg.keys())[:30])
        if "d_future" in cfg: print("cfg[d_future] =", cfg["d_future"])
        if "head" in cfg: print("cfg[head] =", cfg["head"])
    else:
        # dataclass/namespace일 수도
        print("cfg dict:", getattr(cfg, "__dict__", None))

# peek_ckpt(save_root_A0)
peek_ckpt(save_root_A1)
# peek_ckpt(save_root_A2)


In [None]:
# ============================================
# CELL 7) (선택) 예측 시각화 (첨부 노트북 스타일)
# ============================================
import matplotlib.pyplot as plt

def plot_samples(y_true, preds: dict, max_n: int = 64):
    n = min(max_n, y_true.shape[0])
    fig, axes = plt.subplots(nrows=n, ncols=1, figsize=(12, 2.2*n), sharex=True)

    if n == 1:
        axes = [axes]

    for i in range(n):
        ax = axes[i]
        ax.plot(y_true[i], label="true")
        for k, v in preds.items():
            ax.plot(v[i], label=k)
        ax.set_title(f"sample={i}", fontsize=9)
        if i == 0:
            ax.legend(loc="upper right", fontsize=8)

    plt.tight_layout()
    plt.show()

plot_samples(
    y_true=y0,
    preds={
        "A0_y_only": yhat0,
        "A1_time": yhat1,
        "A2_time+holiday": yhat2,
    },
    max_n=32,
)