In [6]:
def compute_smagorinsky_viscosity(psi1_h, psi2_h, KX, KY, C_smag, dx):
    """
    Compute Smagorinsky eddy viscosity: ν = (C_s Δ)² |S|
    where |S| is the strain rate magnitude.
    """
    # Velocities
    u1 = np.fft.irfft2(1j * KY * psi1_h).real
    v1 = np.fft.irfft2(-1j * KX * psi1_h).real
    
    # Strain rate tensor components
    dudx = np.fft.irfft2(1j * KX * np.fft.rfft2(u1)).real
    dudy = np.fft.irfft2(1j * KY * np.fft.rfft2(u1)).real
    dvdx = np.fft.irfft2(1j * KX * np.fft.rfft2(v1)).real
    dvdy = np.fft.irfft2(1j * KY * np.fft.rfft2(v1)).real
    
    # Strain rate magnitude: |S| = sqrt(2 S_ij S_ij)
    S11 = dudx
    S12 = 0.5 * (dudy + dvdx)
    S22 = dvdy
    S_mag = np.sqrt(2 * (S11**2 + 2*S12**2 + S22**2))
    
    # Smagorinsky viscosity
    nu_smag = (C_smag * dx)**2 * S_mag
    
    return nu_smag

def apply_subgrid_parameterization(q1_h, q2_h, psi1_h, psi2_h, state, cfg, rng):
    """
    Apply subgrid-scale parameterizations to represent missing physics.
    """
    K2 = state['K2']
    KX = state['KX']
    KY = state['KY']
    
    tendency1_h = np.zeros_like(q1_h)
    tendency2_h = np.zeros_like(q2_h)
    
    # 1. SMAGORINSKY EDDY VISCOSITY
    # Mimics dissipation from unresolved turbulent eddies
    if cfg.use_smagorinsky:
        nu_smag = compute_smagorinsky_viscosity(psi1_h, psi2_h, KX, KY, 
                                                cfg.C_smag, cfg.L/cfg.nx)
        # Apply as diffusion: -ν∇²q
        nu_smag_h = np.fft.rfft2(nu_smag)
        tendency1_h += nu_smag_h * (-K2 * q1_h)
        tendency2_h += nu_smag_h * (-K2 * q2_h)
    
    # 2. STOCHASTIC FORCING
    # Represents random fluctuations from unresolved scales
    if cfg.use_stochastic:
        # Band-limited white noise
        noise = rng.standard_normal((cfg.nx, cfg.nx))
        noise_h = np.fft.rfft2(noise)
        
        # Apply only at intermediate scales
        k_mag = np.sqrt(K2)
        kmin = cfg.stochastic_kmin * 2*np.pi / cfg.L
        kmax = cfg.stochastic_kmax * 2*np.pi / cfg.L
        mask = (k_mag >= kmin) & (k_mag <= kmax)
        
        forcing_h = np.zeros_like(noise_h)
        forcing_h[mask] = noise_h[mask] * cfg.stochastic_amp
        
        # Anti-symmetric for baroclinic injection
        tendency1_h += forcing_h
        tendency2_h += -forcing_h
    
    # 3. BACKSCATTER (optional)
    # Returns energy from subgrid to resolved scales
    if cfg.use_backscatter:
        # Negative viscosity at large scales
        k_mag = np.sqrt(K2)
        k_cutoff = 10 * 2*np.pi / cfg.L
        backscatter_mask = k_mag < k_cutoff
        
        backscatter1_h = np.zeros_like(q1_h)
        backscatter2_h = np.zeros_like(q2_h)
        backscatter1_h[backscatter_mask] = cfg.backscatter_amp * K2[backscatter_mask] * q1_h[backscatter_mask]
        backscatter2_h[backscatter_mask] = cfg.backscatter_amp * K2[backscatter_mask] * q2_h[backscatter_mask]
        
        tendency1_h += backscatter1_h
        tendency2_h += backscatter2_h
    
    return tendency1_h, tendency2_h 
    
def coarse_grain(field, factor):
    """Coarse-grain a 2D field by spatial averaging."""
    ny, nx = field.shape
    ny_c, nx_c = ny // factor, nx // factor
    coarse = field[:ny_c*factor, :nx_c*factor].reshape(
        ny_c, factor, nx_c, factor).mean(axis=(1, 3))
    return coarse

