In [None]:
!pip install neuralhydrology

In [None]:
import os
import copy
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from typing import Dict
from torch import Tensor
import pandas as pd

from neuralhydrology.modelzoo.handoff_forecast_lstm import HandoffForecastLSTM
from neuralhydrology.utils.config import Config

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Preprocess

In [None]:
DATA_DIR = "/kaggle/input/full-hydro"

attrs = pd.read_csv(os.path.join(DATA_DIR, "step5_static.csv"))
meteo = pd.read_csv(os.path.join(DATA_DIR, "step5_era5.csv"))
hres = pd.read_csv(os.path.join(DATA_DIR, "step5_ecmwf.csv"))
q = pd.read_csv(os.path.join(DATA_DIR, "step5_discharge.csv"))
precip = pd.read_csv(os.path.join(DATA_DIR, "step5_precip.csv"))

keys = ["grdc_no", "date"]

for _df in [meteo, hres, q, precip]:
    _df["date"] = pd.to_datetime(_df["date"])
    
flood_thr = (
    q.dropna(subset=["q_obs"])
     .groupby("grdc_no")["q_obs"]
     .quantile(0.95)
     .to_dict()
)

USE_METEO = True
USE_HRES = True
USE_NOAA = True
USE_STATIC = True
USE_LOG = True

df = None

df = meteo.copy()

if USE_HRES:
    df = df.merge(hres, on=keys, how="inner")

if USE_NOAA:
    df = df.merge(precip, on=keys, how="inner")

# discharge luôn phải có
df = df.merge(q, on=keys, how="inner")

if USE_STATIC:
    df = df.merge(attrs, on="grdc_no", how="left")
    
df = df.sort_values(["grdc_no", "date"]).reset_index(drop=True)

LOG1P_COLS = []

if USE_METEO:
    LOG1P_COLS += ["tp_x"]

if USE_HRES:
    LOG1P_COLS += ["tp_y"]

if USE_NOAA:
    LOG1P_COLS += ["prcp_NOAA"]

LOG1P_COLS += ["q_obs"]

for c in LOG1P_COLS:
    if c in df.columns:
        df[c] = np.log1p(np.clip(df[c].to_numpy(), 0, None))
    else:
        print(f"[warn] column not found, skipping log1p: {c}")
        
print("Total records:", len(df))
df.columns

In [None]:
METEO_BASE = ["t2m", "tp", "sp", "ssr", "str"]
METEO_COLS = [f"{b}_x" for b in METEO_BASE] + ["prcp_NOAA"]
HRES_COLS  = [f"{b}_y" for b in METEO_BASE]
TARGET = "q_obs"

STATIC_COLS = [
    "area", "mean_elevation", "mean_slope",
    "clay_fraction", "sand_fraction", "silt_fraction",
    "forest_cover", "urban_cover", "cropland_cover",
    "aridity_index", "mean_precip"
]

SEQ_H = 365
SEQ_F = 1

In [None]:
def compute_norm_stats(df, basin_ids):
    df_sub = df[df["grdc_no"].isin(basin_ids)].copy()

    basin_static = df_sub.groupby("grdc_no")[STATIC_COLS].first()

    stats = {
        "ATTR_MEAN": basin_static.mean(),
        "ATTR_STD":  basin_static.std().replace(0, 1.0),

        "HRES_MEAN": df_sub[HRES_COLS].mean(),
        "HRES_STD":  df_sub[HRES_COLS].std().replace(0, 1.0),

        "METEO_MEAN": df_sub[METEO_COLS].mean(),
        "METEO_STD":  df_sub[METEO_COLS].std().replace(0, 1.0),
    }
    return stats

