## GRU (PyTorch) — sequences from fetch singlekeys (20 Locations)

**Yêu cầu Dataset:**
- Chạy notebook `fetch-demo-data-singlekeys.ipynb` trước (đã fetch 20 tỉnh/thành)
- Upload output thành Kaggle Dataset
- Add dataset vào notebook này

**Dữ liệu format:**
- Files theo **location_id** (UUID), không theo tên location
- Ví dụ: `{location_id}_train_2021_2023_seq_multi_lag49_h100.npz`
- Metadata: `weather_20loc/data/meta/locations.json` chứa mapping location_id -> name

**Config:**
- LAG = 49h lookback
- HORIZON = 100h forecast (~4 ngày)  
- 20 locations thay vì 34/63
- Batch size = 96, Hidden = 192 (tối ưu cho T4 GPU)
- AMP (Mixed Precision) để tăng tốc

**Features:**
- Accelerator: gpu / cpu
- Auto-detect paths: tự dò SEQ_DIR trong `/kaggle/input/<dataset>/weather_20loc/data/sequences`
- Report: bao gồm location_id column trong CSV

In [None]:
# ============================================================
# train_gru_weather_v3_full.py
# GRU multi-target weather + precip 2-step (event+amount)
# v3: streaming (no concat), true log1p amount, pos_weight BCE, clean AMP API
# v3.1: location_id tracking from metadata
# 20 locations, LAG=49, HORIZON=100
# ============================================================

import os, json, gc, time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Dict, Tuple

import numpy as np
import pandas as pd

# ============================================================
# 0) ONE-LINE SWITCH
# ============================================================
ACCELERATOR = "gpu"   # "gpu" | "cpu"
# ============================================================

# ============================================================
# 1) CONFIG (OPTIMIZED for 20 locations + Kaggle Free Tier)
# ============================================================
@dataclass
class CFG:
    # === DATA PATHS (auto-detected) ===
    base_dir: str = "/kaggle/input/general-demo-data/weather_20loc"
    seq_dir_rel: str = "data/sequences"
    meta_dir_rel: str = "data/meta"

    lag: int = 49           # 49h lookback
    horizon: int = 100      # 100h forecast (~4 days)

    # === LOCATION BATCHING ===
    start_loc_idx: int = 0
    end_loc_idx: int = -1   # -1 = all remaining

    # model (OPTIMIZED: smaller for Kaggle)
    hidden: int = 192       # Reduced from 256
    layers: int = 2
    dropout: float = 0.15

    # train (OPTIMIZED)
    epochs: int = 20        # Reduced from 25
    batch_size: int = 96    # Optimized for T4 GPU
    lr: float = 2e-3
    weight_decay: float = 1e-4
    grad_clip: float = 1.0
    patience: int = 5

    # precip
    rain_thr_mm: float = 0.1
    amount_space: str = "log1p"
    precip_mm_max: float = 200.0
    loss_cls_w: float = 0.35
    loss_reg_w: float = 0.65
    pos_weight_cap: float = 20.0

    # weights for other targets
    w_temp: float = 1.0
    w_rh: float = 0.5
    w_press: float = 0.2
    w_cloud: float = 0.5
    w_wind: float = 0.7

    # TPU (disabled)
    tpu_cores: int = 8

    # report
    run_report: bool = True
    p_thr_cand: Tuple[float, ...] = tuple(np.round(np.linspace(0.05, 0.95, 19), 2).tolist())
    bins: Tuple[Tuple[int, int], ...] = ((1,24),(25,48),(49,72),(73,100))

    # output
    out_dir: str = "/kaggle/working/gru_weather_v3_out"

cfg = CFG()
OUT_DIR = Path(cfg.out_dir)
OUT_DIR.mkdir(parents=True, exist_ok=True)
REPORT_DIR = OUT_DIR / "reports"
REPORT_DIR.mkdir(parents=True, exist_ok=True)

# ============================================================
# 2) HELPERS
# ============================================================
def is_tpu() -> bool: return ACCELERATOR.lower() == "tpu"
def is_gpu() -> bool: return ACCELERATOR.lower() == "gpu"