def compare_statistics(ds_truth, ds_model):
    """Compare statistics between truth and model."""
    print("\n" + "="*70)
    print("STATISTICAL COMPARISON")
    print("="*70)
    
    # Coarse-graining factor
    factor = ds_truth.dims['x'] // ds_model.dims['x']
    print(f"\nCoarse-graining truth by factor {factor}×{factor} ({ds_truth.dims['x']}→{ds_model.dims['x']})")
    
    # Skip spinup (first 20%)
    spinup_frac = 0.2
    t_start_truth = int(spinup_frac * ds_truth.dims['time'])
    t_start_model = int(spinup_frac * ds_model.dims['time'])
    
    print(f"Using data after {spinup_frac*100:.0f}% spinup\n")
    
    # Get data
    zeta_truth = ds_truth['zeta1'].isel(time=slice(t_start_truth, None)).values
    zeta_model = ds_model['zeta1'].isel(time=slice(t_start_model, None)).values
    psi_truth = ds_truth['psi_bt'].isel(time=slice(t_start_truth, None)).values
    psi_model = ds_model['psi_bt'].isel(time=slice(t_start_model, None)).values
    
    # Coarse-grain truth
    zeta_truth_cg = np.array([coarse_grain(z, factor) for z in zeta_truth])
    psi_truth_cg = np.array([coarse_grain(p, factor) for p in psi_truth])
    
    # 1. VORTICITY STATISTICS
    print("1. VORTICITY STATISTICS")
    print("-" * 70)
    z_truth_mean = zeta_truth_cg.mean()
    z_truth_std = zeta_truth_cg.std()
    z_model_mean = zeta_model.mean()
    z_model_std = zeta_model.std()
    
    print(f"  Truth (coarse):  mean={z_truth_mean:.2e}, std={z_truth_std:.2e}")
    print(f"  Model:           mean={z_model_mean:.2e}, std={z_model_std:.2e}")
    print(f"  Std ratio (model/truth): {z_model_std/z_truth_std:.3f}")
    
    if z_model_std / z_truth_std < 0.7:
        print(f"  ⚠️  Model is TOO DISSIPATIVE - reduce nu4 by ~50%")
    elif z_model_std / z_truth_std > 1.3:
        print(f"  ⚠️  Model is UNDER-DISSIPATIVE - increase nu4 by ~50%")
    else:
        print(f"  ✓ Vorticity variance is reasonably matched")
    
    # 2. ENERGY STATISTICS  
    print("\n2. ENERGY STATISTICS")
    print("-" * 70)
    ke_truth = np.var(psi_truth_cg, axis=(1,2)).mean()
    ke_model = np.var(psi_model, axis=(1,2)).mean()
    
    print(f"  Truth (coarse):  KE proxy={ke_truth:.2e}")
    print(f"  Model:           KE proxy={ke_model:.2e}")
    print(f"  Energy ratio (model/truth): {ke_model/ke_truth:.3f}")
    
    # 3. VORTICITY PDFs
    print("\n3. VORTICITY PROBABILITY DISTRIBUTIONS")
    print("-" * 70)
    z_truth_flat = zeta_truth_cg.flatten()
    z_model_flat = zeta_model.flatten()
    
    # Compute percentiles
    for p in [1, 5, 25, 75, 95, 99]:
        p_truth = np.percentile(z_truth_flat, p)
        p_model = np.percentile(z_model_flat, p)
        print(f"  {p:2d}th percentile: truth={p_truth:.2e}, model={p_model:.2e}")
    
    # 4. SPATIAL SCALES
    print("\n4. SPATIAL STRUCTURE")
    print("-" * 70)
    # Typical vortex size (correlation length estimate)
    autocorr_truth = np.mean([np.corrcoef(z[0, :], z[1, :])[0,1] 
                               for z in zeta_truth_cg[-10:]])
    autocorr_model = np.mean([np.corrcoef(z[0, :], z[1, :])[0,1] 
                               for z in zeta_model[-10:]])
    print(f"  Spatial correlation (row-to-row):")
    print(f"    Truth: {autocorr_truth:.3f}")
    print(f"    Model: {autocorr_model:.3f}")
    
    # 5. TEMPORAL EVOLUTION
    print("\n5. TEMPORAL EVOLUTION")
    print("-" * 70)
    z_truth_ts = np.array([np.abs(z).max() for z in zeta_truth_cg])
    z_model_ts = np.array([np.abs(z).max() for z in zeta_model])
    
    # Growth/decay rate (linear fit to log)
    if len(z_truth_ts) > 10:
        t_truth = np.arange(len(z_truth_ts))
        t_model = np.arange(len(z_model_ts))
        
        # Fit to first half (before saturation)
        n_fit = len(z_truth_ts) // 2
        trend_truth = np.polyfit(t_truth[:n_fit], np.log(z_truth_ts[:n_fit] + 1e-20), 1)[0]
        trend_model = np.polyfit(t_model[:n_fit], np.log(z_model_ts[:n_fit] + 1e-20), 1)[0]
        
        print(f"  Growth/decay rate (first half):")
        print(f"    Truth: {trend_truth:.2e} per snapshot")
        print(f"    Model: {trend_model:.2e} per snapshot")
    
    # SUMMARY
    print("\n" + "="*70)
    print("PARAMETER TUNING RECOMMENDATIONS")
    print("="*70)
    
    vort_ratio = z_model_std / z_truth_std
    energy_ratio = ke_model / ke_truth
    
    print(f"\nCurrent parameters:")
    print(f"  nu4 = {ds_model.attrs['nu4']:.2e} m⁴/s")
    print(f"  rek = {ds_model.attrs['rek']:.2e} s⁻¹")
    
    print(f"\nMismatch summary:")
    print(f"  Vorticity std ratio:  {vort_ratio:.3f} (target: ~1.0)")
    print(f"  Energy ratio:         {energy_ratio:.3f} (target: ~1.0)")
    
    print(f"\nRecommendations:")
    if vort_ratio < 0.8 or energy_ratio < 0.8:
        nu4_new = ds_model.attrs['nu4'] * 0.5
        print(f"  ⚠️  Model too dissipative!")
        print(f"  → Try nu4 = {nu4_new:.2e} (reduce by 50%)")
        if energy_ratio < 0.5:
            print(f"  → Or increase initial condition amplitude")
    elif vort_ratio > 1.2 or energy_ratio > 1.2:
        nu4_new = ds_model.attrs['nu4'] * 2.0
        print(f"  ⚠️  Model under-dissipative!")
        print(f"  → Try nu4 = {nu4_new:.2e} (increase by 100%)")
    else:
        print(f"  ✓ Parameters are reasonably balanced!")
        print(f"  → Ready for ML parameterization experiments")
    
    print("="*70)
    
    return zeta_truth_cg, zeta_model, psi_truth_cg, psi_model
