In [72]:
import numpy as np
import numba as nb
from typing import Tuple

# ==========================================================================
# 1. Standalone Building Blocks for the New Laplacian Method
# ==========================================================================

def calculate_pTheta_placeholder(nx: int, ny: int) -> np.ndarray:
    """Placeholder: Calculates coefficient corresponding to hplus (cell-centered)."""
    # In reality, this depends on thermodynamic variables (rho, rhoY, GammaInv, Y)
    # For testing, use random positive values.
    print("INFO: Using placeholder for pTheta coefficient.")
    return np.random.rand(nx, ny) + 0.5 # Shape (nx, ny) - cell centered

def calculate_helmholtz_coeff_placeholder(nx: int, ny: int) -> np.ndarray:
    """Placeholder: Calculates Helmholtz coefficient (nodal)."""
    # In reality, depends on compressibility, rhoY, etc.
    # For testing, use random values.
    print("INFO: Using placeholder for Helmholtz coefficient.")
    # Note: C code calculates this via scattering, resulting in nodal values.
    return np.random.rand(nx + 1, ny + 1) * 0.1 # Shape (nx+1, ny+1) - nodal

@nb.jit(nopython=True, nogil=True, cache=True)
def gradient_nodal_to_cell(p_nodal: np.ndarray, dx: float, dy: float, nx: int, ny: int) -> Tuple[np.ndarray, np.ndarray]:
    """
    Computes gradient at cell centers from nodal values.
    Equivalent to the C code's Dpx, Dpy calculation inside
    correction_increments_nodes and the user's original Python gradient.
    Assumes p_nodal includes boundary nodes.
    """
    dpdx = np.zeros((nx, ny))
    dpdy = np.zeros((nx, ny))

    # This implements the 0.25 * (diff + diff + diff + diff) averaging scheme
    # For dpdx at cell (i, j), uses nodes (i,j), (i+1,j), (i,j+1), (i+1,j+1)
    for j in range(ny):
        for i in range(nx):
            n00 = p_nodal[i, j]
            n10 = p_nodal[i+1, j]
            n01 = p_nodal[i, j+1]
            n11 = p_nodal[i+1, j+1]
            dpdx[i, j] = 0.25 * (n10 - n00 + n11 - n01 + n10 - n00 + n11 - n01) / dx # Simplified from 3D version, double check averaging logic
            dpdy[i, j] = 0.25 * (n01 - n00 + n11 - n10 + n01 - n00 + n11 - n10) / dy # Simplified from 3D version, double check averaging logic

            # Let's use the exact 2D logic from the user's original gradient function if available
            # Or derive from the 3D C code: Dpx = 0.25 * oodx * (p[n001] - p[n000] + p[n011] - p[n010] + p[n101] - p[n100] + p[n111] - p[n110]);
            # In 2D: n00=p[i,j], n10=p[i+1,j], n01=p[i,j+1], n11=p[i+1,j+1]
            # Dpx = 0.5 * oodx * (p[n10] - p[n00] + p[n11] - p[n01]) equivalent to signs (-1, -1, +1, +1)
            # Dpy = 0.5 * oody * (p[n01] - p[n00] + p[n11] - p[n10]) equivalent to signs (-1, +1, -1, +1)
            dpdx[i, j] = 0.5 * (n10 - n00 + n11 - n01) / dx
            dpdy[i, j] = 0.5 * (n01 - n00 + n11 - n10) / dy

    return dpdx, dpdy

#@nb.jit(nopython=True, nogil=True, cache=True) # Numba might struggle with dicts/complex args easily
def apply_coriolis_transform_placeholder(u: np.ndarray, v: np.ndarray, w: np.ndarray, dt: float, coriolis_params: Tuple[float, float, float]):
    """
    Placeholder for the T_inverse transformation (Coriolis, buoyancy).
    Modifies u, v, w IN-PLACE. For testing, set effect to zero.
    """
    print("INFO: Using placeholder (identity) for Coriolis transform.")
    # To test non-identity, uncomment below:
    # u *= (1.0 + dt * 0.1)
    # v *= (1.0 - dt * 0.1)
    pass # No operation for identity transform

