In [1]:
# Minimal Berreman 4x4 solver (NumPy only)
# - Non-magnetic (mu_r = I), no bianisotropy (xi=zeta=0)
# - Incident/substrate half-spaces isotropic
# - Internal layers: fully anisotropic with arbitrary rotations
# - All angles in radians, wavelengths in meters

from __future__ import annotations
import numpy as np
from dataclasses import dataclass
from typing import Callable, Optional, List, Tuple
import numpy.testing as npt  # <-- needed for assert_allclose
# Physical constants (SI)
C0   = 299_792_458.0
MU0  = 4e-7*np.pi
EPS0 = 1.0/(MU0*C0*C0)

Tensor = np.ndarray
TensorFn = Callable[[float], Tensor]  # fn(wavelength)->3x3 complex


# ---------- S-matrix primitives (stable composition) ----------

def impedance_balance(F):
    Z0 = 376.730313668  # Ohms
    Fb = F.copy()
    Fb[2:,:] *= Z0      # scale Hx,Hy rows
    # normalize each column by its max magnitude to compare components fairly
    for j in range(Fb.shape[1]):
        m = np.max(np.abs(Fb[:,j]))
        if m > 0: Fb[:,j] /= m
    return Fb


def _split_forward_backward(vals, vecs):
    """
    Split Berreman eigenpairs into forward/backward sets based on the forward branch rule.
    Returns indices (ifwd, ibwd) and reorders vecs accordingly.
    """
    # forward branch: Im(kz)>=0; if ~real then Re(kz)>=0
    is_fwd = []
    for kz_over_k0 in vals:
        kz = kz_over_k0  # eigenvalues(A) are kz/k0
        fwd = True
        if np.imag(kz) < 0:
            fwd = False
        elif abs(np.imag(kz)) < 1e-14 and np.real(kz) < 0:
            fwd = False
        is_fwd.append(fwd)
    ifwd = [i for i, f in enumerate(is_fwd) if f]
    ibwd = [i for i, f in enumerate(is_fwd) if not f]
    # Expect 2 fwd + 2 bwd
    assert len(ifwd) == 2 and len(ibwd) == 2, "Berreman: not 2 forward/2 backward modes"
    return np.array(ifwd), np.array(ibwd)

def _ani_halfspace_modes(eps: np.ndarray, k0: float, kx: complex, ky: complex):
    """
    General (possibly anisotropic) half-space modal bases from Berreman A.
    Returns kz_eff (unused here), Fp (4x2 forward tangential fields), Fm (4x2 backward).
    """
    A = _A_from_diag_eps(eps, k0, kx, ky)             # 4x4
    vals, vecs = np.linalg.eig(A)                     # A Ψ = (kz/k0) Ψ
    ifwd, ibwd = _split_forward_backward(vals, vecs)  # indices of forward/backward eigenmodes

    # Columns of vecs are Ψ=[Ex,Ey,Hx,Hy]^T already (Berreman state is tangential fields)
    Fp = vecs[:, ifwd]    # 4x2
    Fm = vecs[:, ibwd]    # 4x2
    # For reference, an effective kz can be recovered as kz = k0*vals[ifwd]. We don't need it here.
    kz_eff = k0 * vals[ifwd]
    return kz_eff, Fp, Fm

def _general_halfspace_modes(material_eps, k0, kx, ky, assume_isotropic):
    if assume_isotropic:
        n = np.sqrt(material_eps[0, 0])
        kz_dimless, Fp, Fm = _iso_halfspace_modes(n, k0, kx, ky)
        kz_dimless = np.full(2, kz_dimless, dtype=complex)
        return kz_dimless, Fp, Fm

    # --- Anisotropic branch ---
    A = _A_from_diag_eps(material_eps, k0, kx, ky)
    vals, vecs = np.linalg.eig(A)
    ifwd, ibwd = _split_forward_backward(vals, vecs)
    Fp = vecs[:, ifwd]
    Fm = vecs[:, ibwd]

    if np.abs(Fp[1, 0]) < np.abs(Fp[0, 0]):  # ordinary wave check
        Fp = Fp[:, [1, 0]]
        Fm = Fm[:, [1, 0]]
        vals_ifwd = vals[ifwd][[1, 0]]
    else:
        vals_ifwd = vals[ifwd]

    for i in range(2):
        Sz = 0.5*np.real(Fp[0, i]*np.conj(Fp[3, i]) - Fp[1, i]*np.conj(Fp[2, i]))
        Fp[:, i] /= np.sqrt(abs(Sz)) * (1 if Sz > 0 else -1)
        Sz = 0.5*np.real(Fm[0, i]*np.conj(Fm[3, i]) - Fm[1, i]*np.conj(Fm[2, i]))
        Fm[:, i] /= np.sqrt(abs(Sz)) * (1 if Sz > 0 else -1)

    kz_eff = vals_ifwd
    return kz_eff, Fp, Fm


def _halfspace_modes(material_eps, k0, kx, ky, assume_isotropic):
    """
    Dispatch to isotropic or general half-space mode solver.

    Returns:
        kz_dimless : length-2 complex array of kz/k0
        Fp, Fm     : 4x2 arrays for forward/backward modal fields
    """
    if assume_isotropic:
        n = np.sqrt(material_eps[0, 0])
        return _iso_halfspace_modes(n, k0, kx, ky)
    else:
        return _general_halfspace_modes(material_eps, k0, kx, ky, assume_isotropic=False)




def redheffer_star(S_next, S_prev):
    """
    Redheffer star product: S_prev followed by S_next (left->right).
    S has 2x2 blocks: r, t, rp, tp.
    """
    r1, t1, rp1, tp1 = S_prev['r'], S_prev['t'], S_prev['rp'], S_prev['tp']
    r2, t2, rp2, tp2 = S_next['r'], S_next['t'], S_next['rp'], S_next['tp']
    I = np.eye(r1.shape[0], dtype=complex)

    X  = np.linalg.inv(I - rp1 @ r2)
    Xp = np.linalg.inv(I - r2  @ rp1)

    return {
        't' : t2 @ X @ t1,
        'r' : r1 + tp1 @ r2 @ X @ t1,
        'tp': tp1 @ Xp @ tp2,
        'rp': rp2 + t2 @ rp1 @ Xp @ tp2,
    }

import numpy as np

def smatrix_from_matching(F_left, F_right):
    """
    Construct 4×4 field-basis interface S-matrix given forward/backward
    field bases on each side.
    """
    # Pseudoinverses for stability
    FpL_pinv = np.linalg.pinv(F_left[0])
    FmL_pinv = np.linalg.pinv(F_left[1])
    FpR_pinv = np.linalg.pinv(F_right[0])
    FmR_pinv = np.linalg.pinv(F_right[1])

    # Match forward/backward modes directly in field space
    t_forward  = F_right[0] @ FpL_pinv
    t_backward = F_left[1]  @ FmR_pinv

    Z4 = np.zeros((4, 4), dtype=complex)
    return {'r': Z4, 'rp': Z4, 't': t_forward, 'tp': t_backward}

def _halfspace_modes_old(material_eps: np.ndarray, k0: float, kx: complex, ky: complex, assume_isotropic: bool):
    """
    Enhanced version with robust mode sorting and power normalization.
    """
    if assume_isotropic:
        n = np.sqrt(material_eps[0,0])
        return _iso_halfspace_modes(n, k0, kx, ky)
    
    A = _A_from_diag_eps(material_eps, k0, kx, ky)
    vals, vecs = np.linalg.eig(A)
    
    # Split forward/backward modes
    ifwd, ibwd = _split_forward_backward(vals, vecs)
    Fp = vecs[:, ifwd]
    Fm = vecs[:, ibwd]
    
    # Power normalization function
    def normalize_mode(v):
        # Calculate Sz (z-component of Poynting vector)
        Sz = 0.5 * np.real(v[0]*np.conj(v[3]) - v[1]*np.conj(v[2]))
        # Normalize to unit power (Sz = ±1)
        v_norm = v / np.sqrt(abs(Sz))
        # Ensure forward modes have Sz > 0
        return v_norm if Sz > 0 else -v_norm
    
    # Normalize all modes
    Fp = np.column_stack([normalize_mode(Fp[:,0]), normalize_mode(Fp[:,1])])
    Fm = np.column_stack([normalize_mode(Fm[:,0]), normalize_mode(Fm[:,1])])
    
    # Sort modes by polarization state
    # Ordinary wave typically has Ey dominant for OA in x-direction
    if np.abs(Fp[1,0]) < np.abs(Fp[0,0]):  # If Ex > Ey, swap
        Fp = Fp[:, [1,0]]
        Fm = Fm[:, [1,0]]
        vals_ifwd = vals[ifwd][[1,0]]
    else:
        vals_ifwd = vals[ifwd]
        
    return vals_ifwd, Fp, Fm


def interface_smatrix_eps(epsL: np.ndarray, epsR: np.ndarray, k0: float, kx: complex, ky: complex,
                          assume_iso_L: bool, assume_iso_R: bool):
    """
    Convenience: build S for a single interface given left/right permittivity tensors.
    """
    _, FpL, FmL = _halfspace_modes(epsL, k0, kx, ky, assume_iso_L)
    _, FpR, FmR = _halfspace_modes(epsR, k0, kx, ky, assume_iso_R)
    return smatrix_from_matching(FpL, FmL, FpR, FmR)




def _propagation_smatrix_mode_basis(kz_phys, d):
    """4×4 field-space propagation S-matrix (replaces old 2×2 mode version)."""
    # If no thickness, it's exactly identity in field space
    if abs(d) < 1e-15:
        I4 = np.eye(4, dtype=complex)
        Z4 = np.zeros((4, 4), dtype=complex)
        return {'r': Z4, 'rp': Z4, 't': I4, 'tp': I4}

    P_forw = np.diag(np.exp(1j * kz_phys * d))
    P_back = np.diag(np.exp(-1j * kz_phys * d))

    # Here Fp and Fm will be provided from _halfspace_modes
    raise ValueError("_propagation_smatrix_mode_basis now requires Fp/Fm, use propagation_smatrix_in_medium instead.")


def _propagation_smatrix_field_basis(kz_phys, Fp, Fm, d):
    """4×4 field-space propagation S-matrix."""
    if abs(d) < 1e-15:
        I4 = np.eye(4, dtype=complex)
        Z4 = np.zeros((4, 4), dtype=complex)
        return {'r': Z4, 'rp': Z4, 't': I4, 'tp': I4}

    P_forw = np.diag(np.exp(1j * kz_phys * d))
    P_back = np.diag(np.exp(-1j * kz_phys * d))

    # Fp/Fm are 4×2; use pseudoinverse to map modal→field safely
    Fp_pinv = np.linalg.pinv(Fp)
    Fm_pinv = np.linalg.pinv(Fm)

    t_forward = Fp @ P_forw @ Fp_pinv
    t_backward = Fm @ P_back @ Fm_pinv

    Z4 = np.zeros((4, 4), dtype=complex)
    return {'r': Z4, 'rp': Z4, 't': t_forward, 'tp': t_backward}


def propagation_smatrix_in_medium(eps, k0, kx, ky, d, assume_isotropic):
    """
    4×4 field-basis propagation S-matrix for a uniform layer.
    """
    kz_dimless, Fp, Fm = _halfspace_modes(eps, k0, kx, ky, assume_isotropic)
    kz_phys = kz_dimless * k0

    if abs(d) < 1e-15:
        I4 = np.eye(4, dtype=complex)
        Z4 = np.zeros((4, 4), dtype=complex)
        return {'r': Z4, 'rp': Z4, 't': I4, 'tp': I4}

    P_forw = np.diag(np.exp(1j * kz_phys * d))
    P_back = np.diag(np.exp(-1j * kz_phys * d))

    Fp_pinv = np.linalg.pinv(Fp)
    Fm_pinv = np.linalg.pinv(Fm)

    t_forward  = Fp @ P_forw @ Fp_pinv
    t_backward = Fm @ P_back @ Fm_pinv

    Z4 = np.zeros((4, 4), dtype=complex)
    return {'r': Z4, 'rp': Z4, 't': t_forward, 'tp': t_backward}



    
def redheffer_star(SB, SA):
    """
    Compose two 2-port S-matrices: SA followed by SB (left->right propagation).
    Each S is a dict with keys: 'r', 't', 'rp', 'tp' for 2x2 blocks (s,p) or full 2x2 field space.
    """
    rA, tA, rAp, tAp = SA['r'], SA['t'], SA['rp'], SA['tp']
    rB, tB, rBp, tBp = SB['r'], SB['t'], SB['rp'], SB['tp']

    I = np.eye(rA.shape[0], dtype=complex)
    X  = np.linalg.inv(I - rAp @ rB)
    Xp = np.linalg.inv(I - rB @ rAp)

    S = {}
    S['t']  = tB @ X @ tA
    S['r']  = rA + tAp @ rB @ X @ tA
    S['tp'] = tAp @ Xp @ tBp
    S['rp'] = rBp + tB @ rAp @ Xp @ tBp
    return S

import numpy as np

def _iso_halfspace_modes_old(n: complex, k0: float, kx: complex, ky: complex):
    """
    Isotropic half-space modal bases built from plane-wave physics:
      - Forward basis Fp columns (s,p) have Sz = +1
      - Backward basis Fm columns (s,p) have Sz = -1
    Works with complex n, arbitrary kx, ky.
    Returns: kz_fwd (scalar), Fp (4x2), Fm (4x2) with columns [Ex,Ey,Hx,Hy]^T.
    """
    # Helpers
    def Sz(P):
        Ex, Ey, Hx, Hy = P
        return 0.5 * np.real(Ex*np.conj(Hy) - Ey*np.conj(Hx))
    def norm2(x, y):  # complex-safe 2D norm
        return np.sqrt(np.abs(x)**2 + np.abs(y)**2) + 1e-300
    omega = C0 * k0

    # Wavevectors
    k_mag = k0 * n
    kt2   = kx*kx + ky*ky
    kz    = np.lib.scimath.sqrt(k_mag*k_mag - kt2)
    # forward/decaying branch
    if np.imag(kz) < 0 or (abs(np.imag(kz)) < 1e-14 and np.real(kz) < 0):
        kz = -kz
    kz_fwd = kz

    # Unit vectors: k-hat (direction), s-hat (TE), p-hat (TM)
    # handle normal incidence (kt=0) by choosing s = x̂
    if abs(kt2) < 1e-30:
        s_hat = np.array([1+0j, 0+0j, 0+0j])
    else:
        kt = np.lib.scimath.sqrt(kt2)
        # ŝ ∝ ẑ × k̂_t → (-ky, kx, 0) normalized
        sx, sy = -ky/kt, kx/kt
        s_hat = np.array([sx, sy, 0+0j], dtype=complex)

    # Forward k and unit vectors
    kf   = np.array([kx, ky, kz], dtype=complex)
    kf_n = np.lib.scimath.sqrt(kx*kx + ky*ky + kz*kz) + 1e-300
    k_hat_f = kf / kf_n
    p_hat_f = np.cross(s_hat, k_hat_f)

    # Backward k and unit vectors
    kb   = np.array([kx, ky, -kz], dtype=complex)
    kb_n = np.lib.scimath.sqrt(kx*kx + ky*ky + (-kz)*(-kz)) + 1e-300
    k_hat_b = kb / kb_n
    p_hat_b = np.cross(s_hat, k_hat_b)   # keep same ŝ; p̂ changes with k̂

    # Plane-wave fields: choose E = ŝ or p̂; H = (1/(μ0 ω)) k × E
    def tangential(E, H):
        return np.array([E[0], E[1], H[0], H[1]], dtype=complex)

    # Forward TE
    E_s_f = s_hat
    H_s_f = np.cross(kf, E_s_f) / (MU0 * omega)
    Ps = tangential(E_s_f, H_s_f)

    # Forward TM
    E_p_f = p_hat_f
    H_p_f = np.cross(kf, E_p_f) / (MU0 * omega)
    Pp = tangential(E_p_f, H_p_f)

    # Backward TE
    E_s_b = s_hat
    H_s_b = np.cross(kb, E_s_b) / (MU0 * omega)
    Qs = tangential(E_s_b, H_s_b)

    # Backward TM
    E_p_b = p_hat_b
    H_p_b = np.cross(kb, E_p_b) / (MU0 * omega)
    Qp = tangential(E_p_b, H_p_b)

    # Power-normalize: forward Sz=+1, backward Sz=-1
    for vec in (Ps, Pp):
        s = Sz(vec)
        if s == 0:
            vec[:] = 0
        else:
            vec[:] = vec / np.sqrt(abs(s))
            if s < 0:  # ensure +1
                vec[2:] *= -1.0
    for vec in (Qs, Qp):
        s = Sz(vec)
        if s == 0:
            vec[:] = 0
        else:
            vec[:] = vec / np.sqrt(abs(s))
            if s > 0:  # ensure -1
                vec[2:] *= -1.0

    Fp = np.column_stack([Ps, Pp])  # forward (+z): Sz=+1 each
    Fm = np.column_stack([Qs, Qp])  # backward (−z): Sz=−1 each

    # Guard against accidental degeneracy
    if np.linalg.cond(np.hstack([Fp, Fm])) > 1e12:
        # tiny orthogonal tweak to keep B invertible without changing power
        eps = 1e-12
        Fm[:, 1] += eps * Fp[:, 0]

    return kz_fwd, Fp, Fm