"""
Simple 2-layer QG solver - bypasses pyqg entirely
Uses pseudospectral method for stability
"""
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from dataclasses import dataclass
from tqdm import trange

@dataclass
class Config:
    # Resolution
    nx: int = 64
    L: float = 2.0e6  # meters
    run_mode: str = "truth"  # "truth" or "model"
    
    # Physical parameters
    beta: float = 1.6e-11
    H1: float = 500.0
    H2: float = 4500.0
    rd: float = 4.0e4
    U1: float = 0.0
    U2: float = 0.0
    
    # TUNABLE SUBGRID-SCALE PARAMETERS
    # These represent missing physics at unresolved scales
    nu4: float = 1e10  # hyperviscosity [1e8, 1e11]
    rek: float = 5e-7  # bottom drag [1e-8, 1e-6]
    
    # Smagorinsky eddy viscosity (subgrid turbulence model)
    use_smagorinsky: bool = False
    C_smag: float = 0.2  # Smagorinsky constant [0.1, 0.5]
    
    # Backscatter (energy return from subgrid to resolved)
    use_backscatter: bool = False
    backscatter_amp: float = 0.01  # [0.0, 0.1]
    
    # Stochastic forcing (representing subgrid fluctuations)
    use_stochastic: bool = False
    stochastic_amp: float = 0.001  # [0.0, 0.01]
    stochastic_kmin: int = 8
    stochastic_kmax: int = 16
    
    # Time
    dt: float = 1800.0  # 30 min
    days: float = 30.0
    snap_every: int = 8
    
    # Output
    out_nc: str = "simple_qg.nc"

