In [31]:
import numpy as np
import matplotlib.pyplot as plt
from numba import njit
import os

In [32]:
def total_water(state):
    qv = state['qv']
    qc = state['qc']
    qt = qv + qc
    return float(qt.sum()), float(qt.min()), float(qt.max())


In [33]:
def plot_state(state, step, params):
    u = state['u']; w = state['w']
    qc = state['qc']
    theta_p = state['theta_p']  # kept in sync
    theta0 = params['theta0']

    nx, ny, nz = u.shape
    j = ny // 2
    dx, dz = params['dx'], params['dz']

    x = (np.arange(nx) + 0.5)*dx / 1000.0
    z = (np.arange(nz) + 0.5)*dz

    T_slice = (theta0 + theta_p)[:, j, :]
    u_slice = u[:, j, :]
    w_slice = w[:, j, :]
    qc_slice = qc[:, j, :]

    cmap = plt.cm.coolwarm
    norm = plt.Normalize(vmin=T_slice.min(), vmax=T_slice.max())
    rgba = cmap(norm(T_slice))

    cloud_mask = qc_slice > 1e-4
    rgba[cloud_mask] = (1.0, 1.0, 0.0, 1.0)

    plt.clf()
    fig = plt.gcf()
    fig.set_size_inches(6, 4)
    extent = [x.min(), x.max(), z.min(), z.max()]
    plt.imshow(rgba.transpose(1,0,2), origin='lower', aspect='auto', extent=extent)

    skip = 3
    Xq, Zq = np.meshgrid(x[::skip], z[::skip], indexing='ij')
    Uq = u_slice[::skip, ::skip]
    Wq = w_slice[::skip, ::skip]
    plt.quiver(Xq, Zq, Uq, Wq, scale=20, width=0.002, color='k')

    plt.xlabel('x (km)')
    plt.ylabel('z (m)')
    plt.title(f"Step {step} | center slice | yellow = qc")
    plt.tight_layout()
    plt.pause(0.001)

def save_state(state, step, params):
    u = state['u']; w = state['w']
    qc = state['qc']
    theta_p = state['theta_p']  # kept in sync
    theta0 = params['theta0']

    nx, ny, nz = u.shape
    j = ny // 2
    dx, dz = params['dx'], params['dz']

    x = (np.arange(nx) + 0.5)*dx / 1000.0
    z = (np.arange(nz) + 0.5)*dz

    T_slice = (theta0 + theta_p)[:, j, :]
    u_slice = u[:, j, :]
    w_slice = w[:, j, :]
    qc_slice = qc[:, j, :]

    cmap = plt.cm.coolwarm
    norm = plt.Normalize(vmin=T_slice.min(), vmax=T_slice.max())
    rgba = cmap(norm(T_slice))

    cloud_mask = qc_slice > 1e-4
    rgba[cloud_mask] = (1.0, 1.0, 0.0, 1.0)

    plt.clf()
    fig = plt.gcf()
    fig.set_size_inches(6, 4)
    extent = [x.min(), x.max(), z.min(), z.max()]
    plt.imshow(rgba.transpose(1,0,2), origin='lower', aspect='auto', extent=extent)

    skip = 1
    Xq, Zq = np.meshgrid(x[::skip], z[::skip], indexing='ij')
    Uq = u_slice[::skip, ::skip]
    Wq = w_slice[::skip, ::skip]
    plt.quiver(Xq, Zq, Uq, Wq, scale=20, width=0.002, color='k')

    plt.xlabel('x (km)')
    plt.ylabel('z (m)')
    plt.title(f"Step {step} | center slice | yellow = qc")
    plt.tight_layout()

    os.makedirs("/content/stateImages", exist_ok=True)

    # Save figure
    save_path = f"/content/stateImages/state_{step:05d}.png"
    plt.savefig(save_path, dpi=150)

    plt.close(fig)
    # plt.pause(0.001)

def plot_qv(state, step, params):
    qv = state['qv']

    nx, ny, nz = qv.shape
    j = ny // 2  # center y-slice

    dx, dz = params['dx'], params['dz']

    # x in km, z in m (same convention as your other plot)
    x = (np.arange(nx) + 0.5) * dx / 1000.0
    z = (np.arange(nz) + 0.5) * dz

    # vertical slice through domain center
    qv_slice = qv[:, j, :]

    plt.clf()
    fig = plt.gcf()
    fig.set_size_inches(6, 4)

    extent = [x.min(), x.max(), z.min(), z.max()]

    im = plt.imshow(
        qv_slice.T,
        origin='lower',
        aspect='auto',
        extent=extent,
        cmap=plt.cm.Blues
    )

    cbar = plt.colorbar(im)
    cbar.set_label('qv (kg/kg)')

    plt.xlabel('x (km)')
    plt.ylabel('z (m)')
    plt.title(f"Step {step} | center slice | water vapor mixing ratio qv")
    plt.tight_layout()
    plt.pause(0.001)

def plot_qc(state, step, params):
    import numpy as np
    import matplotlib.pyplot as plt

    qc = state['qc']

    nx, ny, nz = qc.shape
    j = ny // 2  # center y-slice

    dx, dz = params['dx'], params['dz']

    # x in km, z in m
    x = (np.arange(nx) + 0.5) * dx / 1000.0
    z = (np.arange(nz) + 0.5) * dz

    # center vertical slice
    qc_slice = qc[:, j, :]

    plt.clf()
    fig = plt.gcf()
    fig.set_size_inches(6, 4)

    extent = [x.min(), x.max(), z.min(), z.max()]

    im = plt.imshow(
        qc_slice.T,
        origin='lower',
        aspect='auto',
        extent=extent,
        cmap=plt.cm.Greens
    )

    cbar = plt.colorbar(im)
    cbar.set_label('qc (kg/kg)')

    plt.xlabel('x (km)')
    plt.ylabel('z (m)')
    plt.title(f"Step {step} | center slice | cloud water mixing ratio qc")
    plt.tight_layout()
    plt.pause(0.001)
def plot_theta_p(state, step, params):
    import numpy as np
    import matplotlib.pyplot as plt

    theta_p = state['theta_p']
    theta0 = params['theta0']

    nx, ny, nz = theta_p.shape
    j = ny // 2  # center y-slice

    dx, dz = params['dx'], params['dz']

    # x in km, z in m
    x = (np.arange(nx) + 0.5) * dx / 1000.0
    z = (np.arange(nz) + 0.5) * dz

    # full potential temperature = base + perturbation
    theta_slice = (theta_p)[:, j, :]

    plt.clf()
    fig = plt.gcf()
    fig.set_size_inches(6, 4)

    extent = [x.min(), x.max(), z.min(), z.max()]

    im = plt.imshow(
        theta_slice.T,
        origin='lower',
        aspect='auto',
        extent=extent,
        cmap=plt.cm.coolwarm
    )

    cbar = plt.colorbar(im)
    cbar.set_label('θ (K)')

    plt.xlabel('x (km)')
    plt.ylabel('z (m)')
    plt.title(f"Step {step} | center slice | potential temperature θ")
    plt.tight_layout()
    plt.pause(0.001)
def plot_bottom_wind(state, step, params):
    import numpy as np
    import matplotlib.pyplot as plt

    u = state['u']
    v = state['v']

    nx, ny, nz = u.shape
    k = 0  # bottom layer

    dx, dy = params['dx'], params['dy']

    # x, y in km at cell centers
    x = (np.arange(nx) + 0.5) * dx / 1000.0
    y = (np.arange(ny) + 0.5) * dy / 1000.0

    u_slice = u[:, :, k]
    v_slice = v[:, :, k]

    # check max speed at bottom
    speed = np.sqrt(u_slice**2 + v_slice**2)
    max_speed = speed.max()

    plt.clf()
    fig = plt.gcf()
    fig.set_size_inches(6, 4)

    X, Y = np.meshgrid(x, y, indexing='ij')

    skip = 2

    if max_speed < 1e-8:
        # basically no horizontal motion here
        plt.text(0.5, 0.5,
                 "No significant bottom-layer wind",
                 ha='center', va='center', transform=plt.gca().transAxes)
    else:
        # use physical units for arrow length so small winds are visible
        plt.quiver(
            X[::skip, ::skip],
            Y[::skip, ::skip],
            u_slice[::skip, ::skip],
            v_slice[::skip, ::skip],
            scale_units='xy',
            scale=1,
            width=0.002,
            color='k'
        )

    plt.xlabel('x (km)')
    plt.ylabel('y (km)')
    plt.title(f"Step {step} | bottom layer horizontal wind (max={max_speed:.3f} m/s)")
    plt.tight_layout()
    plt.pause(0.001)
