## TCN (PyTorch) — sequences from fetch singlekeys (20 Locations)

**Yêu cầu Dataset:**
- Chạy notebook `fetch-demo-data-singlekeys.ipynb` trước (đã fetch 20 tỉnh/thành)
- Upload output thành Kaggle Dataset
- Add dataset vào notebook này

**Dữ liệu format:**
- Files theo **location_id** (UUID), không theo tên location
- Ví dụ: `{location_id}_train_2021_2023_seq_multi_lag49_h100.npz`
- Metadata: `weather_20loc/data/meta/locations.json` chứa mapping location_id -> name

**Config:**
- LAG = 49h lookback
- HORIZON = 100h forecast (~4 ngày)
- 20 locations thay vì 34/63
- tcn_channels = 192 (giảm từ 256)
- tcn_levels = 4 (giảm từ 5) - vẫn đủ receptive field (~61) cho 49h
- batch_size = 96 (tối ưu cho T4 GPU)
- AMP (Mixed Precision) để tăng tốc

**Features:**
- Accelerator: cpu / gpu
- Auto-detect paths: tự dò SEQ_DIR
- Report: bao gồm location_id column trong CSV

In [None]:
# ============================================================
# train_tcn_weather_2step_full_v1p1.py
# TCN multi-target weather + precip 2-step (event+amount)
# 20 locations, LAG=49, HORIZON=100
# ============================================================

import os, json, gc, time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Dict, Tuple

import numpy as np
import pandas as pd

# ============================================================
# 0) ONE-LINE SWITCH
# ============================================================
ACCELERATOR = "gpu"   # "gpu" | "cpu"
# ============================================================

# ============================================================
# 1) CONFIG (OPTIMIZED for 20 locations + Kaggle Free Tier)
# ============================================================
@dataclass
class CFG:
    base_dir: str = "/kaggle/input/general-demo-data/weather_20loc"
    seq_dir_rel: str = "data/sequences"
    meta_dir_rel: str = "data/meta"

    lag: int = 49           # 49h lookback
    horizon: int = 100      # 100h forecast (~4 days)

    # ---- TCN model
    # With levels=4, k=3, dilations 1..8, 2 conv/block -> receptive_field ~ 61
    tcn_channels: int = 192     # Reduced from 256
    tcn_levels: int = 4         # Reduced from 5 (still enough for 24h input)
    kernel_size: int = 3
    dropout: float = 0.15
    use_weight_norm: bool = True
    act: str = "relu"

    # ---- train (OPTIMIZED)
    epochs: int = 20            # Reduced from 25
    batch_size: int = 96        # Optimized for T4 GPU
    lr: float = 2e-3
    weight_decay: float = 1e-4
    grad_clip: float = 1.0
    patience: int = 5

    # ---- precip
    rain_thr_mm: float = 0.1
    amount_space: str = "log1p"
    precip_mm_max: float = 200.0
    loss_cls_w: float = 0.35
    loss_reg_w: float = 0.65
    pos_weight_cap: float = 20.0

    # ---- stabilize multi-target loss
    scale_y_other: bool = True

    w_temp: float = 1.0
    w_rh: float = 0.7
    w_press: float = 0.5
    w_cloud: float = 0.7
    w_wind: float = 0.8

    # ---- report
    run_report: bool = True
    p_thr_cand: Tuple[float, ...] = tuple(np.round(np.linspace(0.05, 0.95, 19), 2).tolist())
    bins: Tuple[Tuple[int, int], ...] = ((1,24),(25,48),(49,72),(73,100))

    # output
    out_dir: str = "/kaggle/working/tcn_weather_2step_out"

    run_id: str = ""
    deterministic: bool = False
    num_workers: int = 2
    pin_memory: bool = True
    save_preds: bool = True

    train_split_key: str = "train_2021_2023"
    val_split_key: str = "val_2024"
    test_split_key: str = "test_2025_01_to_2025_11"

cfg = CFG()

_env_run_id = os.environ.get("RUN_ID", "").strip()
if not cfg.run_id:
    cfg.run_id = _env_run_id or time.strftime("%Y%m%d-%H%M%S")
OUT_DIR = Path(cfg.out_dir) / cfg.run_id
OUT_DIR.mkdir(parents=True, exist_ok=True)
REPORT_DIR = OUT_DIR / "reports"
REPORT_DIR.mkdir(parents=True, exist_ok=True)

def is_gpu() -> bool:
    return ACCELERATOR.lower() == "gpu"

