In [3]:
# =============================================================
# MS-GWN-A ONLY EVALUATION (NO TRAINING)
# =============================================================

import numpy as np
import pandas as pd
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
from sklearn.metrics import r2_score

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# =============================================================
# LOAD DATA
# =============================================================

df = pd.read_csv("pems_bay_final_with_extra_features.csv", index_col="timestamp", parse_dates=True)

sensor_cols = [c for c in df.columns if c.isdigit()]
time_cols = [c for c in df.columns if c not in sensor_cols]

traffic = df[sensor_cols].to_numpy(dtype=np.float32)
time_feat = df[time_cols].to_numpy(dtype=np.float32)

mean = traffic.mean(axis=0, keepdims=True)
std  = traffic.std(axis=0, keepdims=True)
std[std == 0] = 1.0
traffic = (traffic - mean) / std

T, N = traffic.shape
F_time = time_feat.shape[1]

traffic = traffic[..., None]
time_feat_expanded = np.broadcast_to(time_feat[:, None, :], (T, N, F_time))
data = np.concatenate([traffic, time_feat_expanded], axis=2).astype(np.float32)

# =============================================================
# DATASET
# =============================================================

SEQ_LEN = 24
PRED_LEN = 3

class TrafficDataset(Dataset):
    def __init__(self, data):
        self.data = data.astype(np.float32)

    def __len__(self):
        return len(self.data) - SEQ_LEN - PRED_LEN

    def __getitem__(self, idx):
        x = self.data[idx : idx+SEQ_LEN]
        y = self.data[idx+SEQ_LEN : idx+SEQ_LEN+PRED_LEN, :, 0]
        x = torch.from_numpy(x).permute(2,1,0)
        y = torch.from_numpy(y)
        return x, y

split = int(len(data) * 0.8)
test_loader = DataLoader(TrafficDataset(data[split:]), batch_size=32, shuffle=False)

# =============================================================
# ADJACENCY
# =============================================================

with open("adj_mx_PEMS-BAY.pkl", "rb") as f:
    adj_data = pickle.load(f, encoding="latin1")

A = adj_data[2].astype(np.float32)
A = A + np.eye(A.shape[0], dtype=np.float32)
D = np.sum(A, axis=1)
D_inv_sqrt = np.diag(1.0 / np.sqrt(D + 1e-8))
A_norm = D_inv_sqrt @ A @ D_inv_sqrt
adj_mx = torch.tensor(A_norm, dtype=torch.float32).to(device)

# =============================================================
# MODEL (EXACT SAME)
# =============================================================

class NodeAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.query = nn.Linear(channels, channels)
        self.key = nn.Linear(channels, channels)
        self.value = nn.Linear(channels, channels)
        self.scale = math.sqrt(channels)

    def forward(self, x):
        B,C,N,T = x.shape
        x_pool = x.mean(dim=-1).permute(0,2,1)
        Q,K,V = self.query(x_pool), self.key(x_pool), self.value(x_pool)
        attn = torch.bmm(Q,K.transpose(1,2))/self.scale
        attn = F.softmax(attn, dim=-1)
        out = torch.bmm(attn,V).permute(0,2,1).unsqueeze(-1).expand(B,C,N,T)
        return out

class AdaptiveAdjacency(nn.Module):
    def __init__(self, num_nodes, adj_fixed):
        super().__init__()
        self.register_buffer('adj_fixed', adj_fixed)
        self.adj_learned = nn.Parameter(torch.randn(num_nodes, num_nodes) * 0.01)
        self.alpha = nn.Parameter(torch.tensor(0.5))

    def forward(self):
        adj_l = F.relu(self.adj_learned)
        adj_l = adj_l/(adj_l.sum(dim=1,keepdim=True)+1e-8)
        alpha = torch.sigmoid(self.alpha)
        return alpha*self.adj_fixed+(1-alpha)*adj_l

class MultiScaleTemporalBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv_1 = nn.Conv2d(channels, channels, (1,3), padding=(0,1))
        self.conv_2 = nn.Conv2d(channels, channels, (1,3), padding=(0,2), dilation=(1,2))
        self.conv_4 = nn.Conv2d(channels, channels, (1,3), padding=(0,4), dilation=(1,4))
        self.fusion = nn.Conv2d(channels*3, channels, 1)

    def forward(self,x):
        x1=F.relu(self.conv_1(x))
        x2=F.relu(self.conv_2(x))
        x3=F.relu(self.conv_4(x))
        return self.fusion(torch.cat([x1,x2,x3],dim=1))

class GraphConvolution(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.lin = nn.Linear(in_channels, out_channels)

    def forward(self,x,adj):
        B,C,N,T=x.shape
        x=x.permute(0,3,2,1).reshape(B*T,N,C)
        x=torch.bmm(adj.unsqueeze(0).expand(B*T,N,N),x)
        x=self.lin(x).reshape(B,T,N,-1).permute(0,3,2,1)
        return x

class TemporalAttention(nn.Module):
    def __init__(self,pred_len):
        super().__init__()
        self.attn_weights = nn.Parameter(torch.ones(pred_len)/pred_len)

    def forward(self,x):
        w=F.softmax(self.attn_weights,dim=0).view(1,-1,1)
        return x*w

class MS_GWN_A(nn.Module):
    def __init__(self,num_nodes,in_dim,out_dim,adj_fixed):
        super().__init__()
        channels=48
        self.adaptive_adj=AdaptiveAdjacency(num_nodes,adj_fixed)
        self.input_proj=nn.Conv2d(in_dim,channels,1)
        self.temporal_blocks=nn.ModuleList([MultiScaleTemporalBlock(channels) for _ in range(3)])
        self.graph_convs=nn.ModuleList([GraphConvolution(channels,channels) for _ in range(3)])
        self.node_attentions=nn.ModuleList([NodeAttention(channels) for _ in range(3)])
        self.skip_convs=nn.ModuleList([nn.Conv2d(channels,channels,1) for _ in range(3)])
        self.temporal_pool=nn.AdaptiveAvgPool2d((num_nodes,1))
        self.output_proj=nn.Sequential(nn.Linear(channels,128),nn.ReLU(),nn.Dropout(0.1),nn.Linear(128,out_dim))
        self.temporal_attn=TemporalAttention(out_dim)

    def forward(self,x):
        adj=self.adaptive_adj()
        x=self.input_proj(x)
        skips=[]
        for t,g,a,s in zip(self.temporal_blocks,self.graph_convs,self.node_attentions,self.skip_convs):
            res=x
            x=t(x)
            x=F.relu(g(x,adj))
            x=x+a(x)
            x=x+res
            skips.append(s(x))
        x=torch.stack(skips).sum(0)
        x=self.temporal_pool(x).squeeze(-1).permute(0,2,1)
        out=self.output_proj(x).permute(0,2,1)
        return self.temporal_attn(out)

# =============================================================
# LOAD MODEL
# =============================================================

model = MS_GWN_A(num_nodes=N,in_dim=data.shape[2],out_dim=PRED_LEN,adj_fixed=adj_mx).to(device)
#model.load_state_dict(torch.load("ms_gwn_a_best.pth",map_location=device))
state_dict = torch.load("ms_gwn_a_best.pth", map_location=device, weights_only=True)
model.load_state_dict(state_dict)

model.eval()
print("Model Loaded Successfully ✓")

# =============================================================
# EVALUATE
# =============================================================
from tqdm import tqdm

mae = 0
mse = 0
mape = 0
count = 0

preds_list = []
trues_list = []

with torch.no_grad():
    for x, y in tqdm(test_loader, desc="Evaluating", ncols=100):

        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        pred = model(x)

        # regression metrics
        mae += torch.abs(pred - y).sum().item()
        mse += ((pred - y) ** 2).sum().item()

        mask = y != 0
        mape += (torch.abs((pred - y) / (y + 1e-8))[mask]).sum().item()
        count += y.numel()

        preds_list.append(pred.cpu())
        trues_list.append(y.cpu())

# stack once at end (memory safe)
preds = torch.cat(preds_list).numpy()
trues = torch.cat(trues_list).numpy()

# metrics
mae /= count
rmse = (mse / count) ** 0.5
mape = (mape / count) * 100

real_mae = mae * std.mean()
real_rmse = rmse * std.mean()

print("\n===== FINAL RESULTS =====")
print("Normalized MAE :", mae)
print("Normalized RMSE:", rmse)
print("MAPE:", mape, "%")
print("Real MAE :", real_mae)
print("Real RMSE:", real_rmse)

# R2
from sklearn.metrics import r2_score
r2 = r2_score(trues.reshape(-1), preds.reshape(-1))
print("R2 Score:", r2)
print("=========================")


Device: cuda
Model Loaded Successfully ✓


Evaluating: 100%|█████████████████████████████████████████████████| 325/325 [00:43<00:00,  7.51it/s]



===== FINAL RESULTS =====
Normalized MAE : 0.14823707404772518
Normalized RMSE: 0.30634279095634
MAPE: 118.81603055029434 %
Real MAE : 1.2684547704842077
Real RMSE: 2.621354860707189
R2 Score: 0.9076744318008423


In [4]:
# Save normalization + check window
np.save("train_mean.npy", mean)
np.save("train_std.npy", std)
print("Saved ✓")

# Check sensor 400863 at window 9012
idx = sensor_cols.index('400863')
window = 9012

true_norm = data[window, idx, 0]
true_real = float(true_norm * std[0, idx] + mean[0, idx])
print(f"True value at window 9012: {true_real:.2f} mph")

Saved ✓
True value at window 9012: 22.20 mph


In [7]:
# Check what timestamp corresponds to window 9012
idx = sensor_cols.index('400863')
window = 9012

# Get the actual timestamp at window 9012
df_check = pd.read_csv("pems_bay_final_with_extra_features.csv", 
                        index_col="timestamp", parse_dates=True)
df_check = df_check.sort_index()

timestamp = df_check.index[window]
csv_value = float(df_check.loc[timestamp, '400863'])

print(f"Window 9012 → Timestamp: {timestamp}")
print(f"CSV value at that timestamp: {csv_value:.2f} mph")
print(f"Data array value:           {true_real:.2f} mph")
print(f"Match: {abs(csv_value - true_real) < 1.0}")

Window 9012 → Timestamp: 06-02-2017 07:00
CSV value at that timestamp: 55.00 mph
Data array value:           22.20 mph
Match: False


In [9]:
print("Is index monotonic?", df_check.index.is_monotonic_increasing)
print("Duplicate timestamps:", df_check.index.duplicated().sum())
print("Total rows:", len(df_check))

Is index monotonic? True
Duplicate timestamps: 0
Total rows: 52116


In [13]:
print("First 5 timestamps:", df_check.index[:5].tolist())
print("Timestamp dtype:", df_check.index.dtype)
print("Window 9012 timestamp:", df_check.index[9012])

First 5 timestamps: ['01-01-2017 00:00', '01-01-2017 00:05', '01-01-2017 00:10', '01-01-2017 00:15', '01-01-2017 00:20']
Timestamp dtype: object
Window 9012 timestamp: 06-02-2017 07:00


In [15]:
df_check = pd.read_csv("pems_bay_final_with_extra_features.csv", 
                        index_col="timestamp")
df_check.index = pd.to_datetime(df_check.index, format="%d-%m-%Y %H:%M")

print("First 5:", df_check.index[:5].tolist())
print("Window 9012:", df_check.index[9012])
print("CSV value at window 9012:", df_check.iloc[9012]['400863'])
print("Data array value:", true_real)

First 5: [Timestamp('2017-01-01 00:00:00'), Timestamp('2017-01-01 00:05:00'), Timestamp('2017-01-01 00:10:00'), Timestamp('2017-01-01 00:15:00'), Timestamp('2017-01-01 00:20:00')]
Window 9012: 2017-02-01 07:00:00
CSV value at window 9012: 22.2
Data array value: 22.200000762939453
