In [None]:
"""
ML-Based Subgrid Parameter Optimization using Gaussian Process
Optimizes low-res parameters to match high-res PV and streamfunction fields
"""

import numpy as np
import pickle
import os
from scipy.optimize import differential_evolution
from scipy.stats import qmc
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern, RBF, ConstantKernel
import matplotlib.pyplot as plt
from qg_model import QGTwoLayerModel
from scipy.ndimage import uniform_filter

# ============================================================================
# PARAMETER BOUNDS
# ============================================================================

PARAM_BOUNDS = {
    'viscosity_scale': {
        'bounds': (0.5, 5.0),
        'type': 'linear',
        'description': 'Multiplies hyperviscosity coefficient',
    },
    'drag_scale': {
        'bounds': (0.5, 3.0),
        'type': 'linear',
        'description': 'Multiplies Ekman drag coefficient',
    },
    'eddy_diffusivity': {
        'bounds': (1e3, 1e5),
        'type': 'log',
        'description': 'Additional biharmonic diffusion (m²/s)',
    },
    'smagorinsky_coeff': {
        'bounds': (0.0, 0.3),
        'type': 'linear',
        'description': 'Smagorinsky eddy viscosity coefficient',
    },
    'energy_correction': {
        'bounds': (-0.01, 0.01),
        'type': 'linear',
        'description': 'Energy backscatter correction',
    },
    'enstrophy_correction': {
        'bounds': (0.0, 1e-6),
        'type': 'log',
        'description': 'Enstrophy dissipation rate',
    },
}

PARAM_NAMES = list(PARAM_BOUNDS.keys())
N_PARAMS = len(PARAM_NAMES)

# ============================================================================
# LATIN HYPERCUBE SAMPLING
# ============================================================================

def generate_latin_hypercube_samples(n_samples, bounds_dict):
    """
    Generate Latin Hypercube samples for initial exploration
    
    Parameters:
    -----------
    n_samples : int
        Number of samples to generate
    bounds_dict : dict
        Dictionary of parameter bounds
    
    Returns:
    --------
    samples : np.ndarray (n_samples, n_params)
        Parameter samples
    """
    n_params = len(bounds_dict)
    sampler = qmc.LatinHypercube(d=n_params, seed=42)
    unit_samples = sampler.random(n=n_samples)
    
    # Transform to actual parameter space
    samples = np.zeros_like(unit_samples)
    
    for i, (param_name, param_info) in enumerate(bounds_dict.items()):
        lower, upper = param_info['bounds']
        
        if param_info['type'] == 'log':
            # Log-uniform sampling
            log_lower = np.log10(lower) if lower > 0 else -10
            log_upper = np.log10(upper)
            samples[:, i] = 10 ** (unit_samples[:, i] * (log_upper - log_lower) + log_lower)
        else:
            # Linear sampling
            samples[:, i] = unit_samples[:, i] * (upper - lower) + lower
    
    return samples

# ============================================================================
# SIMULATION RUNNER
# ============================================================================

def run_lowres_with_params(params_array, config_base, highres_results, sim_days=180, save_outputs=True):
    """
    Run low-res simulation with given parameters and compute loss
    
    Parameters:
    -----------
    params_array : np.ndarray (n_params,)
        Parameter values
    config_base : dict
        Base low-res configuration
    highres_results : dict
        High-res results for comparison
    sim_days : int
        Simulation duration
    save_outputs : bool
        Whether to save model outputs
    
    Returns:
    --------
    loss : float
        Combined loss for PV and streamfunction (NaN if failed)
    results : dict or None
        Simulation results (None if failed)
    detailed_outputs : dict or None
        Detailed outputs including fields used for loss computation
    """
    from main_comparison import run_simulation, create_initial_conditions
    
    # Create config with new parameters
    config = config_base.copy()
    
    subgrid_params = {}
    for i, param_name in enumerate(PARAM_NAMES):
        subgrid_params[param_name] = float(params_array[i])
    
    config['subgrid_params'] = subgrid_params
    
    print(f"\n{'='*70}")
    print(f"Testing parameters:")
    for param_name, val in subgrid_params.items():
        print(f"  {param_name}: {val:.6e}")
    
    try:
        # Run simulation
        results = run_simulation(config, sim_days=sim_days, save_interval_hours=12)
        
        # Compute loss and get detailed outputs
        loss, detailed_outputs = compute_loss(results, highres_results, return_fields=True)
        
        # Check if loss is valid
        if not np.isfinite(loss):
            print(f"  ✗ Loss is not finite: {loss}")
            return np.nan, None, None
        
        print(f"  Loss: {loss:.6f}")
        
        return loss, results, detailed_outputs
        
    except Exception as e:
        print(f"  ✗ Simulation failed: {e}")
        import traceback
        traceback.print_exc()
        return np.nan, None, None

# ============================================================================
# LOSS COMPUTATION
# ============================================================================

