# dataclasses

In [1]:
# Berreman data classes and basic tests (Python 3.10, NumPy only)

from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Optional, Union
import numpy as np

ComplexOrCallable = Union[complex, Callable[[float], complex]]
TensorOrCallable  = Union[np.ndarray, complex, Callable[[float], Union[np.ndarray, complex]]]

# ---------- helpers ----------

def _wrap_tensor(arg: Optional[TensorOrCallable], *, default: np.ndarray, name: str) -> Callable[[float], np.ndarray]:
    """
    Normalize various inputs into a callable(wl)->3x3 complex tensor:
      - None           -> constant `default`
      - scalar         -> scalar * I
      - 3x3 arraylike  -> constant 3x3
      - callable(wl)   -> may return scalar or 3x3; scalar promoted to scalar*I
    """
    if arg is None:
        M = np.array(default, dtype=complex)
        return lambda wl: M

    if callable(arg):
        def fn(wl: float):
            val = arg(wl)
            arr = np.array(val, dtype=complex)
            if arr.ndim == 0:
                return (arr * np.eye(3, dtype=complex))
            if arr.shape == (3,3):
                return arr
            raise ValueError(f"{name}(wl) must return scalar or 3x3 tensor; got shape {arr.shape}")
        return fn

    arr = np.array(arg, dtype=complex)
    if arr.ndim == 0:
        M = arr * np.eye(3, dtype=complex)
        return lambda wl: M
    if arr.shape == (3,3):
        M = arr
        return lambda wl: M
    raise ValueError(f"{name} must be scalar or 3x3 tensor; got shape {arr.shape}")

def _as_complex_matrix(x, shape):
    arr = np.array(x, dtype=complex)
    if arr.shape != shape:
        raise ValueError(f"Expected shape {shape}, got {arr.shape}")
    return arr


def _is_rotation_matrix(R: np.ndarray, tol: float = 1e-9) -> bool:
    """Check orthonormality and det≈+1."""
    if R.shape != (3,3):
        return False
    RtR = R.conj().T @ R
    I = np.eye(3, dtype=complex)
    return (np.allclose(RtR, I, atol=tol, rtol=0) 
            and np.isclose(np.linalg.det(R).real, 1.0, atol=1e-9))

# -----------------------------
# Data classes
# -----------------------------

@dataclass(frozen=True)
class Geometry:
    wavelength: float         # meters
    theta: float              # radians (in incident medium)
    phi: float = 0.0          # radians (azimuth)

    def __post_init__(self):
        if not (isinstance(self.wavelength, (float, int)) and self.wavelength > 0):
            raise ValueError("Geometry.wavelength must be a positive float (meters).")
        if not isinstance(self.theta, (float, int)):
            raise ValueError("Geometry.theta must be a float (radians).")
        if not isinstance(self.phi, (float, int)):
            raise ValueError("Geometry.phi must be a float (radians).")

# --- your class with new constructors ---




@dataclass(frozen=True)
class Layer:
    material: Material
    d: float                           # thickness [m]
    R: Optional[np.ndarray] = None     # 3x3 rotation to lab (None -> identity)

    def __post_init__(self):
        if not isinstance(self.material, Material):
            raise ValueError("Layer.material must be a Material.")
        if not (isinstance(self.d, (float, int)) and self.d >= 0):
            raise ValueError("Layer.d must be a nonnegative float (meters).")
        if self.R is not None:
            R = np.array(self.R, dtype=complex)
            if R.shape != (3,3) or not _is_rotation_matrix(R):
                raise ValueError("Layer.R must be a proper 3x3 rotation matrix (orthonormal, det≈+1).")

@dataclass(frozen=True)
class HalfSpace:
    material: Material
    R: Optional[np.ndarray] = None

    def __post_init__(self):
        if not isinstance(self.material, Material):
            raise ValueError("HalfSpace.material must be a Material.")
        if self.R is not None:
            R = np.array(self.R, dtype=complex)
            if R.shape != (3,3) or not _is_rotation_matrix(R):
                raise ValueError("HalfSpace.R must be a proper 3x3 rotation matrix (orthonormal, det≈+1).")

# -----------------------------
# Convenience factory functions
# -----------------------------

def iso_material(n_of_wl: Callable[[float], complex]) -> Material:
    def eps_fn(wl: float) -> np.ndarray:
        n = complex(n_of_wl(wl))
        return (n**2) * np.eye(3, dtype=complex)
    return Material(eps_fn=eps_fn)

def uniaxial_z_material(no_of_wl: Callable[[float], complex],
                        ne_of_wl: Callable[[float], complex]) -> Material:
    def eps_fn(wl: float) -> np.ndarray:
        no, ne = complex(no_of_wl(wl)), complex(ne_of_wl(wl))
        return np.diag([no**2, no**2, ne**2]).astype(complex)
    return Material(eps_fn=eps_fn)

def R_from_euler(alpha: float, beta: float, gamma: float) -> np.ndarray:
    """ZXZ Euler angles (radians)."""
    ca, sa = np.cos(alpha), np.sin(alpha)
    cb, sb = np.cos(beta),  np.sin(beta)
    cg, sg = np.cos(gamma), np.sin(gamma)
    Rz1 = np.array([[ca,-sa,0],[sa,ca,0],[0,0,1]], float)
    Rx  = np.array([[1,0,0],[0,cb,-sb],[0,sb,cb]], float)  # actually rotation about x; ZXZ is optional here
    Rz2 = np.array([[cg,-sg,0],[sg,cg,0],[0,0,1]], float)
    return (Rz2 @ Rx @ Rz1).astype(complex)



### Material class

In [2]:
# ---------- core dataclass (strict) ----------

