In [None]:
import numpy as np
import torch

from torch.optim import AdamW
from torch.utils.data import DataLoader, random_split

from dataset.traffic_dataset import TrafficDataset
from dataset.dataset_config import edge_index, edge_attr
#from models.baselines import STGCN
from models.STLinear import STLinear
from models.STLinear_biased_models import STLinear_SPE
from utils.Trainer import Trainer  
import optuna

# 2) collate_fn
def collate_fn(batch_list):
    xs = torch.stack([data.x for data in batch_list], dim=0)  # [B, T, E, C]
    ys = torch.stack([data.y for data in batch_list], dim=0)  # [B, n_pred, E, D]
    return xs, ys


In [None]:

# 3) dataset
dataset_np = np.load('dataset/traffic_dataset_13_smoothen.npy', allow_pickle=True)
dataset = TrafficDataset(dataset_np, window=12, randomize=False)

train_size = int(len(dataset) * 0.8)
val_size   = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=512, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=512, shuffle=False, collate_fn=collate_fn)

# shape test
x0, y0 = next(iter(train_loader))
B, T, E, C_in = x0.shape
_, n_pred, _, C_out = y0.shape

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Dataset shapes: x0={x0.shape}, y0={y0.shape}, device={device}")


In [None]:

# 4) Optuna Objective
def objective(trial):
    # --- Hyperparameter suggestions ---
    # common
    weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-3)
    dropout = trial.suggest_uniform("dropout", 0.1, 0.3)
    # hyperparameter for GNN
    kernel_size = trial.suggest_categorical("kernel_size", [17, 33, 65])
    K = trial.suggest_int("K", 1, 3)
    num_layers = trial.suggest_int("num_layers", 2, 4)
    num_heads = trial.suggest_int("num_heads", 1,2,4)

    input_embedding_dim = trial.suggest_categorical("input_embedding_dim", [16, 32, 64])
    tod_embedding_dim = trial.suggest_categorical("tod_embedding_dim", [16, 32, 64])
    dow_embedding_dim = trial.suggest_categorical("dow_embedding_dim", [16, 32, 64])
    spatial_embedding_dim = trial.suggest_categorical("spatial_embedding_dim", [0, 16, 32, 64])
    adaptive_embedding_dim = trial.suggest_categorical("adaptive_embedding_dim", [0, 16, 32, 64])
    spe_dim = trial.suggest_categorical("spe_dim", [16, 32, 64])
    spe_out_dim = trial.suggest_categorical("spe_out_dim", [16, 32, 64])


    model = STLinear_SPE(
        num_nodes =E,
        kernel_size=kernel_size, #odd number
        num_heads=num_heads,
        num_layers=num_layers,
        dropout=dropout,
        input_embedding_dim = input_embedding_dim,
        tod_embedding_dim = tod_embedding_dim,
        dow_embedding_dim = dow_embedding_dim,
        spatial_embedding_dim = spatial_embedding_dim,
        adaptive_embedding_dim = adaptive_embedding_dim,
        spe_dim = spe_dim,
        spe_out_dim = spe_out_dim
    ).to(device)

    optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=weight_decay)
    criterion = torch.nn.L1Loss()

    # --- run Trainer ---
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        valid_loader=val_loader,
        optimizer=optimizer,
        criterion=criterion,
        epochs=50,
        device=device,
        print_interval=0,    # no print
        plot_interval=0,     # no plot
        early_stopping_patience=4
    )
    trainer.fit()

    # return best loss
    valid_loss = trainer.get_best_valid_loss() 
    return valid_loss


In [None]:
# 5) Run Optuna Study
study = optuna.create_study(
    direction="minimize",
    sampler=optuna.samplers.TPESampler(seed=42)
)
study.optimize(objective, n_trials=30)

# 6) check best result
print("Best validation loss:", study.best_value)
print("Best hyperparameters:")
for k,v in study.best_params.items():
    print(f"  {k}: {v}")

In [None]:
# 7) train with best hyperparameters
best_params = study.best_params
best_model = STLinear_SPE(
    num_nodes=E,
    # GNN
    kernel_size=best_params['kernel_size'],        # odd number
    num_heads=best_params['num_heads'],
    num_layers=best_params['num_layers'],
    dropout=best_params['dropout'],
    # embedding dim
    input_embedding_dim=best_params['input_embedding_dim'],
    tod_embedding_dim=best_params['tod_embedding_dim'],
    dow_embedding_dim=best_params['dow_embedding_dim'],
    spatial_embedding_dim=best_params['spatial_embedding_dim'],
    adaptive_embedding_dim=best_params['adaptive_embedding_dim'],
    spe_dim=best_params['spe_dim'],
    spe_out_dim=best_params['spe_out_dim']
).to(device)

best_opt = AdamW(best_model.parameters(), lr=5e-5, weight_decay=best_params['weight_decay'])
trainer = Trainer(
    model=best_model,
    train_loader=train_loader,
    valid_loader=val_loader,
    optimizer=best_opt,
    criterion=torch.nn.L1Loss(),
    epochs=60,
    device=device,
    print_interval=0,
    plot_interval=2,
    auto_save=True,
    save_dir='./final_model'
)
trainer.fit()
hist = trainer.get_history()

import matplotlib.pyplot as plt
plt.plot(hist['train_loss'], label='Train Loss')
plt.plot(hist['valid_loss'], label='Val Loss')
plt.legend()
plt.show()