In [1]:
"""
Single-file optimizer with:
  • NGBoost surrogate (overflow-safe, quiet)
  • Instability-aware proposer (filters candidates with a stability classifier)
  • Integrated stability gate (short probe run + param backoff) before full sim
  • Log-based instability detection
  • Checkpoint: optimization_ngboost.pkl

Usage: python optimize_ngboost_stable.py
Requires: ngboost, scikit-learn, scipy, numpy, matplotlib

Tip: If your model already has CFL control, leave it; the stability gate
     still catches explosive parameter sets cheaply before long runs.
"""

import os
import re
import io
import pickle
import contextlib
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np
from scipy.stats import qmc
from scipy.ndimage import uniform_filter
import matplotlib.pyplot as plt

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern, ConstantKernel
from sklearn.linear_model import LogisticRegression

# NGBoost surrogate
try:
    from ngboost import NGBRegressor
    from ngboost.distns import Normal
    from ngboost.scores import CRPS
    _HAS_NGBOOST = True
except Exception:
    _HAS_NGBOOST = False

# Keep import for compatibility even if unused directly here
from qg_model import QGTwoLayerModel  # noqa: F401

# =============================== Bounds =============================== #
PARAM_BOUNDS = {
    'viscosity_scale':      {'bounds': (0.5, 5.0),   'type': 'linear'},
    'drag_scale':           {'bounds': (0.5, 3.0),   'type': 'linear'},
    'eddy_diffusivity':     {'bounds': (1e3, 1e5),   'type': 'log'},
    'smagorinsky_coeff':    {'bounds': (0.0, 0.3),   'type': 'linear'},
    'energy_correction':    {'bounds': (-0.01, 0.01),'type': 'linear'},
    'enstrophy_correction': {'bounds': (0.0, 1e-6),  'type': 'log'},
}
PARAM_NAMES = list(PARAM_BOUNDS.keys())
N_PARAMS = len(PARAM_NAMES)
CKPT_FILE = 'optimization_ngboost.pkl'

# ========================= Stability Config ========================== #
@dataclass
class STAB_CFG:
    probe_days: float = 2.0           # short cheap run
    probe_damp_nu4: float = 2.0       # multiply ν4 during probe
    probe_damp_drag: float = 1.5      # multiply drag during probe
    max_fail_ratio: float = 0.80      # if failures exceed, tighten proposer filter

STAB = STAB_CFG()

# =============================== LHS ================================= #

def generate_latin_hypercube_samples(n_samples, bounds_dict):
    n_params = len(bounds_dict)
    sampler = qmc.LatinHypercube(d=n_params, seed=32)
    unit = sampler.random(n=n_samples)
    samples = np.zeros_like(unit)
    for i, (_, info) in enumerate(bounds_dict.items()):
        lo, hi = info['bounds']
        if info['type'] == 'log':
            llo = np.log10(lo) if lo > 0 else -10
            lhi = np.log10(hi)
            samples[:, i] = 10 ** (unit[:, i] * (lhi - llo) + llo)
        else:
            samples[:, i] = unit[:, i] * (hi - lo) + lo
    return samples

# ====================== Simulation + Loss & Logs ===================== #

def _make_config_with_params(config_base, params_array):
    config = config_base.copy()
    sub = {PARAM_NAMES[i]: float(params_array[i]) for i in range(len(PARAM_NAMES))}
    # Allow overrides from caller
    for k in ('nu4_scale','drag_scale'):
        if 'subgrid_params_override_multipliers' in config and k in config['subgrid_params_override_multipliers']:
            if k in sub:
                sub[k] *= config['subgrid_params_override_multipliers'][k]
            else:
                sub[k] = config['subgrid_params_override_multipliers'][k]
    config['subgrid_params'] = sub
    return config

FAIL_REGEX = r"unstable|\*\*\*\s*unstable|unstable at step|nan encountered|overflow"


def _run_simulation_capture(config, sim_days):
    from main_comparison import run_simulation  # late import
    buf = io.StringIO()
    with contextlib.redirect_stdout(buf), contextlib.redirect_stderr(buf):
        results = run_simulation(config, sim_days=sim_days, save_interval_hours=12)
    logs = buf.getvalue()
    return results, logs