def seed_all(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    try:
        import torch
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        if getattr(cfg, "deterministic", False):
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
    except Exception:
        pass

seed_all(42)

# ============================================================
# 2) DATA DISCOVERY + LOADING
# ============================================================
def find_base_dir() -> Path:
    p = Path(cfg.base_dir)
    if (p / cfg.seq_dir_rel).exists():
        return p
    root = Path("/kaggle/input")
    for pattern in ["weather_20loc", "weather_34loc", "weather_63loc", "weather_4loc"]:
        hits = list(root.rglob(f"{pattern}/data/sequences"))
        if hits:
            return hits[0].parent.parent
    hits = list(root.rglob("data/sequences"))
    if hits:
        return hits[0].parent.parent
    hits = list(root.rglob(f"*_seq_multi_lag{cfg.lag}_h{cfg.horizon}.npz"))
    if hits:
        return hits[0].parent.parent.parent
    raise FileNotFoundError("Cannot find base dir in /kaggle/input")

BASE_DIR = find_base_dir()
SEQ_DIR = BASE_DIR / cfg.seq_dir_rel
META_DIR = BASE_DIR / cfg.meta_dir_rel
print("[BASE_DIR]", BASE_DIR)
print("[SEQ_DIR]", SEQ_DIR)
print("[META_DIR]", META_DIR)

# ============================================================
# 2.1) LOCATION ID LOADING
# ============================================================
def load_location_ids() -> Tuple[List[str], Dict[str, str]]:
    meta_file = META_DIR / "locations.json"
    if meta_file.exists():
        with open(meta_file, "r") as f:
            meta = json.load(f)
        loc_ids = meta.get("location_ids", [])
        loc_names = {loc["location_id"]: loc["name"] for loc in meta.get("locations", [])}
        print(f"[meta] Loaded {len(loc_ids)} location_ids from {meta_file.name}")
        return loc_ids, loc_names
    
    pattern = f"*_{cfg.train_split_key}_seq_multi_lag{cfg.lag}_h{cfg.horizon}.npz"
    files = sorted(SEQ_DIR.glob(pattern))
    loc_ids = []
    for fp in files:
        loc_id = fp.name.split(f"_{cfg.train_split_key}_")[0]
        if loc_id and loc_id not in loc_ids:
            loc_ids.append(loc_id)
    print(f"[fallback] Found {len(loc_ids)} location_ids from file scan")
    return loc_ids, {lid: lid[:8] for lid in loc_ids}

LOCATION_IDS, LOC_NAMES = load_location_ids()
print("[LOCATION_IDS]", LOCATION_IDS)
for lid in LOCATION_IDS:
    print(f"  {lid} -> {LOC_NAMES.get(lid, '?')}")

def list_npz(split_key: str) -> List[Path]:
    files = sorted(SEQ_DIR.glob(f"*_{split_key}_seq_multi_lag{cfg.lag}_h{cfg.horizon}.npz"))
    if not files:
        raise FileNotFoundError(f"No npz for split={split_key} in {SEQ_DIR}")
    return files

def get_loc_id_from_file(fp: Path) -> str:
    name = fp.name
    for split in [cfg.train_split_key, cfg.val_split_key, cfg.test_split_key]:
        if f"_{split}_" in name:
            return name.split(f"_{split}_")[0]
    return name.split("_")[0]

TRAIN_FILES = list_npz(cfg.train_split_key)
VAL_FILES   = list_npz(cfg.val_split_key)
TEST_FILES  = list_npz(cfg.test_split_key)
print("[files] train:", len(TRAIN_FILES), "val:", len(VAL_FILES), "test:", len(TEST_FILES))

def load_npz(fp: Path):
    z = np.load(fp, allow_pickle=False)
    X = z["X"].astype(np.float32, copy=False)
    Y = z["Y"].astype(np.float32, copy=False)
    Tm = z["T"]
    meta = json.loads(z["meta"].item())
    return X, Y, Tm, meta

X0, Y0, T0, meta = load_npz(TRAIN_FILES[0])
input_cols = meta.get("input_cols") or meta.get("x_cols")
target_cols = meta.get("target_cols") or meta.get("y_cols")
if input_cols is None or target_cols is None:
    raise KeyError(f"Meta missing input/target cols. Keys={list(meta.keys())}")
IDX = {c:i for i,c in enumerate(target_cols)}
if "precipitation" not in IDX:
    raise KeyError(f"'precipitation' not found in target_cols: {target_cols}")
PRECIP_IDX = IDX["precipitation"]

n_feat = X0.shape[-1]
n_tgt  = Y0.shape[-1]
del X0, Y0, T0
gc.collect()

print("[meta] n_feat:", n_feat, "n_tgt:", n_tgt)
print("[targets]", target_cols)
print("[PRECIP_IDX]", PRECIP_IDX, "->", target_cols[PRECIP_IDX])

# ============================================================
# 3) X SCALER STREAMING
# ============================================================
def compute_x_scaler_stream(files: List[Path], n_feat: int) -> Tuple[np.ndarray, np.ndarray]:
    s1 = np.zeros(n_feat, dtype=np.float64)
    s2 = np.zeros(n_feat, dtype=np.float64)
    n = 0

    for fp in files:
        X, _, _, _ = load_npz(fp)
        x2 = X.reshape(-1, n_feat).astype(np.float64, copy=False)
        s1 += x2.sum(axis=0)
        s2 += (x2 * x2).sum(axis=0)
        n += x2.shape[0]
        del X, x2
        gc.collect()

    mu = (s1 / max(n,1)).astype(np.float32)
    var = (s2 / max(n,1) - (mu.astype(np.float64) ** 2)).astype(np.float64)
    var = np.maximum(var, 1e-6)
    sd = np.sqrt(var).astype(np.float32)
    sd = np.where(sd < 1e-6, 1.0, sd).astype(np.float32)
    return mu, sd

X_mu, X_sd = compute_x_scaler_stream(TRAIN_FILES, n_feat)
np.savez(
    OUT_DIR / "x_scaler.npz",
    mu=X_mu, sd=X_sd, input_cols=np.array(input_cols, dtype=object),
)

def scale_x(X: np.ndarray) -> np.ndarray:
    return ((X - X_mu[None,None,:]) / X_sd[None,None,:]).astype(np.float32, copy=False)

# ============================================================
# 4) Y SCALER (non-precip only) STREAMING
# ============================================================
def compute_y_scaler_stream_nonprecip(files: List[Path], n_tgt: int, precip_idx: int) -> Tuple[np.ndarray, np.ndarray]:
    s1 = np.zeros(n_tgt, dtype=np.float64)
    s2 = np.zeros(n_tgt, dtype=np.float64)
    n = 0

    for fp in files:
        _, Y, _, _ = load_npz(fp)
        y2 = Y.reshape(-1, n_tgt).astype(np.float64, copy=False)
        s1 += y2.sum(axis=0)
        s2 += (y2 * y2).sum(axis=0)
        n += y2.shape[0]
        del Y, y2
        gc.collect()

    mu = (s1 / max(n,1)).astype(np.float32)
    var = (s2 / max(n,1) - (mu.astype(np.float64) ** 2)).astype(np.float64)
    var = np.maximum(var, 1e-6)
    sd = np.sqrt(var).astype(np.float32)
    sd = np.where(sd < 1e-6, 1.0, sd).astype(np.float32)

    mu[precip_idx] = 0.0
    sd[precip_idx] = 1.0
    return mu, sd

if cfg.scale_y_other:
    Y_mu, Y_sd = compute_y_scaler_stream_nonprecip(TRAIN_FILES, n_tgt, PRECIP_IDX)
    np.savez(
        OUT_DIR / "y_scaler_nonprecip.npz",
        mu=Y_mu, sd=Y_sd, target_cols=np.array(target_cols, dtype=object),
    )
else:
    Y_mu = np.zeros(n_tgt, dtype=np.float32)
    Y_sd = np.ones(n_tgt, dtype=np.float32)

# ============================================================
# 5) POS_WEIGHT for precip event
# ============================================================
def compute_pos_weight_stream(files: List[Path], precip_idx: int, thr_mm: float, cap: float) -> float:
    pos = 0.0
    total = 0.0
    for fp in files:
        _, Y, _, _ = load_npz(fp)
        y = Y[..., precip_idx]
        pos += float((y >= thr_mm).sum())
        total += float(y.size)
        del Y, y
        gc.collect()
    neg = total - pos
    pw = neg / (pos + 1e-9)
    pw = float(min(pw, cap))
    return pw

POS_WEIGHT = compute_pos_weight_stream(TRAIN_FILES, PRECIP_IDX, cfg.rain_thr_mm, cfg.pos_weight_cap)
print("[pos_weight]", POS_WEIGHT)

# ============================================================
# 6) TORCH + MODEL (TCN)
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

print("[torch]", torch.__version__)
if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = (not cfg.deterministic)
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

def maybe_weight_norm(conv: nn.Module, enabled: bool) -> nn.Module:
    if not enabled:
        return conv
    try:
        from torch.nn.utils.parametrizations import weight_norm as wn
        return wn(conv)
    except Exception:
        try:
            from torch.nn.utils import weight_norm as wn2
            return wn2(conv)
        except Exception:
            return conv

class ArrayDataset(Dataset):
    def __init__(self, X: np.ndarray, Y: np.ndarray):
        self.X = X
        self.Y = Y
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i):
        return torch.from_numpy(self.X[i]), torch.from_numpy(self.Y[i])