def compute_loss(lowres_results, highres_results, n_days_avg=30, return_fields=False):
    """
    Compute loss between low-res and high-res (coarsened)
    Focus ONLY on barotropic PV and streamfunction (last 30 days mean)
    
    Parameters:
    -----------
    lowres_results : dict
        Low-res simulation results
    highres_results : dict
        High-res simulation results
    n_days_avg : int
        Number of days to average (default: 30)
    return_fields : bool
        Whether to return the fields used for loss computation
    
    Returns:
    --------
    loss : float
        Combined loss for barotropic PV and streamfunction
    detailed_outputs : dict (optional)
        Fields used for loss computation if return_fields=True
    """
    # Verify dimensions
    nx_hr = highres_results['config']['nx']
    ny_hr = highres_results['config']['ny']
    nx_lr = lowres_results['config']['nx']
    ny_lr = lowres_results['config']['ny']
    
    print(f"  High-res grid: {nx_hr}x{ny_hr}")
    print(f"  Low-res grid:  {nx_lr}x{ny_lr}")
    
    # Check if coarsening is possible
    if nx_hr % nx_lr != 0 or ny_hr % ny_lr != 0:
        raise ValueError(f"High-res grid ({nx_hr}x{ny_hr}) not evenly divisible by "
                        f"low-res grid ({nx_lr}x{ny_lr})")
    
    coarsen_factor_x = nx_hr // nx_lr
    coarsen_factor_y = ny_hr // ny_lr
    
    print(f"  Coarsening factor: {coarsen_factor_x}x in X, {coarsen_factor_y}x in Y")
    
    # Get time indices for last n_days_avg
    times_hr = highres_results['times']
    times_lr = lowres_results['times']
    
    time_threshold_hr = times_hr[-1] - n_days_avg
    time_threshold_lr = times_lr[-1] - n_days_avg
    
    indices_hr = np.where(times_hr >= time_threshold_hr)[0]
    indices_lr = np.where(times_lr >= time_threshold_lr)[0]
    
    print(f"  Averaging over last {n_days_avg} days:")
    print(f"    High-res: {len(indices_hr)} snapshots")
    print(f"    Low-res:  {len(indices_lr)} snapshots")
    
    # Average fields over last 30 days
    q1_hr_avg = np.mean([highres_results['q1_history'][i] for i in indices_hr], axis=0)
    q2_hr_avg = np.mean([highres_results['q2_history'][i] for i in indices_hr], axis=0)
    q1_lr_avg = np.mean([lowres_results['q1_history'][i] for i in indices_lr], axis=0)
    q2_lr_avg = np.mean([lowres_results['q2_history'][i] for i in indices_lr], axis=0)
    
    print(f"  High-res field shapes: q1={q1_hr_avg.shape}, q2={q2_hr_avg.shape}")
    print(f"  Low-res field shapes:  q1={q1_lr_avg.shape}, q2={q2_lr_avg.shape}")
    
    # Compute streamfunctions
    model_hr = highres_results['model']
    model_lr = lowres_results['model']
    
    psi1_hr_avg, psi2_hr_avg = model_hr.q_to_psi(q1_hr_avg, q2_hr_avg)
    psi1_lr_avg, psi2_lr_avg = model_lr.q_to_psi(q1_lr_avg, q2_lr_avg)
    
    # Compute ONLY barotropic components (depth-weighted average)
    H1, H2 = model_hr.H1, model_hr.H2
    H_total = H1 + H2
    
    q_bt_hr = (H1 * q1_hr_avg + H2 * q2_hr_avg) / H_total
    psi_bt_hr = (H1 * psi1_hr_avg + H2 * psi2_hr_avg) / H_total
    
    q_bt_lr = (H1 * q1_lr_avg + H2 * q2_lr_avg) / H_total
    psi_bt_lr = (H1 * psi1_lr_avg + H2 * psi2_lr_avg) / H_total
    
    # Coarsen high-res to low-res grid for fair comparison
    def coarsen(field, factor_x, factor_y):
        """Coarsen field with potentially different factors in x and y"""
        # Apply box filter
        filtered = uniform_filter(field, size=(factor_y, factor_x), mode='wrap')
        # Subsample
        return filtered[::factor_y, ::factor_x]
    
    q_bt_hr_coarse = coarsen(q_bt_hr, coarsen_factor_x, coarsen_factor_y)
    psi_bt_hr_coarse = coarsen(psi_bt_hr, coarsen_factor_x, coarsen_factor_y)
    
    print(f"  Coarsened high-res shapes: q_bt={q_bt_hr_coarse.shape}, psi_bt={psi_bt_hr_coarse.shape}")
    
    # Verify shapes match
    if q_bt_hr_coarse.shape != q_bt_lr.shape:
        raise ValueError(f"Shape mismatch after coarsening! HR coarse: {q_bt_hr_coarse.shape}, "
                        f"LR: {q_bt_lr.shape}")
    
    # Compute NRMSE (Normalized Root Mean Square Error)
    def nrmse(pred, target):
        mse = np.mean((pred - target)**2)
        std = np.std(target)
        return np.sqrt(mse) / (std + 1e-20)
    
    loss_q_bt = nrmse(q_bt_lr, q_bt_hr_coarse)
    loss_psi_bt = nrmse(psi_bt_lr, psi_bt_hr_coarse)
    
    print(f"  Barotropic PV NRMSE:           {loss_q_bt:.6f}")
    print(f"  Barotropic Streamfunction NRMSE: {loss_psi_bt:.6f}")
    
    # Weighted combination (equal weights for PV and streamfunction)
    # PV slightly more important since it's the prognostic variable
    weight_pv = 0.6
    weight_psi = 0.4
    
    total_loss = weight_pv * loss_q_bt + weight_psi * loss_psi_bt
    
    if return_fields:
        detailed_outputs = {
            'q_bt_hr_coarse': q_bt_hr_coarse,
            'psi_bt_hr_coarse': psi_bt_hr_coarse,
            'q_bt_lr': q_bt_lr,
            'psi_bt_lr': psi_bt_lr,
            'loss_q_bt': loss_q_bt,
            'loss_psi_bt': loss_psi_bt,
            'total_loss': total_loss,
            'weight_pv': weight_pv,
            'weight_psi': weight_psi,
        }
        return total_loss, detailed_outputs
    
    return total_loss

# ============================================================================
# GAUSSIAN PROCESS OPTIMIZER
# ============================================================================

