In [2]:
# k-sensitivity, h=24
# Metrics: MAE, RMSE, viol_rate (val + test)
# Runs: k = 3, 5, 6 (k=4 already done)


import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F


DATA_PATH = "dataset_2023_2025.csv"

H = 24
LAGS = [1, 2, 6, 24]

TRAIN_END = "2024-12-31 23:00:00"
VAL_END   = "2025-06-30 23:00:00"
TEST_END  = "2025-11-23 23:00:00"

ELL_KM = 120.0
K_LIST = [3, 5, 6]  

SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_TIMES = 128
EPOCHS = 30
PATIENCE = 4
LR = 2e-3
WEIGHT_DECAY = 1e-6

DT = 1.0
GAMMA = 1.0

# full PI-GNN constraints
LAMBDA_INEQ = 0.30
LAMBDA_NONNEG = 0.05

DROPOUT = 0.20

SAVE_PREDICTIONS = False
OUT_MET = "metrics_k_sensitivity_pignn_h24.csv"
OUT_PRED = "predictions_pignn_k_sensitivity_h24.csv"

np.random.seed(SEED)
torch.manual_seed(SEED)

# Helpers
def make_lag_features(df, lags):
    df = df.sort_values(["city_id","datetime"]).copy()
    for lag in lags:
        df[f"pm2_5_lag{lag}"] = df.groupby("city_id")["pm2_5"].shift(lag)
    return df

def make_targets(df, h):
    df = df.sort_values(["city_id","datetime"]).copy()
    df[f"y_h{h}"] = df.groupby("city_id")["pm2_5"].shift(-h)
    df[f"pm10_h{h}"] = df.groupby("city_id")["pm10"].shift(-h)
    return df

def time_split_times(times):
    times = pd.to_datetime(times)
    tr = times[times <= pd.Timestamp(TRAIN_END)]
    va = times[(times > pd.Timestamp(TRAIN_END)) & (times <= pd.Timestamp(VAL_END))]
    te = times[(times > pd.Timestamp(VAL_END)) & (times <= pd.Timestamp(TEST_END))]
    return tr, va, te

def build_knn_graph(city_meta, k=4, ell_km=120.0):
    coords = city_meta[["lat","lon"]].to_numpy()
    coords_rad = np.radians(coords)

    nnm = NearestNeighbors(n_neighbors=k+1, metric="haversine").fit(coords_rad)
    dists, idxs = nnm.kneighbors(coords_rad)

    n = len(coords)
    W = np.zeros((n,n), dtype=np.float32)
    for i in range(n):
        for t in range(1, k+1):
            j = idxs[i,t]
            dist_km = dists[i,t] * 6371.0
            w = np.exp(-(dist_km/ell_km)**2)
            W[i,j] = max(W[i,j], w)
            W[j,i] = max(W[j,i], w)

    D = np.diag(W.sum(axis=1))
    L = (D - W).astype(np.float32)
    return W, L

def hinge_pos(x): return F.relu(x)

def mae_loss(pred, y): return torch.mean(torch.abs(pred - y))

def rmse_np(y, yhat): return float(np.sqrt(np.mean((y-yhat)**2)))
def mae_np(y, yhat):  return float(np.mean(np.abs(y-yhat)))

@torch.no_grad()
def eval_mae_rmse_viol(model, c0, X, y, pm10_tgt, idx, steps, device):
    model.eval()
    all_y, all_p, all_pm10 = [], [], []

    for b0 in range(0, len(idx), BATCH_TIMES):
        ii = idx[b0:b0+BATCH_TIMES]
        c0b = torch.from_numpy(c0[ii]).to(device)
        Xb  = torch.from_numpy(X[ii]).to(device)

        pred = model(c0b, Xb, steps)

        all_p.append(pred.detach().cpu().numpy())
        all_y.append(y[ii])
        all_pm10.append(pm10_tgt[ii])

    y_true = np.concatenate(all_y, axis=0).reshape(-1)
    y_pred = np.concatenate(all_p, axis=0).reshape(-1)
    pm10   = np.concatenate(all_pm10, axis=0).reshape(-1)

    mae = mae_np(y_true, y_pred)
    rmse = rmse_np(y_true, y_pred)
    viol = float(np.mean(y_pred > pm10))
    return mae, rmse, viol

# -----------------------------
# Load data once
# -----------------------------
df = pd.read_csv(DATA_PATH)
df["datetime"] = pd.to_datetime(df["datetime"])

keep_cols = [
    "city_id","city_name","lat","lon","datetime",
    "pm2_5","pm10",
    "carbon_monoxide","nitrogen_dioxide","sulphur_dioxide","ozone",
    "doy_sin","doy_cos","hour_sin","hour_cos"
]
df = df[keep_cols].copy()

df = make_lag_features(df, LAGS)
df = make_targets(df, H)

