In [68]:
import argparse
import numpy as np
import pandas as pd
import os
import random
import torch
import torch.nn as nn
import datetime
import time
import matplotlib.pyplot as plt
from torchinfo import summary
import yaml
import json
import sys
import glob
import copy
from tqdm import tqdm, trange

In [69]:
sys.path.append("..")
from lib.utils import (
    MaskedMAELoss,
    MaskedHuberLoss,
    print_log,
    seed_everything,
    set_cpu_num,
    masked_mae_loss,
    CustomJSONEncoder,
)
from lib.metrics import RMSE_MAE_MAPE
from lib.data_prepare import get_dataloaders_from_index_data, load_inrix_data_with_details
from model.STGWformer_eval import STGWformer

In [70]:
DATASET = "STGWformer_INRIX_MANHATTAN"
MODE = 'test'
DEVICE = 'cuda:1'
SHIFT = True
SCALER = None

In [71]:
@torch.no_grad()
def inference_graph(model):
    graph = torch.matmul(model.adaptive_embedding, model.adaptive_embedding.transpose(1, 2))
    graph = model.pooling(graph.transpose(0, 2)).transpose(0, 2)
    graph = nn.functional.relu(graph)
    graph = nn.functional.softmax(graph, dim=-1)
    return graph

def eval_model(model, valset_loader, criterion):
    model.eval()
    batch_loss_list = []
    for x_batch, y_batch in valset_loader:
        x_batch = x_batch.to(DEVICE)
        y_batch = y_batch.to(DEVICE)

        out_batch = model(x_batch)
        out_batch = SCALER.inverse_transform(out_batch)
        loss = criterion(out_batch, y_batch)
        batch_loss_list.append(loss.item())

    return np.mean(batch_loss_list)


@torch.no_grad()
def predict(model, loader):
    model.eval()
    y = []
    out = []

    for x_batch, y_batch in tqdm(loader):
        x_batch = x_batch.to(DEVICE)
        y_batch = y_batch.to(DEVICE)
        out_batch = model(x_batch)
        out_batch = SCALER.inverse_transform(out_batch)

        out_batch = out_batch.cpu().numpy()
        y_batch = y_batch.cpu().numpy()
        out.append(out_batch)
        y.append(y_batch)
    _, _, num_nodes, _ = out_batch.shape
    out = np.vstack(out).reshape(-1, 1, num_nodes)  # (samples, out_steps, num_nodes)
    y = np.vstack(y).reshape(-1, 1, num_nodes)

    return y, out

@torch.no_grad()
def test_model(model, testset_loader, log=None):
    model.eval()
    print_log("--------- Test ---------", log=log)

    start = time.time()
    y_true, y_pred = predict(model, testset_loader)
    end = time.time()

    rmse_all, mae_all, mape_all = RMSE_MAE_MAPE(y_true, y_pred)
    out_str = "All Steps RMSE = %.5f, MAE = %.5f, MAPE = %.5f\n" % (
        rmse_all,
        mae_all,
        mape_all,
    )
    # print (f"--- y_true: {y_true.shape}  y_pred: {y_pred.shape} ---")
    out_steps = y_pred.shape[1]
    for i in range(out_steps):
        rmse, mae, mape = RMSE_MAE_MAPE(y_true[:, i, :], y_pred[:, i, :])
        out_str += "Step %d RMSE = %.5f, MAE = %.5f, MAPE = %.5f\n" % (
            i + 1,
            rmse,
            mae,
            mape,
        )

    print_log(out_str, log=log, end="")
    print_log("Inference time: %.2f s" % (end - start), log=log)

In [72]:
seed = random.randint(0,1000)  # set random seed here
seed_everything(seed)
set_cpu_num(1)

In [73]:
path_to_Weight = 'saved_models/STGWformer-STGWFORMER_INRIX_MANHATTAN-2025-10-07-02-18-59.pt'

In [74]:
dataset = DATASET.upper()
data_path = f"../data/{dataset}"

model_name = STGWformer.__name__

with open(f"{model_name}.yaml", "r") as f:
    cfg = yaml.safe_load(f)
cfg = cfg[dataset]

# ------------------------------- make log file ------------------------------ #

now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
log_path = f"../logs/"
if not os.path.exists(log_path):
    os.makedirs(log_path)