def _iso_halfspace_modes(n: complex, k0: float, kx: complex, ky: complex):
    """
    Isotropic half-space modal bases built from plane-wave physics.
    Returns:
        kz_dimless: length-2 array of kz/k0 for [s, p]
        Fp, Fm: forward/backward modal field matrices (4x2 each)
    """
    omega = C0 * k0
    kt2 = kx * kx + ky * ky
    k_mag = k0 * n
    kz = np.lib.scimath.sqrt(k_mag * k_mag - kt2)

    # forward/decaying branch
    if np.imag(kz) < 0 or (abs(np.imag(kz)) < 1e-14 and np.real(kz) < 0):
        kz = -kz
    kz_dimless = np.full(2, kz / k0, dtype=complex)

    # Unit vectors: s-hat, p-hat
    if abs(kt2) < 1e-30:
        s_hat = np.array([1+0j, 0+0j, 0+0j])
    else:
        kt = np.lib.scimath.sqrt(kt2)
        sx, sy = -ky/kt, kx/kt
        s_hat = np.array([sx, sy, 0+0j], dtype=complex)

    # Forward/backward k and p-hat
    kf = np.array([kx, ky, kz], dtype=complex)
    k_hat_f = kf / (np.lib.scimath.sqrt(kx*kx + ky*ky + kz*kz) + 1e-300)
    p_hat_f = np.cross(s_hat, k_hat_f)

    kb = np.array([kx, ky, -kz], dtype=complex)
    k_hat_b = kb / (np.lib.scimath.sqrt(kx*kx + ky*ky + kz*kz) + 1e-300)
    p_hat_b = np.cross(s_hat, k_hat_b)

    def tangential(E, H):
        return np.array([E[0], E[1], H[0], H[1]], dtype=complex)

    # Forward TE
    E_s_f = s_hat
    H_s_f = np.cross(kf, E_s_f) / (MU0 * omega)
    Ps = tangential(E_s_f, H_s_f)

    # Forward TM
    E_p_f = p_hat_f
    H_p_f = np.cross(kf, E_p_f) / (MU0 * omega)
    Pp = tangential(E_p_f, H_p_f)

    # Backward TE
    E_s_b = s_hat
    H_s_b = np.cross(kb, E_s_b) / (MU0 * omega)
    Qs = tangential(E_s_b, H_s_b)

    # Backward TM
    E_p_b = p_hat_b
    H_p_b = np.cross(kb, E_p_b) / (MU0 * omega)
    Qp = tangential(E_p_b, H_p_b)

    # Power normalization
    def normalize(vec, target_sz):
        s = 0.5 * np.real(vec[0]*np.conj(vec[3]) - vec[1]*np.conj(vec[2]))
        if s != 0:
            vec[:] = vec / np.sqrt(abs(s))
            if np.sign(s) != np.sign(target_sz):
                vec[2:] *= -1.0
        else:
            vec[:] = 0

    normalize(Ps, +1)
    normalize(Pp, +1)
    normalize(Qs, -1)
    normalize(Qp, -1)

    Fp = np.column_stack([Ps, Pp])
    Fm = np.column_stack([Qs, Qp])

    return kz_dimless, Fp, Fm







def _is_iso_tensor(eps):
    return np.allclose(eps, np.eye(3)*eps[0, 0], rtol=1e-12, atol=1e-15)

def interface_smatrix(eps_left, eps_right, k0, kx, ky,
                      assume_left_iso=False, assume_right_iso=False):
    """
    Field-basis interface S-matrix between two uniform half-spaces.
    """
    _, FpL, FmL = _halfspace_modes(eps_left,  k0, kx, ky, assume_left_iso)
    _, FpR, FmR = _halfspace_modes(eps_right, k0, kx, ky, assume_right_iso)

    return smatrix_from_matching((FpL, FmL), (FpR, FmR))


def layer_smatrix(A, k0, thickness):
    """
    Uniform layer S-matrix using Berreman A-matrix.
    Ensures propagation phase includes n via modal kz.
    """
    eigvals, W = np.linalg.eig(A)           # eigvals ~ kz/k0
    kz_phys = eigvals * k0                  # physical kz in 1/m

    # Forward/backward propagation diagonal blocks
    P_forw = np.diag(np.exp(1j * kz_phys * thickness))
    P_back = np.diag(np.exp(-1j * kz_phys * thickness))

    P_block = np.block([
        [P_forw, np.zeros_like(P_forw)],
        [np.zeros_like(P_back), P_back]
    ])

    W_inv = np.linalg.inv(W)
    M_layer = W @ P_block @ W_inv

    return transfer_to_smatrix(M_layer)




def solve_stack_smatrix_old(layers, wl, theta=0.0, phi=0.0, incident_power=1.0):
    """
    Solve multilayer stack using only field-basis (4×4) S-matrices.
    All interfaces and propagators are computed in field space.
    """
    k0 = 2 * np.pi / wl
    eps_inc = layers[0].material.eps(wl)
    eps_sub = layers[-1].material.eps(wl)
    nL = np.sqrt(eps_inc[0, 0])
    kx, ky = _k_components(k0, nL, theta, phi)

    def _is_iso_tensor(eps):
        return np.allclose(eps, np.eye(3) * eps[0, 0], rtol=1e-12, atol=1e-15)

    # Field-basis interface builder
    def _build_interface(epsL, epsR):
        isoL = _is_iso_tensor(epsL)
        isoR = _is_iso_tensor(epsR)
        _, FpL, FmL = _halfspace_modes(epsL, k0, kx, ky, isoL)
        _, FpR, FmR = _halfspace_modes(epsR, k0, kx, ky, isoR)
        return smatrix_from_matching(FpL, FmL, FpR, FmR)  # always 4×4

    # --- Case 1: only an interface ---
    if len(layers) == 2:
        return _solve_two_layer_field_basis(layers, wl, k0, kx, ky, _build_interface, incident_power)

    # --- Case 2: multi-layer stack ---
    eps_first = layers[1].material.eps(wl)
    S_total = _build_interface(eps_inc, eps_first)

    # Interior finite layers
    for lay in layers[1:-1]:
        eps_layer = lay.material.eps(wl)
        iso_layer = _is_iso_tensor(eps_layer)
        kz_dimless, Fp, Fm = _halfspace_modes(eps_layer, k0, kx, ky, iso_layer)
        kz_phys = kz_dimless * k0
        S_prop = _propagation_smatrix_field_basis(kz_phys, Fp, Fm, lay.thickness)
        S_total = redheffer_star(S_prop, S_total)

    # Final interface: last interior → substrate
    eps_last = layers[-2].material.eps(wl)
    S_last = _build_interface(eps_last, eps_sub)
    S_total = redheffer_star(S_last, S_total)

    # --- Extract power from field-basis S-matrix ---
    return _compute_power_dict_field_basis(S_total, eps_inc, eps_sub, k0, kx, ky, incident_power)

def solve_stack_smatrix(layers, wl, theta=0.0, phi=0.0, incident_power=1.0):
    """
    Solve multilayer stack using only field-basis (4×4) S-matrices.
    All interfaces and propagators are computed in field space.
    Returns amplitudes (t_s, t_p, r_s, r_p) and powers (R_s, R_p, T_s, T_p).
    """
    k0 = 2 * np.pi / wl
    eps_inc = layers[0].material.eps(wl)
    eps_sub = layers[-1].material.eps(wl)
    nL = np.sqrt(eps_inc[0, 0])
    kx, ky = _k_components(k0, nL, theta, phi)

    def _is_iso_tensor(eps):
        return np.allclose(eps, np.eye(3) * eps[0, 0], rtol=1e-12, atol=1e-15)

    # Field-basis interface builder
    def _build_interface(epsL, epsR):
        isoL = _is_iso_tensor(epsL)
        isoR = _is_iso_tensor(epsR)
        _, FpL, FmL = _halfspace_modes(epsL, k0, kx, ky, isoL)
        _, FpR, FmR = _halfspace_modes(epsR, k0, kx, ky, isoR)
        return smatrix_from_matching((FpL, FmL), (FpR, FmR))  # pass tuples

    # --- Case 1: only an interface ---
    if len(layers) == 2:
        return _solve_two_layer_field_basis(layers, wl, k0, kx, ky, _build_interface, incident_power)

    # --- Case 2: multi-layer stack ---
    eps_first = layers[1].material.eps(wl)
    S_total = _build_interface(eps_inc, eps_first)

    # Interior finite layers
    for lay in layers[1:-1]:
        eps_layer = lay.material.eps(wl)
        iso_layer = _is_iso_tensor(eps_layer)
        kz_dimless, Fp, Fm = _halfspace_modes(eps_layer, k0, kx, ky, iso_layer)
        kz_phys = kz_dimless * k0
        if lay.thickness is not None:
            S_prop = _propagation_smatrix_field_basis(kz_phys, Fp, Fm, lay.thickness)
            S_total = redheffer_star(S_prop, S_total)

    # Final interface: last interior → substrate
    eps_last = layers[-2].material.eps(wl)
    S_last = _build_interface(eps_last, eps_sub)
    S_total = redheffer_star(S_last, S_total)

    # --- Compute powers and amplitudes ---
    results = _compute_power_dict_field_basis(S_total, eps_inc, eps_sub, k0, kx, ky, incident_power)

    # Extract amplitudes directly from S-matrix
    r = S_total['r']
    t = S_total['t']
    results["r_s"] = r[0, 0]
    results["r_p"] = r[1, 1]
    results["t_s"] = t[0, 0]
    results["t_p"] = t[1, 1]

    return results

def _compute_power_dict_field_basis(S_total, pol='s', incident_power=1.0):
    """
    Compute reflected/transmitted powers from a 4x4 S-matrix in field basis.
    Assumes ports:
        f_s, f_p, b_s, b_p  (forward/backward)
    Forward power in each pol is proportional to |amp|^2.
    """
    # Select which pol is excited (0=s, 1=p)
    inc_vec = np.zeros((4, 1), dtype=complex)
    inc_vec[pol, 0] = np.sqrt(incident_power)

    r_vec = S_total['r'] @ inc_vec   # backward on input port
    t_vec = S_total['t'] @ inc_vec   # forward on output port

    # Powers are sum of |amp|^2 over pol components in forward direction
    R_s = np.abs(r_vec[0, 0])**2
    R_p = np.abs(r_vec[1, 0])**2
    T_s = np.abs(t_vec[0, 0])**2
    T_p = np.abs(t_vec[1, 0])**2

    return {
        "R_s": R_s, "R_p": R_p,
        "T_s": T_s, "T_p": T_p
    }


def _solve_two_layer(layers, wl, k0, kx, ky, build_interface, incident_power):
    eps_inc = layers[0].material.eps(wl)
    eps_sub = layers[1].material.eps(wl)
    S_total = build_interface(eps_inc, eps_sub)
    return _compute_power_dict(S_total, eps_inc, eps_sub, k0, kx, ky, incident_power)


def _compute_power_dict(S_total, eps_inc, eps_sub, k0, kx, ky, incident_power):
    # Detect isotropy for mode generation
    isoL = np.allclose(eps_inc, np.eye(3)*eps_inc[0,0], rtol=1e-12, atol=1e-15)
    isoR = np.allclose(eps_sub, np.eye(3)*eps_sub[0,0], rtol=1e-12, atol=1e-15)

    _, FpL, FmL = _halfspace_modes(eps_inc, k0, kx, ky, isoL)
    _, FpR, FmR = _halfspace_modes(eps_sub, k0, kx, ky, isoR)

    def _Sz(Psi4):
        Ex, Ey, Hx, Hy = Psi4
        return 0.5 * np.real(Ex*np.conj(Hy) - Ey*np.conj(Hx))

    def _requested_P(pol):
        if isinstance(incident_power, dict):
            return float(incident_power.get(pol, 1.0))
        return float(incident_power)

    out = {}
    r, t = S_total['r'], S_total['t']
    for pol, idx in (('s', 0), ('p', 1)):
        a_vec = np.array([1+0j, 0+0j]) if pol == 's' else np.array([0+0j, 1+0j])
        Psi_inc = FpL @ a_vec
        S_inc_raw = _Sz(Psi_inc)
        P_tar = _requested_P(pol)
        scale = np.sqrt(abs(P_tar) / (abs(S_inc_raw) + 1e-30))
        if S_inc_raw < 0:
            scale = -scale
        a_inc = a_vec * scale
        S_inc = P_tar

        r_vec = r @ a_inc
        t_vec = t @ a_inc
        S_ref = _Sz(FmL @ r_vec)
        S_tr  = _Sz(FpR @ t_vec)

        # Evanescent guard
        kzR, _, _ = _halfspace_modes(eps_sub, k0, kx, ky, isoR)
        if np.all((np.imag(kzR) > 1e-12) & (np.abs(np.real(kzR)) < 1e-10)):
            S_tr = 0.0

        R = -S_ref / S_inc
        T =  S_tr  / S_inc

        out[f"R_{pol}"] = float(R)
        out[f"T_{pol}"] = float(T)
        out[f"r_{pol}"] = r_vec[idx]
        out[f"t_{pol}"] = t_vec[idx]

    return out


def _as_tensor_fn(maybe: Tensor | TensorFn) -> TensorFn:
    if callable(maybe):
        return maybe  # already a dispersive tensor fn(wl)->3x3
    T = np.array(maybe, dtype=complex)
    assert T.shape == (3,3)
    return lambda wl: T

def rot_tensor(eps: Tensor, R: Optional[Tensor]) -> Tensor:
    """Rotate a 3x3 tensor by R (active); if R is None, return eps."""
    if R is None:
        return eps
    return R @ eps @ R.T

def R_from_euler(alpha: float, beta: float, gamma: float) -> Tensor:
    """Z(α)Y(β)Z(γ) right-handed active rotation."""
    ca, sa = np.cos(alpha), np.sin(alpha)
    cb, sb = np.cos(beta),  np.sin(beta)
    cg, sg = np.cos(gamma), np.sin(gamma)
    Rz1 = np.array([[ca,-sa,0],[sa,ca,0],[0,0,1]], complex)
    Ry  = np.array([[cb,0,sb],[0,1,0],[-sb,0,cb]], complex)
    Rz2 = np.array([[cg,-sg,0],[sg,cg,0],[0,0,1]], complex)
    return Rz1 @ Ry @ Rz2

@dataclass(frozen=True)
class Material:
    """Minimal material: relative permittivity tensor (3x3) or callable(wl)->3x3."""
    eps_r: Tensor | TensorFn
    R: Optional[Tensor] = None  # optional rotation to apply to eps
    
    def eps(self, wl: float) -> Tensor:
        base = _as_tensor_fn(self.eps_r)(wl)
        return rot_tensor(base, self.R)

@dataclass(frozen=True)
class Layer:
    material: Material
    thickness: Optional[float]  # None => semi-infinite half-space

# ---------- Core Berreman bits ----------

def _k_components(k0: float, n_inc: complex, theta: float, phi: float) -> Tuple[complex, complex]:
    """Tangential k components fixed by incident medium & angles."""
    k = k0 * n_inc
    kx = k*np.sin(theta)*np.cos(phi)
    ky = k*np.sin(theta)*np.sin(phi)
    return kx, ky

def _forward_branch(kz: complex) -> complex:
    """Choose decaying/forward branch: Im(kz)>=0; if ~real, Re(kz)>=0."""
    if np.imag(kz) < 0:
        return -kz
    if abs(np.imag(kz)) < 1e-14 and np.real(kz) < 0:
        return -kz
    return kz

def _A_from_diag_eps(eps: np.ndarray, k0: float, kx: complex, ky: complex) -> np.ndarray:
    """
    Full Berreman 4x4 A-matrix with physical constants (μ0, ε0, c0), no shortcuts.
    Time dependence exp(+i ω t) assumed; ODE: dΨ/dz = i k0 A Ψ, Ψ=[Ex,Ey,Hx,Hy]^T.

    eps : 3x3 relative permittivity tensor (possibly rotated), complex
    k0  : vacuum wavenumber = 2π/λ
    kx, ky : tangential wavevector components in the incident frame
    """
    eps = np.asarray(eps, dtype=complex)
    exx, exy, exz = eps[0,0], eps[0,1], eps[0,2]
    eyx, eyy, eyz = eps[1,0], eps[1,1], eps[1,2]
    ezx, ezy, ezz = eps[2,0], eps[2,1], eps[2,2]

    # Physical constants
    omega = C0 * k0
    c = 1.0 / (MU0 * omega)         # 1/(ω μ0)
    a = 1.0 / (EPS0 * omega * ezz)  # 1/(ω ε0 ε_zz)
    bx = ezx / ezz                  # ε_zx/ε_zz
    by = ezy / ezz                  # ε_zy/ε_zz

    # Build A so that dΨ/dz = i k0 A Ψ  (⇒ A = coefficients / k0)
    A = np.zeros((4,4), dtype=complex)

    # dE_x/dz
    A[0,0] = (-kx * bx) / k0
    A[0,1] = (-kx * by) / k0
    A[0,2] = ( kx * a * ky) / k0
    A[0,3] = (-kx * a * kx + omega * MU0) / k0

    # dE_y/dz
    A[1,0] = (-ky * bx) / k0
    A[1,1] = (-ky * by) / k0
    A[1,2] = ( ky * a * ky - omega * MU0) / k0
    A[1,3] = (-ky * a * kx) / k0

    # dH_x/dz
    A[2,0] = (-kx * c * ky - omega * EPS0 * (eyx - eyz * bx)) / k0
    A[2,1] = ( kx * c * kx - omega * EPS0 * (eyy - eyz * by)) / k0
    A[2,2] = ( -omega * EPS0 * eyz * a * ky) / k0
    A[2,3] = (  omega * EPS0 * eyz * a * kx) / k0

    # dH_y/dz
    A[3,0] = (-ky * c * ky + omega * EPS0 * (exx - exz * bx)) / k0
    A[3,1] = ( ky * c * kx + omega * EPS0 * (exy - exz * by)) / k0
    A[3,2] = (  omega * EPS0 * exz * a * ky) / k0
    A[3,3] = ( -omega * EPS0 * exz * a * kx) / k0

    return A


def _propagator(A: Tensor, k0: float, d: float) -> Tensor:
    """exp(i k0 A d) via eigendecomposition (stable & tiny)."""
    vals, vecs = np.linalg.eig(A)
    return vecs @ np.diag(np.exp(1j * k0 * vals * d)) @ np.linalg.inv(vecs)





# ---------- Tiny helpers for common materials ----------

def iso_material(n: complex) -> Material:
    eps = (n*n) * np.eye(3, dtype=complex)
    return Material(eps)