def plot_pi(state, step, params):
    import numpy as np
    import matplotlib.pyplot as plt

    pi = state['pi']

    nx, ny, nz = pi.shape
    j = ny // 2  # center y-slice

    dx, dz = params['dx'], params['dz']

    # x in km, z in m
    x = (np.arange(nx) + 0.5) * dx / 1000.0
    z = (np.arange(nz) + 0.5) * dz

    pi_slice = pi[:, j, :]

    plt.clf()
    fig = plt.gcf()
    fig.set_size_inches(6, 4)

    extent = [x.min(), x.max(), z.min(), z.max()]

    im = plt.imshow(
        pi_slice.T,
        origin='lower',
        aspect='auto',
        extent=extent,
        cmap=plt.cm.PuOr  # pressure deviations often look good on diverging colormaps
    )

    cbar = plt.colorbar(im)
    cbar.set_label('π (Exner perturbation)')

    plt.xlabel('x (km)')
    plt.ylabel('z (m)')
    plt.title(f"Step {step} | center slice | Exner perturbation π")
    plt.tight_layout()
    plt.pause(0.001)
def plot_relative_humidity(state, step, params):
    import numpy as np
    import matplotlib.pyplot as plt

    qv      = state['qv']
    theta_p = state['theta_p']
    theta0  = params['theta0']
    p0      = params['p0']

    nx, ny, nz = qv.shape
    j = ny // 2  # center y-slice

    dx, dz = params['dx'], params['dz']

    # x in km, z in m
    x = (np.arange(nx) + 0.5) * dx / 1000.0
    z = (np.arange(nz) + 0.5) * dz

    # total potential temperature and temperature
    theta = theta0 + theta_p

    # Exner and saturation mixing ratio
    Pi = exner(p0)  # assuming exner(p) from your helpers
    T  = theta * Pi
    qs = saturation_mixing_ratio(T, p0)  # from your helpers

    # Relative humidity
    RH = np.zeros_like(qv)
    mask = qs > 0.0
    RH[mask] = qv[mask] / qs[mask]

    mask = state["qc"] > 0
    if mask.sum() > 0:
        RH_cloud = (RH[mask].min(), RH[mask].max())
        print("RH over cloudy cells:", RH_cloud)

    # center slice
    RH_slice = RH[:, j, :]

    plt.contour(x, z, RH_slice.T, levels=[1.0], colors='red')

    plt.clf()
    fig = plt.gcf()
    fig.set_size_inches(6, 4)

    extent = [x.min(), x.max(), z.min(), z.max()]

    im = plt.imshow(
        RH_slice.T,
        origin='lower',
        aspect='auto',
        extent=extent,
        cmap=plt.cm.Blues,
        vmin=0.0,
        vmax=1.0,  # shows up to 150% if supersaturation occurs
    )

    cbar = plt.colorbar(im)
    cbar.set_label('Relative Humidity (RH)')

    plt.xlabel('x (km)')
    plt.ylabel('z (m)')
    plt.title(f"Step {step} | center slice | Relative Humidity")
    plt.tight_layout()
    plt.pause(0.001)


In [34]:
def compute_dt(state, params, cfl=0.4, dt_max=0.5):
    """
    Very safe dt:
      - CFL-limited by advection
      - optionally limited by diffusion
      - hard-capped by dt_max
    """
    dx, dy, dz = params['dx'], params['dy'], params['dz']
    u = state['u']; v = state['v']; w = state['w']

    # advective speeds
    max_u = float(np.max(np.abs(u)))
    max_v = float(np.max(np.abs(v)))
    max_w = float(np.max(np.abs(w)))

    # avoid divide-by-zero: if flow is still, use a big dummy denominator
    max_vel_over_space = max(
        max_u / dx if max_u > 0 else 0.0,
        max_v / dy if max_v > 0 else 0.0,
        max_w / dz if max_w > 0 else 0.0,
        1e-8,  # don't let this go to 0
    )
    dt_adv = cfl / max_vel_over_space

    # diffusion constraint (if you use nu)
    nu = params.get('nu', 0.0)
    if nu > 0.0:
        hmin = min(dx, dy, dz)
        dt_diff = 0.25 * hmin * hmin / nu
        dt = min(dt_adv, dt_diff, dt_max)
    else:
        dt = min(dt_adv, dt_max)

    return dt


In [35]:
def _zeros_like_state(state):
    """Return a dict of zero arrays matching the prognostic fields' shapes."""
    return {
        'u': np.zeros_like(state['u']),
        'v': np.zeros_like(state['v']),
        'w': np.zeros_like(state['w']),
        'theta_p': np.zeros_like(state['theta_p']),
        'qv': np.zeros_like(state['qv']),
        'qc': np.zeros_like(state['qc']),
    }


def _copy_state(state):
    """Deep copy of the prognostic fields + pi."""
    return {
        'u': state['u'].copy(),
        'v': state['v'].copy(),
        'w': state['w'].copy(),
        'theta_p': state['theta_p'].copy(),
        'qv': state['qv'].copy(),
        'qc': state['qc'].copy(),
        'pi': state.get('pi', np.zeros_like(state['u'])).copy()
    }


In [36]:
# 2nd-order upwind advection for scalars and velocity + wiring into compute_rhs + tests.
# Periodic boundaries in all directions. Returns RHS tendencies (no time integration inside).
#
# What this does:
#   Advection computes the conservative form:  ∂φ/∂t = -∇·(φ U)
#   We build face fluxes Fx,Fy,Fz using a 2nd-order upwind interpolation of φ to faces,
#   with face velocities from centered averages. Then take the flux divergence.
#
# Note:
#   - This is a pure RHS operator. You still use RK2 (or any scheme) outside to advance in time.
#   - Periodic in x,y,z via np.roll.
#
import numpy as np

def _face_average(a, axis):
    """Centered average to faces along 'axis' with periodic wrap: a_{i+1/2} = 0.5*(a_i + a_{i+1})."""
    return 0.5 * (a + np.roll(a, -1, axis=axis))

def _upwind2_to_faces(phi, vel_faces, axis):
    """
    2nd-order upwind interpolation of phi to faces along 'axis', using periodic BCs.
    For each face, choose the upwind-biased stencil based on sign of face velocity.
      If u_face > 0:  phi_{i+1/2} ≈ (3/2) phi_i   - (1/2) phi_{i-1}
      If u_face < 0:  phi_{i+1/2} ≈ (3/2) phi_{i+1} - (1/2) phi_{i+2}
    """
    phi_i   = phi
    phi_im1 = np.roll(phi,  1, axis=axis)
    phi_ip1 = np.roll(phi, -1, axis=axis)
    phi_ip2 = np.roll(phi, -2, axis=axis)

    # candidate reconstructions
    phi_face_pos = 1.5*phi_i   - 0.5*phi_im1
    phi_face_neg = 1.5*phi_ip1 - 0.5*phi_ip2

    # select by velocity sign
    return np.where(vel_faces > 0.0, phi_face_pos, phi_face_neg)

In [37]:
# -----------------
# Laplacian (periodic) and diffusion operators
# -----------------
def laplacian_periodic(phi, dx, dy, dz):
    """
    3D Laplacian with periodic boundaries in all directions.
    Centered 2nd-order finite differences.
    """
    # roll implements periodic neighbors
    d2x = (np.roll(phi, -1, axis=0) - 2.0*phi + np.roll(phi,  1, axis=0)) / (dx*dx)
    d2y = (np.roll(phi, -1, axis=1) - 2.0*phi + np.roll(phi,  1, axis=1)) / (dy*dy)
    d2z = (np.roll(phi, -1, axis=2) - 2.0*phi + np.roll(phi,  1, axis=2)) / (dz*dz)
    return d2x + d2y + d2z

def diffuse_scalar(phi, kappa, dx, dy, dz):
    """
    Return diffusion tendency for a scalar field: dphi/dt = kappa * Laplacian(phi).
    """
    if kappa == 0.0:
        return np.zeros_like(phi)
    return kappa * laplacian_periodic(phi, dx, dy, dz)

