In [None]:
import sys

import polars as pl
from data_loader.TimeSeriesModule import MultiPartDataModule
from model_runner.train.patchmixer_train import PatchMixerTrain
from models.patchmixer.PatchMixer import *

'''
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
https://developer.nvidia.com/cuda-12-8-0-download-archive
'''

MAC_DIR = '../data/'
WINDOW_DIR = 'C:/Users/USER/PycharmProjects/research/data/'

if sys.platform == 'win32':
    DIR = WINDOW_DIR
    print(torch.cuda.is_available())
    print(torch.cuda.device_count())
    print(torch.version.cuda)
    print(torch.__version__)
    print(torch.cuda.get_device_name(0))
    print(torch.__version__)
else:
    DIR = MAC_DIR

In [None]:
target_dyn_demand = pl.read_parquet(DIR + 'target_dyn_demand.parquet')
data_module = MultiPartDataModule(
    df = target_dyn_demand,
    config = PatchMixerConfigMonthly(),
    batch_size = 128,
    val_ratio = 0.2
)

plan_yyyymm = 202012

train_loader = data_module.get_train_loader()
val_loader = data_module.get_val_loader()
anchor_loader = data_module.get_inference_loader_at_plan(
    plan_yyyymm,
    parts_filter=None,
    fill_missing='ffill'
)

In [None]:
# Training
# model, _ = PatchMixerQuantileModel(PatchMixerConfig())
#
# result = PatchMixerTrain().base_train(
#     train_loader=train_loader,
#     val_loader=val_loader,
#     model=model,
#     epochs= 100
# )
# torch.save(model.state_dict(), DIR + 'patch_mixer_20250918_l36_h48.pt')


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ckpt_path = DIR + 'patch_mixer_20250918_l36_h48.pt'

model = PatchMixerQuantileModel(PatchMixerConfigMonthly())

state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state)

model.to(device).eval()


## 1) 검증 세트에서 예측 + 핵심 지표 계산
### 참고

공정성을 위해 검증 시에는 가중치 없이 그대로 평가(학습 중 간헐 가중을 썼더라도).

P50 커버리지(50% 구간)는 P25/P75가 없으니 대신 P10~P90(80%) 구간으로 모니터. 더 많은 분위수를 예측하도록 헤드를 늘리면 50/90% 커버리지도 바로 확인 가능.

In [None]:
# 1) Get (B, Q, H) prediction and (B, H, 1) trues
from model_runner.inferences.patch_mixer_inference import evaluate_PatchMixer_with_truth
preds_q, trues, part_ids = evaluate_PatchMixer_with_truth(
    model, val_loader, device = device, q_use = 0.5, return_full = True
) # preds_q: (B, Q, H), trues (B, H, 1)

# 2) Tensor reshape
y_true = trues.squeeze(-1) # (B, H)
q10, q50, q90 = preds_q[:, 0, :], preds_q[:, 1, :], preds_q[:, 2, :] # (B, H)

# 3) Point estimator (P50)
mae = (q50 - y_true).abs().mean().item()
rmse = torch.sqrt(((q50 - y_true) ** 2).mean()).item()
smape = (2.0 * (q50 - y_true).abs() / (y_true.abs() + q50.abs() + 1e-8)).mean().item()

# 4) 확률 지표: Pinball(평균), 커버리지/폭 (80% 구간: P10 ~ P90)
from utils.losses import pinball_loss_weighted
pinball = pinball_loss_weighted(preds_q, y_true, quantiles = (0.1, 0.5, 0.9), weights = None).item()

coverage_80 = ((y_true >= q10) & (y_true <= q90)).float().mean().item()
avg_interval_width = (q90 - q10).mean().item()

print(f"MAE={mae:.4f} RMSE={rmse:.4f} sMAPE={smape:.4f}")
print(f"Pinball(avg) = {pinball:.4f} Coverage@80={coverage_80:.4f} Width@80={avg_interval_width:.4f}")

## 2) 앵커 시점(plan_yyyymm) 배치 예측
### (A) 간단 배치 예측 (P50만)
### (B) 분위수 전체 저장(P10/P50/P90)