@dataclass(frozen=True)
class Material:
    """Callable providers for constitutive tensors (relative SI)."""
    eps_fn:  Callable[[float], np.ndarray]
    mu_fn:   Callable[[float], np.ndarray] = lambda wl: np.eye(3, dtype=complex)
    xi_fn:   Callable[[float], np.ndarray] = lambda wl: np.zeros((3,3), dtype=complex)
    zeta_fn: Callable[[float], np.ndarray] = lambda wl: np.zeros((3,3), dtype=complex)

    def eps(self, wl: float) -> np.ndarray:  return _as_complex_matrix(self.eps_fn(wl),  (3,3))
    def mu(self, wl: float)  -> np.ndarray:  return _as_complex_matrix(self.mu_fn(wl),   (3,3))
    def xi(self, wl: float)  -> np.ndarray:  return _as_complex_matrix(self.xi_fn(wl),   (3,3))
    def zeta(self, wl: float)-> np.ndarray:  return _as_complex_matrix(self.zeta_fn(wl), (3,3))

    # ---------- explicit factories ----------

    @classmethod
    def from_index(cls,
                   n: ComplexOrCallable,
                   *,
                   mu: TensorOrCallable = 1.0,
                   xi: TensorOrCallable = 0.0,
                   zeta: TensorOrCallable = 0.0):
        """Isotropic material with scalar index n (constant or dispersive)."""
        if callable(n):
            def eps_fn(wl: float):
                val = complex(n(wl))
                return (val**2) * np.eye(3, dtype=complex)
        else:
            n_const = complex(n)
            E = (n_const**2) * np.eye(3, dtype=complex)
            eps_fn = lambda wl: E

        mu_fn   = _wrap_tensor(mu,   default=np.eye(3, dtype=complex),    name="mu")
        xi_fn   = _wrap_tensor(xi,   default=np.zeros((3,3), complex),    name="xi")
        zeta_fn = _wrap_tensor(zeta, default=np.zeros((3,3), complex),    name="zeta")
        return cls(eps_fn=eps_fn, mu_fn=mu_fn, xi_fn=xi_fn, zeta_fn=zeta_fn)

    @classmethod
    def from_epsilon(cls,
                     epsilon: TensorOrCallable,
                     *,
                     mu:   TensorOrCallable = 1.0,
                     xi:   TensorOrCallable = 0.0,
                     zeta: TensorOrCallable = 0.0):
        """
        Constant or dispersive tensors.
        epsilon/mu/xi/zeta can be scalar, 3x3, or callable(wl) -> scalar or 3x3.
        """
        eps_fn  = _wrap_tensor(epsilon, default=np.eye(3, dtype=complex), name="epsilon")
        mu_fn   = _wrap_tensor(mu,      default=np.eye(3, dtype=complex), name="mu")
        xi_fn   = _wrap_tensor(xi,      default=np.zeros((3,3), complex), name="xi")
        zeta_fn = _wrap_tensor(zeta,    default=np.zeros((3,3), complex), name="zeta")
        return cls(eps_fn=eps_fn, mu_fn=mu_fn, xi_fn=xi_fn, zeta_fn=zeta_fn)

    @classmethod
    def from_uniaxial_z(cls,
                        no: ComplexOrCallable,
                        ne: ComplexOrCallable,
                        *,
                        mu:   TensorOrCallable = 1.0,
                        xi:   TensorOrCallable = 0.0,
                        zeta: TensorOrCallable = 0.0):
        """Uniaxial with optic axis along z (before any Layer.R)."""
        if callable(no) or callable(ne):
            def eps_fn(wl: float):
                n_o = complex(no(wl) if callable(no) else no)
                n_e = complex(ne(wl) if callable(ne) else ne)
                return np.diag([n_o**2, n_o**2, n_e**2]).astype(complex)
        else:
            E = np.diag([complex(no)**2, complex(no)**2, complex(ne)**2]).astype(complex)
            eps_fn = lambda wl: E

        mu_fn   = _wrap_tensor(mu,   default=np.eye(3, dtype=complex), name="mu")
        xi_fn   = _wrap_tensor(xi,   default=np.zeros((3,3), complex), name="xi")
        zeta_fn = _wrap_tensor(zeta, default=np.zeros((3,3), complex), name="zeta")
        return cls(eps_fn=eps_fn, mu_fn=mu_fn, xi_fn=xi_fn, zeta_fn=zeta_fn)

    @classmethod
    def from_biaxial(cls,
                     nx: ComplexOrCallable,
                     ny: ComplexOrCallable,
                     nz: ComplexOrCallable,
                     *,
                     mu:   TensorOrCallable = 1.0,
                     xi:   TensorOrCallable = 0.0,
                     zeta: TensorOrCallable = 0.0):
        """Biaxial (principal axes x,y,z before any Layer.R)."""
        def eval_n(v, wl):
            return complex(v(wl)) if callable(v) else complex(v)
        def eps_fn(wl: float):
            nxx, nyy, nzz = eval_n(nx, wl), eval_n(ny, wl), eval_n(nz, wl)
            return np.diag([nxx**2, nyy**2, nzz**2]).astype(complex)

        mu_fn   = _wrap_tensor(mu,   default=np.eye(3, dtype=complex), name="mu")
        xi_fn   = _wrap_tensor(xi,   default=np.zeros((3,3), complex), name="xi")
        zeta_fn = _wrap_tensor(zeta, default=np.zeros((3,3), complex), name="zeta")
        return cls(eps_fn=eps_fn, mu_fn=mu_fn, xi_fn=xi_fn, zeta_fn=zeta_fn)

# ---------- ergonomic dispatcher (for call sites/tests) ----------

def material(*, epsilon: Optional[TensorOrCallable] = None,
             n: Optional[ComplexOrCallable] = None,
             mu:   TensorOrCallable = 1.0,
             xi:   TensorOrCallable = 0.0,
             zeta: TensorOrCallable = 0.0) -> Material:
    """
    Friendly factory: provide exactly one of `epsilon` or `n`.
    """
    if (epsilon is None) == (n is None):
        raise ValueError("Provide exactly one of 'epsilon' or 'n'.")
    if epsilon is not None:
        return Material.from_epsilon(epsilon, mu=mu, xi=xi, zeta=zeta)
    return Material.from_index(n, mu=mu, xi=xi, zeta=zeta)

## data-class tests

In [3]:

# -----------------------------
# Tests
# -----------------------------

# ===== New data-class tests for Material =====

def _is_c33(x): 
    return isinstance(x, np.ndarray) and x.shape == (3,3) and x.dtype.kind == 'c'

def test_Material_from_index_constants():
    m = Material.from_index(1.6+0.1j)  # constant n
    wl = 500e-9
    eps, mu, xi, zeta = m.eps(wl), m.mu(wl), m.xi(wl), m.zeta(wl)
    assert _is_c33(eps) and _is_c33(mu) and _is_c33(xi) and _is_c33(zeta)
    n = 1.6+0.1j
    assert np.allclose(eps, (n**2)*np.eye(3))
    assert np.allclose(mu, np.eye(3))
    assert np.allclose(xi, 0)
    assert np.allclose(zeta, 0)

def test_Material_from_index_callable_with_mu_xi_zeta():
    # dispersive n(λ) and custom μ, ξ, ζ (scalars)
    n_of_wl = lambda wl: 1.5 + 0.05j
    mu_r = 1.02+0j
    xi_r = 0.001+0j
    zeta_r = 0.002+0j
    m = Material.from_index(n_of_wl, mu=mu_r, xi=xi_r, zeta=zeta_r)
    wl = 600e-9
    eps, mu, xi, zeta = m.eps(wl), m.mu(wl), m.xi(wl), m.zeta(wl)
    assert _is_c33(eps) and _is_c33(mu) and _is_c33(xi) and _is_c33(zeta)
    n = n_of_wl(wl)
    assert np.allclose(eps, (n**2)*np.eye(3))
    assert np.allclose(mu,   mu_r  * np.eye(3))
    assert np.allclose(xi,   xi_r  * np.eye(3))
    assert np.allclose(zeta, zeta_r* np.eye(3))

def test_Material_from_epsilon_scalar_and_tensors():
    # scalar epsilon, tensor mu; xi, zeta zeros
    eps_scalar = 2.3+0.2j
    mu_tensor = np.diag([1.1, 1.2, 1.3]).astype(complex)
    m = Material.from_epsilon(epsilon=eps_scalar, mu=mu_tensor)
    wl = 700e-9
    eps, mu, xi, zeta = m.eps(wl), m.mu(wl), m.xi(wl), m.zeta(wl)
    assert _is_c33(eps) and _is_c33(mu) and _is_c33(xi) and _is_c33(zeta)
    assert np.allclose(eps, eps_scalar*np.eye(3))
    assert np.allclose(mu,  mu_tensor)
    assert np.allclose(xi,  0)
    assert np.allclose(zeta,0)

    # callable epsilon that returns a 3x3 tensor + callable xi returning scalar
    eps_fn = lambda wl: np.diag([2.0+0.1j, 2.1+0.1j, 2.2+0.1j]).astype(complex)
    xi_fn  = lambda wl: 0.005+0j
    m2 = Material.from_epsilon(epsilon=eps_fn, xi=xi_fn)
    eps2, xi2 = m2.eps(550e-9), m2.xi(550e-9)
    assert _is_c33(eps2) and _is_c33(xi2)
    assert np.allclose(xi2, (0.005+0j)*np.eye(3))

def test_Material_from_uniaxial_z_constants_and_callable():
    # constants
    no, ne = 1.6+0j, 1.8+0j
    m = Material.from_uniaxial_z(no, ne)
    eps = m.eps(500e-9)
    assert _is_c33(eps)
    assert np.allclose(np.diag(eps), [(no**2), (no**2), (ne**2)])

    # callable no/ne and tensor zeta
    no_fn = lambda wl: 1.55+0.01j
    ne_fn = lambda wl: 1.75+0.02j
    zeta_tensor = np.array([[0,0,0],[0,0,0],[0,0,0.01]], dtype=complex)
    m2 = Material.from_uniaxial_z(no_fn, ne_fn, zeta=zeta_tensor)
    wl = 620e-9
    eps2, zeta2 = m2.eps(wl), m2.zeta(wl)
    assert _is_c33(eps2) and _is_c33(zeta2)
    assert np.allclose(np.diag(eps2),
                       [ (no_fn(wl)**2), (no_fn(wl)**2), (ne_fn(wl)**2) ])
    assert np.allclose(zeta2, zeta_tensor)

def test_Material_invalid_shapes_raise():
    # wrong epsilon shape
    try:
        Material.from_epsilon(np.ones((2,2)))
        raise AssertionError("Expected ValueError for bad epsilon shape")
    except ValueError:
        pass
    # callable returning wrong shape
    bad_eps_fn = lambda wl: np.ones((2,2))
    try:
        m = Material.from_epsilon(bad_eps_fn)
        _ = m.eps(500e-9)
        raise AssertionError("Expected ValueError when callable returns bad shape")
    except ValueError:
        pass