log = os.path.join(log_path, f"{model_name}-{dataset}-{now}.log")
log = open(log, "a", encoding="utf-8")
log.seek(0)
log.truncate()

# ------------------------------- load dataset ------------------------------- #

print_log(dataset, log=log)
print_log(dataset, log=log)
    
(trainset_loader, valset_loader, testset_loader, SCALER, adj_mx, gdf, tmc) = (
    load_inrix_data_with_details(
        "/home/dachuan/Productivities/Spectral GAT/NY/adj_manhattan.npy",
        "/home/dachuan/Productivities/Spectral GAT/SPGAT/Data/speed_19_Manhattan_5min_py36",
        "/home/dachuan/Productivities/Spectral GAT/NY/Manhattan_FinalVersion.shp",
        "/home/dachuan/Productivities/Spectral GAT/NY/TMC_FinalVersion.csv",
        tod=cfg.get("time_of_day"),
        dow=cfg.get("day_of_week"),
        batch_size=cfg.get("batch_size", 64),
        history_seq_len=cfg.get("in_steps"),
        future_seq_len=cfg.get("out_steps"),
        log=log,
        train_ratio=cfg.get("train_size", 0.6),
        valid_ratio=cfg.get("val_size", 0.2),
        shift=SHIFT,
    )
)

print_log(log=log)
supports = [torch.tensor(i).to(DEVICE) for i in adj_mx]

STGWFORMER_INRIX_MANHATTAN
STGWFORMER_INRIX_MANHATTAN
--- Building Sequences ---
--- Scaling Sequences ---
Trainset:	x-(63057, 1212, 12, 1)	y-(63057, 1212, 1, 1)
Valset:  	x-(21019, 1212, 12, 1)  	y-(21019, 1212, 1, 1)
Testset:	x-(21020, 1212, 12, 1)	y-(21020, 1212, 1, 1)



In [75]:
# ---------------------- set loss, optimizer, scheduler ---------------------- #
from functools import partial

model = partial(STGWformer)
model = model(**cfg["model_args"])
model = model.to(DEVICE)
criterion = MaskedMAELoss()  # MaskedHuberLoss()
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=cfg["lr"],
    weight_decay=cfg.get("weight_decay", 0),
    eps=cfg.get("eps", 1e-8),
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=cfg["milestones"],
    gamma=cfg.get("lr_decay_rate"),
    verbose=False,
)



In [76]:
# --------------------------- set model saving path -------------------------- #
save_path = f"../saved_models/"

model_files = glob.glob(os.path.join(save_path, f"{model_name}-{dataset}-*.pt"))
if not model_files:
    raise ValueError("No saved model found for testing.")
latest_model = max(model_files, key=os.path.getctime)
print_log(f"Loading the latest model: {latest_model}", log=log)
model.load_state_dict(torch.load(latest_model))
model = model.to(DEVICE)

Loading the latest model: ../saved_models/STGWformer-STGWFORMER_INRIX_MANHATTAN-2025-10-07-02-18-59.pt


In [77]:
# --------------------------- print model structure -------------------------- #

print_log("---------", model_name, "---------", log=log)
print_log(
    json.dumps(cfg, ensure_ascii=False, indent=4, cls=CustomJSONEncoder), log=log
)
print_log(
    summary(
        model,
        [
            cfg["batch_size"],
            cfg["in_steps"],
            cfg["num_nodes"],
            next(iter(trainset_loader))[0].shape[-1],
        ],
        verbose=0,  # avoid print twice
        device=DEVICE,
    ),
    log=log,
)
print_log(log=log)


--------- STGWformer ---------
{
    "num_nodes": 1212,
    "in_steps": 12,
    "out_steps": 1,
    "train_size": 0.6,
    "val_size": 0.2,
    "time_of_day": false,
    "day_of_week": false,
    "lr": 0.001,
    "weight_decay": 0.0015,
    "milestones": [
        25,
        45,
        65
    ],
    "lr_decay_rate": 0.1,
    "batch_size": 8,
    "max_epochs": 300,
    "early_stop": 30,
    "use_cl": false,
    "cl_step_size": 2500,
    "model_args": {
        "num_nodes": 1212,
        "in_steps": 12,
        "out_steps": 1,
        "steps_per_day": 288,
        "input_dim": 1,
        "output_dim": 1,
        "input_embedding_dim": 24,
        "tod_embedding_dim": 0,
        "dow_embedding_dim": 0,
        "adaptive_embedding_dim": 12,
        "kernel_size": [
            1
        ],
        "num_heads": 4,
        "num_layers": 6,
        "dropout": 0.1,
        "dropout_a": 0.35
    }
}
Layer (type:depth-idx)                        Output Shape              Param #
STGWformer    

