In [None]:
"""
    uh         = u (m/s) at scalar points
    vh         = v (m/s) at scalar points
    rmh        =
    mf         =
    rmf        =
    pi0        = basic-state array
    thv0       = basic-state array
    rho0       = basic-state array (density)
    rf0        = basic-state array
    dum1       =
    dum2       =
    dum3       =
    def_array  =
    divx       =
    ppi        = perturbation nondimensional pressure ("Exner function")
    uten       =
    vten       =
    wten       =
    cfb        =
    cfa        =
    cfc        =
    d1         =
    d2         =
    pdt        =
    lgbth      =
    lgbph      =
    rhs        =
    trans      =
    ddtmp      =
"""

In [None]:
import numpy as np
from scipy.fft import fftn, ifftn

def poiss(uh, vh, mh, rmh, mf, rmf, pi0, thv0, rho0, rf0, 
          dum1, dum2, dum3, def_array, divx, ppi, uten, vten, wten, 
          cfb, cfa, cfc, d1, d2, pdt, lgbth, lgbph, rhs, trans, dttmp,
          axisymm=False, imirror=False, jmirror=False, timestats=False):
    nk, nj, ni = mh.shape
    rdx, rdy, rdz = 1.0, 1.0, 1.0  # Define grid spacings (modify as needed)
    
    if axisymm:
        raise ValueError("The anelastic/incompressible solver cannot be used with the axisymmetric model.")
    
    tem = 1.0 / dttmp
    r1 = np.zeros(nk)
    deft2 = np.zeros(nk)
    
    # Loop over vertical levels (k)
    for k in range(nk):
        # Compute `def` array and `rhs`
        for j in range(nj):
            for i in range(ni):
                def_array[i, j, k] = (
                    rho0[0, 0, k] * (
                        (uten[i + 1, j, k] - uten[i, j, k]) * rdx * uh[i] +
                        (vten[i, j + 1, k] - vten[i, j, k]) * rdy * vh[j]
                    ) +
                    (rf0[i, j, k + 1] * wten[i, j, k + 1] -
                     rf0[i, j, k] * wten[i, j, k]) * rdz * mh[i, j, k] +
                    divx[i, j, k] * tem
                )
                rhs[i, j] = complex(def_array[i, j, k] * d1[k], 0.0)
        
        # Mirror operations
        if imirror:
            for j in range(nj):
                for i in range(ni):
                    rhs[ni - i - 1, j] = rhs[i, j]
        
        if jmirror:
            for j in range(nj):
                for i in range(ni):
                    rhs[i, nj - j - 1] = rhs[i, j]
        
        if imirror and jmirror:
            for j in range(nj):
                for i in range(ni):
                    rhs[ni - i - 1, nj - j - 1] = rhs[i, j]
        
        # Fourier transform of RHS
        trans[:, :] = fftn(rhs)
        
        if k == 0:
            for j in range(cfb.shape[1]):
                for i in range(cfb.shape[0]):
                    tem = 1.0 / cfb[i, j, 0]
                    lgbth[i, j, 0] = -cfc[0] * tem
                    lgbph[i, j, 0] = trans[i, j] * tem
        
        elif k < nk - 1:
            for j in range(cfb.shape[1]):
                for i in range(cfb.shape[0]):
                    temc = 1.0 / (cfa[k] * lgbth[i, j, k - 1] + cfb[i, j, k])
                    lgbth[i, j, k] = -cfc[k] * temc
                    lgbph[i, j, k] = (trans[i, j] - cfa[k] * lgbph[i, j, k - 1]) * temc
        
        else:
            for j in range(cfb.shape[1]):
                for i in range(cfb.shape[0]):
                    temc = 1.0 / (cfa[k] * lgbth[i, j, k - 1] + cfb[i, j, k])
                    lgbth[i, j, k] = -cfc[k] * temc
                    lgbph[i, j, k] = (trans[i, j] - cfa[k] * lgbph[i, j, k - 1]) * temc
                    pdt[i, j, k] = lgbph[i, j, k]
        
        deft2[k] = np.real(trans[0, 0])
    
    # Backward part of tridiagonal solver
    for k in range(nk - 2, -1, -1):
        for j in range(cfb.shape[1]):
            for i in range(cfb.shape[0]):
                pdt[i, j, k] = lgbth[i, j, k] * pdt[i, j, k + 1] + lgbph[i, j, k]
    
    # Compute R1 and Pdt corrections
    r1[1] = (deft2[0] - cfb[0, 0, 0] * r1[0]) / cfc[0]
    pdt[0, 0, 1] = complex(r1[1], 0.0)
    
    for k in range(1, nk - 1):
        r1[k + 1] = (deft2[k] - cfa[k] * r1[k - 1] - cfb[0, 0, k] * r1[k]) / cfc[k]
        pdt[0, 0, k + 1] = complex(r1[k + 1], 0.0)
    
    # Reverse Fourier transform
    for k in range(nk):
        for j in range(cfb.shape[1]):
            for i in range(cfb.shape[0]):
                rhs[i, j] = pdt[i, j, k]
        
        trans[:, :] = ifftn(rhs)
        
        for j in range(nj):
            for i in range(ni):
                ppi[i, j, k] = np.real(trans[i, j]) * d2[k]
    
    if timestats:
        print("Timing information not implemented.")