def uniaxial_material(n_o: complex, n_e: complex, R: Optional[Tensor]=None) -> Material:
    eps = np.diag([n_o*n_o, n_o*n_o, n_e*n_e]).astype(complex)
    return Material(eps, R=R)

def _Sz(Psi4: np.ndarray) -> float:
    Ex, Ey, Hx, Hy = Psi4
    return float(0.5 * np.real(Ex*np.conj(Hy) - Ey*np.conj(Hx)))

def _is_isotropic(eps: np.ndarray, tol=1e-12) -> bool:
    # Treat as isotropic if eps ≈ scalar * I
    s = eps[0,0]
    return np.allclose(eps, s*np.eye(3, dtype=complex), atol=tol, rtol=0)

def _power_coeffs_from_S(k0: float, kx: complex, ky: complex,
                         eps_inc: np.ndarray, eps_sub: np.ndarray,
                         S: dict, incident_power=1.0) -> dict:
    """
    Convert S ('r','t','rp','tp') into power coefficients using
    *port-consistent* modal bases and strict Poynting accounting.
    """
    # Build port bases using *the ports' media*
    _, FpL, FmL = _iso_halfspace_modes(np.sqrt(eps_inc[0,0]), k0, kx, ky)
    _, FpR, FmR = _iso_halfspace_modes(np.sqrt(eps_sub[0,0]), k0, kx, ky)

    r, t = S['r'], S['t']

    # Evanescent guard on the right (rarer medium beyond TIR etc.)
    kzR, _, _ = _iso_halfspace_modes(np.sqrt(eps_sub[0,0]), k0, kx, ky)
    kzR = np.atleast_1d(kzR)
    right_all_evan = np.all((np.imag(kzR) > 1e-12) & (np.abs(np.real(kzR)) < 1e-10))

    def _P_inc_scale(Psi_inc_raw, P_target):
        S_inc_raw = _Sz(Psi_inc_raw)
        scale = np.sqrt(abs(P_target) / (abs(S_inc_raw) + 1e-300))
        # Ensure forward power is positive after scaling
        if S_inc_raw < 0:
            scale = -scale
        return scale

    out = {}
    for pol_idx, pol in enumerate(("s","p")):
        a_hat = np.array([1+0j,0+0j]) if pol_idx==0 else np.array([0+0j,1+0j])

        # Normalize so incident *power* equals incident_power (or 1.0)
        P_tar = incident_power if not isinstance(incident_power, dict) else float(incident_power.get(pol, 1.0))
        Psi_inc_raw = FpL @ a_hat
        scale = _P_inc_scale(Psi_inc_raw, P_tar)
        a_in = a_hat * scale
        S_inc = P_tar  # by construction, positive

        # Modal responses
        r_vec = r @ a_in
        t_vec = t @ a_in

        # Tangential fields at ports
        Psi_refL = FmL @ r_vec   # backward on left => Sz < 0
        Psi_trR  = FpR @ t_vec   # forward on right  => Sz > 0 (if propagating)

        S_ref = _Sz(Psi_refL)
        S_tr  = 0.0 if right_all_evan else _Sz(Psi_trR)

        # Power coefficients (note minus for backward flux)
        R = -S_ref / (S_inc + 1e-300)
        T =  S_tr  / (S_inc + 1e-300)

        # Micro‑clip tiny negatives, but surface big problems
        def _clip01(x, tol=5e-6):
            if x < -tol or x > 1+tol:
                return float(x)
            return float(min(1.0, max(0.0, x)))

        out[f"R_{pol}"] = _clip01(np.real_if_close(R))
        out[f"T_{pol}"] = _clip01(np.real_if_close(T))
        out[f"r_{pol}"] = r_vec[pol_idx]
        out[f"t_{pol}"] = t_vec[pol_idx]

    return out



# solvers

In [2]:

def solve_stack_smatrix(
    layers: list[Layer],
    wavelength: float,
    theta: float = 0.0,
    phi: float   = 0.0,
    incident_power: float | dict = 1.0,   # W per unit area
) -> dict:

    wl = wavelength
    k0 = 2*np.pi/wl
    eps_inc = layers[0].material.eps(wl)
    eps_sub = layers[-1].material.eps(wl)
    n_inc = np.sqrt(eps_inc[0,0])
    kx, ky = _k_components(k0, n_inc, theta, phi)

    def _is_iso_tensor(eps):
        return np.allclose(eps, np.eye(3)*eps[0,0], rtol=1e-12, atol=1e-15)

    # 4×4 field-basis interface
    def S_interface(eL, eR):
        isoL = _is_iso_tensor(eL)
        isoR = _is_iso_tensor(eR)
    
        # Compute forward/backward mode matrices
        _, FpL, FmL = _halfspace_modes(eL, k0, kx, ky, isoL)
        _, FpR, FmR = _halfspace_modes(eR, k0, kx, ky, isoR)
    
        # Call matcher exactly as it expects
        return smatrix_from_matching((FpL, FmL), (FpR, FmR))


    # --- Build total S-matrix ---
    if len(layers) == 2:
        S_total = S_interface(eps_inc, eps_sub)
    else:
        eps1 = layers[1].material.eps(wl)
        S_total = S_interface(eps_inc, eps1)

        for i in range(1, len(layers)-1):
            lay = layers[i]
            eps_i = lay.material.eps(wl)
            is_last_finite = (i == len(layers)-2)
            eps_next = eps_sub if is_last_finite else layers[i+1].material.eps(wl)

            if lay.thickness is not None:
                S_prop = propagation_smatrix_in_medium(eps_i, k0, kx, ky, lay.thickness, assume_isotropic=False)
                S_total = redheffer_star(S_prop, S_total)

            S_if = S_interface(eps_i, eps_next)
            S_total = redheffer_star(S_if, S_total)

    # --- Power extraction in field basis ---
    def _Sz(Psi4):
        Ex, Ey, Hx, Hy = Psi4
        return 0.5 * np.real(Ex*np.conj(Hy) - Ey*np.conj(Hx))

    def _P_target(pol: str) -> float:
        if isinstance(incident_power, dict):
            return float(incident_power.get(pol, 1.0))
        return float(incident_power)

    _, Fp0, Fm0 = _halfspace_modes(eps_inc, k0, kx, ky, _is_iso_tensor(eps_inc))
    _, FP,  FM  = _halfspace_modes(eps_sub, k0, kx, ky, _is_iso_tensor(eps_sub))

    out = {}
    for pol_idx, pol in enumerate(('s', 'p')):
        a_vec = np.zeros((2,), dtype=complex)
        a_vec[pol_idx] = 1.0

        Psi_inc_raw = Fp0 @ a_vec
        S_inc_raw = _Sz(Psi_inc_raw)
        P_tar = _P_target(pol)
        eps_guard = 1e-30
        scale = np.sqrt(abs(P_tar) / (abs(S_inc_raw) + eps_guard))
        if S_inc_raw < 0:
            scale = -scale
        a_inc = a_vec * scale
        S_inc = P_tar

        # promote a_inc to 4×1 (field basis input port = forward modes only)
        a_inc_field = np.zeros((4,1), dtype=complex)
        a_inc_field[:2,0] = a_inc

        r_vec = S_total['r'] @ a_inc_field
        t_vec = S_total['t'] @ a_inc_field

        Psi_ref0 = r_vec
        Psi_trR  = t_vec
        S_ref = np.sum(np.abs(Psi_ref0[:2])**2)
        S_tr  = np.sum(np.abs(Psi_trR[:2])**2)

        R = float(np.real_if_close(S_ref / S_inc))
        T = float(np.real_if_close(S_tr  / S_inc))

        out[f"R_{pol}"] = max(0.0, min(1.0, R))
        out[f"T_{pol}"] = max(0.0, min(1.0, T))
        out[f"Pinc_{pol}"] = S_inc

    return out


def solve_stack_4x4_berreman_dont_use(
    layers: List[Layer],
    wavelength: float,
    theta: float = 0.0,
    phi: float   = 0.0,
) -> dict:
    """
    Returns {r_s, r_p, t_s, t_p, R_s, R_p, T_s, T_p}.
    The first and last elements of `layers` must be the incident and substrate half-spaces
    with thickness=None and *isotropic* permittivity tensors (scalar*n^2*I).
    Internal layers may be fully anisotropic & rotated.
    """
    assert len(layers) >= 2 and layers[0].thickness is None and layers[-1].thickness is None

    wl = wavelength
    k0 = 2*np.pi / wl

    # Incident / substrate refractive indices (assume isotropic tensors)
    eps_inc = layers[0].material.eps(wl); n_inc = np.sqrt(eps_inc[0,0])
    eps_sub = layers[-1].material.eps(wl); n_sub = np.sqrt(eps_sub[0,0])

    # Tangential components set by incidence in the *incident* medium
    kx, ky = _k_components(k0, n_inc, theta, phi)

    # Global 4x4 transfer across the stack interior (exclude half-spaces)
    M = np.eye(4, dtype=complex)
    for lay in layers[1:-1]:
        A = _A_from_diag_eps(lay.material.eps(wl), k0, kx, ky)
        M = _propagator(A, k0, lay.thickness) @ M

    # Build TE/TM tangential field bases in the half-spaces
    kz_inc, Fp0, Fm0 = _iso_halfspace_modes(n_inc, k0, kx, ky)   # incident side
    kz_sub, FP , FM  = _iso_halfspace_modes(n_sub, k0, kx, ky)   # substrate side

    # Boundary conditions:
    # Ψ(0) = Fp0 @ a_inc + Fm0 @ r         (r: 2×1 unknowns [r_s, r_p])
    # Ψ(L) = FP  @ t                       (t: 2×1 unknowns [t_s, t_p]; no backward in substrate)
    # and Ψ(L) = M @ Ψ(0)
    # Solve two separate 2×2 problems for s and p incidence by choosing a_inc = [1,0] or [0,1].

    results = {}

    for pol, a_inc in (("s", np.array([1+0j, 0+0j])),
                   ("p", np.array([0+0j, 1+0j]))):

        # Assemble linear system for unknown x = [r(2), t(2)]
        A_sys = np.block([[-M @ Fm0,  FP]])
        b_sys = (M @ (Fp0 @ a_inc))
    
        # Solve 4×4 system
        x = np.linalg.solve(A_sys, b_sys)
        r = x[:2]
        t = x[2:]
    
        # Save amplitudes
        results[f"r_{pol}"] = r[0] if pol == "s" else r[1]
        results[f"t_{pol}"] = t[0] if pol == "s" else t[1]
    
        # --- Power coefficients from Poynting vector ---
        def _Sz_from_tangential(Psi4):
            Ex, Ey, Hx, Hy = Psi4
            return 0.5 * np.real(Ex * np.conj(Hy) - Ey * np.conj(Hx))
        
        Psi_inc0 = Fp0 @ a_inc
        Psi_ref0 = Fm0 @ r
        Psi_trL  = FP  @ t
        
        S_inc = _Sz_from_tangential(Psi_inc0)
        S_ref = _Sz_from_tangential(Psi_ref0)
        S_tr  = _Sz_from_tangential(Psi_trL)
        
        R = abs(S_ref / S_inc)   # reflected power is positive by definition
        T = S_tr / S_inc
        
        results[f"R_{pol}"] = float(np.real_if_close(T*0 + R))  # ensure scalar real
        results[f"T_{pol}"] = float(np.real_if_close(T))


    return results

def solve_stack_4x4(layers, wavelength, theta=0.0, phi=0.0) -> dict:
    # Temporary shim: delegate to S-matrix implementation for stability
    results =  solve_stack_smatrix(layers, wavelength, theta=theta, phi=phi)
    # Corrected phase difference calculation
    def calculate_phase_diff(t_s, t_p):
        phase_s = np.angle(t_s)
        phase_p = np.angle(t_p)
        phase_diff = phase_p - phase_s
        # Proper phase unwrapping
        phase_diff = (phase_diff + np.pi) % (2 * np.pi) - np.pi
        return phase_diff
    
    # Add phase difference to results
    results['phase_diff'] = calculate_phase_diff(results['t_s'], results['t_p'])
    
    return results


In [3]:
# ---------- Example ----------
if __name__ == "__main__":
    wl = 550e-9
    air  = Layer(iso_material(1.0), None)  # incident
    glas = Layer(iso_material(1.5), 200e-9)  # thin isotropic film
    sub  = Layer(iso_material(1.45), None)   # substrate
    res = solve_stack_4x4([air, glas, sub], wl, theta=30*np.pi/180, phi=0.0)
    print(res)


KeyError: 't_s'

In [None]:
#helper
def fresnel_iso(n0, n1, theta0):
    """Return (r_s, r_p, t_s, t_p, R_s, R_p, T_s, T_p) for a single isotropic interface."""
    import numpy as np
    s0 = np.sin(theta0); c0 = np.cos(theta0)
    s1 = (n0/n1) * s0
    if np.abs(s1) > 1:     # total internal reflection
        return np.nan, np.nan, np.nan, np.nan, 1.0, 1.0, 0.0, 0.0
    c1 = np.sqrt(1 - s1**2)

    rs = (n0*c0 - n1*c1)/(n0*c0 + n1*c1)
    rp = (n1*c0 - n0*c1)/(n1*c0 + n0*c1)
    ts = 2*n0*c0/(n0*c0 + n1*c1)
    tp = 2*n0*c0/(n1*c0 + n0*c1)

    # Power transmission (Poynting z-ratio)
    Ts = (n1*c1)/(n0*c0) * np.abs(ts)**2
    Tp = (n1*c1)/(n0*c0) * np.abs(tp)**2
    Rs = np.abs(rs)**2
    Rp = np.abs(rp)**2
    return rs, rp, ts, tp, Rs, Rp, Ts, Tp


# level 1 tests

In [None]:
# ===================== VERBOSE DIAGNOSTIC SUBTESTS (Level_1) =====================

def _sz(P):
    Ex, Ey, Hx, Hy = P
    return 0.5 * np.real(Ex*np.conj(Hy) - Ey*np.conj(Hx))

def _power_RT_from_S(FpL, FmL, FpR, r, t):
    """
    Compute R,T from field S-blocks using port bases and Poynting Sz.
    Assumes we launch unit-power in the chosen pol column.
    """
    out = {}
    for pol_idx, pol in enumerate(("s","p")):
        a = np.zeros((2,), dtype=complex); a[pol_idx] = 1.0
        Psi_inc = FpL @ a
        Sinc = _sz(Psi_inc)
        # Reflected (backward on left), Transmitted (forward on right)
        Psi_ref = FmL @ (r @ a)
        Psi_tr  = FpR @ (t @ a)
        Sref = _sz(Psi_ref)
        Str  = _sz(Psi_tr)
        R = float(np.real(-Sref/Sinc))
        T = float(np.real(Str/Sinc))
        out[f"R_{pol}"] = R
        out[f"T_{pol}"] = T
    return out

def test_debug_air_air_interface_verbose_level_1():
    """Verbose check of air|air interface blocks and power accounting."""
    print("\n=== DEBUG: air | air interface (expect r≈0, t≈I in power basis) ===")
    wl = 600e-9; k0 = 2*np.pi/wl
    n = 1.0
    eL = (n*n)*np.eye(3, dtype=complex)
    eR = (n*n)*np.eye(3, dtype=complex)
    theta = np.deg2rad(23); phi = 0.4
    kx, ky = _k_components(k0, n_inc=n, theta=theta, phi=phi)

    # Half-space modes
    _, FpL, FmL = _halfspace_modes(eL, k0, kx, ky, assume_isotropic=True)
    _, FpR, FmR = _halfspace_modes(eR, k0, kx, ky, assume_isotropic=True)

    # Interface S via field matching (independent of nested S_interface)
    S_if = smatrix_from_matching(FpL, FmL, FpR, FmR)
    r, t, rp, tp = S_if['r'], S_if['t'], S_if['rp'], S_if['tp']

    np.set_printoptions(precision=5, suppress=True)
    print("FpL:\n", FpL); print("FmL:\n", FmL)
    print("FpR:\n", FpR); print("FmR:\n", FmR)
    print("Interface r (left):\n", r)
    print("Interface t (L→R):\n", t)
    print("Interface rp (right):\n", rp)
    print("Interface tp (R→L):\n", tp)

    # Power from Poynting
    out = _power_RT_from_S(FpL, FmL, FpR, r, t)
    print("Power R/T from interface (should be ~0/1):", out)

    # Keep this test non-failing: sanity on shapes only
    assert r.shape == (2,2) and t.shape == (2,2)

def test_debug_uniform_air_slab_transfer_vs_smatrix_level_1():
    """
    Build uniform-air slab transfer T and convert to S; compare with direct diag phase.
    Prints T11,T12,T21,T22, and the derived r,t,rp,tp plus power.
    """
    print("\n=== DEBUG: air slab (expect r≈0, |t|=1) via T→S conversion ===")
    wl = 600e-9; k0 = 2*np.pi/wl
    n = 1.0
    eps = (n*n)*np.eye(3, dtype=complex)
    theta = np.deg2rad(23); phi = 0.4
    kx, ky = _k_components(k0, n_inc=n, theta=theta, phi=phi)
    d = 2.0e-6

    # Modes & propagator
    _, Fp, Fm = _halfspace_modes(eps, k0, kx, ky, assume_isotropic=True)
    A = _A_from_diag_eps(eps, k0, kx, ky)
    T = _propagator(A, k0, d)

    B = np.hstack([Fp, Fm])
    Tfb = np.linalg.inv(B) @ T @ B
    T11, T12 = Tfb[0:2,0:2], Tfb[0:2,2:4]
    T21, T22 = Tfb[2:4,0:2], Tfb[2:4,2:4]
    print("T11:\n", T11); print("T12:\n", T12)
    print("T21:\n", T21); print("T22:\n", T22)

    # Correct T->S: invert T22
    def inv_reg(M, eps=1e-14):
        try: return np.linalg.inv(M)
        except np.linalg.LinAlgError: return np.linalg.pinv(M, rcond=eps)
    Dinv = inv_reg(T22)
    r  = - Dinv @ T21
    t  =   T11  - T12 @ Dinv @ T21
    rp =   T12 @ Dinv
    tp =   Dinv

    print("Derived r:\n", r); print("Derived t:\n", t)
    print("Derived rp:\n", rp); print("Derived tp:\n", tp)

    # Power check with air ports (same medium both sides)
    out = _power_RT_from_S(Fp, Fm, Fp, r, t)
    print("Power R/T from slab T→S (expect ~0/1):", out)

    assert r.shape == (2,2) and t.shape == (2,2)

