In [None]:
import sys

import polars as pl
from data_loader.TimeSeriesModule import MultiPartDataModule
from model_runner.model_configs import PatchMixerConfigMonthly
from model_runner.train.patchmixer_train import PatchMixerTrain
from models.patchmixer.PatchMixer import *
from utils.validation_utils import compute_validation_metrics_for_patchmixer

'''
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
https://developer.nvidia.com/cuda-12-8-0-download-archive
'''

MAC_DIR = '../data/'
WINDOW_DIR = 'C:/Users/USER/PycharmProjects/research/data/'

if sys.platform == 'win32':
    DIR = WINDOW_DIR
    print(torch.cuda.is_available())
    print(torch.cuda.device_count())
    print(torch.version.cuda)
    print(torch.__version__)
    print(torch.cuda.get_device_name(0))
    print(torch.__version__)
else:
    DIR = MAC_DIR

### Data Loader

In [None]:
target_dyn_demand = pl.read_parquet(DIR + 'target_dyn_demand_monthly.parquet')

config = PatchMixerConfigMonthly()

data_module = MultiPartDataModule(
    df = target_dyn_demand,
    config = config,
    is_running = False,
    batch_size = 128,
    val_ratio = 0.2
)

plan_yyyymm = 202401
train_loader = data_module.get_train_loader()
val_loader = data_module.get_val_loader()
anchor_loader = data_module.get_inference_loader_at_plan(
    plan_dt = plan_yyyymm,
    parts_filter=None,
    fill_missing='ffill'
)

from utils.helper import collect_indices

fp_tr = collect_indices(train_loader)
fp_va = collect_indices(val_loader)
print("fp shapes:", fp_tr.shape, fp_va.shape)
print("approx equal:", torch.allclose(fp_tr.mean(0), fp_va.mean(0)))

### Model Training

In [None]:
# Training
model = PatchMixerQuantileModel(config)
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer, T_max = 10)

trained_model, tr_hist, va_hist = PatchMixerTrain().base_train(
    train_loader=train_loader,
    val_loader=val_loader,
    model=model,
    point_loss = 'mae',
    epochs= 100,
    # 손실 기본값은 그대로 두되,
    use_intermittent=True, # 학습에서만 밸런스 가중 + 하위 분위수 가중
    # 검증은 내부에서 plain pinball로 자동 처리
)


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

xb, yb, *rest = next(iter(train_loader))
print("y batch min/max:", yb.min().item(), yb.max().item())

model.eval()
with torch.no_grad():
    pred = model(xb.to(next(model.parameters()).device))
pred_cpu = pred.detach().cpu()

print("pred min/max:", pred_cpu.min().item(), pred_cpu.max().item())

xb, *rest = next(iter(anchor_loader))
xb = xb.to(device)
model.eval()
with torch.no_grad():
    raw = model(xb)  # (B,H) or (B,Q,H)
print("RAW on anchor min/max:", raw.min().item(), raw.max().item())

print(model.head)  # 또는 model
for name, m in model.named_modules():
    if isinstance(m, (nn.Sigmoid, nn.Tanh)):
        print("Found bounded activation:", name, m)


### Model Save

In [None]:
# Full Save
torch.save({
    'model_state': trained_model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'scheduler_state': scheduler.state_dict(),
    'config': vars(config),
    'train_loss_hist': tr_hist,
    'val_loss_hist': va_hist,
}, DIR + 'fit/patch_mixer_20250919_ltb_full.pt')

# only for weight
# torch.save(train_model.state_dict(), DIR + 'patch_mixer_20250919_l54_h_27.pt')

### Trained Model Recall

In [None]:
# ------------------ Full Recall ------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ckpt = torch.load(DIR + 'fit/patch_mixer_20250919_ltb_full.pt', map_location=device)
cfg = PatchMixerConfigMonthly()
model = PatchMixerQuantileModel(cfg).to(device)
model.load_state_dict(ckpt["model_state"])
optimizer.load_state_dict(ckpt["optimizer_state"])
scheduler.load_state_dict(ckpt["scheduler_state"])

# ------------------ Weight Recall ------------------
# ckpt_path = DIR + 'patch_mixer_20250918_l36_h48.pt'
# state = torch.load(ckpt_path, map_location=device)
# model.load_state_dict(state)

model.to(device).eval()


## 1) 검증 세트에서 예측 + 핵심 지표 계산
### 참고

공정성을 위해 검증 시에는 가중치 없이 그대로 평가(학습 중 간헐 가중을 썼더라도).