def test_Geometry():
    g = Geometry(wavelength=550e-9, theta=0.3, phi=0.1)
    assert isinstance(g, Geometry)
    try:
        Geometry(wavelength=-1.0, theta=0, phi=0)
        raise AssertionError("Negative wavelength should have raised ValueError.")
    except ValueError:
        pass

def test_Material_iso():
    air = iso_material(lambda wl: 1.0+0j)
    eps = air.eps(500e-9)
    mu = air.mu(500e-9)
    xi = air.xi(500e-9)
    zeta = air.zeta(500e-9)
    assert eps.shape == (3,3) and np.allclose(eps, np.eye(3))
    assert mu.shape == (3,3) and np.allclose(mu, np.eye(3))
    assert xi.shape == (3,3) and np.allclose(xi, 0)
    assert zeta.shape == (3,3) and np.allclose(zeta, 0)

def test_Material_uniax():
    mat = uniaxial_z_material(lambda wl: 1.5+0j, lambda wl: 1.7+0j)
    eps = mat.eps(600e-9)
    assert eps.shape == (3,3)
    assert np.allclose(np.diag(eps).real, [1.5**2, 1.5**2, 1.7**2])

def test_Rotation_and_Layer():
    R = R_from_euler(0.2, 0.3, 0.4)
    assert _is_rotation_matrix(R)
    air = iso_material(lambda wl: 1.0+0j)
    L = Layer(material=air, d=100e-9, R=R)
    assert isinstance(L, Layer)
    try:
        Layer(material=air, d=-1e-9)
        raise AssertionError("Negative thickness should have raised ValueError.")
    except ValueError:
        pass
    try:
        Layer(material=air, d=10e-9, R=np.eye(2))  # wrong shape
        raise AssertionError("Bad rotation shape should have raised ValueError.")
    except ValueError:
        pass

def test_HalfSpace():
    glass = iso_material(lambda wl: 1.5+0j)
    H = HalfSpace(material=glass, R=None)
    assert isinstance(H, HalfSpace)
    try:
        HalfSpace(material="not a material")  # type: ignore
        raise AssertionError("Non-Material should have raised ValueError.")
    except ValueError:
        pass


In [4]:
def run_all_tests():
    tests = [
        ("Geometry", test_Geometry),
        ("Material_iso", test_Material_iso),
        ("Material_uniax", test_Material_uniax),
        ("Material_from_index_constants", test_Material_from_index_constants),
        ("Material_from_index_callable_with_mu_xi_zeta", test_Material_from_index_callable_with_mu_xi_zeta),
        ("Material_from_epsilon_scalar_and_tensors", test_Material_from_epsilon_scalar_and_tensors),
        ("Material_from_uniaxial_z_constants_and_callable", test_Material_from_uniaxial_z_constants_and_callable),
        ("Material_invalid_shapes_raise", test_Material_invalid_shapes_raise),
        ("Rotation_and_Layer", test_Rotation_and_Layer),
        ("HalfSpace", test_HalfSpace),
    ]
    
    passed = 0
    print("Running data-class tests...\n")
    for name, fn in tests:
        try:
            fn()
            print(f"[PASS] {name}")
            passed += 1
        except Exception as e:
            print(f"[FAIL] {name} -> {e}")
            print(f"--- Verbose output for {name} ---")
            import traceback
            traceback.print_exc()
            print("--- End verbose output ---\n")
    
    print(f"\n{passed}/{len(tests)} tests passed.")



run_all_tests()


Running data-class tests...

[PASS] Geometry
[PASS] Material_iso
[PASS] Material_uniax
[PASS] Material_from_index_constants
[PASS] Material_from_index_callable_with_mu_xi_zeta
[PASS] Material_from_epsilon_scalar_and_tensors
[PASS] Material_from_uniaxial_z_constants_and_callable
[PASS] Material_invalid_shapes_raise
[PASS] Rotation_and_Layer
[PASS] HalfSpace

10/10 tests passed.


# Solver

In [8]:
import numpy as np
from typing import Optional, Tuple

C0   = 299_792_458.0
MU0  = 4e-7*np.pi
EPS0 = 1.0/(MU0*C0*C0)

def _forward_kz(kz: complex) -> complex:
    """
    Choose the forward/decaying branch:
      - if Im(kz) < 0  → flip sign (decay into +z)
      - if Im(kz) ≈ 0 and Re(kz) < 0 → flip sign (propagate +z)
    """
    if np.imag(kz) < 0:
        return -kz
    if abs(np.imag(kz)) < 1e-14 and np.real(kz) < 0:
        return -kz
    return kz


def _is_iso_eps(eps: np.ndarray) -> bool:
    return np.allclose(eps, eps[0,0]*np.eye(3), atol=1e-12)

def _rotate_tensor(T: np.ndarray, R: Optional[np.ndarray]) -> np.ndarray:
    if R is None:
        return T
    return R @ T @ R.T

def _n_from_eps(eps: np.ndarray) -> complex:
    # assume isotropic (or we'll just take xx); eps, mu are relative
    return np.lib.scimath.sqrt(eps[0,0])

def _kz_iso(k0: float, eps_r: complex, kx: complex, ky: complex) -> complex:
    kz = np.lib.scimath.sqrt((k0**2)*eps_r - (kx**2 + ky**2))
    return _forward_kz(kz)


def _eta_s(kz: complex, k0: float, mu_r: complex = 1.0) -> complex:
    # TE (s) reduced admittance (impedance-like scaling)
    return kz / (k0 * mu_r)

def _eta_p(kz, k0, eps_r):
    return (eps_r * k0) / kz

def _M_layer(phi: complex, eta: complex) -> np.ndarray:
    c, s = np.cos(phi), np.sin(phi)
    return np.array([[c, 1j*s/eta],
                     [1j*eta*s, c]], dtype=complex)
def _uniax_axis_aligned_kz_eta(eps: np.ndarray, k0: float, kx: complex, ky: complex):
    eps_perp = eps[0,0]
    eps_par  = eps[2,2]
    kt2 = kx*kx + ky*ky
    kz_s = np.lib.scimath.sqrt(eps_perp*k0*k0 - kt2)                    # ordinary (s)
    kz_p = np.lib.scimath.sqrt(eps_perp*k0*k0 - (eps_perp/eps_par)*kt2) # extraordinary (p)

    # Enforce forward/decaying branches
    kz_s = _forward_kz(kz_s)
    kz_p = _forward_kz(kz_p)

    # Admittances
    eta_s = kz_s / k0
    eta_p = (eps_perp * k0) / kz_p
    return kz_s, kz_p, eta_s, eta_p