def setup_resolution(cfg):
    """Adjust parameters based on resolution mode."""
    if cfg.run_mode == "truth":
        cfg.nx = 256  # High resolution (was 512, too slow)
        cfg.dt = 900.0  # 15 min
        cfg.nu4 = 5e8   # WEAK dissipation - let eddies develop
        cfg.rek = 1e-7
        cfg.snap_every = 16  # every 4 hours
        cfg.out_nc = "qg_truth_256.nc"
        
        # Truth: NO subgrid parameterization (resolves scales explicitly)
        cfg.use_smagorinsky = False
        cfg.use_backscatter = False
        cfg.use_stochastic = False
        
        print("="*70)
        print("HIGH-RESOLUTION TRUTH (256×256)")
        print("="*70)
        print(f"  Grid spacing: {cfg.L/cfg.nx/1e3:.1f} km")
        print(f"  Resolves eddies down to: ~{2*cfg.L/cfg.nx/1e3:.1f} km")
        print(f"  Dissipation: MINIMAL (let dynamics develop)")
        print(f"  Subgrid models: DISABLED (resolves explicitly)")
        
    elif cfg.run_mode == "model":
        cfg.nx = 64  # Coarse resolution
        cfg.dt = 1800.0  # 30 min
        cfg.nu4 = 5e9   # Moderate dissipation
        cfg.rek = 3e-7
        cfg.snap_every = 4  # every 2 hours
        cfg.out_nc = "qg_model_64.nc"
        
        # Model: ADD subgrid parameterization (compensate for missing scales)
        cfg.use_smagorinsky = True   # ENABLE Smagorinsky eddy viscosity
        cfg.C_smag = 0.15             # TUNABLE
        cfg.use_backscatter = False   # Optional: energy backscatter
        cfg.backscatter_amp = 0.02
        cfg.use_stochastic = True     # ENABLE stochastic forcing
        cfg.stochastic_amp = 0.005    # TUNABLE
        
        print("="*70)
        print("LOW-RESOLUTION MODEL (64×64) - WITH SUBGRID PARAMETERIZATION")
        print("="*70)
        print(f"  Grid spacing: {cfg.L/cfg.nx/1e3:.1f} km")
        print(f"  Resolves eddies down to: ~{2*cfg.L/cfg.nx/1e3:.1f} km")
        print(f"  Missing scales: {2*cfg.L/cfg.nx/1e3:.0f}-{2*cfg.L/256/1e3:.0f} km")
    
    print(f"\n  TUNABLE SUBGRID PARAMETERS:")
    print(f"    nu4 (hyperviscosity): {cfg.nu4:.2e} m⁴/s [1e8, 1e11]")
    print(f"    rek (bottom drag): {cfg.rek:.2e} s⁻¹ [1e-8, 1e-6]")
    if cfg.use_smagorinsky:
        print(f"    C_smag (Smagorinsky): {cfg.C_smag:.3f} [0.1, 0.5]")
        print(f"      → Mimics eddy viscosity from unresolved turbulence")
    if cfg.use_backscatter:
        print(f"    backscatter_amp: {cfg.backscatter_amp:.4f} [0.0, 0.1]")
        print(f"      → Returns energy from subgrid to resolved")
    if cfg.use_stochastic:
        print(f"    stochastic_amp: {cfg.stochastic_amp:.4f} [0.0, 0.01]")
        print(f"      → Random forcing mimicking subgrid fluctuations")
    print()
    
    return cfg

def initialize_model(cfg):
    """Initialize spectral operators and state."""
    nx = cfg.nx
    L = cfg.L
    dx = L / nx
    
    # Wavenumbers (using rfft for efficiency)
    kx = np.fft.rfftfreq(nx, d=dx) * 2 * np.pi
    ky = np.fft.fftfreq(nx, d=dx) * 2 * np.pi
    KX, KY = np.meshgrid(kx, ky)  # Default indexing for rfft2
    K2 = KX**2 + KY**2
    K2[0, 0] = 1.0  # avoid division by zero
    
    # Deformation wavenumber
    kd = 1.0 / cfg.rd
    F = kd**2
    
    # Dissipation operators - store as multiplicative factors per timestep
    # For RK3, dissipation should be: exp(-nu*k^4*dt) per step
    dissipation_factor = np.exp(-cfg.nu4 * K2**2 * cfg.dt)
    drag_factor = np.exp(-cfg.rek * cfg.dt)
    
    print(f"[INIT] Deformation radius: {cfg.rd/1e3:.0f} km")
    print(f"[INIT] Dissipation at kmax: {dissipation_factor.min():.4f}")
    print(f"[INIT] Bottom drag factor: {drag_factor:.6f}")
    
    # Initialize with simple vortex pair
    y = np.linspace(0, L, nx, endpoint=False)
    x = np.linspace(0, L, nx, endpoint=False)
    X, Y = np.meshgrid(x, y)  # Default indexing for spatial coordinates
    
    sigma = 0.15 * L
    psi1 = 10.0 * (np.exp(-((X-0.3*L)**2 + (Y-0.5*L)**2)/(2*sigma**2)) - 
                   np.exp(-((X-0.7*L)**2 + (Y-0.5*L)**2)/(2*sigma**2)))
    psi2 = -0.6 * psi1
    
    # Convert to spectral PV
    psi1_h = np.fft.rfft2(psi1)
    psi2_h = np.fft.rfft2(psi2)
    
    q1_h = -K2 * psi1_h + F * (psi2_h - psi1_h)
    q2_h = -K2 * psi2_h + F * (psi1_h - psi2_h)
    
    state = {
        'q1_h': q1_h,
        'q2_h': q2_h,
        'KX': KX,
        'KY': KY,
        'K2': K2,
        'F': F,
        'dissipation_factor': dissipation_factor,
        'drag_factor': drag_factor,
        'dx': dx,
    }
    
    print(f"[INIT] Grid: {nx}×{nx}, dx={dx/1e3:.1f} km")
    print(f"[INIT] q1 amplitude: {np.abs(q1_h).max():.2e}")
    print(f"[INIT] Timestep: {cfg.dt/60:.1f} min")
    
    return state