def diffuse_velocity(u, v, w, nu, dx, dy, dz):
    """
    Return diffusion tendencies for velocity components: du/dt = nu * Laplacian(u), etc.
    """
    if nu == 0.0:
        z = np.zeros_like(u)
        return z, z.copy(), z.copy()
    Lu = laplacian_periodic(u, dx, dy, dz)
    Lv = laplacian_periodic(v, dx, dy, dz)
    Lw = laplacian_periodic(w, dx, dy, dz)
    return nu * Lu, nu * Lv, nu * Lw

In [38]:
def compute_buoyancy(theta_p, qv, qv_bg, qc, theta0, g=9.81):
    """
    Buoyancy for moist Boussinesq:
      b = g * ( theta'/theta0 + 0.61*qv - qc )
    Inputs:
      theta_p: potential temperature perturbation θ' (K)
      qv: water vapor mixing ratio (kg/kg)
      qc: cloud liquid water mixing ratio (kg/kg)
      theta0: background potential temperature θ0(z) or 3D broadcast (K)
      g: gravity (m/s^2)
    Output:
      b: buoyancy (m/s^2), same shape
    """
    qv_prime = qv - qv_bg
    return g * ( (theta_p / theta0) + 0.61*qv_prime - qc )

def theta_rhs_extra(w, dtheta0_dz):
    """
    Background cooling/heating due to motion through stratification:
      ∂θ'/∂t |_bg = - w * dθ0/dz
    Inputs:
      w: vertical velocity (m/s)
      dtheta0_dz: vertical gradient of background θ0 (K/m)
    Output:
      rhs_theta_extra: same shape

      how temperature change is handled relative to altitude/vertical velocity. as you rise it gets cooler, this is what handles this
    """
    return - w * dtheta0_dz

# ---- Wiring helper: how to add into compute_rhs ----
def add_buoyancy_and_theta_to_rhs(state, params, rhs):
    """
    Mutates rhs in-place to include:
      - w-tendency from buoyancy (+b)
      - theta' tendency from background cooling (-w dθ0/dz)
    Expects in params:
      theta0 : 3D array (or broadcastable) of θ0
      dtheta0_dz : 3D array (or broadcastable) of dθ0/dz
      g : gravity (optional)
    """
    g = params.get('g', 9.81)
    theta0 = params['theta0']
    dthdz  = params['dtheta0_dz']

    b = compute_buoyancy(state['theta_p'], state['qv'], params['qv_bg'], state['qc'], theta0, g=g)
    # b -= b.mean() # to resolve w drift, uncomment after fixed bugs
    rhs['w']       += b
    rhs['theta_p'] += theta_rhs_extra(state['w'], dthdz)

In [39]:
# --- Thermo helpers ---
def exner(p, p_ref=1.0e5, R=287.0, cp=1004.0):
    """Exner function Π = (p/p_ref)^(R/cp)."""
    return (p / p_ref) ** (R / cp)

def saturation_vapor_pressure_water(T):
    """
    Saturation vapor pressure over liquid water (Pa).
    Using Bolton (1980)-style Magnus formula:
      es(hPa) = 6.112 * exp(17.67*(T-273.15)/(T-29.65))
    Convert to Pa by *100.
    """
    Tc = T - 273.15
    es_hPa = 6.112 * np.exp(17.67 * Tc / (Tc + 243.5))  # using 243.5 vs 29.65 in denominator (common variant)
    return es_hPa * 100.0

def saturation_mixing_ratio(T, p, epsilon=0.622):
    """
    Saturation mixing ratio q_s(T,p) = epsilon * es / (p - es), kg/kg.
    T in K, p in Pa.
    """
    es = saturation_vapor_pressure_water(T)
    es = np.minimum(es, 0.99 * p)  # safety clamp to avoid division blow-up
    return epsilon * es / (p - es)


In [40]:
@njit
def exner_jit(p, p_ref=1.0e5, R=287.0, cp=1004.0):
    return (p / p_ref) ** (R / cp)

@njit
def saturation_mixing_ratio_jit(T, p, epsilon=0.622):
    """
    Same formula you’d use in Python version, but scalar & njit-friendly.
    Assumes T (K), p (Pa).
    """
    Tc = T - 273.15
    # Bolton/Magnus-type over liquid water
    es_hPa = 6.112 * np.exp(17.67 * Tc / (Tc + 243.5))
    es = es_hPa * 100.0  # Pa
    return epsilon * es / (p - es)

@njit
def microphysics_saturation_adjust(theta_p, qv, qc, theta0, p0,
                                   Lv=2.5e6, cp=1004.0,
                                   p_ref=1.0e5, eps=0.622,
                                   tol=1e-12, max_iter=6,
                                   qc_crit=1e-4,       # autoconversion threshold (kg/kg)
                                   rain_frac=0.3):     # fraction of excess qc removed per call
    """
    Numba-jitted saturation adjustment with vertically varying p0(z).

    Per cell, conserves:
      - qt = qv + qc
      - θ_l = θ - (Lv/(cp*Π)) qc
    and enforces final state:
      - unsaturated (qc=0, qv=qt) if qt <= qs(T_unsat)
      - else saturated with qv = qs(T), qc = qt - qs(T)

    Logic is the same as your original np.nditer-based version.
    """

    nx, ny, nz = theta_p.shape

    theta_p_new = np.empty_like(theta_p)
    qv_new      = np.empty_like(qv)
    qc_new      = np.empty_like(qc)

    R = 287.0

    for i in range(nx):
        for j in range(ny):
            for k in range(nz):
                p0_ijk = p0[i, j, k]

                # Local Exner, latent coeff
                Pi_ijk = exner_jit(p0_ijk, p_ref, R, cp)
                c0_ijk = Lv / (cp * Pi_ijk)

                # Current state
                theta_ijk = theta0[i, j, k] + theta_p[i, j, k]
                qv_ijk = qv[i, j, k]
                qc_ijk = qc[i, j, k]

                # Invariants
                qt_i = qv_ijk + qc_ijk
                theta_l_i = theta_ijk - c0_ijk * qc_ijk

                # Try unsaturated (qc_f = 0)
                theta_unsat = theta_l_i
                T_unsat = theta_unsat * Pi_ijk
                qs_unsat = saturation_mixing_ratio_jit(T_unsat, p0_ijk, eps)

                if qt_i <= qs_unsat + tol:
                    # Fully unsaturated
                    qc_f = 0.0
                    qv_f = qt_i
                    theta_f = theta_unsat
                else:
                    # Saturated: Newton on qc_f
                    # Initial guess matches vectorized/original: qt - qs_unsat
                    qc_f = qt_i - qs_unsat
                    if qc_f < 0.0:
                        qc_f = 0.0
                    if qc_f > qt_i:
                        qc_f = qt_i

                    for _ in range(max_iter):
                        theta_f = theta_l_i + c0_ijk * qc_f
                        T_f = theta_f * Pi_ijk
                        qs_f = saturation_mixing_ratio_jit(T_f, p0_ijk, eps)

                        f = qt_i - qc_f - qs_f
                        # Same relative tolerance criterion
                        if abs(f) < tol * (qt_i if qt_i > 1e-12 else 1e-12):
                            break

                        dT = 0.05
                        qs_p = saturation_mixing_ratio_jit(T_f + dT, p0_ijk, eps)
                        dqs_dT = (qs_p - qs_f) / dT

                        df_dqc = -1.0 - dqs_dT * (c0_ijk * Pi_ijk)
                        if df_dqc == 0.0:
                            break

                        qc_newton = qc_f - f / df_dqc

                        # Clamp like original
                        if qc_newton < 0.0:
                            qc_newton = 0.0
                        if qc_newton > qt_i:
                            qc_newton = qt_i

                        qc_f = qc_newton

                    theta_f = theta_l_i + c0_ijk * qc_f
                    qv_f = qt_i - qc_f

                # --- Simple autoconversion + rain-out (instant fallout) ---
                # We do NOT change theta_f here: converting cloud water
                # to precip that falls out carries mass/latent energy away.
                if qc_f > qc_crit:
                    excess = qc_f - qc_crit
                    rain = rain_frac * excess
                    qc_f -= rain

                # Write back (with same non-negativity guards)
                theta_p_new[i, j, k] = theta_f - theta0[i, j, k]
                qv_new[i, j, k]      = qv_f if qv_f > 0.0 else 0.0
                qc_new[i, j, k]      = qc_f if qc_f > 0.0 else 0.0

    return theta_p_new, qv_new, qc_new