def run_lowres_with_params(params_array, config_base, highres_results, sim_days=180, save_outputs=True,
                           fail_regex: str = FAIL_REGEX, echo_logs: bool = False):
    print(f"\n{'='*70}\nTesting parameters:")
    for i, n in enumerate(PARAM_NAMES):
        print(f"  {n}: {float(params_array[i]):.6e}")

    try:
        results, logs = _run_simulation_capture(_make_config_with_params(config_base, params_array), sim_days)
    except Exception as e:
        print(f"  ✗ Simulation raised: {e}")
        import traceback; traceback.print_exc()
        return np.nan, None, None

    if echo_logs and logs:
        print(logs)
    if re.search(fail_regex, logs, flags=re.IGNORECASE):
        print("  ✗ Detected instability in logs → fail")
        return np.nan, None, None

    try:
        loss, detailed = compute_loss(results, highres_results, return_fields=True)
        if not np.isfinite(loss):
            print(f"  ✗ Non-finite loss: {loss}")
            return np.nan, None, None
        print(f"  ✓ Loss: {loss:.6f}")
        return loss, results, detailed
    except Exception as e:
        print(f"  ✗ Loss computation failed: {e}")
        import traceback; traceback.print_exc()
        return np.nan, None, None


# -------- Stability Gate: cheap probe then full eval, with backoff ----- #

def evaluate_with_stability_gate(params_array, config_base, highres_results, full_days=180):
    # 1) probe with stronger damping
    probe_cfg = config_base.copy()
    probe_cfg['subgrid_params_override_multipliers'] = {
        'nu4_scale': STAB.probe_damp_nu4,
        'drag_scale': STAB.probe_damp_drag,
    }
    try:
        probe_res, probe_logs = _run_simulation_capture(_make_config_with_params(probe_cfg, params_array), STAB.probe_days)
        if re.search(FAIL_REGEX, probe_logs, flags=re.IGNORECASE):
            raise RuntimeError('probe unstable')
    except Exception:
        # backoff toward mid-box
        safe = backoff_params(params_array, frac=0.5)
        try:
            probe_res, probe_logs = _run_simulation_capture(_make_config_with_params(probe_cfg, safe), STAB.probe_days)
            if re.search(FAIL_REGEX, probe_logs, flags=re.IGNORECASE):
                return np.nan, None, None
        except Exception:
            return np.nan, None, None

    # 2) full eval
    return run_lowres_with_params(params_array, config_base, highres_results, sim_days=full_days, echo_logs=False)


def backoff_params(p, frac=0.5):
    p = np.asarray(p, float).copy()
    lbs = np.array([v['bounds'][0] for v in PARAM_BOUNDS.values()])
    ubs = np.array([v['bounds'][1] for v in PARAM_BOUNDS.values()])
    mids = 0.5*(lbs+ubs)
    return mids + frac*(p - mids)

# ================================ Loss ================================= #