class BerremanSolver:
    def __init__(self, layers: list[Layer], inc: HalfSpace, sub: HalfSpace):
        self.layers = layers
        self.inc = inc
        self.sub = sub

    # --- kinematics ---
    def kxkyk0(self, geom: Geometry, n_inc_scalar: complex) -> Tuple[complex, complex, float]:
        k0 = 2.0*np.pi/geom.wavelength
        kt = k0 * n_inc_scalar * np.sin(geom.theta)
        kx = kt * np.cos(geom.phi)
        ky = kt * np.sin(geom.phi)
        return kx, ky, k0

    def tensors_lab(self, mat: Material, wl: float, R: Optional[np.ndarray]):
        eps = _rotate_tensor(mat.eps(wl), R)
        mu  = _rotate_tensor(mat.mu(wl),  R)
        xi  = _rotate_tensor(mat.xi(wl),  R)
        zeta= _rotate_tensor(mat.zeta(wl),R)
        return eps, mu, xi, zeta

    def A_from_tensors(self, eps, mu, xi, zeta, k0, kx, ky):
        """
        4x4 Berreman A for anisotropic dielectric (μ=I typical; xi=zeta=0).
        Ψ=[Ex,Ey,Hx,Hy]^T, dΨ/dz = i k0 A Ψ .
        For diagonal eps (biaxial/uniaxial) with optic axis any (after rotation),
        the Yeh form reduces to:
          A = [[0,0, 0, 1],
               [0,0,-1, 0],
               [Q11,Q12,0, 0],
               [Q21,Q22,0, 0]]
        with Q = ε_t  - (1/ε_zz) (k_t k_t^T)/k0^2
        """
        # assume mu ≈ I and no bianisotropy for now
        eps_xx, eps_yy, eps_zz = eps[0,0], eps[1,1], eps[2,2]
        denom = (k0*k0) * eps_zz
        Q11 = eps_xx - (kx*kx)/denom
        Q22 = eps_yy - (ky*ky)/denom
        Q12 = - (kx*ky)/denom
        Q21 = Q12
        A = np.array([[0, 0,  0, 1],
                      [0, 0, -1, 0],
                      [Q11, Q12, 0, 0],
                      [Q21, Q22, 0, 0]], dtype=complex)
        return A
    
    def eigenmodes(self, A):
        # Return V (eigenvectors) and nu (kz/k0)
        nu, V = np.linalg.eig(A)
        return V, nu
    
    def propagator(self, V, nu, k0, d):
        return V @ np.diag(np.exp(1j * k0 * nu * d)) @ np.linalg.inv(V)
    
    def halfspace_modes_iso(self, n: complex, k0: float, kx: complex, ky: complex):
        """
        Power-normalized forward/backward TE/TM modes in an isotropic half-space.
        Returns kz, Fp (4x2), Fm (4x2) with columns [s, p].
        Field ordering: [Ex, Ey, Hx, Hy]^T.
        """
        kz = _forward_kz(np.lib.scimath.sqrt((k0*n)**2 - (kx**2 + ky**2)))
    
        # Build transverse unit vectors
        kt2 = kx*kx + ky*ky
        tol = 1e-15 * (abs(k0) * abs(n) + 1.0)
        if abs(kt2) < tol:
            # normal incidence: choose s along y, p has H_t along s_hat
            s_hat  = np.array([0+0j, 1+0j, 0+0j])
            pH_hat = s_hat.copy()  # p: H_t ⟂ plane of incidence (same as s_hat)
        else:
            kt = np.sqrt(kt2)
            s_hat  = np.array([-ky/kt, kx/kt, 0+0j])  # s: E_t ⟂ plane
            pH_hat = s_hat.copy()                      # p: H_t ⟂ plane
    
        # Forward (+kz) wavevector
        kf = np.array([kx, ky, kz], dtype=complex)
    
        # Maxwell relations (SI): ω = c k0
        def H_from_E(kvec, Evec):
            return np.cross(kvec, Evec) / (MU0 * C0 * k0)   # = (1/ωμ0) k×E
        def E_from_H(kvec, Hvec, eps_r):
            return -np.cross(kvec, Hvec) / (EPS0 * eps_r * C0 * k0)  # = -(1/ωε0ε_r) k×H
    
        # Forward modes
        Es_f = s_hat
        Hs_f = H_from_E(kf, Es_f)
        Hp_f = pH_hat
        Ep_f = E_from_H(kf, Hp_f, n**2)
    
        # Power-normalize so Pz = 1 for each forward column
        def tangential(E, H):  # [Ex,Ey,Hx,Hy]
            return np.array([E[0], E[1], H[0], H[1]], dtype=complex)
        def Pz(E, H):
            # time-avg Poynting z-flux from tangential fields
            return 0.5 * np.real(E[0]*np.conj(H[1]) - E[1]*np.conj(H[0]))
    
        Ps = np.real(Pz(Es_f, Hs_f))
        Pp = np.real(Pz(Ep_f, Hp_f))
        scale_s = np.sqrt(Ps) if Ps > 1e-30 else 1.0
        scale_p = np.sqrt(Pp) if Pp > 1e-30 else 1.0
    
        Es_f, Hs_f = Es_f/scale_s, Hs_f/scale_s
        Ep_f, Hp_f = Ep_f/scale_p, Hp_f/scale_p
    
        Fp = np.column_stack([tangential(Es_f, Hs_f), tangential(Ep_f, Hp_f)])
    
        # Backward modes (−kz ⇒ flip sign in k)
        kb = np.array([kx, ky, -kz], dtype=complex)
        Hs_b = H_from_E(kb, s_hat)/scale_s
        Ep_b = E_from_H(kb, pH_hat, n**2)/scale_p
        Fm = np.column_stack([tangential(s_hat, Hs_b), tangential(Ep_b, pH_hat)])
    
        return kz, Fp, Fm

    
    def _any_anisotropic_or_rotated(self, wl: float) -> bool:
        """Return True only when a layer is rotated (optic axis not along z)."""
        for L in self.layers:
            if L.R is not None:
                return True
        return False

    
    def solve(self, geom: Geometry, pol_basis: str = "sp"):
        eps0, _, _, _ = self.tensors_lab(self.inc.material, geom.wavelength, self.inc.R)
        epsS, _, _, _ = self.tensors_lab(self.sub.material, geom.wavelength, self.sub.R)
        n0 = np.lib.scimath.sqrt(eps0[0,0]); nS = np.lib.scimath.sqrt(epsS[0,0])
    
        kx, ky, k0 = self.kxkyk0(geom, n0)

        kt2 = kx*kx + ky*ky
        at_normal_incidence = (abs(kx) < 1e-12*abs(k0)) and (abs(ky) < 1e-12*abs(k0))

        
        # If normal incidence, the stack behaves as if axis-aligned in x–y,
        # so we can safely use the stable 2×2 engine regardless of R.
        if at_normal_incidence:
            force_fast_path = True
        else:
            force_fast_path = False
            
        # -------- Fast 2×2 path: isotropic layers AND axis-aligned uniaxial (R=None) --------
        if force_fast_path or (not self._any_anisotropic_or_rotated(geom.wavelength)):

            Ms = np.identity(2, dtype=complex); Mp = np.identity(2, dtype=complex)
    
            kz0 = _kz_iso(k0, eps0[0,0], kx, ky)
            kzS = _kz_iso(k0, epsS[0,0], kx, ky)
            _eta_s_iso = lambda kz: kz / k0                     # TE: n cosθ
            _eta_p_iso = lambda kz, eps_r: (eps_r * k0) / kz    # TM: n / cosθ
    
            eta0_s, etaS_s = _eta_s_iso(kz0), _eta_s_iso(kzS)
            eta0_p, etaS_p = _eta_p_iso(kz0, eps0[0,0]), _eta_p_iso(kzS, epsS[0,0])
    
            for L in self.layers:
                epsL, _, _, _ = self.tensors_lab(L.material, geom.wavelength, L.R)
                if (L.R is None) and (not _is_iso_eps(epsL)):
                    # Axis-aligned uniaxial
                    kz_s, kz_p, eta_s, eta_p = _uniax_axis_aligned_kz_eta(epsL, k0, kx, ky)
                    phi_s, phi_p = kz_s * L.d, kz_p * L.d
                    Ms = Ms @ np.array([[np.cos(phi_s), 1j*np.sin(phi_s)/eta_s],
                                        [1j*eta_s*np.sin(phi_s), np.cos(phi_s)]], dtype=complex)
                    Mp = Mp @ np.array([[np.cos(phi_p), 1j*np.sin(phi_p)/eta_p],
                                        [1j*eta_p*np.sin(phi_p), np.cos(phi_p)]], dtype=complex)
                else:
                    # Isotropic
                    eps_r = epsL[0,0]
                    kzL = _kz_iso(k0, eps_r, kx, ky)
                    phi = kzL * L.d
                    Ms = Ms @ np.array([[np.cos(phi), 1j*np.sin(phi)/_eta_s_iso(kzL)],
                                        [1j*_eta_s_iso(kzL)*np.sin(phi), np.cos(phi)]], dtype=complex)
                    Mp = Mp @ np.array([[np.cos(phi), 1j*np.sin(phi)/_eta_p_iso(kzL, eps_r)],
                                        [1j*_eta_p_iso(kzL, eps_r)*np.sin(phi), np.cos(phi)]], dtype=complex)
    
            # Final interface match (same for uni/iso because we used the right admittances above)
            # s
            Bs = Ms[0,0] + Ms[0,1]*etaS_s
            Cs = Ms[1,0] + Ms[1,1]*etaS_s
            den_s = eta0_s*Bs + Cs
            rs = (eta0_s*Bs - Cs)/den_s
            ts = (2*eta0_s)/den_s
            # p
            Bp = Mp[0,0] + Mp[0,1]*etaS_p
            Cp = Mp[1,0] + Mp[1,1]*etaS_p
            den_p = eta0_p*Bp + Cp
            rp = (eta0_p*Bp - Cp)/den_p
            tp = (2*eta0_p)/den_p
    
            # Power
            Rs, Rp = abs(rs)**2, abs(rp)**2
            Ts = (np.real(etaS_s)/np.real(eta0_s)) * abs(ts)**2
            Tp = (np.real(etaS_p)/np.real(eta0_p)) * abs(tp)**2
    
            rJ = np.array([[rs, 0],[0, rp]], dtype=complex)
            tJ = np.array([[ts, 0],[0, tp]], dtype=complex)
            return {
                "rJ": rJ, "tJ": tJ,
                "Rsp": {"s": Rs, "p": Rp},
                "Tsp": {"s": Ts, "p": Tp},
                "Asp": {"s": max(0.0, 1-(Rs+Ts)),
                        "p": max(0.0, 1-(Rp+Tp))}
            }
    
        # -------- 4×4 Berreman path (for rotated anisotropy) --------
        kz0, Fp0, Fm0 = self.halfspace_modes_iso(n0, k0, kx, ky)
        kzS, FpS, FmS = self.halfspace_modes_iso(nS, k0, kx, ky)
    
        M = np.eye(4, dtype=complex)
        for L in self.layers:
            eps, mu, xi, zeta = self.tensors_lab(L.material, geom.wavelength, L.R)
            A = self.A_from_tensors(eps, mu, xi, zeta, k0, kx, ky)
            V, nu = self.eigenmodes(A)
            P = self.propagator(V, nu, k0, L.d)
            M = P @ M
    
        rJ = np.zeros((2,2), complex); tJ = np.zeros((2,2), complex)
        for j in range(2):
            a = np.zeros(2, complex); a[j] = 1.0
            A_sys = np.concatenate([M @ Fm0, -FpS], axis=1)
            b_sys = - M @ (Fp0 @ a)
            sol, *_ = np.linalg.lstsq(A_sys, b_sys, rcond=None)
            r, t = sol[:2], sol[2:]
            rJ[:,j] = r; tJ[:,j] = t
    
        Rs, Rp = abs(rJ[0,0])**2, abs(rJ[1,1])**2
        Ts, Tp = abs(tJ[0,0])**2, abs(tJ[1,1])**2
        return {"rJ": rJ, "tJ": tJ,
                "Rsp": {"s": Rs, "p": Rp},
                "Tsp": {"s": Ts, "p": Tp},
                "Asp": {"s": max(0.0, 1-(Rs+Ts)), "p": max(0.0, 1-(Rp+Tp))}}