def seed_all(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    try:
        import torch
        torch.manual_seed(seed)
    except Exception:
        pass

seed_all(42)

def find_base_dir() -> Path:
    """Find the base directory containing data/sequences and data/meta."""
    p = Path(cfg.base_dir)
    if (p / cfg.seq_dir_rel).exists():
        return p
    root = Path("/kaggle/input")
    # Try weather_20loc first, then others
    for pattern in ["weather_20loc", "weather_34loc", "weather_63loc", "weather_4loc"]:
        hits = list(root.rglob(f"{pattern}/data/sequences"))
        if hits:
            return hits[0].parent.parent
    hits = list(root.rglob("data/sequences"))
    if hits:
        return hits[0].parent.parent
    hits = list(root.rglob(f"*_seq_multi_lag{cfg.lag}_h{cfg.horizon}.npz"))
    if hits:
        return hits[0].parent.parent.parent
    raise FileNotFoundError("Cannot find base dir in /kaggle/input")

BASE_DIR = find_base_dir()
SEQ_DIR = BASE_DIR / cfg.seq_dir_rel
META_DIR = BASE_DIR / cfg.meta_dir_rel
print("[BASE_DIR]", BASE_DIR)
print("[SEQ_DIR]", SEQ_DIR)
print("[META_DIR]", META_DIR)

# ============================================================
# 2.1) LOCATION ID LOADING
# ============================================================
def load_location_ids() -> Tuple[List[str], Dict[str, str]]:
    """Load location_ids from metadata file, or fallback to scanning files."""
    meta_file = META_DIR / "locations.json"
    if meta_file.exists():
        with open(meta_file, "r") as f:
            meta = json.load(f)
        loc_ids = meta.get("location_ids", [])
        loc_names = {loc["location_id"]: loc["name"] for loc in meta.get("locations", [])}
        print(f"[meta] Loaded {len(loc_ids)} location_ids from {meta_file.name}")
        return loc_ids, loc_names
    
    # Fallback: scan files
    pattern = f"*_train_2021_2023_seq_multi_lag{cfg.lag}_h{cfg.horizon}.npz"
    files = sorted(SEQ_DIR.glob(pattern))
    loc_ids = []
    for fp in files:
        loc_id = fp.name.split("_train_")[0]
        if loc_id and loc_id not in loc_ids:
            loc_ids.append(loc_id)
    print(f"[fallback] Found {len(loc_ids)} location_ids from file scan")
    return loc_ids, {lid: lid[:8] for lid in loc_ids}

LOCATION_IDS_ALL, LOC_NAMES = load_location_ids()

# === LOCATION BATCHING ===
_start = cfg.start_loc_idx
_end = cfg.end_loc_idx if cfg.end_loc_idx >= 0 else len(LOCATION_IDS_ALL)
LOCATION_IDS = LOCATION_IDS_ALL[_start:_end]
print(f"[LOCATION BATCH] Using {len(LOCATION_IDS)}/{len(LOCATION_IDS_ALL)} locations (idx {_start}:{_end})")
for lid in LOCATION_IDS:
    print(f"  {lid} -> {LOC_NAMES.get(lid, '?')}")

def list_npz(split_key: str) -> List[Path]:
    files = sorted(SEQ_DIR.glob(f"*_{split_key}_seq_multi_lag{cfg.lag}_h{cfg.horizon}.npz"))
    if not files:
        raise FileNotFoundError(f"No npz for split={split_key} in {SEQ_DIR}")
    return files

def get_loc_id_from_file(fp: Path) -> str:
    """Extract location_id from filename."""
    name = fp.name
    for split in ["train_2021_2023", "val_2024", "test_2025_01_to_2025_11"]:
        if f"_{split}_" in name:
            return name.split(f"_{split}_")[0]
    return name.split("_")[0]

TRAIN_FILES = list_npz("train_2021_2023")
VAL_FILES   = list_npz("val_2024")
TEST_FILES  = list_npz("test_2025_01_to_2025_11")
print("[files] train:", len(TRAIN_FILES), "val:", len(VAL_FILES), "test:", len(TEST_FILES))

def load_npz(fp: Path):
    z = np.load(fp, allow_pickle=False)
    X = z["X"].astype(np.float32, copy=False)
    Y = z["Y"].astype(np.float32, copy=False)
    T = z["T"]
    meta = json.loads(z["meta"].item())
    return X, Y, T, meta

# read meta once
X0, Y0, T0, meta = load_npz(TRAIN_FILES[0])
input_cols = meta.get("input_cols") or meta.get("x_cols")
target_cols = meta.get("target_cols") or meta.get("y_cols")
if input_cols is None or target_cols is None:
    raise KeyError(f"Meta missing input/target cols. Keys={list(meta.keys())}")
IDX = {c:i for i,c in enumerate(target_cols)}
if "precipitation" not in IDX:
    raise KeyError(f"'precipitation' not found in meta['target_cols'] = {target_cols}")
PRECIP_IDX = IDX["precipitation"]
n_feat = X0.shape[-1]
n_tgt  = Y0.shape[-1]
del X0, Y0, T0
gc.collect()

print("[meta] n_feat:", n_feat, "n_tgt:", n_tgt)
print("[targets]", target_cols)

# ============================================================
# 3) SCALER (X only) — computed streaming over train files
# ============================================================
def compute_x_scaler_stream(files: List[Path], n_feat: int) -> Tuple[np.ndarray, np.ndarray]:
    s1 = np.zeros(n_feat, dtype=np.float64)
    s2 = np.zeros(n_feat, dtype=np.float64)
    n = 0

    for fp in files:
        X, _, _, _ = load_npz(fp)
        x2 = X.reshape(-1, n_feat).astype(np.float64, copy=False)
        s1 += x2.sum(axis=0)
        s2 += (x2 * x2).sum(axis=0)
        n += x2.shape[0]
        del X, x2
        gc.collect()

    mu = (s1 / max(n,1)).astype(np.float32)
    var = (s2 / max(n,1) - (mu.astype(np.float64) ** 2)).astype(np.float64)
    var = np.maximum(var, 1e-6)
    sd = np.sqrt(var).astype(np.float32)
    sd = np.where(sd < 1e-6, 1.0, sd).astype(np.float32)
    return mu, sd

X_mu, X_sd = compute_x_scaler_stream(TRAIN_FILES, n_feat)
np.savez(OUT_DIR / "x_scaler.npz",
         mu=X_mu, sd=X_sd, input_cols=np.array(input_cols, dtype=object))

def scale_x(X: np.ndarray) -> np.ndarray:
    return ((X - X_mu[None,None,:]) / X_sd[None,None,:]).astype(np.float32, copy=False)

# ============================================================
# 4) POS_WEIGHT for event (streaming over train files)
# ============================================================
def compute_pos_weight_stream(files: List[Path], precip_idx: int, thr_mm: float, cap: float) -> float:
    pos = 0.0
    total = 0.0
    for fp in files:
        _, Y, _, _ = load_npz(fp)
        y = Y[..., precip_idx]
        pos += float((y >= thr_mm).sum())
        total += float(y.size)
        del Y, y
        gc.collect()
    neg = total - pos
    pw = neg / (pos + 1e-9)
    pw = float(min(pw, cap))
    return pw

POS_WEIGHT = compute_pos_weight_stream(TRAIN_FILES, PRECIP_IDX, cfg.rain_thr_mm, cfg.pos_weight_cap)
print("[pos_weight]", POS_WEIGHT)

# ============================================================
# 5) TORCH + MODEL
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True

class ArrayDataset(Dataset):
    def __init__(self, X: np.ndarray, Y: np.ndarray):
        self.X = X
        self.Y = Y
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i):
        return torch.from_numpy(self.X[i]), torch.from_numpy(self.Y[i])

