In [13]:
# GS-ADR for PM2.5 forecasting 
#  1) Implicit macro-step dynamics:
#       (I + hdt*(D*L + kI)) c_{t+h} = c_t + hdt*s_theta(...)
#  2) Target-time cyclical features (hour/doy at t+h)
#  3) Graph-aware source term: uses neighbor-aggregated features and neighbor PM2.5
#  5) City embeddings
#  6) Train-only scaling, strict time split, saves predictions in baseline schema

import os
import numpy as np
import pandas as pd

from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F


DATA_PATH = "dataset_2023_2025.csv"
BASELINE_PRED_PATH = "predictions_baselines.csv"  
OUT_PRED_PATH = "predictions_pignnpp.csv"
OUT_ALL_PATH = "predictions_all_with_pignnpp.csv"

K_GRAPH = 4
ELL_KM = 120.0

HORIZONS = [1, 6, 24]
PM_LAGS = [1, 2, 6, 24]
CHEMS = ['carbon_monoxide', 'nitrogen_dioxide', 'sulphur_dioxide', 'ozone']
CHEM_LAGS = [1, 6, 24]

TRAIN_END = "2024-12-31 23:00:00"
VAL_END   = "2025-06-30 23:00:00"
TEST_END  = "2025-11-23 23:00:00"

# training
SEED = 42
BATCH_TIMES = 128
EPOCHS = 30
PATIENCE = 5
LR = 2e-3
WEIGHT_DECAY = 1e-6

# physics + constraints
DT = 1.0
LAMBDA_INEQ = 1.0      
LAMBDA_NONNEG = 0.05

# model capacity
DROPOUT = 0.20
EMB_DIM = 8
HIDDEN = 96

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

np.random.seed(SEED)
torch.manual_seed(SEED)

# -----------------------------
# Helpers
# -----------------------------
def time_split_times(times):
    times = pd.to_datetime(times)
    tr = times[times <= pd.Timestamp(TRAIN_END)]
    va = times[(times > pd.Timestamp(TRAIN_END)) & (times <= pd.Timestamp(VAL_END))]
    te = times[(times > pd.Timestamp(VAL_END)) & (times <= pd.Timestamp(TEST_END))]
    return tr, va, te

def build_knn_graph(city_meta, k=4, ell_km=120.0):
    coords = city_meta[['lat', 'lon']].to_numpy()
    coords_rad = np.radians(coords)

    nnm = NearestNeighbors(n_neighbors=k+1, metric='haversine')
    nnm.fit(coords_rad)
    dists, idxs = nnm.kneighbors(coords_rad)

    n = len(city_meta)
    W = np.zeros((n, n), dtype=np.float32)

    for i in range(n):
        for t in range(1, k+1):
            j = idxs[i, t]
            dist_km = dists[i, t] * 6371.0
            w = np.exp(-(dist_km / ell_km) ** 2)
            if w > W[i, j]:
                W[i, j] = w
            if w > W[j, i]:
                W[j, i] = w

    D = np.diag(W.sum(axis=1))
    L = (D - W).astype(np.float32)
    return W, L

def hinge_pos(x):
    return F.relu(x)

def mae_loss(pred, y):
    return torch.mean(torch.abs(pred - y))

# Load
df = pd.read_csv(DATA_PATH)
df['datetime'] = pd.to_datetime(df['datetime'])

keep_cols = [
    'city_id','city_name','lat','lon','datetime',
    'pm2_5','pm10',
] + CHEMS
df = df[keep_cols].copy()
df = df.sort_values(['city_id','datetime']).reset_index(drop=True)

# Build lags 
for lag in PM_LAGS:
    df[f'pm2_5_lag{lag}'] = df.groupby('city_id')['pm2_5'].shift(lag)

for col in CHEMS:
    for lag in CHEM_LAGS:
        df[f'{col}_lag{lag}'] = df.groupby('city_id')[col].shift(lag)

# Targets for each horizon
for h in HORIZONS:
    df[f'y_h{h}'] = df.groupby('city_id')['pm2_5'].shift(-h)
    df[f'pm10_h{h}'] = df.groupby('city_id')['pm10'].shift(-h)

# Drop NA rows needed for any horizon (keeps things consistent)
needed = []
needed += [f'pm2_5_lag{l}' for l in PM_LAGS]
needed += [f'{c}_lag{l}' for c in CHEMS for l in CHEM_LAGS]
needed += [f'y_h{h}' for h in HORIZONS] + [f'pm10_h{h}' for h in HORIZONS]
df = df.dropna(subset=needed).copy()

# City metadata + ordering
city_meta = df[['city_id','city_name','lat','lon']].drop_duplicates().sort_values('city_id').reset_index(drop=True)
city_ids = city_meta['city_id'].to_numpy()
cid_to_idx = {cid:i for i,cid in enumerate(city_ids)}
N = len(city_ids)