# tests

In [9]:
# Tests for a future BerremanSolver implementation.
# Assumes you already defined: Geometry, Material, Layer, HalfSpace,
# iso_material, uniaxial_z_material, R_from_euler (from the previous cell).

import numpy as np
from math import isclose



# ------------- Simple test helpers -------------

def _assert_close(a, b, tol=1e-9, label=""):
    if not (abs(a - b) <= tol):
        raise AssertionError(f"{label} expected {b}, got {a} (|Δ|={abs(a-b)})")

def _ensure_solver_present():
    if 'BerremanSolver' not in globals():
        raise RuntimeError("BerremanSolver class is not defined in the notebook yet. "
                           "Please implement it before running these tests.")

# ------------- Test cases -------------

def test_shapes_and_types():
    _ensure_solver_present()
    air   = iso_material(lambda wl: 1.0+0j)
    glass = iso_material(lambda wl: 1.5+0j)
    layers = []  # direct interface
    solver = BerremanSolver(layers, HalfSpace(air), HalfSpace(glass))
    geom = Geometry(wavelength=550e-9, theta=0.0)
    out = solver.solve(geom)
    # Required keys
    for k in ["rJ","tJ","Rsp","Tsp","Asp"]:
        assert k in out, f"Missing key {k} in solver output"
    rJ, tJ = out["rJ"], out["tJ"]
    assert isinstance(rJ, np.ndarray) and rJ.shape == (2,2) and rJ.dtype.kind == 'c'
    assert isinstance(tJ, np.ndarray) and tJ.shape == (2,2) and tJ.dtype.kind == 'c'
    for pol in ("s","p"):
        for dict_key in ("Rsp","Tsp","Asp"):
            assert pol in out[dict_key], f"Missing '{pol}' in {dict_key}"
            val = out[dict_key][pol]
            assert np.isscalar(val), f"{dict_key}['{pol}'] must be a scalar"
            # allow tiny negative due to FP
            assert val < 10, f"{dict_key}['{pol}'] looks unphysical (too large)"

# --- Fresnel helpers (replace your previous ones) ---

def _snell_cos(n0, n1, theta0):
    s0 = np.sin(theta0)
    s1 = n0 * s0 / n1
    c1 = np.lib.scimath.sqrt(1 - s1**2)  # allows TIR
    c0 = np.cos(theta0)
    return c0, c1

def fresnel_interface_rs_rp(n0, n1, theta0):
    c0, c1 = _snell_cos(n0, n1, theta0)
    rs = (n0*c0 - n1*c1) / (n0*c0 + n1*c1)
    rp = (n1*c0 - n0*c1) / (n1*c0 + n0*c1)
    return rs, rp

def fresnel_interface_Rs_Rp_Ts_Tp(n0, n1, theta0):
    """Return power R/T using *amplitude t_s, t_p* (not 1+r_p)."""
    c0, c1 = _snell_cos(n0, n1, theta0)
    rs, rp = fresnel_interface_rs_rp(n0, n1, theta0)
    # amplitude transmission coefficients
    ts = (2*n0*c0) / (n0*c0 + n1*c1)
    tp = (2*n0*c0) / (n1*c0 + n0*c1)
    # power transmission (Poynting) factors
    Rs, Rp = abs(rs)**2, abs(rp)**2
    Ts = (np.real(n1*c1) / np.real(n0*c0)) * abs(ts)**2
    Tp = (np.real(n1*c1) / np.real(n0*c0)) * abs(tp)**2
    return Rs, Rp, Ts, Tp

# --- Update the single-interface test to use the helper above ---

def test_single_interface_fresnel():
    _ensure_solver_present()
    air   = iso_material(lambda wl: 1.0+0j)
    glass = iso_material(lambda wl: 1.5+0j)
    solver = BerremanSolver([], HalfSpace(air), HalfSpace(glass))
    for theta_deg in (0.0, 30.0, 60.0):
        theta = np.deg2rad(theta_deg)
        geom = Geometry(wavelength=550e-9, theta=theta, phi=0.0)
        out = solver.solve(geom)
        Rs, Rp = out["Rsp"]["s"], out["Rsp"]["p"]
        Ts, Tp = out["Tsp"]["s"], out["Tsp"]["p"]
        Rs_ref, Rp_ref, Ts_ref, Tp_ref = fresnel_interface_Rs_Rp_Ts_Tp(1.0+0j, 1.5+0j, theta)
        # tolerances
        tolR = 1e-8
        tolT = 1e-8
        _assert_close(Rs, Rs_ref, tolR, f"Rs at {theta_deg}°")
        _assert_close(Rp, Rp_ref, tolR, f"Rp at {theta_deg}°")
        _assert_close(Ts, Ts_ref, tolT, f"Ts at {theta_deg}°")
        _assert_close(Tp, Tp_ref, tolT, f"Tp at {theta_deg}°")
        _assert_close(Rs + Ts, 1.0, 5e-8, f"Energy (s) at {theta_deg}°")
        _assert_close(Rp + Tp, 1.0, 5e-8, f"Energy (p) at {theta_deg}°")