def precip_log_to_mm(pred_log: torch.Tensor, mm_max: float) -> torch.Tensor:
    return torch.expm1(pred_log).clamp(min=0.0, max=float(mm_max))

def _act(name: str):
    name = (name or "relu").lower()
    if name == "gelu":
        return nn.GELU()
    return nn.ReLU()

class CausalConv1d(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, kernel_size: int, dilation: int, use_weight_norm: bool):
        super().__init__()
        self.pad = (kernel_size - 1) * dilation
        conv = nn.Conv1d(in_ch, out_ch, kernel_size=kernel_size, dilation=dilation)
        self.conv = maybe_weight_norm(conv, use_weight_norm)

    def forward(self, x):
        x = F.pad(x, (self.pad, 0))
        return self.conv(x)

class TCNBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, k: int, d: int, dropout: float, use_weight_norm: bool, act_name: str):
        super().__init__()
        self.conv1 = CausalConv1d(in_ch, out_ch, kernel_size=k, dilation=d, use_weight_norm=use_weight_norm)
        self.act1  = _act(act_name)
        self.drop1 = nn.Dropout(dropout)

        self.conv2 = CausalConv1d(out_ch, out_ch, kernel_size=k, dilation=d, use_weight_norm=use_weight_norm)
        self.act2  = _act(act_name)
        self.drop2 = nn.Dropout(dropout)

        self.down = nn.Conv1d(in_ch, out_ch, kernel_size=1) if in_ch != out_ch else None
        self.out_act = _act(act_name)

    def forward(self, x):
        y = self.drop1(self.act1(self.conv1(x)))
        y = self.drop2(self.act2(self.conv2(y)))
        res = x if self.down is None else self.down(x)
        return self.out_act(y + res)