def precip_log_to_mm(pred_log: torch.Tensor, mm_max: float) -> torch.Tensor:
    mm = torch.expm1(pred_log).clamp(min=0.0, max=float(mm_max))
    return mm

class GRUWeather2Step(nn.Module):
    def __init__(self, n_feat, n_tgt, horizon, hidden, layers, dropout, precip_idx, amount_space="log1p"):
        super().__init__()
        self.horizon = horizon
        self.n_tgt = n_tgt
        self.precip_idx = precip_idx
        self.amount_space = amount_space

        self.gru = nn.GRU(
            input_size=n_feat,
            hidden_size=hidden,
            num_layers=layers,
            batch_first=True,
            dropout=(dropout if layers > 1 else 0.0),
        )
        self.head_main = nn.Linear(hidden, horizon * n_tgt)
        self.head_rain = nn.Linear(hidden, horizon)

    def forward(self, x):
        _, h = self.gru(x)
        h_last = h[-1]
        y = self.head_main(h_last).view(-1, self.horizon, self.n_tgt)
        rain_logit = self.head_rain(h_last).view(-1, self.horizon)

        if self.amount_space == "mm":
            y[..., self.precip_idx] = F.softplus(y[..., self.precip_idx])
        return y, rain_logit

def build_optimizer(model: nn.Module):
    return torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