# Graph
W_np, L_np = build_knn_graph(city_meta, k=K_GRAPH, ell_km=ELL_KM)
L = torch.tensor(L_np, device=DEVICE)  # (N,N)
W = torch.tensor(W_np, device=DEVICE)  # (N,N)

# Row-normalized adjacency for neighbor aggregation
deg = W.sum(dim=1, keepdim=True).clamp_min(1e-6)
A_norm = W / deg  # (N,N)

print(f"Device: {DEVICE}")
print(f"Cities: {N}  Graph k={K_GRAPH}  Avg deg={(W_np>0).sum(axis=1).mean():.2f}")

# Build time-indexed tensors (T,N,F) per horizon
def build_time_tensor(df_in, horizon_h):
    df2 = df_in.copy()
    df2['cid_idx'] = df2['city_id'].map(cid_to_idx)
    df2 = df2.sort_values(['datetime','cid_idx'])

    # target-time cyclical features (t+h)
    dt_tgt = df2['datetime'] + pd.to_timedelta(horizon_h, unit='h')
    hour = dt_tgt.dt.hour.values
    doy = dt_tgt.dt.dayofyear.values

    df2['hour_sin_tgt'] = np.sin(2*np.pi*hour/24.0)
    df2['hour_cos_tgt'] = np.cos(2*np.pi*hour/24.0)
    df2['doy_sin_tgt']  = np.sin(2*np.pi*doy/365.0)
    df2['doy_cos_tgt']  = np.cos(2*np.pi*doy/365.0)

    # feature columns (PM10 excluded)
    X_COLS = []
    # current chems
    X_COLS += CHEMS
    # chemical lags
    X_COLS += [f'{c}_lag{l}' for c in CHEMS for l in CHEM_LAGS]
    # PM2.5 lags
    X_COLS += [f'pm2_5_lag{l}' for l in PM_LAGS]
    # target-time calendar
    X_COLS += ['doy_sin_tgt','doy_cos_tgt','hour_sin_tgt','hour_cos_tgt']
    # static
    X_COLS += ['lat','lon']

    # keep only full times (all cities present)
    counts = df2.groupby('datetime')['cid_idx'].nunique()
    full_times = counts[counts == N].index
    df2 = df2[df2['datetime'].isin(full_times)].copy()

    times = pd.to_datetime(np.sort(df2['datetime'].unique()))
    T = len(times)

    X = np.zeros((T, N, len(X_COLS)), dtype=np.float32)
    c0 = np.zeros((T, N, 1), dtype=np.float32)
    y  = np.zeros((T, N, 1), dtype=np.float32)
    pm10_tgt = np.zeros((T, N, 1), dtype=np.float32)

    g = df2.groupby('datetime', sort=True)
    for ti, t in enumerate(times):
        gt = g.get_group(t).sort_values('cid_idx')
        X[ti,:,:] = gt[X_COLS].to_numpy(np.float32)
        c0[ti,:,0] = gt['pm2_5'].to_numpy(np.float32)
        y[ti,:,0]  = gt[f'y_h{horizon_h}'].to_numpy(np.float32)
        pm10_tgt[ti,:,0] = gt[f'pm10_h{horizon_h}'].to_numpy(np.float32)

    return times, X, c0, y, pm10_tgt, X_COLS

# Model: Graph-aware source + city embeddings + implicit macro-step physics
class SourceNet(nn.Module):
    def __init__(self, in_dim, hidden=96, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1),
        )

    def forward(self, z):   # (B,N,in_dim)
        return self.net(z)  # (B,N,1)

