In [None]:
import sys

import polars as pl
from data_loader.TimeSeriesModule import MultiPartDataModule
from model_runner.model_configs import PatchMixerConfigWeekly
from model_runner.train.patchmixer_train import PatchMixerTrain
from models.PatchMixer.PatchMixer import *

'''
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
https://developer.nvidia.com/cuda-12-8-0-download-archive
'''

MAC_DIR = '../data/'
WINDOW_DIR = 'C:/Users/USER/PycharmProjects/research/data/'

if sys.platform == 'win32':
    DIR = WINDOW_DIR
    print(torch.cuda.is_available())
    print(torch.cuda.device_count())
    print(torch.version.cuda)
    print(torch.__version__)
    print(torch.cuda.get_device_name(0))
    print(torch.__version__)
else:
    DIR = MAC_DIR

### Data Loader

In [None]:
target_dyn_demand_weekly = pl.read_parquet(DIR + 'target_dyn_demand_weekly.parquet')

config = PatchMixerConfigWeekly()

data_module = MultiPartDataModule(
    df = target_dyn_demand_weekly,
    config = config,
    is_running = True,
    batch_size = 128,
    val_ratio = 0.2
)

plan_yyyyww = 202315
train_loader = data_module.get_train_loader()
val_loader = data_module.get_val_loader()
inference_loader = data_module.get_inference_loader()
anchor_loader = data_module.get_inference_loader_at_plan(
    plan_dt = plan_yyyyww,
    parts_filter=None,
    fill_missing='ffill'
)

In [None]:
# DataModule 안에서 만든 인덱스가 있다면 활용하세요. 없으면 아래처럼 만들어서 비교:
def collect_indices(loader, max_batches=999999):
    idxs = []
    seen=0
    for b in loader:
        # 데이터셋에서 원본 인덱스를 함께 반환하도록 구현되어 있지 않다면,
        # 배치 단위로 해시를 만들어 임시로 비교합니다 (완벽하진 않음)
        x = b[0]
        idxs.append(torch.tensor([x.numel(), x.sum()]).float())  # 조잡한 지문
        seen += 1
        if seen>=max_batches: break
    return torch.stack(idxs)

fp_tr = collect_indices(train_loader)
fp_va = collect_indices(val_loader)
print("fp shapes:", fp_tr.shape, fp_va.shape)
print("approx equal:", torch.allclose(fp_tr.mean(0), fp_va.mean(0)))


### Model Training

In [None]:
model = PatchMixerQuantileModel(config)
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3, weight_decay = 1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer, T_max = 2)


train_model, tr_hist, va_hist = PatchMixerTrain().base_train(
    train_loader = train_loader,
    val_loader = val_loader,
    model = model,
    epochs = 2
)


### Model Save

In [None]:
# Full Save
torch.save({
    'model_state': train_model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'scheduler_state': scheduler.state_dict(),
    'config': vars(config),
    'train_loss_hist': tr_hist,
    'val_loss_hist': va_hist,
}, DIR + 'fit/patch_mixer_20250919_running_full.pt')

# only for weight
# torch.save(train_model.state_dict(), DIR + 'patch_mixer_20250919_l54_h_27.pt')

### Trained Model Recall

In [None]:
# ------------------ Full Recall ------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ckpt = torch.load(DIR + 'fit/patch_mixer_20250919_running_full.pt', map_location=device)
cfg = PatchMixerConfigWeekly()
model = PatchMixerQuantileModel(cfg).to(device)
model.load_state_dict(ckpt["model_state"])
optimizer.load_state_dict(ckpt["optimizer_state"])
scheduler.load_state_dict(ckpt["scheduler_state"])
model = PatchMixerQuantileModel(cfg)

# ------------------ Weight Recall ------------------
# ckpt_path = DIR + 'patch_mixer_20250918_l36_h48.pt'
# state = torch.load(ckpt_path, map_location=device)
# model.load_state_dict(state)

model.to(device).eval()


In [None]:
from model_runner.inferences.patch_mixer_inference import PatchMixer_DMSForecaster
from tqdm import tqdm

dms = PatchMixer_DMSForecaster(model, target_channel = 0, fill_mode = 'copy_last')
all_preds = []
all_parts = []

all_x_first = None
with torch.no_grad():
    for i, batch in enumerate(tqdm(anchor_loader)):
        if len(batch) == 2:
            x, part = batch
        else:
            x, _, part = batch

        all_parts.extend(part)

        if i == 0: all_x_first = x # first batch x for history graph
        x = x.to(device)

        y_hat = dms.forecast_overlap_avg(
            x, horizon = 27,
            q_use = 0.5, return_full = False,
            use_winsor = False, use_multi_guard = False, use_dampen = False,
            clip_q = (0.05, 0.95), clip_mul = 2.0, # 강도는 도메인에 맞게 수정
            max_growth = 1.10, max_step_up = 0.10, max_step_down = 0.30,
            damp_min = 0.2, damp_max = 0.6
        )   # (B, H)
        all_preds.append(y_hat.cpu())

preds_anchor_dms = torch.cat(all_preds, dim = 0) # (B_total, H)

In [None]:
from model_runner.inferences.patch_mixer_inference import PatchMixer_IMSForecaster

ims = PatchMixer_IMSForecaster(model, target_channel = 0, fill_mode = 'copy_last')
model.eval()
all_preds = []
with torch.no_grad():
    for batch in tqdm(anchor_loader):
        if len(batch) == 2:
            x, part = batch
        else:
            x, _, part = batch
        x = x.to(device)

        y_hat = ims.forecast(
            x, horizon = 27,
            q_use = 0.5, return_full = False,
            use_winsor = False,use_dampen = False, damp = 0.5
        )

        all_preds.append(y_hat.cpu())

preds_anchor_ims = torch.cat(all_preds, dim = 0)

In [None]:
from utils.plot_utils import plot_anchored_forecasts_yyyyww_multi

plot_anchored_forecasts_yyyyww_multi(
    df            = target_dyn_demand_weekly,
    parts         = all_parts,  # ← 여기서 사용!
    preds_dict    = {"IMS": preds_anchor_ims, "DMS": preds_anchor_dms},
    plan_yyyyww   = plan_yyyyww,
    lookback      = PatchMixerConfigMonthly().lookback,
    k             = 343,
    outdir        = "./plots",
    prefix        = "anchored_ims_vs_dms"
)