class TCNEncoder(nn.Module):
    def __init__(self, in_ch: int, hidden: int, levels: int, k: int, dropout: float, use_weight_norm: bool, act_name: str):
        super().__init__()
        blocks = []
        ch_in = in_ch
        for i in range(levels):
            d = 2 ** i
            blocks.append(TCNBlock(
                ch_in, hidden, k=k, d=d, dropout=dropout,
                use_weight_norm=use_weight_norm, act_name=act_name
            ))
            ch_in = hidden
        self.net = nn.Sequential(*blocks)

    def forward(self, x):
        return self.net(x)

class TCNWeather2Step(nn.Module):
    def __init__(self, n_feat, n_tgt, horizon, hidden, levels, k, dropout, precip_idx,
                 amount_space="log1p", use_weight_norm=True, act_name="relu"):
        super().__init__()
        self.horizon = horizon
        self.n_tgt = n_tgt
        self.precip_idx = precip_idx
        self.amount_space = amount_space

        self.enc = TCNEncoder(
            in_ch=n_feat, hidden=hidden, levels=levels, k=k,
            dropout=dropout, use_weight_norm=use_weight_norm, act_name=act_name
        )
        self.head_main = nn.Linear(hidden, horizon * n_tgt)
        self.head_rain = nn.Linear(hidden, horizon)

    def forward(self, x):
        x = x.transpose(1, 2)
        h = self.enc(x)
        h_last = h[:, :, -1]
        y = self.head_main(h_last).view(-1, self.horizon, self.n_tgt)
        rain_logit = self.head_rain(h_last).view(-1, self.horizon)

        if self.amount_space == "mm":
            y[..., self.precip_idx] = F.softplus(y[..., self.precip_idx])
        return y, rain_logit

def count_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def receptive_field(levels: int, k: int, convs_per_block: int = 2) -> int:
    dilations = [2**i for i in range(levels)]
    rf = 1 + (k - 1) * convs_per_block * sum(dilations)
    return int(rf)

def build_optimizer(model: nn.Module):
    return torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

def safe_set_weight(w: torch.Tensor, name: str, value: float):
    if name in IDX:
        w[..., IDX[name]] = float(value)

def build_target_weights(device, dtype=torch.float32) -> torch.Tensor:
    w = torch.ones((1, 1, n_tgt), device=device, dtype=dtype)
    w[..., PRECIP_IDX] = 0.0

    if cfg.scale_y_other:
        return w

    safe_set_weight(w, "temperature_2m", cfg.w_temp)
    safe_set_weight(w, "relative_humidity_2m", cfg.w_rh)
    safe_set_weight(w, "surface_pressure", cfg.w_press)
    safe_set_weight(w, "cloud_cover", cfg.w_cloud)
    safe_set_weight(w, "u10", cfg.w_wind)
    safe_set_weight(w, "v10", cfg.w_wind)
    return w

def loss_fn(y_pred, rain_logit, y_true, pos_weight_value: float,
            y_mu_t: torch.Tensor, y_sd_t: torch.Tensor,
            w_tgt: torch.Tensor):
    if cfg.scale_y_other:
        yp_s = (y_pred - y_mu_t[None,None,:]) / y_sd_t[None,None,:]
        yt_s = (y_true - y_mu_t[None,None,:]) / y_sd_t[None,None,:]
    else:
        yp_s, yt_s = y_pred, y_true

    reg = F.smooth_l1_loss(yp_s, yt_s, reduction="none")
    loss_other = (reg * w_tgt).mean()

    y_mm = y_true[..., PRECIP_IDX]
    rain_label = (y_mm >= cfg.rain_thr_mm).float()
    pw = torch.tensor(pos_weight_value, device=y_true.device, dtype=y_true.dtype)
    loss_cls = F.binary_cross_entropy_with_logits(rain_logit, rain_label, pos_weight=pw)

    pos = rain_label > 0.5
    if cfg.amount_space == "mm":
        pred_mm = y_pred[..., PRECIP_IDX].clamp(min=0.0, max=float(cfg.precip_mm_max))
        loss_reg = F.smooth_l1_loss(pred_mm[pos], y_mm[pos]) if pos.any() else pred_mm.mean() * 0.0
    else:
        pred_log = y_pred[..., PRECIP_IDX]
        true_log = torch.log1p(y_mm.clamp(min=0.0))
        loss_reg = F.smooth_l1_loss(pred_log[pos], true_log[pos]) if pos.any() else pred_log.mean() * 0.0

    total = loss_other + cfg.loss_cls_w * loss_cls + cfg.loss_reg_w * loss_reg
    return total