def safe_set_weight(w: torch.Tensor, name: str, value: float):
    if name in IDX:
        w[..., IDX[name]] = float(value)

def loss_fn(y_pred, rain_logit, y_true, pos_weight_value: float):
    reg = F.smooth_l1_loss(y_pred, y_true, reduction="none")

    w = torch.ones((1, 1, y_true.shape[-1]), device=y_true.device, dtype=y_true.dtype)
    safe_set_weight(w, "temperature_2m", cfg.w_temp)
    safe_set_weight(w, "relative_humidity_2m", cfg.w_rh)
    safe_set_weight(w, "surface_pressure", cfg.w_press)
    safe_set_weight(w, "cloud_cover", cfg.w_cloud)
    safe_set_weight(w, "u10", cfg.w_wind)
    safe_set_weight(w, "v10", cfg.w_wind)
    w[..., PRECIP_IDX] = 0.0
    loss_other = (reg * w).mean()

    y_mm = y_true[..., PRECIP_IDX]
    rain_label = (y_mm >= cfg.rain_thr_mm).float()
    pw = torch.tensor(pos_weight_value, device=y_true.device, dtype=y_true.dtype)
    loss_cls = F.binary_cross_entropy_with_logits(rain_logit, rain_label, pos_weight=pw)

    pos = rain_label > 0.5
    if cfg.amount_space == "mm":
        pred_mm = y_pred[..., PRECIP_IDX]
        if pos.any():
            loss_reg = F.smooth_l1_loss(pred_mm[pos], y_mm[pos])
        else:
            loss_reg = pred_mm.mean() * 0.0
    else:
        pred_log = y_pred[..., PRECIP_IDX]
        true_log = torch.log1p(y_mm.clamp(min=0.0))
        if pos.any():
            loss_reg = F.smooth_l1_loss(pred_log[pos], true_log[pos])
        else:
            loss_reg = pred_log.mean() * 0.0

    total = loss_other + cfg.loss_cls_w * loss_cls + cfg.loss_reg_w * loss_reg
    return total