class GPOptimizer:
    """Gaussian Process-based Bayesian Optimization"""
    
    def __init__(self, bounds_dict, n_initial_samples=10):
        """
        Initialize GP optimizer
        
        Parameters:
        -----------
        bounds_dict : dict
            Parameter bounds
        n_initial_samples : int
            Number of initial Latin Hypercube samples
        """
        self.bounds_dict = bounds_dict
        self.n_params = len(bounds_dict)
        self.n_initial_samples = n_initial_samples
        
        # Storage
        self.X_samples = []  # Parameter samples
        self.y_samples = []  # Loss values
        self.detailed_outputs = []  # Detailed outputs for each sample
        self.best_loss = np.inf
        self.best_params = None
        self.best_iteration = -1
        self.iteration = 0
        
        # GP model
        kernel = ConstantKernel(1.0) * Matern(length_scale=1.0, nu=2.5)
        self.gp = GaussianProcessRegressor(
            kernel=kernel,
            alpha=1e-6,
            normalize_y=True,
            n_restarts_optimizer=10,
            random_state=42
        )
    
    def random_sample(self):
        """
        Generate a random sample within bounds
        
        Returns:
        --------
        sample : np.ndarray
            Random parameter sample
        """
        sample = np.zeros(self.n_params)
        for i, (param_name, param_info) in enumerate(self.bounds_dict.items()):
            lower, upper = param_info['bounds']
            if param_info['type'] == 'log':
                log_lower = np.log10(lower) if lower > 0 else -10
                log_upper = np.log10(upper)
                sample[i] = 10 ** (np.random.uniform(log_lower, log_upper))
            else:
                sample[i] = np.random.uniform(lower, upper)
        return sample
        
    def initialize_samples(self):
        """Generate initial Latin Hypercube samples"""
        print(f"\n{'='*70}")
        print(f"INITIALIZING WITH LATIN HYPERCUBE SAMPLING")
        print(f"{'='*70}")
        print(f"Generating {self.n_initial_samples} initial samples...")
        
        samples = generate_latin_hypercube_samples(self.n_initial_samples, self.bounds_dict)
        return samples
    
    def acquisition_function(self, X, xi=0.01):
        """
        Expected Improvement acquisition function
        
        Parameters:
        -----------
        X : np.ndarray
            Candidate points (can be 1D or 2D)
        xi : float
            Exploration parameter
        
        Returns:
        --------
        ei : float or np.ndarray
            Expected improvement (scalar if X is 1D, array if X is 2D)
        """
        from scipy.stats import norm
        
        # Ensure X is 2D
        X = np.atleast_2d(X)
        
        mu, sigma = self.gp.predict(X, return_std=True)
        
        # Use best valid loss
        valid_losses = np.array(self.y_samples)[np.isfinite(self.y_samples)]
        if len(valid_losses) == 0:
            return np.zeros_like(mu)
        
        mu_best = np.min(valid_losses)
        
        with np.errstate(divide='warn'):
            imp = mu_best - mu - xi
            Z = imp / sigma
            ei = imp * norm.cdf(Z) + sigma * norm.pdf(Z)
            ei[sigma == 0.0] = 0.0
        
        # Return scalar if input was 1D
        if ei.shape[0] == 1:
            return float(ei[0])
        return ei
    
    def propose_next_sample(self):
        """
        Propose next sample point using acquisition function
        
        Returns:
        --------
        next_sample : np.ndarray
            Proposed parameter sample
        """
        from scipy.stats import norm
        
        # Fit GP to current data
        X = np.array(self.X_samples)
        y = np.array(self.y_samples)
        
        # Filter out failed simulations (NaN/inf values)
        valid_mask = np.isfinite(y)
        n_valid = np.sum(valid_mask)
        
        print(f"  Valid samples: {n_valid}/{len(y)}")
        
        if n_valid < 5:
            # Not enough valid samples for reliable GP
            print("  ⚠ Warning: Too few valid samples, using random exploration")
            return self.random_sample()
        
        X_valid = X[valid_mask]
        y_valid = y[valid_mask]
        
        # Fit GP on valid data only
        try:
            self.gp.fit(X_valid, y_valid)
            print(f"  ✓ GP fitted successfully on {n_valid} valid samples")
        except Exception as e:
            print(f"  ⚠ Warning: GP fitting failed ({e}), using random exploration")
            return self.random_sample()
        
        # Optimize acquisition function
        best_ei = -np.inf
        best_x = None
        
        # Try multiple random starts
        n_starts = 200
        for _ in range(n_starts):
            x0 = self.random_sample()
            
            # Expected improvement
            try:
                ei = self.acquisition_function(x0)
                
                if ei > best_ei:
                    best_ei = ei
                    best_x = x0
            except Exception as e:
                continue
        
        if best_x is None:
            print("  ⚠ Warning: Acquisition optimization failed, using random sample")
            return self.random_sample()
        
        print(f"  ✓ Proposed sample with EI = {best_ei:.6e}")
        return best_x
    
    def optimize(self, config_base, highres_results, max_iterations=500):
        """
        Run Bayesian optimization
        
        Parameters:
        -----------
        config_base : dict
            Base low-res config
        highres_results : dict
            High-res results
        max_iterations : int
            Maximum number of iterations
        
        Returns:
        --------
        best_params : dict
            Best parameters found
        """
        # Check if we should resume or start fresh
        n_existing = len(self.X_samples)
        
        if n_existing >= self.n_initial_samples:
            # Already have enough samples, skip to Bayesian optimization
            print(f"\n{'='*70}")
            print(f"RESUMING OPTIMIZATION")
            print(f"{'='*70}")
            print(f"Already have {n_existing} samples, skipping initial sampling phase")
            print(f"Starting Bayesian optimization from iteration {self.iteration + 1}")
            start_iteration = self.iteration + 1
        else:
            # Phase 1: Initial sampling
            initial_samples = self.initialize_samples()
            
            for i, params in enumerate(initial_samples):
                print(f"\n{'='*70}")
                print(f"Initial sample {i+1}/{self.n_initial_samples}")
                print(f"{'='*70}")
                loss, results, detailed = run_lowres_with_params(params, config_base, highres_results)
                
                self.X_samples.append(params)
                self.y_samples.append(loss)
                self.detailed_outputs.append(detailed)
                
                if np.isfinite(loss) and loss < self.best_loss:
                    self.best_loss = loss
                    self.best_params = params.copy()
                    self.best_iteration = len(self.X_samples) - 1
                    print(f"  ★ New best loss: {loss:.6f}")
                
                # Save after each iteration
                self.save_progress()
            
            start_iteration = self.n_initial_samples
        
        # Phase 2: Bayesian optimization
        print(f"\n{'='*70}")
        print(f"BAYESIAN OPTIMIZATION PHASE")
        print(f"{'='*70}")
        
        for iteration in range(start_iteration, max_iterations):
            try:
                self.iteration = iteration
                
                print(f"\n{'='*70}")
                print(f"Iteration {self.iteration + 1}/{max_iterations}")
                print(f"{'='*70}")
                
                # Propose next sample
                next_params = self.propose_next_sample()
                
                # Evaluate
                loss, results, detailed = run_lowres_with_params(next_params, config_base, highres_results)
                
                self.X_samples.append(next_params)
                self.y_samples.append(loss)
                self.detailed_outputs.append(detailed)
                
                if np.isfinite(loss) and loss < self.best_loss:
                    self.best_loss = loss
                    self.best_params = next_params.copy()
                    self.best_iteration = len(self.X_samples) - 1
                    print(f"  ★ New best loss: {loss:.6f}")
                
                # Print current status
                n_valid = np.sum(np.isfinite(self.y_samples))
                n_failed = len(self.y_samples) - n_valid
                print(f"\n  Status: {n_valid} successful, {n_failed} failed simulations")
                print(f"  Best loss so far: {self.best_loss:.6f} (iteration {self.best_iteration + 1})")
                
                # Save progress every iteration
                self.save_progress()
                
                # Plot progress every 10 iterations
                if (self.iteration + 1) % 10 == 0:
                    try:
                        self.plot_progress()
                    except Exception as e:
                        print(f"  ⚠ Warning: Plotting failed ({e}), continuing optimization...")
                        import traceback
                        traceback.print_exc()
            
            except KeyboardInterrupt:
                print("\n\n⚠ Optimization interrupted by user!")
                print("Saving progress before exiting...")
                self.save_progress()
                break
            
            except Exception as e:
                print(f"\n⚠ Error in iteration {self.iteration + 1}: {e}")
                import traceback
                traceback.print_exc()
                print("Continuing to next iteration...")
                # Save NaN for this failed iteration
                if 'next_params' in locals():
                    self.X_samples.append(next_params)
                else:
                    self.X_samples.append(self.random_sample())
                self.y_samples.append(np.nan)
                self.detailed_outputs.append(None)
                self.save_progress()
                continue
        
        return self.get_best_params()
    
    def get_best_params(self):
        """Return best parameters as dictionary"""
        if self.best_params is None:
            raise ValueError("No valid parameters found during optimization!")
        
        best_dict = {}
        for i, param_name in enumerate(PARAM_NAMES):
            best_dict[param_name] = float(self.best_params[i])
        return best_dict
    
    def save_progress(self, filename='optimization_progress.pkl'):
        """Save optimization progress"""
        data = {
            'X_samples': self.X_samples,
            'y_samples': self.y_samples,
            'detailed_outputs': self.detailed_outputs,
            'best_loss': self.best_loss,
            'best_params': self.best_params,
            'best_iteration': self.best_iteration,
            'iteration': self.iteration,
            'bounds_dict': self.bounds_dict,
            'n_initial_samples': self.n_initial_samples,
        }
        with open(filename, 'wb') as f:
            pickle.dump(data, f)
        print(f"  ✓ Progress saved to {filename}")
    
    @classmethod
    def load_progress(cls, filename='optimization_progress.pkl'):
        """
        Load optimization progress from file
        
        Parameters:
        -----------
        filename : str
            Checkpoint file
        
        Returns:
        --------
        optimizer : GPOptimizer
            Restored optimizer instance
        """
        print(f"Loading checkpoint from {filename}...")
        with open(filename, 'rb') as f:
            data = pickle.load(f)
        
        # Create optimizer instance
        optimizer = cls(
            bounds_dict=data['bounds_dict'],
            n_initial_samples=data['n_initial_samples']
        )
        
        # Restore state
        optimizer.X_samples = data['X_samples']
        optimizer.y_samples = data['y_samples']
        optimizer.detailed_outputs = data.get('detailed_outputs', [None] * len(data['y_samples']))
        optimizer.best_loss = data['best_loss']
        optimizer.best_params = data['best_params']
        optimizer.best_iteration = data.get('best_iteration', -1)
        optimizer.iteration = data['iteration']
        
        n_samples = len(optimizer.X_samples)
        n_valid = np.sum(np.isfinite(optimizer.y_samples))
        
        print(f"✓ Loaded checkpoint:")
        print(f"  Total samples: {n_samples}")
        print(f"  Valid samples: {n_valid}")
        print(f"  Failed samples: {n_samples - n_valid}")
        print(f"  Current iteration: {optimizer.iteration + 1}")
        print(f"  Best loss: {optimizer.best_loss:.6f} (iteration {optimizer.best_iteration + 1})")
        
        return optimizer
    
    def plot_progress(self, filename='optimization_progress.png'):
        """Plot optimization progress"""
        # We need 2 plots for loss + all parameters
        n_total_plots = 2 + len(PARAM_NAMES)
        n_cols = 3
        n_rows = (n_total_plots + n_cols - 1) // n_cols  # Ceiling division
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(18, 5*n_rows))
        
        # Handle both single row and multiple rows cases
        if n_rows == 1:
            axes = axes.reshape(1, -1)
        axes = axes.flatten()  # Flatten to 1D array for easier indexing
        
        # Filter valid samples for plotting
        y_array = np.array(self.y_samples)
        valid_mask = np.isfinite(y_array)
        
        iterations = np.arange(len(self.y_samples))
        iterations_valid = iterations[valid_mask]
        y_valid = y_array[valid_mask]
        
        # Loss vs iteration
        ax = axes[0]
        if len(y_valid) > 0:
            ax.plot(iterations_valid, y_valid, 'o', alpha=0.6, label='Valid')
            if np.any(~valid_mask):
                max_y = np.max(y_valid) if len(y_valid) > 0 else 1.0
                ax.plot(iterations[~valid_mask], np.ones(np.sum(~valid_mask)) * max_y * 1.1, 
                       'x', color='red', markersize=10, label='Failed')
            ax.axhline(self.best_loss, color='g', linestyle='--', linewidth=2,
                      label=f'Best: {self.best_loss:.4f}')
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Loss')
        ax.set_title('Loss vs Iteration')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Best loss cumulative
        ax = axes[1]
        best_so_far = []
        current_best = np.inf
        for loss in y_array:
            if np.isfinite(loss) and loss < current_best:
                current_best = loss
            best_so_far.append(current_best if np.isfinite(current_best) else np.nan)
        
        best_so_far = np.array(best_so_far)
        valid_best = np.isfinite(best_so_far)
        if np.any(valid_best):
            ax.plot(iterations[valid_best], best_so_far[valid_best], 'g-', linewidth=2)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Best Loss So Far')
        ax.set_title('Cumulative Best Loss')
        ax.grid(True, alpha=0.3)
        
        # Parameter evolution
        X = np.array(self.X_samples)
        for i, param_name in enumerate(PARAM_NAMES):
            ax_idx = i + 2  # Start from index 2
            if ax_idx >= len(axes):
                print(f"  Warning: Not enough subplots for parameter {param_name}")
                break
                
            ax = axes[ax_idx]
            # Plot all samples
            ax.plot(iterations, X[:, i], 'o-', alpha=0.4, markersize=4)
            # Highlight valid samples
            if len(iterations_valid) > 0:
                ax.plot(iterations_valid, X[valid_mask, i], 'o', alpha=0.8, markersize=6)
            # Mark best
            if self.best_params is not None:
                ax.axhline(self.best_params[i], color='r', linestyle='--', linewidth=2,
                          label=f'Best: {self.best_params[i]:.4e}')
            ax.set_xlabel('Iteration')
            ax.set_ylabel(param_name)
            ax.set_title(f'{param_name} Evolution')
            ax.legend(fontsize=8)
            ax.grid(True, alpha=0.3)
