In [3]:
import numpy as np
import pandas as pd
from pathlib import Path
import argparse

def normalize_station_id(x):
    try:
        return str(int(x))
    except:
        return str(x).strip()

def make_windows(X, M, window=16, stride=8):
    # X: (N, L), M: (N, L)
    N, L = X.shape
    windows = []
    masks = []
    for start in range(0, L - window + 1, stride):
        w = X[:, start:start+window]        # (N, window)
        m = M[:, start:start+window]
        # transpose to (window, N)
        windows.append(w.T[..., np.newaxis])  # (window, N, 1)
        masks.append(m.T[..., np.newaxis])
    if len(windows) == 0:
        return np.zeros((0, window, N, 1), dtype=np.float32), np.zeros((0, window, N, 1), dtype=np.float32)
    Xb = np.stack(windows, axis=0).astype(np.float32)  # (B, T, N, 1)
    Mb = np.stack(masks, axis=0).astype(np.float32)
    return Xb, Mb

def main(data_dir="data/Discharge", window=16, stride=8, save_to="data/Discharge/processed.npz"):
    base = Path(data_dir)
    src = pd.read_csv(base / "SSC_discharge.csv", parse_dates=["date"]).set_index("date")
    tgt = pd.read_csv(base / "SSC_pooled.csv", parse_dates=["date"]).set_index("date")
    flow = pd.read_csv(base / "SSC_sites_flow_direction.csv", index_col=0)

    # normalize station names
    stations = [normalize_station_id(c) for c in src.columns if c != "date"]
    src = src.loc[:, src.columns]  # keep original col order
    src.columns = stations
    tgt.columns = [normalize_station_id(c) for c in tgt.columns]
    flow.index = flow.index.map(normalize_station_id)
    flow.columns = flow.columns.map(normalize_station_id)

    # intersect dates
    common = src.index.intersection(tgt.index).sort_values()
    src = src.loc[common]
    tgt = tgt.loc[common]

    # Ensure flow contains all stations: reorder flow to the station order
    missing = [s for s in stations if s not in flow.index]
    if missing:
        print("WARNING: the following stations are missing in flow matrix:", missing)
    # Reindex (if index missing, rows/cols will be NaN -> fill zeros)
    flow = flow.reindex(index=stations, columns=stations, fill_value=0)
    adj = flow.values.astype(np.float32)

    # build arrays (N, L)
    X_source = src.values.T.astype(np.float64)  # (N, L)
    X_target_raw = tgt.values.T  # (N, L) contains NaNs

    mask_target = (~pd.DataFrame(X_target_raw).isna()).astype(float).values  # (N, L)
    X_target_filled = np.nan_to_num(X_target_raw, nan=0.0).astype(np.float64)

    # normalize per-station using source (dense)
    station_mean = np.nanmean(X_source, axis=1, keepdims=True)
    station_std = np.nanstd(X_source, axis=1, keepdims=True) + 1e-6

    X_target_norm = (X_target_filled - station_mean) / station_std

    # make windows
    Xb, Mb = make_windows(X_target_norm, mask_target, window=window, stride=stride)
    # Save meta and arrays
    np.savez_compressed(save_to,
                        x=Xb.astype(np.float32),
                        mask=Mb.astype(np.float32),
                        adj=adj.astype(np.float32),
                        station_ids=np.array(stations),
                        mean=station_mean.astype(np.float32).squeeze(),
                        std=station_std.astype(np.float32).squeeze())
    print(f"Saved processed data to {save_to}. Shapes: x={Xb.shape}, mask={Mb.shape}, adj={adj.shape}")

if __name__ == "__main__":
    import sys
    sys.argv = [sys.argv[0]]
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", default="data/Discharge")
    parser.add_argument("--window", type=int, default=16)
    parser.add_argument("--stride", type=int, default=8)
    parser.add_argument("--out", default="data/Discharge/processed.npz")
    args = parser.parse_args()
    main(data_dir=args.data_dir, window=args.window, stride=args.stride, save_to=args.out)


Saved processed data to data/Discharge/processed.npz. Shapes: x=(339, 16, 20, 1), mask=(339, 16, 20, 1), adj=(20, 20)


In [4]:
import numpy as np
data = np.load("data/Discharge/processed.npz")
x, mask, adj = data["x"], data["mask"], data["adj"]
print(x.shape, mask.shape, adj.shape)