# ============================================================
# 7) SANITY CHECKS
# ============================================================
def sanity_checks():
    X, Y, _, _ = load_npz(TRAIN_FILES[0])
    assert X.shape[1] == cfg.lag, f"X lag mismatch: {X.shape}"
    assert Y.shape[1] == cfg.horizon, f"Y horizon mismatch: {Y.shape}"
    assert X.shape[2] == n_feat, f"n_feat mismatch: {X.shape}"
    assert Y.shape[2] == n_tgt, f"n_tgt mismatch: {Y.shape}"

    pmin = float(np.min(Y[..., PRECIP_IDX]))
    if pmin < -1e-6:
        raise ValueError(f"Precip appears negative (min={pmin}).")
    Xs = scale_x(X[: min(2048, X.shape[0])])
    if not np.isfinite(Xs).all():
        raise ValueError("Scaled X contains NaN/Inf.")
    Ys = Y[: min(2048, Y.shape[0])]
    if not np.isfinite(Ys).all():
        raise ValueError("Y contains NaN/Inf.")

    print("[sanity] OK | precip_min_mm:", pmin, "| X_scaled finite:", True)
    del X, Y, Xs, Ys
    gc.collect()

# ============================================================
# 8) TRAIN (GPU/CPU) — streaming over files
# ============================================================
def train_gpu_cpu():
    sanity_checks()

    device = torch.device("cuda" if (is_gpu() and torch.cuda.is_available()) else "cpu")
    print("[device]", device)

    rf = receptive_field(cfg.tcn_levels, cfg.kernel_size, convs_per_block=2)
    print(f"[TCN] levels={cfg.tcn_levels} k={cfg.kernel_size} receptive_field≈{rf} (input_len={cfg.lag})")

    model = TCNWeather2Step(
        n_feat=n_feat, n_tgt=n_tgt, horizon=cfg.horizon,
        hidden=cfg.tcn_channels, levels=cfg.tcn_levels, k=cfg.kernel_size,
        dropout=cfg.dropout, precip_idx=PRECIP_IDX, amount_space=cfg.amount_space,
        use_weight_norm=cfg.use_weight_norm, act_name=cfg.act
    ).to(device)
    print("[model] params:", count_params(model))

    opt = build_optimizer(model)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", factor=0.5, patience=2)

    use_amp = (device.type == "cuda")
    scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

    y_mu_t = torch.tensor(Y_mu, device=device, dtype=torch.float32)
    y_sd_t = torch.tensor(Y_sd, device=device, dtype=torch.float32)
    w_tgt  = build_target_weights(device=device, dtype=torch.float32)

    best_val = float("inf")
    bad = 0
    best_path = OUT_DIR / "tcn_weather2step_best.pt"

    for ep in range(1, cfg.epochs + 1):
        model.train()
        tr_sum = 0.0
        tr_n = 0

        files = TRAIN_FILES.copy()
        np.random.shuffle(files)

        for fp in files:
            X, Y, _, _ = load_npz(fp)
            X = scale_x(X)

            ds = ArrayDataset(X, Y)
            loader = DataLoader(
                ds,
                batch_size=cfg.batch_size,
                shuffle=True,
                num_workers=(2 if device.type == "cuda" else 0),
                pin_memory=(device.type == "cuda"),
                drop_last=True
            )

            for xb, yb in loader:
                xb = xb.to(device, non_blocking=True)
                yb = yb.to(device, non_blocking=True)

                opt.zero_grad(set_to_none=True)

                if use_amp:
                    with torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
                        yp, rl = model(xb)
                        loss = loss_fn(yp, rl, yb, POS_WEIGHT, y_mu_t, y_sd_t, w_tgt)
                    scaler.scale(loss).backward()
                    if cfg.grad_clip and cfg.grad_clip > 0:
                        scaler.unscale_(opt)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
                    scaler.step(opt)
                    scaler.update()
                else:
                    yp, rl = model(xb)
                    loss = loss_fn(yp, rl, yb, POS_WEIGHT, y_mu_t, y_sd_t, w_tgt)
                    loss.backward()
                    if cfg.grad_clip and cfg.grad_clip > 0:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
                    opt.step()

                tr_sum += float(loss.item()) * xb.size(0)
                tr_n += xb.size(0)

            del X, Y, ds, loader
            gc.collect()

        tr_loss = tr_sum / max(tr_n, 1)

        model.eval()
        va_sum = 0.0
        va_n = 0
        with torch.no_grad():
            for fp in VAL_FILES:
                X, Y, _, _ = load_npz(fp)
                X = scale_x(X)

                ds = ArrayDataset(X, Y)
                loader = DataLoader(ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory, drop_last=False)

                for xb, yb in loader:
                    xb = xb.to(device, non_blocking=True)
                    yb = yb.to(device, non_blocking=True)

                    if use_amp:
                        with torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
                            yp, rl = model(xb)
                            loss = loss_fn(yp, rl, yb, POS_WEIGHT, y_mu_t, y_sd_t, w_tgt)
                    else:
                        yp, rl = model(xb)
                        loss = loss_fn(yp, rl, yb, POS_WEIGHT, y_mu_t, y_sd_t, w_tgt)

                    va_sum += float(loss.item()) * xb.size(0)
                    va_n += xb.size(0)

                del X, Y, ds, loader
                gc.collect()

        va_loss = va_sum / max(va_n, 1)
        sched.step(va_loss)

        lr_now = opt.param_groups[0]["lr"]
        print(f"epoch {ep:02d} | lr {lr_now:.2e} | train {tr_loss:.6f} | val {va_loss:.6f}")

        if va_loss < best_val - 1e-6:
            best_val = va_loss
            bad = 0
            ckpt_meta = dict(meta)
            ckpt_meta.update({
                "n_feat": n_feat,
                "n_tgt": n_tgt,
                "precip_idx": PRECIP_IDX,
                "input_cols": input_cols,
                "target_cols": target_cols,
            })
            torch.save(
                {
                    "model": model.state_dict(),
                    "cfg": cfg.__dict__,
                    "meta": ckpt_meta,
                    "pos_weight": POS_WEIGHT,
                    "x_scaler": {"mu": X_mu.tolist(), "sd": X_sd.tolist()},
                    "y_scaler_nonprecip": {"mu": Y_mu.tolist(), "sd": Y_sd.tolist()},
                },
                best_path
            )
        else:
            bad += 1
            if bad >= cfg.patience:
                print("[early stop]")
                break

    print("[best val]", best_val, "->", best_path)
    return best_path