# Hide unused subplots
        for idx in range(n_total_plots, len(axes)):
            axes[idx].axis('off')
        
        plt.tight_layout()
        plt.savefig(filename, dpi=150, bbox_inches='tight')
        print(f"  ✓ Progress plot saved to {filename}")
        plt.close()
    
    def export_results(self, filename='optimization_results.pkl'):
        """
        Export detailed results including all parameters, losses, and model outputs
        
        Parameters:
        -----------
        filename : str
            Output filename
        """
        print(f"\nExporting detailed results to {filename}...")
        
        # Organize results
        results = {
            'metadata': {
                'n_total_iterations': len(self.X_samples),
                'n_successful': np.sum(np.isfinite(self.y_samples)),
                'n_failed': np.sum(~np.isfinite(self.y_samples)),
                'best_loss': self.best_loss,
                'best_iteration': self.best_iteration,
                'parameter_names': PARAM_NAMES,
                'bounds': self.bounds_dict,
            },
            'all_iterations': [],
            'best_result': None,
        }
        
        # Store all iterations
        for i, (params, loss, detailed) in enumerate(zip(self.X_samples, self.y_samples, self.detailed_outputs)):
            iter_data = {
                'iteration': i,
                'parameters': {name: float(params[j]) for j, name in enumerate(PARAM_NAMES)},
                'loss': float(loss) if np.isfinite(loss) else None,
                'is_valid': bool(np.isfinite(loss)),
                'detailed_outputs': detailed,
            }
            results['all_iterations'].append(iter_data)
        
        # Store best result
        if self.best_params is not None:
            results['best_result'] = {
                'iteration': self.best_iteration,
                'parameters': {name: float(self.best_params[j]) for j, name in enumerate(PARAM_NAMES)},
                'loss': self.best_loss,
                'detailed_outputs': self.detailed_outputs[self.best_iteration] if self.best_iteration < len(self.detailed_outputs) else None,
            }
        
        # Save to file
        with open(filename, 'wb') as f:
            pickle.dump(results, f)
        
        print(f"✓ Detailed results exported:")
        print(f"  Total iterations: {results['metadata']['n_total_iterations']}")
        print(f"  Successful: {results['metadata']['n_successful']}")
        print(f"  Failed: {results['metadata']['n_failed']}")
        print(f"  Best loss: {results['metadata']['best_loss']:.6f}")
        print(f"  Best iteration: {results['metadata']['best_iteration'] + 1}")

# ============================================================================
# MAIN EXECUTION
# ============================================================================