In [41]:
def compute_divergence(u, v, w, dx, dy, dz):
    """
    Compute discrete divergence using the same backward-difference operator
    used inside project_velocity_fd, so both solvers are compared consistently.
    """
    du_dx = (u - np.roll(u, +1, axis=0)) / dx
    dv_dy = (v - np.roll(v, +1, axis=1)) / dy
    dw_dz = (w - np.roll(w, +1, axis=2)) / dz
    return du_dx + dv_dy + dw_dz


def print_divergence(label, u, v, w, dx, dy, dz):
    """
    Print divergence norms for a corrected velocity field.
    Use this after FFT projection or FD-Jacobi projection.
    """
    div = compute_divergence(u, v, w, dx, dy, dz)

    l2 = np.sqrt(np.mean(div**2))
    maxabs = np.max(np.abs(div))
    meanval = np.mean(div)

    print(f"\n---- {label} Divergence ----")
    print(f"L2 norm:      {l2:.6e}")
    print(f"Max |div|:    {maxabs:.6e}")
    print(f"Mean(div):    {meanval:.6e}")
    print("-----------------------------")

def kinetic_energy(u, v, w):
    return 0.5 * np.mean(u**2 + v**2 + w**2)

def _wavenumbers(n, d):
    """
    Angular wavenumbers for periodic FFT derivatives.
    k = 2π * fftfreq(n, d)
    """
    return 2.0 * np.pi * np.fft.fftfreq(n, d=d)

def project_velocity(u, v, w, dx, dy, dz):
    """
    Periodic FFT-based projection.
    Inputs:
      u,v,w : 3D arrays of provisional velocity (cell-centered)
      dx,dy,dz : grid spacings
    Returns:
      u_corr, v_corr, w_corr, psi  (all real-valued 3D arrays)
    """
    ke_star = kinetic_energy(u, v, w)

    # FFTs of components
    U = np.fft.fftn(u)
    V = np.fft.fftn(v)
    W = np.fft.fftn(w)

    nx, ny, nz = u.shape
    kx = _wavenumbers(nx, dx)[:, None, None]
    ky = _wavenumbers(ny, dy)[None, :, None]
    kz = _wavenumbers(nz, dz)[None, None, :]

    # Divergence in spectral space: div_k = i(kx U + ky V + kz W)
    i = 1j
    div_k = i*(kx*U + ky*V + kz*W)

    # Solve Laplace: ∇²ψ = div  =>  -k^2 ψ_k = div_k  => ψ_k = -div_k / k^2
    k2 = kx**2 + ky**2 + kz**2
    psi_k = np.zeros_like(U, dtype=complex)
    mask = k2 != 0.0
    psi_k[mask] = -div_k[mask] / k2[mask]
    psi_k[~mask] = 0.0  # zero-mean gauge

    # Gradient of psi in spectral space
    Gx_k = i * kx * psi_k
    Gy_k = i * ky * psi_k
    Gz_k = i * kz * psi_k

    # Corrected velocities
    Uc = U - Gx_k
    Vc = V - Gy_k
    Wc = W - Gz_k

    u_corr = np.fft.ifftn(Uc).real
    v_corr = np.fft.ifftn(Vc).real
    w_corr = np.fft.ifftn(Wc).real
    psi    = np.fft.ifftn(psi_k).real

    ke_corr = kinetic_energy(u_corr, v_corr, w_corr)

    print_divergence("FFT", u_corr, v_corr, w_corr, dx, dy, dz)
    print('u star ke: ' + str(ke_star))
    print('u corr ke: ' + str(ke_corr))

    return u_corr, v_corr, w_corr, psi

# # def project_velocity_fd(u, v, w, dx, dy, dz, iters=100):
# #     """
# #     Projection that *matches your current WebGPU implementation*:

# #       - Divergence: backward differences (periodic)
# #       - Poisson solve: Jacobi with central-difference Laplacian
# #       - Gradient: forward differences (periodic)

# #     Inputs:
# #       u, v, w : 3D arrays (nx, ny, nz), provisional velocity
# #       dx, dy, dz : grid spacings
# #       iters : number of Jacobi iterations

# #     Returns:
# #       u_corr, v_corr, w_corr, psi
# #     """
# #     nx, ny, nz = u.shape

# #     # ----------------------------------------------------------
# #     # 1) Divergence: backward differences (matches divergence.ts)
# #     #    D_x u = (u[i] - u[i-1]) / dx, periodic with roll(+1)
# #     # ----------------------------------------------------------
# #     du_dx = (u - np.roll(u, +1, axis=0)) / dx
# #     dv_dy = (v - np.roll(v, +1, axis=1)) / dy
# #     dw_dz = (w - np.roll(w, +1, axis=2)) / dz

# #     div = du_dx + dv_dy + dw_dz    # rhs = b

# #     # ----------------------------------------------------------
# #     # 2) Jacobi solve for ∇² psi = div
# #     #    Laplacian: central 7-point (matches jacobi_poisson.ts)
# #     # ----------------------------------------------------------
# #     psi = np.zeros_like(u)

# #     ax = 1.0 / (dx * dx)
# #     ay = 1.0 / (dy * dy)
# #     az = 1.0 / (dz * dz)
# #     invDen = 1.0 / (2.0 * (ax + ay + az))  # same as invDen in WGSL

# #     for _ in range(iters):
# #         psi_xp = np.roll(psi, -1, axis=0)
# #         psi_xm = np.roll(psi, +1, axis=0)
# #         psi_yp = np.roll(psi, -1, axis=1)
# #         psi_ym = np.roll(psi, +1, axis=1)
# #         psi_zp = np.roll(psi, -1, axis=2)
# #         psi_zm = np.roll(psi, +1, axis=2)

# #         sumN = (
# #             ax * (psi_xp + psi_xm) +
# #             ay * (psi_yp + psi_ym) +
# #             az * (psi_zp + psi_zm)
# #         )

# #         # Jacobi update: ψ^{n+1} = (sumN - b) / (2(ax+ay+az))
# #         psi = (sumN - div) * invDen

# #     # ----------------------------------------------------------
# #     # 3) Gradient: forward differences (matches grad_subtract.ts)
# #     #    G_x psi = (psi[i+1] - psi[i]) / dx
# #     # ----------------------------------------------------------
# #     psi_xp = np.roll(psi, -1, axis=0)
# #     psi_yp = np.roll(psi, -1, axis=1)
# #     psi_zp = np.roll(psi, -1, axis=2)

# #     dpsi_dx = (psi_xp - psi) / dx
# #     dpsi_dy = (psi_yp - psi) / dy
# #     dpsi_dz = (psi_zp - psi) / dz

# #     # ----------------------------------------------------------
# #     # 4) Velocity correction: u_new = u* - ∇psi
# #     # ----------------------------------------------------------
# #     u_corr = u - dpsi_dx
# #     v_corr = v - dpsi_dy
# #     w_corr = w - dpsi_dz

# #     print_divergence("FD-Jacobi", u_corr, v_corr, w_corr, dx, dy, dz)

# #     return u_corr, v_corr, w_corr, psi

# def project_velocity_fd_with_diag(u, v, w, dx, dy, dz, psi_init=None,
#                                  max_iters=1000, tol_div=1e-8, tol_ke=1e-8):
#     # start from previous psi if provided
#     psi = np.zeros_like(u) if psi_init is None else psi_init.copy()

#     ax = 1.0/(dx*dx); ay = 1.0/(dy*dy); az = 1.0/(dz*dz)
#     invDen = 1.0 / (2.0 * (ax + ay + az))

#     # divergence (backward)
#     du_dx = (u - np.roll(u, +1, axis=0)) / dx
#     dv_dy = (v - np.roll(v, +1, axis=1)) / dy
#     dw_dz = (w - np.roll(w, +1, axis=2)) / dz
#     div = du_dx + dv_dy + dw_dz

#     # optional: subtract mean(div) for periodic solvability
#     div = div - div.mean()

#     for it in range(max_iters):
#         psi_xp = np.roll(psi, -1, axis=0); psi_xm = np.roll(psi, +1, axis=0)
#         psi_yp = np.roll(psi, -1, axis=1); psi_ym = np.roll(psi, +1, axis=1)
#         psi_zp = np.roll(psi, -1, axis=2); psi_zm = np.roll(psi, +1, axis=2)

