In [19]:
# cpi_pipeline.py
# CPI prediction & presentation pipeline with Import-weighted tariff features and robust calibration.
# (Fixes: robust HS2 column detection in Import.xlsx; datetime column headers; HS2 padding to 2 digits.)

import os, io, re, glob, json, zipfile, argparse, warnings
from typing import List, Dict, Tuple, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import RandomForestRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error
from sklearn.inspection import permutation_importance
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression

warnings.filterwarnings("ignore")

# XGBoost (fallback to GBR if missing)
XGB_PRESENT = True
try:
    from xgboost import XGBRegressor
except Exception:
    XGB_PRESENT = False
    from sklearn.ensemble import GradientBoostingRegressor as XGBRegressor

from statsmodels.tsa.statespace.sarimax import SARIMAX


# ---------------- CLI ----------------
def parse_args(argv=None):
    p = argparse.ArgumentParser(description="CPI Nowcasting & Presentation Pipeline")
    p.add_argument("--base_dir", default=os.environ.get("BASE_DIR", "."), help="Folder with CPI.csv etc.")
    p.add_argument("--out_dir", default=None, help="Output directory (default: <base_dir>/outputs)")
    p.add_argument("--auto_sarima", action="store_true", help="Small SARIMA grid by AIC.")
    p.add_argument("--test_months", type=int, default=24, help="Test split length (months).")
    p.add_argument("--coverage", type=float, default=0.80, help="Min coverage for most features.")
    p.add_argument("--protected_cov", type=float, default=0.50, help="Lower coverage for Import__/tariff_/TariffWgt__.")
    p.add_argument("--min_nonnull_import_tariff", type=int, default=6, help="Min non-nulls to keep Import__/tariff_*/TariffWgt__.")
    p.add_argument("--pi_max_features", type=int, default=60, help="Max features for permutation importance.")
    p.add_argument("--target_mom", action="store_true", help="Train ML on CPI MoM and reconstruct Level.")
    p.add_argument("--model_start", default="2010-01", help="First month for the modeling window.")
    args, _ = p.parse_known_args(argv)  # ignore Jupyter -f
    return args


# ---------------- I/O helpers ----------------
def parse_date_col(df: pd.DataFrame) -> pd.Series:
    for col in df.columns:
        try:
            parsed = pd.to_datetime(df[col], errors="coerce", infer_datetime_format=True)
            if parsed.notna().mean() > 0.7:
                return parsed
        except Exception:
            pass
    idx = pd.to_datetime(df.index, errors="coerce", infer_datetime_format=True)
    if idx.notna().mean() > 0.7:
        return idx
    raise ValueError("Could not infer a date column.")