def test_quarter_wave_AR_normal_incidence():
    _ensure_solver_present()
    n0, nS = 1.0+0j, 1.5+0j
    n_film = np.sqrt(n0*nS)  # ideal AR index
    wl = 550e-9
    d = wl/(4*np.real(n_film))
    air    = iso_material(lambda wl: n0)
    film   = iso_material(lambda wl: n_film)
    glass  = iso_material(lambda wl: nS)
    layers = [Layer(material=film, d=d, R=None)]
    solver = BerremanSolver(layers, HalfSpace(air), HalfSpace(glass))
    out = solver.solve(Geometry(wavelength=wl, theta=0.0))
    # Expect ~zero reflectance both pols at normal incidence
    assert abs(out["Rsp"]["s"]) < 1e-6, f"AR Rs too large: {out['Rsp']['s']}"
    assert abs(out["Rsp"]["p"]) < 1e-6, f"AR Rp too large: {out['Rsp']['p']}"
    # Transmission near 1
    assert isclose(out["Tsp"]["s"], 1.0, rel_tol=0, abs_tol=1e-6)
    assert isclose(out["Tsp"]["p"], 1.0, rel_tol=0, abs_tol=1e-6)

def test_energy_conservation_lossless_stack():
    _ensure_solver_present()
    wl = 550e-9
    air   = iso_material(lambda wl: 1.0+0j)
    n1, n2 = 1.45+0j, 2.0+0j
    L1 = iso_material(lambda wl: n1)
    L2 = iso_material(lambda wl: n2)
    layers = [
        Layer(L1, d=90e-9),
        Layer(L2, d=60e-9),
        Layer(L1, d=120e-9),
    ]
    solver = BerremanSolver(layers, HalfSpace(air), HalfSpace(air))
    for theta_deg in (0.0, 45.0, 70.0):
        out = solver.solve(Geometry(wavelength=wl, theta=np.deg2rad(theta_deg)))
        Es = abs((out["Rsp"]["s"] + out["Tsp"]["s"]) - 1.0)
        Ep = abs((out["Rsp"]["p"] + out["Tsp"]["p"]) - 1.0)
        assert Es < 5e-6, f"Energy not conserved (s) at {theta_deg}°: {Es}"
        assert Ep < 5e-6, f"Energy not conserved (p) at {theta_deg}°: {Ep}"

def test_uniaxial_axis_aligned_no_crosspol():
    _ensure_solver_present()
    wl = 600e-9
    air = iso_material(lambda wl: 1.0+0j)
    # Uniaxial with optic axis || z (no cross-pol for phi=0)
    no, ne = 1.6+0j, 1.7+0j
    mat = uniaxial_z_material(lambda wl: no, lambda wl: ne)
    layers = [Layer(mat, d=300e-9, R=None)]
    solver = BerremanSolver(layers, HalfSpace(air), HalfSpace(air))
    out = solver.solve(Geometry(wavelength=wl, theta=np.deg2rad(45), phi=0.0))
    rJ, tJ = out["rJ"], out["tJ"]
    # Off-diagonals should be ~0 (no s<->p conversion when optic axis || z and plane of incidence x-z)
    assert abs(rJ[0,1]) < 1e-6 and abs(rJ[1,0]) < 1e-6, f"Unexpected cross-pol in rJ: {rJ}"
    assert abs(tJ[0,1]) < 1e-6 and abs(tJ[1,0]) < 1e-6, f"Unexpected cross-pol in tJ: {tJ}"
    # Energy ~1 for lossless
    Es = abs((out["Rsp"]["s"] + out["Tsp"]["s"]) - 1.0)
    Ep = abs((out["Rsp"]["p"] + out["Tsp"]["p"]) - 1.0)
    assert Es < 5e-6 and Ep < 5e-6, "Energy not conserved for uniaxial aligned layer"

def test_axis_aligned_birefringence():
    """Axis-aligned uniaxial film between air: Ts != Tp at oblique incidence, no cross-pol."""
    wl = 550e-9
    air = iso_material(lambda wl: 1.0+0j)
    # Ordinary / extraordinary indices
    n_o, n_e = 1.50+0j, 1.70+0j
    biref = uniaxial_z_material(lambda wl: n_o, lambda wl: n_e)

    layers = [Layer(material=biref, d=300e-9, R=None)]  # optic axis ‖ z
    solver = BerremanSolver(layers, inc=HalfSpace(air), sub=HalfSpace(air))

    # Use oblique incidence so p sees the extraordinary wave
    geom = Geometry(wavelength=wl, theta=np.deg2rad(45), phi=0.0)
    out = solver.solve(geom)

    rJ, tJ = out["rJ"], out["tJ"]
    Rs, Ts = out["Rsp"]["s"], out["Tsp"]["s"]
    Rp, Tp = out["Rsp"]["p"], out["Tsp"]["p"]

    # No cross-pol expected when axis is aligned
    assert abs(rJ[0,1]) < 1e-8 and abs(rJ[1,0]) < 1e-8, f"Unexpected cross-pol in rJ: {rJ}"
    assert abs(tJ[0,1]) < 1e-8 and abs(tJ[1,0]) < 1e-8, f"Unexpected cross-pol in tJ: {tJ}"

    # Birefringence: different s vs p response
    assert abs(Ts - Tp) > 1e-6 or abs(Rs - Rp) > 1e-6, f"No birefringence: Ts={Ts}, Tp={Tp}, Rs={Rs}, Rp={Rp}"

    # Energy per pol (lossless)
    assert abs((Rs + Ts) - 1.0) < 5e-6, f"Energy not conserved (s): {Rs+Ts}"
    assert abs((Rp + Tp) - 1.0) < 5e-6, f"Energy not conserved (p): {Rp+Tp}"

def test_rotated_axis_birefringence_crosspol():
    wl = 550e-9
    air = iso_material(lambda wl: 1.0+0j)
    no, ne = 1.55+0j, 1.75+0j
    mat = uniaxial_z_material(lambda wl: no, lambda wl: ne)
    Rtilt = R_from_euler(0.0, np.deg2rad(35), 0.0)

    layers = [Layer(material=mat, d=300e-9, R=Rtilt)]
    solver = BerremanSolver(layers, HalfSpace(air), HalfSpace(air))
    out = solver.solve(Geometry(wavelength=wl, theta=np.deg2rad(45), phi=0.0))

    rJ, tJ = out["rJ"], out["tJ"]
    # cross-pol must be nonzero
    assert abs(rJ[0,1]) > 1e-6 or abs(rJ[1,0]) > 1e-6 or abs(tJ[0,1]) > 1e-6 or abs(tJ[1,0]) > 1e-6
    # energy per input pol
    assert abs(out["Rsp"]["s"] + out["Tsp"]["s"] - 1) < 5e-6
    assert abs(out["Rsp"]["p"] + out["Tsp"]["p"] - 1) < 5e-6

def test_uniaxial_isotropic_limit():
    wl = 600e-9
    n = 1.65+0j
    air = iso_material(lambda wl: 1.0+0j)
    uni = uniaxial_z_material(lambda wl: n, lambda wl: n)
    iso = iso_material(lambda wl: n)

    layers_uni = [Layer(uni, d=250e-9, R=None)]
    layers_iso = [Layer(iso, d=250e-9, R=None)]
    geom = Geometry(wavelength=wl, theta=np.deg2rad(40), phi=0.0)

    sol_uni = BerremanSolver(layers_uni, HalfSpace(air), HalfSpace(air)).solve(geom)
    sol_iso = BerremanSolver(layers_iso, HalfSpace(air), HalfSpace(air)).solve(geom)

    for k in ("s","p"):
        assert abs(sol_uni["Rsp"][k]-sol_iso["Rsp"][k]) < 1e-8
        assert abs(sol_uni["Tsp"][k]-sol_iso["Tsp"][k]) < 1e-8

def test_brewster_angle_isotropic():
    n0, n1 = 1.0, 1.5
    wl = 550e-9
    air = iso_material(lambda wl: n0+0j)
    glass = iso_material(lambda wl: n1+0j)
    solver = BerremanSolver([], HalfSpace(air), HalfSpace(glass))

    theta_B = np.arctan(n1/n0)  # in radians
    geom = Geometry(wavelength=wl, theta=theta_B, phi=0.0)
    out = solver.solve(geom)
    assert out["Rsp"]["p"] < 1e-10, f"Rp at Brewster not ~0: {out['Rsp']['p']}"