class PIGNNPP(nn.Module):
    def __init__(self, x_dim, L, A_norm, n_cities, emb_dim=8, hidden=96, dt=1.0, dropout=0.2):
        super().__init__()
        self.L = L
        self.A = A_norm
        self.dt = dt

        self.city_emb = nn.Embedding(n_cities, emb_dim)
        self.register_buffer("city_idx", torch.arange(n_cities, dtype=torch.long))

        # Source gets: [x, x_nb, c0, c_nb, emb]
        self.source = SourceNet(in_dim=2*x_dim + 2*1 + emb_dim, hidden=hidden, dropout=dropout)

        self.D_raw = nn.Parameter(torch.tensor(0.0))
        self.k_raw = nn.Parameter(torch.tensor(-1.0))

    def D(self):
        return F.softplus(self.D_raw) + 1e-6

    def k(self):
        return F.softplus(self.k_raw) + 1e-6

    def forward(self, c0, x, steps):
        """
        c0: (B,N,1)
        x : (B,N,F)
        """
        # neighbor-aggregated features/state
        x_nb = torch.einsum('ij,bjk->bik', self.A, x)      # (B,N,F)
        c_nb = torch.einsum('ij,bjk->bik', self.A, c0)     # (B,N,1)

        # city embeddings (N,emb) -> (B,N,emb)
        emb = self.city_emb(self.city_idx).unsqueeze(0).expand(x.shape[0], -1, -1)

        # source input
        z = torch.cat([x, x_nb, c0, c_nb, emb], dim=-1)    # (B,N, 2F+2+emb)
        s = self.source(z)                                 # (B,N,1)

        # physics implicit macro-step
        D = self.D()
        k = self.k()
        hdt = steps * self.dt

        N = self.L.shape[0]
        I = torch.eye(N, device=c0.device, dtype=c0.dtype)

        A_sys = I + hdt * (D * self.L + k * I)            # (N,N)
        A_b = A_sys.unsqueeze(0).expand(c0.shape[0], -1, -1)

        rhs = c0 + hdt * s                                # (B,N,1)
        pred = torch.linalg.solve(A_b, rhs)
        return pred

# Evaluation
@torch.no_grad()
def eval_mae(model, c0, X, y, pm10_tgt, times_idx, steps):
    model.eval()
    losses, viols = [], []

    for b0 in range(0, len(times_idx), BATCH_TIMES):
        idx = times_idx[b0:b0+BATCH_TIMES]
        c0b = torch.from_numpy(c0[idx]).to(DEVICE)
        Xb  = torch.from_numpy(X[idx]).to(DEVICE)
        yb  = torch.from_numpy(y[idx]).to(DEVICE)
        pm10b = torch.from_numpy(pm10_tgt[idx]).to(DEVICE)

        pred = model(c0b, Xb, steps)
        losses.append(torch.mean(torch.abs(pred - yb)).item())

        v = (pred - pm10b) > 0
        viols.append(v.float().mean().item())

    return float(np.mean(losses)), float(np.mean(viols))