def main(checkpoint_file='optimization_progress.pkl', max_iterations=500):
    """Main optimization routine with automatic resume capability"""
    
    print("\n" + "="*70)
    print("ML-BASED SUBGRID PARAMETER OPTIMIZATION")
    print("="*70)
    
    # Load high-res results
    if not os.path.exists('highres_results.pkl'):
        print("\n✗ Error: highres_results.pkl not found!")
        print("  Please run main_comparison.py first to generate high-res data.")
        return
    
    print("\nLoading high-res results...")
    with open('highres_results.pkl', 'rb') as f:
        highres_results = pickle.load(f)
    print(f"✓ Loaded high-res: {highres_results['config']['nx']}x{highres_results['config']['ny']}, "
          f"{highres_results['times'][-1]:.1f} days")
    
    # Load base low-res config
    from main_comparison import config_lowres
    config_base = config_lowres.copy()
    
    # Check if checkpoint exists
    if os.path.exists(checkpoint_file):
        print(f"\n{'='*70}")
        print(f"CHECKPOINT FOUND: {checkpoint_file}")
        print(f"{'='*70}")
        
        # Load existing progress
        optimizer = GPOptimizer.load_progress(checkpoint_file)
        
        n_existing = len(optimizer.X_samples)
        
        if n_existing >= max_iterations:
            print(f"\n✓ Already completed {n_existing} iterations (target: {max_iterations})")
            print("  No additional iterations needed.")
        elif n_existing >= optimizer.n_initial_samples:
            print(f"\nResuming Bayesian optimization phase...")
            print(f"  Will run {max_iterations - n_existing} more iterations")
        else:
            print(f"\nWill complete initial sampling then continue to Bayesian optimization")
            print(f"  Need {optimizer.n_initial_samples - n_existing} more initial samples")
            print(f"  Then {max_iterations - optimizer.n_initial_samples} Bayesian iterations")
    else:
        print(f"\n{'='*70}")
        print(f"STARTING NEW OPTIMIZATION")
        print(f"{'='*70}")
        
        # Initialize new optimizer
        optimizer = GPOptimizer(
            bounds_dict=PARAM_BOUNDS,
            n_initial_samples=12,  # 2× number of parameters
        )
    
    # Run optimization
    best_params = optimizer.optimize(
        config_base=config_base,
        highres_results=highres_results,
        max_iterations=max_iterations,
    )
    
    # Print results
    print("\n" + "="*70)
    print("OPTIMIZATION COMPLETE")
    print("="*70)
    
    # Count successful/failed runs
    y_array = np.array(optimizer.y_samples)
    n_valid = np.sum(np.isfinite(y_array))
    n_failed = len(y_array) - n_valid
    
    print(f"\nTotal simulations: {len(y_array)}")
    print(f"  Successful: {n_valid}")
    print(f"  Failed: {n_failed}")
    print(f"\nBest loss: {optimizer.best_loss:.6f}")
    print(f"Best iteration: {optimizer.best_iteration + 1}")
    print("\nOptimal parameters:")
    for param_name, value in best_params.items():
        print(f"  '{param_name}': {value:.6e},")
    
    # Final plot
    try:
        optimizer.plot_progress(filename='optimization_progress_final.png')
    except Exception as e:
        print(f"⚠ Warning: Final plotting failed ({e})")
    
    # Export detailed results
    optimizer.export_results(filename='optimization_results_detailed.pkl')
    
    # Save final optimal parameters
    with open('optimal_params.pkl', 'wb') as f:
        pickle.dump(best_params, f)
    print("\n✓ Optimal parameters saved to optimal_params.pkl")
    
    # Print summary of best result
    if optimizer.best_iteration >= 0 and optimizer.best_iteration < len(optimizer.detailed_outputs):
        best_detailed = optimizer.detailed_outputs[optimizer.best_iteration]
        if best_detailed is not None:
            print(f"\nBest result details:")
            print(f"  PV NRMSE: {best_detailed['loss_q_bt']:.6f}")
            print(f"  Streamfunction NRMSE: {best_detailed['loss_psi_bt']:.6f}")
            print(f"  Total loss: {best_detailed['total_loss']:.6f}")
    
    return optimizer, best_params

def load_and_analyze_results(results_file='optimization_results_detailed.pkl'):
    """
    Load and analyze saved optimization results
    
    Parameters:
    -----------
    results_file : str
        Path to detailed results file
    
    Returns:
    --------
    results : dict
        Loaded results
    """
    print(f"Loading results from {results_file}...")
    with open(results_file, 'rb') as f:
        results = pickle.load(f)
    
    print("\n" + "="*70)
    print("OPTIMIZATION RESULTS SUMMARY")
    print("="*70)
    
    meta = results['metadata']
    print(f"\nTotal iterations: {meta['n_total_iterations']}")
    print(f"  Successful: {meta['n_successful']}")
    print(f"  Failed: {meta['n_failed']}")
    print(f"  Success rate: {100*meta['n_successful']/meta['n_total_iterations']:.1f}%")
    
    print(f"\nBest result (iteration {meta['best_iteration']+1}):")
    print(f"  Loss: {meta['best_loss']:.6f}")
    
    if results['best_result'] is not None:
        print(f"\n  Best parameters:")
        for name, value in results['best_result']['parameters'].items():
            print(f"    {name}: {value:.6e}")
        
        if results['best_result']['detailed_outputs'] is not None:
            det = results['best_result']['detailed_outputs']
            print(f"\n  Component losses:")
            print(f"    PV NRMSE: {det['loss_q_bt']:.6f}")
            print(f"    Streamfunction NRMSE: {det['loss_psi_bt']:.6f}")
    
    # Print statistics of valid runs
    valid_losses = [it['loss'] for it in results['all_iterations'] if it['is_valid']]
    if valid_losses:
        print(f"\nLoss statistics (valid runs only):")
        print(f"  Mean: {np.mean(valid_losses):.6f}")
        print(f"  Std: {np.std(valid_losses):.6f}")
        print(f"  Min: {np.min(valid_losses):.6f}")
        print(f"  Max: {np.max(valid_losses):.6f}")
        print(f"  Median: {np.median(valid_losses):.6f}")
    
    return results