def pv_to_streamfunction(q1_h, q2_h, K2, F):
    """Invert PV to get streamfunction (2-layer QG).
    
    From the 2-layer QG equations:
    q1 = ∇²ψ1 + F(ψ2 - ψ1)
    q2 = ∇²ψ2 + F(ψ1 - ψ2)
    
    In spectral space:
    q1 = -k²ψ1 + F(ψ2 - ψ1)
    q2 = -k²ψ2 + F(ψ1 - ψ2)
    
    Solve the linear system for ψ1, ψ2.
    """
    # Avoid division by zero at k=0
    K2_safe = K2.copy()
    K2_safe[0, 0] = 1.0
    
    # Matrix inversion for 2×2 system
    # [-(k²+F)    F    ] [ψ1]   [q1]
    # [   F   -(k²+F) ] [ψ2] = [q2]
    
    det = (K2_safe + F)**2 - F**2
    det[0, 0] = 1.0  # avoid division by zero
    
    psi1_h = (-(K2_safe + F) * q1_h + F * q2_h) / det
    psi2_h = (F * q1_h - (K2_safe + F) * q2_h) / det
    
    # Set k=0 mode to zero (no mean flow)
    psi1_h[0, 0] = 0.0
    psi2_h[0, 0] = 0.0
    
    return psi1_h, psi2_h

def compute_advection(q_h, psi_h, KX, KY):
    """Compute J(psi, q) = -u*dq/dx - v*dq/dy."""
    # Velocities from streamfunction
    u_h = 1j * KY * psi_h
    v_h = -1j * KX * psi_h
    
    # PV gradients
    dqdx_h = 1j * KX * q_h
    dqdy_h = 1j * KY * q_h
    
    # Transform to physical space
    u = np.fft.irfft2(u_h).real
    v = np.fft.irfft2(v_h).real
    dqdx = np.fft.irfft2(dqdx_h).real
    dqdy = np.fft.irfft2(dqdy_h).real
    
    # Advection in physical space (dealiased)
    adv = -(u * dqdx + v * dqdy)
    
    # Back to spectral
    adv_h = np.fft.rfft2(adv)
    
    return adv_h

def time_step(state, cfg, rng=None):
    """Forward Euler with dissipation and subgrid parameterization."""
    q1_h = state['q1_h']
    q2_h = state['q2_h']
    K2 = state['K2']
    F = state['F']
    KX = state['KX']
    KY = state['KY']
    dt = cfg.dt
    
    if rng is None:
        rng = np.random.default_rng()
    
    # Get streamfunctions
    psi1_h, psi2_h = pv_to_streamfunction(q1_h, q2_h, K2, F)
    
    # Advection for both layers
    adv1_h = compute_advection(q1_h, psi1_h, KX, KY)
    adv2_h = compute_advection(q2_h, psi2_h, KX, KY)
    
    # Subgrid-scale parameterization (only if enabled)
    subgrid1_h, subgrid2_h = apply_subgrid_parameterization(
        q1_h, q2_h, psi1_h, psi2_h, state, cfg, rng)
    
    # Forward Euler step
    q1_h_new = q1_h + dt * (adv1_h + subgrid1_h)
    q2_h_new = q2_h + dt * (adv2_h + subgrid2_h)
    
    # Apply explicit dissipation (hyperviscosity + bottom drag)
    q1_h_new *= state['dissipation_factor']
    q2_h_new *= state['dissipation_factor'] * state['drag_factor']
    
    state['q1_h'] = q1_h_new
    state['q2_h'] = q2_h_new
    
    return state