In [78]:
# --------------------------- evaluate model performance --------------------------- #
test_model(model, testset_loader, log=log)

--------- Test ---------


100%|██████████| 2627/2627 [01:31<00:00, 28.57it/s]


All Steps RMSE = 5.18845, MAE = 2.50357, MAPE = 18.29603
Step 1 RMSE = 5.18845, MAE = 2.50357, MAPE = 18.29603
Inference time: 91.99 s


In [79]:
# --------------------------- inference graph --------------------------- #
graph = inference_graph(model)

### Inference wavelets

In [80]:
import types

@torch.no_grad()
def average_U_instance(model, loader, steps=8):
    device = next(model.parameters()).device
    model.eval()

    # 1) Find the noflayer instance actually used
    wave = model.attn_layers_s[0].locals.wavelet_layer  # your noflayer instance

    # 2) Keep originals
    orig_attention = wave.attention

    # 3) Tape + tapped attention
    TAPE = {'Ulist': []}
    def tapped_attention(self, x_BTNC, A_TNN):
        U, P, A_BTNN = orig_attention(x_BTNC, A_TNN)  # call the real thing
        TAPE['Ulist'].append(U.detach().cpu())
        return U, P, A_BTNN

    # 4) Monkey-patch the INSTANCE (bind with MethodType)
    wave.attention = types.MethodType(tapped_attention, wave)

    U_acc, n = None, 0
    it = iter(loader)
    try:
        for _ in range(steps):
            batch = next(it)
            batch_x = batch[0] if isinstance(batch, (list, tuple)) else batch
            batch_x = batch_x.to(device)

            _ = model(batch_x)  # this will now push U’s into TAPE

            for U in TAPE['Ulist']:
                Umean = U.mean(0).mean(0).numpy()  # (T,N,N) → (N,N)
                U_acc = Umean if U_acc is None else U_acc + Umean
                n += 1
            TAPE['Ulist'].clear()
    except StopIteration:
        pass
    finally:
        # 5) Restore original method no matter what
        wave.attention = orig_attention

    if n == 0:
        raise RuntimeError("average_U_instance: never captured any U. "
                           "Check that the forward actually calls wavelet_layer.attention.")
    return U_acc / n  # Ubar (N,N)

In [95]:
# --------------------------- inference wavelets --------------------------- #
def get_coe_cheb(wave):
    coe  = torch.sigmoid(wave.temp).detach().cpu().numpy()
    cheb = torch.sigmoid(wave.cheb).detach().cpu().numpy()
    if wave.alpha_ is not None:  # override if alpha_ is set
        a1 = float(wave.alpha_)
        a2 = float(wave.alpha_)
    else:
        a1 = float(coe[1])
        a2 = float(coe[2])
    return a1, a2, cheb

def compute_betas(K, a1, a2, cheb, r0):
    betas = []
    r_s = float(r0)
    for s in range(K):  # s=0..K-1 => k=s+1
        w_s = 1.0 - (1.0 - a1) * r_s
        beta = (1.0 - a2) * (a2 ** (K-1 - s)) * w_s
        betas.append(beta)
        if s < K-1 and len(cheb) > s:
            r_s *= float(cheb[s])
    return np.array(betas, dtype=float)  # shape (K,)

def poly_operator(Ubar, betas, symmetrize=True):
    U = 0.5*(Ubar+Ubar.T) if symmetrize else Ubar
    N = U.shape[0]
    M = np.zeros((N,N), dtype=float)
    Uk = np.eye(N)
    for k, beta_k in enumerate(betas, start=1):
        Uk = Uk @ U
        M += beta_k * Uk
    return M