# ============================================================
# 6) TRAIN — streaming over files (no concat)
# ============================================================
def train_gpu_cpu():
    device = torch.device("cuda" if (is_gpu() and torch.cuda.is_available()) else "cpu")
    print("[device]", device)

    model = GRUWeather2Step(
        n_feat=n_feat, n_tgt=n_tgt, horizon=cfg.horizon,
        hidden=cfg.hidden, layers=cfg.layers, dropout=cfg.dropout,
        precip_idx=PRECIP_IDX, amount_space=cfg.amount_space
    ).to(device)
    opt = build_optimizer(model)

    use_amp = (device.type == "cuda")
    scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

    best_val = float("inf")
    bad = 0
    best_path = OUT_DIR / "gru_weather2step_best.pt"

    for ep in range(1, cfg.epochs + 1):
        model.train()
        tr_sum = 0.0
        tr_n = 0

        files = TRAIN_FILES.copy()
        np.random.shuffle(files)

        for fp in files:
            X, Y, _, _ = load_npz(fp)
            X = scale_x(X)
            ds = ArrayDataset(X, Y)
            loader = DataLoader(
                ds, batch_size=cfg.batch_size, shuffle=True,
                num_workers=(2 if device.type == "cuda" else 0),
                pin_memory=(device.type == "cuda"), drop_last=True
            )

            for xb, yb in loader:
                xb = xb.to(device, non_blocking=True)
                yb = yb.to(device, non_blocking=True)

                opt.zero_grad(set_to_none=True)

                if use_amp:
                    with torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
                        yp, rl = model(xb)
                        loss = loss_fn(yp, rl, yb, POS_WEIGHT)
                    scaler.scale(loss).backward()
                    if cfg.grad_clip and cfg.grad_clip > 0:
                        scaler.unscale_(opt)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
                    scaler.step(opt)
                    scaler.update()
                else:
                    yp, rl = model(xb)
                    loss = loss_fn(yp, rl, yb, POS_WEIGHT)
                    loss.backward()
                    if cfg.grad_clip and cfg.grad_clip > 0:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
                    opt.step()

                tr_sum += float(loss.item()) * xb.size(0)
                tr_n += xb.size(0)

            del X, Y, ds, loader
            gc.collect()

        tr_loss = tr_sum / max(tr_n, 1)

        model.eval()
        va_sum = 0.0
        va_n = 0
        with torch.no_grad():
            for fp in VAL_FILES:
                X, Y, _, _ = load_npz(fp)
                X = scale_x(X)
                ds = ArrayDataset(X, Y)
                loader = DataLoader(ds, batch_size=cfg.batch_size, shuffle=False, num_workers=0, drop_last=False)

                for xb, yb in loader:
                    xb = xb.to(device, non_blocking=True)
                    yb = yb.to(device, non_blocking=True)

                    if use_amp:
                        with torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
                            yp, rl = model(xb)
                            loss = loss_fn(yp, rl, yb, POS_WEIGHT)
                    else:
                        yp, rl = model(xb)
                        loss = loss_fn(yp, rl, yb, POS_WEIGHT)

                    va_sum += float(loss.item()) * xb.size(0)
                    va_n += xb.size(0)

                del X, Y, ds, loader
                gc.collect()

        va_loss = va_sum / max(va_n, 1)

        print(f"epoch {ep:02d} | train {tr_loss:.6f} | val {va_loss:.6f}")

        if va_loss < best_val - 1e-6:
            best_val = va_loss
            bad = 0
            ckpt_meta = dict(meta)
            ckpt_meta.update({
                "n_feat": n_feat,
                "n_tgt": n_tgt,
                "precip_idx": PRECIP_IDX,
                "input_cols": input_cols,
                "target_cols": target_cols,
            })
            torch.save(
              {"model": model.state_dict(), "cfg": cfg.__dict__, "meta": ckpt_meta, 
               "pos_weight": POS_WEIGHT, "x_scaler": {"mu": X_mu.tolist(), "sd": X_sd.tolist()}},
              best_path
            )
        else:
            bad += 1
            if bad >= cfg.patience:
                print("[early stop]")
                break

    print("[best val]", best_val, "->", best_path)
    return best_path.as_posix()

# ============================================================
# 7) REPORT — streaming (no concat), tune P_THR on VAL, export CSV
# ============================================================
def event_metrics_counts(y_true01: np.ndarray, y_pred01: np.ndarray):
    tp = int(((y_true01 == 1) & (y_pred01 == 1)).sum())
    fp = int(((y_true01 == 0) & (y_pred01 == 1)).sum())
    fn = int(((y_true01 == 1) & (y_pred01 == 0)).sum())
    return tp, fp, fn

def tune_pthr_on_val(model, device, amount_space: str) -> Tuple[float, pd.DataFrame]:
    cand = np.array(cfg.p_thr_cand, dtype=np.float32)
    K = len(cand)
    tp = np.zeros(K, dtype=np.int64)
    fp = np.zeros(K, dtype=np.int64)
    fn = np.zeros(K, dtype=np.int64)

    model.eval()
    with torch.no_grad():
        for fp_npz in VAL_FILES:
            X, Y, _, _ = load_npz(fp_npz)
            X = scale_x(X)
            ds = ArrayDataset(X, Y)
            loader = DataLoader(ds, batch_size=512, shuffle=False, num_workers=0, drop_last=False)

            for xb, yb in loader:
                xb = xb.to(device)
                yb = yb.to(device)

                yp, logit = model(xb)
                p = torch.sigmoid(logit).detach().cpu().numpy().astype(np.float32)
                y_mm = yb[..., PRECIP_IDX].detach().cpu().numpy().astype(np.float32)
                y_ev = (y_mm >= cfg.rain_thr_mm).reshape(-1)
                p1 = p.reshape(-1)

                for i, thr in enumerate(cand):
                    pred = (p1 >= thr)
                    tpi, fpi, fni = event_metrics_counts(y_ev.astype(np.int32), pred.astype(np.int32))
                    tp[i] += tpi; fp[i] += fpi; fn[i] += fni

            del X, Y, ds, loader
            gc.collect()

    prec = tp / (tp + fp + 1e-9)
    rec  = tp / (tp + fn + 1e-9)
    f1   = 2 * prec * rec / (prec + rec + 1e-9)

    best_i = int(np.argmax(f1))
    best_thr = float(cand[best_i])

    rank = pd.DataFrame({
        "p_thr": cand,
        "precision": prec,
        "recall": rec,
        "f1": f1,
        "tp": tp, "fp": fp, "fn": fn
    }).sort_values("f1", ascending=False).reset_index(drop=True)

    return best_thr, rank