#         sumN = (
#             ax * (psi_xp + psi_xm) +
#             ay * (psi_yp + psi_ym) +
#             az * (psi_zp + psi_zm)
#         )
#         psi_new = (sumN - div) * invDen

#         # convergence on psi itself (crude but ok)
#         diff = np.max(np.abs(psi_new - psi))
#         psi = psi_new
#         if diff < 1e-10:
#             print('breaking on ' + it)
#             break

#     # gradient (forward)
#     psi_xp = np.roll(psi, -1, axis=0)
#     psi_yp = np.roll(psi, -1, axis=1)
#     psi_zp = np.roll(psi, -1, axis=2)
#     dpsi_dx = (psi_xp - psi) / dx
#     dpsi_dy = (psi_yp - psi) / dy
#     dpsi_dz = (psi_zp - psi) / dz

#     u_star_ke = kinetic_energy(u, v, w)
#     u_corr = u - dpsi_dx
#     v_corr = v - dpsi_dy
#     w_corr = w - dpsi_dz
#     u_corr_ke = kinetic_energy(u_corr, v_corr, w_corr)

#     print_divergence("FD-Jacobi", u_corr, v_corr, w_corr, dx, dy, dz)
#     print('u star ke: ' + str(u_star_ke))
#     print('u corr ke: ' + str(u_corr_ke))

#     return u_corr, v_corr, w_corr, psi, u_star_ke, u_corr_ke



In [42]:
# def grad4(psi, dx, dy, dz):
#     """
#     4th-order central-difference gradient, periodic.
#     Returns (gx, gy, gz).
#     """
#     # x-derivative
#     psi_xm1 = np.roll(psi, +1, axis=0)
#     psi_xm2 = np.roll(psi, +2, axis=0)
#     psi_xp1 = np.roll(psi, -1, axis=0)
#     psi_xp2 = np.roll(psi, -2, axis=0)
#     gx = (-psi_xp2 + 8*psi_xp1 - 8*psi_xm1 + psi_xm2) / (12.0 * dx)

#     # y-derivative
#     psi_ym1 = np.roll(psi, +1, axis=1)
#     psi_ym2 = np.roll(psi, +2, axis=1)
#     psi_yp1 = np.roll(psi, -1, axis=1)
#     psi_yp2 = np.roll(psi, -2, axis=1)
#     gy = (-psi_yp2 + 8*psi_yp1 - 8*psi_ym1 + psi_ym2) / (12.0 * dy)

#     # z-derivative
#     psi_zm1 = np.roll(psi, +1, axis=2)
#     psi_zm2 = np.roll(psi, +2, axis=2)
#     psi_zp1 = np.roll(psi, -1, axis=2)
#     psi_zp2 = np.roll(psi, -2, axis=2)
#     gz = (-psi_zp2 + 8*psi_zp1 - 8*psi_zm1 + psi_zm2) / (12.0 * dz)

#     return gx, gy, gz


# def div4(u, v, w, dx, dy, dz):
#     """
#     4th-order central-difference divergence, periodic.
#     """
#     # du/dx
#     u_xm1 = np.roll(u, +1, axis=0)
#     u_xm2 = np.roll(u, +2, axis=0)
#     u_xp1 = np.roll(u, -1, axis=0)
#     u_xp2 = np.roll(u, -2, axis=0)
#     du_dx = (-u_xp2 + 8*u_xp1 - 8*u_xm1 + u_xm2) / (12.0 * dx)

#     # dv/dy
#     v_ym1 = np.roll(v, +1, axis=1)
#     v_ym2 = np.roll(v, +2, axis=1)
#     v_yp1 = np.roll(v, -1, axis=1)
#     v_yp2 = np.roll(v, -2, axis=1)
#     dv_dy = (-v_yp2 + 8*v_yp1 - 8*v_ym1 + v_ym2) / (12.0 * dy)

#     # dw/dz
#     w_zm1 = np.roll(w, +1, axis=2)
#     w_zm2 = np.roll(w, +2, axis=2)
#     w_zp1 = np.roll(w, -1, axis=2)
#     w_zp2 = np.roll(w, -2, axis=2)
#     dw_dz = (-w_zp2 + 8*w_zp1 - 8*w_zm1 + w_zm2) / (12.0 * dz)

#     return du_dx + dv_dy + dw_dz
# def apply_L4(psi, dx, dy, dz):
#     """
#     4th-order Laplacian for projection: L psi = div4(grad4(psi)).
#     """
#     gx, gy, gz = grad4(psi, dx, dy, dz)
#     return div4(gx, gy, gz, dx, dy, dz)
# def cg_solve_pressure_4th(b, dx, dy, dz, tol=1e-10, maxiter=1000):
#     """
#     Solve L psi = b with L = div4(grad4(.)).
#     """
#     psi = np.zeros_like(b)
#     r = b - apply_L4(psi, dx, dy, dz)
#     p = r.copy()
#     rsold = np.vdot(r, r).real

#     for k in range(maxiter):
#         Ap = apply_L4(p, dx, dy, dz)
#         alpha = rsold / np.vdot(p, Ap).real
#         psi = psi + alpha * p
#         r = r - alpha * Ap
#         rsnew = np.vdot(r, r).real
#         if np.sqrt(rsnew) < tol:
#             break
#         p = r + (rsnew / rsold) * p
#         rsold = rsnew

#     return psi
# def kinetic_energy(u, v, w):
#     return 0.5 * np.mean(u**2 + v**2 + w**2)


# def project_velocity_fd_4th(u, v, w, dx, dy, dz,
#                             tol=1e-5, maxiter=40):
#     """
#     4th-order FD projection:
#       - div4 for divergence
#       - L = div4 ∘ grad4
#       - CG solve
#       - projection u = u* - grad4(psi)
#     """
#     KE_star = kinetic_energy(u, v, w)

#     # 1) divergence
#     div = div4(u, v, w, dx, dy, dz)
#     div = div - div.mean()    # periodic solvability

#     # 2) solve for psi
#     psi = cg_solve_pressure_4th(div, dx, dy, dz, tol=tol, maxiter=maxiter)

#     # 3) grad and projection
#     gx, gy, gz = grad4(psi, dx, dy, dz)
#     u_corr = u - gx
#     v_corr = v - gy
#     w_corr = w - gz

#     KE_corr = kinetic_energy(u_corr, v_corr, w_corr)

#     # print_divergence("FD-Jacobi 4th", u_corr, v_corr, w_corr, dx, dy, dz)
#     # print('u star ke: ' + str(KE_star))
#     # print('u corr ke: ' + str(KE_corr))
#     return u_corr, v_corr, w_corr, psi


In [43]:
def grad4(psi, dx, dy, dz):
    """
    4th-order central-difference gradient, periodic.
    Returns (gx, gy, gz).
    """
    # x-derivative
    psi_xm1 = np.roll(psi, +1, axis=0)
    psi_xm2 = np.roll(psi, +2, axis=0)
    psi_xp1 = np.roll(psi, -1, axis=0)
    psi_xp2 = np.roll(psi, -2, axis=0)
    gx = (-psi_xp2 + 8*psi_xp1 - 8*psi_xm1 + psi_xm2) / (12.0 * dx)

    # y-derivative
    psi_ym1 = np.roll(psi, +1, axis=1)
    psi_ym2 = np.roll(psi, +2, axis=1)
    psi_yp1 = np.roll(psi, -1, axis=1)
    psi_yp2 = np.roll(psi, -2, axis=1)
    gy = (-psi_yp2 + 8*psi_yp1 - 8*psi_ym1 + psi_ym2) / (12.0 * dy)

    # z-derivative
    psi_zm1 = np.roll(psi, +1, axis=2)
    psi_zm2 = np.roll(psi, +2, axis=2)
    psi_zp1 = np.roll(psi, -1, axis=2)
    psi_zp2 = np.roll(psi, -2, axis=2)
    gz = (-psi_zp2 + 8*psi_zp1 - 8*psi_zm1 + psi_zm2) / (12.0 * dz)

    return gx, gy, gz