def run_simulation(cfg):
    """Run the QG simulation."""
    cfg = setup_resolution(cfg)
    state = initialize_model(cfg)
    total_steps = int(cfg.days * 86400 / cfg.dt)
    
    rng = np.random.default_rng(cfg.seed if hasattr(cfg, 'seed') else 42)
    
    times, psi_bt_list, zeta1_list = [], [], []
    
    for step in trange(total_steps, desc=f"QG {cfg.run_mode}"):
        state = time_step(state, cfg, rng)
        
        if (step + 1) % cfg.snap_every == 0:
            # Compute diagnostics
            psi1_h, psi2_h = pv_to_streamfunction(state['q1_h'], state['q2_h'], 
                                                   state['K2'], state['F'])
            
            psi1 = np.fft.irfft2(psi1_h, s=(cfg.nx, cfg.nx)).real
            psi2 = np.fft.irfft2(psi2_h, s=(cfg.nx, cfg.nx)).real
            
            psi_bt = 0.5 * (psi1 + psi2)
            zeta1 = -np.fft.irfft2(state['K2'] * psi1_h, s=(cfg.nx, cfg.nx)).real
            
            times.append((step + 1) * cfg.dt)
            psi_bt_list.append(psi_bt.astype(np.float32))
            zeta1_list.append(zeta1.astype(np.float32))
            
            if len(times) % 20 == 0:
                max_zeta = np.abs(zeta1).max()
                max_psi = np.abs(psi_bt).max()
                print(f"  t={times[-1]/86400:.1f}d, max|ζ|={max_zeta:.2e}, max|ψ|={max_psi:.2e}")
    
    # Save
    attrs = {
        'nx': cfg.nx, 'nu4': cfg.nu4, 'rek': cfg.rek, 'dt': cfg.dt,
        'use_smagorinsky': int(cfg.use_smagorinsky),
        'C_smag': cfg.C_smag if cfg.use_smagorinsky else 0.0,
        'use_stochastic': int(cfg.use_stochastic),
        'stochastic_amp': cfg.stochastic_amp if cfg.use_stochastic else 0.0,
    }
    
    ds = xr.Dataset(
        data_vars=dict(
            psi_bt=(("time","y","x"), np.stack(psi_bt_list)),
            zeta1=(("time","y","x"), np.stack(zeta1_list)),
        ),
        coords=dict(
            time=np.arange(len(times)),
            y=np.linspace(0, cfg.L, cfg.nx, endpoint=False),
            x=np.linspace(0, cfg.L, cfg.nx, endpoint=False),
        ),
        attrs=attrs
    )
    ds.to_netcdf(cfg.out_nc)
    print(f"\n[OK] Saved {cfg.out_nc} with {len(times)} snapshots")
    
    return ds