def accumulate_precip_stats(model, device, files: List[Path], p_thr: float, split_name: str, amount_space: str):
    H = cfg.horizon
    tp = np.zeros(H, dtype=np.int64)
    fp = np.zeros(H, dtype=np.int64)
    fn = np.zeros(H, dtype=np.int64)
    npos = np.zeros(H, dtype=np.int64)

    abs_hard = np.zeros(H, dtype=np.float64)
    sq_hard  = np.zeros(H, dtype=np.float64)

    n_all = 0

    model.eval()
    with torch.no_grad():
        for fp_npz in files:
            X, Y, _, _ = load_npz(fp_npz)
            X = scale_x(X)
            ds = ArrayDataset(X, Y)
            loader = DataLoader(ds, batch_size=512, shuffle=False, num_workers=0, drop_last=False)

            for xb, yb in loader:
                xb = xb.to(device)
                yb = yb.to(device)

                yp, logit = model(xb)
                p = torch.sigmoid(logit)

                y_mm = yb[..., PRECIP_IDX]
                y_mm_np = y_mm.detach().cpu().numpy().astype(np.float32)
                y_ev_np = (y_mm_np >= cfg.rain_thr_mm)

                if amount_space == "mm":
                    pred_mm = yp[..., PRECIP_IDX].clamp(min=0.0, max=float(cfg.precip_mm_max))
                else:
                    pred_log = yp[..., PRECIP_IDX]
                    pred_mm = precip_log_to_mm(pred_log, cfg.precip_mm_max)

                pred_mm_np = pred_mm.detach().cpu().numpy().astype(np.float32)
                p_np = p.detach().cpu().numpy().astype(np.float32)
                pred_ev_np = (p_np >= p_thr)

                tp += (pred_ev_np & y_ev_np).sum(axis=0).astype(np.int64)
                fp += (pred_ev_np & (~y_ev_np)).sum(axis=0).astype(np.int64)
                fn += ((~pred_ev_np) & y_ev_np).sum(axis=0).astype(np.int64)
                npos += y_ev_np.sum(axis=0).astype(np.int64)

                hard = np.where(pred_ev_np, pred_mm_np, 0.0).astype(np.float32)
                err_h = hard - y_mm_np

                abs_hard += np.abs(err_h).sum(axis=0)
                sq_hard  += (err_h * err_h).sum(axis=0)

                n_all += y_mm_np.shape[0]

            del X, Y, ds, loader
            gc.collect()

    prec = tp / (tp + fp + 1e-9)
    rec  = tp / (tp + fn + 1e-9)
    f1   = 2 * prec * rec / (prec + rec + 1e-9)

    denom_all = max(n_all, 1)
    mae_hard = abs_hard / denom_all
    rmse_hard = np.sqrt(sq_hard / denom_all)

    rep = pd.DataFrame({
        "split": split_name,
        "horizon": np.arange(1, H+1),
        "rain_thr_mm": cfg.rain_thr_mm,
        "p_thr": float(p_thr),
        "event_precision": prec,
        "event_recall": rec,
        "event_f1": f1,
        "tp": tp, "fp": fp, "fn": fn,
        "npos": npos,
        "final_mae_all_hard": mae_hard,
        "final_rmse_all_hard": rmse_hard,
    })
    return rep