def div4(u, v, w, dx, dy, dz):
    """
    4th-order central-difference divergence, periodic.
    """
    # du/dx
    u_xm1 = np.roll(u, +1, axis=0)
    u_xm2 = np.roll(u, +2, axis=0)
    u_xp1 = np.roll(u, -1, axis=0)
    u_xp2 = np.roll(u, -2, axis=0)
    du_dx = (-u_xp2 + 8*u_xp1 - 8*u_xm1 + u_xm2) / (12.0 * dx)

    # dv/dy
    v_ym1 = np.roll(v, +1, axis=1)
    v_ym2 = np.roll(v, +2, axis=1)
    v_yp1 = np.roll(v, -1, axis=1)
    v_yp2 = np.roll(v, -2, axis=1)
    dv_dy = (-v_yp2 + 8*v_yp1 - 8*v_ym1 + v_ym2) / (12.0 * dy)

    # dw/dz
    w_zm1 = np.roll(w, +1, axis=2)
    w_zm2 = np.roll(w, +2, axis=2)
    w_zp1 = np.roll(w, -1, axis=2)
    w_zp2 = np.roll(w, -2, axis=2)
    dw_dz = (-w_zp2 + 8*w_zp1 - 8*w_zm1 + w_zm2) / (12.0 * dz)

    return du_dx + dv_dy + dw_dz

def apply_L4_anelastic(psi, inv_rho0, dx, dy, dz):
    gx, gy, gz = grad4(psi, dx, dy, dz)
    return div4(inv_rho0 * gx,
                inv_rho0 * gy,
                inv_rho0 * gz,
                dx, dy, dz)

def cg_solve_pressure_anelastic(b, inv_rho0, dx, dy, dz, tol=1e-10, maxiter=1000):
    psi = np.zeros_like(b)
    r = b - apply_L4_anelastic(psi, inv_rho0, dx, dy, dz)
    p = r.copy()
    rsold = np.vdot(r, r)

    for k in range(maxiter):
        Ap = apply_L4_anelastic(p, inv_rho0, dx, dy, dz)
        alpha = rsold / np.vdot(p, Ap)
        psi = psi + alpha * p
        r = r - alpha * Ap
        rsnew = np.vdot(r, r)
        if np.sqrt(rsnew) < tol:
            break
        p = r + (rsnew/rsold)*p
        rsold = rsnew

    return psi

def kinetic_energy(u, v, w):
    return 0.5 * np.mean(u**2 + v**2 + w**2)

def project_velocity_fd_4th_anelastic(u, v, w, rho0, inv_rho0,
                                      dx, dy, dz,
                                      tol=1e-5, maxiter=40):
    KE_star = kinetic_energy(u, v, w)

    # 1) weighted divergence: div(rho0 * u)
    div = div4(rho0 * u,
               rho0 * v,
               rho0 * w,
               dx, dy, dz)
    div = div - div.mean()    # periodic solvability

    # 2) Solve variable-coefficient Poisson:
    #    div( inv_rho0 * grad(psi) ) = div / dt
    psi = cg_solve_pressure_anelastic(div, inv_rho0, dx, dy, dz,
                                      tol=tol, maxiter=maxiter)

    # 3) Correct velocity:
    gx, gy, gz = grad4(psi, dx, dy, dz)
    u_corr = u - inv_rho0 * gx
    v_corr = v - inv_rho0 * gy
    w_corr = w - inv_rho0 * gz

    KE_corr = kinetic_energy(u_corr, v_corr, w_corr)

    # print_divergence("FD-Jacobi 4th", u_corr, v_corr, w_corr, dx, dy, dz)
    # print('u star ke: ' + str(KE_star))
    # print('u corr ke: ' + str(KE_corr))

    return u_corr, v_corr, w_corr, psi



In [44]:
# Minimal RK2 scaffold for your first step.
# - compute_rhs: returns zero tendencies (stub)
# - project_velocity: identity (no-op)
# - step_rk2: orchestrates RK2 calls following the choreography
#
# Includes a tiny smoke test to confirm that with zero RHS and identity projection,
# the state remains unchanged after one step.

def advect_scalar_anelastic(phi, u, v, w, rho0, dx, dy, dz):
    """
    Anelastic conservative advection:
        dphi/dt = -(1/rho0) * ∇ · (rho0 * phi * U)
    rho0 is background density (3D or broadcastable).
    """

    # face rho0 (simple average to faces)
    rho0_x = 0.5 * (rho0 + np.roll(rho0, -1, axis=0))
    rho0_y = 0.5 * (rho0 + np.roll(rho0, -1, axis=1))
    rho0_z = 0.5 * (rho0 + np.roll(rho0, -1, axis=2))

    # x-face fluxes
    ux = 0.5 * (u + np.roll(u, -1, axis=0))
    phi_L = phi
    phi_R = np.roll(phi, -1, axis=0)
    phi_x_up = np.where(ux > 0.0, phi_L, phi_R)
    Fx = rho0_x * ux * phi_x_up

    # y-face fluxes
    vy = 0.5 * (v + np.roll(v, -1, axis=1))
    phi_L = phi
    phi_R = np.roll(phi, -1, axis=1)
    phi_y_up = np.where(vy > 0.0, phi_L, phi_R)
    Fy = rho0_y * vy * phi_y_up

    # z-face fluxes
    wz = 0.5 * (w + np.roll(w, -1, axis=2))
    phi_L = phi
    phi_R = np.roll(phi, -1, axis=2)
    phi_z_up = np.where(wz > 0.0, phi_L, phi_R)
    Fz = rho0_z * wz * phi_z_up

    # divergence of mass-weighted flux
    dFx = (Fx - np.roll(Fx, 1, axis=0)) / dx
    dFy = (Fy - np.roll(Fy, 1, axis=1)) / dy
    dFz = (Fz - np.roll(Fz, 1, axis=2)) / dz

    return -(dFx + dFy + dFz) / rho0


def advect_velocity(u, v, w, rho0, dx, dy, dz):
    """
    Advect each velocity component with U = (u,v,w).
    Uses the same robust first-order upwind scheme.
    """
    rhs_u = advect_scalar_anelastic(u, u, v, w, rho0, dx, dy, dz)
    rhs_v = advect_scalar_anelastic(v, u, v, w, rho0, dx, dy, dz)
    rhs_w = advect_scalar_anelastic(w, u, v, w, rho0, dx, dy, dz)
    return rhs_u, rhs_v, rhs_w

