In [None]:
import sys

import polars as pl
import torch

from data_loader.TimeSeriesModule import MultiPartDataModule
from model_runner.train.titanl_train import TitanTrain
from models.Titan.Titans import LMMModel
import torch.nn as nn

'''
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
https://developer.nvidia.com/cuda-12-8-0-download-archive
'''

MAC_DIR = '../data/'
WINDOW_DIR = 'C:/Users/USER/PycharmProjects/research/data/'

if sys.platform == 'win32':
    DIR = WINDOW_DIR
    print(torch.cuda.is_available())
    print(torch.cuda.device_count())
    print(torch.version.cuda)
    print(torch.__version__)
    print(torch.cuda.get_device_name(0))
    print(torch.__version__)
else:
    DIR = MAC_DIR

In [None]:
from models.Titan.common.configs import TitanConfigMonthly

target_dyn_demand_monthly = pl.read_parquet(DIR + 'target_dyn_demand_monthly.parquet')
cfg = TitanConfigMonthly()


data_module = MultiPartDataModule(
    df = target_dyn_demand_monthly,
    config = cfg,
    batch_size = 128,
    val_ratio = 0.2,
    is_running = False
)

plan_yyyymm = 202401
train_loader = data_module.get_train_loader()
val_loader = data_module.get_val_loader()
anchor_loader = data_module.get_inference_loader_at_plan(
    plan_dt = plan_yyyymm,
    parts_filter = None,
    fill_missing = 'ffill'
)
from utils.validation_utils import collect_indices

fp_tr = collect_indices(train_loader)
fp_va = collect_indices(val_loader)
print("fp shapes:", fp_tr.shape, fp_va.shape)
print("approx equal:", torch.allclose(fp_tr.mean(0), fp_va.mean(0)))

### Model Training

In [None]:
model = LMMModel(cfg)
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-3, weight_decay = 1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer, T_max = 10)

trained_model, tr_hist, va_hist = TitanTrain().train_model_with_tta(
    model = model,
    train_loader = train_loader,
    val_loader = val_loader
)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

xb, yb, *rest = next(iter(train_loader))
print("y batch min/max:", yb.min().item(), yb.max().item())

model.eval()
with torch.no_grad():
    pred = model(xb.to(next(model.parameters()).device))
pred_cpu = pred.detach().cpu()

print("pred min/max:", pred_cpu.min().item(), pred_cpu.max().item())

xb, *rest = next(iter(anchor_loader))
xb = xb.to(device)
model.eval()
with torch.no_grad():
    raw = model(xb)  # (B,H) or (B,Q,H)
print("RAW on anchor min/max:", raw.min().item(), raw.max().item())

print(model.head)  # 또는 model
for name, m in model.named_modules():
    if isinstance(m, (nn.Sigmoid, nn.Tanh)):
        print("Found bounded activation:", name, m)

### Model Save

In [None]:
# Full Save
torch.save({
    'model_state': trained_model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'scheduler_state': scheduler.state_dict(),
    'config': vars(cfg),
    'train_loss_hist': tr_hist,
    'val_loss_hist': va_hist
}, DIR + 'fit/titan_20250922_ltb_full.pt')

# only for weight
# torch.save(model.state_dict(), DIR + f'titan_tta_20250922_l{cfg.lookback}_h_{cfg.horizon}.pt')

### Load Train Model

In [None]:

# ------------------ Full Recall ------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ckpt = torch.load(DIR + 'fit/titan_20250922_ltb_full.pt', map_location = device)
cfg = TitanConfigMonthly()
model = LMMModel(cfg).to(device)
model.load_state_dict(ckpt['model_state'])
optimizer.load_state_dict(ckpt['optimizer_state'])
scheduler.load_state_dict(ckpt['scheduler_state'])

# ------------------ Weight Recall ------------------
# ckpt_path = DIR + 'fit/titan_20250922_ltb_full.pt', map_location = device)
# state = torch.load(ckpt_path, map_location = device)
# model.load_state_dict(state)

model.to(device).eval()