def accumulate_other_targets(model, device, files: List[Path], split_name: str, amount_space: str):
    H = cfg.horizon
    T = len(target_cols)

    targets = [t for t in range(T) if t != PRECIP_IDX]
    sum_abs = np.zeros((len(targets), H), dtype=np.float64)
    sum_sq  = np.zeros((len(targets), H), dtype=np.float64)
    n_all = 0

    model.eval()
    with torch.no_grad():
        for fp_npz in files:
            X, Y, _, _ = load_npz(fp_npz)
            X = scale_x(X)
            ds = ArrayDataset(X, Y)
            loader = DataLoader(ds, batch_size=512, shuffle=False, num_workers=0, drop_last=False)

            for xb, yb in loader:
                xb = xb.to(device)
                yb = yb.to(device)
                yp, _ = model(xb)

                y_true = yb.detach().cpu().numpy().astype(np.float32)
                y_pred = yp.detach().cpu().numpy().astype(np.float32)

                for i, ti in enumerate(targets):
                    de = (y_pred[:, :, ti] - y_true[:, :, ti]).astype(np.float64)
                    sum_abs[i] += np.abs(de).sum(axis=0)
                    sum_sq[i]  += (de * de).sum(axis=0)

                n_all += y_true.shape[0]

            del X, Y, ds, loader
            gc.collect()

    denom = max(n_all, 1)
    rows = []
    for i, ti in enumerate(targets):
        name = target_cols[ti]
        mae = (sum_abs[i] / denom).astype(np.float64)
        rmse = np.sqrt(sum_sq[i] / denom).astype(np.float64)
        for h in range(H):
            rows.append({
                "split": split_name,
                "target": name,
                "horizon": h+1,
                "mae": float(mae[h]),
                "rmse": float(rmse[h]),
            })
    return pd.DataFrame(rows)

def summarize_bins_precip(rep_split: pd.DataFrame):
    rows = []
    for a,b in cfg.bins:
        r = rep_split[(rep_split["horizon"]>=a) & (rep_split["horizon"]<=b)]
        if len(r) == 0:
            continue
        rows.append({
            "split": r["split"].iloc[0],
            "horizon_bin": f"{a}-{b}",
            "event_f1_mean": float(r["event_f1"].mean()),
            "event_recall_mean": float(r["event_recall"].mean()),
            "event_precision_mean": float(r["event_precision"].mean()),
            "mae_all_hard_mean": float(r["final_mae_all_hard"].mean()),
            "rmse_all_hard_mean": float(r["final_rmse_all_hard"].mean()),
            "npos_sum": int(r["npos"].sum()),
        })
    return pd.DataFrame(rows)

def summarize_bins_other(rep_long: pd.DataFrame):
    rows = []
    for (split, target), g in rep_long.groupby(["split", "target"]):
        for a,b in cfg.bins:
            r = g[(g["horizon"]>=a) & (g["horizon"]<=b)]
            if len(r) == 0:
                continue
            rows.append({
                "split": split,
                "target": target,
                "horizon_bin": f"{a}-{b}",
                "mae_mean": float(r["mae"].mean()),
                "rmse_mean": float(r["rmse"].mean()),
            })
    return pd.DataFrame(rows)