def compute_loss(lowres_results, highres_results, n_days_avg=30, return_fields=False):
    nx_hr = highres_results['config']['nx']; ny_hr = highres_results['config']['ny']
    nx_lr = lowres_results['config']['nx']; ny_lr = lowres_results['config']['ny']
    print(f"  High-res grid: {nx_hr}x{ny_hr}\n  Low-res grid:  {nx_lr}x{ny_lr}")
    if nx_hr % nx_lr != 0 or ny_hr % ny_lr != 0:
        raise ValueError("Grid not evenly divisible")
    fx = nx_hr // nx_lr; fy = ny_hr // ny_lr
    times_hr = highres_results['times']; times_lr = lowres_results['times']
    th_hr = times_hr[-1] - n_days_avg; th_lr = times_lr[-1] - n_days_avg
    ih = np.where(times_hr >= th_hr)[0]; il = np.where(times_lr >= th_lr)[0]
    q1h = np.mean([highres_results['q1_history'][i] for i in ih], axis=0)
    q2h = np.mean([highres_results['q2_history'][i] for i in ih], axis=0)
    q1l = np.mean([lowres_results['q1_history'][i] for i in il], axis=0)
    q2l = np.mean([lowres_results['q2_history'][i] for i in il], axis=0)

    mh = highres_results['model']; ml = lowres_results['model']
    psi1h, psi2h = mh.q_to_psi(q1h, q2h); psi1l, psi2l = ml.q_to_psi(q1l, q2l)
    H1, H2 = mh.H1, mh.H2; Ht = H1+H2
    q_bth = (H1*q1h + H2*q2h)/Ht; psi_bth = (H1*psi1h + H2*psi2h)/Ht
    q_btl = (H1*q1l + H2*q2l)/Ht; psi_btl = (H1*psi1l + H2*psi2l)/Ht

    def coarsen(a, fx, fy):
        return uniform_filter(a, size=(fy, fx), mode='wrap')[::fy, ::fx]

    qh_c = coarsen(q_bth, fx, fy); psih_c = coarsen(psi_bth, fx, fy)
    if qh_c.shape != q_btl.shape:
        raise ValueError("Shape mismatch after coarsening")

    def nrmse(pred, tgt):
        mse = np.mean((pred-tgt)**2); std = np.std(tgt)
        return np.sqrt(mse)/(std+1e-20)

    lq = nrmse(q_btl, qh_c); lp = nrmse(psi_btl, psih_c)
    total = 0.6*lq + 0.4*lp
    if return_fields:
        return total, {
            'q_bt_hr_coarse': qh_c, 'psi_bt_hr_coarse': psih_c,
            'q_bt_lr': q_btl, 'psi_bt_lr': psi_btl,
            'loss_q_bt': lq, 'loss_psi_bt': lp, 'total_loss': total,
        }
    return total