def test_debug_star_order_air_slab_air_level_1():
    """
    Compose air|air slab|air in both possible star orders to verify composition direction.
    Prints R/T after each composition step.
    """
    print("\n=== DEBUG: star product order for air | air(d) | air ===")
    wl = 600e-9; k0 = 2*np.pi/wl
    n = 1.0
    eps_air = (n*n)*np.eye(3, dtype=complex)
    theta = np.deg2rad(23); phi = 0.4
    kx, ky = _k_components(k0, n_inc=n, theta=theta, phi=phi)
    d = 2.0e-6

    # Left/right port modes
    _, FpL, FmL = _halfspace_modes(eps_air, k0, kx, ky, assume_isotropic=True)
    _, FpR, FmR = _halfspace_modes(eps_air, k0, kx, ky, assume_isotropic=True)

    # Interface S (air|air)
    S_if = smatrix_from_matching(FpL, FmL, FpR, FmR)

    # Propagation S
    S_prop = propagation_smatrix_in_medium(eps_air, k0, kx, ky, d, assume_isotropic=True)

    # Two possible composition orders (depending on your redheffer_star signature):
    # Option A: S_total = S_prop ⭑ S_if
    SA = redheffer_star(S_prop, S_if)
    # Option B: S_total = S_if ⭑ S_prop
    SB = redheffer_star(S_if, S_prop)

    # Power using left/right ports of the composed system
    def power_from_S(S):
        r, t = S['r'], S['t']
        return _power_RT_from_S(FpL, FmL, FpR, r, t)

    outA = power_from_S(SA)
    outB = power_from_S(SB)

    print("Order A (S_prop ⭑ S_if): R/T =", outA)
    print("Order B (S_if ⭑ S_prop): R/T =", outB)
    print("Note: The correct order is the one that yields R≈0, T≈1 for air|air(d)|air.")

    assert True  # purely diagnostic

def test_debug_quarter_wave_AR_stepwise_level_1():
    """
    Step-by-step AR coating at normal incidence:
    Prints interface Fresnels, slab phase, and composed R/T at each step.
    """
    print("\n=== DEBUG: quarter-wave AR stepwise (air | film | glass, normal incidence) ===")
    wl = 550e-9; k0 = 2*np.pi/wl
    n0, n2 = 1.0, 1.5
    n1 = np.sqrt(n0*n2)
    d  = wl/(4*n1)

    eps0 = (n0*n0)*np.eye(3, dtype=complex)
    eps1 = (n1*n1)*np.eye(3, dtype=complex)
    eps2 = (n2*n2)*np.eye(3, dtype=complex)

    theta = 0.0; phi = 0.0
    kx, ky = _k_components(k0, n_inc=n0, theta=theta, phi=phi)

    # Ports & interfaces
    _, Fp0, Fm0 = _halfspace_modes(eps0, k0, kx, ky, assume_isotropic=True)
    _, Fp1, Fm1 = _halfspace_modes(eps1, k0, kx, ky, assume_isotropic=True)
    _, Fp2, Fm2 = _halfspace_modes(eps2, k0, kx, ky, assume_isotropic=True)

    S01 = smatrix_from_matching(Fp0, Fm0, Fp1, Fm1)
    S12 = smatrix_from_matching(Fp1, Fm1, Fp2, Fm2)
    Sprop = propagation_smatrix_in_medium(eps1, k0, kx, ky, d, assume_isotropic=True)

    # Compose either ((Sprop ⭑ S01) ⭑ S12) or (S12 ⭑ (Sprop ⭑ S01)) etc
    S_tmp = redheffer_star(Sprop, S01)
    S_tot = redheffer_star(S12, S_tmp)

    def RT_from(S):
        r, t = S['r'], S['t']
        return _power_RT_from_S(Fp0, Fm0, Fp2, r, t)

    print("Interface 0|1 R/T:", _power_RT_from_S(Fp0, Fm0, Fp1, S01['r'], S01['t']))
    print("Slab in 1 (phase only) R/T:", _power_RT_from_S(Fp1, Fm1, Fp1, Sprop['r'], Sprop['t']))
    print("Interface 1|2 R/T:", _power_RT_from_S(Fp1, Fm1, Fp2, S12['r'], S12['t']))
    print("Composed AR stack R/T (expect ~0/1):", RT_from(S_tot))

    assert True  # diagnostic only


# ---------- Unit tests for small utilities ----------
def test_A_eigs_isotropic_match_kz_over_k0_level_1():
    wl = 633e-9
    k0 = 2*np.pi/wl
    n  = 1.7
    eps = (n*n)*np.eye(3, dtype=complex)
    theta = np.deg2rad(25)
    phi   = 0.3
    kx, ky = _k_components(k0, n_inc=n, theta=theta, phi=phi)  # same medium just to set k_t
    A = _A_from_diag_eps(eps, k0, kx, ky)
    vals = np.linalg.eigvals(A)
    # Expected eigenvalues: ±(kz/k0), each twice
    kt2 = kx*kx + ky*ky
    kz = np.lib.scimath.sqrt((k0*n)**2 - kt2)
    expect = np.array([ kz/k0, kz/k0, -kz/k0, -kz/k0], dtype=complex)
    # Compare sets (order-insensitive)
    npt.assert_allclose(np.sort_complex(vals), np.sort_complex(expect), rtol=1e-9, atol=1e-12)


def test_R_from_euler_identity_and_orthonormal_level_1():
    R = R_from_euler(0.0, 0.0, 0.0)
    npt.assert_allclose(R, np.eye(3), atol=1e-14)
    # random angles: R should be orthonormal and det ~ 1
    R = R_from_euler(0.3, -0.7, 1.2)
    npt.assert_allclose(R.T @ R, np.eye(3), atol=1e-12)
    det = np.linalg.det(R)
    npt.assert_allclose(det, 1.0, rtol=0, atol=1e-12)

def test_rot_tensor_behavior_level_1():
    T = np.diag([2.0, 3.0, 4.0]).astype(complex)
    # Identity rotation -> no change
    T2 = rot_tensor(T, np.eye(3))
    npt.assert_allclose(T2, T, atol=1e-14)
    # 90 deg about z swaps x and y axes
    Rz90 = R_from_euler(np.pi/2, 0.0, 0.0)
    Trot = rot_tensor(T, Rz90)
    expect = np.diag([3.0, 2.0, 4.0]).astype(complex)
    npt.assert_allclose(Trot, expect, atol=1e-12)

def test_as_tensor_fn_and_material_eps_level_1():
    # Constant tensor
    eps = 2.25 * np.eye(3, dtype=complex)
    m = Material(eps)
    wl = 633e-9
    npt.assert_allclose(m.eps(wl), eps)
    # Dispersive callable
    def disp(wl):
        return (2.0 + 0.5*(wl/1e-6))**2 * np.eye(3, dtype=complex)
    m2 = Material(disp)
    e1 = m2.eps(500e-9)
    e2 = m2.eps(1000e-9)
    assert e1[0,0] != e2[0,0]

def test_iso_and_uniaxial_material_helpers_level_1():
    n = 1.7 + 0.0j
    mi = iso_material(n)
    e = mi.material.eps(550e-9) if isinstance(mi, Layer) else mi.eps(550e-9)
    # iso_material returns Material; here we check directly on returned Material
    e = mi.eps(550e-9)
    npt.assert_allclose(e, (n*n)*np.eye(3), atol=1e-14)

    ne = 1.6; no = 1.5
    mu = uniaxial_material(no, ne)
    ex = mu.eps(550e-9)
    npt.assert_allclose(np.diag(ex), np.array([no*no, no*no, ne*ne]), atol=1e-14)

def test_k_components_zero_angle_level_1():
    wl = 550e-9
    k0 = 2*np.pi/wl
    kx, ky = _k_components(k0, n_inc=1.0, theta=0.0, phi=1.1)
    npt.assert_allclose([kx, ky], [0.0, 0.0], atol=1e-14)

def test_forward_branch_sign_choice_level_1():
    # Negative imaginary should flip
    kz = 1.0 - 1e-6j
    out = _forward_branch(kz)
    npt.assert_allclose(out, -kz)
    # Positive imaginary should keep
    kz2 = 0.3 + 1e-6j
    out2 = _forward_branch(kz2)
    npt.assert_allclose(out2, kz2)
    # Nearly real negative -> flip
    kz3 = -0.5 + 0j
    out3 = _forward_branch(kz3)
    npt.assert_allclose(out3, -kz3)

def test_A_from_diag_eps_isotropic_normal_incidence_level_1():
    wl = 550e-9
    k0 = 2*np.pi/wl
    n  = 1.4
    eps = (n*n)*np.eye(3, dtype=complex)

    A = _A_from_diag_eps(eps, k0, kx=0.0, ky=0.0)

    Z0 = np.sqrt(MU0 / EPS0)  # free-space impedance ≈ 376.730313...
    npt.assert_allclose(A[0,3],  Z0, rtol=1e-12, atol=1e-12)
    npt.assert_allclose(A[1,2], -Z0, rtol=1e-12, atol=1e-12)
    npt.assert_allclose(A[2,1], -(n*n)/Z0, rtol=1e-12, atol=1e-12)
    npt.assert_allclose(A[3,0],  (n*n)/Z0, rtol=1e-12, atol=1e-12)

    # everything else ~0
    mask = np.ones((4,4), dtype=bool)
    mask[0,3] = mask[1,2] = mask[2,1] = mask[3,0] = False
    npt.assert_allclose(A[mask], 0.0, atol=1e-12)


def test_propagator_zero_thickness_is_identity_level_1():
    A = np.random.default_rng(0).standard_normal((4,4)) + 1j*np.random.default_rng(1).standard_normal((4,4))
    k0 = 2*np.pi/633e-9
    P = _propagator(A, k0, d=0.0)
    npt.assert_allclose(P, np.eye(4), atol=1e-12)

def test_iso_halfspace_modes_shapes_and_kz_sign_level_1():
    """
    Checks that _iso_halfspace_modes returns correct shapes and that kz has
    positive real part for the forward branch, with negligible imaginary part.
    """
    wl = 550e-9
    k0 = 2*np.pi/wl
    kz, Fp, Fm = _iso_halfspace_modes(n=1.5, k0=k0, kx=0.0, ky=0.0)

    assert Fp.shape == (4, 2)
    assert Fm.shape == (4, 2)

    # Forward branch: kz positive real, negligible imaginary part
    assert np.all(np.real(kz) > 0)
    assert np.all(np.abs(np.imag(kz)) < 1e-14)

# ---------- Integration-style tests for the solver ----------

def test_single_interface_matches_fresnel_power_level_2(theta_deg=37.0):
    theta = np.deg2rad(theta_deg)
    wl = 550e-9
    n0 = 1.0
    n1 = 1.7

    air = Layer(iso_material(n0), None)
    sub = Layer(iso_material(n1), None)
    out = solve_stack_4x4([air, sub], wl, theta=theta, phi=0.0)

    _, _, _, _, Rs, Rp, Ts, Tp = fresnel_iso(n0, n1, theta)

    npt.assert_allclose(out["R_s"], Rs, rtol=1e-4, atol=1e-6)
    npt.assert_allclose(out["R_p"], Rp, rtol=1e-4, atol=1e-6)
    npt.assert_allclose(out["T_s"], Ts, rtol=1e-4, atol=1e-6)
    npt.assert_allclose(out["T_p"], Tp, rtol=1e-4, atol=1e-6)
    npt.assert_allclose(out["R_s"] + out["T_s"], 1.0, rtol=1e-4, atol=1e-6)
    npt.assert_allclose(out["R_p"] + out["T_p"], 1.0, rtol=1e-4, atol=1e-6)

def test_thin_film_quarter_wave_AR_normal_incidence_level_2():
    wl = 550e-9
    n0 = 1.0
    n2 = 1.5
    n1 = (n0*n2)**0.5
    d  = wl/(4*n1)

    air  = Layer(iso_material(n0), None)
    film = Layer(iso_material(n1), d)
    sub  = Layer(iso_material(n2), None)

    out = solve_stack_4x4([air, film, sub], wl, theta=0.0, phi=0.0)

    print(f'{out["R_s"]=}',f'{out["R_p"]=}',f'{out["T_s"]=}',f'{out["T_p"]=}')
    assert out["R_s"] < 5e-3
    assert out["R_p"] < 5e-3
    assert out["T_s"] > 1 - 5e-3
    assert out["T_p"] > 1 - 5e-3

    



def test_anisotropic_layer_rotated_is_stable_and_reasonable_level_2():
    wl = 633e-9
    n0 = 1.0
    n_sub = 1.45

    no, ne = 1.6, 1.8
    Rmat = R_from_euler(0.3, 0.4, -0.2)
    mat = uniaxial_material(no, ne, R=Rmat)
    layer = Layer(mat, 200e-9)

    air = Layer(iso_material(n0), None)
    sub = Layer(iso_material(n_sub), None)

    out = solve_stack_4x4([air, layer, sub], wl, theta=np.deg2rad(15), phi=0.2)

    for key in ("R_s", "R_p", "T_s", "T_p"):
        val = float(np.real(out[key]))
        assert np.isfinite(val)
        assert -5e-3 <= val <= 1 + 5e-3

    Rs_Ts = float(np.real(out["R_s"] + out["T_s"]))
    Rp_Tp = float(np.real(out["R_p"] + out["T_p"]))
    assert abs(abs(out["R_s"]) + abs(out["T_s"]) - 1.0) < 2e-2
    assert abs(abs(out["R_p"]) + abs(out["T_p"]) - 1.0) < 2e-2
def test_iso_layer_phase_matches_nk0d_level_1():
    """
    Isotropic layer at normal incidence: forward-mode kz should be n*k0,
    so phase = n*k0*d.
    """
    wl = 550e-9
    n = 1.7
    d = 100e-9
    k0 = 2*np.pi / wl
    eps = (n**2) * np.eye(3)
    kx = ky = 0.0

    A = _A_from_diag_eps(eps, k0, kx, ky)
    eigvals, _ = np.linalg.eig(A)
    kz_phys = eigvals * k0
    # Only take forward-propagating modes (positive real part)
    kz_forw = [k for k in kz_phys if np.real(k) > 0]
    for kz in kz_forw:
        phase = np.real(kz) * d
        expected = n * k0 * d
        np.testing.assert_allclose(phase, expected, rtol=1e-12, atol=1e-12)


def test_uniaxial_layer_phase_matches_ne_no_level_1():
    """
    Uniaxial layer at normal incidence: kz should be n_o*k0 for ordinary,
    n_e*k0 for extraordinary mode.
    """
    wl = 550e-9
    no, ne = 1.5, 1.6
    d = 100e-9
    k0 = 2*np.pi / wl
    eps_uni = np.diag([ne**2, no**2, no**2])  # OA along x
    kx = ky = 0.0

    A = _A_from_diag_eps(eps_uni, k0, kx, ky)
    eigvals, _ = np.linalg.eig(A)
    kz_phys = eigvals * k0
    kz_forw = sorted([np.real(k) for k in kz_phys if np.real(k) > 0])
    expected = sorted([ne*k0, no*k0])
    np.testing.assert_allclose(kz_forw, expected, rtol=1e-12, atol=1e-12)

def test_iso_layer_phase_includes_n_level_1():
    """
    For an isotropic layer at normal incidence, phase = n*k0*d.
    This test ensures propagation_smatrix_in_medium uses n correctly.
    """
    wl = 550e-9
    n = 1.7
    d = 100e-9
    k0 = 2*np.pi / wl
    kx = ky = 0.0
    eps = (n**2) * np.eye(3)

    S_prop = propagation_smatrix_in_medium(eps, k0, kx, ky, d, assume_isotropic=True)
    t_phase = np.angle(S_prop['t'][0,0])
    expected_phase = (n * k0 * d) % (2*np.pi)
    np.testing.assert_allclose(t_phase, expected_phase, rtol=1e-12, atol=1e-12)

import numpy as np

def test_mode_basis_shapes_and_identity_level_1():
    kz_phys = np.array([1.0, 1.0])
    d = 0.0  # zero thickness → identity in mode basis
    S = _propagation_smatrix_mode_basis(kz_phys, d)

    assert set(S.keys()) == {'r', 'rp', 't', 'tp'}
    assert S['r'].shape == (2, 2)
    assert S['rp'].shape == (2, 2)
    assert np.allclose(S['t'], np.eye(2))
    assert np.allclose(S['tp'], np.eye(2))


def test_field_basis_shapes_and_identity_level_1():
    """
    Zero-thickness propagation in field basis should give:
    - 't' and 'tp' exactly 4×4 identity
    - 'r' and 'rp' exactly zero
    """
    wl = 550e-9
    k0 = 2*np.pi/wl
    eps = np.eye(3)  # vacuum
    kx, ky = 0.0, 0.0
    d = 0.0  # zero thickness

    S = propagation_smatrix_in_medium(eps, k0, kx, ky, d, assume_isotropic=True)

    assert S['t'].shape == (4, 4)
    assert S['tp'].shape == (4, 4)
    assert S['r'].shape == (4, 4)
    assert S['rp'].shape == (4, 4)

    assert np.allclose(S['t'], np.eye(4))
    assert np.allclose(S['tp'], np.eye(4))
    assert np.allclose(S['r'], np.zeros((4, 4)))
    assert np.allclose(S['rp'], np.zeros((4, 4)))