def try_numeric(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    for c in out.columns:
        if out[c].dtype == "O":
            out[c] = (out[c].astype(str)
                      .str.replace(",", "", regex=False)
                      .str.replace("%", "", regex=False))
            out[c] = pd.to_numeric(out[c], errors="coerce")
    out = out.loc[:, out.notna().any(axis=0)]
    return out

def robust_read(path: str) -> Dict[str, pd.DataFrame]:
    ext = os.path.splitext(path)[1].lower()
    out: Dict[str, pd.DataFrame] = {}
    if ext == ".csv":
        try:
            out["main"] = pd.read_csv(path)
        except UnicodeDecodeError:
            out["main"] = pd.read_csv(path, encoding="latin-1")
    elif ext == ".xlsx":
        x = pd.ExcelFile(path)  # openpyxl
        for s in x.sheet_names:
            try: out[s] = x.parse(s)
            except Exception: pass
    elif ext == ".xls":
        try:
            import xlrd  # noqa
            x = pd.ExcelFile(path, engine="xlrd")
            for s in x.sheet_names:
                try: out[s] = x.parse(s)
                except Exception: pass
        except Exception:
            return {}
    else:
        return {}
    return out


# -------------- wide/long monthly unifiers --------------
def unify_monthly_series(dfs: Dict[str, pd.DataFrame], prefer_sum: bool = False) -> pd.DataFrame:
    frames = []
    for name, df in dfs.items():
        try:
            date_idx = parse_date_col(df)
            df = df.copy(); df.index = pd.to_datetime(date_idx)
            drop_candidates = [c for c in df.columns if any(k in str(c).lower() for k in ["date", "time", "period", "month"])]
            df = df.drop(columns=drop_candidates, errors="ignore")
            df = try_numeric(df).sort_index()
            agg = "sum" if prefer_sum else "mean"
            m = df.resample("MS").agg(agg)
            m.columns = [f"{name}__{c}" for c in m.columns]
            frames.append(m)
        except Exception:
            continue
    if not frames: return pd.DataFrame()
    out = pd.concat(frames, axis=1).sort_index()
    return out.groupby(level=0, axis=1).first()

def _date_like_cols(cols):
    parsed = []
    for c in cols:
        try:
            dt = pd.to_datetime(str(c), errors="coerce")
            if pd.notna(dt):
                ms = pd.Period(dt, freq="M").to_timestamp('M') + pd.offsets.MonthBegin(1)
                parsed.append((c, ms))
        except Exception:
            pass
    return parsed

def melt_wide_months(df: pd.DataFrame, prefer_sum=False) -> pd.DataFrame:
    candidates = _date_like_cols(df.columns)
    if len(candidates) < 6:  # need several months at least
        return pd.DataFrame()
    col_map = {orig: m for orig, m in candidates}
    tmp = try_numeric(df[list(col_map.keys())].copy())
    agg = np.nansum if prefer_sum else np.nanmean
    vec = agg(tmp.values, axis=0)
    ser = pd.Series(vec, index=[col_map[c] for c in tmp.columns], name="value").sort_index()
    return ser.resample("MS").mean().to_frame()

def unify_monthly_any_format(dfs: Dict[str, pd.DataFrame], prefer_sum=False, prefix=None) -> pd.DataFrame:
    frames = []
    for name, df in dfs.items():
        got = unify_monthly_series({name: df}, prefer_sum=prefer_sum)
        if got.empty or got.shape[0] <= 2:
            m = melt_wide_months(df, prefer_sum=prefer_sum)
            if not m.empty:
                m.columns = [f"{name}__{c}" for c in m.columns]
                frames.append(m)
        else:
            frames.append(got)
    if not frames: return pd.DataFrame()
    out = pd.concat(frames, axis=1).sort_index()
    out = out.groupby(level=0, axis=1).first()
    if prefix: out = out.add_prefix(prefix)
    return out


# ---------------- Tariff ZIP reader ----------------
def read_tariff_zip(zip_path: str) -> pd.DataFrame:
    if not os.path.exists(zip_path): return pd.DataFrame()
    monthly_list, hs2_list = [], []
    with zipfile.ZipFile(zip_path, "r") as z:
        for member in z.namelist():
            if not member.lower().endswith((".csv", ".xlsx", ".xls")): continue
            with z.open(member) as f: data = f.read()
            buf = io.BytesIO(data)

            df = None
            if member.lower().endswith(".csv"):
                for enc in [None, "latin-1"]:
                    try:
                        df = pd.read_csv(buf if enc is None else io.BytesIO(data),
                                         encoding=None if enc is None else "latin-1")
                        break
                    except Exception: df = None
            elif member.lower().endswith(".xlsx"):
                try: df = pd.read_excel(buf)
                except Exception: df = None
            elif member.lower().endswith(".xls"):
                try:
                    import xlrd  # noqa
                    df = pd.read_excel(buf, engine="xlrd")
                except Exception: df = None
            if df is None or df.empty: continue

            # Date
            date_col = None
            for c in df.columns:
                if any(k in str(c).lower() for k in ["date", "period", "month"]):
                    parsed = pd.to_datetime(df[c], errors="coerce")
                    if parsed.notna().mean() > 0.5: date_col = c; break
            if date_col is None:
                idx = pd.to_datetime(df.index, errors="coerce")
                if idx.notna().mean() > 0.5: df = df.copy(); df.index = idx
                else: continue
            else:
                df = df.copy(); df.index = pd.to_datetime(df[date_col]); df = df.drop(columns=[date_col], errors="ignore")

            # Rate columns
            cand = [c for c in df.columns if any(k in str(c).lower() for k in ["tariff", "duty", "rate", "advalorem"])]
            if not cand: continue
            numeric = try_numeric(df[cand])
            rates = numeric.mean(axis=1, skipna=True)
            monthly_list.append(rates.resample("MS").mean().rename("tariff_overall"))

            # HS2 split
            hs_col = None
            for c in df.columns:
                if re.search(r"\b(HS|HTS).*(CODE)?\b", str(c), re.I): hs_col = c; break
            if hs_col is None:
                for c in df.columns:
                    if "code" in str(c).lower(): hs_col = c; break
            if hs_col is not None:
                tmp = pd.DataFrame({
                    "date": df.index,
                    "hs": df[hs_col].astype(str).str.extract(r"(\d{1,2})", expand=False).str.zfill(2),
                    "rate": rates.values,
                }).dropna(subset=["hs"])
                tmp["date"] = pd.to_datetime(tmp["date"])
                grp = tmp.groupby([pd.Grouper(key="date", freq="MS"), "hs"])["rate"].mean().reset_index()
                piv = grp.pivot(index="date", columns="hs", values="rate").add_prefix("tariff_hs2_")
                hs2_list.append(piv)

    out = None
    if monthly_list:
        out = pd.concat(monthly_list, axis=1).mean(axis=1).to_frame("tariff_overall")
    if hs2_list:
        hs2 = pd.concat(hs2_list, axis=1).groupby(level=0, axis=1).mean()
        out = hs2.join(out, how="outer") if out is not None else hs2
    return out if out is not None else pd.DataFrame()


# ------------- Import.xlsx → HS2 monthly values -------------
def _extract_hs_col(df: pd.DataFrame) -> Optional[str]:
    # FIX: always treat column names as strings
    for c in df.columns:
        cl = str(c).lower()
        if re.search(r"\b(hs|hts).*code\b", cl) or cl in {"hs", "hts", "code", "commodity code", "hs code", "hts code", "chapter"}:
            return c
    for c in df.columns:
        if "code" in str(c).lower() or str(c).lower() in {"hs2", "hs_2", "chapter"}:
            return c
    return None

def read_import_hs2(import_path: str) -> pd.DataFrame:
    """Return monthly HS2 import values: columns like imp_hs2_01..imp_hs2_99 and Import__TOTAL."""
    if not os.path.exists(import_path): return pd.DataFrame()
    dfs = robust_read(import_path)
    frames = []
    for name, df in dfs.items():
        if df is None or df.empty: continue
        hs_col = _extract_hs_col(df)
        if hs_col:
            # WIDE case: rows are HS, columns are months (many datetime-like headers)
            date_cols = _date_like_cols(df.columns)
            if date_cols:
                col_map = {orig: m for orig, m in date_cols}
                cols = [hs_col] + list(col_map.keys())
                tmp = df[cols].copy()
                tmp.columns = [hs_col] + [col_map[c] for c in col_map.keys()]
                tmp = try_numeric(tmp)
                tmp["hs2"] = (tmp[hs_col].astype(str)
                              .str.extract(r"(\d{1,2})", expand=False)
                              .str.zfill(2))
                tmp = tmp.dropna(subset=["hs2"])
                melt = tmp.drop(columns=[hs_col]).melt(id_vars="hs2", var_name="date", value_name="value")
                melt["date"] = pd.to_datetime(melt["date"])
                grp = melt.groupby([pd.Grouper(key="date", freq="MS"), "hs2"])["value"].sum().reset_index()
                piv = grp.pivot(index="date", columns="hs2", values="value").add_prefix("imp_hs2_")
                frames.append(piv)
                continue
            # LONG case: hs/date/value in columns
            maybe_date = None
            for c in df.columns:
                if any(k in str(c).lower() for k in ["date", "period", "month"]):
                    maybe_date = c; break
            val_cols = [c for c in df.columns if c not in {hs_col, maybe_date} and df[c].dtype != "O"]
            if maybe_date and val_cols:
                tmp = df[[hs_col, maybe_date] + val_cols].copy()
                tmp["hs2"] = (tmp[hs_col].astype(str)
                              .str.extract(r"(\d{1,2})", expand=False)
                              .str.zfill(2))
                tmp[maybe_date] = pd.to_datetime(tmp[maybe_date])
                tmp["value"] = try_numeric(tmp[val_cols]).sum(axis=1, min_count=1)
                grp = tmp.groupby([pd.Grouper(key=maybe_date, freq="MS"), "hs2"])["value"].sum().reset_index()
                piv = grp.pivot(index=maybe_date, columns="hs2", values="value").add_prefix("imp_hs2_")
                frames.append(piv)
                continue
        # Fallback: if sheet has only totals by month
        m = unify_monthly_any_format({name: df}, prefer_sum=True)
        if not m.empty:
            m.columns = [c if str(c).startswith("Import__") else f"Import__{name}__{c}" for c in m.columns]
            frames.append(m)
    if not frames: return pd.DataFrame()
    imp = pd.concat(frames, axis=1).sort_index()
    # collapse duplicated columns and build TOTAL
    imp = imp.groupby(level=0, axis=1).sum(min_count=1)
    hs2_cols = [c for c in imp.columns if c.startswith("imp_hs2_")]
    if hs2_cols:
        imp["Import__TOTAL"] = imp[hs2_cols].sum(axis=1, min_count=1)
    elif "Import__TOTAL" not in imp.columns:
        imp["Import__TOTAL"] = imp.sum(axis=1, min_count=1)
    return imp


def build_import_weighted_tariffs(tariffs: pd.DataFrame, imp_hs2: pd.DataFrame, topn: int = 12) -> pd.DataFrame:
    """Create TariffWgt__Index and top-N HS2 weighted contributions: share(HSt) * tariff(HSt)."""
    if tariffs is None or tariffs.empty or imp_hs2 is None or imp_hs2.empty:
        return pd.DataFrame()
    hs2_tar_cols = [c for c in tariffs.columns if re.match(r"tariff_hs2_\d{2}$", c)]
    hs2_imp_cols = [c for c in imp_hs2.columns if re.match(r"imp_hs2_\d{2}$", c)]
    if not hs2_tar_cols or not hs2_imp_cols:
        return pd.DataFrame()

    # Align to monthly MS frequency
    t = tariffs[hs2_tar_cols].copy().asfreq("MS")
    m = imp_hs2[hs2_imp_cols + (["Import__TOTAL"] if "Import__TOTAL" in imp_hs2.columns else [])].copy().asfreq("MS")
    if "Import__TOTAL" not in m.columns:
        m["Import__TOTAL"] = m[hs2_imp_cols].sum(axis=1, min_count=1)

    # Shares
    shares = m[hs2_imp_cols].div(m["Import__TOTAL"].replace(0, np.nan), axis=0)
    shares = shares.fillna(0.0)

    # Pairs in common
    def pair_name(c_imp): return c_imp.replace("imp_", "tariff_")
    imp_common = [c for c in hs2_imp_cols if pair_name(c) in t.columns]
    if not imp_common:
        return pd.DataFrame()
    tar_common = [pair_name(c) for c in imp_common]

    # Weighted index
    w_contribs = shares[imp_common].values * t[tar_common].values
    index = pd.Series(np.nansum(w_contribs, axis=1), index=t.index, name="TariffWgt__Index")

    # Top-N HS2 by average share
    avg_share = shares[imp_common].mean().sort_values(ascending=False)
    top_imp_cols = list(avg_share.head(topn).index)
    top_tar_cols = [pair_name(c) for c in top_imp_cols]
    top_w = pd.DataFrame(shares[top_imp_cols].values * t[top_tar_cols].values, index=t.index,
                         columns=[f"TariffWgt__{c[-2:]}" for c in top_imp_cols])

    out = pd.concat([index, top_w], axis=1)
    return out


# ---------------- Metrics & transforms ----------------
def annualize_qoq_saar(last3_mom: List[float]) -> float:
    try: return float(np.prod([1 + m for m in last3_mom])**4 - 1.0)
    except Exception: return np.nan

def compute_qoq_saar_series(level_series: pd.Series) -> pd.Series:
    mom = level_series.pct_change()
    out = []
    for i in range(len(level_series)):
        if i < 3 or mom.iloc[i-2:i+1].isna().any(): out.append(np.nan)
        else: out.append(annualize_qoq_saar(mom.iloc[i-2:i+1].values))
    return pd.Series(out, index=level_series.index, name=f"{level_series.name}_QoQ_SAAR")

def rmse(a, b): 
    return float(mean_squared_error(a, b) ** 0.5)


# ---------------- Data assembly & features ----------------
def build_feature_table(base_dir: str, model_start: str) -> Tuple[pd.Series, pd.DataFrame, Dict]:
    files = {
        "cpi": os.path.join(base_dir, "CPI.csv"),
        "usd": os.path.join(base_dir, "USD Index.csv"),
        "rent": os.path.join(base_dir, "Rent Price.csv"),
        "house": os.path.join(base_dir, "Housing Price.csv"),
        "oil": os.path.join(base_dir, "Oil.xls"),
        "freight": os.path.join(base_dir, "Freight.xlsx"),
        "import": os.path.join(base_dir, "Import.xlsx"),
    }
    zip_paths = sorted(set(glob.glob(os.path.join(base_dir, "tariff*.zip")) +
                           glob.glob(os.path.join(base_dir, "tariff_data_*.zip"))))

    audit = {"base_dir": base_dir, "found_files": {}, "notes": [], "tariff_zips_used": []}

    def read_any(path, prefer_sum=False, label=None, prefix=None):
        label = label or os.path.basename(path)
        exists = os.path.exists(path)
        audit["found_files"][label] = {"exists": exists, "path": path}
        if not exists: return pd.DataFrame()
        m = unify_monthly_any_format(robust_read(path), prefer_sum=prefer_sum, prefix=prefix)
        audit["found_files"][label]["shape"] = list(m.shape) if m is not None else None
        return m

    # Target
    cpi_df = unify_monthly_any_format(robust_read(files["cpi"]), prefer_sum=False)
    cpi_cols = [c for c in cpi_df.columns if cpi_df[c].notna().sum() > 10]
    if not cpi_cols: raise RuntimeError("CPI.csv had no usable numeric columns.")
    y = cpi_df[cpi_cols[0]].rename("CPI")

    # Exogenous (macro/freight/etc.)
    feats = []
    feats.append(read_any(files["usd"],    False, "USD Index.csv",     "USD__"))
    feats.append(read_any(files["rent"],   False, "Rent Price.csv",    "Rent__"))
    feats.append(read_any(files["house"],  False, "Housing Price.csv", "House__"))
    feats.append(read_any(files["oil"],    False, "Oil.xls",           "Oil__"))
    feats.append(read_any(files["freight"],False, "Freight.xlsx",      "Freight__"))

    # Imports & tariffs
    imp_wide = read_any(files["import"], True, "Import.xlsx", "Import__")  # totals (if any)
    imp_hs2 = read_import_hs2(files["import"])  # HS2 values
    if not imp_wide.empty and "Import__AGG" not in imp_wide.columns:
        imp_wide["Import__AGG"] = imp_wide.filter(regex=r"^Import__").sum(axis=1, min_count=1)
    tariffs_list = []
    for zp in zip_paths:
        t = read_tariff_zip(zp); lab = os.path.basename(zp)
        audit["found_files"][lab] = {"exists": os.path.exists(zp), "path": zp, "shape": list(t.shape) if t is not None else None}
        if not t.empty: tariffs_list.append(t); audit["tariff_zips_used"].append(lab)
    tariffs = pd.concat(tariffs_list, axis=1).groupby(level=0, axis=1).mean() if tariffs_list else pd.DataFrame()

    # Import-weighted tariffs
    tariff_wgt = build_import_weighted_tariffs(tariffs, imp_hs2, topn=12)

    # Merge exogenous blocks
    X = None
    for block in [imp_wide, imp_hs2, tariffs, tariff_wgt] + feats:
        if block is None or block.empty: continue
        X = block if X is None else X.join(block, how="outer")

    if X is None or X.empty: raise RuntimeError("No exogenous features parsed.")

    # Align monthly & trim
    df = pd.concat([y, X], axis=1).sort_index()
    df = df[~df.index.duplicated(keep="first")].asfreq("MS").ffill().dropna(thresh=2)
    start_ts = pd.to_datetime(model_start) if model_start else df.index.min()
    df = df.loc[df.index >= start_ts]

    audit["target_span"] = [str(df.index.min().date()), str(df.index.max().date())]
    exog_raw = df.drop(columns=["CPI"])
    audit["exog_columns_raw_count"] = exog_raw.shape[1]
    audit["exog_span"] = [str(df.index.min().date()), str(df.index.max().date())]
    audit["import_weighted_tariff_present"] = bool(tariff_wgt is not None and not tariff_wgt.empty)

    return df["CPI"], exog_raw, audit


def engineer_features(y: pd.Series, X: pd.DataFrame,
                      coverage: float = 0.80, protected_cov: float = 0.50,
                      min_nonnull_import_tariff: int = 6):
    """Numeric-only exog; protect Import__/tariff_/TariffWgt__; replace ±∞→NaN; transforms+lags (incl. tariffs)."""
    X_num = X.select_dtypes(include=[np.number]).copy()
    df = pd.concat([y.rename("CPI"), X_num], axis=1).sort_index()

    cov = X_num.notna().mean(); nn = X_num.notna().sum()
    is_special = (X_num.columns.str.startswith("Import__")
                  | X_num.columns.str.contains("tariff_", case=False)
                  | X_num.columns.str.startswith("TariffWgt__"))

    keep_general   = cov[cov >= coverage].index.tolist()
    keep_protected = cov[is_special & ((cov >= protected_cov) | (nn >= min_nonnull_import_tariff))].index.tolist()
    exog_keep = sorted(set(keep_general) | set(keep_protected))
    if not any(col.startswith("Import__") for col in exog_keep):
        exog_keep += [c for c in X_num.columns if c == "Import__AGG"]

    Xg = X_num[exog_keep].copy()

    # Target transforms
    df["CPI_mom"] = df["CPI"].pct_change()
    df["CPI_yoy"] = df["CPI"].pct_change(12)

    # Exog transforms
    for c in exog_keep:
        s = Xg[c]
        df[f"{c}_mom"]    = s.pct_change()
        df[f"{c}_yoy"]    = s.pct_change(12)
        df[f"{c}_roll3"]  = s.rolling(3).mean()
        df[f"{c}_roll6"]  = s.rolling(6).mean()
        df[f"{c}_roll12"] = s.rolling(12).mean()

    # Replace ±∞ -> NaN
    df = df.replace([np.inf, -np.inf], np.nan)

    # Lags (CPI + key channels incl. tariffs & TariffWgt__)
    key_exog = [c for c in exog_keep
                if c.startswith(("Freight__", "Oil__", "Import__", "USD__", "TariffWgt__"))
                or ("tariff_" in c.lower())]
    base_cols = ["CPI", "CPI_mom", "CPI_yoy"] + [c for c in key_exog if c in df.columns]
    for c in base_cols:
        for L in range(1, 13):
            df[f"{c}_lag{L}"] = df[c].shift(L)

    # Coverage pass & fill
    cov2 = df.notna().mean(); nn2 = df.notna().sum()
    idx_series = cov2.index.to_series()
    is_prot2 = (idx_series.str.startswith("Import__")
                | idx_series.str.contains("tariff_", case=False)
                | idx_series.str.startswith("TariffWgt__"))
    keep_cols = cov2[(cov2 >= coverage)
                     | (is_prot2 & ((cov2 >= protected_cov) | (nn2 >= min_nonnull_import_tariff)))].index
    df = df[keep_cols].ffill().bfill()

    y_ml = df["CPI"].copy()
    X_ml = df.drop(columns=["CPI"])

    # Compact exog for SARIMAX
    sarimax_exog = None
    candidates = [c for c in df.columns if c.startswith("Freight__") and (c.endswith("_mom") or c.endswith("_roll3") or c.endswith("_roll6"))]
    if candidates: sarimax_exog = df[[candidates[0]]].copy()

    prov = {
        "exog_used_count": X_ml.shape[1],
        "exog_used_sample": X_ml.columns[:20].tolist(),
        "exog_import_cols_count": int(sum(X_ml.columns.to_series().str.startswith("Import__"))),
        "exog_tariff_cols_count": int(sum(X_ml.columns.to_series().str.contains("tariff_", case=False))),
        "exog_tariffwgt_cols_count": int(sum(X_ml.columns.to_series().str.startswith("TariffWgt__")))
    }
    return y_ml, X_ml, sarimax_exog, prov


def train_test_split_time(X: pd.DataFrame, y: pd.Series, test_months: int = 24):
    if len(y) < test_months + 36:
        test_months = max(12, min(18, len(y)//4))
    split_point = y.index[-test_months]
    X_tr, X_te = X.loc[:split_point - pd.offsets.MonthBegin(0)], X.loc[split_point:]
    y_tr, y_te = y.loc[:split_point - pd.offsets.MonthBegin(0)], y.loc[split_point:]
    return X_tr, X_te, y_tr, y_te


# ---------------- Modeling ----------------
def fit_sarima(y_train: pd.Series, auto: bool = False):
    if not auto:
        order, sorder = (0, 1, 1), (0, 1, 1, 12)
        res = SARIMAX(y_train, order=order, seasonal_order=sorder,
                      enforce_stationarity=False, enforce_invertibility=False).fit(disp=False)
        return (order, sorder), res

    best_aic = np.inf; best = None; best_res = None
    for p in [0, 1, 2]:
        for q in [0, 1, 2]:
            for P in [0, 1]:
                for Q in [0, 1]:
                    try:
                        order, sorder = (p, 1, q), (P, 1, Q, 12)
                        res = SARIMAX(y_train, order=order, seasonal_order=sorder,
                                      enforce_stationarity=False, enforce_invertibility=False).fit(disp=False)
                        if res.aic < best_aic:
                            best_aic, best, best_res = res.aic, (order, sorder), res
                    except Exception: pass
    if best_res is None: return fit_sarima(y_train, auto=False)
    return best, best_res

def train_model_zoo():
    models = {
        "XGB_or_GBR": XGBRegressor(
            n_estimators=600 if XGB_PRESENT else 400,
            max_depth=4,
            learning_rate=0.05,
            subsample=0.9,
            colsample_bytree=0.9 if XGB_PRESENT else None,
            random_state=42,
            tree_method="hist" if XGB_PRESENT else None,
            enable_categorical=False
        ),
        "RandomForest": RandomForestRegressor(n_estimators=500, max_depth=8, random_state=42, n_jobs=-1),
        "KNN": make_pipeline(StandardScaler(with_mean=False), KNeighborsRegressor(n_neighbors=5, weights="distance"))
    }
    return models

def _sanitize_xy(X: pd.DataFrame, y: pd.Series) -> Tuple[pd.DataFrame, pd.Series]:
    Xc = X.replace([np.inf, -np.inf], np.nan)
    yc = y.replace([np.inf, -np.inf], np.nan)
    Xc = Xc.fillna(Xc.mean()); yc = yc.fillna(yc.mean())
    return Xc.astype(np.float64), yc.astype(np.float64)

def fit_predict_zoo(models: Dict[str, object], X_train, y_train, X_test) -> Dict[str, pd.Series]:
    out = {}
    X_tr, y_tr = _sanitize_xy(X_train, y_train)
    X_te = X_test.replace([np.inf, -np.inf], np.nan).fillna(X_tr.mean()).astype(np.float64)
    for name, m in models.items():
        m.fit(X_tr, y_tr)
        out[name] = pd.Series(m.predict(X_te), index=X_test.index, name=f"{name}_pred")
    return out

def refine_sarimax(y_train, y_test, sarimax_exog):
    if sarimax_exog is None: return None, None
    exog_all = sarimax_exog.reindex(pd.concat([y_train, y_test]).index).replace([np.inf, -np.inf], np.nan).ffill()
    exog_tr = exog_all.loc[y_train.index]; exog_te = exog_all.loc[y_test.index]
    try:
        res = SARIMAX(y_train, order=(1,1,1), seasonal_order=(0,1,1,12),
                      exog=exog_tr, enforce_stationarity=False, enforce_invertibility=False).fit(disp=False)
        fc_te = res.get_forecast(steps=len(y_test), exog=exog_te).predicted_mean
        # In-sample predicted mean aligned with y_train
        fc_tr = res.get_prediction(start=y_train.index[0], end=y_train.index[-1]).predicted_mean
        return pd.Series(fc_tr, index=y_train.index, name="SARIMAX_exog_in"), pd.Series(fc_te, index=y_test.index, name="SARIMAX_exog_pred")
    except Exception:
        return None, None


# -------------- Calibration & CIs --------------
def robust_sigma(resid: pd.Series) -> float:
    resid = resid.dropna()
    if len(resid) == 0: return 0.0
    mad = np.median(np.abs(resid - np.median(resid)))
    return float(max(resid.std(ddof=1), 1.4826 * mad))  # conservative

def calibrate_bias(y_train: pd.Series,
                   sarima_train_pred: pd.Series,
                   sarimax_train_pred: Optional[pd.Series],
                   zoo_train_mean: Optional[pd.Series],
                   weights: dict,
                   window: int = 12):
    w_sarima  = float(weights.get("SARIMA", 0.0))
    w_sarimax = float(weights.get("SARIMAX_exog", 0.0))
    w_zoo     = float(sum(v for k, v in weights.items() if k not in ["SARIMA", "SARIMAX_exog"]))

    # Compose train-side ensemble using the *same* weights
    grand_train = (w_sarima * sarima_train_pred.reindex(y_train.index)).fillna(0.0)
    if sarimax_train_pred is not None:
        grand_train = grand_train.add(w_sarimax * sarimax_train_pred.reindex(y_train.index), fill_value=0.0)
    if zoo_train_mean is not None and not zoo_train_mean.empty:
        grand_train = grand_train.add(w_zoo * zoo_train_mean.reindex(y_train.index), fill_value=0.0)

    idx = y_train.dropna().index.intersection(grand_train.dropna().index)
    if len(idx) < max(6, window): window = max(6, len(idx)//2 or 6)
    tail = idx[-window:]

    # OLS slope + intercept
    X = grand_train.loc[tail].values.reshape(-1, 1)
    y = y_train.loc[tail].values
    lr = LinearRegression()
    lr.fit(X, y)
    a_raw = float(lr.intercept_); b_raw = float(lr.coef_[0])

    # Guardrails: clamp slope or fallback to intercept-only if unstable
    if not np.isfinite(b_raw) or b_raw < 0.7 or b_raw > 1.3:
        b = 1.0
        a = float(np.median(y - X.ravel()))  # robust mean shift
    else:
        a, b = a_raw, b_raw

    grand_adj = a + b * grand_train
    sigma = robust_sigma(y_train.loc[tail] - grand_adj.loc[tail])
    return a, b, grand_adj, sigma


# -------------- Evaluation & plotting --------------
def compute_metrics(y_test: pd.Series, pred_map: Dict[str, Optional[pd.Series]]):
    # Level
    rows = []
    for name, p in pred_map.items():
        if p is None: continue
        idx = y_test.dropna().index.intersection(p.dropna().index)
        if len(idx) < 1: continue
        a, b = y_test.loc[idx], p.loc[idx]
        rows.append({"Model": name, "MAE_level": mean_absolute_error(a, b), "RMSE_level": rmse(a, b)})
    level = pd.DataFrame(rows).set_index("Model").sort_values("RMSE_level") if rows else pd.DataFrame(columns=["MAE_level","RMSE_level"])

    # MoM
    actual_mom = y_test.pct_change()
    rows = []
    for name, p in pred_map.items():
        if p is None: continue
        m = p.pct_change()
        idx = actual_mom.dropna().index.intersection(m.dropna().index)
        if len(idx) < 1: continue
        a, b = actual_mom.loc[idx], m.loc[idx]
        rows.append({"Model": name, "MAE_MoM": mean_absolute_error(a, b), "RMSE_MoM": rmse(a, b)})
    mom = pd.DataFrame(rows).set_index("Model").sort_values("RMSE_MoM") if rows else pd.DataFrame(columns=["MAE_MoM","RMSE_MoM"])

    # YoY
    actual_yoy = y_test.pct_change(12)
    rows = []
    for name, p in pred_map.items():
        if p is None: continue
        m = p.pct_change(12)
        idx = actual_yoy.dropna().index.intersection(m.dropna().index)
        if len(idx) < 1: continue
        a, b = actual_yoy.loc[idx], m.loc[idx]
        rows.append({"Model": name, "MAE_YoY": mean_absolute_error(a, b), "RMSE_YoY": rmse(a, b)})
    yoy = pd.DataFrame(rows).set_index("Model").sort_values("RMSE_YoY") if rows else pd.DataFrame(columns=["MAE_YoY","RMSE_YoY"])

    # QoQ SAAR
    actual_qoq = compute_qoq_saar_series(y_test)
    rows = []
    for name, p in pred_map.items():
        if p is None: continue
        q = compute_qoq_saar_series(p)
        idx = actual_qoq.dropna().index.intersection(q.dropna().index)
        if len(idx) < 1: continue
        a, b = actual_qoq.loc[idx], q.loc[idx]
        rows.append({"Model": name, "MAE_QoQ_SAAR": mean_absolute_error(a, b), "RMSE_QoQ_SAAR": rmse(a, b)})
    qoq = pd.DataFrame(rows).set_index("Model").sort_values("RMSE_QoQ_SAAR") if rows else pd.DataFrame(columns=["MAE_QoQ_SAAR","RMSE_QoQ_SAAR"])

    return level, mom, yoy, qoq

def save_top_features_bar(series: pd.Series, path: str, title: str, topn: int = 20):
    vals = series.sort_values(ascending=False).head(topn)
    plt.figure(figsize=(10, 6))
    vals[::-1].plot(kind="barh")
    plt.title(title); plt.tight_layout(); plt.savefig(path, dpi=150); plt.close()

def rolling_rmse(a: pd.Series, b: pd.Series, window=6) -> pd.Series:
    err = (a - b)**2
    r = err.rolling(window).mean()**0.5
    return r

def pretty_plot(dates, actual, pred, title, ylabel, path, ci95=None, ci75=None):
    plt.figure(figsize=(10,5))
    plt.plot(dates, actual, label="Actual")
    plt.plot(dates, pred, label="Predicted", linestyle="--")
    if ci95 is not None:
        plt.fill_between(dates, ci95[0], ci95[1], alpha=0.18, label="95% CI")
    if ci75 is not None:
        plt.fill_between(dates, ci75[0], ci75[1], alpha=0.18, label="75% CI")
    plt.title(title); plt.ylabel(ylabel); plt.xlabel("")
    plt.legend(); plt.tight_layout(); plt.savefig(path, dpi=150); plt.close()


# ---------------- Main ----------------
def main(args=None):
    if args is None:
        args = parse_args()

    base_dir = args.base_dir
    out_dir = args.out_dir or os.path.join(base_dir, "outputs")
    os.makedirs(out_dir, exist_ok=True)

    # 1) Data
    y, X_raw, audit = build_feature_table(base_dir, args.model_start)
    X = X_raw.select_dtypes(include=[np.number]).copy()
    audit["exog_columns_numeric_count"] = X.shape[1]
    audit["exog_columns_non_numeric_dropped"] = sorted(list(set(X_raw.columns) - set(X.columns)))

    # 2) Features
    y_ml, X_ml, sarimax_exog, feat_prov = engineer_features(
        y, X, coverage=args.coverage, protected_cov=args.protected_cov,
        min_nonnull_import_tariff=args.min_nonnull_import_tariff
    )
    audit.update(feat_prov)

    # Optional: ML on MoM
    use_mom = bool(args.target_mom)
    if use_mom:
        y_mom = y.pct_change().replace([np.inf, -np.inf], np.nan)
        X_lag1 = X.shift(1).replace([np.inf, -np.inf], np.nan)
        data = pd.concat([y_mom, X_lag1], axis=1).dropna()
        y_ml = data.iloc[:, 0].rename("CPI_mom")
        X_ml = data.iloc[:, 1:]

    # Save engineered tables
    try:
        y_ml.to_csv(os.path.join(out_dir, "y_ml.csv"))
        X_ml.to_parquet(os.path.join(out_dir, "X_ml.parquet"))
    except Exception:
        X_ml.to_csv(os.path.join(out_dir, "X_ml.csv"))

    # 3) Split
    X_train, X_test, y_train, y_test = train_test_split_time(X_ml, y_ml, test_months=args.test_months)

    # Audit: presence of special features
    audit["import_cols_in_train"]    = int(sum(X_train.columns.to_series().str.startswith("Import__")))
    audit["tariff_cols_in_train"]    = int(sum(X_train.columns.to_series().str.contains("tariff_", case=False)))
    audit["tariffwgt_cols_in_train"] = int(sum(X_train.columns.to_series().str.startswith("TariffWgt__")))
    audit["import_cols_in_test"]     = int(sum(X_test.columns.to_series().str.startswith("Import__")))
    audit["tariff_cols_in_test"]     = int(sum(X_test.columns.to_series().str.contains("tariff_", case=False)))
    audit["tariffwgt_cols_in_test"]  = int(sum(X_test.columns.to_series().str.startswith("TariffWgt__")))

    # 4) SARIMA on level y (train window)
    (order, sorder), sarima_res = fit_sarima(
        y.loc[y.index <= y_test.index[-1] - pd.offsets.MonthBegin(0)],
        auto=args.auto_sarima
    )
    sarima_fc = sarima_res.get_forecast(steps=len(y_test))
    sarima_pred_level = pd.Series(sarima_fc.predicted_mean, index=y_test.index, name="SARIMA_pred")
    # In-sample predicted mean (not fittedvalues to avoid differencing quirks)
    sarima_train_pred = sarima_res.get_prediction(start=y_train.index[0], end=y_train.index[-1]).predicted_mean
    sarima_train_pred = pd.Series(sarima_train_pred, index=y_train.index, name="SARIMA_in")

    # 5) Model zoo
    models = train_model_zoo()
    zoo_preds = fit_predict_zoo(models, X_train, y_train, X_test)

    # If ML target is MoM, reconstruct to LEVEL
    if use_mom:
        start_level = y.loc[y.index < y_test.index[0]].iloc[-1]
        def mom_to_level(mom_series: pd.Series, start_val: float) -> pd.Series:
            out, cur = [], float(start_val)
            for g in mom_series.fillna(0).values:
                cur *= (1 + float(g)); out.append(cur)
            return pd.Series(out, index=mom_series.index)
        for k in list(zoo_preds.keys()):
            zoo_preds[k] = mom_to_level(zoo_preds[k], start_level).rename(zoo_preds[k].name.replace("_pred", "_level_pred"))
        y_test_level = y.loc[y_test.index]
    else:
        y_test_level = y_test

    # 6) SARIMAX refinement (train+test)
    sarimax_train_in, sarimax_pred = refine_sarimax(y.loc[y.index < y_test.index[0]], y.loc[y_test.index], sarimax_exog)

    # 7) Weighted ensemble
    preds_dict = {"SARIMA": sarima_pred_level, "SARIMAX_exog": sarimax_pred}
    preds_dict.update(zoo_preds)

    perf = {}
    for name, p in preds_dict.items():
        if p is None: continue
        idx = y_test_level.index.intersection(p.dropna().index)
        if len(idx) < 1: continue
        perf[name] = rmse(y_test_level.loc[idx], p.loc[idx])

    weights = {k: 1.0 / (v**2 + 1e-9) for k, v in perf.items()}
    s = sum(weights.values())
    if s == 0 or not weights: weights = {"SARIMA": 1.0}
    else: weights = {k: v/s for k, v in weights.items()}

    aligned = []
    for name, p in preds_dict.items():
        if p is None or name not in weights: continue
        aligned.append(weights[name] * p.reindex(y_test_level.index))
    grand_ensemble = (pd.concat(aligned, axis=1).sum(axis=1) if aligned
                      else sarima_pred_level.copy()).rename("GrandEnsemble_pred")

    # 8) Metrics (Level/MoM/YoY/QoQ SAAR)
    metrics_inputs = {"SARIMA": sarima_pred_level, "SARIMAX_exog": sarimax_pred, **zoo_preds, "GrandEnsemble": grand_ensemble}
    level_metrics, mom_metrics, yoy_metrics, qoq_metrics = compute_metrics(y_test_level, metrics_inputs)

    # 9) Calibration & CIs (train-tail)
    # Train-side zoo preds
    zoo_train_preds = {}
    X_tr_san, y_tr_san = _sanitize_xy(X_train, y_train)
    for name, m in train_model_zoo().items():
        m.fit(X_tr_san, y_tr_san)
        zoo_train_preds[name] = pd.Series(m.predict(X_tr_san), index=X_train.index)
    if use_mom:
        seed = y.loc[y.index < X_train.index[0]].iloc[-1]
        for k, srs in zoo_train_preds.items():
            cur, vals = float(seed), []
            for g in srs.fillna(0).values:
                cur *= (1 + float(g)); vals.append(cur)
            zoo_train_preds[k] = pd.Series(vals, index=X_train.index)
    avg_zoo_train = (pd.concat(zoo_train_preds.values(), axis=1).mean(axis=1)
                     if zoo_train_preds else pd.Series(index=y.index, dtype=float))

    a_cal, b_cal, grand_train_adj, sigma_train = calibrate_bias(
        y.loc[y_train.index],
        sarima_train_pred,
        sarimax_train_in,
        avg_zoo_train,
        weights,
        window=12
    )
    pred_raw = grand_ensemble
    pred_adj = a_cal + b_cal * pred_raw
    z95, z75 = 1.96, 1.150349
    ci95_lo = pred_adj - z95 * sigma_train; ci95_hi = pred_adj + z95 * sigma_train
    ci75_lo = pred_adj - z75 * sigma_train; ci75_hi = pred_adj + z75 * sigma_train

    print(f"Calibration: a={a_cal:.3f}, b={b_cal:.4f}, sigma(train)={sigma_train:.3f}")

    # 10) Main charts — Level + YoY%
    lvl_path = os.path.join(out_dir, "cpi_level_pred_vs_actual.png")
    pretty_plot(y_test_level.index, y_test_level, pred_adj,
                "CPI Level: Actual vs Predicted (Grand Ensemble, calibrated)",
                "CPI Index (1982-84=100)", lvl_path, ci95=(ci95_lo, ci95_hi), ci75=(ci75_lo, ci75_hi))

    denom12 = y.shift(12).reindex(pred_adj.index)
    yoy_actual = (y.reindex(pred_adj.index) / denom12 - 1.0) * 100.0
    yoy_pred   = (pred_adj / denom12 - 1.0) * 100.0
    yoy_ci95_lo = (ci95_lo / denom12 - 1.0) * 100.0
    yoy_ci95_hi = (ci95_hi / denom12 - 1.0) * 100.0
    yoy_ci75_lo = (ci75_lo / denom12 - 1.0) * 100.0
    yoy_ci75_hi = (ci75_hi / denom12 - 1.0) * 100.0
    mask = denom12.notna() & y_test_level.reindex(pred_adj.index).notna()
    dates = pred_adj.index[mask]
    yoy_path = os.path.join(out_dir, "cpi_yoy_pred_vs_actual.png")
    pretty_plot(dates, yoy_actual.loc[dates], yoy_pred.loc[dates],
                "CPI YoY%: Actual vs Predicted (Grand Ensemble, calibrated)",
                "YoY (%)", yoy_path,
                ci95=(yoy_ci95_lo.loc[dates], yoy_ci95_hi.loc[dates]),
                ci75=(yoy_ci75_lo.loc[dates], yoy_ci75_hi.loc[dates]))

    # 11) Additional presentation charts
    # MoM%
    denom1 = y.shift(1).reindex(pred_adj.index)
    mom_actual = (y.reindex(pred_adj.index) / denom1 - 1.0) * 100.0
    mom_pred   = (pred_adj / denom1 - 1.0) * 100.0
    mom_ci95_lo = (ci95_lo / denom1 - 1.0) * 100.0
    mom_ci95_hi = (ci95_hi / denom1 - 1.0) * 100.0
    mom_ci75_lo = (ci75_lo / denom1 - 1.0) * 100.0
    mom_ci75_hi = (ci75_hi / denom1 - 1.0) * 100.0
    mask_mom = denom1.notna() & y_test_level.reindex(pred_adj.index).notna()
    dates_mom = pred_adj.index[mask_mom]
    mom_path = os.path.join(out_dir, "cpi_mom_pred_vs_actual.png")
    pretty_plot(dates_mom, mom_actual.loc[dates_mom], mom_pred.loc[dates_mom],
                "CPI MoM%: Actual vs Predicted (calibrated)", "MoM (%)", mom_path,
                ci95=(mom_ci95_lo.loc[dates_mom], mom_ci95_hi.loc[dates_mom]),
                ci75=(mom_ci75_lo.loc[dates_mom], mom_ci75_hi.loc[dates_mom]))

    # QoQ SAAR
    qoq_actual = compute_qoq_saar_series(y.reindex(pred_adj.index))
    qoq_pred   = compute_qoq_saar_series(pred_adj)
    qoq_ci95_lo = compute_qoq_saar_series(ci95_lo)
    qoq_ci95_hi = compute_qoq_saar_series(ci95_hi)
    qoq_ci75_lo = compute_qoq_saar_series(ci75_lo)
    qoq_ci75_hi = compute_qoq_saar_series(ci75_hi)
    mask_qoq = qoq_actual.notna() & qoq_pred.notna()
    dates_qoq = qoq_pred.index[mask_qoq]
    qoq_path = os.path.join(out_dir, "cpi_qoq_saar_pred_vs_actual.png")
    pretty_plot(dates_qoq, (qoq_actual*100).loc[dates_qoq], (qoq_pred*100).loc[dates_qoq],
                "CPI QoQ SAAR: Actual vs Predicted (calibrated)", "QoQ SAAR (%)", qoq_path,
                ci95=((qoq_ci95_lo*100).loc[dates_qoq], (qoq_ci95_hi*100).loc[dates_qoq]),
                ci75=((qoq_ci75_lo*100).loc[dates_qoq], (qoq_ci75_hi*100).loc[dates_qoq]))

    # Residuals & scatter (level)
    resid = (y_test_level - pred_adj).dropna()
    plt.figure(figsize=(10,4)); plt.plot(resid.index, resid.values); plt.title("Residuals (Actual - Pred, Level)"); plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "residuals_level.png"), dpi=150); plt.close()
    plt.figure(figsize=(5,5)); plt.scatter(y_test_level, pred_adj); 
    mn, mx = float(min(y_test_level.min(), pred_adj.min())), float(max(y_test_level.max(), pred_adj.max()))
    plt.plot([mn, mx], [mn, mx], linestyle="--"); plt.title("Actual vs Predicted (Level)"); plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "scatter_level.png"), dpi=150); plt.close()

    # Rolling RMSE (6m)
    rr = rolling_rmse(y_test_level, pred_adj, window=6)
    plt.figure(figsize=(10,4)); plt.plot(rr.index, rr.values); plt.title("Rolling RMSE (6m, Level)"); plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "rolling_rmse_6m.png"), dpi=150); plt.close()

    # 12) Train metrics
    train_rows = []
    for name, m in train_model_zoo().items():
        m.fit(X_tr_san, y_tr_san)
        p = pd.Series(m.predict(X_tr_san), index=X_train.index)
        train_rows.append({"Model": name, "Train_MAE": mean_absolute_error(y.loc[X_train.index], p),
                           "Train_RMSE": rmse(y.loc[X_train.index], p)})
    sarima_in = sarima_train_pred  # already aligned
    train_rows.append({"Model": "SARIMA", "Train_MAE": mean_absolute_error(y.loc[sarima_in.index], sarima_in),
                       "Train_RMSE": rmse(y.loc[sarima_in.index], sarima_in)})
    train_metrics = pd.DataFrame(train_rows).set_index("Model").sort_values("Train_RMSE")

    # 13) Feature importance
    best_name = level_metrics.index[0] if not level_metrics.empty else "RandomForest"
    model_for_imp = train_model_zoo().get(best_name, train_model_zoo().get("RandomForest"))
    corr_abs = X_tr_san.corrwith(y_tr_san).abs().sort_values(ascending=False)
    top_cols = corr_abs.index[:max(5, args.pi_max_features)]
    try:
        model_for_imp.fit(X_tr_san[top_cols], y_tr_san)
        X_te_san = X_test.replace([np.inf, -np.inf], np.nan).fillna(X_tr_san[top_cols].mean()).astype(np.float64)
        imp = permutation_importance(model_for_imp, X_te_san[top_cols], y_test, n_repeats=8, random_state=42, n_jobs=-1)
        imp_vals = pd.Series(imp.importances_mean, index=top_cols).sort_values(ascending=False)
        top_imp = imp_vals.head(80).to_frame("permutation_importance")
    except Exception:
        top_imp = corr_abs.head(args.pi_max_features).to_frame("proxy_importance")

    # Feature bars
    save_top_features_bar(top_imp.iloc[:,0], os.path.join(out_dir, "top_features_bar.png"),
                          "Top Feature Importances (test-window)")

    # Import/tariff attribution
    imp_imports = top_imp[top_imp.index.str.contains("Import__|tariff_|TariffWgt__", case=False, regex=True)]
    if imp_imports.empty:
        it_cols = [c for c in X_test.columns if ("Import__" in c) or ("tariff_" in c.lower()) or (c.startswith("TariffWgt__"))]
        if it_cols:
            sarima_resid_test = (y_test_level - sarima_pred_level).dropna()
            proxy = pd.Series({c: abs(sarima_resid_test.corr(X_test[c].reindex(sarima_resid_test.index))) for c in it_cols})
            imp_imports = proxy.sort_values(ascending=False).to_frame("proxy_importance")

    # 14) Save artifacts
    level_metrics.to_csv(os.path.join(out_dir, "level_metrics.csv"))
    mom_metrics.to_csv(os.path.join(out_dir, "mom_metrics.csv"))
    yoy_metrics.to_csv(os.path.join(out_dir, "yoy_metrics.csv"))
    qoq_metrics.to_csv(os.path.join(out_dir, "qoq_saar_metrics.csv"))
    train_metrics.to_csv(os.path.join(out_dir, "train_metrics.csv"))
    top_imp.to_csv(os.path.join(out_dir, "top_feature_importances.csv"))
    imp_imports.head(30).to_csv(os.path.join(out_dir, "top_import_tariff_categories.csv"))

    # Ensemble weights bar
    w_series = pd.Series(weights).sort_values(ascending=True)
    plt.figure(figsize=(8,4)); w_series.plot(kind="barh"); plt.title("Ensemble Weights"); plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "ensemble_weights.png"), dpi=150); plt.close()

    with open(os.path.join(out_dir, "summary.json"), "w") as f:
        json.dump({
            "xgb_present": XGB_PRESENT,
            "best_sarima_order": str((order, sorder)),
            "charts": {"level": lvl_path, "yoy": yoy_path, "mom": mom_path, "qoq": qoq_path},
            "weights": {k: float(v) for k,v in weights.items()},
            "calibration": {"a": a_cal, "b": b_cal, "sigma_train": sigma_train}
        }, f, indent=2)

    # 15) Data audit
    audit.update({
        "final_train_span": [str(y_train.index.min().date()), str(y_train.index.max().date())],
        "final_test_span":  [str(y_test.index.min().date()),  str(y_test.index.max().date())],
        "final_exog_numeric_count": X.shape[1],
        "features_after_engineering_count": X_ml.shape[1],
        "target_mode_for_ml": "MoM (reconstructed→Level)" if use_mom else "Level"
    })
    with open(os.path.join(out_dir, "data_used_summary.json"), "w") as f:
        json.dump(audit, f, indent=2)

    # 16) Console summary
    print(f"\n=== OUTPUTS written to {out_dir} ===\n")
    print(level_metrics if not level_metrics.empty else "No level metrics computed.")
    print("\nMoM metrics:\n", mom_metrics if not mom_metrics.empty else "N/A")
    print("\nYoY metrics:\n", yoy_metrics if not yoy_metrics.empty else "N/A")
    print("\nQoQ SAAR metrics:\n", qoq_metrics if not qoq_metrics.empty else "N/A")
    print(f"\nSpecial features in TRAIN: Import={audit['import_cols_in_train']}, Tariff={audit['tariff_cols_in_train']}, TariffWgt={audit['tariffwgt_cols_in_train']}")
    print(f"Special features in TEST:  Import={audit['import_cols_in_test']},  Tariff={audit['tariff_cols_in_test']},  TariffWgt={audit['tariffwgt_cols_in_test']}")
    print("\nTop import/tariff categories (by permutation or proxy importance):")
    print(imp_imports.head(10) if not imp_imports.empty else "No import/tariff features available in this run.")