def test_total_internal_reflection():
    wl = 550e-9
    n_hi, n_lo = 1.5+0j, 1.0+0j
    inc = iso_material(lambda wl: n_hi)
    sub = iso_material(lambda wl: n_lo)
    solver = BerremanSolver([], HalfSpace(inc), HalfSpace(sub))
    theta_c = np.arcsin((n_lo.real)/(n_hi.real)) + np.deg2rad(5)  # above critical

    out = solver.solve(Geometry(wavelength=wl, theta=theta_c, phi=0.0))
    # near-perfect reflection
    assert 1 - out["Rsp"]["s"] < 1e-10 and 1 - out["Rsp"]["p"] < 1e-10
    # transmittance nearly zero
    assert out["Tsp"]["s"] < 1e-12 and out["Tsp"]["p"] < 1e-12


def test_thin_film_limit():
    wl = 550e-9
    air = iso_material(lambda wl: 1.0+0j)
    glass = iso_material(lambda wl: 1.5+0j)
    film = iso_material(lambda wl: 2.0+0j)

    # very thin film
    d = 1e-12
    layers = [Layer(film, d=d)]
    geom = Geometry(wavelength=wl, theta=np.deg2rad(35), phi=0.0)

    sol_stack = BerremanSolver(layers, HalfSpace(air), HalfSpace(glass)).solve(geom)
    sol_iface = BerremanSolver([], HalfSpace(air), HalfSpace(glass)).solve(geom)

    for k in ("s","p"):
        assert abs(sol_stack["Rsp"][k]-sol_iface["Rsp"][k]) < 1e-8
        assert abs(sol_stack["Tsp"][k]-sol_iface["Tsp"][k]) < 1e-8

def test_reciprocity_isotropic_stack():
    wl = 550e-9
    air = iso_material(lambda wl: 1.0+0j)
    n1, n2 = 1.45+0j, 2.0+0j
    L1, L2 = iso_material(lambda wl: n1), iso_material(lambda wl: n2)
    layers = [Layer(L1, 90e-9), Layer(L2, 60e-9), Layer(L1, 120e-9)]
    geom = Geometry(wavelength=wl, theta=np.deg2rad(25), phi=0.0)

    # forward (air -> air)
    fwd = BerremanSolver(layers, HalfSpace(air), HalfSpace(air)).solve(geom)
    # reverse (reverse stack)
    rev_layers = list(reversed(layers))
    rev = BerremanSolver(rev_layers, HalfSpace(air), HalfSpace(air)).solve(geom)

    # total (unpolarized) transmittance same both directions
    T_fwd = 0.5*(fwd["Tsp"]["s"] + fwd["Tsp"]["p"])
    T_rev = 0.5*(rev["Tsp"]["s"] + rev["Tsp"]["p"])
    assert abs(T_fwd - T_rev) < 1e-8

def test_lossy_film_absorption(verbose: bool = False):
    """Check that energy is conserved with loss: R + T + A ≈ 1, with A>0."""
    wl = 500e-9

    # Lossy isotropic film via direct epsilon, and air via scalar n
    lossy_eps = (1.5 + 0.1j)**2 * np.eye(3, dtype=complex)
    lossy_mat = material(epsilon=lossy_eps)
    air       = material(n=1.0)

    inc = HalfSpace(air)
    sub = HalfSpace(air)
    layers = [Layer(lossy_mat, 100e-9)]
    geom = Geometry(wavelength=wl, theta=0.0, phi=0.0)

    solver = BerremanSolver(layers, inc, sub)
    res = solver.solve(geom)

    for pol in ("s", "p"):
        R = res["Rsp"][pol]
        T = res["Tsp"][pol]
        A = res["Asp"][pol]
        total = R + T + A

        if verbose:
            print(f"--- lossy_film_absorption ({pol}) ---")
            print(f"R = {R:.6f}")
            print(f"T = {T:.6f}")
            print(f"A = {A:.6f}")
            print(f"R+T+A = {total:.6f}")

        # Check energy conservation and positive absorption
        assert abs(total - 1.0) < 5e-6, f"Energy not conserved with loss: {total}"
        assert A > 1e-6, f"Absorption should be positive, got {A}"


def test_normal_incidence_orientation_invariance(verbose: bool = False):
    """At normal incidence, rotation around z should not change R/T; cross-pol ~ 0."""
    wl = 500e-9
    air = material(n=1.0)

    # Uniaxial-ish constant epsilon (principal axes already in lab);
    # rotation about z should be irrelevant at normal incidence.
    eps_uniax = np.diag([2.0, 1.5, 1.5]).astype(complex)
    mat = material(epsilon=eps_uniax)

    inc = HalfSpace(air)
    sub = HalfSpace(air)

    L0 = Layer(mat, 200e-9, R=None)
    geom0 = Geometry(wavelength=wl, theta=0.0, phi=0.0)

    res0 = BerremanSolver([L0], inc, sub).solve(geom0)

    # Rotate around z-axis by 45° (only z-rotation)
    Rz45 = R_from_euler(np.deg2rad(45), 0.0, 0.0)
    L1 = Layer(mat, 200e-9, R=Rz45)
    res1 = BerremanSolver([L1], inc, sub).solve(geom0)

    if verbose:
        print("--- normal_incidence_orientation_invariance ---")
        print("rJ (no rot):\n", res0["rJ"])
        print("tJ (no rot):\n", res0["tJ"])
        print("rJ (z rot):\n",  res1["rJ"])
        print("tJ (z rot):\n",  res1["tJ"])

    # R/T invariance
    for k in ("s", "p"):
        assert abs(res0["Rsp"][k] - res1["Rsp"][k]) < 1e-8, f"Rsp[{k}] changed at normal incidence"
        assert abs(res0["Tsp"][k] - res1["Tsp"][k]) < 1e-8, f"Tsp[{k}] changed at normal incidence"

    # Cross-pol ~ 0 at normal incidence
    for rJ, tJ in ((res0["rJ"], res0["tJ"]), (res1["rJ"], res1["tJ"])):
        assert abs(rJ[0,1]) < 1e-8 and abs(rJ[1,0]) < 1e-8, "cross-pol in rJ at normal incidence"
        assert abs(tJ[0,1]) < 1e-8 and abs(tJ[1,0]) < 1e-8, "cross-pol in tJ at normal incidence"

def test_lossy_interface_only_diagnostic(verbose: bool = False):
    """Air | lossy half-space (no film): solver must match Fresnel with complex n."""
    wl = 500e-9
    n_loss = 1.5 + 0.1j
    air = iso_material(lambda wl: 1.0+0j)
    lossy = iso_material(lambda wl: n_loss)

    # direct interface (no layer)
    out = BerremanSolver([], HalfSpace(air), HalfSpace(lossy)).solve(
        Geometry(wavelength=wl, theta=0.0, phi=0.0))

    Rs, Rp = out["Rsp"]["s"], out["Rsp"]["p"]
    Ts, Tp = out["Tsp"]["s"], out["Tsp"]["p"]

    # Reference Fresnel for complex n
    Rs_ref, Rp_ref, Ts_ref, Tp_ref = fresnel_interface_Rs_Rp_Ts_Tp(1.0+0j, n_loss, 0.0)

    if verbose:
        print("--- lossy_interface_only_diagnostic ---")
        print(f"Solver:   Rs={Rs:.6f} Ts={Ts:.6f} | Rp={Rp:.6f} Tp={Tp:.6f}")
        print(f"Fresnel:  Rs={Rs_ref:.6f} Ts={Ts_ref:.6f} | Rp={Rp_ref:.6f} Tp={Tp_ref:.6f}")
        print(f"Energy s: R+T={Rs+Ts:.6f} | p: {Rp+Tp:.6f}")

    # They should match tightly, and energy ≤ 1 (equal if no absorption in either half-space)
    assert abs(Rs - Rs_ref) < 1e-8 and abs(Rp - Rp_ref) < 1e-8
    assert abs(Ts - Ts_ref) < 1e-8 and abs(Tp - Tp_ref) < 1e-8