def test_propagation_smatrix_in_medium_isotropic_level_1():
    """
    Isotropic medium propagation in field basis:
    - For zero thickness → identity transmission
    """
    wl = 550e-9
    k0 = 2*np.pi/wl
    eps = np.eye(3) * 2.25  # n=1.5 isotropic
    kx, ky = 0.0, 0.0

    # Zero thickness
    S_field = propagation_smatrix_in_medium(eps, k0, kx, ky, 0.0, assume_isotropic=True)
    assert np.allclose(S_field['t'], np.eye(4))
    assert np.allclose(S_field['tp'], np.eye(4))
    assert np.allclose(S_field['r'], np.zeros((4, 4)))
    assert np.allclose(S_field['rp'], np.zeros((4, 4)))



In [None]:
import inspect
import sys
import traceback
import math

# --------------------------- Self-runner (no pytest needed) ---------------------------

def run_all_tests(module=None, level = "level_1"):
    """
    Finds and runs all functions in this module whose name starts with 'test'.
    Prints a pass/fail summary. No external test runner needed.
    """
    if module is None:
        module = sys.modules[__name__]
    funcs = inspect.getmembers(module, inspect.isfunction)
    test_funcs = [(name, fn) for name, fn in funcs if name.lower().startswith("test") and name.lower().endswith(level)]
    print(f"Found {len(test_funcs)} test(s). Running...\n")
    passed = failed = 0
    for name, fn in test_funcs:
        try:
            fn()
            print(f"[PASS] {name}")
            passed += 1
        except Exception as e:
            print(f"[FAIL] {name}: {e}")
            traceback.print_exc()
            failed += 1
    print(f"\nSummary: {passed} passed, {failed} failed.")

def test_iso_halfspace_modes_kz_shape_and_values_level_1():
    import numpy as np
    wl = 500e-9
    k0 = 2*np.pi / wl
    n = 1.5
    eps = np.eye(3) * (n**2)
    kx, ky = 0.0, 0.0

    kz_dimless, Fp, Fm = _halfspace_modes(eps, k0, kx, ky, assume_isotropic=True)

    # Shape and equality checks
    assert kz_dimless.shape == (2,)
    assert np.allclose(kz_dimless[0], kz_dimless[1])

    # Physical value check
    expected_kz = np.sqrt((k0*n)**2 - (kx**2 + ky**2)) / k0
    assert np.allclose(kz_dimless[0], expected_kz)


## run level 1 tests

In [None]:
if __name__ == "__main__":
    run_all_tests( level = "level_1")

## failed cases

In [None]:
# ===================== VERBOSE SUBTESTS FOR FAILED CASES (Level_1) =====================

np.set_printoptions(precision=6, suppress=True)

def _dbg_Sz(Psi4):
    Ex, Ey, Hx, Hy = Psi4
    return 0.5 * np.real(Ex*np.conj(Hy) - Ey*np.conj(Hx))

def _dbg_modes(eps, k0, kx, ky, assume_iso=True, label=""):
    kz, Fp, Fm = _halfspace_modes(eps, k0, kx, ky, assume_iso)
    B = np.hstack([Fp, Fm])
    print(f"\n[{label}] kz_fwd = {kz}")
    print(f"[{label}] cond([Fp|Fm]) = {np.linalg.cond(B):.3e}")
    for j, nm in enumerate(("Fp_s","Fp_p","Fm_s","Fm_p")):
        col = Fp[:,0] if j==0 else Fp[:,1] if j==1 else Fm[:,0] if j==2 else Fm[:,1]
        print(f"[{label}] {nm} Sz = {_dbg_Sz(col): .6e}")
    return kz, Fp, Fm

def _dbg_power_from_S(FpL, FmL, FpR, r, t, label=""):
    out = {}
    for pol_idx, pol in enumerate(("s","p")):
        a = np.zeros((2,), dtype=complex); a[pol_idx] = 1.0
        Psi_inc = FpL @ a
        Sinc = _dbg_Sz(Psi_inc)
        Psi_ref = FmL @ (r @ a)
        Psi_tr  = FpR @ (t @ a)
        Sref = _dbg_Sz(Psi_ref)
        Str  = _dbg_Sz(Psi_tr)
        R = -Sref / (Sinc + 1e-300)  # minus sign for backward power
        T =  Str / (Sinc + 1e-300)
        print(f"[{label}] pol={pol}: Sinc={Sinc: .6e}, Sref={Sref: .6e}, Str={Str: .6e}  =>  R={R: .6e}, T={T: .6e}")
        out[f"R_{pol}"] = float(np.real_if_close(R))
        out[f"T_{pol}"] = float(np.real_if_close(T))
    return out

def _dbg_interface_iso_vs_fresnel(n0, n1, k0, kx, ky, wl, theta, label="IF"):
    # numeric interface via smatrix_from_matching
    epsL = (n0*n0)*np.eye(3, dtype=complex)
    epsR = (n1*n1)*np.eye(3, dtype=complex)
    _, FpL, FmL = _halfspace_modes(epsL, k0, kx, ky, True)
    _, FpR, FmR = _halfspace_modes(epsR, k0, kx, ky, True)
    S = smatrix_from_matching(FpL, FmL, FpR, FmR)
    print(f"\n[{label}] r:\n{S['r']}\n[{label}] t:\n{S['t']}\n[{label}] rp:\n{S['rp']}\n[{label}] tp:\n{S['tp']}")
    _ = _dbg_power_from_S(FpL, FmL, FpR, S['r'], S['t'], label=f"{label} power")

    # Fresnel analytic for comparison
    rs, rp, ts, tp, Rs, Rp, Ts, Tp = fresnel_iso(n0, n1, theta)
    print(f"[{label}] Fresnel: Rs={Rs:.6e}, Rp={Rp:.6e}, Ts={Ts:.6e}, Tp={Tp:.6e}")

# ---------- 1) Absorptance >50% in thick lossy film ----------

def test_debug_absorptance_thick_lossy_layer_steps_level_1():
    """
    Air | lossy(d=10 µm) | Air @ λ=1 µm (normal incidence).
    Prints:
      - Mode bases and Sz signs in all ports
      - Interface S blocks and power
      - Slab propagation blocks and power
      - Composed stack R/T and A
    """
    wl = 1.0e-6; k0 = 2*np.pi/wl
    n0 = 1.0; nL = 2.0 + 1.5j; n2 = 1.0
    d  = 10e-6
    theta = 0.0; phi = 0.0
    kx, ky = _k_components(k0, n_inc=n0, theta=theta, phi=phi)
    # RIGHT
    eps0 = (n0*n0)*np.eye(3, dtype=complex)
    eps1 = (nL*nL)*np.eye(3, dtype=complex)
    eps2 = (n2*n2)*np.eye(3, dtype=complex)


    _, Fp0, Fm0 = _dbg_modes(eps0, k0, kx, ky, True, "port L (air)")
    _, Fp1, Fm1 = _dbg_modes(eps1, k0, kx, ky, True, "film (lossy)")
    _, Fp2, Fm2 = _dbg_modes(eps2, k0, kx, ky, True, "port R (air)")

    S01 = smatrix_from_matching(Fp0, Fm0, Fp1, Fm1)
    print("\n[IF 0|1] blocks:")
    print("r:\n", S01['r']); print("t:\n", S01['t'])
    _dbg_power_from_S(Fp0, Fm0, Fp1, S01['r'], S01['t'], "[IF 0|1]")

    Sprop = propagation_smatrix_in_medium(eps1, k0, kx, ky, d, True)
    print("\n[PROP in film] blocks:")
    print("r:\n", Sprop['r']); print("t:\n", Sprop['t'])
    _dbg_power_from_S(Fp1, Fm1, Fp1, Sprop['r'], Sprop['t'], "[PROP film]")

    S12 = smatrix_from_matching(Fp1, Fm1, Fp2, Fm2)
    print("\n[IF 1|2] blocks:")
    print("r:\n", S12['r']); print("t:\n", S12['t'])
    _dbg_power_from_S(Fp1, Fm1, Fp2, S12['r'], S12['t'], "[IF 1|2]")

    S_tmp = redheffer_star(Sprop, S01)
    S_tot = redheffer_star(S12, S_tmp)
    print("\n[STACK] Combined r,t:\n", S_tot['r'], "\n", S_tot['t'])
    out = _dbg_power_from_S(Fp0, Fm0, Fp2, S_tot['r'], S_tot['t'], "[STACK]")
    for pol in ("s","p"):
        A = 1.0 - (out[f"R_{pol}"] + out[f"T_{pol}"])
        print(f"[STACK] pol={pol}  A = {A: .6e}")

# ---------- 2) Anisotropic rotated layer “reasonable” ----------

def test_debug_anisotropic_rotated_steps_level_1():
    """
    Rotated uniaxial film between isotropic ports.
    Prints:
      - A eigenvalues and fwd/bwd split
      - Modal bases cond numbers
      - Each interface power and film propagation power
      - Final R/T and R+T
    """
    wl = 633e-9; k0 = 2*np.pi/wl
    n0, n2 = 1.0, 1.45
    no, ne = 1.6, 1.8
    Rmat = R_from_euler(0.3, 0.4, -0.2)
    film = uniaxial_material(no, ne, R=Rmat)
    d = 200e-9
    theta = np.deg2rad(15); phi = 0.2
    kx, ky = _k_components(k0, n_inc=n0, theta=theta, phi=phi)
    eps0 = (n0*n0)*np.eye(3, dtype=complex)
    eps2 = (n2*n2)*np.eye(3, dtype=complex)
    eps1 = film.eps(wl)

    # A eigenvalues
    A = _A_from_diag_eps(eps1, k0, kx, ky)
    vals = np.linalg.eigvals(A)
    print("\n[ANI] eig(A):\n", vals)

    _, Fp0, Fm0 = _dbg_modes(eps0, k0, kx, ky, True, "port L")
    _, Fp1, Fm1 = _dbg_modes(eps1, k0, kx, ky, False, "film (anisotropic)")
    _, Fp2, Fm2 = _dbg_modes(eps2, k0, kx, ky, True, "port R")

    S01 = smatrix_from_matching(Fp0, Fm0, Fp1, Fm1)
    Sprop = propagation_smatrix_in_medium(eps1, k0, kx, ky, d, False)
    S12 = smatrix_from_matching(Fp1, Fm1, Fp2, Fm2)

    print("\n[ANI IF 0|1] power:"); _dbg_power_from_S(Fp0, Fm0, Fp1, S01['r'], S01['t'], "[ANI IF 0|1]")
    print("\n[ANI PROP] power:"); _dbg_power_from_S(Fp1, Fm1, Fp1, Sprop['r'], Sprop['t'], "[ANI PROP]")
    print("\n[ANI IF 1|2] power:"); _dbg_power_from_S(Fp1, Fm1, Fp2, S12['r'], S12['t'], "[ANI IF 1|2]")

    S_tmp = redheffer_star(Sprop, S01)
    S_tot = redheffer_star(S12, S_tmp)
    print("\n[ANI STACK] r:\n", S_tot['r'], "\n[ANI STACK] t:\n", S_tot['t'])
    out = _dbg_power_from_S(Fp0, Fm0, Fp2, S_tot['r'], S_tot['t'], "[ANI STACK]")
    print(f"[ANI STACK] R_s+T_s={out['R_s']+out['T_s']:.6e}  R_p+T_p={out['R_p']+out['T_p']:.6e}")

# ---------- 3) Brewster angle p-minimum ----------

def test_debug_brewster_angle_steps_level_1():
    """
    Air→glass at θB. Prints Fresnel numbers, interface S, power R/T.
    """
    wl = 550e-9; k0 = 2*np.pi/wl
    n0, n1 = 1.0, 1.5
    thetaB = np.arctan(n1/n0); phi = 0.0
    kx, ky = _k_components(k0, n_inc=n0, theta=thetaB, phi=phi)
    _dbg_interface_iso_vs_fresnel(n0, n1, k0, kx, ky, wl, thetaB, label="Brewster IF")

# ---------- 4) Energy balance lossless stack ----------

def test_debug_energy_balance_lossless_steps_level_1():
    """
    Lossless isotropic multi-layer. Prints per-step S, power, and final sums.
    """
    wl = 700e-9; k0 = 2*np.pi/wl
    stack = [
        Layer(iso_material(1.0), None),
        Layer(iso_material(1.3), 120e-9),
        Layer(iso_material(1.8),  50e-9),
        Layer(iso_material(1.4), 220e-9),
        Layer(iso_material(1.5), None),
    ]
    theta = np.deg2rad(17); phi = 0.7
    kx, ky = _k_components(k0, n_inc=1.0, theta=theta, phi=phi)

    epsL = stack[0].material.eps(wl)
    epsR = stack[-1].material.eps(wl)
    _, FpL, FmL = _halfspace_modes(epsL, k0, kx, ky, True)
    _, FpR, FmR = _halfspace_modes(epsR, k0, kx, ky, True)

    # First interface
    _, Fp1, Fm1 = _halfspace_modes(stack[1].material.eps(wl), k0, kx, ky, True)
    S_total = smatrix_from_matching(FpL, FmL, Fp1, Fm1)
    print("\n[EB] Step 0 (IF L|1) power:"); _dbg_power_from_S(FpL, FmL, Fp1, S_total['r'], S_total['t'], "[EB 0]")

    # Walk layers
    for i in range(1, len(stack)-1):
        eps_i = stack[i].material.eps(wl)
        iso_i = True
        if stack[i].thickness:
            S_prop = propagation_smatrix_in_medium(eps_i, k0, kx, ky, stack[i].thickness, iso_i)
            S_total = redheffer_star(S_prop, S_total)
            print(f"\n[EB] Step {i} (PROP in {i}) power:")
            _, Fpi, Fmi = _halfspace_modes(eps_i, k0, kx, ky, iso_i)
            _dbg_power_from_S(Fpi, Fmi, Fpi, S_prop['r'], S_prop['t'], f"[EB PROP {i}]")
        eps_next = stack[i+1].material.eps(wl)
        iso_next = (i+1 == len(stack)-1)
        _, Fp_next, Fm_next = _halfspace_modes(eps_next, k0, kx, ky, iso_next)
        S_if = smatrix_from_matching(_halfspace_modes(eps_i, k0, kx, ky, iso_i)[1],
                                     _halfspace_modes(eps_i, k0, kx, ky, iso_i)[2],
                                     Fp_next, Fm_next)
        S_total = redheffer_star(S_if, S_total)
        print(f"[EB] Step {i} (IF {i}|{i+1}) power:")
        _dbg_power_from_S(_halfspace_modes(eps_i, k0, kx, ky, iso_i)[1],
                          _halfspace_modes(eps_i, k0, kx, ky, iso_i)[2],
                          Fp_next, S_if['r'], S_if['t'], f"[EB IF {i}|{i+1}]")

    print("\n[EB] Final r,t:\n", S_total['r'], "\n", S_total['t'])
    out = _dbg_power_from_S(FpL, FmL, FpR, S_total['r'], S_total['t'], "[EB FINAL]")
    print(f"[EB] (R_s+T_s)={out['R_s']+out['T_s']:.6e}  (R_p+T_p)={out['R_p']+out['T_p']:.6e}")

# ---------- 5) Single interface vs Fresnel (signs!) ----------

def test_debug_single_interface_steps_level_1():
    """
    Single interface (lossless, isotropic), the same as level_2 but fully verbose.
    Shows sign conventions in Poynting power.
    """
    wl = 550e-9; k0 = 2*np.pi/wl
    n0, n1 = 1.0, 1.7
    theta = np.deg2rad(37.0); phi = 0.0
    kx, ky = _k_components(k0, n_inc=n0, theta=theta, phi=phi)
    _dbg_interface_iso_vs_fresnel(n0, n1, k0, kx, ky, wl, theta, label="SINGLE-IF")

# ---------- 6) Total internal reflection ----------

def test_debug_TIR_steps_level_1():
    """
    Glass→air beyond critical angle. Prints kz in rarer medium and power guard behaviour.
    """
    wl = 550e-9; k0 = 2*np.pi/wl
    n0, n1 = 1.5, 1.0
    theta_c = np.arcsin(n1/n0)
    theta = theta_c + np.deg2rad(5.0)
    phi = 0.0
    kx, ky = _k_components(k0, n_inc=n0, theta=theta, phi=phi)

    # kz in transmitted medium
    kzR = np.lib.scimath.sqrt((k0*n1)**2 - (kx**2 + ky**2))
    print(f"\n[TIR] kz in rarer medium: kzR = {kzR}  (Im>0, Re≈0 indicates evanescent)")

    _dbg_interface_iso_vs_fresnel(n0, n1, k0, kx, ky, wl, theta, label="TIR IF")


# level 2 tests

In [None]:

# berreman_preflight_tests.py
"""
Pre-flight test suite for a full-physics Berreman 4x4 implementation.

- No hidden shortcuts: keeps μ0, ε0, c0 explicit and uses Poynting-based power.
- Each test has a docstring explaining what it checks, the assumptions, and expected behavior.
- Auto-runs without pytest: just `python berreman_preflight_tests.py`

If your solver module isn't named `minimal_berreman.py`, change the import below.
"""

import sys
import inspect
import traceback
import numpy as np
import numpy.testing as npt

# ---- Adjust this import to your solver module name (no "mb." usage in tests) ----


# --------------------------- Helpers (tests only) ---------------------------

def fresnel_iso(n0, n1, theta0):
    """
    Textbook Fresnel coefficients for a single *isotropic, non-magnetic* interface.
    Returns (r_s, r_p, t_s, t_p, R_s, R_p, T_s, T_p) with power T using z-Poynting ratios.
    Assumes SI units; independent of Berreman implementation details.
    """
    s0 = np.sin(theta0)
    c0 = np.cos(theta0)
    s1 = (n0/n1) * s0
    if np.abs(s1) > 1:  # TIR
        return np.nan, np.nan, np.nan, np.nan, 1.0, 1.0, 0.0, 0.0
    c1 = np.sqrt(1 - s1**2)

    rs = (n0*c0 - n1*c1)/(n0*c0 + n1*c1)
    rp = (n1*c0 - n0*c1)/(n1*c0 + n0*c1)
    ts = 2*n0*c0/(n0*c0 + n1*c1)
    tp = 2*n0*c0/(n1*c0 + n0*c1)

    Ts = (n1*c1)/(n0*c0) * np.abs(ts)**2
    Tp = (n1*c1)/(n0*c0) * np.abs(tp)**2
    Rs = np.abs(rs)**2
    Rp = np.abs(rp)**2
    return rs, rp, ts, tp, Rs, Rp, Ts, Tp