# ======================= NGBoost + Stable Proposer ==================== #
class InstabilityAwareProposer:
    """NGBoost EI + simple stability classifier filter.
       Classifier: LogisticRegression trained on valid(1)/failed(0) labels.
       We discard candidate points with P(stable) < thresh.
    """
    def __init__(self, bounds_dict, n_candidates=4096, kappa=1.5, min_train=10,
                 winsor_p: Tuple[float,float]=(1.0, 99.0), sigma_cap_mult: float = 10.0,
                 learning_rate: float = 0.03, n_estimators: int = 600,
                 base_thresh: float = 0.30, max_thresh: float = 0.70):
        if not _HAS_NGBOOST:
            raise ImportError("NGBoost not installed. Please `pip install ngboost`.")
        self.names = list(bounds_dict.keys())
        self.lb = np.array([v['bounds'][0] for v in bounds_dict.values()], float)
        self.ub = np.array([v['bounds'][1] for v in bounds_dict.values()], float)
        self.is_log = np.array([v['type']=='log' for v in bounds_dict.values()], bool)
        self.n_candidates = n_candidates
        self.kappa = kappa
        self.min_train = min_train
        self.base_thresh = base_thresh
        self.max_thresh = max_thresh
        self.rng = np.random.default_rng(0)

        # NGBoost
        self.y_mean = 0.0; self.y_std = 1.0
        self.winsor_p = winsor_p
        self.sigma_cap_mult = sigma_cap_mult
        self.learning_rate = learning_rate
        self.n_estimators = n_estimators
        self.ngb = None

        # Stability classifier
        self.clf = LogisticRegression(max_iter=1000, class_weight='balanced')
        self.clf_ready = False

    # ---- space transforms ---- #
    def _to_model_space(self, X):
        X = np.asarray(X, float).copy()
        for i, lg in enumerate(self.is_log):
            if lg:
                X[..., i] = np.log10(np.clip(X[..., i], 1e-300, None))
        return X
    def _from_unit(self, U):
        X = np.empty_like(U)
        for i, (lbi, ubi) in enumerate(zip(self.lb, self.ub)):
            if self.is_log[i]:
                X[:, i] = 10 ** (np.log10(lbi) + U[:, i]*(np.log10(ubi)-np.log10(lbi)))
            else:
                X[:, i] = lbi + U[:, i]*(ubi-lbi)
        return X
    def _lhs(self, n):
        s = qmc.LatinHypercube(d=len(self.names), seed=int(self.rng.integers(1,1e9)))
        return self._from_unit(s.random(n))

    # ---- target scaling ---- #
    def _winsorize(self, y):
        lo, hi = np.percentile(y, self.winsor_p)
        return np.clip(y, lo, hi)
    def _standardize_targets(self, y):
        y = self._winsorize(y.copy())
        self.y_mean = float(np.nanmean(y)); self.y_std = float(np.nanstd(y))
        if not np.isfinite(self.y_std) or self.y_std < 1e-8:
            self.y_std = 1.0
        return (y - self.y_mean) / self.y_std
    def _destandardize(self, mu, sigma):
        mu = mu * self.y_std + self.y_mean
        sigma = np.maximum(sigma * self.y_std, 1e-8)
        cap = self.sigma_cap_mult * self.y_std
        sigma = np.clip(sigma, 1e-8, cap)
        return mu, sigma

    # ---- fit models ---- #
    def _fit_ngb(self, X, y):
        Xm, ym_raw = self._to_model_space(X), y
        ym = self._standardize_targets(ym_raw)
        self.ngb = NGBRegressor(Dist=Normal, Score=CRPS, n_estimators=self.n_estimators,
                                learning_rate=self.learning_rate, natural_gradient=True,
                                minibatch_frac=1.0, random_state=123, verbose=False)
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", message=r".*overflow.*")
            warnings.filterwarnings("ignore", message=r".*divide by zero.*")
            buf = io.StringIO()
            with contextlib.redirect_stdout(buf), contextlib.redirect_stderr(buf):
                self.ngb.fit(Xm, ym)
    def _fit_clf(self, X, y):
        valid = np.isfinite(y)
        labels = valid.astype(int)
        if labels.sum() == 0:
            self.clf_ready = False
            return
        Xms = self._to_model_space(X)
        self.clf.fit(Xms, labels)
        self.clf_ready = True

    # ---- predict ---- #
    def _mu_sigma(self, Xc):
        preds = self.ngb.pred_dist(self._to_model_space(Xc))
        mu, sigma = self._destandardize(preds.loc, preds.scale)
        mu = np.where(np.isfinite(mu), mu, np.nanmedian(mu))
        sigma = np.where(np.isfinite(sigma), sigma, np.nanmedian(sigma[sigma>0]) if np.any(sigma>0) else 1.0)
        return mu, np.clip(sigma, 1e-8, self.sigma_cap_mult*max(self.y_std,1.0))
    def _p_stable(self, Xc):
        if not self.clf_ready:
            return np.ones(len(Xc)) * 0.5
        ps = self.clf.predict_proba(self._to_model_space(Xc))[:,1]
        ps = np.where(np.isfinite(ps), ps, 0.5)
        return ps

    # ---- propose ---- #
    def propose(self, X_hist, y_hist, iter_idx=0, total_iter=100, use='ei'):
        X_hist = np.asarray(X_hist); y_hist = np.asarray(y_hist)
        m = np.isfinite(y_hist)
        if m.sum() < 10:
            return self._lhs(1)[0]
        self._fit_ngb(X_hist[m], y_hist[m])
        self._fit_clf(X_hist, y_hist)

        C = self._lhs(self.n_candidates)
        ps = self._p_stable(C)
        # Dynamic threshold: tighten as failures dominate or as iterations progress
        fail_ratio = 1.0 - (np.isfinite(y_hist).sum() / max(len(y_hist),1))
        t_prog = min(1.0, (iter_idx+1)/max(total_iter,1))
        thresh = self.base_thresh + (self.max_thresh - self.base_thresh) * max(fail_ratio, t_prog)
        keep = ps >= thresh
        if not np.any(keep):
            # fallback: take top-K by p_stable
            topk = np.argsort(-ps)[:max(32, self.n_candidates//16)]
            Ck = C[topk]
        else:
            Ck = C[keep]

        mu, sigma = self._mu_sigma(Ck)
        if use == 'lcb':
            score = mu - 1.5 * sigma
            return Ck[int(np.nanargmin(score))]
        from scipy.stats import norm
        y_best = np.nanmin(y_hist)
        imp = y_best - mu
        Z = imp / sigma
        ei = np.where(sigma>0, imp*norm.cdf(Z) + sigma*norm.pdf(Z), 0.0)
        ei = np.where(np.isfinite(ei), ei, -1e9)
        return Ck[int(np.nanargmax(ei))]

# ============================= Optimizer ============================== #
class GPOptimizer:
    def __init__(self, bounds_dict, n_initial_samples=12, proposer=None, checkpoint_file: str = CKPT_FILE):
        self.bounds_dict = bounds_dict
        self.n_initial_samples = n_initial_samples
        self.proposer = proposer
        self.checkpoint_file = checkpoint_file
        self.X_samples = []
        self.y_samples = []
        self.detailed_outputs = []
        self.best_loss = np.inf
        self.best_params = None
        self.best_iteration = -1
        self.iteration = 0
        self.gp = GaussianProcessRegressor(kernel=ConstantKernel(1.0)*Matern(nu=2.5), alpha=1e-6,
                                           normalize_y=True, n_restarts_optimizer=10, random_state=42)

    def random_sample(self):
        s = np.zeros(len(self.bounds_dict))
        for i, (_, info) in enumerate(self.bounds_dict.items()):
            lo, hi = info['bounds']
            if info['type'] == 'log':
                s[i] = 10 ** (np.random.uniform(np.log10(lo), np.log10(hi)))
            else:
                s[i] = np.random.uniform(lo, hi)
        return s

    def initialize_samples(self):
        print(f"\n{'='*70}\nINITIALIZING WITH LATIN HYPERCUBE SAMPLING\n{'='*70}")
        return generate_latin_hypercube_samples(self.n_initial_samples, self.bounds_dict)

    def optimize(self, config_base, highres_results, max_iterations=100):
        if len(self.X_samples) < self.n_initial_samples:
            init = self.initialize_samples()
            for i, p in enumerate(init):
                print(f"\n{'='*70}\nInitial sample {i+1}/{self.n_initial_samples}\n{'='*70}")
                loss, res, det = evaluate_with_stability_gate(p, config_base, highres_results, full_days=180)
                self.X_samples.append(p); self.y_samples.append(loss); self.detailed_outputs.append(det)
                self._maybe_best(loss, p)
                self.save_progress()
        start = max(self.iteration+1, self.n_initial_samples)

        print(f"\n{'='*70}\nBAYESIAN OPTIMIZATION PHASE\n{'='*70}")
        for it in range(start, max_iterations):
            self.iteration = it
            print(f"\n{'='*70}\nIteration {it+1}/{max_iterations}\n{'='*70}")
            try:
                if self.proposer is not None:
                    x_next = self.proposer.propose(np.array(self.X_samples), np.array(self.y_samples), iter_idx=it, total_iter=max_iterations, use='ei')
                else:
                    x_next = self.random_sample()
                loss, res, det = evaluate_with_stability_gate(x_next, config_base, highres_results, full_days=180)
                self.X_samples.append(x_next); self.y_samples.append(loss); self.detailed_outputs.append(det)
                self._maybe_best(loss, x_next)
                self._print_status()
                self.save_progress()
                if (it+1) % 10 == 0:
                    try: self.plot_progress()
                    except Exception as e: print(f"  ⚠ Plotting failed: {e}")
            except KeyboardInterrupt:
                print("Interrupted — saving and exiting."); self.save_progress(); break
            except Exception as e:
                print(f"  ⚠ Iteration error: {e}")
                import traceback; traceback.print_exc()
                self.X_samples.append(self.random_sample()); self.y_samples.append(np.nan); self.detailed_outputs.append(None)
                self.save_progress(); continue
        return self.get_best_params()

    def _maybe_best(self, loss, params):
        if np.isfinite(loss) and loss < self.best_loss:
            self.best_loss = loss; self.best_params = params.copy(); self.best_iteration = len(self.X_samples)-1
            print(f"  ★ New best loss: {loss:.6f}")

    def _print_status(self):
        y = np.array(self.y_samples); n_valid = int(np.isfinite(y).sum()); n_failed = len(y)-n_valid
        print(f"\n  Status: {n_valid} successful, {n_failed} failed simulations")
        print(f"  Best loss so far: {self.best_loss:.6f} (iteration {self.best_iteration + 1})")

    def get_best_params(self):
        if self.best_params is None:
            raise ValueError("No valid parameters found during optimization!")
        return {PARAM_NAMES[i]: float(self.best_params[i]) for i in range(len(PARAM_NAMES))}

    def save_progress(self, filename: Optional[str] = None):
        fn = filename or CKPT_FILE
        data = {
            'X_samples': self.X_samples,
            'y_samples': self.y_samples,
            'detailed_outputs': self.detailed_outputs,
            'best_loss': self.best_loss,
            'best_params': self.best_params,
            'best_iteration': self.best_iteration,
            'iteration': self.iteration,
            'bounds_dict': self.bounds_dict,
            'n_initial_samples': self.n_initial_samples,
        }
        with open(fn, 'wb') as f: pickle.dump(data, f)
        print(f"  ✓ Progress saved to {fn}")

    @classmethod
    def load_progress(cls, filename: str = CKPT_FILE, proposer=None):
        print(f"Loading checkpoint from {filename}…")
        with open(filename, 'rb') as f: data = pickle.load(f)
        opt = cls(bounds_dict=data['bounds_dict'], n_initial_samples=data['n_initial_samples'], proposer=proposer, checkpoint_file=filename)
        opt.X_samples = data['X_samples']; opt.y_samples = data['y_samples']
        opt.detailed_outputs = data.get('detailed_outputs', [None]*len(data['y_samples']))
        opt.best_loss = data['best_loss']; opt.best_params = data['best_params']
        opt.best_iteration = data.get('best_iteration', -1); opt.iteration = data['iteration']
        n = len(opt.X_samples); n_valid = int(np.isfinite(opt.y_samples).sum())
        print(f"✓ Loaded checkpoint: total={n}, valid={n_valid}, failed={n-n_valid}, current_iter={opt.iteration+1}, best_loss={opt.best_loss:.6f} (iter {opt.best_iteration+1})")
        return opt

    def plot_progress(self, filename='optimization_progress_final.png'):
        n_total = 2 + len(PARAM_NAMES); n_cols = 3; n_rows = (n_total + n_cols - 1)//n_cols
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(18, 5*n_rows))
        if n_rows == 1: axes = axes.reshape(1,-1)
        axes = axes.flatten()
        y = np.array(self.y_samples); valid = np.isfinite(y)
        it = np.arange(len(y)); itv = it[valid]; yv = y[valid]
        ax = axes[0]
        if len(yv)>0:
            ax.plot(itv, yv, 'o', alpha=0.6, label='Valid')
            if np.any(~valid):
                ax.plot(it[~valid], np.full((~valid).sum(), np.nanmax(yv)*1.1 if len(yv)>0 else 1.0), 'x', color='red', label='Failed')
            ax.axhline(self.best_loss, color='g', ls='--', lw=2, label=f'Best: {self.best_loss:.4f}')
        ax.set_title('Loss vs Iteration'); ax.legend(); ax.grid(True, alpha=0.3)
        ax = axes[1]
        best_so_far = []
        cur = np.inf
        for L in y:
            if np.isfinite(L) and L < cur: cur = L
            best_so_far.append(cur if np.isfinite(cur) else np.nan)
        bsf = np.array(best_so_far); vb = np.isfinite(bsf)
        if np.any(vb): ax.plot(it[vb], bsf[vb], 'g-', lw=2)
        ax.set_title('Cumulative Best Loss'); ax.grid(True, alpha=0.3)
        X = np.array(self.X_samples) if len(self.X_samples)>0 else np.zeros((0,len(PARAM_NAMES)))
        for i, name in enumerate(PARAM_NAMES):
            ax = axes[i+2]
            if i+2 >= len(axes): break
            if len(X)>0:
                ax.plot(it, X[:,i], 'o-', alpha=0.4, ms=4)
                if self.best_params is not None:
                    ax.axhline(self.best_params[i], color='r', ls='--', lw=2, label=f'Best: {self.best_params[i]:.4e}')
            ax.set_title(name); ax.grid(True, alpha=0.3); ax.legend(fontsize=8)
        for j in range(n_total, len(axes)): axes[j].axis('off')
        plt.tight_layout(); plt.savefig(filename, dpi=150, bbox_inches='tight'); print(f"  ✓ Progress plot saved to {filename}"); plt.close()

# =============================== Main ================================ #

def export_results(optimizer, filename='optimization_results_detailed.pkl'):
    print(f"\nExporting detailed results to {filename}…")
    results = {
        'metadata': {
            'n_total_iterations': len(optimizer.X_samples),
            'n_successful': int(np.isfinite(optimizer.y_samples).sum()),
            'n_failed': int((~np.isfinite(np.array(optimizer.y_samples))).sum()),
            'best_loss': optimizer.best_loss,
            'best_iteration': optimizer.best_iteration,
            'parameter_names': PARAM_NAMES,
            'bounds': optimizer.bounds_dict if hasattr(optimizer,'bounds_dict') else PARAM_BOUNDS,
        },
        'all_iterations': [],
        'best_result': None,
    }
    for i, (p, L, det) in enumerate(zip(optimizer.X_samples, optimizer.y_samples, optimizer.detailed_outputs)):
        results['all_iterations'].append({
            'iteration': i,
            'parameters': {name: float(p[j]) for j, name in enumerate(PARAM_NAMES)},
            'loss': float(L) if np.isfinite(L) else None,
            'is_valid': bool(np.isfinite(L)),
            'detailed_outputs': det,
        })
    if optimizer.best_params is not None:
        results['best_result'] = {
            'iteration': optimizer.best_iteration,
            'parameters': {name: float(optimizer.best_params[j]) for j, name in enumerate(PARAM_NAMES)},
            'loss': optimizer.best_loss,
            'detailed_outputs': optimizer.detailed_outputs[optimizer.best_iteration] if optimizer.best_iteration < len(optimizer.detailed_outputs) else None,
        }
    with open(filename, 'wb') as f: pickle.dump(results, f)
    print("✓ Detailed results exported.")


def main(checkpoint_file: str = CKPT_FILE, max_iterations: int = 100):
    print("\n" + "="*70)
    print("NGBoost Optimization with Stability Gate & Classifier Filter")
    print("="*70)

    if not os.path.exists('highres_results.pkl'):
        print("\n✗ Error: highres_results.pkl not found! Run main_comparison.py first.")
        return
    with open('highres_results.pkl', 'rb') as f: H = pickle.load(f)
    print(f"✓ Loaded high-res: {H['config']['nx']}x{H['config']['ny']}, {H['times'][-1]:.1f} days")

    from main_comparison import config_lowres
    config_base = config_lowres.copy()

    proposer = InstabilityAwareProposer(PARAM_BOUNDS, n_candidates=4096, base_thresh=0.30, max_thresh=0.70)

    if os.path.exists(checkpoint_file):
        print(f"\n{'='*70}\nCHECKPOINT FOUND: {checkpoint_file}\n{'='*70}")
        opt = GPOptimizer.load_progress(checkpoint_file, proposer=proposer)
    else:
        print(f"\n{'='*70}\nSTARTING NEW OPTIMIZATION\n{'='*70}")
        opt = GPOptimizer(bounds_dict=PARAM_BOUNDS, n_initial_samples=12, proposer=proposer, checkpoint_file=checkpoint_file)

    best = opt.optimize(config_base=config_base, highres_results=H, max_iterations=max_iterations)

    print("\n" + "="*70)
    print("OPTIMIZATION COMPLETE")
    print("="*70)
    y = np.array(opt.y_samples); n_valid = int(np.isfinite(y).sum()); n_failed = len(y)-n_valid
    print(f"\nTotal simulations: {len(y)}\n  Successful: {n_valid}\n  Failed: {n_failed}")
    print(f"\nBest loss: {opt.best_loss:.6f}\nBest iteration: {opt.best_iteration + 1}")
    print("\nOptimal parameters:")
    for pname, val in best.items(): print(f"  '{pname}': {val:.6e},")

    try: opt.plot_progress(filename='optimization_progress_final.png')
    except Exception as e: print(f"⚠ Final plotting failed ({e})")
    export_results(opt, filename='optimization_results_detailed.pkl')
    with open('optimal_params.pkl', 'wb') as f: pickle.dump(best, f)
    print("\n✓ Optimal parameters saved to optimal_params.pkl")

    if opt.best_iteration >= 0 and opt.best_iteration < len(opt.detailed_outputs):
        det = opt.detailed_outputs[opt.best_iteration]
        if det is not None:
            print(f"\nBest result details:\n  PV NRMSE: {det['loss_q_bt']:.6f}\n  Psi NRMSE: {det['loss_psi_bt']:.6f}\n  Total: {det['total_loss']:.6f}")

    return opt, best


if __name__ == '__main__':
    main(checkpoint_file=CKPT_FILE, max_iterations=100)



NGBoost Optimization with Stability Gate & Classifier Filter
✓ Loaded high-res: 512x256, 180.0 days

STARTING NEW OPTIMIZATION

INITIALIZING WITH LATIN HYPERCUBE SAMPLING

Initial sample 1/12
  ✓ Progress saved to optimization_ngboost.pkl

Initial sample 2/12
  ✓ Progress saved to optimization_ngboost.pkl

Initial sample 3/12
  ✓ Progress saved to optimization_ngboost.pkl

Initial sample 4/12
  ✓ Progress saved to optimization_ngboost.pkl

Initial sample 5/12
  ✓ Progress saved to optimization_ngboost.pkl

Initial sample 6/12
  ✓ Progress saved to optimization_ngboost.pkl

Initial sample 7/12
  ✓ Progress saved to optimization_ngboost.pkl

Initial sample 8/12
  ✓ Progress saved to optimization_ngboost.pkl

Initial sample 9/12
  ✓ Progress saved to optimization_ngboost.pkl

Initial sample 10/12
  ✓ Progress saved to optimization_ngboost.pkl

Initial sample 11/12
  ✓ Progress saved to optimization_ngboost.pkl

Initial sample 12/12
  ✓ Progress saved to optimization_ngboost.pkl

BAYESIAN

  X[:, i] = 10 ** (np.log10(lbi) + U[:, i]*(np.log10(ubi)-np.log10(lbi)))
  X[:, i] = 10 ** (np.log10(lbi) + U[:, i]*(np.log10(ubi)-np.log10(lbi)))



  Status: 0 successful, 18 failed simulations
  Best loss so far: inf (iteration 0)
  ✓ Progress saved to optimization_ngboost.pkl

Iteration 19/100

  Status: 0 successful, 19 failed simulations
  Best loss so far: inf (iteration 0)
  ✓ Progress saved to optimization_ngboost.pkl

Iteration 20/100

  Status: 0 successful, 20 failed simulations
  Best loss so far: inf (iteration 0)
  ✓ Progress saved to optimization_ngboost.pkl


  ax.set_title('Loss vs Iteration'); ax.legend(); ax.grid(True, alpha=0.3)
  ax.set_title(name); ax.grid(True, alpha=0.3); ax.legend(fontsize=8)


  ✓ Progress plot saved to optimization_progress_final.png

Iteration 21/100

  Status: 0 successful, 21 failed simulations
  Best loss so far: inf (iteration 0)
  ✓ Progress saved to optimization_ngboost.pkl

Iteration 22/100

  Status: 0 successful, 22 failed simulations
  Best loss so far: inf (iteration 0)
  ✓ Progress saved to optimization_ngboost.pkl

Iteration 23/100

  Status: 0 successful, 23 failed simulations
  Best loss so far: inf (iteration 0)
  ✓ Progress saved to optimization_ngboost.pkl

Iteration 24/100

  Status: 0 successful, 24 failed simulations
  Best loss so far: inf (iteration 0)
  ✓ Progress saved to optimization_ngboost.pkl

Iteration 25/100

  Status: 0 successful, 25 failed simulations
  Best loss so far: inf (iteration 0)
  ✓ Progress saved to optimization_ngboost.pkl

Iteration 26/100

  Status: 0 successful, 26 failed simulations
  Best loss so far: inf (iteration 0)
  ✓ Progress saved to optimization_ngboost.pkl

Iteration 27/100

  Status: 0 successful

ValueError: No valid parameters found during optimization!