if __name__ == "__main__":
    # Run both truth and model
    print("\n" + "="*70)
    print("STEP 1: HIGH-RESOLUTION TRUTH")
    print("="*70)
    cfg_truth = Config(run_mode="truth", days=30)
    ds_truth = run_simulation(cfg_truth)
    
    print("\n" + "="*70)
    print("STEP 2: LOW-RESOLUTION MODEL")
    print("="*70)
    cfg_model = Config(run_mode="model", days=30)
    
    ds_model = run_simulation(cfg_model)
    
    # Statistical comparison
    zeta_truth_cg, zeta_model, psi_truth_cg, psi_model = compare_statistics(ds_truth, ds_model)
    
    # Detailed comparison plots
    print("\n" + "="*70)
    print("GENERATING COMPARISON PLOTS")
    print("="*70)
    
    fig = plt.figure(figsize=(16, 10))
    
    # Row 1: Vorticity snapshots
    ax1 = plt.subplot(3, 3, 1)
    im = ax1.pcolormesh(zeta_truth_cg[-1], cmap='RdBu_r', shading='nearest')
    ax1.set_title(f"Truth (coarse-grained to {cfg_model.nx}×{cfg_model.nx})")
    ax1.set_ylabel("Vorticity")
    plt.colorbar(im, ax=ax1)
    
    ax2 = plt.subplot(3, 3, 2)
    im = ax2.pcolormesh(zeta_model[-1], cmap='RdBu_r', shading='nearest')
    ax2.set_title(f"Model ({cfg_model.nx}×{cfg_model.nx})")
    plt.colorbar(im, ax=ax2)
    
    ax3 = plt.subplot(3, 3, 3)
    diff = zeta_model[-1] - zeta_truth_cg[-1]
    vmax_diff = np.percentile(np.abs(diff), 98)
    im = ax3.pcolormesh(diff, cmap='RdBu_r', vmin=-vmax_diff, vmax=vmax_diff, shading='nearest')
    ax3.set_title("Difference (model - truth)")
    plt.colorbar(im, ax=ax3)
    
    # Row 2: PDFs
    ax4 = plt.subplot(3, 3, 4)
    ax4.hist(zeta_truth_cg.flatten(), bins=50, alpha=0.5, label='Truth', density=True)
    ax4.hist(zeta_model.flatten(), bins=50, alpha=0.5, label='Model', density=True)
    ax4.set_xlabel("Vorticity")
    ax4.set_ylabel("PDF")
    ax4.legend()
    ax4.set_title("Vorticity Distribution")
    ax4.set_yscale('log')
    
    ax5 = plt.subplot(3, 3, 5)
    ax5.hist(psi_truth_cg.flatten(), bins=50, alpha=0.5, label='Truth', density=True)
    ax5.hist(psi_model.flatten(), bins=50, alpha=0.5, label='Model', density=True)
    ax5.set_xlabel("Streamfunction")
    ax5.set_ylabel("PDF")
    ax5.legend()
    ax5.set_title("Streamfunction Distribution")
    
    # Row 2, col 3: Q-Q plot
    ax6 = plt.subplot(3, 3, 6)
    quantiles = np.linspace(0.01, 0.99, 50)
    q_truth = np.quantile(zeta_truth_cg.flatten(), quantiles)
    q_model = np.quantile(zeta_model.flatten(), quantiles)
    ax6.scatter(q_truth, q_model, alpha=0.5, s=20)
    lims = [min(q_truth.min(), q_model.min()), max(q_truth.max(), q_model.max())]
    ax6.plot(lims, lims, 'k--', alpha=0.5, label='Perfect match')
    ax6.set_xlabel("Truth quantiles")
    ax6.set_ylabel("Model quantiles")
    ax6.set_title("Q-Q Plot (Vorticity)")
    ax6.legend()
    ax6.grid(True, alpha=0.3)
    
    # Row 3: Time series
    ax7 = plt.subplot(3, 3, 7)
    zeta_truth_ts = [np.abs(z).max() for z in zeta_truth_cg]
    zeta_model_ts = [np.abs(z).max() for z in zeta_model]
    time_truth = np.arange(len(zeta_truth_ts)) * cfg_truth.snap_every * cfg_truth.dt / 86400
    time_model = np.arange(len(zeta_model_ts)) * cfg_model.snap_every * cfg_model.dt / 86400
    ax7.plot(time_truth, zeta_truth_ts, label='Truth', alpha=0.7, linewidth=2)
    ax7.plot(time_model, zeta_model_ts, label='Model', alpha=0.7, linewidth=2)
    ax7.set_xlabel("Time (days)")
    ax7.set_ylabel("max|ζ|")
    ax7.legend()
    ax7.grid(True, alpha=0.3)
    ax7.set_title("Vorticity Evolution")
    
    ax8 = plt.subplot(3, 3, 8)
    ke_truth_ts = np.var(psi_truth_cg, axis=(1,2))
    ke_model_ts = np.var(psi_model, axis=(1,2))
    ax8.plot(time_truth, ke_truth_ts, label='Truth', alpha=0.7, linewidth=2)
    ax8.plot(time_model, ke_model_ts, label='Model', alpha=0.7, linewidth=2)
    ax8.set_xlabel("Time (days)")
    ax8.set_ylabel("KE proxy")
    ax8.legend()
    ax8.grid(True, alpha=0.3)
    ax8.set_title("Energy Evolution")
    
    # Row 3, col 3: Ratio time series
    ax9 = plt.subplot(3, 3, 9)
    # Interpolate to common times for ratio
    ratio = np.interp(time_truth, time_model, zeta_model_ts) / np.array(zeta_truth_ts)
    ax9.plot(time_truth, ratio, linewidth=2)
    ax9.axhline(y=1.0, color='k', linestyle='--', alpha=0.5, label='Perfect match')
    ax9.fill_between(time_truth, 0.8, 1.2, alpha=0.2, color='green', label='Acceptable range')
    ax9.set_xlabel("Time (days)")
    ax9.set_ylabel("Ratio (model/truth)")
    ax9.set_title("Vorticity Ratio Evolution")
    ax9.legend()
    ax9.grid(True, alpha=0.3)
    ax9.set_ylim([0.5, 1.5])
    
    plt.tight_layout()
    plt.savefig("qg_detailed_comparison.png", dpi=150, bbox_inches='tight')
    print("[OK] Saved qg_detailed_comparison.png")
    plt.show()
    
    print("\n" + "="*70)
    print("NEXT STEPS FOR ML PARAMETERIZATION")
    print("="*70)
    print("\n1. Files generated:")
    print(f"   - {cfg_truth.out_nc} (high-resolution truth)")
    print(f"   - {cfg_model.out_nc} (low-resolution model)")
    print(f"   - qg_detailed_comparison.png (analysis)")
    
    print("\n2. Key differences despite same IC:")
    print(f"   - Truth resolves {cfg_truth.nx//cfg_model.nx}× more scales")
    print(f"   - Missing scales: {2*cfg_model.L/cfg_model.nx/1e3:.0f}-{2*cfg_truth.L/cfg_truth.nx/1e3:.0f} km")
    print(f"   - These missing eddies affect large-scale flow!")
    
    print("\n3. For ML parameterization:")
    print("   a) Input:  coarse-grid state (64×64)")
    print("   b) Target: effect of missing scales on resolved flow")
    print("   c) Method: NN learns correction = truth_tendency - model_tendency")
    
    print("\n4. Parameters to tune (if not using ML):")
    print("   - nu4: currently {:.2e} → try range [1e9, 1e12]".format(cfg_model.nu4))
    print("   - rek: currently {:.2e} → try range [1e-8, 1e-6]".format(cfg_model.rek))
    print("="*70)


