# Fusion baseline Model (GOES + Ground)

In [None]:
"""
Fusion (NPZ) — GOES images + ground/tabular → GHI (simple & clean)

Assumptions
- Filenames encode UTC hour as YYYYMMDD_HH (e.g., 20220101_00_MCMIPF.npz).
- DSRF .npz: [H,W] or [T,H,W] (use last along axis 0).
- MCMIPF .npz: [C,H,W] or [T,H,W] (treated as channels).
- Tabular will be coerced to UTC and downsampled to hourly.

Paths
- DSRF  : /mnt/SOLARLAB/E_Ladino/Repo/irradiance-fusion-forecast/data_processed/GOES/DSRF
- MCMIPF: /mnt/SOLARLAB/E_Ladino/Repo/irradiance-fusion-forecast/data_processed/GOES/MCMIPF
- Parquets: ../data_processed/ground_{train,val,test}_h6.parquet
- Outputs : ../models + ../reports/figures
"""

In [None]:
from __future__ import annotations
import re, json
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, List, Tuple, Dict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import layers, models

In [None]:
# ----------------- CONFIG -----------------
@dataclass
class Config:
    product: str  # 'DSRF' or 'MCMIPF'
    dsrf_dir: Path
    mcmipf_dir: Path
    parquet_train: Path
    parquet_val: Path
    parquet_test: Path
    target_col: str = "y_ghi_h6"
    feature_cols: Optional[List[str]] = None
    seq_len: int = 12
    batch_size: int = 8
    epochs: int = 40
    out_dir: Path = Path("../models")

In [None]:
# ----------------- INDEX (NPZ) -----------------
_RGX = re.compile(r"(20\d{6})[_T]?(\d{2})")  # YYYYMMDD_HH

def ts_from_name(p: Path) -> Optional[pd.Timestamp]:
    m = _RGX.search(p.name)
    if not m:
        return None
    ymd, hh = m.groups()
    return pd.to_datetime(ymd + hh + "00", format="%Y%m%d%H%M", utc=True)

def list_npz_by_ts(root: Path) -> pd.Series:
    files = sorted(Path(root).glob("**/*.npz"))
    pairs = [(ts_from_name(p), p) for p in files]
    pairs = [(t, p) for (t, p) in pairs if t is not None]
    if not pairs:
        raise FileNotFoundError(f"No .npz with timestamp pattern under {root}")
    s = pd.Series({t: p for t, p in pairs}).sort_index()
    s.index.name = "ts"
    return s

# ----------------- READERS (NPZ) -----------------
def _normalize_dsrf(img: np.ndarray) -> np.ndarray:
    img = np.nan_to_num(img, nan=0.0, posinf=0.0, neginf=0.0)
    img = np.clip(img, 0, None) / 1200.0  # ~[0,1]
    return img.astype(np.float32)