(339, 16, 20, 1) (339, 16, 20, 1) (20, 20)


In [5]:
import torch
import torch.nn as nn
import numpy as np
from lib.nn.models.grin import GRINet

device = "cpu"

x_t = torch.tensor(x, dtype=torch.float32).to(device)
mask_t = torch.tensor(mask, dtype=torch.bool).to(device)
adj_t = torch.tensor(adj, dtype=torch.float32).to(device)

B, T, N, C = x_t.shape

model = GRINet(
    adj=adj_t.numpy(),
    d_in=1,
    d_hidden=64,
    d_ff=64,
    ff_dropout=0.1,
    n_layers=1,
    kernel_size=2,
    decoder_order=1,
    impute_only_holes=True
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.L1Loss(reduction="none")

epochs = 20

for ep in range(epochs):
    model.train()
    total_loss = 0
    
    for i in range(B):
        x_i = x_t[i:i+1]
        m_i = mask_t[i:i+1]

        aug_mask = (torch.rand_like(m_i.float()) < 0.1) & m_i
        train_mask = m_i & (~aug_mask)

        optimizer.zero_grad()
        imputed, _ = model(x_i, train_mask)

        # only evaluate added missing positions
        target_mask = (~train_mask) & m_i  # only original observed that we masked
        if target_mask.sum() == 0:
            continue

        loss = (loss_fn(imputed, x_i) * target_mask.float()).sum() / target_mask.sum()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {ep+1}, Loss = {total_loss:.6f}")


Epoch 1, Loss = 82.578490
Epoch 2, Loss = 64.579495
Epoch 3, Loss = 69.506373
Epoch 4, Loss = 60.363521
Epoch 5, Loss = 70.075543
Epoch 6, Loss = 61.893148
Epoch 7, Loss = 64.442609
Epoch 8, Loss = 60.818612
Epoch 9, Loss = 64.965758
Epoch 10, Loss = 63.308463
Epoch 11, Loss = 66.587658
Epoch 12, Loss = 67.658133
Epoch 13, Loss = 69.100616
Epoch 14, Loss = 49.487439
Epoch 15, Loss = 59.068139
Epoch 16, Loss = 62.477311
Epoch 17, Loss = 66.249729
Epoch 18, Loss = 57.942661
Epoch 19, Loss = 59.419797
Epoch 20, Loss = 44.194377


In [6]:
model.eval()
with torch.no_grad():
    final_output = model(x_t, mask_t)     

In [7]:
W = final_output.shape[0]  # 339
window = final_output.shape[1]  # 16
stride = 8

L_total = (W - 1) * stride + window
print("Total length:", L_total)

Total length: 2720


In [8]:
import numpy as np

def merge_windows(imputed, L_total, window=16, stride=8):
    # imputed: (B, window, N, 1)
    B, w, N, _ = imputed.shape
    full = np.zeros((L_total, N))
    count = np.zeros((L_total, N))

    pos = 0
    for i in range(B):
        end = min(pos + window, L_total)
        take = end - pos
        full[pos:end] += imputed[i, :take, :, 0]
        count[pos:end] += 1
        pos += stride

    full = full / np.maximum(count, 1)
    return full

In [9]:
merged_norm = merge_windows(final_output.numpy(), L_total)
print("Merged shape:", merged_norm.shape)

Merged shape: (2720, 20)


In [10]:
data = np.load("data/Discharge/processed.npz")
mean = data["mean"]  # shape (20,)
std = data["std"]
merged_real = merged_norm * std + mean

In [11]:
tgt = pd.read_csv("data/Discharge/SSC_pooled.csv", parse_dates=["date"]).set_index("date")
L_final = merged_real.shape[0]
tgt = tgt.iloc[:L_final]

target_array = tgt.values
original_missing = np.isnan(target_array)
valid_mask = ~original_missing

np.random.seed(42)
synthetic_mask = valid_mask & (np.random.rand(*valid_mask.shape) < 0.2)

y_true = target_array[synthetic_mask]
y_pred = merged_real[synthetic_mask]


In [12]:
from sklearn.metrics import mean_absolute_error, mean_squared_error
mae = mean_absolute_error(y_true, y_pred)
rmse = mean_squared_error(y_true, y_pred)
print("MAE =", mae)
print("RMSE =", rmse)

MAE = 5.997624942979036e-05
RMSE = 1.92834783018547e-08