In [None]:
def make_data_for_nh(df_basin, t_idx, stats):
    df_basin = df_basin.sort_values("date").reset_index(drop=True)

    start = t_idx - SEQ_H
    end   = t_idx + SEQ_F
    assert start >= 0 and end <= len(df_basin)

    window = df_basin.iloc[start:end]

    hind = window.iloc[:SEQ_H]
    fore = window.iloc[SEQ_H:]

    x_d_hindcast = {}
    for c in METEO_COLS:
        vals = ((hind[c] - stats["METEO_MEAN"][c]) / stats["METEO_STD"][c]).to_numpy(dtype=np.float32)
        x_d_hindcast[c] = torch.tensor(vals).unsqueeze(-1).unsqueeze(0)   # (1, SEQ_H, 1)

    x_d_forecast = {}
    for c in HRES_COLS:
        vals = ((fore[c] - stats["HRES_MEAN"][c]) / stats["HRES_STD"][c]).to_numpy(dtype=np.float32)
        x_d_forecast[c] = torch.tensor(vals).unsqueeze(-1).unsqueeze(0)   # (1, SEQ_F, 1)

    # static attributes (normalized)
    x_s_vals = ((window.iloc[0][STATIC_COLS] - stats["ATTR_MEAN"]) / stats["ATTR_STD"]).to_numpy(dtype=np.float32)
    x_s = torch.tensor(x_s_vals).unsqueeze(0)

    y = torch.tensor(
        window[[TARGET]].values, dtype=torch.float32
    ).unsqueeze(0)                  # (1, SEQ_H+SEQ_F, 1)

    data = {
        "x_d_hindcast": x_d_hindcast,
        "x_d_forecast": x_d_forecast,
        "x_s": x_s,
        "y": y
    }

    return data

print("Total basins:", df["grdc_no"].nunique())
print("SEQ_H:", SEQ_H, "SEQ_F:", SEQ_F)
total_windows = 0

for basin, df_basin in df.groupby("grdc_no"):
    n_days = len(df_basin)
    n_windows = max(0, n_days - SEQ_H - SEQ_F + 1)
    total_windows += n_windows
print("Total samples after windowing:", total_windows)



In [None]:
cfg = Config({
    "model": "HandoffForecastLSTM",

    # ----- sequence -----
    "seq_length": SEQ_H + SEQ_F,      # 366
    "forecast_seq_length": SEQ_F,     # 1

    # ----- inputs -----
    "hindcast_inputs": [METEO_COLS],
    "forecast_inputs": [HRES_COLS],
    "dynamic_inputs": [HRES_COLS + METEO_COLS],

    "static_attributes": STATIC_COLS,

    # ----- embeddings -----
    "statics_embedding": {
        "type": "fc",
        "hiddens": [32],
        "activation": "tanh",
        "dropout": 0.0
    },

    "dynamics_embedding": {
        "type": "fc",
        "hiddens": [64],
        "activation": "tanh",
        "dropout": 0.0
    },

    # ----- model params -----
    "hidden_size": 128,
    "initial_forget_bias": 1.0,
    "output_dropout": 0.0,

    # ----- head -----
    "head": "regression",
    "target_variables": [TARGET],

    # ----- required -----
    "nan_handling_method": None,
    "timestep_counter": False,

    "state_handoff_network": {
        "type": "fc",
        "hiddens": [128],
        "activation": "tanh",
        "dropout": 0.0
    },

})


In [None]:
model = HandoffForecastLSTM(cfg)
model.train()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


In [None]:
def preprocess_basins(df):
    basins = {}

    for basin, df_basin in df.groupby("grdc_no"):
        df_basin = df_basin.sort_values("date").reset_index(drop=True)

        basins[basin] = {
            "df": df_basin,
            "n_days": len(df_basin)
        }

    return basins

basins = preprocess_basins(df)
basins_all = basins

In [None]:
class HydroWindowDataset(Dataset):
    def __init__(self, basins, seq_h, seq_f):
        self.samples = []
        self.basins = basins
        self.seq_h = seq_h
        self.seq_f = seq_f

        for basin, d in basins.items():
            n = d["n_days"]
            for t in range(seq_h, n - seq_f + 1):
                self.samples.append((basin, t))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