def test_lossy_thick_film_diagnostic(verbose: bool = False):
    """
    Air | (thick lossy film) | Air: with large d, multiple reflections die out,
    so result should approach a single lossy interface (front) with T ~ 0.
    """
    wl = 500e-9
    n_loss = 1.5 + 0.1j
    air = iso_material(lambda wl: 1.0+0j)
    lossy = iso_material(lambda wl: n_loss)

    # Reference front-interface Fresnel (air -> lossy)
    Rs_ref, Rp_ref, Ts_ref, Tp_ref = fresnel_interface_Rs_Rp_Ts_Tp(1.0+0j, n_loss, 0.0)

    # Very thick lossy film
    d = 5e-6  # 5 microns >> absorption length for 0.1j at 500 nm (just make it "very" thick)
    layers = [Layer(lossy, d=d)]
    out = BerremanSolver(layers, HalfSpace(air), HalfSpace(air)).solve(
        Geometry(wavelength=wl, theta=0.0, phi=0.0))

    Rs, Rp = out["Rsp"]["s"], out["Rsp"]["p"]
    Ts, Tp = out["Tsp"]["s"], out["Tsp"]["p"]

    if verbose:
        print("--- lossy_thick_film_diagnostic ---")
        print(f"Thick film: Rs={Rs:.6f} Ts={Ts:.6f} | Rp={Rp:.6f} Tp={Tp:.6f}")
        print(f"Front iface Rs_ref={Rs_ref:.6f} | Rp_ref={Rp_ref:.6f}")

    # Expect T ~ 0, and R close to the *front interface* reflectance
    assert Ts < 1e-3 and Tp < 1e-3, f"T not small for thick lossy film: Ts={Ts}, Tp={Tp}"
    assert abs(Rs - Rs_ref) < 5e-3 and abs(Rp - Rp_ref) < 5e-3, \
        f"R doesn't approach front-interface value for thick lossy film"

def test_normal_incidence_orientation_invariance(verbose: bool = False):
    """At normal incidence, rotation around z should not change R/T; cross-pol ~ 0."""
    wl = 500e-9
    air = material(n=1.0)

    # Uniaxial with axis || z: two equal in x,y; different in z
    n_o2, n_e2 = 1.5**2, 2.0**2
    eps_uniax_z = np.diag([n_o2, n_o2, n_e2]).astype(complex)  # axis || z
    mat = material(epsilon=eps_uniax_z)

    inc = HalfSpace(air)
    sub = HalfSpace(air)

    L0 = Layer(mat, 200e-9, R=None)
    geom0 = Geometry(wavelength=wl, theta=0.0, phi=0.0)

    res0 = BerremanSolver([L0], inc, sub).solve(geom0)

    # Rotate around z-axis by 45° (should be inert when axis || z)
    Rz45 = R_from_euler(np.deg2rad(45), 0.0, 0.0)
    L1 = Layer(mat, 200e-9, R=Rz45)
    res1 = BerremanSolver([L1], inc, sub).solve(geom0)

    if verbose:
        print("--- normal_incidence_orientation_invariance ---")
        print("rJ (no rot):\n", res0["rJ"])
        print("tJ (no rot):\n", res0["tJ"])
        print("rJ (z rot):\n",  res1["rJ"])
        print("tJ (z rot):\n",  res1["tJ"])

    # Invariance
    for k in ("s", "p"):
        assert abs(res0["Rsp"][k] - res1["Rsp"][k]) < 1e-8, f"Rsp[{k}] changed at normal incidence"
        assert abs(res0["Tsp"][k] - res1["Tsp"][k]) < 1e-8, f"Tsp[{k}] changed at normal incidence"

    # Cross-pol ~ 0
    for rJ, tJ in ((res0["rJ"], res0["tJ"]), (res1["rJ"], res1["tJ"])):
        assert abs(rJ[0,1]) < 1e-8 and abs(rJ[1,0]) < 1e-8
        assert abs(tJ[0,1]) < 1e-8 and abs(tJ[1,0]) < 1e-8

def test_normal_incidence_xy_rotation_changes(verbose: bool = False):
    """If axis ⟂ z (e.g., along x), rotating around z *will* change R/T at θ=0."""
    wl = 500e-9
    air = material(n=1.0)

    # Axis along x: two equal in y,z
    eps_axis_x = np.diag([2.0, 1.5, 1.5]).astype(complex)
    mat = material(epsilon=eps_axis_x)

    out0 = BerremanSolver([Layer(mat, 200e-9, R=None)], HalfSpace(air), HalfSpace(air)).solve(
        Geometry(wavelength=wl, theta=0.0, phi=0.0))

    Rz45 = R_from_euler(np.deg2rad(45), 0.0, 0.0)
    out1 = BerremanSolver([Layer(mat, 200e-9, R=Rz45)], HalfSpace(air), HalfSpace(air)).solve(
        Geometry(wavelength=wl, theta=0.0, phi=0.0))

    # Expect *difference* now
    assert abs(out0["Rsp"]["s"] - out1["Rsp"]["s"]) > 1e-6 or abs(out0["Tsp"]["s"] - out1["Tsp"]["s"]) > 1e-6


# Run tests

In [11]:
def run_berreman_solver_tests():
    tests = [
        ("shapes_and_types",                        test_shapes_and_types),
        ("single_interface_fresnel",                test_single_interface_fresnel),
        ("quarter_wave_AR_normal_incidence",        test_quarter_wave_AR_normal_incidence),
        ("energy_conservation_lossless_stack",      test_energy_conservation_lossless_stack),
        ("uniaxial_axis_aligned_no_crosspol",       test_uniaxial_axis_aligned_no_crosspol),
        ("axis_aligned_birefringence",              test_axis_aligned_birefringence),
        ("uniaxial_isotropic_limit",                test_uniaxial_isotropic_limit),
        ("brewster_angle_isotropic",                test_brewster_angle_isotropic),
        ("total_internal_reflection",               test_total_internal_reflection),
        ("thin_film_limit",                         test_thin_film_limit),
        ("reciprocity_isotropic_stack",             test_reciprocity_isotropic_stack),
        # ---- diagnostics & tricky ones moved later ----
        ("lossy_interface_only_diagnostic",         test_lossy_interface_only_diagnostic),
        ("lossy_thick_film_diagnostic",             test_lossy_thick_film_diagnostic),
        ("normal_incidence_xy_rotation_changes",    test_normal_incidence_xy_rotation_changes),  # optional, keeps insight
        ("lossy_film_absorption",                   test_lossy_film_absorption),
        ("normal_incidence_orientation_invariance", test_normal_incidence_orientation_invariance),
        # ("rotated_axis_birefringence_crosspol",   test_rotated_axis_birefringence_crosspol),
    ]

    verbose_on_fail = {
        "lossy_interface_only_diagnostic",
        "lossy_thick_film_diagnostic",
        "lossy_film_absorption",
        "normal_incidence_orientation_invariance",
        "normal_incidence_xy_rotation_changes",
    }

    passed = 0
    print("Running BerremanSolver tests...\n")
    for name, fn in tests:
        try:
            fn()
            print(f"[PASS] {name}")
            passed += 1
        except Exception as e:
            print(f"[FAIL] {name} -> {e}")
            if name in verbose_on_fail:
                try:
                    print(f"--- Verbose output for {name} ---")
                    fn(verbose=True)
                except Exception as e2:
                    print(f"(Verbose) Still failing: {e2}")
    print(f"\n{passed}/{len(tests)} tests passed.")
run_berreman_solver_tests()

Running BerremanSolver tests...

[PASS] shapes_and_types
[PASS] single_interface_fresnel
[PASS] quarter_wave_AR_normal_incidence
[PASS] energy_conservation_lossless_stack
[PASS] uniaxial_axis_aligned_no_crosspol
[PASS] axis_aligned_birefringence
[PASS] uniaxial_isotropic_limit
[PASS] brewster_angle_isotropic
[PASS] total_internal_reflection
[PASS] thin_film_limit
[PASS] reciprocity_isotropic_stack
[PASS] lossy_interface_only_diagnostic
[FAIL] lossy_thick_film_diagnostic -> T not small for thick lossy film: Ts=0.001865716528057891, Tp=0.001865716528057894
--- Verbose output for lossy_thick_film_diagnostic ---
--- lossy_thick_film_diagnostic ---
Thick film: Rs=24.080600 Ts=0.001866 | Rp=24.080600 Tp=0.001866
Front iface Rs_ref=0.041534 | Rp_ref=0.041534
(Verbose) Still failing: T not small for thick lossy film: Ts=0.001865716528057891, Tp=0.001865716528057894
[PASS] normal_incidence_xy_rotation_changes
[FAIL] lossy_film_absorption -> Energy not conserved with loss: 1.2522030928767793
---