In [None]:
import sys

import polars as pl
import torch

from data_loader.TimeSeriesModule import MultiPartDataModule
from training.config import TrainingConfig
from training.model_trainers.total_train import run_total_train_monthly
from utils.checkpoint import save_model_dict, load_model_dict
from utils.plot_utils import plot_val_per_part, plot_120_months_many

'''
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
https://developer.nvidia.com/cuda-12-8-0-download-archive
'''

MAC_DIR = '../data/'
WINDOW_DIR = 'C:/Users/USER/PycharmProjects/research/data/'

if sys.platform == 'win32':
    DIR = WINDOW_DIR
    print(torch.cuda.is_available())
    print(torch.cuda.device_count())
    print(torch.version.cuda)
    print(torch.__version__)
    print(torch.cuda.get_device_name(0))
    print(torch.__version__)
else:
    DIR = MAC_DIR

In [1]:
import torch

class DummyUsesExo(torch.nn.Module):
    def __init__(self, H, exo_dim):
        super().__init__()
        self.H = H
        self.w = torch.nn.Linear(exo_dim, 1, bias=False)
    def forward(self, x, future_exo=None, mode=None):
        if future_exo is None:
            return torch.zeros(x.size(0), self.H, device=x.device)
        return self.w(future_exo).squeeze(-1)  # (B,H,exo_dim) -> (B,H)

B, L, C, H, exo_dim = 2, 24, 1, 12, 2
x = torch.randn(B, L, C)

def make_exo(offset):
    t = torch.arange(offset, offset+H).float()
    exo = torch.stack([torch.sin(2*torch.pi*t/12), torch.cos(2*torch.pi*t/12)], -1)
    return exo.unsqueeze(0).expand(B, -1, -1)

m = DummyUsesExo(H, exo_dim)
from training.adapters import DefaultAdapter
ad = DefaultAdapter()

y1 = ad.forward(m, x, future_exo=make_exo(0), mode='train')
y2 = ad.forward(m, x, future_exo=make_exo(3), mode='train')
print("diff:", (y1 - y2).abs().sum().item())  # 0보다 커야 정상


diff: 18.067934036254883


In [None]:
target_dyn_demand_monthly = pl.read_parquet(DIR + 'target_dyn_demand_monthly.parquet')

In [None]:
plan_yyyymm = 202305

train_cfg = TrainingConfig()

data_module = MultiPartDataModule(
    target_dyn_demand_monthly,
    train_cfg,
    batch_size = 64,
    val_ratio = 0.2,
    is_running = False
)
train_loader = data_module.get_train_loader()
val_loader = data_module.get_val_loader()

In [None]:
model_dict = run_total_train_monthly(train_loader, val_loader)

In [None]:
model_dict

In [None]:
from models.PatchTST.common.configs import PatchTSTConfigMonthly
from models.Titan.common.configs import TitanConfigMonthly
from models.PatchMixer.common.configs import PatchMixerConfigMonthly

save_dir = DIR + 'fit'
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

pm_config = PatchMixerConfigMonthly(
        device = device,
        loss_mode = 'quantile',
        quantiles = (0.1, 0.5, 0.9)
    )

ti_config = TitanConfigMonthly(
        device = device,
        loss_mode = 'point',
        point_loss = 'huber'
    )

pt_config = PatchTSTConfigMonthly(
        device = device,
        loss_mode = 'auto',
        quantiles = (0.1, 0.5, 0.9)
    )

cfg_map = {
    "PatchMixer Base": pm_config,
    "PatchMixer Quantile": pm_config,
    "Titan Base": ti_config,
    "Titan LMM": ti_config,
    "Titan Seq2Seq": ti_config,
    "PatchTST Base": pt_config
}

builder_key_by_name = {
  "PatchMixer Base": "patchmixer_base",
  "PatchMixer Quantile": "patchmixer_quantile",
  "Titan Base": "titan_base",
  "Titan LMM": "titan_lmm",
  "Titan Seq2Seq": "titan_seq2seq",
  "PatchTST Base": "patchtst_base",
}
save_index = save_model_dict(model_dict, save_dir, cfg_by_name = cfg_map, builder_key_by_name=builder_key_by_name)

# Load
from models.model_builder import (
    build_patch_mixer_base, build_patch_mixer_quantile,
    build_titan_base, build_titan_lmm, build_titan_seq2seq,
    build_patchTST_base
)

builders = {
    "patchmixer_base": lambda cfg: build_patch_mixer_base(cfg or PatchMixerConfigMonthly()),
    "patchmixer_quantile": lambda cfg: build_patch_mixer_quantile(cfg or PatchMixerConfigMonthly()),
    "titan_base": lambda cfg: build_titan_base(cfg or TitanConfigMonthly()),
    "titan_lmm": lambda cfg: build_titan_lmm(cfg or TitanConfigMonthly()),
    "titan_seq2seq": lambda cfg: build_titan_seq2seq(cfg or TitanConfigMonthly()),
    "patchtst_base": lambda cfg: build_patchTST_base(cfg or PatchTSTConfigMonthly()),
}
loaded = load_model_dict(save_dir, builders, device = device)

plot_out = DIR + 'plot'
plot_120_months_many(loaded, val_loader, device=device, use_truth=True,
                     max_plots=100, show=True)