In [None]:
from model_runner.inferences.patch_mixer_inference import evaluate_PatchMixer
preds_anchor, parts_anchor = evaluate_PatchMixer(
    model, anchor_loader, device = device, q_use = 0.5, return_full = False
) # preds_anchor: (B, H)
print('(A)', preds_anchor.shape, len(parts_anchor))

preds_anchor_q, parts_anchor = evaluate_PatchMixer(
    model, anchor_loader, device = device, q_use = 0.5, return_full = True
)
print('(B)', preds_anchor_q)

## 3) Overlap-Averaging 기반 DMS 추론(안정화 포함, 옵션)
### 학습 시 DMS(직접 다스텝)로 훈련된 모델이라면, 겹침 평균 + 1-step 안정화 가드로 더 매끈한 결과를 얻을 수 있음.

In [None]:
from model_runner.inferences.patch_mixer_inference import PatchMixer_DMSForecaster
from tqdm import tqdm

dms = PatchMixer_DMSForecaster(model, target_channel = 0, fill_mode = 'copy_last')
all_preds = []
all_parts = []

all_x_first = None
with torch.no_grad():
    for i, batch in enumerate(tqdm(anchor_loader)):
        if len(batch) == 2:
            x, part = batch
        else:
            x, _, part = batch

        all_parts.extend(part)

        if i == 0: all_x_first = x # first batch x for history graph
        x = x.to(device)

        y_hat = dms.forecast_overlap_avg(
            x, horizon = PatchMixerConfigMonthly().horizon,
            q_use = 0.5, return_full = False,
            use_winsor = True, use_multi_guard = True, use_dampen = True,
            clip_q = (0.05, 0.95), clip_mul = 2.0, # 강도는 도메인에 맞게 수정
            max_growth = 1.10, max_step_up = 0.10, max_step_down = 0.30,
            damp_min = 0.2, damp_max = 0.6
        )   # (B, H)
        all_preds.append(y_hat.cpu())

preds_anchor_dms = torch.cat(all_preds, dim = 0) # (B_total, H)

## 4) IMS(autoregressive) 추론

In [None]:
from model_runner.inferences.patch_mixer_inference import PatchMixer_IMSForecaster

ims = PatchMixer_IMSForecaster(model, target_channel = 0, fill_mode = 'copy_last')
model.eval()
all_preds = []
with torch.no_grad():
    for batch in tqdm(anchor_loader):
        if len(batch) == 2:
            x, part = batch
        else:
            x, _, part = batch
        x = x.to(device)

        y_hat = ims.forecast(
            x, horizon = PatchMixerConfigMonthly().horizon,
            q_use = 0.5, return_full = False,
            use_winsor = True, use_dampen = True, damp = 0.5
        )

        all_preds.append(y_hat.cpu())

preds_anchor_ims = torch.cat(all_preds, dim = 0)

In [None]:
from utils.plot_utils import plot_two_preds_same_ylim

# 첫 배치 히스토리 vs IMS/DMS 비교
plot_two_preds_same_ylim(
    x_batch = all_x_first,                             # 첫 배치 x
    y_hat_a = preds_anchor_ims[:all_x_first.size(0)], # 같은 배치 크기만큼
    y_hat_b = preds_anchor_dms[:all_x_first.size(0)],
    parts   = all_parts[:all_x_first.size(0)],
    k = 3,
    target_channel = 0,
    outdir = "./plots",
    prefix = "ims_vs_dms",
    label_a = "IMS",
    label_b = "DMS",
)

In [None]:
from utils.plot_utils import plot_anchored_forecasts_yyyymm_multi

plot_anchored_forecasts_yyyymm_multi(
    df            = target_dyn_demand,
    parts         = all_parts,  # ← 여기서 사용!
    preds_dict    = {"IMS": preds_anchor_ims, "DMS": preds_anchor_dms},
    plan_yyyymm   = plan_yyyymm,
    lookback      = PatchMixerConfigMonthly().lookback,
    k             = 343,
    outdir        = "./plots",
    prefix        = "anchored_ims_vs_dms"
)