#@nb.jit(nopython=True, nogil=True, cache=True) # Numba might struggle with dicts/complex args easily
def calculate_correction_increments_nodes(p_nodal: np.ndarray, pTheta: np.ndarray, dx: float, dy: float, dt: float, nx: int, ny: int, coriolis_params: Tuple[float, float, float]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Calculates and returns the intermediate transformed flux increments.
    Mirrors C's correction_increments_nodes.
    """
    # Calculate gradient (nodal -> cell)
    dpdx, dpdy = gradient_nodal_to_cell(p_nodal, dx, dy, nx, ny)

    # Calculate initial flux increment (cell-centered)
    u = -dt * pTheta * dpdx
    v = -dt * pTheta * dpdy
    w = np.zeros_like(u) # 2D placeholder

    # Apply Coriolis/Buoyancy transform (T_inv) in-place
    apply_coriolis_transform_placeholder(u, v, w, dt, coriolis_params)

    # u, v, w now contain the transformed increments
    return u, v, w

@nb.jit(nopython=True, nogil=True, cache=True)
def divergence_cell_to_node(u_cell: np.ndarray, v_cell: np.ndarray, w_cell: np.ndarray,
                            dx: float, dy: float, dz: float, # dz needed for scaling even in 2D if C code uses it
                            nx: int, ny: int,
                            Xbot: np.ndarray, Xtop: np.ndarray, # Boundary factors
                            is_x_periodic: bool, is_y_periodic: bool) -> np.ndarray:
    """
    Calculates divergence at nodes from cell-centered vectors using scattering.
    Mirrors C's divergence_nodes_Pv. Output is on inner nodes (nx, ny).
    """
    div_nodal = np.zeros((nx + 1, ny + 1)) # Nodal grid, including boundaries
    oodx = 1.0 / dx
    oody = 1.0 / dy
    oodz = 1.0 / dz # Use non-zero dz=1.0 for 2D if C code did

    # Loop over cells (i_cell, j_cell) from 0 to nx-1, 0 to ny-1
    for j_cell in range(ny):
        for i_cell in range(nx):
            ne = i_cell + j_cell * nx # Placeholder index if needed, C code uses 2D indices

            # Get cell-centered velocity components
            u_c = u_cell[i_cell, j_cell]
            v_c = v_cell[i_cell, j_cell]
            w_c = w_cell[i_cell, j_cell] # Zero for 2D

            # Calculate flux terms to scatter
            tmpfx = 0.25 * oodx * u_c
            tmpfy = 0.25 * oody * v_c
            tmpfz = 0.25 * oodz * w_c # Scales by oodz even if w_c is 0

            # Get boundary scaling factors (assuming Xbot/Xtop are 1D arrays of size nx)
            isbot = (j_cell == 0)
            istop = (j_cell == ny - 1)
            # Ensure correct indexing for Xbot/Xtop (size nx)
            xb = Xbot[i_cell] if isbot else 1.0
            xt = Xtop[i_cell] if istop else 1.0
            # C code version DIV_HOR_SCALED_TO_VERTICAL_BDRY:
            # Multiplies tmpfx, tmpfz by Xbot/Xtop, adds tmpfy separately


            # Node indices (bottom-left node of cell is (i_cell, j_cell))
            nn00 = (i_cell)   + (j_cell)   * (nx + 1) # Node (i, j)
            nn10 = (i_cell+1) + (j_cell)   * (nx + 1) # Node (i+1, j)
            nn01 = (i_cell)   + (j_cell+1) * (nx + 1) # Node (i, j+1)
            nn11 = (i_cell+1) + (j_cell+1) * (nx + 1) # Node (i+1, j+1)

            # Scattering based on C code divergence_nodes_Pv (using Xbot/Xtop scaling)
            # Note: C code uses +/- signs differently depending on DIV_HOR_SCALED_TO_VERTICAL_BDRY
            # Let's use the one shown in divergence_nodes_Pv
            # div[nn000] += + tmpfy + (+ tmpfx + tmpfz) * Xbot; -> Node (i, j) gets from cell (i, j)
            # div[nn100] += + tmpfy + (- tmpfx + tmpfz) * Xbot; -> Node (i+1, j) gets from cell (i, j)
            # div[nn010] += - tmpfy + (+ tmpfx + tmpfz) * Xtop; -> Node (i, j+1) gets from cell (i, j)
            # div[nn110] += - tmpfy + (- tmpfx + tmpfz) * Xtop; -> Node (i+1, j+1) gets from cell (i, j)
            # And contributions from cells (i-1, j), (i, j-1), (i-1, j-1) need adding up

            # Reshape div_nodal temporarily for easy 2D indexing
            div_nodal_2d = div_nodal.reshape((ny + 1, nx + 1))

            # Scatter from cell (i_cell, j_cell) to its four surrounding nodes
            div_nodal_2d[j_cell,   i_cell]   += +tmpfy + (+tmpfx + tmpfz) * xb # nn00 from cell (i_cell, j_cell)
            div_nodal_2d[j_cell,   i_cell+1] += +tmpfy + (-tmpfx + tmpfz) * xb # nn10 from cell (i_cell, j_cell)
            div_nodal_2d[j_cell+1, i_cell]   += -tmpfy + (+tmpfx + tmpfz) * xt # nn01 from cell (i_cell, j_cell)
            div_nodal_2d[j_cell+1, i_cell+1] += -tmpfy + (-tmpfx + tmpfz) * xt # nn11 from cell (i_cell, j_cell)


    # Extract the inner nodes (nx, ny)
    # The nodal indices corresponding to inner nodes are [1:nx+1, 1:ny+1]
    # However, the C code loops imply the output size is nx*ny
    # The scattering accumulates at nodes (0..nx, 0..ny).
    # Which nodes correspond to the 'inner' result? Often node (i,j) result corresponds
    # to indices (i-igx, j-igy). Let's assume inner nodes are [igx:icx-igx, igy:icy-igy]
    # For simplicity, let igx=igy=1. Inner nodes are [1:nx+1, 1:ny+1].
    div_inner = div_nodal.reshape((ny + 1, nx + 1))[1:ny+1, 1:nx+1]

    return div_inner # Shape (ny, nx) -> transpose if needed

#@nb.jit(nopython=True, nogil=True, cache=True) # Numba might struggle with dicts/complex args easily
def calculate_new_laplacian_action(p_nodal: np.ndarray, pTheta: np.ndarray, helmholtz_coeff_nodal: np.ndarray,
                                   dx: float, dy: float, dz: float, dt: float, nx: int, ny: int,
                                   coriolis_params: Tuple[float, float, float],
                                   Xbot: np.ndarray, Xtop: np.ndarray,
                                   is_x_periodic: bool, is_y_periodic: bool) -> np.ndarray:
    """
    Calculates the action of the Laplacian operator L(p) using the new structure.
    Mirrors C code logic: increments -> divergence -> Helmholtz.
    """
    # Pad p_nodal for gradient calculation if periodic
    # Simple 1-layer padding sufficient for first-order differences
    p_padded = np.pad(p_nodal, ((1,1),(1,1)), mode='wrap' if is_x_periodic and is_y_periodic else 'constant')
    if is_x_periodic:
        p_padded[:, 0] = p_padded[:, -2]
        p_padded[:, -1] = p_padded[:, 1]
    if is_y_periodic:
        p_padded[0, :] = p_padded[-2, :]
        p_padded[-1, :] = p_padded[1, :]
        # Corners if both periodic
        if is_x_periodic:
             p_padded[0, 0] = p_padded[-2, -2]
             p_padded[0, -1] = p_padded[-2, 1]
             p_padded[-1, 0] = p_padded[1, -2]
             p_padded[-1, -1] = p_padded[1, 1]

    # Calculate Transformed Flux Increments (Cell Centered) using padded pressure
    # Note: Need to adjust nx,ny input to gradient if p_padded is used directly
    # Or, implement gradient to handle padding internally. Let's adjust nx, ny.
    # The core calculation area for gradient is still nx x ny cells.
    # It reads from p_padded which has size (nx+3, ny+3) effectively? No, pad is (nx+1)+2 = nx+3.
    # Let's stick to passing p_nodal and handle boundaries inside gradient (more complex)
    # Easier: Pass p_padded, but gradient needs modification or careful indexing.

    # Let's assume gradient works on (nx+1, ny+1) input and needs padding handled outside.
    # We already padded p_nodal. The gradient function as written accesses p_nodal[i+1, j+1] max.
    # So passing the original p_nodal (nx+1, ny+1) to gradient should be fine if it handles boundary indexing.
    # Revisit gradient_nodal_to_cell - it assumes input is (nx+1, ny+1)
    # Let's assume p_nodal already has ghost cells filled appropriately.
    u_incr, v_incr, w_incr = calculate_correction_increments_nodes(
        p_nodal, pTheta, dx, dy, dt, nx, ny, coriolis_params
    )

    # Calculate Divergence of Increments (Nodal Inner)
    divergence_term = divergence_cell_to_node(
        u_incr, v_incr, w_incr, dx, dy, dz, nx, ny, Xbot, Xtop, is_x_periodic, is_y_periodic
    ).T

    # Add Helmholtz Term
    p_inner = p_nodal[1:nx+1, 1:ny+1]
    helmholtz_coeff_inner = helmholtz_coeff_nodal[1:nx+1, 1:ny+1]

    if divergence_term.shape != helmholtz_coeff_inner.shape or divergence_term.shape != p_inner.shape:
        raise ValueError(f"Shape mismatch for Helmholtz term: Div={divergence_term.shape}, Coeff={helmholtz_coeff_inner.shape}, P={p_inner.shape}")

    helmholtz_term_result = helmholtz_coeff_inner * p_inner

    # Combine
    lap_result = divergence_term + helmholtz_term_result # Shape (nx, ny)

    # Preconditioning/Scaling (diag_inv) is omitted here
    return lap_result

# ==========================================================================
# 2. lap2D_gather Function (Copied from user prompt)
# ==========================================================================

@nb.jit(nopython=True, nogil=False, cache=True)
def lap2D_gather(p, igx, igy, iicxn, iicyn, hplusx, hplusy, hcenter, oodx, oody, x_periodic, y_periodic, x_wall, y_wall,
                 diag_inv, coriolis):
    ngnc = (iicxn) * (iicyn)
    lap = np.zeros((ngnc))
    cnt_x = 0
    cnt_y = 0

    nine_pt = 0.25 * (2.0) * 1.0
    cyy, cxx, cyx, cxy = coriolis
    oodx2 = 0.5 * oodx ** 2
    oody2 = 0.5 * oody ** 2

    for idx in range(iicxn * iicyn):
        ne_topleft = idx - iicxn - 1
        ne_topright = idx - iicxn
        ne_botleft = idx - 1
        ne_botright = idx

        # get indices of the 9pt stencil
        topleft_idx = idx - iicxn - 1
        midleft_idx = idx - 1
        botleft_idx = idx + iicxn - 1

        topmid_idx = idx - iicxn
        midmid_idx = idx
        botmid_idx = idx + iicxn

        topright_idx = idx - iicxn + 1
        midright_idx = idx + 1
        botright_idx = idx + iicxn + 1

        if cnt_x == 0:
            topleft_idx += iicxn - 1
            midleft_idx += iicxn - 1
            botleft_idx += iicxn - 1

            ne_topleft += iicxn - 1
            ne_botleft += iicxn - 1

        if cnt_x == (iicxn - 1):
            topright_idx -= iicxn - 1
            midright_idx -= iicxn - 1
            botright_idx -= iicxn - 1

            ne_topright -= iicxn - 1
            ne_botright -= iicxn - 1

        if cnt_y == 0:
            topleft_idx += ((iicxn) * (iicyn - 1))
            topmid_idx += ((iicxn) * (iicyn - 1))
            topright_idx += ((iicxn) * (iicyn - 1))

            ne_topleft += ((iicxn) * (iicyn - 1))
            ne_topright += ((iicxn) * (iicyn - 1))

        if cnt_y == (iicyn - 1):
            botleft_idx -= ((iicxn) * (iicyn - 1))
            botmid_idx -= ((iicxn) * (iicyn - 1))
            botright_idx -= ((iicxn) * (iicyn - 1))

            ne_botleft -= ((iicxn) * (iicyn - 1))
            ne_botright -= ((iicxn) * (iicyn - 1))

        topleft = p[topleft_idx]
        midleft = p[midleft_idx]
        botleft = p[botleft_idx]

        topmid = p[topmid_idx]
        midmid = p[midmid_idx]
        botmid = p[botmid_idx]

        topright = p[topright_idx]
        midright = p[midright_idx]
        botright = p[botright_idx]

        hplusx_topleft = hplusx[ne_topleft]
        hplusx_botleft = hplusx[ne_botleft]
        hplusy_topleft = hplusy[ne_topleft]
        hplusy_botleft = hplusy[ne_botleft]

        hplusx_topright = hplusx[ne_topright]
        hplusx_botright = hplusx[ne_botright]
        hplusy_topright = hplusy[ne_topright]
        hplusy_botright = hplusy[ne_botright]

        cxx_tl = cxx[ne_topleft]
        cxx_tr = cxx[ne_topright]
        cxx_bl = cxx[ne_botleft]
        cxx_br = cxx[ne_botright]

        cxy_tl = cxy[ne_topleft]
        cxy_tr = cxy[ne_topright]
        cxy_bl = cxy[ne_botleft]
        cxy_br = cxy[ne_botright]

        cyx_tl = cyx[ne_topleft]
        cyx_tr = cyx[ne_topright]
        cyx_bl = cyx[ne_botleft]
        cyx_br = cyx[ne_botright]

        cyy_tl = cyy[ne_topleft]
        cyy_tr = cyy[ne_topright]
        cyy_bl = cyy[ne_botleft]
        cyy_br = cyy[ne_botright]

        if x_wall and (cnt_x == 0):
            hplusx_topleft = 0.
            hplusy_topleft = 0.
            hplusx_botleft = 0.
            hplusy_botleft = 0.

        if x_wall and (cnt_x == (iicxn - 1)):
            hplusx_topright = 0.
            hplusy_topright = 0.
            hplusx_botright = 0.
            hplusy_botright = 0.

        if y_wall and (cnt_y == 0):
            hplusx_topleft = 0.
            hplusy_topleft = 0.
            hplusx_topright = 0.
            hplusy_topright = 0.

        if y_wall and (cnt_y == (iicyn - 1)):
            hplusx_botleft = 0.
            hplusy_botleft = 0.
            hplusx_botright = 0.
            hplusy_botright = 0.

        Dx_tl = 0.5 * (topmid - topleft + midmid - midleft) * hplusx_topleft
        Dx_tr = 0.5 * (topright - topmid + midright - midmid) * hplusx_topright
        Dx_bl = 0.5 * (botmid - botleft + midmid - midleft) * hplusx_botleft
        Dx_br = 0.5 * (botright - botmid + midright - midmid) * hplusx_botright

        Dy_tl = 0.5 * (midmid - topmid + midleft - topleft) * hplusy_topleft
        Dy_tr = 0.5 * (midright - topright + midmid - topmid) * hplusy_topright
        Dy_bl = 0.5 * (botmid - midmid + botleft - midleft) * hplusy_botleft
        Dy_br = 0.5 * (botright - midright + botmid - midmid) * hplusy_botright

        fac = 1.0
        Dxx = 0.5 * (cxx_tr * Dx_tr - cxx_tl * Dx_tl + cxx_br * Dx_br - cxx_bl * Dx_bl) * oodx * oodx * fac
        Dyy = 0.5 * (cyy_br * Dy_br - cyy_tr * Dy_tr + cyy_bl * Dy_bl - cyy_tl * Dy_tl) * oody * oody * fac
        Dyx = 0.5 * (cyx_br * Dy_br - cyx_bl * Dy_bl + cyx_tr * Dy_tr - cyx_tl * Dy_tl) * oody * oodx * fac
        Dxy = 0.5 * (cxy_br * Dx_br - cxy_tr * Dx_tr + cxy_bl * Dy_bl - cxy_tl * Dx_tl) * oodx * oody * fac

        lap[idx] = Dxx + Dyy + Dyx + Dxy + hcenter[idx] * p[idx]

        lap[idx] *= diag_inv[idx]

        cnt_x += 1
        if cnt_x % iicxn == 0:
            cnt_y += 1
            cnt_x = 0

    return lap

# ==========================================================================
# 3. Comparison Test Setup
# ==========================================================================

print("Setting up comparison test...")

# Grid parameters
nx = 20  # Inner cells x
ny = 15  # Inner cells y
igx = 2  # Ghost cells x (assuming 1 layer)
igy = 2  # Ghost cells y
dx = 0.1
dy = 0.12
dz = 1.0 # For 2D scaling consistency if needed
oodx = 1.0 / dx
oody = 1.0 / dy

# Physics/Solver parameters
dt = 0.01
coriolis_params = (0.0, 0.0, 0.0) # f_x, f_y, f_z -> Set to zero for simpler test

# Boundary Conditions
is_x_periodic = True
is_y_periodic = True
is_x_wall = not is_x_periodic
is_y_wall = not is_y_periodic

# --- Generate Consistent Data ---

# Nodal Pressure (including ghost/boundary nodes)
# p_nodal = np.random.rand(nx +1 + 2*igx, ny+1 + 2*igy)
p_nodal = np.random.rand(nx  + 2*igx, ny + 2*igy)

# Cell-centered coefficient 'pTheta' (used for hplusx, hplusy)
pTheta_cell = calculate_pTheta_placeholder(nx, ny)

# Nodal Helmholtz coefficient 'helmholtz_coeff_nodal' (used for hcenter)
helmholtz_coeff_nodal = calculate_helmholtz_coeff_placeholder(nx, ny) # Shape (nx+1, ny+1)

# Boundary factors Xbot, Xtop (placeholder, size nx)
Xbot = np.ones(nx)
Xtop = np.ones(nx)
if is_y_wall: # Simple mimic of C code logic
    Xbot *= 0.5 # Example scaling factor
    Xtop *= 0.5

# --- Prepare Inputs for New Laplacian Function ---
# (Uses data in its natural grid locations)

# --- Prepare Inputs for lap2D_gather ---
iicxn = nx
iicyn = ny
p_nodal_flat = p_nodal.flatten() # Flatten the full nodal array

# lap2D_gather coefficients (need to be flattened inner arrays, size nx*ny)
# Use pTheta for hplusx, hplusy
hplusx_flat = pTheta_cell.flatten()
hplusy_flat = pTheta_cell.flatten()

# Use inner part of helmholtz_coeff for hcenter
hcenter_flat = helmholtz_coeff_nodal[igy:ny+igy, igx:nx+igx].flatten()

# Placeholder for diag_inv (set to 1.0 for no effect)
diag_inv_flat = np.ones(nx * ny)

# Coriolis components for lap2D_gather (flattened, size nx*ny)
# Order: cyy, cxx, cyx, cxy. Set to zero consistent with coriolis_params=(0,0,0)
cyy_flat = np.zeros(nx * ny)
cxx_flat = np.zeros(nx * ny)
cyx_flat = np.zeros(nx * ny)
cxy_flat = np.zeros(nx * ny)
coriolis_gather_tuple = (cyy_flat, cxx_flat, cyx_flat, cxy_flat)

# ==========================================================================
# 4. Run Calculations
# ==========================================================================

print("\nRunning New Laplacian Method...")
result_new = calculate_new_laplacian_action(
    p_nodal,
    pTheta_cell,
    helmholtz_coeff_nodal,
    dx, dy, dz, dt, nx, ny,
    coriolis_params,
    Xbot, Xtop,
    is_x_periodic, is_y_periodic
)
print("Done.")

print("\nRunning lap2D_gather Method...")
# Ensure lap2D_gather uses consistent parameters
try:
    result_gather_flat = lap2D_gather(
        p_nodal_flat, # Pass full flattened nodal p
        igx, igy,
        iicxn, iicyn,
        hplusx_flat, hplusy_flat, hcenter_flat,
        oodx, oody,
        is_x_periodic, is_y_periodic, is_x_wall, is_y_wall,
        diag_inv_flat,
        coriolis_gather_tuple
    )
    result_gather = result_gather_flat.reshape((ny, nx)) # Reshape to (ny, nx)
    print("Done.")

    # ==========================================================================
    # 5. Compare Results
    # ==========================================================================
    print("\nComparing results (New Laplacian vs lap2D_gather):")

    # Ensure shapes match (result_new should be (nx, ny) from divergence)
    if result_new.shape == (nx, ny):
         result_new_compare = result_new.T # Transpose to match (ny, nx)
    elif result_new.shape == (ny, nx):
         result_new_compare = result_new
    else:
         print(f"ERROR: Unexpected shape for result_new: {result_new.shape}")
         result_new_compare = None


    inslice = tuple([slice(4,-4)]*2)

    if result_new_compare is not None and result_gather.shape == result_new_compare.shape:
        difference = np.abs(result_new_compare[inslice] - result_gather[inslice])
        max_abs_diff = np.max(difference)
        mean_abs_diff = np.mean(difference)
        max_rel_diff = np.max(difference / (np.abs(result_gather[inslice]) + 1e-15)) # Avoid division by zero

        print(f"Max: {np.max(result_new_compare):.6e}")
        print(f"Min:  {np.min(result_new_compare):.6e}")
        print(f"Mean: {np.mean(result_new_compare):.6e}")

        print(f"Shapes: New={result_new_compare.shape}, Gather={result_gather.shape}")
        print(f"Maximum absolute difference: {max_abs_diff:.6e}")
        print(f"Mean absolute difference:  {mean_abs_diff:.6e}")
        print(f"Maximum relative difference: {max_rel_diff:.6e}")

        # Adjust tolerance as needed
        if np.allclose(result_new_compare, result_gather, rtol=1e-6, atol=1e-8):
            print("\nResults are numerically close!")
        else:
            print("\nResults differ significantly.")
            # Optional: Print indices where difference is large
            # max_diff_idx = np.unravel_index(np.argmax(difference), difference.shape)
            # print(f"Max difference at index {max_diff_idx}:")
            # print(f"  New Laplacian: {result_new_compare[max_diff_idx]}")
            # print(f"  lap2D_gather: {result_gather[max_diff_idx]}")
    else:
        print("ERROR: Shapes do not match for comparison.")
        print(f"Shapes: New (transposed)={result_new_compare.shape if result_new_compare is not None else 'N/A'}, Gather={result_gather.shape}")

except Exception as e:
    print("\nAn error occurred during lap2D_gather execution or comparison:")
    print(e)
    import traceback
    traceback.print_exc()

Setting up comparison test...
INFO: Using placeholder for pTheta coefficient.
INFO: Using placeholder for Helmholtz coefficient.

Running New Laplacian Method...
INFO: Using placeholder (identity) for Coriolis transform.
Done.

Running lap2D_gather Method...
Done.

Comparing results (New Laplacian vs lap2D_gather):
Max: 7.113838e-01
Min:  -7.036523e-01
Mean: 2.559678e-02
Shapes: New=(15, 20), Gather=(15, 20)
Maximum absolute difference: 3.215502e+228
Mean absolute difference:  3.827979e+226
Maximum relative difference: 1.871631e+14

Results differ significantly.


In [6]:
import numpy as np
from typing import Tuple

def gradient_2d_numpy(p: np.ndarray, dx: float, dy: float) -> Tuple[np.ndarray, np.ndarray]:
    """
    Calculate the discrete gradient of a 2D scalar field using NumPy slicing.

    The gradient is calculated at cell centers based on nodal values of p.
     mimics the calculation specified in eq. (30a) in BK19 paper for 2D.

    Dpx[i, j] = (p[i+1, j] + p[i+1, j+1] - p[i, j] - p[i, j+1]) * 0.5 / dx
    Dpy[i, j] = (p[i, j+1] + p[i+1, j+1] - p[i, j] - p[i+1, j]) * 0.5 / dy

    Parameters
    ----------
    p : np.ndarray of shape (nx+1, ny+1)
        The nodal scalar field.
    dx : float
        Grid spacing in the x-direction.
    dy : float
        Grid spacing in the y-direction.

    Returns
    -------
    Tuple[np.ndarray, np.ndarray]
        The gradient components (Dpx, Dpy), each of shape (nx, ny).
        The gradient is defined on cells.
    """
    if p.ndim != 2:
        raise ValueError(f"Input array p must be 2D, but got ndim={p.ndim}")

    # Calculate Dpx using array slicing
    # Dpx represents the average gradient across the cell in x-direction
    # It averages (p[i+1, j] - p[i, j])/dx and (p[i+1, j+1] - p[i, j+1])/dx
    term1_px = p[1:, :-1]  # p[i+1, j]
    term2_px = p[1:, 1:]   # p[i+1, j+1]
    term3_px = p[:-1, :-1] # p[i, j]
    term4_px = p[:-1, 1:]  # p[i, j+1]
    Dpx = (term1_px + term2_px - term3_px - term4_px) * (0.5 / dx)

    # Calculate Dpy using array slicing
    # Dpy represents the average gradient across the cell in y-direction
    # It averages (p[i, j+1] - p[i, j])/dy and (p[i+1, j+1] - p[i+1, j])/dy
    term1_py = p[:-1, 1:]  # p[i, j+1]
    term2_py = p[1:, 1:]   # p[i+1, j+1]
    term3_py = p[:-1, :-1] # p[i, j]
    term4_py = p[1:, :-1]  # p[i+1, j]
    Dpy = (term1_py + term2_py - term3_py - term4_py) * (0.5 / dy)

    return Dpx, Dpy

In [10]:
import numpy as np
import numba
from typing import Tuple

@numba.njit(cache=True, fastmath=True, nogil=True)
def gradient_2d_numba(p: np.ndarray, dx: float, dy: float) -> Tuple[np.ndarray, np.ndarray]:
    """Numba kernel for 2D gradient (cell-centered)."""
    nx = p.shape[0] - 1
    ny = p.shape[1] - 1
    if nx < 0 or ny < 0:
         return (np.empty((max(0,nx), max(0,ny)), dtype=p.dtype),
                 np.empty((max(0,nx), max(0,ny)), dtype=p.dtype))
    if nx == 0 or ny == 0:
         return (np.empty((nx, ny), dtype=p.dtype),
                 np.empty((nx, ny), dtype=p.dtype))


    Dpx = np.empty((nx, ny), dtype=p.dtype)
    Dpy = np.empty((nx, ny), dtype=p.dtype)

    inv_dx_half = 0.5 / dx
    inv_dy_half = 0.5 / dy

    for i in range(nx):
        for j in range(ny):
            p_i_j     = p[i, j]
            p_ip1_j   = p[i+1, j]
            p_i_jp1   = p[i, j+1]
            p_ip1_jp1 = p[i+1, j+1]

            # Dpx = avg gradient in x across the cell (i,j)
            Dpx[i, j] = (p_ip1_j + p_ip1_jp1 - p_i_j - p_i_jp1) * inv_dx_half
            # Dpy = avg gradient in y across the cell (i,j)
            Dpy[i, j] = (p_i_jp1 + p_ip1_jp1 - p_i_j - p_ip1_j) * inv_dy_half

    return Dpx, Dpy

In [12]:
import numpy as np
import numba
import timeit
import functools # To use partial for timeit setup

# --- Include the function definitions from above ---
# def gradient_2d_numpy(p: np.ndarray, dx: float, dy: float) -> Tuple[np.ndarray, np.ndarray]: ...
# @numba.njit(cache=True, fastmath=True)
# def gradient_2d_numba(p: np.ndarray, dx: float, dy: float) -> Tuple[np.ndarray, np.ndarray]: ...
# --- Assume they are defined here ---

# --- Setup Parameters ---
nx, ny = 500, 500  # Grid size (number of cells)
dx, dy = 0.01, 0.01 # Grid spacing

# Create sample nodal data (shape nx+1, ny+1)
p_nodal = np.random.rand(nx + 1, ny + 1)

print(f"Benchmarking on a grid of size: ({nx} x {ny}) cells")
print(f"Input array shape: {p_nodal.shape}")
print("-" * 30)

# --- Correctness Check (Important!) ---
Dpx_np, Dpy_np = gradient_2d_numpy(p_nodal, dx, dy)
# Run Numba once for compilation (warm-up)
Dpx_nb, Dpy_nb = gradient_2d_numba(p_nodal, dx, dy)

try:
    np.testing.assert_allclose(Dpx_np, Dpx_nb, rtol=1e-10, atol=1e-12)
    np.testing.assert_allclose(Dpy_np, Dpy_nb, rtol=1e-10, atol=1e-12)
    print("Correctness check PASSED: NumPy and Numba results are close.")
except AssertionError as e:
    print(f"Correctness check FAILED: {e}")
    # Decide whether to proceed with benchmarking if results differ
    # exit() # Optional: stop if results don't match

print("-" * 30)

# --- Benchmarking ---
n_runs = 100 # Number of times to execute the function within one timeit measurement
n_repeat = 7 # Number of times to repeat the measurement

# Use functools.partial to pass arguments to the functions being timed
numpy_timer = timeit.Timer(functools.partial(gradient_2d_numpy, p_nodal, dx, dy))
numba_timer = timeit.Timer(functools.partial(gradient_2d_numba, p_nodal, dx, dy))

# Time the NumPy version
try:
    numpy_times = numpy_timer.repeat(repeat=n_repeat, number=n_runs)
    numpy_avg_time = np.mean(numpy_times) / n_runs
    numpy_std_time = np.std(numpy_times) / n_runs
    print(f"NumPy Version:")
    print(f"  Average time: {numpy_avg_time:.6f} seconds")
    print(f"  Std dev:      {numpy_std_time:.6f} seconds")
except Exception as e:
    print(f"Could not benchmark NumPy version: {e}")


# Time the Numba version (already compiled during warm-up)
try:
    numba_times = numba_timer.repeat(repeat=n_repeat, number=n_runs)
    numba_avg_time = np.mean(numba_times) / n_runs
    numba_std_time = np.std(numba_times) / n_runs
    print(f"Numba Version:")
    print(f"  Average time: {numba_avg_time:.6f} seconds")
    print(f"  Std dev:      {numba_std_time:.6f} seconds")
except Exception as e:
    print(f"Could not benchmark Numba version: {e}")


# --- Comparison ---
print("-" * 30)
if 'numpy_avg_time' in locals() and 'numba_avg_time' in locals():
    speedup = numpy_avg_time / numba_avg_time
    print(f"Numba speedup over NumPy: {speedup:.2f}x")
elif 'numba_avg_time' not in locals():
     print("Cannot calculate speedup because Numba benchmark failed.")
else:
     print("Cannot calculate speedup because NumPy benchmark failed.")

Benchmarking on a grid of size: (500 x 500) cells
Input array shape: (501, 501)
------------------------------
Correctness check PASSED: NumPy and Numba results are close.
------------------------------
NumPy Version:
  Average time: 0.000969 seconds
  Std dev:      0.000035 seconds
Numba Version:
  Average time: 0.000098 seconds
  Std dev:      0.000002 seconds
------------------------------
Numba speedup over NumPy: 9.88x


In [13]:
import numpy as np
from typing import List, Tuple

def derivative_np(variable: np.ndarray, axis: int, ndim: int, dxyz: List[float]):
    """Standalone NumPy version of the derivative function.

    Calculates the nodal derivative of the given cell variable in the given direction.
    Logic matches the original class method.

    Parameters
    ----------
    variable : np.ndarray
        The cell-centered values. Shape (nx, [ny], [nz]).
    axis : int
        Axis along which to differentiate (0, 1, or 2).
    ndim : int
        Number of dimensions (1, 2, or 3).
    dxyz : List[float]
        List of grid spacings [dx, dy, dz].

    Returns
    -------
    np.ndarray
        Nodal derivative. Shape (nx-1, [ny-1], [nz-1]).
    """
    if axis >= ndim:
        raise ValueError(f"Axis {axis} is out of bounds for ndim {ndim}")
    if len(dxyz) < ndim:
        raise ValueError(f"dxyz must have at least {ndim} elements for ndim={ndim}")

    # Discretization fineness
    ds = dxyz[axis]
    if ds == 0:
        # Handle potential division by zero if dx is zero
        # Returning zeros of the expected shape might be one approach
        print("Warning: ds is zero in derivative_np. Returning zeros.")
        out_shape = tuple(s - 1 for s in variable.shape)
        return np.zeros(out_shape, dtype=variable.dtype)


    # Bring the differentiation axis to the front.
    u_flat = np.moveaxis(variable, axis, 0)

    # Compute the primary difference along the first axis.
    # Use np.diff which computes x[1:] - x[:-1]
    d = np.diff(u_flat, axis=0) / ds

    # Based on the dimensionality, average over the complementary axes.
    if ndim == 1:
        # 1D case: no additional averaging required. d shape (nx-1,)
        result = d
    elif ndim == 2:
        # 2D case: average along the second axis (which was not axis 0).
        # d has shape (n-1, m) if axis=0, or (m-1, n) if axis=1.
        # Average is always along the second dimension of d.
        if d.shape[1] < 2: # Check if averaging is possible
             result = np.empty(d.shape[:-1] + (0,), dtype=d.dtype) # Resulting shape has 0 size in the averaged dim
        else:
             result = 0.5 * (d[:, :-1] + d[:, 1:]) # Result shape (n-1, m-1)
    elif ndim == 3:
        # 3D case: average along both the second and third axes.
        # d has shape (n-1, m, p) or similar permutations.
        # Average is always along the second and third dimensions of d.
        if d.shape[1] < 2 or d.shape[2] < 2: # Check if averaging is possible
             out_shape_list = list(d.shape)
             if d.shape[1] < 2: out_shape_list[1] = 0
             else: out_shape_list[1] -= 1
             if d.shape[2] < 2: out_shape_list[2] = 0
             else: out_shape_list[2] -= 1
             result = np.empty(tuple(out_shape_list), dtype=d.dtype)
        else:
            result = (
                d[:, :-1, :-1] + d[:, :-1, 1:] + d[:, 1:, :-1] + d[:, 1:, 1:]
            ) / 4.0 # Result shape (n-1, m-1, p-1)
    else:
        raise ValueError("Only 1D, 2D, or 3D arrays are supported.")

    # Move the differentiation axis back to its original location if ndim > 1.
    # If ndim=1, result is 1D, moveaxis isn't needed/meaningful here.
    if result.ndim > 0 and result.ndim == variable.ndim: # Only move if result has dims and matches original rank
         # Need to calculate the correct target axis after averaging reduction
         original_axes = list(range(variable.ndim))
         target_axis = original_axes[axis]
         return np.moveaxis(result, 0, target_axis)
    else:
        # Handle cases where dimensions were reduced (e.g. 2D input, ny=1 -> 1D output)
        # Or the 1D case where moveaxis isn't needed.
        return result


def divergence_np(vector: np.ndarray, ndim: int, dxyz: List[float]) -> np.ndarray:
    """Standalone NumPy version of the divergence function.

    Calculates the divergence of the cell-centered vector using derivative_np.
    Logic matches the original class method.

    Parameters
    ----------
    vector : np.ndarray
        The cell-centered vector field. Shape (nx, [ny], [nz], ndim).
    ndim : int
        Number of dimensions (1, 2, or 3).
    dxyz : List[float]
        List of grid spacings [dx, dy, dz].

    Returns
    -------
    np.ndarray
        Divergence evaluated at the nodes. Shape (nx-1, [ny-1], [nz-1]).
    """
    if vector.shape[-1] != ndim:
        raise ValueError(
            f"Last dim of vector ({vector.shape[-1]}) must match ndim ({ndim})."
        )
    if vector.ndim != ndim + 1:
        raise ValueError(
            f"Vector dimensions ({vector.ndim}) must be ndim+1 ({ndim+1})."
            )

    # Determine output shape: (nx-1, [ny-1], [nz-1])
    # Need to handle cases where input shape might be 1 in some dims
    output_shape = tuple(max(0, s - 1) for s in vector.shape[:-1])

    # Pre-allocation.
    Ux = np.zeros(output_shape, dtype=vector.dtype)

    # Calculate divergence by summing derivatives
    for axis in range(ndim):
        component = vector[..., axis]
        # Ensure component has the correct dimensions before passing
        if component.ndim != ndim:
             raise RuntimeError(f"Internal error: Component slice has unexpected ndim {component.ndim}")
        deriv_component = derivative_np(component, axis, ndim, dxyz)

        # Check if shapes match before adding (can happen if input dims were 1)
        if deriv_component.shape == Ux.shape:
             Ux += deriv_component
        elif Ux.size == 0 and deriv_component.size == 0:
             pass # Adding zero to zero is fine
        else:
            # This case indicates an issue, likely due to input dimensions being 1
            # which causes derivative_np to return a shape incompatible with the pre-allocated Ux
             raise RuntimeError(f"Shape mismatch adding derivative component. Ux shape: {Ux.shape}, deriv_component shape: {deriv_component.shape}. Check input dimensions.")


    return Ux

In [14]:
import numba

# Copied from previous answer for completeness
@numba.njit(cache=True, fastmath=True, nogil=True)
def _divergence_node_centered_2d_numba(
    vector_field: np.ndarray, # Shape (nx, ny, 2) - Cell-centered
    dx: float,
    dy: float
) -> np.ndarray:
    """
    Calculates 2D divergence at nodes based on cell-centered vector components.
    Output shape: (nx-1, ny-1) - Node-centered
    """
    nx, ny = vector_field.shape[0], vector_field.shape[1]
    nnx = max(0, nx - 1)
    nny = max(0, ny - 1)
    dtype = vector_field.dtype

    if nnx == 0 or nny == 0:
        return np.empty((nnx, nny), dtype=dtype)

    div_result = np.empty((nnx, nny), dtype=dtype)
    inv_dx = 1.0 / dx if dx != 0 else 0.0
    inv_dy = 1.0 / dy if dy != 0 else 0.0
    scale_xy = 0.5 # Averaging factor

    u = vector_field[:, :, 0] # Component along axis 0
    v = vector_field[:, :, 1] # Component along axis 1

    for i in range(nnx): # Node index i
        for j in range(nny): # Node index j
            # Term d(u)/dx: diff along axis 0, average along axis 1
            diff_u_j   = (u[i+1, j]   - u[i, j])
            diff_u_jp1 = (u[i+1, j+1] - u[i, j+1])
            term_x = scale_xy * (diff_u_j + diff_u_jp1) * inv_dx

            # Term d(v)/dy: diff along axis 1, average along axis 0
            diff_v_i   = (v[i,   j+1] - v[i,   j])
            diff_v_ip1 = (v[i+1, j+1] - v[i+1, j])
            term_y = scale_xy * (diff_v_i + diff_v_ip1) * inv_dy

            div_result[i, j] = term_x + term_y

    return div_result

In [15]:
import timeit
import functools

# --- Setup Parameters (2D) ---
nx, ny = 500, 500  # Grid size (number of cells)
dx, dy = 0.01, 0.02 # Grid spacing
ndim = 2
dxyz = [dx, dy]

# Create sample cell-centered vector data
vector_2d_cells = np.random.rand(nx, ny, ndim) # Shape (nx, ny, 2)

print(f"Benchmarking on a cell grid of size: ({nx} x {ny})")
print(f"Input vector shape: {vector_2d_cells.shape}")
print("-" * 30)

# --- Correctness Check ---
print("Running correctness check...")
try:
    div_np = divergence_np(vector_2d_cells, ndim, dxyz)
    # Run Numba once for compilation (warm-up)
    div_nb = _divergence_node_centered_2d_numba(vector_2d_cells, dx, dy)

    np.testing.assert_allclose(div_np, div_nb, rtol=1e-12, atol=1e-14)
    print("Correctness check PASSED: NumPy and Numba results are close.")
    print(f"Output shape: {div_np.shape}") # Should be (nx-1, ny-1)
except Exception as e:
    print(f"Correctness check FAILED: {e}")
    # Decide whether to proceed
    exit() # Stop if results don't match

print("-" * 30)

# --- Benchmarking ---
n_runs = 50 # Number of times to execute the function within one timeit measurement
n_repeat = 7 # Number of times to repeat the measurement

# Use functools.partial to pass arguments
numpy_timer = timeit.Timer(functools.partial(divergence_np, vector_2d_cells, ndim, dxyz))
# Numba function only needs vector, dx, dy
numba_timer = timeit.Timer(functools.partial(_divergence_node_centered_2d_numba, vector_2d_cells, dx, dy))

# Time the NumPy version
try:
    numpy_times = numpy_timer.repeat(repeat=n_repeat, number=n_runs)
    numpy_avg_time = np.mean(numpy_times) / n_runs
    numpy_std_time = np.std(numpy_times) / n_runs
    print(f"NumPy Standalone Version:")
    print(f"  Average time: {numpy_avg_time:.6f} seconds")
    print(f"  Std dev:      {numpy_std_time:.6f} seconds")
except Exception as e:
    print(f"Could not benchmark NumPy version: {e}")


# Time the Numba version (already compiled during warm-up)
try:
    numba_times = numba_timer.repeat(repeat=n_repeat, number=n_runs)
    numba_avg_time = np.mean(numba_times) / n_runs
    numba_std_time = np.std(numba_times) / n_runs
    print(f"Numba 2D Kernel Version:")
    print(f"  Average time: {numba_avg_time:.6f} seconds")
    print(f"  Std dev:      {numba_std_time:.6f} seconds")
except Exception as e:
    print(f"Could not benchmark Numba version: {e}")


# --- Comparison ---
print("-" * 30)
if 'numpy_avg_time' in locals() and 'numba_avg_time' in locals():
    speedup = numpy_avg_time / numba_avg_time
    print(f"Numba speedup over NumPy: {speedup:.2f}x")
elif 'numba_avg_time' not in locals():
     print("Cannot calculate speedup because Numba benchmark failed.")
else:
     print("Cannot calculate speedup because NumPy benchmark failed.")

Benchmarking on a cell grid of size: (500 x 500)
Input vector shape: (500, 500, 2)
------------------------------
Running correctness check...
Correctness check PASSED: NumPy and Numba results are close.
Output shape: (499, 499)
------------------------------
NumPy Standalone Version:
  Average time: 0.001641 seconds
  Std dev:      0.000348 seconds
Numba 2D Kernel Version:
  Average time: 0.000728 seconds
  Std dev:      0.000110 seconds
------------------------------
Numba speedup over NumPy: 2.26x