# --------------------------- Unit tests for small utilities ---------------------------

def test_R_from_euler_identity_and_orthonormal_level_2():
    """
    Checks the rotation constructor’s core invariants:
    - R(0,0,0) = I.
    - For arbitrary angles, R^T R = I and det(R) = +1 (proper rotation).
    Independent of EM normalization; pure geometry check.
    """
    R = R_from_euler(0.0, 0.0, 0.0)
    npt.assert_allclose(R, np.eye(3), atol=1e-14)
    R = R_from_euler(0.3, -0.7, 1.2)
    npt.assert_allclose(R.T @ R, np.eye(3), atol=1e-12)
    npt.assert_allclose(np.linalg.det(R), 1.0, rtol=0, atol=1e-12)


def test_rot_tensor_behavior_level_2():
    """
    Verifies the active rotation convention for tensors: ε' = R ε R^T.
    For a 90° rotation about z, x/y diagonal entries swap; identity rotation is a no-op.
    """
    T = np.diag([2.0, 3.0, 4.0]).astype(complex)
    T2 = rot_tensor(T, np.eye(3))
    npt.assert_allclose(T2, T, atol=1e-14)
    Rz90 = R_from_euler(np.pi/2, 0.0, 0.0)
    Trot = rot_tensor(T, Rz90)
    expect = np.diag([3.0, 2.0, 4.0]).astype(complex)
    npt.assert_allclose(Trot, expect, atol=1e-12)


def test_as_tensor_fn_and_material_eps_level_2():
    """
    Ensures Material accepts both constant and dispersive (callable) ε_r inputs,
    returning a 3×3 tensor at any wavelength. Pure API sanity check.
    """
    eps = 2.25 * np.eye(3, dtype=complex)
    m = Material(eps)
    wl = 633e-9
    npt.assert_allclose(m.eps(wl), eps)

    def disp(wl):
        return (2.0 + 0.5*(wl/1e-6))**2 * np.eye(3, dtype=complex)
    m2 = Material(disp)
    e1 = m2.eps(500e-9); e2 = m2.eps(1000e-9)
    assert e1[0,0] != e2[0,0]


def test_iso_and_uniaxial_material_helpers_level_2():
    """
    Quick helpers sanity:
    - iso_material(n) → ε = n^2 I.
    - uniaxial_material(no, ne) → diag(no^2, no^2, ne^2).
    Independent of rotations/branches.
    """
    n = 1.7 + 0.0j
    mi = iso_material(n)
    e = mi.eps(550e-9)
    npt.assert_allclose(e, (n*n)*np.eye(3), atol=1e-14)

    ne = 1.6; no = 1.5
    mu = uniaxial_material(no, ne)
    ex = mu.eps(550e-9)
    npt.assert_allclose(np.diag(ex), np.array([no*no, no*no, ne*ne]), atol=1e-14)


def test_k_components_zero_angle_level_2():
    """
    Angle-to-(kx,ky) mapping: at θ=0 any φ must yield kx=ky=0.
    Confirms how the test harness computes tangential k from incidence.
    """
    wl = 550e-9
    k0 = 2*np.pi/wl
    kx, ky = _k_components(k0, n_inc=1.0, theta=0.0, phi=1.1)
    npt.assert_allclose([kx, ky], [0.0, 0.0], atol=1e-14)


def test_forward_branch_sign_choice_level_2():
    """
    Branch rule for forward/decaying kz (exp(+i ω t) convention):
    - If Im(kz)<0 → flip sign so the evanescent decays for z>0.
    - If Im(kz)≈0 and Re(kz)<0 → flip so propagation is +z.
    """
    kz = 1.0 - 1e-6j
    npt.assert_allclose(_forward_branch(kz), -kz)
    kz2 = 0.3 + 1e-6j
    npt.assert_allclose(_forward_branch(kz2), kz2)
    kz3 = -0.5 + 0j
    npt.assert_allclose(_forward_branch(kz3), -kz3)


def test_A_from_diag_eps_isotropic_normal_incidence_level_2():
    """
    Full-physics A-matrix entries at normal incidence, isotropic ε_r=n^2 I:
    With dΨ/dz = i k0 A Ψ and SI constants kept:
      A[0,3] = +Z0, A[1,2] = -Z0,
      A[2,1] = -(n^2)/Z0, A[3,0] = +(n^2)/Z0,
    where Z0 = sqrt(μ0/ε0) ≈ 376.730313 Ω. All other entries ~ 0.
    This replaces the older unitless shortcut (±1) assumption.
    """
    wl = 550e-9
    k0 = 2*np.pi/wl
    n  = 1.4
    eps = (n*n)*np.eye(3, dtype=complex)
    A = _A_from_diag_eps(eps, k0, kx=0.0, ky=0.0)
    Z0 = np.sqrt(MU0 / EPS0)
    npt.assert_allclose(A[0,3],  Z0, rtol=1e-12, atol=1e-12)
    npt.assert_allclose(A[1,2], -Z0, rtol=1e-12, atol=1e-12)
    npt.assert_allclose(A[2,1], -(n*n)/Z0, rtol=1e-12, atol=1e-12)
    npt.assert_allclose(A[3,0],  (n*n)/Z0, rtol=1e-12, atol=1e-12)
    mask = np.ones((4,4), dtype=bool)
    mask[0,3] = mask[1,2] = mask[2,1] = mask[3,0] = False
    npt.assert_allclose(A[mask], 0.0, atol=1e-12)


def test_A_eigs_isotropic_match_kz_over_k0_level_2():
    """
    Eigenvalue sanity for isotropic layers: eigenvalues(A) = {± kz/k0, ± kz/k0}.
    This is implementation- and normalization-invariant if A is built from Maxwell.
    """
    wl = 633e-9
    k0 = 2*np.pi/wl
    n  = 1.7
    eps = (n*n)*np.eye(3, dtype=complex)
    theta = np.deg2rad(25)
    phi   = 0.3
    kx, ky = _k_components(k0, n_inc=n, theta=theta, phi=phi)  # set k_t in same medium
    A = _A_from_diag_eps(eps, k0, kx, ky)
    vals = np.linalg.eigvals(A)
    kt2 = kx*kx + ky*ky
    kz = np.lib.scimath.sqrt((k0*n)**2 - kt2)
    expect = np.array([ kz/k0, kz/k0, -kz/k0, -kz/k0], dtype=complex)
    npt.assert_allclose(np.sort_complex(vals), np.sort_complex(expect), rtol=1e-9, atol=1e-12)


def test_propagator_zero_thickness_is_identity_level_1():
    """
    Matrix exponential consistency: exp(i k0 A * 0) = I for any A.
    Catches incorrect expm shortcuts or scaling mistakes.
    """
    rng = np.random.default_rng(0)
    A = rng.standard_normal((4,4)) + 1j*rng.standard_normal((4,4))
    k0 = 2*np.pi/633e-9
    P = _propagator(A, k0, d=0.0)
    npt.assert_allclose(P, np.eye(4), atol=1e-12)


def test_iso_halfspace_modes_shapes_and_kz_sign_level_1():
    """
    Half-space mode builder sanity:
    - F^+, F^- are 4×2 bases for forward/backward (s,p) tangential fields.
    - At normal incidence in a lossless medium, kz should be ≈ real and > 0 (forward).
    """
    wl = 550e-9; k0 = 2*np.pi/wl
    kz, Fp, Fm = _iso_halfspace_modes(n=1.5, k0=k0, kx=0.0, ky=0.0)
    assert Fp.shape == (4,2) and Fm.shape == (4,2)
    assert np.all(np.real(kz) > 0) and np.all(abs(np.imag(kz)) < 1e-14) and abs(np.imag(kz)) < 1e-14


# --------------------------- Integration-style tests ---------------------------

def test_single_interface_matches_fresnel_power_level_2(theta_deg=37.0):
    """
    Berreman vs Fresnel at a single interface (lossless, isotropic):
    Compare R_s, R_p, T_s, T_p to analytic Fresnel values using Poynting power ratios.
    Independent of internal modal scaling.
    """
    theta = np.deg2rad(theta_deg)
    wl = 550e-9
    n0, n1 = 1.0, 1.7
    air = Layer(iso_material(n0), None)
    sub = Layer(iso_material(n1), None)
    out = solve_stack_4x4([air, sub], wl, theta=theta, phi=0.0)
    _, _, _, _, Rs, Rp, Ts, Tp = fresnel_iso(n0, n1, theta)
    npt.assert_allclose(out["R_s"], Rs, rtol=1e-4, atol=1e-6)
    npt.assert_allclose(out["R_p"], Rp, rtol=1e-4, atol=1e-6)
    npt.assert_allclose(out["T_s"], Ts, rtol=1e-4, atol=1e-6)
    npt.assert_allclose(out["T_p"], Tp, rtol=1e-4, atol=1e-6)
    npt.assert_allclose(out["R_s"] + out["T_s"], 1.0, rtol=1e-4, atol=1e-6)
    npt.assert_allclose(out["R_p"] + out["T_p"], 1.0, rtol=1e-4, atol=1e-6)


def test_brewster_angle_p_reflection_min_level_2():
    """
    Brewster-angle check for p-polarization (non-magnetic, isotropic):
    For air→glass (n1>n0), tan θ_B = n1/n0; R_p(θ_B) ≈ 0 while R_s(θ_B) > 0.
    Verifies angle handling and p-basis correctness.
    """
    wl = 550e-9
    n0, n1 = 1.0, 1.5
    thetaB = np.arctan(n1/n0)  # Brewster angle (p) in non-magnetic isotropic media
    air = Layer(iso_material(n0), None)
    glass = Layer(iso_material(n1), None)
    out = solve_stack_4x4([air, glass], wl, theta=thetaB, phi=0.0)
    assert abs(out["R_p"]) < 1e-4
    assert out["R_s"] > 1e-3  # s is not zero at Brewster


def test_total_internal_reflection_limits_level_2():
    """
    TIR check: glass→air at θ > θ_c = asin(n1/n0).
    Expect T ≈ 0, R ≈ 1; transmitted field is evanescent with Im(kz)>0 in the rarer medium.
    Validates branch choice and energy accounting.
    """
    wl = 550e-9
    n0, n1 = 1.5, 1.0
    theta_c = np.arcsin(n1/n0)
    theta = theta_c + np.deg2rad(5)
    glass = Layer(iso_material(n0), None)
    air   = Layer(iso_material(n1), None)
    out = solve_stack_4x4([glass, air], wl, theta=theta, phi=0.0)
    assert out["T_s"] < 1e-4 and out["T_p"] < 1e-4
    assert out["R_s"] > 1 - 1e-6 and out["R_p"] > 1 - 1e-6


def test_thin_film_quarter_wave_AR_normal_incidence_level_2():
    """
    Quarter-wave AR coating at normal incidence:
    Choose n1 = sqrt(n0 * n2), d = λ/(4 n1).
    Expect very low R (<1e-3) and near-unity T (>1 - 1e-3).
    Validates that per-layer propagators and interface matching
    produce correct phases in the field basis.
    """
    wl = 550e-9
    n0, n2 = 1.0, 1.5
    n1 = np.sqrt(n0 * n2)
    d = wl / (4 * n1)

    air = Layer(iso_material(n0), None)
    film = Layer(iso_material(n1), d)
    sub = Layer(iso_material(n2), None)

    out = solve_stack_4x4([air, film, sub], wl, theta=0.0, phi=0.0)

    print("\n=== DEBUG: Quarter-wave AR coating ===")
    print(f"Wavelength (m): {wl:.3e}")
    print(f"n0: {n0}, n1 (film): {n1}, n2: {n2}")
    print(f"Film thickness (m): {d:.10e}")
    print(f"Expected optical thickness: λ/4 -> {wl/(4*n1):.10e} m")
    print("\nComputed outputs:")
    print(f"R_s: {out['R_s']}")
    print(f"T_s: {out['T_s']}")
    print(f"R_p: {out['R_p']}")
    print(f"T_p: {out['T_p']}")
    print("\nCheck expectations:")
    print(f"R_s < 1e-3? {out['R_s'] < 1e-3}")
    print(f"R_p < 1e-3? {out['R_p'] < 1e-3}")
    print(f"T_s > 1 - 1e-3? {out['T_s'] > 1 - 1e-3}")
    print(f"T_p > 1 - 1e-3? {out['T_p'] > 1 - 1e-3}")

    assert out["R_s"] < 1e-3 and out["R_p"] < 1e-3, "Reflectance too high for AR design"
    assert out["T_s"] > 1 - 1e-3 and out["T_p"] > 1 - 1e-3, "Transmittance too low for AR design"




def test_vacuum_slab_transparent_level_2():
    """
    Identity stack: air | (air, any thickness) | air.
    Expect R=0, T=1 for both polarizations. Catches interface or phase errors.
    """
    wl = 600e-9
    air  = Layer(iso_material(1.0), None)
    slab = Layer(iso_material(1.0), 2.0e-6)
    sub  = Layer(iso_material(1.0), None)
    out = solve_stack_4x4([air, slab, sub], wl, theta=np.deg2rad(23), phi=0.4)
    npt.assert_allclose(out["R_s"], 0.0, atol=1e-8)
    npt.assert_allclose(out["R_p"], 0.0, atol=1e-8)
    npt.assert_allclose(out["T_s"], 1.0, atol=1e-8)
    npt.assert_allclose(out["T_p"], 1.0, atol=1e-8)


def test_rotation_invariance_z_for_isotropic_stack_level_2():
    """
    Isotropic stacks are azimuthally invariant:
    For fixed θ, rotating φ (about z) must not change R/T.
    Validates kx,ky mapping and isotropy handling.
    """
    wl = 550e-9
    n0, n2 = 1.0, 1.7
    d = 100e-9
    air  = Layer(iso_material(n0), None)
    film = Layer(iso_material(1.3), d)
    sub  = Layer(iso_material(n2), None)
    out0 = solve_stack_4x4([air, film, sub], wl, theta=np.deg2rad(30), phi=0.0)
    out1 = solve_stack_4x4([air, film, sub], wl, theta=np.deg2rad(30), phi=1.1)
    for k in ("R_s","R_p","T_s","T_p"):
        npt.assert_allclose(out0[k], out1[k], rtol=1e-3, atol=1e-3)


def test_anisotropic_layer_rotated_is_stable_and_reasonable_level_2():
    """
    General anisotropy sanity:
    Rotated uniaxial layer between isotropic half-spaces should yield finite, bounded R/T,
    and (R+T) close to 1 in lossless case (allowing tiny numerical slack).
    Validates Berreman A, rotation of ε, and interface matching coherently.
    """
    wl = 633e-9
    n0, n_sub = 1.0, 1.45
    no, ne = 1.6, 1.8
    Rmat = R_from_euler(0.3, 0.4, -0.2)
    mat = uniaxial_material(no, ne, R=Rmat)
    layer = Layer(mat, 200e-9)
    air = Layer(iso_material(n0), None)
    sub = Layer(iso_material(n_sub), None)
    out = solve_stack_4x4([air, layer, sub], wl, theta=np.deg2rad(15), phi=0.2)
    for key in ("R_s","R_p","T_s","T_p"):
        val = float(np.real(out[key]))
        assert np.isfinite(val)
        assert -5e-3 <= val <= 1 + 5e-3
    # Energy balance (lossless): allow small slack
    assert abs(abs(out["R_s"]) + abs(out["T_s"]) - 1.0) < 2e-2
    assert abs(abs(out["R_p"]) + abs(out["T_p"]) - 1.0) < 2e-2


def test_energy_balance_lossless_level_2():
    """
    Global energy check for a random *lossless* stack:
    In absence of absorption, R+T ~ 1 for each pol. Uses Poynting power; implementation-agnostic.
    """
    wl = 700e-9
    stack = [
        Layer(iso_material(1.0), None),
        Layer(iso_material(1.3), 120e-9),
        Layer(iso_material(1.8),  50e-9),
        Layer(iso_material(1.4), 220e-9),
        Layer(iso_material(1.5), None),
    ]
    out = solve_stack_4x4(stack, wl, theta=np.deg2rad(17), phi=0.7)
    npt.assert_allclose(out["R_s"] + out["T_s"], 1.0, atol=2e-3)
    npt.assert_allclose(out["R_p"] + out["T_p"], 1.0, atol=2e-3)


def test_absorption_is_nonnegative_level_2():
    """
    Absorbing layer sanity: A = 1 - (R+T) ≥ 0 (within tiny numerical tolerance).
    Ensures branch and power definitions behave with complex n.
    """
    wl = 532e-9
    n0, n2 = 1.0, 1.5
    n_abs = 1.6 + 0.05j  # absorbing film
    air  = Layer(iso_material(n0), None)
    film = Layer(iso_material(n_abs), 200e-9)
    sub  = Layer(iso_material(n2), None)
    out = solve_stack_4x4([air, film, sub], wl, theta=np.deg2rad(20), phi=0.0)
    for pol in ("s","p"):
        A = 1.0 - (out[f"R_{pol}"] + out[f"T_{pol}"])
        assert A >= -5e-3  # allow tiny negative due to numerical roundoff


def test_absorptance_over_50pct_thick_lossy_layer_level_2():
    """
    Air | lossy film | air, normal incidence.
    Film: n = 2.0 + 1.5i, thickness = 10 µm, wavelength = 1 µm.
    Expect strong absorption: A_s > 0.5 and A_p > 0.5.
    """
    wl = 1.0e-6              # 1 µm
    d  = 10.0e-6             # 10 µm
    n_lossy = 2.0 + 1.5j     # strong absorber, modest front-face reflection

    stack = [
        Layer(iso_material(1.0), None),      # incident half-space (air)
        Layer(iso_material(n_lossy), d),     # lossy film
        Layer(iso_material(1.0), None),      # substrate half-space (air)
    ]

    out = solve_stack_smatrix(stack, wl, theta=0.0, phi=0.0)  # or solve_stack_4x4 if that's your entry
    for pol in ("s", "p"):
        R = out[f"R_{pol}"]
        T = out[f"T_{pol}"]
        A = 1.0 - (R + T)
        assert A > 0.5, f"Absorptance not >50% for {pol}-pol: R={R:.4f}, T={T:.4f}, A={A:.4f}"