# -----------------
# Stubs / placeholders
# -----------------
def compute_rhs(state, params=None):
    """
    Stub RHS builder.
    Returns tendencies for prognostic fields as zeros (same shapes).
    Expand later with advection, diffusion, buoyancy, etc.
    """
    rhs = _zeros_like_state(state)
    dx, dy, dz = params['dx'], params['dy'], params['dz']
    rho0 = params['rho0']

    qt_before = np.sum(rhs['qv'] + rhs['qc'])
    add_buoyancy_and_theta_to_rhs(state, params, rhs)
    # after buoyancy, bound w
    tau = params.get('tau_damp_w', 300.0)   # seconds
    rhs['w'] += -(state['w'] / tau)
    qt_after = np.sum(rhs['qv'] + rhs['qc'])
    print("add_buoyancy_and_theta_to_rhs Δqt:", qt_after - qt_before) if qt_after - qt_before >= 1e-15 else None


    # # # # advection
    # qt_before = np.sum(rhs['qv'] + rhs['qc'])
    du_adv, dv_adv, dw_adv = advect_velocity(state['u'], state['v'], state['w'], rho0, dx, dy, dz)
    rhs['u'] += du_adv
    rhs['v'] += dv_adv
    rhs['w'] += dw_adv
    # qt_after = np.sum(rhs['qv'] + rhs['qc'])
    # print("advect_velocity Δqt:", qt_after - qt_before) if qt_after - qt_before >= 1e-15 else None

    # qt_before = np.sum(rhs['qv'] + rhs['qc'])
    rhs['theta_p'] += advect_scalar_anelastic(state['theta_p'], state['u'], state['v'], state['w'], rho0, dx, dy, dz)
    rhs['qv']      += advect_scalar_anelastic(state['qv'],      state['u'], state['v'], state['w'], rho0, dx, dy, dz)
    rhs['qc']      += advect_scalar_anelastic(state['qc'],      state['u'], state['v'], state['w'], rho0, dx, dy, dz)
    # qt_after = np.sum(rhs['qv'] + rhs['qc'])
    # print("advect_scalar Δqt:", qt_after - qt_before) if qt_after - qt_before >= 1e-15 else None

    # # diffusion
    nu = params.get('nu', 0.0)
    kappa = params.get('kappa', 0.0)
    Dq = params.get('Dq', 0.0)

    qt_before = np.sum(rhs['qv'] + rhs['qc'])
    du, dv, dw = diffuse_velocity(state['u'], state['v'], state['w'], nu, dx, dy, dz)
    rhs['u'] += du
    rhs['v'] += dv
    rhs['w'] += dw
    qt_after = np.sum(rhs['qv'] + rhs['qc'])
    print("diffuse_velocity Δqt:", qt_after - qt_before) if qt_after - qt_before >= 1e-15 else None


    # qt_before = np.sum(rhs['qv'] + rhs['qc'])
    rhs['theta_p'] += diffuse_scalar(state['theta_p'], kappa, dx, dy, dz)
    rhs['qv']      += diffuse_scalar(state['qv'],      Dq,    dx, dy, dz)
    rhs['qc']      += diffuse_scalar(state['qc'],      Dq,    dx, dy, dz)
    # qt_after = np.sum(rhs['qv'] + rhs['qc'])
    # print("diffuse_scalar Δqt:", qt_after - qt_before) if qt_after - qt_before >= 1e-15 else None

    # # # --- Radiative cooling (simple Newtonian on theta_p) ---
    # tau_rad = params.get('tau_rad', None)
    # if tau_rad is not None and tau_rad > 0.0:
    #     # potentially remove this is potential temperature perturbation runs away or gets really positive
    #     # this should be radiative cooling but if thetap is negative then it could start adding heat to it
    #     # maybe consider clamping to min(0, ...)
    #     rhs['theta_p'] += - state['theta_p'] / tau_rad

    # # # --- Surface moist/heat source (bottom boundary layer forcing) ---
    # Nbl = params.get('Nbl', 2)  # number of bottom levels to force
    # tau_surf = params.get('tau_surf', None)

    # if tau_surf is not None and tau_surf > 0.0 and Nbl > 0:
    #     theta0   = params['theta0']
    #     theta_p  = state['theta_p']
    #     qv       = state['qv']

    #     kmax = min(Nbl, theta_p.shape[2])

    #     # Targets are already 3D: (nx, ny, kmax)
    #     theta_surf_target = params['theta_surf_target'][..., :kmax]
    #     qv_surf_target    = params['qv_surf_target'][..., :kmax]

    #     # Actual θ in bottom kmax levels
    #     theta_bl = theta0[..., :kmax] + theta_p[..., :kmax]

    #     # Relax toward targets
    #     rhs['theta_p'][..., :kmax] += (theta_surf_target - theta_bl)       / tau_surf
    #     rhs['qv'][..., :kmax]      += (qv_surf_target    - qv[..., :kmax]) / tau_surf

    return rhs

def apply_bcs(state, params=None):
    """Boundary conditions placeholder (no-op for now)."""
    return state

# -----------------
# RK2 stepper
# -----------------
def step_rk2(state, dt, params=None):
    """
    RK2 / Heun orchestrator using the stubbed pieces.
    Order:
      1) BCs
      2) rhs1
      3) provisional state*
      4) microphysics(state*)
      5) BCs(state*)
      6) projection (identity for now)
      7) rhs2 at state*
      8) final combine
      9) microphysics(final)
      10) BCs(final)
    """
    s0 = _copy_state(state)
    apply_bcs(s0, params)

    # # Stage 1 RHS
    rhs1 = compute_rhs(s0, params)

    # # Provisional state*
    s_star = _copy_state(s0)
    qt_before = np.sum(s_star['qv'] + s_star['qc'])
    s_star['u']       = s0['u']       + dt * rhs1['u']
    s_star['v']       = s0['v']       + dt * rhs1['v']
    s_star['w']       = s0['w']       + dt * rhs1['w']
    s_star['theta_p'] = s0['theta_p'] + dt * rhs1['theta_p']
    s_star['qv']      = s0['qv']      + dt * rhs1['qv']
    s_star['qc']      = s0['qc']      + dt * rhs1['qc']
    qt_after = np.sum(s_star['qv'] + s_star['qc'])
    # print("advect_scalar Δqt:", qt_after - qt_before) if qt_after - qt_before >= 1e-15 else None

    # Microphysics + BCs
    # assert np.max(state['qv']) < 0.1, "qv too large before microphysics"
    # assert np.max(state['qc']) < 0.1, "qc too large before microphysics"

    s_star['theta_p'], s_star['qv'], s_star['qc'] = microphysics_saturation_adjust(s_star['theta_p'], s_star['qv'], s_star['qc'],
                                                                                   params['theta0'], params['p0'],
                                                                                   qc_crit=params.get('qc_crit', 1e-4),
                                                                                   rain_frac=params.get('rain_frac', 0.3))
    # apply_bcs(s_star, params)

    # # Projection
    # s_star['u'], s_star['v'], s_star['w'], s_star['pi'] = project_velocity(
    #     s_star['u'], s_star['v'], s_star['w'], dx, dy, dz
    # )
    # s_star['u'], s_star['v'], s_star['w'], s_star['pi'] = project_velocity_fd_4th(
    #     s_star['u'], s_star['v'], s_star['w'], dx, dy, dz
    # )
    s_star['u'], s_star['v'], s_star['w'], s_star['pi'] = project_velocity_fd_4th_anelastic(s_star['u'], s_star['v'], s_star['w'], params['rho0'], params['inv_rho0'], dx, dy, dz)

    # Stage 2 RHS
    rhs2 = compute_rhs(s_star, params)

    # Final combine
    s_new = _copy_state(s0)
    qt_before = np.sum(s_new['qv'] + s_new['qc'])
    s_new['u']       = s0['u']       + 0.5 * dt * (rhs1['u']       + rhs2['u'])
    s_new['v']       = s0['v']       + 0.5 * dt * (rhs1['v']       + rhs2['v'])
    s_new['w']       = s0['w']       + 0.5 * dt * (rhs1['w']       + rhs2['w'])
    s_new['theta_p'] = s0['theta_p'] + 0.5 * dt * (rhs1['theta_p'] + rhs2['theta_p'])
    s_new['qv']      = s0['qv']      + 0.5 * dt * (rhs1['qv']      + rhs2['qv'])
    s_new['qc']      = s0['qc']      + 0.5 * dt * (rhs1['qc']      + rhs2['qc'])
    qt_after = np.sum(s_new['qv'] + s_new['qc'])
    # print("advect_scalar Δqt:", qt_after - qt_before) if qt_after - qt_before >= 1e-15 else None

    # Final microphysics + BCs
    # assert np.max(state['qv']) < 0.1, "qv too large before microphysics"
    # assert np.max(state['qc']) < 0.1, "qc too large before microphysics"

    s_new['theta_p'], s_new['qv'], s_new['qc'] = microphysics_saturation_adjust(s_new['theta_p'], s_new['qv'], s_new['qc'],
                                                                                params['theta0'], params['p0'],
                                                                                qc_crit=params.get('qc_crit', 1e-4),
                                                                                rain_frac=params.get('rain_frac', 0.3))
    # apply_bcs(s_new, params)

    # # Projection
    # s_new['u'], s_new['v'], s_new['w'], s_new['pi'] = project_velocity(
    #     s_new['u'], s_new['v'], s_new['w'], dx, dy, dz
    # )
    # s_new['u'], s_new['v'], s_new['w'], s_new['pi'] = project_velocity_fd_4th(
    #     s_new['u'], s_new['v'], s_new['w'], dx, dy, dz
    # )
    s_new['u'], s_new['v'], s_new['w'], s_new['pi'] = project_velocity_fd_4th_anelastic(s_new['u'], s_new['v'], s_new['w'], params['rho0'], params['inv_rho0'], dx, dy, dz)


    return s_new


In [None]:
# ------------------ Initialization that makes a small cloud ------------------
nx, ny, nz = 64, 64, 40
Lx, Ly, Lz = 6400.0, 6400.0, 4000.0   # 6.4 km × 6.4 km × 4 km
dx, dy, dz = Lx/nx, Ly/ny, Lz/nz

# Grid
x = (np.arange(nx)+0.5)*dx
y = (np.arange(ny)+0.5)*dy
z = (np.arange(nz)+0.5)*dz
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')