def centered_poly_operator(Ubar, betas):
    U = 0.5*(Ubar+Ubar.T)
    # stationary distribution ~ principal eigenvector (normalize to sum 1)
    vals, vecs = np.linalg.eigh(U)
    v = vecs[:, -1]; pi = np.abs(v); pi = pi / pi.sum()
    J = np.ones((U.shape[0],1)) @ pi[None,:]  # 1 * pi^T

    M = np.zeros_like(U)
    Uk = np.eye(U.shape[0])
    for k, beta in enumerate(betas, start=1):
        Uk = Uk @ U
        M += beta * (Uk - J)          # <-- remove rank-1 limit each hop
    return M

def poly_operator_unsym(Ubar, betas):
    M = np.zeros_like(Ubar)
    Uk = np.eye(Ubar.shape[0])
    for k, beta in enumerate(betas, start=1):
        Uk = Uk @ Ubar
        M += beta * Uk
    return M

def row_topk(U, k=8):
    U2 = U.copy()
    idx = np.argsort(U2, axis=1)[:, :-k]
    U2[np.arange(U2.shape[0])[:,None], idx] = 0.0
    # renormalize rows
    rs = U2.sum(axis=1, keepdims=True); rs[rs==0]=1
    return U2/rs

def row_normalize(M):
    M_pos = np.maximum(M, 0.0)
    rs = M_pos.sum(axis=1, keepdims=True)
    rs[rs==0] = 1.0
    return M_pos / rs

def center_and_spectral_norm(M):
    # Stationary via principal eigenvector of symmetric part (stable)
    S = 0.5*(M+M.T)
    vals, vecs = np.linalg.eigh(S)
    pi = np.abs(vecs[:, -1]); pi = pi / pi.sum()
    J = np.ones((M.shape[0],1)) @ pi[None,:]
    R = M - J
    # spectral norm (≈ largest singular value)
    smax = np.linalg.svd(R, compute_uv=False)[0]
    return R / (smax if smax>0 else 1.0)

def min_max_normalize(M):
    m_min = np.min(M)
    m_max = np.max(M)
    return (M-m_min)/(m_max-m_min)


In [105]:
wave = model.attn_layers_s[0].locals.wavelet_layer
a1, a2, cheb = get_coe_cheb(wave)
betas = compute_betas(1, a1, a2, cheb, 0.5)

Ubar = average_U_instance(model, trainset_loader, steps=8)
m_indices = poly_operator_unsym(Ubar, betas)
m_indices = min_max_normalize(m_indices)

In [108]:
len(m_indices[m_indices > 0.5])

1468928

In [109]:
len(m_indices[m_indices < 0.5])

16

In [111]:
def debug_U_stats(model, loader, steps=1):
    device = next(model.parameters()).device
    wave = model.attn_layers_s[0].locals.wavelet_layer  # noflayer instance
    orig_attn = wave.attention
    stats = []

    def tapped_attn(self, x, A):
        # --- copy your attention, but collect e and mask stats
        a1 = self.a[:self.in_features,:]; a2 = self.a[self.in_features:,:]
        feat_1 = torch.matmul(x, a1); feat_2 = torch.matmul(x, a2)
        e = self.leakyrelu(feat_1 + feat_2.transpose(-2,-1))  # (B,T,N,N)

        A_BTNN = A.unsqueeze(0).expand(x.size(0),-1,-1,-1)
        mask = (A_BTNN > 0)

        neg_inf = torch.finfo(e.dtype).min
        e_masked = torch.where(mask, e, e.new_full((), neg_inf))

        # --- collect per-row std of logits (variance drives softmax sharpness)
        with torch.no_grad():
            em = e_masked.clone()
            em[~mask] = 0
            # std across neighbors j
            row_std = em.float().std(dim=-1, unbiased=False).mean().item()
            mask_density = mask.float().mean().item()
            stats.append((row_std, mask_density))

        U = torch.softmax(e_masked, dim=-1)
        return U, 0.5*U, A_BTNN

    wave.attention = types.MethodType(tapped_attn, wave)
    try:
        it = iter(loader)
        for _ in range(steps):
            bx = next(it)[0].to(device)
            _ = model(bx)
    finally:
        wave.attention = orig_attn

    row_std_mean = float(np.mean([s for s,_ in stats]))
    mask_density_mean = float(np.mean([d for _,d in stats]))
    print(f"[U debug] mean row-logit std: {row_std_mean:.6f}  |  mask density: {mask_density_mean:.3f}")

In [112]:
debug_U_stats(model, trainset_loader, steps=1)

[U debug] mean row-logit std: 0.000000  |  mask density: 1.000