## run level 2 tests

In [None]:

if __name__ == "__main__":
    run_all_tests( level = "level_2")


# Level 3 tests

In [None]:
# water_cases_spectrum.py
"""
Spectral T/R/A plots for air↔water interface and water films using your Berreman solver.

Scenarios:
1) Air → Water interface
2) Water film of 1 nm in air
3) Water film of 1 µm in air
4) Water film of 1 cm in air
5) Water film of 10 cm in air

Each test:
- Sweeps wavelengths from 250 nm to 1000 nm (linear spacing).
- Computes R, T, A using your solver with Poynting power.
- Plots the *unpolarized average* curves (T_avg, R_avg, A_avg).
- Saves a PNG and (optionally) displays the figure.

Toggle behavior at the top via SHOW_PLOTS / SAVE_PLOTS.
"""

import sys, inspect, traceback
import numpy as np
import numpy.testing as npt
import matplotlib.pyplot as plt

# ---------- Config ----------
SHOW_PLOTS = True     # Set False if running headless; True shows plots in Jupyter/desktop
SAVE_PLOTS = True     # Set False if you don't want PNGs
NUM_POINTS = 300      # # of wavelength samples between 250–1000 nm (inclusive)


# ---------- Reusable helpers ----------

def simulate_water(thickness_m, wavelength_m, theta_rad=0.0, phi_rad=0.0,
                   n_air=1.0+0j, n_water=1.333+0j):
    """
    Build the air↔water case and run your Berreman solver.

    - thickness_m = None  → pure interface: Air | Water (half-spaces)
    - thickness_m = d     → slab: Air | Water(d) | Air

    Returns dict with s/p and avg:
      {
        "R": {"s": Rs, "p": Rp, "avg": (Rs+Rp)/2},
        "T": {"s": Ts, "p": Tp, "avg": (Ts+Tp)/2},
        "A": {"s": 1-(Rs+Ts), "p": 1-(Rp+Tp), "avg": ...},
        "raw": original solver dict
      }
    """
    if thickness_m is None:
        stack = [Layer(iso_material(n_air), None),
                 Layer(iso_material(n_water), None)]
    else:
        stack = [Layer(iso_material(n_air), None),
                 Layer(iso_material(n_water), float(thickness_m)),
                 Layer(iso_material(n_air), None)]
    out = solve_stack_4x4(stack, wavelength_m, theta=theta_rad, phi=phi_rad)

    Rs, Rp = float(out["R_s"]), float(out["R_p"])
    Ts, Tp = float(out["T_s"]), float(out["T_p"])
    As, Ap = 1.0 - (Rs + Ts), 1.0 - (Rp + Tp)

    return {
        "R": {"s": Rs, "p": Rp, "avg": 0.5*(Rs+Rp)},
        "T": {"s": Ts, "p": Tp, "avg": 0.5*(Ts+Tp)},
        "A": {"s": As, "p": Ap, "avg": 0.5*(As+Ap)},
        "raw": out,
    }


def spectrum_water(thickness_m, wl_min_nm=250.0, wl_max_nm=1000.0, num=NUM_POINTS,
                   theta_deg=0.0, phi_deg=0.0, n_air=1.0+0j, n_water=1.333+0j):
    """
    Sweep wavelength and return arrays for T/R/A (avg, s, p).

    Returns:
      wl_nm:  (num,) array of wavelengths in nm
      spec:   dict with keys "avg","s","p", each containing subdicts "T","R","A" -> arrays
    """
    wl_nm = np.linspace(wl_min_nm, wl_max_nm, int(num))
    wl_m  = wl_nm * 1e-9
    theta = np.deg2rad(theta_deg)
    phi   = np.deg2rad(phi_deg)

    T_avg = np.empty_like(wl_m, dtype=float); R_avg = np.empty_like(wl_m, dtype=float); A_avg = np.empty_like(wl_m, dtype=float)
    T_s   = np.empty_like(wl_m, dtype=float); R_s   = np.empty_like(wl_m, dtype=float); A_s   = np.empty_like(wl_m, dtype=float)
    T_p   = np.empty_like(wl_m, dtype=float); R_p   = np.empty_like(wl_m, dtype=float); A_p   = np.empty_like(wl_m, dtype=float)

    for i, wl in enumerate(wl_m):
        res = simulate_water(thickness_m, wavelength_m=wl, theta_rad=theta, phi_rad=phi,
                             n_air=n_air, n_water=n_water)
        T_avg[i], R_avg[i], A_avg[i] = res["T"]["avg"], res["R"]["avg"], res["A"]["avg"]
        T_s[i],   R_s[i],   A_s[i]   = res["T"]["s"],   res["R"]["s"],   res["A"]["s"]
        T_p[i],   R_p[i],   A_p[i]   = res["T"]["p"],   res["R"]["p"],   res["A"]["p"]

    spec = {
        "avg": {"T": T_avg, "R": R_avg, "A": A_avg},
        "s":   {"T": T_s,   "R": R_s,   "A": A_s},
        "p":   {"T": T_p,   "R": R_p,   "A": A_p},
    }
    return wl_nm, spec


def plot_tra(wl_nm, T, R, A, title, save_path=None, show=SHOW_PLOTS):
    """
    Make a single plot of T, R, A vs wavelength (nm). Saves and/or shows.

    Parameters
    ----------
    wl_nm : array
    T, R, A : arrays (same length as wl_nm)
    title : str
    save_path : str or None
    show : bool
    """
    fig = plt.figure(figsize=(7.5, 4.5))
    ax = fig.add_subplot(111)
    ax.plot(wl_nm, T, label="T (avg)")
    ax.plot(wl_nm, R, label="R (avg)")
    ax.plot(wl_nm, A, label="A (avg)")
    ax.set_xlabel("Wavelength (nm)")
    ax.set_ylabel("Power coefficient")
    ax.set_title(title)
    ax.set_xlim(wl_nm.min(), wl_nm.max())
    ax.set_ylim(-0.02, 1.02)
    ax.grid(True, alpha=0.3)
    ax.legend(loc="best")
    fig.tight_layout()

    if SAVE_PLOTS and save_path:
        fig.savefig(save_path, dpi=130)
        print(f"[saved] {save_path}")
    if show:
        plt.show()
    else:
        plt.close(fig)


# ---------- Tests that also plot ----------

def test_plot_air_to_water_interface_spectrum_level_3():
    """Air → water interface at normal incidence: R~2%, T~98%, A~0 across the band (lossless)."""
    wl_nm, spec = spectrum_water(None)
    # Energy & bounds sanity (unpolarized averages)
    npt.assert_allclose(spec["avg"]["T"] + spec["avg"]["R"] + spec["avg"]["A"], 1.0, atol=5e-4)
    assert np.all((spec["avg"]["T"] >= -1e-4) & (spec["avg"]["T"] <= 1+1e-4))
    assert np.all((spec["avg"]["R"] >= -1e-4) & (spec["avg"]["R"] <= 1+1e-4))
    plot_tra(wl_nm, spec["avg"]["T"], spec["avg"]["R"], spec["avg"]["A"],
             title="Air → Water Interface (normal incidence)",
             save_path="air_to_water_interface_TRA.png")


def test_plot_water_film_1nm_spectrum_level_3():
    """Water slab 1 nm in air at normal incidence: symmetric → R≈0, T≈1, A≈0; minor oscillations possible."""
    wl_nm, spec = spectrum_water(1e-9)
    npt.assert_allclose(spec["avg"]["T"] + spec["avg"]["R"] + spec["avg"]["A"], 1.0, atol=5e-4)
    plot_tra(wl_nm, spec["avg"]["T"], spec["avg"]["R"], spec["avg"]["A"],
             title="Water Film (1 nm) in Air — TRA vs λ",
             save_path="water_film_1nm_TRA.png")


def test_plot_water_film_1um_spectrum_level_3():
    """Water slab 1 µm in air: clear Fabry–Perot fringes; enforce physical bounds + energy conservation."""
    wl_nm, spec = spectrum_water(1e-6)
    npt.assert_allclose(spec["avg"]["T"] + spec["avg"]["R"] + spec["avg"]["A"], 1.0, atol=2e-3)
    assert np.all((spec["avg"]["T"] >= -1e-3) & (spec["avg"]["T"] <= 1+1e-3))
    assert np.all((spec["avg"]["R"] >= -1e-3) & (spec["avg"]["R"] <= 1+1e-3))
    plot_tra(wl_nm, spec["avg"]["T"], spec["avg"]["R"], spec["avg"]["A"],
             title="Water Film (1 µm) in Air — TRA vs λ",
             save_path="water_film_1um_TRA.png")


def test_plot_water_film_1cm_spectrum_level_3():
    """Water slab 1 cm in air: many fringes; check bounds + near energy balance (lossless)."""
    wl_nm, spec = spectrum_water(1e-2)
    npt.assert_allclose(spec["avg"]["T"] + spec["avg"]["R"] + spec["avg"]["A"], 1.0, atol=2e-3)
    plot_tra(wl_nm, spec["avg"]["T"], spec["avg"]["R"], spec["avg"]["A"],
             title="Water Film (1 cm) in Air — TRA vs λ",
             save_path="water_film_1cm_TRA.png")


def test_plot_water_film_10cm_spectrum_level_3():
    """Water slab 10 cm in air: extremely thick; still bounded and energy-conserving."""
    wl_nm, spec = spectrum_water(1e-1)
    npt.assert_allclose(spec["avg"]["T"] + spec["avg"]["R"] + spec["avg"]["A"], 1.0, atol=2e-3)
    plot_tra(wl_nm, spec["avg"]["T"], spec["avg"]["R"], spec["avg"]["A"],
             title="Water Film (10 cm) in Air — TRA vs λ",
             save_path="water_film_10cm_TRA.png")








In [None]:
def test_equivalence_transfer_vs_smatrix_level_3():
    wl = 633e-9; theta = np.deg2rad(17); phi = 0.3
    stack = [
        Layer(iso_material(1.0), None),
        Layer(uniaxial_material(1.55, 1.65, R=R_from_euler(0.2, -0.3, 0.1)), 120e-9),
        Layer(iso_material(1.3+0.02j), 90e-9),
        Layer(iso_material(1.7), None),
    ]
    outS = solve_stack_smatrix(stack, wl, theta, phi)
    # If you kept the original transfer solver as solve_stack_4x4_orig:
    # outT = solve_stack_4x4_orig(stack, wl, theta, phi)
    # npt.assert_allclose(outS["R_s"], outT["R_s"], rtol=1e-6, atol=1e-8)
    # npt.assert_allclose(outS["R_p"], outT["R_p"], rtol=1e-6, atol=1e-8)
    assert np.isfinite(outS["R_s"]) and np.isfinite(outS["R_p"])

def test_stability_many_layers_absorbing_level_3():
    wl = 550e-9; theta = np.deg2rad(20); phi = 0.0
    rng = np.random.default_rng(0)
    layers = [Layer(iso_material(1.0), None)]
    for _ in range(120):
        n = 1.4 + 0.2*rng.random() + 1j*(0.05*rng.random())
        d = (5e-9 + 15e-9*rng.random())
        layers.append(Layer(iso_material(n), d))
    layers.append(Layer(iso_material(1.5), None))
    out = solve_stack_smatrix(layers, wl, theta, phi)
    for pol in ("s","p"):
        R, T = out[f"R_{pol}"], out[f"T_{pol}"]
        A = 1.0 - (R + T)
        assert -1e-6 <= R <= 1+1e-6
        assert -1e-6 <= T <= 1+1e-6
        assert A >= -1e-4



## run level 3 tests

In [None]:
if __name__ == "__main__":
    run_all_tests( level = "level_3")

# birrefringence tests

In [None]:
# ===================== BIREFRINGENT SLAB TEST SUITE =====================
import numpy as np
from dataclasses import dataclass
from scipy.optimize import minimize_scalar

@dataclass
class BirefringentSlabResult:
    """Container for birefringent slab test results"""
    description: str
    theta: float
    phi: float
    R_s: float
    R_p: float
    T_s: float
    T_p: float
    phase_s: float
    phase_p: float

def test_birefringent_slab_normal_incidence_level_4():
    """
    Test A: Normal incidence on birefringent slab with optical axis in x-direction.
    Validates phase difference between ordinary and extraordinary rays.
    """
    wl = 500e-9
    d = 5.5e-6  # Quarter-wave thickness for Δn=0.2 at 500nm
    no, ne = 1.5, 1.7
    delta_n = ne - no
    print(f"\n=== DEBUG: Birefringent slab normal incidence (d={d*1e6:.2f}μm) ===")
    print(f"n_o={no}, n_e={ne}, Δn={delta_n}")
    
    # Rotation to align optical axis with x-axis
    R_mat = R_from_euler(0, np.pi/2, 0)
    print(f"Rotation matrix:\n{R_mat}")
    birefringent = uniaxial_material(no, ne, R=R_mat)
    eps = birefringent.eps(wl)
    print(f"Effective permittivity tensor:\n{eps}")
    
    stack = [
        Layer(iso_material(1.0), None),       # Air incident
        Layer(birefringent, d),               # Birefringent slab
        Layer(iso_material(1.0), None)        # Air substrate
    ]
    
    # Run simulation at normal incidence
    print("\n[Solving stack]")
    res = solve_stack_4x4(stack, wl, theta=0, phi=0)
    
    # Calculate expected phase difference
    k0 = 2*np.pi/wl
    delta_expected = delta_n * k0 * d
    print(f"\nExpected phase difference: {delta_expected:.6f} rad ({np.rad2deg(delta_expected):.2f}°)")
    
    # Extract complex transmission coefficients
    t_s = res['t_s']
    t_p = res['t_p']
    print(f"\nTransmission coefficients:")
    print(f"t_s = {t_s:.6f} (mag={np.abs(t_s):.4f}, phase={np.angle(t_s):.6f} rad)")
    print(f"t_p = {t_p:.6f} (mag={np.abs(t_p):.4f}, phase={np.angle(t_p):.6f} rad)")
    
    # Calculate magnitude and phase
    mag_s, phase_s = np.abs(t_s), np.angle(t_s)
    mag_p, phase_p = np.abs(t_p), np.angle(t_p)
    
    # Calculate phase difference with proper unwrapping
    phase_diff = phase_p - phase_s
    print(f"Raw phase difference: {phase_diff:.6f} rad")
    
    # Apply phase unwrapping
    if phase_diff < -np.pi:
        phase_diff += 2*np.pi
        print("Adjusted phase_diff: +2π")
    elif phase_diff > np.pi:
        phase_diff -= 2*np.pi
        print("Adjusted phase_diff: -2π")
    
    print(f"Final phase difference: {phase_diff:.6f} rad ({np.rad2deg(phase_diff):.2f}°)")
    
    # Verify within 90° ± 10°
    try:
        npt.assert_allclose(phase_diff, delta_expected, rtol=0.1, atol=0.1,
                            err_msg="Phase difference doesn't match birefringence")
        print("[PASS] Phase difference matches expected value")
    except AssertionError as e:
        print(f"[FAIL] {str(e)}")
        # Calculate error metrics
        abs_error = abs(phase_diff - delta_expected)
        rel_error = abs_error / abs(delta_expected)
        print(f"Absolute error: {abs_error:.6f} rad")
        print(f"Relative error: {rel_error:.4f}")
        raise
    
    # Verify retardance is quarter wave (90°)
    phase_diff_deg = np.rad2deg(phase_diff)
    print(f"Retardance: {abs(phase_diff_deg):.2f}°")
    try:
        assert 80 < abs(phase_diff_deg) < 100, "Not quarter-wave retardance"
        print("[PASS] Quarter-wave retardance confirmed")
    except AssertionError as e:
        print(f"[FAIL] {str(e)}")
        raise
    
    return BirefringentSlabResult(
        "Normal incidence (optical axis in plane)",
        theta=0,
        phi=0,
        R_s=res["R_s"],
        R_p=res["R_p"],
        T_s=res["T_s"],
        T_p=res["T_p"],
        phase_s=phase_s,
        phase_p=phase_p
    )

def test_birefringent_slab_incidence_parallel_slow_axis_level_4():
    """
    Test B: Incidence in plane parallel to slow axis (optical axis=x).
    Validates different refraction angles for ordinary and extraordinary rays.
    """
    wl = 500e-9
    d = 1e-6
    no, ne = 1.5, 1.7
    theta_inc = np.deg2rad(30)
    
    # Rotation to align optical axis with x-axis
    R_mat = R_from_euler(0, np.pi/2, 0)
    birefringent = uniaxial_material(no, ne, R=R_mat)
    
    stack = [
        Layer(iso_material(1.0), None),
        Layer(birefringent, d),
        Layer(iso_material(1.0), None)
    ]
    
    # Incidence plane: xz (parallel to optical axis)
    res = solve_stack_4x4(stack, wl, theta=theta_inc, phi=0)
    
    # Energy conservation check (lossless material)
    npt.assert_allclose(res["R_s"] + res["T_s"], 1.0, atol=1e-5,
                        err_msg="Energy not conserved for s-pol")
    npt.assert_allclose(res["R_p"] + res["T_p"], 1.0, atol=1e-5,
                        err_msg="Energy not conserved for p-pol")
    
    return BirefringentSlabResult(
        "Incidence parallel to slow axis (xz-plane)",
        theta=np.rad2deg(theta_inc),
        phi=0,
        R_s=res["R_s"],
        R_p=res["R_p"],
        T_s=res["T_s"],
        T_p=res["T_p"],
        phase_s=np.angle(res['t_s']),
        phase_p=np.angle(res['t_p'])
    )