def run_report(best_ckpt_path: Path):
    ckpt = torch.load(best_ckpt_path, map_location="cpu", weights_only=False)
    cfg_ck = ckpt.get("cfg", {})
    amount_space = str(cfg_ck.get("amount_space", cfg.amount_space))

    model = GRUWeather2Step(
        n_feat=n_feat, n_tgt=n_tgt, horizon=cfg.horizon,
        hidden=int(cfg_ck.get("hidden", cfg.hidden)),
        layers=int(cfg_ck.get("layers", cfg.layers)),
        dropout=float(cfg_ck.get("dropout", cfg.dropout)),
        precip_idx=PRECIP_IDX,
        amount_space=amount_space
    )
    model.load_state_dict(ckpt["model"], strict=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    best_p_thr, thr_rank = tune_pthr_on_val(model, device, amount_space)
    thr_rank.to_csv(REPORT_DIR / "precip_thr_tuning_val.csv", index=False)
    print("[REPORT] best_p_thr (VAL F1) =", best_p_thr)

    rep_p_val = accumulate_precip_stats(model, device, VAL_FILES, best_p_thr, "val", amount_space)
    rep_p_te  = accumulate_precip_stats(model, device, TEST_FILES, best_p_thr, "test", amount_space)
    rep_p = pd.concat([rep_p_val, rep_p_te], ignore_index=True)
    rep_p.to_csv(REPORT_DIR / f"gru_precip_report_val_test_h1-{cfg.horizon}.csv", index=False)

    sum_p = pd.concat([summarize_bins_precip(rep_p_val), summarize_bins_precip(rep_p_te)], ignore_index=True)
    sum_p["p_thr"] = best_p_thr
    sum_p["rain_thr_mm"] = cfg.rain_thr_mm
    sum_p.to_csv(REPORT_DIR / "gru_precip_summary_bins_val_test.csv", index=False)

    rep_o_val = accumulate_other_targets(model, device, VAL_FILES, "val", amount_space)
    rep_o_te  = accumulate_other_targets(model, device, TEST_FILES, "test", amount_space)
    rep_o = pd.concat([rep_o_val, rep_o_te], ignore_index=True)
    rep_o.to_csv(REPORT_DIR / "gru_other_targets_report_val_test_long.csv", index=False)

    sum_o = summarize_bins_other(rep_o)
    sum_o.to_csv(REPORT_DIR / "gru_other_targets_summary_bins_val_test.csv", index=False)

    print("[REPORT] saved in:", REPORT_DIR)

# ============================================================
# 8) MAIN
# ============================================================
def export_traced_gru(best_ckpt_path: Path, out_dir: Path):
    best_ckpt_path = Path(best_ckpt_path)
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    obj = torch.load(str(best_ckpt_path), map_location="cpu", weights_only=False)
    sd = obj["model"]
    cfg_d = obj.get("cfg", {})
    meta_d = obj.get("meta", {})

    lag = int(cfg.lag)
    horizon = int(cfg.horizon)

    if "gru.weight_ih_l0" in sd:
        n_feat_from_sd = sd["gru.weight_ih_l0"].shape[1]
    else:
        n_feat_from_sd = None
    n_feat_e = int(n_feat_from_sd or meta_d.get("n_feat", 12))

    if "head_main.weight" in sd:
        n_tgt_from_sd = sd["head_main.weight"].shape[0] // horizon
    else:
        n_tgt_from_sd = None
    n_tgt_e = int(n_tgt_from_sd or meta_d.get("n_tgt", 7))

    precip_idx_e = int(meta_d.get("precip_idx", 1))

    hidden = int(cfg_d.get("hidden", 192))
    layers = int(cfg_d.get("layers", 2))
    dropout = float(cfg_d.get("dropout", 0.0))
    amount_space = str(cfg_d.get("amount_space", "log1p"))

    m = GRUWeather2Step(
        n_feat=n_feat_e, n_tgt=n_tgt_e, horizon=horizon,
        hidden=hidden, layers=layers, dropout=dropout,
        precip_idx=precip_idx_e, amount_space=amount_space
    )
    m.load_state_dict(sd, strict=True)
    m.eval()

    x = torch.zeros(1, lag, n_feat_e, dtype=torch.float32)
    _ = m(x)

    traced = torch.jit.trace(m, x, strict=False)
    ts_path = out_dir / "gru_weather2step_best_traced.pt"
    traced.save(str(ts_path))
    print("[EXPORT] GRU traced saved:", ts_path)
    return ts_path

def main():
    print("ACCELERATOR:", ACCELERATOR)
    print("torch:", torch.__version__)
    print(f"CONFIG: LAG={cfg.lag}, HORIZON={cfg.horizon}, HIDDEN={cfg.hidden}")

    best_ckpt = Path(train_gpu_cpu())

    if cfg.run_report:
        run_report(best_ckpt)

    try:
        ts_path = export_traced_gru(best_ckpt, OUT_DIR)
        _m = torch.jit.load(str(ts_path), map_location="cpu")
        _m.eval()
        print("[EXPORT][VERIFY] torch.jit.load OK:", ts_path)
    except Exception as e:
        print("[WARN] TorchScript export failed:", e)
    
    print("DONE. out_dir =", OUT_DIR)

if __name__ == "__main__":
    main()