dataset = HydroWindowDataset(basins, SEQ_H, SEQ_F)
print("Total samples:", len(dataset))

def make_collate_fn(basins_dict, stats):
    def collate_fn(batch):
        x_d_h, x_d_f, x_s, y = {}, {}, [], []
        basin_ids = []
        
        for c in METEO_COLS:
            x_d_h[c] = []
        for c in HRES_COLS:
            x_d_f[c] = []

        for basin, t_idx in batch:
            basin_ids.append(int(basin))
            
            df_basin = basins_dict[basin]["df"]
            data = make_data_for_nh(df_basin, t_idx, stats)

            for c in METEO_COLS:
                x_d_h[c].append(data["x_d_hindcast"][c][0])
            for c in HRES_COLS:
                x_d_f[c].append(data["x_d_forecast"][c][0])

            x_s.append(data["x_s"][0])
            y.append(data["y"][0])

        for k in x_d_h:
            x_d_h[k] = torch.stack(x_d_h[k], dim=0)
        for k in x_d_f:
            x_d_f[k] = torch.stack(x_d_f[k], dim=0)

        x_s = torch.stack(x_s, dim=0)
        y   = torch.stack(y, dim=0)

        return {"basin_id": torch.tensor(basin_ids, dtype=torch.long), "x_d_hindcast": x_d_h, "x_d_forecast": x_d_f, "x_s": x_s, "y": y}

    return collate_fn


def train_one_epoch(model, loader, optimizer, loss_fn, device, epoch):
    model.train()
    total_loss = 0.0

    for batch_idx, data in enumerate(loader, start=1):

        # move to device
        for k in data["x_d_hindcast"]:
            data["x_d_hindcast"][k] = data["x_d_hindcast"][k].to(device)
        for k in data["x_d_forecast"]:
            data["x_d_forecast"][k] = data["x_d_forecast"][k].to(device)

        data["x_s"] = data["x_s"].to(device)
        data["y"]   = data["y"].to(device)

        optimizer.zero_grad()

        # forward
        out = model(data)

        y_true = data["y"]
        y_true_f = y_true[:, -cfg.forecast_seq_length:, :]

        y_hat = out["y_hat"]
        y_pred = y_hat[:, -cfg.forecast_seq_length:, :]

        loss = loss_fn(y_pred, y_true_f)

        # backward
        loss.backward()
        optimizer.step()

        batch_loss = loss.item()
        total_loss += batch_loss

        # LOG BATCH LOSS
        if batch_idx % 50 == 0 or batch_idx == 1:
            print(
                f"[Epoch {epoch:03d}] "
                f"Batch {batch_idx:05d}/{len(loader)} | "
                f"Batch Loss: {batch_loss:.6f}"
            )

    # EPOCH LOSS
    avg_loss = total_loss / len(loader)

    print(
        f"===== Epoch {epoch:03d} DONE | "
        f"Avg Loss: {avg_loss:.6f} ====="
    )

    return avg_loss


# Helpers

In [None]:
def basin_window_counts(basins_dict, seq_h, seq_f):
    counts = {}
    for b, d in basins_dict.items():
        n = int(d["n_days"])
        counts[b] = max(0, n - seq_h - seq_f + 1)
    return counts

def greedy_test_split_by_windows(basin_ids, counts, test_frac=0.10):
    basin_ids = sorted(basin_ids, key=lambda b: counts.get(b, 0), reverse=True)
    total = sum(counts.get(b, 0) for b in basin_ids)
    target = test_frac * total

    test, s = [], 0
    for b in basin_ids:
        if s < target:
            test.append(b)
            s += counts.get(b, 0)
    train = [b for b in basin_ids if b not in set(test)]
    return train, test