# ============================================================
# 9) REPORT
# ============================================================
def event_metrics_counts(y_true01: np.ndarray, y_pred01: np.ndarray):
    tp = int(((y_true01 == 1) & (y_pred01 == 1)).sum())
    fp = int(((y_true01 == 0) & (y_pred01 == 1)).sum())
    fn = int(((y_true01 == 1) & (y_pred01 == 0)).sum())
    return tp, fp, fn

def tune_pthr_on_val(model, device) -> Tuple[float, pd.DataFrame]:
    cand = np.array(cfg.p_thr_cand, dtype=np.float32)
    K = len(cand)
    tp = np.zeros(K, dtype=np.int64)
    fp = np.zeros(K, dtype=np.int64)
    fn = np.zeros(K, dtype=np.int64)

    model.eval()
    with torch.no_grad():
        for fp_npz in VAL_FILES:
            X, Y, _, _ = load_npz(fp_npz)
            X = scale_x(X)
            ds = ArrayDataset(X, Y)
            loader = DataLoader(ds, batch_size=512, shuffle=False, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory, drop_last=False)

            for xb, yb in loader:
                xb = xb.to(device)
                yb = yb.to(device)

                _, logit = model(xb)
                p = torch.sigmoid(logit).detach().cpu().numpy().astype(np.float32)
                y_mm = yb[..., PRECIP_IDX].detach().cpu().numpy().astype(np.float32)
                y_ev = (y_mm >= cfg.rain_thr_mm).reshape(-1)
                p1 = p.reshape(-1)

                for i, thr in enumerate(cand):
                    pred = (p1 >= thr)
                    tpi, fpi, fni = event_metrics_counts(y_ev.astype(np.int32), pred.astype(np.int32))
                    tp[i] += tpi; fp[i] += fpi; fn[i] += fni

            del X, Y, ds, loader
            gc.collect()

    prec = tp / (tp + fp + 1e-9)
    rec  = tp / (tp + fn + 1e-9)
    f1   = 2 * prec * rec / (prec + rec + 1e-9)

    best_i = int(np.argmax(f1))
    best_thr = float(cand[best_i])

    rank = pd.DataFrame({
        "p_thr": cand,
        "precision": prec,
        "recall": rec,
        "f1": f1,
        "tp": tp, "fp": fp, "fn": fn
    }).sort_values("f1", ascending=False).reset_index(drop=True)

    return best_thr, rank