# Train one horizon
def train_one_horizon(h):
    times, X, c0, y, pm10_tgt, X_COLS = build_time_tensor(df, h)
    tr_times, va_times, te_times = time_split_times(times)

    time_to_i = {t:i for i,t in enumerate(times)}
    tr_idx = np.array([time_to_i[t] for t in tr_times], dtype=int)
    va_idx = np.array([time_to_i[t] for t in va_times], dtype=int)
    te_idx = np.array([time_to_i[t] for t in te_times], dtype=int)

    # scale X using TRAIN only
    scaler = StandardScaler()
    scaler.fit(X[tr_idx].reshape(-1, X.shape[-1]))
    X = scaler.transform(X.reshape(-1, X.shape[-1])).astype(np.float32).reshape(X.shape)

    model = PIGNNPP(
        x_dim=X.shape[-1],
        L=L,
        A_norm=A_norm,
        n_cities=N,
        emb_dim=EMB_DIM,
        hidden=HIDDEN,
        dt=DT,
        dropout=DROPOUT
    ).to(DEVICE)

    opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    best_val = 1e18
    best_state = None
    bad = 0
    steps = int(h)

    for ep in range(1, EPOCHS+1):
        model.train()
        np.random.shuffle(tr_idx)

        tr_losses = []
        for b0 in range(0, len(tr_idx), BATCH_TIMES):
            idx = tr_idx[b0:b0+BATCH_TIMES]
            c0b = torch.from_numpy(c0[idx]).to(DEVICE)
            Xb  = torch.from_numpy(X[idx]).to(DEVICE)
            yb  = torch.from_numpy(y[idx]).to(DEVICE)
            pm10b = torch.from_numpy(pm10_tgt[idx]).to(DEVICE)

            pred = model(c0b, Xb, steps)
            loss = mae_loss(pred, yb)

            # constraints (train only)
            if LAMBDA_INEQ > 0:
                loss = loss + LAMBDA_INEQ * torch.mean(hinge_pos(pred - pm10b))
            if LAMBDA_NONNEG > 0:
                loss = loss + LAMBDA_NONNEG * torch.mean(hinge_pos(-pred))

            opt.zero_grad()
            loss.backward()
            opt.step()

            tr_losses.append(loss.item())

        val_mae, val_viol = eval_mae(model, c0, X, y, pm10_tgt, va_idx, steps)
        D_val = model.D().item()
        k_val = model.k().item()

        print(f"[h={h:2d}] ep {ep:02d} | train {np.mean(tr_losses):.4f} | val_MAE {val_mae:.4f} | "
              f"val_viol {100*val_viol:.2f}% | D {D_val:.4f} k {k_val:.4f}")

        if val_mae < best_val - 1e-4:
            best_val = val_mae
            best_state = {kk: vv.detach().cpu().clone() for kk, vv in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= PATIENCE:
                print(f"[h={h}] Early stop. Best val_MAE={best_val:.4f}")
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    test_mae, test_viol = eval_mae(model, c0, X, y, pm10_tgt, te_idx, steps)
    print(f"[h={h:2d}] TEST | MAE {test_mae:.4f} | viol {100*test_viol:.2f}%")

    return model, scaler, times, X, c0, y, pm10_tgt, va_idx, te_idx

# Predict + save
@torch.no_grad()
def predict_split(model, times, X, c0, y, pm10_tgt, idx, h, split_name, model_name):
    model.eval()
    steps = int(h)
    out_rows = []

    for b0 in range(0, len(idx), BATCH_TIMES):
        ii = idx[b0:b0+BATCH_TIMES]
        c0b = torch.from_numpy(c0[ii]).to(DEVICE)
        Xb  = torch.from_numpy(X[ii]).to(DEVICE)

        pred = model(c0b, Xb, steps).detach().cpu().numpy()  # (B,N,1)

        for bi, ti in enumerate(ii):
            t = times[ti]
            for n in range(N):
                out_rows.append({
                    "city_id": int(city_ids[n]),
                    "datetime": pd.Timestamp(t),
                    "split": split_name,
                    "model": model_name,
                    "horizon_h": int(h),
                    "y_true": float(y[ti, n, 0]),
                    "y_pred": float(pred[bi, n, 0]),
                    "pm10_true": float(pm10_tgt[ti, n, 0]),
                })

    return out_rows

# Run all horizons
all_pred_rows = []

for h in HORIZONS:
    model, scaler, times, X, c0, y, pm10_tgt, va_idx, te_idx = train_one_horizon(h)

    tag = "PI-GNN++"
    if (LAMBDA_INEQ > 0) or (LAMBDA_NONNEG > 0):
        tag = "PI-GNN++_constraints"

    all_pred_rows += predict_split(model, times, X, c0, y, pm10_tgt, va_idx, h, "val", tag)
    all_pred_rows += predict_split(model, times, X, c0, y, pm10_tgt, te_idx, h, "test", tag)

pred_df = pd.DataFrame(all_pred_rows)
pred_df = pred_df.sort_values(["split","horizon_h","model","city_id","datetime"]).reset_index(drop=True)
pred_df.to_csv(OUT_PRED_PATH, index=False)
print(f"\nSaved: {OUT_PRED_PATH}  rows={len(pred_df)}  models={pred_df['model'].nunique()}")

if os.path.exists(BASELINE_PRED_PATH):
    base = pd.read_csv(BASELINE_PRED_PATH, parse_dates=["datetime"])
    pred_df2 = pred_df.copy()
    pred_df2["datetime"] = pd.to_datetime(pred_df2["datetime"])
    combo = pd.concat([base, pred_df2], ignore_index=True)
    combo.to_csv(OUT_ALL_PATH, index=False)
    print(f"Saved: {OUT_ALL_PATH}  rows={len(combo)}")

Device: cpu
Cities: 29  Graph k=4  Avg deg=4.90
[h= 1] ep 01 | train 4.1647 | val_MAE 3.8813 | val_viol 16.47% | D 0.5864 k 0.3062
[h= 1] ep 02 | train 2.7254 | val_MAE 3.5433 | val_viol 15.28% | D 0.5375 k 0.3049
[h= 1] ep 03 | train 2.5602 | val_MAE 3.4039 | val_viol 17.17% | D 0.4978 k 0.3026
[h= 1] ep 04 | train 2.4651 | val_MAE 3.3828 | val_viol 15.44% | D 0.4656 k 0.2997
[h= 1] ep 05 | train 2.3890 | val_MAE 3.2875 | val_viol 16.64% | D 0.4406 k 0.2961
[h= 1] ep 06 | train 2.3586 | val_MAE 3.2765 | val_viol 16.40% | D 0.4171 k 0.2922
[h= 1] ep 07 | train 2.3143 | val_MAE 3.3081 | val_viol 13.21% | D 0.3967 k 0.2879
[h= 1] ep 08 | train 2.2907 | val_MAE 3.1779 | val_viol 15.34% | D 0.3777 k 0.2831
[h= 1] ep 09 | train 2.2697 | val_MAE 3.1110 | val_viol 14.47% | D 0.3585 k 0.2782
[h= 1] ep 10 | train 2.2537 | val_MAE 3.1796 | val_viol 13.13% | D 0.3409 k 0.2733
[h= 1] ep 11 | train 2.2136 | val_MAE 3.1453 | val_viol 12.84% | D 0.3275 k 0.2682
[h= 1] ep 12 | train 2.1890 | val_MAE 3