needed = [f"pm2_5_lag{l}" for l in LAGS] + [f"y_h{H}", f"pm10_h{H}"]
df = df.dropna(subset=needed).copy()

# city order fixed
city_meta = df[["city_id","city_name","lat","lon"]].drop_duplicates().sort_values("city_id").reset_index(drop=True)
city_ids = city_meta["city_id"].to_numpy()
cid_to_idx = {cid:i for i,cid in enumerate(city_ids)}
N = len(city_ids)

X_COLS = [
    "carbon_monoxide","nitrogen_dioxide","sulphur_dioxide","ozone",
    "doy_sin","doy_cos","hour_sin","hour_cos",
] + [f"pm2_5_lag{l}" for l in LAGS] + ["lat","lon"]

def build_time_tensor(df, h):
    df2 = df[["datetime","city_id","pm2_5", f"y_h{h}", f"pm10_h{h}"] + X_COLS].copy()
    df2["cid_idx"] = df2["city_id"].map(cid_to_idx)
    df2 = df2.sort_values(["datetime","cid_idx"])

    counts = df2.groupby("datetime")["cid_idx"].nunique()
    full_times = counts[counts == N].index
    df2 = df2[df2["datetime"].isin(full_times)].copy()

    times = pd.to_datetime(np.sort(df2["datetime"].unique()))
    T = len(times)

    X = np.zeros((T, N, len(X_COLS)), dtype=np.float32)
    c0 = np.zeros((T, N, 1), dtype=np.float32)
    y  = np.zeros((T, N, 1), dtype=np.float32)
    pm10_tgt = np.zeros((T, N, 1), dtype=np.float32)

    g = df2.groupby("datetime", sort=True)
    for ti, t in enumerate(times):
        gt = g.get_group(t).sort_values("cid_idx")
        X[ti,:,:] = gt[X_COLS].to_numpy(np.float32)
        c0[ti,:,0] = gt["pm2_5"].to_numpy(np.float32)
        y[ti,:,0]  = gt[f"y_h{h}"].to_numpy(np.float32)
        pm10_tgt[ti,:,0] = gt[f"pm10_h{h}"].to_numpy(np.float32)

    return times, X, c0, y, pm10_tgt

times, X, c0, y, pm10_tgt = build_time_tensor(df, H)
tr_times, va_times, te_times = time_split_times(times)
time_to_i = {t:i for i,t in enumerate(times)}
tr_idx = np.array([time_to_i[t] for t in tr_times], dtype=int)
va_idx = np.array([time_to_i[t] for t in va_times], dtype=int)
te_idx = np.array([time_to_i[t] for t in te_times], dtype=int)

# scale once (train only)
scaler = StandardScaler()
scaler.fit(X[tr_idx].reshape(-1, X.shape[-1]))
X = scaler.transform(X.reshape(-1, X.shape[-1])).astype(np.float32).reshape(X.shape)

# model
class MLP(nn.Module):
    def __init__(self, in_dim, hidden=64, dropout=0.2, out_dim=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(hidden, out_dim)
        )
    def forward(self, x): return self.net(x)

class PIGNN(nn.Module):
    def __init__(self, x_dim, L, dt=1.0, gamma=1.0, emb_dim=8, hidden=64, dropout=0.2):
        super().__init__()
        self.L = L
        self.dt = dt
        self.gamma = gamma
        self.emb = nn.Embedding(N, emb_dim)
        self.source = MLP(x_dim + emb_dim, hidden=hidden, dropout=dropout, out_dim=1)
        self.D_raw = nn.Parameter(torch.tensor(0.0))
        self.k_raw = nn.Parameter(torch.tensor(-1.0))

    def D(self): return F.softplus(self.D_raw) + 1e-6
    def k(self): return F.softplus(self.k_raw) + 1e-6

    def forward(self, c0, X, steps):
        B, Nn, Fdim = X.shape
        idx = torch.arange(Nn, device=X.device)
        e = self.emb(idx)[None,:,:].expand(B, -1, -1)
        s = self.source(torch.cat([X, e], dim=-1))

        c = c0
        D = self.D()
        k = self.k()
        I = torch.eye(self.L.shape[0], device=c.device, dtype=c.dtype)
        A = I + self.gamma * self.dt * D * self.L + self.dt * k * I

        for _ in range(steps):
            Lc = torch.einsum("ij,bjk->bik", self.L, c)
            rhs = c + self.dt * ((1 - self.gamma) * D * (-Lc) + s)
            c = torch.linalg.solve(A.expand(rhs.shape[0], -1, -1), rhs)

        return c

# Train loop for each k
results=[]
all_pred=[]