P50 커버리지(50% 구간)는 P25/P75가 없으니 대신 P10~P90(80%) 구간으로 모니터. 더 많은 분위수를 예측하도록 헤드를 늘리면 50/90% 커버리지도 바로 확인 가능.

In [None]:
metrics = compute_validation_metrics_for_patchmixer(model, val_loader, device = device)
print(
    "MAE={MAE:.4f} RMSE={RMSE:.4f} sMAPE={sMAPE:.4f} | "
    "Pinball={Pinball(avg):.4f} | "
    "COV[q]: {COV@0.1:.3f},{COV@0.5:.3f},{COV@0.9:.3f} | "
    "I80(cov/wid): {Coverage@80:.3f}/{Width@80:.3f}".format(**metrics)
)

### 2) 앵커 시점(plan_yyyymm) 배치 예측
#### (A) 간단 배치 예측 (P50만)
#### (B) 분위수 전체 저장(P10/P50/P90)

In [None]:
from model_runner.inferences.patch_mixer_inference import evaluate_PatchMixer
preds_anchor, parts_anchor = evaluate_PatchMixer(
    model, anchor_loader, device = device, q_use = 0.5, return_full = False
) # preds_anchor: (B, H)
print('(A)', preds_anchor.shape, len(parts_anchor))

preds_anchor_q, parts_anchor = evaluate_PatchMixer(
    model, anchor_loader, device = device, q_use = 0.5, return_full = True
)
print('(B)', preds_anchor_q)

### 3) Overlap-Averaging 기반 DMS 추론(안정화 포함, 옵션)
학습 시 DMS(직접 다스텝)로 훈련된 모델이라면, 겹침 평균 + 1-step 안정화 가드로 더 매끈한 결과를 얻을 수 있음.

In [None]:
import inspect
from model_runner.inferences.patch_mixer_inference import PatchMixer_DMSForecaster
from tqdm import tqdm
print("outer model id:", id(model))
print("outer model file:", inspect.getfile(model.__class__))
dms = PatchMixer_DMSForecaster(model, target_channel = 0, fill_mode = 'copy_last')
print("dms.model id:", id(dms.model))
print("dms class file:", inspect.getfile(PatchMixer_DMSForecaster))
all_preds = []
all_parts = []

all_x_first = None
with torch.no_grad():
    for i, batch in enumerate(tqdm(anchor_loader)):
        if len(batch) == 2:
            x, part = batch
        else:
            x, _, part = batch

        all_parts.extend(part)

        if i == 0: all_x_first = x # first batch x for history graph
        x = x.to(device)

        y_hat = dms.forecast_overlap_avg(
            x, horizon = 60,
            q_use = 0.5, return_full = False,
            use_winsor = True, use_multi_guard = True, use_dampen = True,
            clip_q = (0.05, 0.95), clip_mul = 2.0, # 강도는 도메인에 맞게 수정
            max_growth = 1.10, max_step_up = 0.10, max_step_down = 0.30,
            damp_min = 0.2, damp_max = 0.6
        )   # (B, H)
        all_preds.append(y_hat.cpu())

preds_anchor_dms = torch.cat(all_preds, dim = 0) # (B_total, H)

## 4) IMS(autoregressive) 추론

In [None]:
from model_runner.inferences.patch_mixer_inference import PatchMixer_IMSForecaster

ims = PatchMixer_IMSForecaster(model, target_channel = 0, fill_mode = 'copy_last')
model.eval()
all_preds = []
with torch.no_grad():
    for batch in tqdm(anchor_loader):
        if len(batch) == 2:
            x, part = batch
        else:
            x, _, part = batch
        x = x.to(device)

        y_hat = ims.forecast(
            x, horizon = 60,
            q_use = 0.5, return_full = False,
            use_winsor = True,use_dampen = True, damp = 0.5
        )

        all_preds.append(y_hat.cpu())

preds_anchor_ims = torch.cat(all_preds, dim = 0)

In [None]:
from utils.plot_utils import plot_anchored_forecasts_yyyymm_multi

plot_anchored_forecasts_yyyymm_multi(
    df            = target_dyn_demand,
    parts         = all_parts,
    preds_dict    = {"IMS": preds_anchor_ims, "DMS": preds_anchor_dms},
    plan_yyyymm   = plan_yyyymm,
    lookback      = PatchMixerConfigMonthly().lookback,
    k             = 343,
    outdir        = "./plots",
    prefix        = "anchored_ims_vs_dms"
)