def accumulate_precip_stats(model, device, files: List[Path], p_thr: float, split_name: str, amount_space: str):
    H = cfg.horizon
    tp = np.zeros(H, dtype=np.int64)
    fp = np.zeros(H, dtype=np.int64)
    fn = np.zeros(H, dtype=np.int64)
    npos = np.zeros(H, dtype=np.int64)

    abs_hard = np.zeros(H, dtype=np.float64)
    sq_hard  = np.zeros(H, dtype=np.float64)

    n_all = 0

    model.eval()
    with torch.no_grad():
        for fp_npz in files:
            X, Y, _, _ = load_npz(fp_npz)
            X = scale_x(X)
            ds = ArrayDataset(X, Y)
            loader = DataLoader(ds, batch_size=512, shuffle=False, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory, drop_last=False)

            for xb, yb in loader:
                xb = xb.to(device)
                yb = yb.to(device)

                yp, logit = model(xb)
                p = torch.sigmoid(logit)

                y_mm = yb[..., PRECIP_IDX]
                y_mm_np = y_mm.detach().cpu().numpy().astype(np.float32)
                y_ev_np = (y_mm_np >= cfg.rain_thr_mm)

                if amount_space == "mm":
                    pred_mm = yp[..., PRECIP_IDX].clamp(min=0.0, max=float(cfg.precip_mm_max))
                else:
                    pred_log = yp[..., PRECIP_IDX]
                    pred_mm = precip_log_to_mm(pred_log, cfg.precip_mm_max)

                pred_mm_np = pred_mm.detach().cpu().numpy().astype(np.float32)
                p_np = p.detach().cpu().numpy().astype(np.float32)
                pred_ev_np = (p_np >= p_thr)

                tp += (pred_ev_np & y_ev_np).sum(axis=0).astype(np.int64)
                fp += (pred_ev_np & (~y_ev_np)).sum(axis=0).astype(np.int64)
                fn += ((~pred_ev_np) & y_ev_np).sum(axis=0).astype(np.int64)
                npos += y_ev_np.sum(axis=0).astype(np.int64)

                hard = np.where(pred_ev_np, pred_mm_np, 0.0).astype(np.float32)
                err_h = hard - y_mm_np

                abs_hard += np.abs(err_h).sum(axis=0)
                sq_hard  += (err_h * err_h).sum(axis=0)

                n_all += y_mm_np.shape[0]

            del X, Y, ds, loader
            gc.collect()

    prec = tp / (tp + fp + 1e-9)
    rec  = tp / (tp + fn + 1e-9)
    f1   = 2 * prec * rec / (prec + rec + 1e-9)

    denom_all = max(n_all, 1)
    mae_hard = abs_hard / denom_all
    rmse_hard = np.sqrt(sq_hard / denom_all)

    rep = pd.DataFrame({
        "split": split_name,
        "horizon": np.arange(1, H+1),
        "rain_thr_mm": cfg.rain_thr_mm,
        "p_thr": float(p_thr),
        "event_precision": prec,
        "event_recall": rec,
        "event_f1": f1,
        "tp": tp, "fp": fp, "fn": fn,
        "npos": npos,
        "final_mae_all_hard": mae_hard,
        "final_rmse_all_hard": rmse_hard,
    })
    return rep

def accumulate_other_targets(model, device, files: List[Path], split_name: str):
    H = cfg.horizon
    T = len(target_cols)
    targets = [t for t in range(T) if t != PRECIP_IDX]

    sum_abs = np.zeros((len(targets), H), dtype=np.float64)
    sum_sq  = np.zeros((len(targets), H), dtype=np.float64)
    n_all = 0

    model.eval()
    with torch.no_grad():
        for fp_npz in files:
            X, Y, _, _ = load_npz(fp_npz)
            X = scale_x(X)
            ds = ArrayDataset(X, Y)
            loader = DataLoader(ds, batch_size=512, shuffle=False, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory, drop_last=False)

            for xb, yb in loader:
                xb = xb.to(device)
                yb = yb.to(device)
                yp, _ = model(xb)

                y_true = yb.detach().cpu().numpy().astype(np.float32)
                y_pred = yp.detach().cpu().numpy().astype(np.float32)

                for i, ti in enumerate(targets):
                    de = (y_pred[:, :, ti] - y_true[:, :, ti]).astype(np.float64)
                    sum_abs[i] += np.abs(de).sum(axis=0)
                    sum_sq[i]  += (de * de).sum(axis=0)

                n_all += y_true.shape[0]

            del X, Y, ds, loader
            gc.collect()

    denom = max(n_all, 1)
    rows = []
    for i, ti in enumerate(targets):
        name = target_cols[ti]
        mae = (sum_abs[i] / denom).astype(np.float64)
        rmse = np.sqrt(sum_sq[i] / denom).astype(np.float64)
        for h in range(H):
            rows.append({
                "split": split_name,
                "target": name,
                "horizon": h+1,
                "mae": float(mae[h]),
                "rmse": float(rmse[h]),
            })
    return pd.DataFrame(rows)