for k_graph in K_LIST:
    # rebuild graph Laplacian for this k
    Wk, Lk_np = build_knn_graph(city_meta, k=k_graph, ell_km=ELL_KM)
    Lk = torch.tensor(Lk_np, device=DEVICE)

    print("\n" + "="*60)
    print(f"PI-GNN | h=24 | k={k_graph}")
    print("="*60)

    model = PIGNN(x_dim=X.shape[-1], L=Lk, dt=DT, gamma=GAMMA, hidden=64, dropout=DROPOUT).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    best_val = 1e18
    best_state=None
    bad=0

    for ep in range(1, EPOCHS+1):
        model.train()
        np.random.shuffle(tr_idx)
        tr_losses=[]

        for b0 in range(0, len(tr_idx), BATCH_TIMES):
            ii = tr_idx[b0:b0+BATCH_TIMES]
            c0b = torch.from_numpy(c0[ii]).to(DEVICE)
            Xb  = torch.from_numpy(X[ii]).to(DEVICE)
            yb  = torch.from_numpy(y[ii]).to(DEVICE)
            pm10b = torch.from_numpy(pm10_tgt[ii]).to(DEVICE)

            pred = model(c0b, Xb, H)

            loss = mae_loss(pred, yb)
            loss = loss + LAMBDA_INEQ * torch.mean(hinge_pos(pred - pm10b))
            loss = loss + LAMBDA_NONNEG * torch.mean(hinge_pos(-pred))

            opt.zero_grad()
            loss.backward()
            opt.step()
            tr_losses.append(loss.item())

        val_mae, val_rmse, val_viol = eval_mae_rmse_viol(model, c0, X, y, pm10_tgt, va_idx, H, DEVICE)
        print(f"[k={k_graph}] ep {ep:02d} | train {np.mean(tr_losses):.4f} | "
              f"val_MAE {val_mae:.4f} | val_RMSE {val_rmse:.4f} | val_viol {100*val_viol:.2f}% | "
              f"D {model.D().item():.4f} k {model.k().item():.4f}")

        if val_mae < best_val - 1e-4:
            best_val = val_mae
            best_state = {k: v.detach().cpu().clone() for k,v in model.state_dict().items()}
            bad=0
        else:
            bad += 1
            if bad >= PATIENCE:
                print(f"[k={k_graph}] Early stop. Best val_MAE={best_val:.4f}")
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    # final val/test
    val_mae, val_rmse, val_viol = eval_mae_rmse_viol(model, c0, X, y, pm10_tgt, va_idx, H, DEVICE)
    te_mae, te_rmse, te_viol = eval_mae_rmse_viol(model, c0, X, y, pm10_tgt, te_idx, H, DEVICE)

    results.append({
        "model": "PI-GNN",
        "k_graph": k_graph,
        "horizon_h": H,
        "val_MAE": val_mae, "val_RMSE": val_rmse, "val_viol_rate": val_viol,
        "test_MAE": te_mae, "test_RMSE": te_rmse, "test_viol_rate": te_viol,
        "D_eff": float(model.D().item()),
        "k_eff": float(model.k().item()),
    })

res_df = pd.DataFrame(results)
res_df.to_csv(OUT_MET, index=False)
print("\nSaved:", OUT_MET)
print(res_df.sort_values("test_MAE"))


PI-GNN | h=24 | k=3
[k=3] ep 01 | train 17.2587 | val_MAE 15.1624 | val_RMSE 22.4972 | val_viol 35.49% | D 0.6475 k 0.2948
[k=3] ep 02 | train 11.9904 | val_MAE 14.8532 | val_RMSE 21.9255 | val_viol 36.84% | D 0.4990 k 0.2972
[k=3] ep 03 | train 11.6149 | val_MAE 14.5831 | val_RMSE 21.3843 | val_viol 38.75% | D 0.4038 k 0.2989
[k=3] ep 04 | train 11.4129 | val_MAE 14.6269 | val_RMSE 21.6229 | val_viol 34.53% | D 0.3415 k 0.3009
[k=3] ep 05 | train 11.2432 | val_MAE 14.7396 | val_RMSE 21.9484 | val_viol 32.10% | D 0.3016 k 0.3018
[k=3] ep 06 | train 11.1488 | val_MAE 14.4789 | val_RMSE 21.3944 | val_viol 35.21% | D 0.2797 k 0.3018
[k=3] ep 07 | train 11.1038 | val_MAE 14.2778 | val_RMSE 20.8339 | val_viol 38.57% | D 0.2677 k 0.3013
[k=3] ep 08 | train 11.0257 | val_MAE 14.2522 | val_RMSE 20.9564 | val_viol 35.77% | D 0.2634 k 0.3012
[k=3] ep 09 | train 10.9645 | val_MAE 14.4333 | val_RMSE 21.3673 | val_viol 34.31% | D 0.2642 k 0.3008
[k=3] ep 10 | train 10.9119 | val_MAE 14.2533 | val_