if __name__ == "__main__":
    import sys
    if "-f" in sys.argv:  # Jupyter kernel arg
        sys.argv = [sys.argv[0]]
    main()


Calibration: a=0.103, b=1.0000, sigma(train)=86.454

=== OUTPUTS written to ./outputs ===

               MAE_level  RMSE_level
Model                               
GrandEnsemble   0.747666    0.856851
SARIMAX_exog    0.963528    1.115070
KNN             9.853637   11.006709
XGB_or_GBR      9.720245   11.063773
RandomForest   13.400760   14.916692

MoM metrics:
                 MAE_MoM  RMSE_MoM
Model                            
GrandEnsemble  0.001195  0.001509
SARIMAX_exog   0.001216  0.001530
KNN            0.002408  0.002980
RandomForest   0.003047  0.003480
XGB_or_GBR     0.003113  0.004259

YoY metrics:
                 MAE_YoY  RMSE_YoY
Model                            
GrandEnsemble  0.001857  0.002227
SARIMAX_exog   0.002119  0.002608
KNN            0.026840  0.026952
XGB_or_GBR     0.028820  0.029124
RandomForest   0.036930  0.036982

QoQ SAAR metrics:
                MAE_QoQ_SAAR  RMSE_QoQ_SAAR
Model                                     
GrandEnsemble      0.010161       0.01