@torch.no_grad()
def predict_all(model, loader, device):
    model.eval()
    ys, yps, bids = [], [], []

    for data in loader:
        # keep basin ids on CPU
        b = data["basin_id"]
        if torch.is_tensor(b):
            bids.append(b.cpu().numpy().reshape(-1))
        else:
            bids.append(np.asarray(b).reshape(-1))

        for k in data["x_d_hindcast"]:
            data["x_d_hindcast"][k] = data["x_d_hindcast"][k].to(device)
        for k in data["x_d_forecast"]:
            data["x_d_forecast"][k] = data["x_d_forecast"][k].to(device)

        data["x_s"] = data["x_s"].to(device)
        data["y"]   = data["y"].to(device)

        out = model(data)

        y_true = data["y"][:, -cfg.forecast_seq_length:, :].detach().cpu().numpy().reshape(-1)
        y_pred = out["y_hat"][:, -cfg.forecast_seq_length:, :].detach().cpu().numpy().reshape(-1)

        ys.append(y_true)
        yps.append(y_pred)

    return np.concatenate(bids, axis=0), np.concatenate(ys, axis=0), np.concatenate(yps, axis=0)

@torch.no_grad()
def eval_one_epoch(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0.0

    for data in loader:
        # move to device
        for k in data["x_d_hindcast"]:
            data["x_d_hindcast"][k] = data["x_d_hindcast"][k].to(device)
        for k in data["x_d_forecast"]:
            data["x_d_forecast"][k] = data["x_d_forecast"][k].to(device)

        data["x_s"] = data["x_s"].to(device)
        data["y"]   = data["y"].to(device)

        out = model(data)

        y_true = data["y"]
        y_true_f = y_true[:, -cfg.forecast_seq_length:, :]

        y_pred = out["y_hat"][:, -cfg.forecast_seq_length:, :]

        loss = loss_fn(y_pred, y_true_f)

        total_loss += loss.item()

    return total_loss / len(loader)

def regression_metrics(y_true, y_pred):
    y_true = np.asarray(y_true).reshape(-1)
    y_pred = np.asarray(y_pred).reshape(-1)

    e = y_pred - y_true
    mse = float(np.mean(e**2))
    rmse = float(np.sqrt(mse))
    bias = float(np.mean(e))

    ss_res = float(np.sum((y_true - y_pred) ** 2))
    ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2))

    r = np.corrcoef(y_true, y_pred)[0, 1]

    nse = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else float("nan")

    return {"MSE": mse, "RMSE": rmse, "Bias": bias, "r": r, "NSE": nse}

# Training

In [None]:
counts = basin_window_counts(basins_all, SEQ_H, SEQ_F)
basin_ids = list(basins_all.keys())

trainval_ids, test_ids = greedy_test_split_by_windows(basin_ids, counts, test_frac=0.10)

val_frac_of_trainval = 0.1 / 0.9
train_ids, val_ids = greedy_test_split_by_windows(trainval_ids, counts, test_frac=val_frac_of_trainval)

def _count_windows(ids):
    return sum(counts[b] for b in ids)

n_all  = _count_windows(basin_ids)
n_tr   = _count_windows(train_ids)
n_val  = _count_windows(val_ids)
n_test = _count_windows(test_ids)

print("Basins total:", len(basin_ids))
print("Train basins:", len(train_ids), "windows:", n_tr, "frac:", n_tr/n_all)
print("Val basins:", len(val_ids), "windows:", n_val, "frac:", n_val/n_all)
print("Test basins:", len(test_ids), "windows:", n_test,"frac:", n_test/n_all)


BATCH_SIZE = 8
MAX_EPOCHS = 50
PATIENCE = 15
MIN_DELTA = 1e-6

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


train_loss_fn = torch.nn.MSELoss()
val_loss_fn = torch.nn.MSELoss()

cfg_base = cfg
cfg_dict_base = cfg.as_dict()