def test_birefringent_slab_incidence_parallel_fast_axis_level_4():
    """
    Test C: Incidence in plane parallel to fast axis (y-direction).
    Validates symmetric behavior with Test B but different polarization response.
    """
    wl = 500e-9
    d = 1e-6
    no, ne = 1.5, 1.7
    theta_inc = np.deg2rad(30)
    
    # Rotation to align optical axis with x-axis
    R_mat = R_from_euler(0, np.pi/2, 0)
    birefringent = uniaxial_material(no, ne, R=R_mat)
    
    stack = [
        Layer(iso_material(1.0), None),
        Layer(birefringent, d),
        Layer(iso_material(1.0), None)
    ]
    
    # Incidence plane: yz (perpendicular to optical axis)
    res = solve_stack_4x4(stack, wl, theta=theta_inc, phi=np.pi/2)
    
    # Energy conservation check
    npt.assert_allclose(res["R_s"] + res["T_s"], 1.0, atol=1e-5,
                        err_msg="Energy not conserved for s-pol")
    npt.assert_allclose(res["R_p"] + res["T_p"], 1.0, atol=1e-5,
                        err_msg="Energy not conserved for p-pol")
    
    return BirefringentSlabResult(
        "Incidence parallel to fast axis (yz-plane)",
        theta=np.rad2deg(theta_inc),
        phi=90,
        R_s=res["R_s"],
        R_p=res["R_p"],
        T_s=res["T_s"],
        T_p=res["T_p"],
        phase_s=np.angle(res['t_s']),
        phase_p=np.angle(res['t_p'])
    )

def test_birefringent_slab_oblique_incidence_level_4():
    """
    Test D: Oblique incidence relative to optical axis.
    Validates general case with polarization mixing.
    """
    wl = 500e-9
    d = 1e-6
    no, ne = 1.5, 1.7
    theta_inc = np.deg2rad(30)
    phi_inc = np.deg2rad(45)  # Oblique to both axes
    
    # Rotation to align optical axis with x-axis
    R_mat = R_from_euler(0, np.pi/2, 0)
    birefringent = uniaxial_material(no, ne, R=R_mat)
    
    stack = [
        Layer(iso_material(1.0), None),
        Layer(birefringent, d),
        Layer(iso_material(1.0), None)
    ]
    
    res = solve_stack_4x4(stack, wl, theta=theta_inc, phi=phi_inc)
    
    # Energy conservation check
    npt.assert_allclose(res["R_s"] + res["T_s"], 1.0, atol=1e-5,
                        err_msg="Energy not conserved for s-pol")
    npt.assert_allclose(res["R_p"] + res["T_p"], 1.0, atol=1e-5,
                        err_msg="Energy not conserved for p-pol")
    
    return BirefringentSlabResult(
        "Oblique incidence (45° to optical axis)",
        theta=np.rad2deg(theta_inc),
        phi=45,
        R_s=res["R_s"],
        R_p=res["R_p"],
        T_s=res["T_s"],
        T_p=res["T_p"],
        phase_s=np.angle(res['t_s']),
        phase_p=np.angle(res['t_p'])
    )

def run_birefringent_tests_level_4():
    """Execute all birefringent slab tests and print results"""
    tests = [
        test_birefringent_slab_normal_incidence_level_4,
        test_birefringent_slab_incidence_parallel_slow_axis_level_4,
        test_birefringent_slab_incidence_parallel_fast_axis_level_4,
        test_birefringent_slab_oblique_incidence_level_4,
        test_birefringent_interface_normal_incidence_level_4
    ]
    
    print("Running birefringent slab tests:\n" + "-"*40)
    results = []
    for test in tests:
        try:
            results.append(test())
            print(f"[PASS] {test.__name__}")
        except Exception as e:
            print(f"[FAIL] {test.__name__}: {str(e)}")
    
    # Print formatted results
    print("\nTest Results Summary:")
    print("="*85)
    print(f"{'Description':<45} | {'θ':>4} | {'φ':>4} | {'R_s':>8} | {'R_p':>8} | {'T_s':>8} | {'T_p':>8} | {'Phase Diff':>12}")
    print("-"*85)
    for res in results:
        # Calculate phase difference in degrees
        phase_diff_deg = np.rad2deg(res.phase_p - res.phase_s)
        
        # Normalize to [-180, 180] range
        phase_diff_deg = (phase_diff_deg + 180) % 360 - 180
        
        print(f"{res.description:<45} | {res.theta:4.1f} | {res.phi:4.0f} | "
              f"{res.R_s:8.5f} | {res.R_p:8.5f} | {res.T_s:8.5f} | {res.T_p:8.5f} | "
              f"{phase_diff_deg:9.2f}°")
    print("="*85)

# ===================== Birefringence sanity tests (verbose, Level_1) =====================

def _wrap_phase_diff(a, b):
    """Return |a-b| wrapped to [0, π]."""
    import numpy as np
    d = np.mod(abs(a - b), 2*np.pi)
    return float(2*np.pi - d if d > np.pi else d)

def test_debug_waveplate_retardance_normal_incidence_level_4():
    """
    Uniaxial film acts as a waveplate when the optic axis is *in-plane*.
    Here we put the OA along +x (β = +π/2 about y) at normal incidence, so:
        Ex (TM) → extraordinary (n_e)
        Ey (TE) → ordinary      (n_o)
    For thickness d = λ / [4 (n_e - n_o)] the retardance Δ = φ_e - φ_o ≈ π/2.

    We:
      1) Build eps for a rotated uniaxial medium with OA → +x.
      2) Read the forward eigen k_z from _halfspace_modes and check they match ~{n_e, n_o}.
      3) Build the propagation S for thickness d and check the two t-phases differ by ~π/2.
    """
    import numpy as np
    print("\n=== DEBUG: Birefringent waveplate (OA in-plane, normal incidence) ===")
    wl = 550e-9
    k0 = 2*np.pi/wl
    no, ne = 1.50, 1.60
    beta = np.pi/2  # rotate OA (originally z) → +x via Y(+π/2)
    R_oa_x = R_from_euler(0.0, beta, 0.0)
    film   = uniaxial_material(no, ne, R=R_oa_x)
    eps    = film.eps(wl)

    # Quarter-wave thickness
    d = wl / (4.0*(ne - no))
    print(f"n_o={no:.5f}, n_e={ne:.5f}, Δn={ne-no:.5f}, λ={wl:.3e} m, d={d:.6e} m (≈ quarter-wave)")

    
        
    # Modes @ normal incidence
    kx = ky = 0.0

    # Add diagnostic output before _halfspace_modes call
    print("\n=== Pre-mode calculation diagnostics ===")
    print(f"kx={kx}, ky={ky}, k0={k0}")
    print("Epsilon tensor eigenvalues:", np.linalg.eigvals(eps))
    
    # Calculate modes
    kz, Fp, Fm = _halfspace_modes(eps, k0, kx, ky, assume_isotropic=False)
    
    # Add more diagnostics
    print("\n=== Mode calculation results ===")
    print("Raw kz values:", kz)
    print("kz/k0:", kz/k0)
    print("Field profiles Fp:")
    for i, mode in enumerate(Fp.T):
        print(f"Mode {i}:", mode)
    #######################################################
    
    kz, Fp, Fm = _halfspace_modes(eps, k0, kx, ky, assume_isotropic=False)
    kz_over_k0 = np.real(kz/k0)
    print("Forward kz/k0 (unordered) ≈", kz_over_k0, "  (expect {~n_e, ~n_o})")
    np.testing.assert_allclose(np.sort(kz_over_k0), np.sort([ne, no]), rtol=0, atol=5e-6)

    # Retardance from eigen k_z
    phi = np.real(kz) * d
    Delta_kz = _wrap_phase_diff(phi[0], phi[1])
    print(f"Retardance via kz: Δ = {Delta_kz:.6f} rad  (target π/2 = {np.pi/2:.6f})")
    np.testing.assert_allclose(Delta_kz, np.pi/2, rtol=0, atol=5e-3)

    # Retardance from the propagation S-matrix (diagonal in the film's eigen-basis)
    Sprop = propagation_smatrix_in_medium(eps, k0, kx, ky, d, assume_isotropic=False)
    t = Sprop['t']
    print("Propagation S (expect diagonal with two unit-magnitude phase factors):\n", t)
    ph = np.unwrap(np.angle(np.diag(t)))
    Delta_t = _wrap_phase_diff(ph[0], ph[1])
    print(f"Retardance via S_prop: Δ = {Delta_t:.6f} rad")
    np.testing.assert_allclose(Delta_t, np.pi/2, rtol=0, atol=5e-3)

def test_birefringent_interface_normal_incidence_level_4():
    """
    Tests ONLY the air→birefringent interface at normal incidence.
    Validates that reflection coefficients match Fresnel expectations for
    ordinary and extraordinary rays.
    """
    wl = 500e-9
    no, ne = 1.5, 1.7
    
    # Rotation to align optical axis with x-axis
    R_mat = R_from_euler(0, np.pi/2, 0)
    birefringent = uniaxial_material(no, ne, R=R_mat)
    
    # Single interface: Air → Birefringent (half-space)
    stack = [
        Layer(iso_material(1.0), None),       # Air incident
        Layer(birefringent, None)             # Birefringent substrate
    ]
    
    # Run simulation at normal incidence
    res = solve_stack_4x4(stack, wl, theta=0, phi=0)
    
    # Expected Fresnel coefficients (normal incidence)
    r_s_expected = (1 - no) / (1 + no)
    r_p_expected = (1 - ne) / (1 + ne)
    
    print("\n=== Birefringent Interface Test ===")
    print(f"Expected s-pol reflection: {r_s_expected:.6f}")
    print(f"Actual s-pol reflection:   {res['r_s']:.6f}")
    print(f"Expected p-pol reflection: {r_p_expected:.6f}")
    print(f"Actual p-pol reflection:   {res['r_p']:.6f}")
    
    # Verify reflection coefficients
    npt.assert_allclose(res['r_s'], r_s_expected, rtol=0.01, atol=0.01,
                        err_msg="s-pol reflection doesn't match ordinary ray")
    npt.assert_allclose(res['r_p'], r_p_expected, rtol=0.01, atol=0.01,
                        err_msg="p-pol reflection doesn't match extraordinary ray")
    
    # Verify power conservation
    npt.assert_allclose(res['R_s'] + res['T_s'], 1.0, atol=1e-5,
                        err_msg="Energy not conserved for s-pol")
    npt.assert_allclose(res['R_p'] + res['T_p'], 1.0, atol=1e-5,
                        err_msg="Energy not conserved for p-pol")
    
    return {
        "r_s_actual": res['r_s'],
        "r_s_expected": r_s_expected,
        "r_p_actual": res['r_p'],
        "r_p_expected": r_p_expected
    }

def test_atomic_birefringent_phase_retardation_level_4():
    """Atomic test for phase retardation calculation"""
    wl = 550e-9
    no, ne = 1.5, 1.6
    d = wl/(4*(ne-no))  # Quarter-wave plate
    
    # Build rotated uniaxial material (OA along x)
    eps = np.diag([ne**2, no**2, no**2])
    
    # Expected phase difference
    delta_expected = (ne - no) * (2*np.pi/wl) * d  # π/2
    
    # Calculate using S-matrix
    k0 = 2*np.pi/wl
    _, Fp, Fm = _halfspace_modes(eps, k0, 0, 0, False)
    S_prop = propagation_smatrix_in_medium(eps, k0, 0, 0, d, False)
    
    # Sub-test 1: Verify diagonal propagation matrix
    assert np.allclose(np.diag(np.abs(S_prop['t'])), [1,1], atol=1e-6)
    
    # Sub-test 2: Check phase difference
    phases = np.angle(np.diag(S_prop['t']))
    phase_diff = (phases[1] - phases[0] + np.pi) % (2*np.pi) - np.pi
    assert abs(phase_diff - delta_expected) < 1e-2
    
    # Sub-test 3: Verify eigenmodes
    assert abs(np.linalg.norm(Fp[:,0]) - 1) < 1e-6  # s-pol
    assert abs(np.linalg.norm(Fp[:,1]) - 1) < 1e-6  # p-pol
    
def test_atomic_birefringent_interface_reflection_level_4():
    """Atomic test for interface reflection coefficients"""
    wl = 500e-9
    no, ne = 1.5, 1.7
    
    # Test 1: Verify s-pol reflection (ordinary ray)
    eps_air = np.eye(3) * (1.0**2)
    eps_biref = np.diag([ne**2, no**2, no**2])  # OA along x
    
    # Calculate using Fresnel
    r_s_expected = (1 - no)/(1 + no)  # -0.2
    r_p_expected = (1 - ne)/(1 + ne)  # -0.259259
    
    # Calculate using S-matrix
    k0 = 2*np.pi/wl
    _, Fp_air, Fm_air = _halfspace_modes(eps_air, k0, 0, 0, True)
    _, Fp_biref, Fm_biref = _halfspace_modes(eps_biref, k0, 0, 0, False)
    
    S = smatrix_from_matching(Fp_air, Fm_air, Fp_biref, Fm_biref)
    r_s_calc = S['r'][0,0]  # s-pol reflection
    r_p_calc = S['r'][1,1]  # p-pol reflection
    
    # Sub-test 1: s-pol magnitude
    assert abs(abs(r_s_calc) - abs(r_s_expected)) < 1e-6
    
    # Sub-test 2: p-pol magnitude 
    assert abs(abs(r_p_calc) - abs(r_p_expected)) < 1e-6
    
    # Sub-test 3: s-pol sign
    assert np.sign(r_s_calc.real) == np.sign(r_s_expected)
    
    # Sub-test 4: p-pol sign
    assert np.sign(r_p_calc.real) == np.sign(r_p_expected)
def test_debug_birefringent_no_mixing_when_axis_aligned_level_4():
    """
    With OA in-plane **aligned** to +x (same setup as above) and normal incidence,
    the slab should NOT mix polarizations in the global (x,y) basis: t, r are ~diagonal.
    This checks that your interface matching + film propagation respect the symmetry.
    """
    import numpy as np
    print("\n=== DEBUG: No s↔p mixing when OA || x, normal incidence ===")
    wl = 550e-9; k0 = 2*np.pi/wl
    n0 = 1.0
    no, ne = 1.50, 1.60
    beta = np.pi/2
    R_oa_x = R_from_euler(0.0, beta, 0.0)
    eps_air = (n0*n0)*np.eye(3, dtype=complex)
    eps_film = uniaxial_material(no, ne, R=R_oa_x).eps(wl)
    d = wl/(7.0*(ne-no))  # arbitrary thin-ish plate

    kx = ky = 0.0
    # Port/film mode bases
    _, FpL, FmL = _halfspace_modes(eps_air,  k0, kx, ky, True)
    _, FpF, FmF = _halfspace_modes(eps_film, k0, kx, ky, False)
    _, FpR, FmR = _halfspace_modes(eps_air,  k0, kx, ky, True)

    # Compose: air|film(d)|air
    S_if_L = smatrix_from_matching(FpL, FmL, FpF, FmF)
    S_pr   = propagation_smatrix_in_medium(eps_film, k0, kx, ky, d, False)
    S_if_R = smatrix_from_matching(FpF, FmF, FpR, FmR)
    S_tmp  = redheffer_star(S_pr,  S_if_L)
    S_tot  = redheffer_star(S_if_R, S_tmp)

    r, t = S_tot['r'], S_tot['t']
    print("t (global TE/TM basis):\n", t)
    print("r (global TE/TM basis):\n", r)

    # Off-diagonals ~ 0
    tol = 5e-6
    assert abs(t[0,1]) < tol and abs(t[1,0]) < tol, "Expected no mixing in transmission when OA || x"
    assert abs(r[0,1]) < tol and abs(r[1,0]) < tol, "Expected no mixing in reflection when OA || x"

    # Energy check
    power = _power_RT_from_S(FpL, FmL, FpR, r, t)
    print("Power (should sum to ~1 for each pol):", power)
    for pol in ("s","p"):
        np.testing.assert_allclose(power[f"R_{pol}"]+power[f"T_{pol}"], 1.0, atol=5e-3)

def test_atomic_birefringent_coordinate_system_level_4():
    """Verify coordinate system alignment"""
    wl = 500e-9
    no, ne = 1.5, 1.7
    
    # Test configuration with OA along x
    eps = np.diag([ne**2, no**2, no**2])
    k0 = 2*np.pi/wl
    
    # Sub-test 1: Normal incidence modes
    _, Fp, _ = _halfspace_modes(eps, k0, 0, 0, False)
    
    # First column should be ordinary (Ey, Hx) mode
    assert abs(Fp[1,0]) > 0.9  # Ey dominant
    assert abs(Fp[2,0]) > 0.9  # Hx dominant
    
    # Second column should be extraordinary (Ex, Hy) mode
    assert abs(Fp[0,1]) > 0.9  # Ex dominant
    assert abs(Fp[3,1]) > 0.9  # Hy dominant
    
    # Sub-test 2: Power flow direction
    def power_flow(Psi):
        Ex, Ey, Hx, Hy = Psi
        return np.real(Ex*np.conj(Hy) - Ey*np.conj(Hx))
    
    assert power_flow(Fp[:,0]) > 0  # Forward ordinary
    assert power_flow(Fp[:,1]) > 0  # Forward extraordinary


# Fix the test case:
def test_fixed_phase_retardation_level_4():
    wl = 550e-9
    no, ne = 1.5, 1.6
    d = wl/(4*(ne-no))
    stack = [
        Layer(iso_material(1.0), None),
        Layer(uniaxial_material(no, ne, R=R_from_euler(0, np.pi/2, 0)), d),
        Layer(iso_material(1.0), None)
    ]
    res = solve_stack_4x4(stack, wl, 0, 0)
    assert np.isclose(res['phase_diff'], np.pi/2, rtol=0.01)
# Execute the tests when this script is run
if __name__ == "__main__":
    run_birefringent_tests_level_4()

## run level 4 tests

In [None]:

if __name__ == "__main__":
    run_all_tests( level = "level_4")

In [None]:
# Next