STEP 1: HIGH-RESOLUTION TRUTH
HIGH-RESOLUTION TRUTH (256×256)
  Grid spacing: 7.8 km
  Resolves eddies down to: ~15.6 km
  Dissipation: MINIMAL (let dynamics develop)
  Subgrid models: DISABLED (resolves explicitly)

  TUNABLE SUBGRID PARAMETERS:
    nu4 (hyperviscosity): 5.00e+08 m⁴/s [1e8, 1e11]
    rek (bottom drag): 1.00e-07 s⁻¹ [1e-8, 1e-6]

[INIT] Deformation radius: 40 km
[INIT] Dissipation at kmax: 0.0000
[INIT] Bottom drag factor: 0.999910
[INIT] Grid: 256×256, dx=7.8 km
[INIT] q1 amplitude: 1.30e+02
[INIT] Timestep: 15.0 min


QG truth:  11%|█▏        | 328/2880 [00:04<00:32, 77.83it/s]

  t=3.3d, max|ζ|=2.52e-08, max|ψ|=2.28e-01


QG truth:  22%|██▎       | 648/2880 [00:08<00:28, 78.34it/s]

  t=6.7d, max|ζ|=1.66e-08, max|ψ|=2.54e-01


QG truth:  34%|███▎      | 970/2880 [00:12<00:25, 75.23it/s]

  t=10.0d, max|ζ|=1.55e-08, max|ψ|=3.58e-01


QG truth:  45%|████▍     | 1290/2880 [00:16<00:21, 74.39it/s]

  t=13.3d, max|ζ|=1.42e-08, max|ψ|=4.59e-01


QG truth:  56%|█████▌    | 1610/2880 [00:21<00:17, 74.50it/s]

  t=16.7d, max|ζ|=1.31e-08, max|ψ|=5.57e-01


QG truth:  67%|██████▋   | 1930/2880 [00:25<00:12, 73.18it/s]

  t=20.0d, max|ζ|=1.22e-08, max|ψ|=6.52e-01


QG truth:  78%|███████▊  | 2250/2880 [00:29<00:08, 74.55it/s]

  t=23.3d, max|ζ|=1.14e-08, max|ψ|=7.45e-01


QG truth:  89%|████████▉ | 2570/2880 [00:34<00:04, 74.50it/s]

  t=26.7d, max|ζ|=1.07e-08, max|ψ|=8.35e-01


QG truth: 100%|██████████| 2880/2880 [00:38<00:00, 75.02it/s]


  t=30.0d, max|ζ|=1.02e-08, max|ψ|=9.22e-01


getfattr: Removing leading '/' from absolute path names



[OK] Saved qg_truth_256.nc with 180 snapshots

STEP 2: LOW-RESOLUTION MODEL


TypeError: Config.__init__() got an unexpected keyword argument 'forcing_type'