In [None]:
best_hp = {'name': 'hs256_lr1e-3_wd1e-4_do0.1', 'hidden_size': 128, 'lr': 0.0001, 'weight_decay': 0.0001, 'output_dropout': 0.0}
final_epochs = MAX_EPOCHS
basins_train = {b: basins_all[b] for b in train_ids}
basins_val = {b: basins_all[b] for b in val_ids}
basins_test = {b: basins_all[b] for b in test_ids}

stats_train = compute_norm_stats(df, train_ids)
basins = basins_train
train_ds = HydroWindowDataset(basins, SEQ_H, SEQ_F)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0,
                             collate_fn=make_collate_fn(basins, stats_train))
basins = basins_val
val_ds = HydroWindowDataset(basins, SEQ_H, SEQ_F)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0,
                            collate_fn=make_collate_fn(basins, stats_train))

cfg_final_dict = copy.deepcopy(cfg_dict_base)
cfg_final_dict["hidden_size"] = best_hp["hidden_size"]
cfg_final_dict["output_dropout"] = best_hp["output_dropout"]
cfg = Config(cfg_final_dict)

final_model = HandoffForecastLSTM(cfg).to(device)
final_opt = torch.optim.Adam(final_model.parameters(), lr=best_hp["lr"], weight_decay=best_hp["weight_decay"])

best_val = float("inf")
best_state = None
bad_epochs = 0

history = {"epoch": [], "train_loss": [], "val_loss": []}

for epoch in range(1, MAX_EPOCHS + 1):
    train_loss = train_one_epoch(final_model, train_loader, final_opt, train_loss_fn, device, epoch)
    val_loss = eval_one_epoch(final_model, val_loader, val_loss_fn, device)

    history["epoch"].append(epoch)
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)

    print(f"[Epoch {epoch:03d}] train_loss={train_loss:.6f}  val_loss={val_loss:.6f}")

    if val_loss < best_val - MIN_DELTA:
        best_val = val_loss
        best_state = copy.deepcopy(final_model.state_dict())
        bad_epochs = 0
    else:
        bad_epochs += 1
        if bad_epochs >= PATIENCE:
            print(f"Early stopping at epoch {epoch:03d} (best val_loss={best_val:.6f})")
            break

# load best checkpoint
if best_state is not None:
    final_model.load_state_dict(best_state)

plt.figure()
plt.plot(history["epoch"], history["train_loss"], label="train")
plt.plot(history["epoch"], history["val_loss"], label="val")
plt.title("Loss Plot")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig("/kaggle/working/loss.png", dpi=200)
plt.show()

# Evaluate on test set
basins_test = {b: basins_all[b] for b in test_ids}
basins = basins_test
test_ds = HydroWindowDataset(basins, SEQ_H, SEQ_F)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0,
                         collate_fn=make_collate_fn(basins_test, stats_train))


basin_id, y_true_log, y_pred_log = predict_all(final_model, test_loader, device)

# metrics in log space
m_log = regression_metrics(y_true_log, y_pred_log)

# invert log1p back to original discharge units
y_true_q = np.expm1(y_true_log)
y_pred_q = np.expm1(y_pred_log)
y_true_q = np.clip(y_true_q, 0, None)
y_pred_q = np.clip(y_pred_q, 0, None)

m_q = regression_metrics(y_true_q, y_pred_q)
print("\nTEST metrics (log1p space):", m_log)
print("TEST metrics (original q):", m_q)

In [None]:
os.makedirs("outputs", exist_ok=True)
save_path = "outputs/handoff_lstm_best_cv.pt"

torch.save(
    {
        "state_dict": final_model.state_dict(),
        "cfg": cfg.as_dict(),
        "best_hp": best_hp,
        "final_epochs": final_epochs,
        "test_metrics_log": m_log,
        "test_metrics_q": m_q,
        "test_basins": test_ids
    },
    save_path
)

print("Saved model to:", save_path)