def summarize_bins_precip(rep_split: pd.DataFrame):
    rows = []
    for a,b in cfg.bins:
        r = rep_split[(rep_split["horizon"]>=a) & (rep_split["horizon"]<=b)]
        if len(r) == 0:
            continue
        rows.append({
            "split": r["split"].iloc[0],
            "horizon_bin": f"{a}-{b}",
            "event_f1_mean": float(r["event_f1"].mean()),
            "event_recall_mean": float(r["event_recall"].mean()),
            "event_precision_mean": float(r["event_precision"].mean()),
            "mae_all_hard_mean": float(r["final_mae_all_hard"].mean()),
            "rmse_all_hard_mean": float(r["final_rmse_all_hard"].mean()),
            "npos_sum": int(r["npos"].sum()),
        })
    return pd.DataFrame(rows)

def summarize_bins_other(rep_long: pd.DataFrame):
    rows = []
    for (split, target), g in rep_long.groupby(["split", "target"]):
        for a,b in cfg.bins:
            r = g[(g["horizon"]>=a) & (g["horizon"]<=b)]
            if len(r) == 0:
                continue
            rows.append({
                "split": split,
                "target": target,
                "horizon_bin": f"{a}-{b}",
                "mae_mean": float(r["mae"].mean()),
                "rmse_mean": float(r["rmse"].mean()),
            })
    return pd.DataFrame(rows)

def run_report(best_ckpt_path: Path):
    ckpt = torch.load(best_ckpt_path, map_location="cpu", weights_only=False)
    cfg_ck = ckpt.get("cfg", {})
    amount_space = str(cfg_ck.get("amount_space", cfg.amount_space))

    model = TCNWeather2Step(
        n_feat=n_feat, n_tgt=n_tgt, horizon=cfg.horizon,
        hidden=int(cfg_ck.get("tcn_channels", cfg.tcn_channels)),
        levels=int(cfg_ck.get("tcn_levels", cfg.tcn_levels)),
        k=int(cfg_ck.get("kernel_size", cfg.kernel_size)),
        dropout=float(cfg_ck.get("dropout", cfg.dropout)),
        precip_idx=PRECIP_IDX,
        amount_space=amount_space,
        use_weight_norm=bool(cfg_ck.get("use_weight_norm", cfg.use_weight_norm)),
        act_name=str(cfg_ck.get("act", cfg.act)),
    )
    model.load_state_dict(ckpt["model"], strict=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    best_p_thr, thr_rank = tune_pthr_on_val(model, device)
    thr_rank.to_csv(REPORT_DIR / "precip_thr_tuning_val.csv", index=False)
    print("[REPORT] best_p_thr (VAL F1) =", best_p_thr)

    rep_p_val = accumulate_precip_stats(model, device, VAL_FILES, best_p_thr, "val", amount_space)
    rep_p_te  = accumulate_precip_stats(model, device, TEST_FILES, best_p_thr, "test", amount_space)
    rep_p = pd.concat([rep_p_val, rep_p_te], ignore_index=True)
    rep_p.to_csv(REPORT_DIR / f"tcn_precip_report_val_test_h1-{cfg.horizon}.csv", index=False)

    sum_p = pd.concat([summarize_bins_precip(rep_p_val), summarize_bins_precip(rep_p_te)], ignore_index=True)
    sum_p["p_thr"] = best_p_thr
    sum_p["rain_thr_mm"] = cfg.rain_thr_mm
    sum_p.to_csv(REPORT_DIR / "tcn_precip_summary_bins_val_test.csv", index=False)

    rep_o_val = accumulate_other_targets(model, device, VAL_FILES, "val")
    rep_o_te  = accumulate_other_targets(model, device, TEST_FILES, "test")
    rep_o = pd.concat([rep_o_val, rep_o_te], ignore_index=True)
    rep_o.to_csv(REPORT_DIR / "tcn_other_targets_report_val_test_long.csv", index=False)

    sum_o = summarize_bins_other(rep_o)
    sum_o.to_csv(REPORT_DIR / "tcn_other_targets_summary_bins_val_test.csv", index=False)

    print("[REPORT] saved in:", REPORT_DIR)

# ============================================================
# 10) MAIN
# ============================================================
def main():
    print("ACCELERATOR:", ACCELERATOR)
    print(f"CONFIG: LAG={cfg.lag}, HORIZON={cfg.horizon}, TCN_CHANNELS={cfg.tcn_channels}, LEVELS={cfg.tcn_levels}")
    best_ckpt = train_gpu_cpu()
    if cfg.run_report:
        run_report(best_ckpt)
    print("DONE. out_dir =", OUT_DIR)

if __name__ == "__main__":
    main()