In [1]:
import sys, site
print(sys.executable)
print(site.getsitepackages())

/opt/anaconda3/envs/ts_forecaster/bin/python
['/opt/anaconda3/envs/ts_forecaster/lib/python3.12/site-packages']


In [2]:
import sys

import polars as pl
import torch

from modeling_module.data_loader import MultiPartExoDataModule

'''
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
https://developer.nvidia.com/cuda-12-8-0-download-archive
'''

MAC_DIR = '/Users/igwanhyeong/PycharmProjects/data_research/raw_data/'
WINDOW_DIR = 'C:/Users/USER/PycharmProjects/research/raw_data/'

if sys.platform == 'win32':
    DIR = WINDOW_DIR
    print(torch.cuda.is_available())
    print(torch.cuda.device_count())
    print(torch.version.cuda)
    print(torch.__version__)
    print(torch.cuda.get_device_name(0))
    print(torch.__version__)
else:
    DIR = MAC_DIR
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

save_dir = DIR + 'fit/20251109_LTB'


In [None]:
pl.read_parquet(DIR + 'target_dyn_demand.parquet').select('oper_part_no').unique()

In [None]:

save_dir = DIR + 'fit/20251115_LTB'

In [None]:
target_dyn_demand_monthly = pl.read_parquet(DIR + 'target_dyn_demand_monthly.parquet').sort(['oper_part_no', 'demand_dt'])
target_dyn_demand_monthly = (target_dyn_demand_monthly
                                .group_by('oper_part_no', maintain_order = True)
                                .map_groups(lambda g: g.with_columns(pl.arange(1, len(g) + 1).alias('sequence')))
                            )


filtered_target = (target_dyn_demand_monthly
                    .group_by('oper_part_no')
                    .agg(pl.col('sequence').max().alias('sequence_max'))
                    .filter(pl.col('sequence_max') > 43)
                    .select('oper_part_no')
                   ) # seq Q75

target_dyn_demand_monthly = (target_dyn_demand_monthly
                                .join(filtered_target, on = 'oper_part_no', how = 'right')
                                .select(['oper_part_no', 'demand_dt', 'sequence', 'demand_qty'])
                             )
target_dyn_demand_monthly


In [None]:
plan_yyyymm = 201801
lookback = 12
horizon = 3


data_module = MultiPartExoDataModule(
    target_dyn_demand_monthly,
    lookback = lookback,
    horizon = horizon,
    batch_size = 128,
    val_ratio = 0.2,
    past_exo_cont_cols=('sequence',),
    is_running = False
)

train_loader = data_module.get_train_loader()
val_loader = data_module.get_val_loader()

In [None]:
from modeling_module.training.model_trainers.total_train import run_total_train_monthly

model_dict = run_total_train_monthly(
    train_loader,
    val_loader,
    lookback = lookback,
    horizon = horizon,
    save_dir = save_dir
)

In [None]:
from modeling_module.utils.checkpoint import load_model_dict
# Load
from modeling_module.models.model_builder import (
    build_patch_mixer_quantile,
    build_patchTST_base, build_patchTST_quantile, build_patch_mixer_base, build_titan_base, build_titan_lmm,
    build_titan_seq2seq,
)
device = 'cuda'

builders = {
    "patchmixer_base": build_patch_mixer_base,
    "patchmixer_quantile": build_patch_mixer_quantile,
    "titan_base": build_titan_base,
    "titan_lmm": build_titan_lmm,
    "titan_seq2seq": build_titan_seq2seq,
    "patchtst_base": build_patchTST_base,
    "patchtst_quantile": build_patchTST_quantile,
}
loaded = load_model_dict(save_dir, builders, device = device)


In [None]:
%load_ext autoreload
%autoreload 2

import importlib, modeling_module.utils.plot_utils as pu
import modeling_module.training.forecaster as fo
importlib.reload(pu)
importlib.reload(fo)

def my_exo_cb(start_idx: int, Hm: int, device="cuda" if torch.cuda.is_available() else "cpu"):
    # exo_dim = 2 (sin, cos)
    return fo.make_calendar_exo(start_idx, Hm, period=12, device=device)

pu.plot_120m(
    models=loaded,           # {"PatchMixer": pm_model, "Titan": ti_model, ...}
    loader=val_loader,       # (xb, yb[, part_ids])
    device="cuda" if torch.cuda.is_available() else "cpu",
    mode="val",              # ← 검증 모드
    max_plots=5,
    out_dir=None,
    show=True,
    future_exo_cb=my_exo_cb
)