# --- Background thermodynamics ---
theta0_surface = 300.0  # K
# dtheta0_dz = -0.0065    # K/m  (temperature decreases upward)
dtheta0_dz = 0.003
theta0 = theta0_surface + dtheta0_dz * Z

# constant pressure isnt a bug but as you move beyond toy this should be array with changing pressure
# p0 = 1.0e5  # Pa
R = 287.0
g = 9.81
T_ref = 300.0
H = R * T_ref / g          # ~ 8780 m

p_surf = 1.0e5             # Pa
p0 = p_surf * np.exp(-Z / H)   # shape (nx, ny, nz)

# Exner Π(z)
cp = 1004.0
p_ref = 1.0e5
Pi0 = exner(p0, p_ref=p_ref, R=R, cp=cp)

T0 = theta0 * Pi0      # treat as actual temperature here
# Saturation mixing ratio (same form as your microphysics)
# def qsat(T, p, eps=0.622):
#     es = 610.94 * np.exp(17.625*(T - 273.15)/(T - 30.11))  # Pa
#     return eps * es / (p - es)

# density
rho0 = p0 / (R * T0)   # shape (nx, ny, nz)

# Optional: also keep 1D version for cleaner operators
rho0_z = rho0[0, 0, :].copy()
inv_rho0 = 1.0 / rho0

# qs_bg = qsat(T0, p0)
qs_bg = saturation_mixing_ratio(T0, p0)

# Background relative humidity (~80%)
RH_bg = 0.4
qv_bg = RH_bg * qs_bg

# # --- Warm, moist bubble (radius ~800 m) centered near 400 m height ---
# xb, yb, zb = Lx*0.5, Ly*0.5, 400.0
# rb = 800.0
# r2 = (X - xb)**2 + (Y - yb)**2 + (Z - zb)**2
# bubble = np.exp(-r2 / (2.0 * rb**2))

# # Bubble is slightly warmer and closer to saturation
# RH_bubble = 0.9
# theta_p0 = 0.5 * bubble          # ~+0.5 K warm anomaly
# qv0 = qv_bg + (RH_bubble * qs_bg - qv_bg) * bubble
# --- Warm, moist bubble centered near 800 m, vertically compact ---
xb, yb, zb = Lx*0.5, Ly*0.5, 800.0

rb_xy = 800.0   # horizontal radius
rb_z  = 300.0   # vertical radius (tighter)

r2_xy = (X - xb)**2 + (Y - yb)**2
r2_z  = (Z - zb)**2

bubble = np.exp(-r2_xy / (2.0 * rb_xy**2) - r2_z / (2.0 * rb_z**2))


RH_bubble = 0.95
theta_amp = 1.5
theta_p0 = theta_amp * bubble
qv0 = qv_bg + (RH_bubble * qs_bg - qv_bg) * bubble

# ------------------ Surface forcing targets (localized under bubble) ------------------
# Config
Nbl       = 2          # number of bottom levels to force
delta_theta_core = 1.5 # K extra at center under bubble
RH_surf_core    = 0.9  # target RH at center

kmax = min(Nbl, nz)

# Gaussian mask under bubble on lowest level
Pi_surf       = Pi0[:, :, 0]
p_surf_level  = p0[:, :, 0]
qv_bg_surf    = qv_bg[:, :, 0]

r2_xy_surf = (X[:, :, 0] - xb)**2 + (Y[:, :, 0] - yb)**2
rb_forcing = rb_xy
forcing_mask = np.exp(-r2_xy_surf / (2.0 * rb_forcing**2))   # (nx, ny), 0–1

# Temperature target: background profile + warm anomaly only where mask>0
theta_surf_target = theta0[:, :, :kmax] + forcing_mask[:, :, None] * delta_theta_core
# (nx, ny, kmax)

# Moisture target:
# First define "core" qv at center using warm surface + RH_surf_core
theta_core_xy = theta0_surface + delta_theta_core           # scalar
T_core_xy     = theta_core_xy * Pi_surf                     # (nx, ny)
qs_core_xy    = saturation_mixing_ratio(T_core_xy, p_surf_level)
qv_core_xy    = RH_surf_core * qs_core_xy                   # (nx, ny)

# Blend between background qv and core qv across mask, extend through kmax levels
qv_surf_target = (
    qv_bg[:, :, :kmax]
    + forcing_mask[:, :, None] * (qv_core_xy - qv_bg_surf)[:, :, None]
)
# (nx, ny, kmax)
# --- Zero initial motion, no cloud water ---
u0 = np.zeros_like(theta0)
v0 = np.zeros_like(theta0)
w0 = np.zeros_like(theta0)
qc0 = np.zeros_like(theta0)
pi0 = np.zeros_like(theta0)

state = {
    'u': u0, 'v': v0, 'w': w0,
    'theta_p': theta_p0,
    'qv': qv0, 'qc': qc0, 'pi': pi0
}

params = {
    'dx': dx, 'dy': dy, 'dz': dz,
    'theta0': theta0,
    'dtheta0_dz': np.full_like(theta0, dtheta0_dz),
    'p0': p0,
    'Lv': 2.5e6, 'cp': 1004.0, 'R': 287.0,
    'p_ref': 1.0e5, 'eps': 0.622,
    'g': 9.81,
    'qv_bg': qv_bg,
    'nu': 10.0, # momentum
    'kappa': 10.0, # thermal
    'Dq': 10.0, # moisture
    'tau_damp_w': 300,

    # radiative cooling timescale (slow)
    'tau_rad': 1800.0,             # 1 hour; you can tune 1800–7200 s

    # surface layer forcing
    'Nbl': Nbl, # number bottom layers to force
    'tau_surf': 100.0,            # 30 min relaxation toward warm/moist surface
    'theta_surf_target': theta_surf_target,
    'qv_surf_target': qv_surf_target,

    # precip
    'qc_crit': 1e-4, # threshold to start precip
    'rain_frac': 0.3, # how much of excess liquid falls out

    # density
    'rho0': rho0,
    'inv_rho0': inv_rho0,
}


# --- run a short integration: buoyancy + microphysics only ---
dt = 0.01  # seconds (tiny, safe)
# steps = 2000
steps = 4000

def diag(state, step):
    qc_max = state['qc'].max()
    w_max  = state['w'].max()
    thp_max= state['theta_p'].max()
    print(f"step {step:4d}: max qc = {qc_max*1000:.3f} g/kg | max w = {w_max:.3f} m/s | max θ' = {thp_max:.3f} K")

print("Running... (expect qc to flip on near the bubble as it lifts, w to grow)")
plt.figure()

qt0, _, _ = total_water(state)
print("IC qt_sum:", qt0)

for n in range(steps):
    dt = compute_dt(state, params, cfl=0.4)
    state = step_rk2(state, dt, params)
    # print("mean w:", state['w'].mean())
    if n == 0:
        qt1, _, _ = total_water(state)
        print("After first step qt_sum:", qt1)
        # assert abs(qt1 - qt0) < 1e-10 * qt0


    # save_state(state, n, params)

    if n % 10 == 0 or n == steps-1:
        diag(state, n)
        plot_state(state, n, params)
        plot_qv(state, n, params)
        plot_qc(state, n, params)
        plot_theta_p(state, n, params)
        plot_bottom_wind(state, n, params)
        plot_pi(state, n, params)
        plot_relative_humidity(state, n, params)

        # qtSum, qtMin, qtMax = total_water(state)
        # print("theta_p max and min ", state['theta_p'].max(), state['theta_p'].min())
        # print("qv max and min ", state['qv'].max(), state['qv'].min())
        # print("qc max and min ", state['qc'].max(), state['qc'].min())
        # print("max w:", state['w'].max())
        # print("qt sum: ", qtSum)
        # print("qt min: ", qtMin)
        # print("qt max: ", qtMax)

# Provide simple summaries for quick confirmation
qc_any = float((state['qc']>0).sum())
wpos = float((state['w']>0).sum())
print("\nSummary: cells with cloud water >", qc_any, " | cells with upward motion >", wpos)


In [None]:
# !ffmpeg -framerate 120 -pattern_type glob -i '/content/stateImages/state_*.png' \
#   -c:v libx264 -preset fast -crf 23 -pix_fmt yuv420p \
#   /content/animation.mp4
# from google.colab import files
# files.download('/content/animation.mp4')