def plot_parameter_distributions(results_file='optimization_results_detailed.pkl', 
                                 output_file='parameter_distributions.png'):
    """
    Plot distributions of parameters from all valid iterations
    
    Parameters:
    -----------
    results_file : str
        Path to detailed results file
    output_file : str
        Output filename for plot
    """
    print(f"Loading results from {results_file}...")
    with open(results_file, 'rb') as f:
        results = pickle.load(f)
    
    # Extract valid iterations
    valid_iters = [it for it in results['all_iterations'] if it['is_valid']]
    
    if len(valid_iters) == 0:
        print("No valid iterations found!")
        return
    
    # Extract parameters and losses
    param_names = results['metadata']['parameter_names']
    n_params = len(param_names)
    
    param_values = {name: [] for name in param_names}
    losses = []
    
    for it in valid_iters:
        for name in param_names:
            param_values[name].append(it['parameters'][name])
        losses.append(it['loss'])
    
    losses = np.array(losses)
    
    # Create plot
    n_cols = 3
    n_rows = (n_params + n_cols - 1) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4*n_rows))
    
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    axes = axes.flatten()
    
    # Plot each parameter
    for i, name in enumerate(param_names):
        ax = axes[i]
        values = np.array(param_values[name])
        
        # Scatter plot colored by loss
        scatter = ax.scatter(range(len(values)), values, c=losses, 
                           cmap='viridis_r', alpha=0.6, s=50)
        
        # Mark best
        best_idx = results['metadata']['best_iteration']
        if best_idx < len(valid_iters):
            best_val = results['best_result']['parameters'][name]
            ax.axhline(best_val, color='r', linestyle='--', linewidth=2,
                      label=f'Best: {best_val:.4e}')
        
        ax.set_xlabel('Valid Iteration Index')
        ax.set_ylabel(name)
        ax.set_title(f'{name}')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Add colorbar
        plt.colorbar(scatter, ax=ax, label='Loss')
    
    # Hide unused subplots
    for idx in range(n_params, len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.savefig(output_file, dpi=150, bbox_inches='tight')
    print(f"✓ Parameter distributions plot saved to {output_file}")
    plt.close()

def compare_fields(results_file='optimization_results_detailed.pkl',
                   iteration=None,
                   output_file='field_comparison.png'):
    """
    Plot comparison of high-res and low-res fields for a specific iteration
    
    Parameters:
    -----------
    results_file : str
        Path to detailed results file
    iteration : int or None
        Iteration to plot (None = best iteration)
    output_file : str
        Output filename
    """
    print(f"Loading results from {results_file}...")
    with open(results_file, 'rb') as f:
        results = pickle.load(f)
    
    if iteration is None:
        iteration = results['metadata']['best_iteration']
        print(f"Using best iteration: {iteration + 1}")
    
    iter_data = results['all_iterations'][iteration]
    
    if not iter_data['is_valid']:
        print(f"Iteration {iteration + 1} is not valid!")
        return
    
    detailed = iter_data['detailed_outputs']
    
    if detailed is None:
        print(f"No detailed outputs for iteration {iteration + 1}!")
        return
    
    # Extract fields
    q_bt_hr = detailed['q_bt_hr_coarse']
    q_bt_lr = detailed['q_bt_lr']
    psi_bt_hr = detailed['psi_bt_hr_coarse']
    psi_bt_lr = detailed['psi_bt_lr']
    
    # Create comparison plot
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    
    # PV comparison
    vmin_q = min(q_bt_hr.min(), q_bt_lr.min())
    vmax_q = max(q_bt_hr.max(), q_bt_lr.max())
    
    im0 = axes[0, 0].imshow(q_bt_hr, cmap='RdBu_r', vmin=vmin_q, vmax=vmax_q)
    axes[0, 0].set_title('High-res PV (coarsened)')
    plt.colorbar(im0, ax=axes[0, 0])
    
    im1 = axes[0, 1].imshow(q_bt_lr, cmap='RdBu_r', vmin=vmin_q, vmax=vmax_q)
    axes[0, 1].set_title('Low-res PV')
    plt.colorbar(im1, ax=axes[0, 1])
    
    im2 = axes[0, 2].imshow(q_bt_lr - q_bt_hr, cmap='RdBu_r')
    axes[0, 2].set_title(f'PV Error (NRMSE: {detailed["loss_q_bt"]:.4f})')
    plt.colorbar(im2, ax=axes[0, 2])
    
    # Streamfunction comparison
    vmin_psi = min(psi_bt_hr.min(), psi_bt_lr.min())
    vmax_psi = max(psi_bt_hr.max(), psi_bt_lr.max())
    
    im3 = axes[1, 0].imshow(psi_bt_hr, cmap='RdBu_r', vmin=vmin_psi, vmax=vmax_psi)
    axes[1, 0].set_title('High-res Streamfunction (coarsened)')
    plt.colorbar(im3, ax=axes[1, 0])
    
    im4 = axes[1, 1].imshow(psi_bt_lr, cmap='RdBu_r', vmin=vmin_psi, vmax=vmax_psi)
    axes[1, 1].set_title('Low-res Streamfunction')
    plt.colorbar(im4, ax=axes[1, 1])
    
    im5 = axes[1, 2].imshow(psi_bt_lr - psi_bt_hr, cmap='RdBu_r')
    axes[1, 2].set_title(f'Streamfunction Error (NRMSE: {detailed["loss_psi_bt"]:.4f})')
    plt.colorbar(im5, ax=axes[1, 2])
    
    # Add overall title
    params_str = ', '.join([f'{k}={v:.3e}' for k, v in iter_data['parameters'].items()])
    fig.suptitle(f'Iteration {iteration + 1} - Loss: {iter_data["loss"]:.4f}\n{params_str}', 
                 fontsize=10)
    
    plt.tight_layout()
    plt.savefig(output_file, dpi=150, bbox_inches='tight')
    print(f"✓ Field comparison saved to {output_file}")
    plt.close()

if __name__ == "__main__":
    # Run optimization with automatic resume
    optimizer, best_params = main(max_iterations=500)
    
    # Analyze results
    print("\n" + "="*70)
    print("ANALYZING RESULTS")
    print("="*70)
    
    results = load_and_analyze_results('optimization_results_detailed.pkl')
    
    # Plot parameter distributions
    plot_parameter_distributions('optimization_results_detailed.pkl')
    
    # Plot best field comparison
    compare_fields('optimization_results_detailed.pkl')


ML-BASED SUBGRID PARAMETER OPTIMIZATION

Loading high-res results...
✓ Loaded high-res: 512x256, 180.0 days

CHECKPOINT FOUND: optimization_progress.pkl
Loading checkpoint from optimization_progress.pkl...
✓ Loaded checkpoint:
  Total samples: 92
  Valid samples: 11
  Failed samples: 81
  Current iteration: 93
  Best loss: 0.941489 (iteration 6)

Resuming Bayesian optimization phase...
  Will run 408 more iterations

RESUMING OPTIMIZATION
Already have 92 samples, skipping initial sampling phase
Starting Bayesian optimization from iteration 93

BAYESIAN OPTIMIZATION PHASE

Iteration 94/500
  Valid samples: 11/92
  ✓ GP fitted successfully on 11 valid samples
  ✓ Proposed sample with EI = 6.636841e-01

Testing parameters:
  viscosity_scale: 1.503735e+00
  drag_scale: 2.215525e+00
  eddy_diffusivity: 2.969417e+03
  smagorinsky_coeff: 2.835742e-01
  energy_correction: -5.940144e-03
  enstrophy_correction: 3.768131e-08

Running LowRes_64x32 Simulation
Grid: 64 x 32
Resolution: 31.2 km per 

  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 6369 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           13.537600
  Barotropic Streamfunction NRMSE: 2.872482
  Loss: 9.271553

  Status: 16 successful, 81 failed simulations
  Best loss so far: 0.244561 (iteration 96)
  ✓ Progress saved to optimization_progress.pkl

Iteration 99/500
  Valid samples: 16/97
  ✓ GP fitted successfully on 16 valid samples
  ✓ Proposed sample with EI = 4.922436e-01

Testing parameters:
  viscosity_scale: 4.430997e+00
  drag_scale: 2.190848e+00
  eddy_diffusivity: 5.000774e+03
  smagorinsky_coeff: 1.106622e-03
  energy_correction: -9.541072e-03
  enstrophy_correction: 1.620744e-07

Runn

  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 7240 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           1.910096
  Barotropic Streamfunction NRMSE: 2.915204
  Loss: 2.312139

  Status: 22 successful, 81 failed simulations
  Best loss so far: 0.244561 (iteration 96)
  ✓ Progress saved to optimization_progress.pkl

Iteration 105/500
  Valid samples: 22/103
  ✓ GP fitted successfully on 22 valid samples
  ✓ Proposed sample with EI = 4.311820e-01

Testing parameters:
  viscosity_scale: 3.399729e+00
  drag_scale: 2.298944e+00
  eddy_diffusivity: 3.934275e+04
  smagorinsky_coeff: 1.908477e-01
  energy_correction: -1.003039e-03
  enstrophy_correction: 2.335064e-07

Run

  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 7685 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           181568326200618208.000000
  Barotropic Streamfunction NRMSE: 3778462187106295.500000
  Loss: 110452380595213440.000000

  Status: 31 successful, 81 failed simulations
  Best loss so far: 0.244561 (iteration 96)
  ✓ Progress saved to optimization_progress.pkl

Iteration 114/500
  Valid samples: 31/112
  ✓ GP fitted successfully on 31 valid samples
  ✓ Proposed sample with EI = 1.296590e+18

Testing parameters:
  viscosity_scale: 1.847974e+00
  drag_scale: 5.474354e-01
  eddy_diffusivity: 5.643650e+04
  smagorinsky_coeff: 1.374988e-01
  energy_correction: -1.385

ABNORMAL: 

You might also want to scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  _check_optimize_result("lbfgs", opt_res)


  ✓ GP fitted successfully on 32 valid samples
  ✓ Proposed sample with EI = 4.815137e+17

Testing parameters:
  viscosity_scale: 1.233088e+00
  drag_scale: 2.976116e+00
  eddy_diffusivity: 6.335602e+04
  smagorinsky_coeff: 7.946411e-03
  energy_correction: -8.317074e-03
  enstrophy_correction: 5.014996e-07

Running LowRes_64x32 Simulation
Grid: 64 x 32
Resolution: 31.2 km per grid point

Subgrid Parameters:
  viscosity_scale: 1.2330876491986777
  drag_scale: 2.9761159510593673
  eddy_diffusivity: 63356.015585161076
  smagorinsky_coeff: 0.007946411482243343
  energy_correction: -0.00831707407940199
  enstrophy_correction: 5.014996236080657e-07

Initial Energy: 5.940e+02
Initial Enstrophy: 8.404e-12

Integrating...

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field 

  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 6542 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           2.491390
  Barotropic Streamfunction NRMSE: 3.214971
  Loss: 2.780822

  Status: 36 successful, 81 failed simulations
  Best loss so far: 0.244561 (iteration 96)
  ✓ Progress saved to optimization_progress.pkl

Iteration 119/500
  Valid samples: 36/117
  ✓ GP fitted successfully on 36 valid samples
  ✓ Proposed sample with EI = 6.532020e+16

Testing parameters:
  viscosity_scale: 2.110982e+00
  drag_scale: 5.915551e-01
  eddy_diffusivity: 6.690980e+04
  smagorinsky_coeff: 1.727164e-01
  energy_correction: 6.773408e-03
  enstrophy_correction: 3.244382e-08

Runn

  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 6499 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           2.468200
  Barotropic Streamfunction NRMSE: 2.660230
  Loss: 2.545012

  Status: 43 successful, 81 failed simulations
  Best loss so far: 0.244561 (iteration 96)
  ✓ Progress saved to optimization_progress.pkl

Iteration 126/500
  Valid samples: 43/124
  ✓ GP fitted successfully on 43 valid samples
  ✓ Proposed sample with EI = 7.721548e+15

Testing parameters:
  viscosity_scale: 2.728149e+00
  drag_scale: 1.191543e+00
  eddy_diffusivity: 2.194633e+04
  smagorinsky_coeff: 1.392955e-01
  energy_correction: 7.927214e-03
  enstrophy_correction: 3.592194e-09

Runn

  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 7631 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           1.791399
  Barotropic Streamfunction NRMSE: 2.808979
  Loss: 2.198431

  Status: 49 successful, 81 failed simulations
  Best loss so far: 0.244561 (iteration 96)
  ✓ Progress saved to optimization_progress.pkl

Iteration 132/500
  Valid samples: 49/130
  ✓ GP fitted successfully on 49 valid samples
  ✓ Proposed sample with EI = 6.571408e+15

Testing parameters:
  viscosity_scale: 3.467610e+00
  drag_scale: 1.677730e+00
  eddy_diffusivity: 4.926983e+04
  smagorinsky_coeff: 3.785740e-02
  energy_correction: -2.126133e-03
  enstrophy_correction: 1.952881e-09

Run

  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 7689 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           15.606466
  Barotropic Streamfunction NRMSE: 2.516124
  Loss: 10.370329

  Status: 54 successful, 81 failed simulations
  Best loss so far: 0.244561 (iteration 96)
  ✓ Progress saved to optimization_progress.pkl

Iteration 137/500
  Valid samples: 54/135
  ✓ GP fitted successfully on 54 valid samples
  ✓ Proposed sample with EI = 6.019644e+15

Testing parameters:
  viscosity_scale: 1.318166e+00
  drag_scale: 7.398389e-01
  eddy_diffusivity: 2.749491e+04
  smagorinsky_coeff: 5.703587e-02
  energy_correction: 6.497120e-03
  enstrophy_correction: 1.067160e-07

Ru

  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 6740 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           2.669866
  Barotropic Streamfunction NRMSE: 2.210568
  Loss: 2.486147

  Status: 56 successful, 81 failed simulations
  Best loss so far: 0.244561 (iteration 96)
  ✓ Progress saved to optimization_progress.pkl

Iteration 139/500
  Valid samples: 56/137
  ✓ GP fitted successfully on 56 valid samples
  ✓ Proposed sample with EI = 5.873565e+15

Testing parameters:
  viscosity_scale: 2.637424e+00
  drag_scale: 2.649868e+00
  eddy_diffusivity: 7.591774e+04
  smagorinsky_coeff: 1.703757e-01
  energy_correction: 2.553278e-03
  enstrophy_correction: 2.791627e-07

Runn

  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 8567 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           1.897902
  Barotropic Streamfunction NRMSE: 2.578314
  Loss: 2.170067

  Status: 65 successful, 81 failed simulations
  Best loss so far: 0.244561 (iteration 96)
  ✓ Progress saved to optimization_progress.pkl

Iteration 148/500
  Valid samples: 65/146
  ✓ GP fitted successfully on 65 valid samples
  ✓ Proposed sample with EI = 5.247724e+15

Testing parameters:
  viscosity_scale: 4.518013e+00
  drag_scale: 1.125801e+00
  eddy_diffusivity: 5.873542e+04
  smagorinsky_coeff: 2.375688e-01
  energy_correction: -5.505347e-03
  enstrophy_correction: 1.689142e-08

Run

  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 7650 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           1.804317
  Barotropic Streamfunction NRMSE: 2.805331
  Loss: 2.204723

  Status: 72 successful, 81 failed simulations
  Best loss so far: 0.244561 (iteration 96)
  ✓ Progress saved to optimization_progress.pkl

Iteration 155/500
  Valid samples: 72/153
  ✓ GP fitted successfully on 72 valid samples
  ✓ Proposed sample with EI = 5.449767e+15

Testing parameters:
  viscosity_scale: 2.712886e+00
  drag_scale: 9.996299e-01
  eddy_diffusivity: 6.705343e+04
  smagorinsky_coeff: 9.683600e-02
  energy_correction: 2.021179e-03
  enstrophy_correction: 3.886681e-07

Runn

  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 7018 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           8.003413
  Barotropic Streamfunction NRMSE: 2.635749
  Loss: 5.856348

  Status: 76 successful, 81 failed simulations
  Best loss so far: 0.244561 (iteration 96)
  ✓ Progress saved to optimization_progress.pkl

Iteration 159/500
  Valid samples: 76/157
  ✓ GP fitted successfully on 76 valid samples
  ✓ Proposed sample with EI = 4.733137e+15

Testing parameters:
  viscosity_scale: 1.405678e+00
  drag_scale: 2.364584e+00
  eddy_diffusivity: 4.823912e+04
  smagorinsky_coeff: 2.510694e-01
  energy_correction: -8.294061e-04
  enstrophy_correction: 9.379218e-08

Run

  'q2_skew': np.mean((q2 - np.mean(q2))**3) / (np.std(q2)**3 + 1e-20),
  ret = umr_sum(arr, axis, dtype, out, keepdims, where=where)
  'q2_skew': np.mean((q2 - np.mean(q2))**3) / (np.std(q2)**3 + 1e-20),
  'q2_kurt': np.mean((q2 - np.mean(q2))**4) / (np.std(q2)**4 + 1e-20),
  'q2_kurt': np.mean((q2 - np.mean(q2))**4) / (np.std(q2)**4 + 1e-20),
  'q2_kurt': np.mean((q2 - np.mean(q2))**4) / (np.std(q2)**4 + 1e-20),
  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 8162 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           1235138596996087210922540823862608420320496396734443304545113427917265944374001188657159072715038301436448886473185749368832.000000
  Barotropic Streamfunction NRMSE: 20634221587879821426380622634492230140903228201303740814429498650889290493720039840070383933483375416480517511031526785024.000000
  Loss: 749336846832804261545872868143562176758957636167030913283221825881902429287137479413707650250173793624513487596675982688256.000000

  Status: 80 successful, 81 failed simulations
  Best loss so far: 0.244561 (iteration 96)
  ✓ Progress saved to optimization_pro

  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 7405 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           3.123874
  Barotropic Streamfunction NRMSE: 2.268293
  Loss: 2.781642

  Status: 92 successful, 81 failed simulations
  Best loss so far: 0.217110 (iteration 167)
  ✓ Progress saved to optimization_progress.pkl

Iteration 175/500
  Valid samples: 92/173
  ✓ GP fitted successfully on 92 valid samples
  ✓ Proposed sample with EI = 3.496686e+121

Testing parameters:
  viscosity_scale: 3.319661e+00
  drag_scale: 8.073336e-01
  eddy_diffusivity: 4.245371e+04
  smagorinsky_coeff: 2.740375e-01
  energy_correction: -3.937158e-03
  enstrophy_correction: 1.077099e-07

R

  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 6990 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           977298679.743897
  Barotropic Streamfunction NRMSE: 15976178.169295
  Loss: 592769679.114056

  Status: 100 successful, 81 failed simulations
  Best loss so far: 0.217110 (iteration 167)
  ✓ Progress saved to optimization_progress.pkl

Iteration 183/500
  Valid samples: 100/181
  ✓ GP fitted successfully on 100 valid samples
  ✓ Proposed sample with EI = 6.041970e+121

Testing parameters:
  viscosity_scale: 1.261304e+00
  drag_scale: 2.036684e+00
  eddy_diffusivity: 3.412531e+04
  smagorinsky_coeff: 1.053934e-02
  energy_correction: -4.047612e-03
  enstrophy_c

  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 7770 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           2.382114
  Barotropic Streamfunction NRMSE: 3.484473
  Loss: 2.823057

  Status: 113 successful, 81 failed simulations
  Best loss so far: 0.217110 (iteration 167)
  ✓ Progress saved to optimization_progress.pkl

Iteration 196/500
  Valid samples: 113/194
  ✓ GP fitted successfully on 113 valid samples
  ✓ Proposed sample with EI = 3.003121e+121

Testing parameters:
  viscosity_scale: 1.599075e+00
  drag_scale: 1.733293e+00
  eddy_diffusivity: 3.594000e+04
  smagorinsky_coeff: 6.286555e-02
  energy_correction: -2.883884e-03
  enstrophy_correction: 7.172008e-10

  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 6360 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           1.751525
  Barotropic Streamfunction NRMSE: 2.557900
  Loss: 2.074075

  Status: 118 successful, 81 failed simulations
  Best loss so far: 0.217110 (iteration 167)
  ✓ Progress saved to optimization_progress.pkl
  ✓ Progress plot saved to optimization_progress.png

Iteration 201/500
  Valid samples: 118/199
  ✓ GP fitted successfully on 118 valid samples
  ✓ Proposed sample with EI = 2.962492e+121

Testing parameters:
  viscosity_scale: 1.817748e+00
  drag_scale: 1.360245e+00
  eddy_diffusivity: 7.000430e+04
  smagorinsky_coeff: 1.099428e-01
  energy_correctio

  'q2_kurt': np.mean((q2 - np.mean(q2))**4) / (np.std(q2)**4 + 1e-20),
  'q2_kurt': np.mean((q2 - np.mean(q2))**4) / (np.std(q2)**4 + 1e-20),
  'q2_kurt': np.mean((q2 - np.mean(q2))**4) / (np.std(q2)**4 + 1e-20),
  jac = ax * by - ay * bx
  jac = ax * by - ay * bx



*** Unstable at step 7586 ***

LowRes_64x32 Simulation Complete!
  High-res grid: 512x256
  Low-res grid:  64x32
  Coarsening factor: 8x in X, 8x in Y
  Averaging over last 30 days:
    High-res: 61 snapshots
    Low-res:  61 snapshots
  High-res field shapes: q1=(256, 512), q2=(256, 512)
  Low-res field shapes:  q1=(32, 64), q2=(32, 64)
  Coarsened high-res shapes: q_bt=(32, 64), psi_bt=(32, 64)
  Barotropic PV NRMSE:           118937615919802614051629454628473254264266190641208972506427715268586055431960022919151616.000000
  Barotropic Streamfunction NRMSE: 2053971529490854195888752440865465816763871529406990256176101114989242622812835966615552.000000
  Loss: 72184158163677912936288477398845066218541264185157804930301304608137663302871070206328832.000000

  Status: 120 successful, 81 failed simulations
  Best loss so far: 0.217110 (iteration 167)
  ✓ Progress saved to optimization_progress.pkl

Iteration 203/500
  Valid samples: 120/201
  ✓ GP fitted successfully on 120 valid sample