def _standardize_per_channel(chw: np.ndarray) -> np.ndarray:
    chw = np.nan_to_num(chw, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    C, H, W = chw.shape
    out = np.empty_like(chw, dtype=np.float32)
    for c in range(C):
        band = chw[c]
        med = np.median(band)
        p90 = np.percentile(band, 90)
        scale = p90 if p90 > 1e-3 else 1.0
        out[c] = (band - med) / scale
    return out

def read_dsrf_npz(npz_path: Path) -> np.ndarray:
    data = np.load(npz_path)
    key = "dsrf" if "dsrf" in data else list(data.keys())[0]
    arr = data[key]
    if arr.ndim == 2:
        img = _normalize_dsrf(arr)
        return img[..., None]                   # H,W,1
    if arr.ndim == 3:
        img = _normalize_dsrf(arr[-1])         # last frame
        return img[..., None]
    raise ValueError(f"Unexpected DSRF shape: {arr.shape}")

def read_mcmipf_npz(npz_path: Path) -> np.ndarray:
    data = np.load(npz_path)
    key = "mcmipf" if "mcmipf" in data else list(data.keys())[0]
    arr = data[key]
    if arr.ndim == 3:                           # [C,H,W] or [T,H,W]
        chw = _standardize_per_channel(arr.astype(np.float32))
        return np.transpose(chw, (1, 2, 0))     # H,W,C
    if arr.ndim == 2:
        hw = arr.astype(np.float32)
        med = np.median(hw); p90 = np.percentile(hw, 90); p90 = p90 if p90 > 1e-3 else 1.0
        hw = (hw - med) / p90
        return hw[..., None]
    raise ValueError(f"Unexpected MCMIPF shape: {arr.shape}")



In [None]:
# ----------------- SEQUENCES -----------------
def build_image_sequences(ts_index: pd.DatetimeIndex, fmap: pd.Series, L: int, product: str) -> Dict[pd.Timestamp, np.ndarray]:
    if len(ts_index) < L:
        return {}
    freq = pd.to_timedelta("1H")
    out: Dict[pd.Timestamp, np.ndarray] = {}
    for t in ts_index:
        seq = pd.date_range(end=t, periods=L, freq=freq)
        if not all(tt in fmap.index for tt in seq):
            continue
        frames = []
        for tt in seq:
            p = fmap.loc[tt]
            hwc = read_dsrf_npz(p) if product.upper() == "DSRF" else read_mcmipf_npz(p)
            frames.append(hwc)
        out[t] = np.stack(frames, axis=0).astype(np.float32)  # [L,H,W,C]
    return out

def build_tabular_sequences(df: pd.DataFrame, target_col: str, L: int) -> Dict[pd.Timestamp, Tuple[np.ndarray,float]]:
    if target_col not in df.columns:
        raise KeyError(f"{target_col} not found")
    feat_cols = [c for c in df.columns if c != target_col and pd.api.types.is_numeric_dtype(df[c])]
    X = df[feat_cols].astype("float32")
    y = df[target_col].astype("float32")
    out: Dict[pd.Timestamp, Tuple[np.ndarray,float]] = {}
    idx = df.index
    for i in range(L-1, len(df)):
        t = idx[i]
        block = X.iloc[i-L+1:i+1].values
        if np.isnan(block).any() or not np.isfinite(y.iloc[i]):
            continue
        out[t] = (block, float(y.iloc[i]))
    return out



In [None]:
# ----------------- TF.DATA -----------------
def make_tf_dataset(img_dict, tab_dict, batch: int, shuffle=True):
    keys = sorted(set(img_dict.keys()) & set(tab_dict.keys()))
    if not keys:
        raise ValueError("No common timestamps between images and tabular.")
    Ximg = np.stack([img_dict[t] for t in keys])
    Xtab = np.stack([tab_dict[t][0] for t in keys])
    Y    = np.array([tab_dict[t][1] for t in keys], dtype="float32")
    ds = tf.data.Dataset.from_tensor_slices(((Ximg, Xtab), Y))
    if shuffle:
        ds = ds.shuffle(min(4*batch, len(Y)))
    ds = ds.batch(batch).prefetch(tf.data.AUTOTUNE)
    return ds, keys



In [None]:
# ----------------- MODELS -----------------
def build_convlstm_fusion(input_img, input_tab, dropout=0.2):
    L,H,W,C = input_img; Lt,F = input_tab
    img_in = layers.Input(shape=(L,H,W,C))
    x = layers.ConvLSTM2D(32, (3,3), padding="same", return_sequences=True, activation="relu")(img_in)
    x = layers.BatchNormalization()(x)
    x = layers.ConvLSTM2D(32, (3,3), padding="same", return_sequences=False, activation="relu")(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalAveragePooling2D()(x)

    tab_in = layers.Input(shape=(Lt,F))
    t = layers.LSTM(64)(tab_in)

    h = layers.Concatenate()([x,t])
    h = layers.Dropout(dropout)(h)
    h = layers.Dense(128, activation="relu")(h)
    h = layers.Dense(64, activation="relu")(h)
    out = layers.Dense(1, dtype="float32")(h)

    m = models.Model([img_in, tab_in], out)
    m.compile(optimizer=tf.keras.optimizers.Adam(1e-3), loss="mse", metrics=["mae"])
    return m

def build_3dcnn_fusion(input_img, input_tab, dropout=0.2):
    L,H,W,C = input_img; Lt,F = input_tab
    img_in = layers.Input(shape=(L,H,W,C))
    x = layers.Conv3D(32, (3,3,3), strides=(1,2,2), padding="same", activation="relu")(img_in)
    x = layers.BatchNormalization()(x)
    x = layers.Conv3D(64, (3,3,3), strides=(1,2,2), padding="same", activation="relu")(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalAveragePooling3D()(x)

    tab_in = layers.Input(shape=(Lt,F))
    t = layers.LSTM(64)(tab_in) if Lt > 1 else layers.Flatten()(tab_in)
    t = layers.Dense(128, activation="relu")(t)

    h = layers.Concatenate()([x,t])
    h = layers.Dropout(dropout)(h)
    h = layers.Dense(128, activation="relu")(h)
    h = layers.Dense(64, activation="relu")(h)
    out = layers.Dense(1, dtype="float32")(h)

    m = models.Model([img_in, tab_in], out)
    m.compile(optimizer=tf.keras.optimizers.Adam(1e-3), loss="mse", metrics=["mae"])
    return m



In [None]:
# ----------------- TRAIN / EVAL -----------------
def _ensure_utc_hourly(df: pd.DataFrame) -> pd.DataFrame:
    idx = df.index
    if getattr(idx, "tz", None) is None:
        df = df.tz_localize("UTC")
    else:
        df = df.tz_convert("UTC")
    return df.groupby(pd.Grouper(freq="1H")).last().dropna(how="any")

def train_fusion(cfg: Config):
    cfg.out_dir.mkdir(parents=True, exist_ok=True)

    # Tabular → UTC hourly
    tr = _ensure_utc_hourly(pd.read_parquet(cfg.parquet_train).sort_index())
    va = _ensure_utc_hourly(pd.read_parquet(cfg.parquet_val).sort_index())
    te = _ensure_utc_hourly(pd.read_parquet(cfg.parquet_test).sort_index())

    feat_cols = cfg.feature_cols or [c for c in tr.columns if c != cfg.target_col and pd.api.types.is_numeric_dtype(tr[c])]
    tr = tr[feat_cols + [cfg.target_col]]
    va = va[feat_cols + [cfg.target_col]]
    te = te[feat_cols + [cfg.target_col]]

    # NPZ images → restrict to tabular time span
    fmap = list_npz_by_ts(cfg.dsrf_dir if cfg.product.upper()=="DSRF" else cfg.mcmipf_dir)
    tmin = min(tr.index.min(), va.index.min()); tmax = te.index.max()
    fmap = fmap[(fmap.index >= tmin) & (fmap.index <= tmax)]

    # Build sequences on intersection
    img_tr = build_image_sequences(fmap.index[fmap.index.isin(tr.index)], fmap, cfg.seq_len, cfg.product)
    img_va = build_image_sequences(fmap.index[fmap.index.isin(va.index)], fmap, cfg.seq_len, cfg.product)
    img_te = build_image_sequences(fmap.index[fmap.index.isin(te.index)], fmap, cfg.seq_len, cfg.product)

    tab_tr = build_tabular_sequences(tr, cfg.target_col, cfg.seq_len)
    tab_va = build_tabular_sequences(va, cfg.target_col, cfg.seq_len)
    tab_te = build_tabular_sequences(te, cfg.target_col, cfg.seq_len)

    ds_tr, keys_tr = make_tf_dataset(img_tr, tab_tr, cfg.batch_size, shuffle=True)
    ds_va, keys_va = make_tf_dataset(img_va, tab_va, cfg.batch_size, shuffle=False)
    ds_te, keys_te = make_tf_dataset(img_te, tab_te, cfg.batch_size, shuffle=False)

    # Shapes
    (ximg_s, xtab_s), _ = next(iter(ds_tr.take(1)))
    L,H,W,C = ximg_s.shape[1:]; Lt,F = xtab_s.shape[1:]

    # Model
    model = build_convlstm_fusion((L,H,W,C), (Lt,F)) if cfg.product.upper()=="DSRF" else \
            build_3dcnn_fusion((L,H,W,C), (Lt,F))

    ckpt = str((cfg.out_dir / f"best_{cfg.product.lower()}_fusion.keras").resolve())
    cbs = [
        tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True),
        tf.keras.callbacks.ModelCheckpoint(ckpt, monitor="val_loss", save_best_only=True),
    ]

    model.fit(ds_tr, validation_data=ds_va, epochs=cfg.epochs, callbacks=cbs)

    # Evaluate
    y_true, y_pred = [], []
    for (xi, xt), y in ds_te:
        yp = model.predict((xi, xt), verbose=0)
        y_true.append(y.numpy()); y_pred.append(yp)
    y_true = np.concatenate(y_true).squeeze()
    y_pred = np.concatenate(y_pred).squeeze()

    # Baseline
    base_src = None
    for c in ["ghi_qc","ghi_sg_definitive","ghi_qc_lag1"]:
        if c in te.columns: base_src = te[c]; break
    if base_src is None:
        base_src = pd.Series(np.nanmedian(tr[cfg.target_col]), index=te.index)

    idx_te = pd.DatetimeIndex(keys_te)
    y_base = base_src.reindex(idx_te).to_numpy(dtype="float32")

    rmse = float(np.sqrt(np.mean((y_true - y_pred)**2)))
    mae  = float(np.mean(np.abs(y_true - y_pred)))
    rmse_base = float(np.sqrt(np.mean((y_true - y_base[:len(y_true)])**2)))
    skill = 1.0 - (rmse / (rmse_base if rmse_base > 1e-9 else rmse))

    out = {"product": cfg.product, "seq_len": int(cfg.seq_len),
           "rmse": rmse, "mae": mae, "skill": float(skill), "n_test": int(len(y_true))}
    with open(cfg.out_dir / f"metrics_{cfg.product.lower()}.json", "w") as f:
        json.dump(out, f, indent=2)

    # Summary CSV
    summary_csv = cfg.out_dir / "fusion_test_summary.csv"
    pd.DataFrame([out]).to_csv(summary_csv, mode="a" if summary_csv.exists() else "w",
                               index=False, header=not summary_csv.exists())

    # Plots
    OUT_FIG = Path("../reports/figures"); OUT_FIG.mkdir(parents=True, exist_ok=True)
    N = min(400, len(y_true))

    plt.figure(figsize=(12, 3.6))
    plt.plot(idx_te[:N], y_true[:N], label="truth", lw=1.4)
    plt.plot(idx_te[:N], y_pred[:N], label=f"{cfg.product} fusion", lw=1.1)
    plt.plot(idx_te[:N], y_base[:N], label="baseline", lw=1.0, alpha=0.7)
    plt.title(f"Test — Truth vs Fusion ({cfg.product}) vs Baseline ({cfg.target_col})")
    plt.ylabel("GHI (W/m²)" if cfg.target_col.startswith("y_ghi") else "target")
    plt.xlabel("Time"); plt.grid(True, ls="--", alpha=0.3); plt.legend()
    plt.xticks(rotation=45); plt.tight_layout()
    plt.savefig(OUT_FIG / f"{cfg.product}_fusion_ts_test.png", dpi=140)
    plt.close()

    lim_min = float(min(np.min(y_true), np.min(y_pred)))
    lim_max = float(max(np.max(y_true), np.max(y_pred)))
    plt.figure(figsize=(4.8, 4.8))
    plt.scatter(y_true, y_pred, s=10, alpha=0.5)
    plt.plot([lim_min, lim_max], [lim_min, lim_max], linestyle="--", linewidth=1.0)
    plt.xlabel("Actual"); plt.ylabel("Predicted")
    plt.title(f"{cfg.product} Fusion — Actual vs Predicted\nRMSE={rmse:.3f}  MAE={mae:.3f}  Skill={skill:.3f}")
    plt.grid(True, ls="--", alpha=0.3); plt.tight_layout()
    plt.savefig(OUT_FIG / f"{cfg.product}_fusion_scatter.png", dpi=140)
    plt.close()

    resid = y_pred - y_true
    plt.figure(figsize=(6, 3.2))
    plt.hist(resid, bins=50, alpha=0.85)
    plt.axvline(0, linestyle="--", linewidth=1)
    plt.title(f"{cfg.product} Fusion — Residuals (mean={np.mean(resid):.3f})")
    plt.xlabel("Residual"); plt.ylabel("Frequency")
    plt.grid(True, ls="--", alpha=0.3); plt.tight_layout()
    plt.savefig(OUT_FIG / f"{cfg.product}_fusion_residuals.png", dpi=140)
    plt.close()

    print(json.dumps(out, indent=2))
    return model, out



In [None]:
# ----------------- MAIN -----------------
if __name__ == "__main__":
    cfg = Config(
        product="DSRF",
        dsrf_dir=Path("/mnt/SOLARLAB/E_Ladino/Repo/irradiance-fusion-forecast/data_processed/GOES/DSRF"),
        mcmipf_dir=Path("/mnt/SOLARLAB/E_Ladino/Repo/irradiance-fusion-forecast/data_processed/GOES/MCMIPF"),
        parquet_train=Path("../data_processed/ground_train_h6.parquet"),
        parquet_val=Path("../data_processed/ground_val_h6.parquet"),
        parquet_test=Path("../data_processed/ground_test_h6.parquet"),
        seq_len=12, batch_size=8, epochs=40, out_dir=Path("../models")
    )

    print("\n>>> Training fusion with DSRF (NPZ, hourly UTC)…")
    cfg.product = "DSRF"
    train_fusion(cfg)

    print("\n>>> Training fusion with MCMIPF (NPZ, hourly UTC)…")
    cfg.product = "MCMIPF"
    train_fusion(cfg)
