# **AARON'S NIGHTMARE EQUATIONS**

## **Install**

### Firedrake

In [6]:
try:
    !wget "https://fem-on-colab.github.io/releases/firedrake-install-release-real.sh" -O "/tmp/firedrake-install.sh"
    !bash "/tmp/firedrake-install.sh"
    from firedrake import *  # noqa: F401
except:
    from firedrake import *  # noqa: F401

--2025-11-13 16:18:43--  https://fem-on-colab.github.io/releases/firedrake-install-release-real.sh
Resolving fem-on-colab.github.io (fem-on-colab.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to fem-on-colab.github.io (fem-on-colab.github.io)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4767 (4.7K) [application/x-sh]
Saving to: ‘/tmp/firedrake-install.sh’


2025-11-13 16:18:43 (46.7 MB/s) - ‘/tmp/firedrake-install.sh’ saved [4767/4767]

+ INSTALL_PREFIX=/usr/local
++ echo /usr/local
++ awk -F/ '{print NF-1}'
+ INSTALL_PREFIX_DEPTH=2
+ PROJECT_NAME=fem-on-colab
+ SHARE_PREFIX=/usr/local/share/fem-on-colab
+ FIREDRAKE_INSTALLED=/usr/local/share/fem-on-colab/firedrake.installed
+ [[ ! -f /usr/local/share/fem-on-colab/firedrake.installed ]]
+ set +x
























################################################################################
#     This installation is offered by FEM on Colab, an open-

### Irksome

In [7]:
try:
    !python3 -m pip install --no-dependencies git+https://github.com/firedrakeproject/Irksome.git
    from irksome import *  # noqa: F401
except:
    from irksome import *  # noqa: F401

Collecting git+https://github.com/firedrakeproject/Irksome.git
  Cloning https://github.com/firedrakeproject/Irksome.git to /tmp/pip-req-build-fh5ep0it
  Running command git clone --filter=blob:none --quiet https://github.com/firedrakeproject/Irksome.git /tmp/pip-req-build-fh5ep0it
  Resolved https://github.com/firedrakeproject/Irksome.git to commit 784ae3f83ab853b04c525d94112c491ce7defbc3
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


### Legacy FET

In [8]:
try:
    import avfet_modules.cheb_fet as cheb_fet
except:
    !python3 -m pip install --no-dependencies "git+https://github.com/BorisAndrews/BorisAndrews.github.io@main#subdirectory=assets/python/avfet_modules"
    import avfet_modules.cheb_fet as cheb_fet

### Other

In [9]:
from pathlib import Path

## **Implementation attempts**

### Irksome

In [57]:
def stefan_maxwell_irksome(
    Nspec:      int = 2,
    Nx:         int = 24,
    deg:        int = 1,
    vdeg:       int = 2,
    time_deg:   int = 1,
    Nt:         int = 24,
    dt:         float = 1e-9,
    Kval:       float = 1.0e-2,
    nu:         float = 1.0e-3,
    scheme:     str = "gauss",
    output_dir: str = "output/stefan_maxwell/",
    write_qois: bool = True,
    write_vtk:  bool = True,
):
    """
    Energy- and entropy-preserving Stefan-Maxwell scheme.
    - Unknowns (time-continuous): (rho_i), u, (rho s)
    - Auxiliary (time-discontinuous): (mu_i), p, theta, m
    Returns
    - Dictionary: {"time": [...], "energy": [...], "entropy": [...]}
    """
    # Ensure output directory exists
    out_path = Path(output_dir)
    out_path.mkdir(parents=True, exist_ok=True)

    # Convert parameters to UFL objects
    K_c = Constant(Kval)
    nu_c = Constant(nu)
    dt_c = Constant(dt)
    V_i = [0.8,0.2]

    # Mesh and coordinate (2D periodic box)
    mesh = PeriodicUnitSquareMesh(Nx, Nx)
    x, y = SpatialCoordinate(mesh)

    # Function spaces
    S = FunctionSpace(mesh, "CG", deg)  # Scalar
    V = VectorFunctionSpace(mesh, "CG", vdeg)  # Vector
    R = FunctionSpace(mesh, "R", 0)  # Real
    Z = MixedFunctionSpace(([S]*Nspec) + [V, S] + ([S]*Nspec) + [S, S, V, R])  # Mixed space: (rho_1...rho_N, u, rho_s, mu_1...mu_N, p, theta, m)

    # Solution functions
    z = Function(Z, name="state")
    z_split = split(z)
    rho = z_split[0:Nspec]; u = z_split[Nspec]; rho_s = z_split[Nspec+1]; mu = z_split[(Nspec+2):(2*Nspec+2)]; p = z_split[2*Nspec+2]; theta = z_split[2*Nspec+3]; m = z_split[2*Nspec+4]; lam = z_split[2*Nspec+5]
    z_out = z.subfunctions
    rho_out = z_out[0:Nspec]; u_out = z_out[Nspec]; rho_s_out = z_out[Nspec+1]; mu_out = z_out[(Nspec+2):(2*Nspec+2)]; p_out = z_out[2*Nspec+2]; theta_out = z_out[2*Nspec+3]; m_out = z_out[2*Nspec+4]; lam_out = z_split[2*Nspec+5]

    # Split tests (UFL)
    tests = TestFunctions(Z)
    psi = tests[0:Nspec]; v = tests[Nspec]; omega = tests[Nspec+1]; zeta = tests[(Nspec+2):(2*Nspec+2)]; q = tests[2*Nspec+2]; gamma = tests[2*Nspec+3]; w = tests[2*Nspec+4]; sigma = tests[2*Nspec+5]

    # Helpers for rho
    rho_tot = sum(rho)
    sqrt_rho = sqrt(rho_tot)

    # Free energy density (simple ideal mixture without gradient terms)
    rho_F = sum([rho[i] * ln(rho[i]/rho_tot) for i in range(Nspec)])

    # Total energy density (per volume)
    rho_e = rho_tot * exp((rho_s + rho_F)/rho_tot)
    rho_e_tot = 0.5 * rho_tot * inner(u, u) + rho_e

    # Mobility M_{ij} and related fluxes
    def M_ij(i, j):
        return (0.1*rho[i] if i == j else 0.0) - 0.1*rho[i] * rho[j] / rho_tot
    def grad_mu_over_theta(j):
        return grad(mu[j] / theta)

    # Skew-symmetric convection form C(rho u, v, w)
    def C_skw(rho_u, v_in, w_in):
        return 0.5 * (
            inner(dot(grad(v_in), rho_u), w_in)
          - inner(dot(grad(w_in), rho_u), v_in)
        )

    # Symmetric gradient
    Du = sym(grad(u))

    # Residual
    F = 0
    for i in range(Nspec):  # Mass (for each species)
        diff_flux_i = sum(M_ij(i, j) * grad_mu_over_theta(j) for j in range(Nspec))
        F += (
            inner(Dt(rho[i]), psi[i])
          - inner(rho[i] * u, grad(psi[i]))
          + inner(diff_flux_i, grad(psi[i]))
        ) * dx
    for i in range(Nspec):  # Chemical potential
        d_rho_e_d_rhoi = diff(rho_e, variable(rho[i]))
        F += (
            (mu[i] - d_rho_e_d_rhoi - V_i[i] * p) * zeta[i]
        ) * dx
    rhou = rho_tot * u  # Momentum
    F += (
        inner(sqrt_rho * Dt(m), v)
      + C_skw(rhou, u, v)
      + 2.0 * nu_c * inner(Du, sym(grad(v)))
      - inner(p, div(v))
      + sum([
            inner(rho[i] * grad(mu[i] - V_i[i] * p), v)
        for i in range(Nspec)])
      + inner(rho_s * grad(theta), v)
    ) * dx
    F += (  # Auxiliary momentum-like thing
        (inner(m, w) - inner(sqrt_rho * u, w))
    ) * dx
    F += (  # Pseudo-incompressibility
        div(u) * q
      + sum([sum([
            V_i[i] * inner(M_ij(i, j) * grad_mu_over_theta(j), grad(q))
        for j in range(Nspec)]) for i in range(Nspec)])
      + inner(lam, q)
    ) * dx
    inv_theta = 1.0 / theta  # Entropy
    F += (
        inner(Dt(rho_s), omega)
      - inner(rho_s * u, grad(omega))
      - 2.0 * nu_c * inner(Du, grad(u)) * inv_theta * omega
      - K_c * inner(grad(inv_theta), grad(omega * inv_theta))
      - sum([sum([
            inner(grad_mu_over_theta(j), grad((omega * mu[i]) * inv_theta)) * M_ij(i, j)
        for j in range(Nspec)]) for i in range(Nspec)])
    ) * dx
    F += (  # Temperature
        (theta - diff(rho_e, variable(rho_s))) * gamma
    ) * dx
    F += (  # Pressure nullspace
        inner(p, sigma)
    ) * dx

    # Time integrator
    t = Constant(0.0)
    sp = {
        # Example linear solver settings (tune as needed)
        "snes_monitor" : None,
        "snes_converged_reason" : None,
        "ksp_monitor" : None,
        "ksp_converged_reason" : None,
    }
    scheme_dict = {
        "cpg"   : ContinuousPetrovGalerkinScheme(time_deg, quadrature_degree=4*time_deg-1),  # Can up degree as needed
        "gauss" : GaussLegendre(time_deg)
    }
    stepper = TimeStepper(
        F, scheme_dict[scheme.lower()], t, dt_c, z,
        solver_parameters=sp
        # solver_parameters=sp, aux_indices=[Nspec+2+i for i in range(Nspec+3)]
    )

    # Initial conditions (Idk just trying this out)
    rho_ic = 1 + 0.2*sin(4*pi*x)*cos(2*pi*y) #0.6 + 0.2 * sin(2*pi*x) * sin(2*pi*y)
    rho_out[0].interpolate(rho_ic)
    rho_out[1].interpolate(1.0/V_i[1]*(1-V_i[0]*rho_ic))
    #for i in range(1, Nspec): rho_out[i].interpolate(rho_ic)
    theta_out.interpolate(1.1)
    rho_tot_out = sum(rho_out)
    rho_s_out.interpolate(rho_tot_out *ln(theta_out) - sum([rho_out[i] * ln(rho_out[i]/rho_tot_out) for i in range(Nspec)]))
    #for i in range(Nspec): mu_out[i].interpolate(diff(rho_e, variable(rho[i])) + V_i[i] * p)
    #theta_out.interpolate(diff(rho_e, variable(rho_s)))

    # Set up outputs
    E_form = rho_e_tot * dx
    S_form = rho_s * dx
    t_arr = []
    E_arr = []
    S_arr = []
    if write_qois:
        qoi_path = out_path / "qois.csv"
        with qoi_path.open("w", encoding="utf-8") as f:
            f.write("time,energy,entropy\n")
    def record_and_log():
        t_out = float(t)
        E_out = float(assemble(E_form))
        S_out = float(assemble(S_form))
        print(BLUE % f"Time (t) = {t_out:.6f}")
        print(GREEN % f"Energy  = {E_out:.8e}")
        print(GREEN % f"Entropy = {S_out:.8e}")
        t_arr.append(t_out)
        E_arr.append(E_out)
        S_arr.append(S_out)
        if write_qois:
            with (out_path / "qois.csv").open("a", encoding="utf-8") as f:
                f.write(f"{t_out},{E_out},{S_out}\n")
    record_and_log()
    if write_vtk:
        vtk = VTKFile(str(out_path / "u.pvd"))
        u_out.rename("Barycentric velocity (u)")
        vtk.write(rho_s_out, time=float(t))

    # Time loop
    for _ in range(Nt):
        stepper.advance()
        t.assign(float(t) + float(dt_c))
        if write_vtk: vtk.write(u_out, time=float(t))
        record_and_log()

    return {"time": t_arr, "energy": E_arr, "entropy": S_arr}

### Time is space

In [167]:
def stefan_maxwell_timeisspace(
    Nspec:      int = 2,
    Nx:         int = 24,
    deg:        int = 1,
    vdeg:       int = 2,
    time_deg:   int = 1,
    Nt:         int = 24,
    dt:         float = 1e-9,
    Kval:       float = 1.0e-2,
    nu:         float = 1.0e-3,
    output_dir: str = "output/stefan_maxwell/",
    write_qois: bool = True,
    write_vtk:  bool = True,
):
    """
    Energy- and entropy-preserving Stefan-Maxwell scheme.
    - Unknowns (time-continuous): (rho_i), u, (rho s)
    - Auxiliary (time-discontinuous): (mu_i), p, theta, m
    Returns
    - Dictionary: {"time": [...], "energy": [...], "entropy": [...]}
    """
    # Parameters
    if Nspec != 2: raise ValueError("Initial conditions currently only set up for 2 species")
    rho_ic = lambda x_, y_: 1 + 0.2*sin(4*pi*x_)*cos(2*pi*y_)
    theta_ic = 1.1  # (Just an initial guess)
    V_i = [Constant(0.8), Constant(0.2)]

    # Derivatives
    grad_ = lambda ufl_expr: as_vector([ufl_expr.dx(i) for i in range(2)])
    div_  = lambda ufl_expr: sum([ufl_expr[i].dx(i) for i in range(2)])
    Dt_   = lambda ufl_expr: ufl_expr.dx(2)

    # Ensure output directory exists
    out_path = Path(output_dir)
    out_path.mkdir(parents=True, exist_ok=True)

    # Convert inputs to UFL objects
    K_c = Constant(Kval)
    nu_c = Constant(nu)
    dt_c = Constant(dt)

    # Mesh and coordinate (2D periodic box)
    base_mesh = PeriodicUnitSquareMesh(Nx, Nx)
    mesh = ExtrudedMesh(base_mesh, layers=1, layer_height=dt)
    x, y, t = SpatialCoordinate(mesh)

    # Function spaces
    CG = lambda deg_: FunctionSpace(
        mesh,
        TensorProductElement(FiniteElement("CG", "triangle", deg_), FiniteElement("CG", "interval", time_deg))
    )
    DG = lambda deg_: FunctionSpace(
        mesh,
        TensorProductElement(FiniteElement("CG", "triangle", deg_), FiniteElement("DG", "interval", time_deg-1))
    )
    Z = MixedFunctionSpace(
        [CG(deg)]*Nspec + [CG(vdeg)]*2 + [CG(deg)]
      + [DG(deg)]*Nspec + [DG(deg)] + [DG(deg)] + [DG(vdeg)]*2
    )

    # Solution functions
    z = Function(Z, name="state")
    z_split = split(z)
    rho = z_split[0:Nspec]; u = as_vector(z_split[Nspec:Nspec+2]); rho_s = z_split[Nspec+2]; mu = z_split[(Nspec+3):(2*Nspec+3)]; p = z_split[2*Nspec+3]; theta = z_split[2*Nspec+4]; m = as_vector(z_split[2*Nspec+5:2*Nspec+7])
    z_out = z.subfunctions
    rho_out = z_out[0:Nspec]; u_out = as_vector(z_out[Nspec:Nspec+2]); rho_s_out = z_out[Nspec+2]; mu_out = z_out[(Nspec+3):(2*Nspec+3)]; p_out = z_out[2*Nspec+3]; theta_out = z_out[2*Nspec+4]; m_out = as_vector(z_out[2*Nspec+5:2*Nspec+7])

    # Test functions
    tests = TestFunctions(Z)
    psi = [Dt_(test) for test in tests[0:Nspec]]; v = Dt_(as_vector(tests[Nspec:Nspec+2])); omega = Dt_(tests[Nspec+2]); zeta = tests[(Nspec+3):(2*Nspec+3)]; q = tests[2*Nspec+3]; gamma = tests[2*Nspec+4]; w = as_vector(tests[2*Nspec+5:2*Nspec+7])

    # Helpers for rho
    rho_tot = sum(rho)
    sqrt_rho = sqrt(rho_tot)

    # Free energy density (simple ideal mixture without gradient terms)
    rho_F = sum([rho[i] * ln(rho[i]/rho_tot) for i in range(Nspec)])

    # Total energy density (per volume)
    rho_e = rho_tot * exp((rho_s + rho_F)/rho_tot)
    rho_e_tot = 0.5 * rho_tot * inner(u, u) + rho_e

    # Mobility M_ij and related fluxes
    M_ij = lambda i, j: (0.1*rho[i] if i == j else 0.0) - 0.1 * rho[i] * rho[j] / rho_tot
    grad_mu_over_theta = lambda j: grad_(mu[j] / theta)

    # Skew-symmetric convection form C(rho u, v, w)
    C_skw = lambda rho_u, v_in, w_in: 0.5 * (
        inner(dot(grad_(v_in), rho_u), w_in)
      - inner(dot(grad_(w_in), rho_u), v_in)
    )

    # Symmetric gradient
    sym_grad_ = lambda ufl_expr: sym(grad_(ufl_expr))

    # Residual
    F = 0
    for i in range(Nspec):  # Mass (for each species)
        diff_flux_i = sum(M_ij(i, j) * grad_mu_over_theta(j) for j in range(Nspec))
        F += (
            inner(Dt_(rho[i]), psi[i])
          + inner( - rho[i] * u + diff_flux_i, grad_(psi[i]))
        ) * dx
    for i in range(Nspec):  # Chemical potential
        F += (
            inner(mu[i] - diff(rho_e, variable(rho[i])) - V_i[i] * p, zeta[i])
        ) * dx
    rhou = rho_tot * u  # Momentum
    F += (
        inner(sqrt_rho * Dt_(m), v)
      + C_skw(rhou, u, v)
      + inner(2.0 * nu_c * sym_grad_(u), sym_grad_(v))
      - inner(p, div_(v))
      + sum([
            inner(rho[i] * grad_(mu[i] - V_i[i] * p), v)
        for i in range(Nspec)])
      + inner(rho_s * grad_(theta), v)
    ) * dx
    F += (  # Auxiliary momentum-like thing
        (inner(m, w) - inner(sqrt_rho * u, w))
    ) * dx
    F += (  # Pseudo-incompressibility
        div_(u) * q
      + sum([sum([
            V_i[i] * inner(M_ij(i, j) * grad_mu_over_theta(j), grad_(q))
        for j in range(Nspec)]) for i in range(Nspec)])
    ) * dx
    inv_theta = 1.0 / theta  # Entropy
    F += (
        inner(Dt_(rho_s), omega)
      - inner(rho_s * u, grad_(omega))
      - inner(2.0 * nu_c * inner(sym_grad_(u), sym_grad_(u)) * inv_theta, omega)
      - inner(K_c * grad_(inv_theta), grad_(inv_theta * omega))
      - sum([sum([
            inner(M_ij(i, j) * grad_mu_over_theta(j), grad_(mu[i] * inv_theta * omega))
        for j in range(Nspec)]) for i in range(Nspec)])
    ) * dx
    F += (  # Temperature
        inner(theta - diff(rho_e, variable(rho_s)), gamma)
    ) * dx

    # Solver parameters
    sp = {
        # Example linear solver settings (tune as needed)
        "snes_monitor"          : None,
        "snes_converged_reason" : None,
        "ksp_monitor"           : None,
        "ksp_converged_reason"  : None,
    }

    # Initial guesses  (To help the solver)
    theta_out.interpolate(theta_ic)

    # Initial(/boundary) conditions (Idk just trying this out)
    Z_ = MixedFunctionSpace([CG(deg)]*Nspec + [CG(vdeg)]*2 + [CG(deg)])
    z_ = Function(Z_)
    z_out_ = z_.subfunctions
    rho_out_ = z_out[0:Nspec]; u_out_ = as_vector(z_out[Nspec:Nspec+2]); rho_s_out_ = z_out[Nspec+2]

    rho_out[0].interpolate(rho_ic(x, y))
    rho_out_[0].interpolate(rho_ic(x, y))
    rho_out[1].interpolate(1.0/V_i[1]*(1-V_i[0]*rho_ic(x, y)))
    rho_out_[1].interpolate(1.0/V_i[1]*(1-V_i[0]*rho_ic(x, y)))
    rho_tot_out = sum(rho_out)
    rho_s_out.interpolate(rho_tot_out * ln(theta_ic) - sum([rho_out[i] * ln(rho_out[i]/rho_tot_out) for i in range(Nspec)]))
    rho_s_out_.interpolate(rho_tot_out * ln(theta_ic) - sum([rho_out[i] * ln(rho_out[i]/rho_tot_out) for i in range(Nspec)]))

    bcs = [
        DirichletBC(Z.sub(i),       rho_out_[i], "bottom") for i in range(Nspec)
    ] + [
        DirichletBC(Z.sub(Nspec),   u_out_[0],   "bottom"),
        DirichletBC(Z.sub(Nspec+1), u_out_[1],   "bottom"),
        DirichletBC(Z.sub(Nspec+2), rho_s_out_,  "bottom")
    ]

    # Time loop
    t = Constant(0)
    for _ in range(Nt):
        print(RED % f"Solving for time t = {float(t) + float(dt_c)}...")
        solve(F==0, z, bcs=bcs, solver_parameters=sp)
        for i in range(Nspec): rho_out_[i].dat.data == np.repeat(rho_out[i].dat.data[time_deg::time_deg+1], time_deg+1)
        for i in range(2): u_out_[i].dat.data == np.repeat(u_out[i].dat.data[time_deg::time_deg+1], time_deg+1)
        rho_s_out_.dat.data == np.repeat(rho_s_out.dat.data[time_deg::time_deg+1], time_deg+1)
        t.assign(float(t) + float(dt_c))

    # return {"time": t_arr, "energy": E_arr, "entropy": S_arr}

In [168]:
_ = stefan_maxwell_timeisspace(Nx=12)

[1;37;31mSolving for time t = 1e-09...[0m
  0 SNES Function norm 9.166666666667e-11
    Residual norms for petsctools_507_ solve.
    0 KSP Residual norm 9.166666666667e-11
    1 KSP Residual norm 3.682628451123e-12
    Linear petsctools_507_ solve converged due to CONVERGED_ITS iterations 1
  1 SNES Function norm 3.686393946053e+28
  Nonlinear petsctools_507_ solve did not converge due to DIVERGED_DTOL iterations 1


ConvergenceError: Nonlinear solve failed to converge after 1 nonlinear iterations.
Reason:
   DIVERGED_DTOL

In [None]:
    # # Set up outputs
    # E_form = rho_e_tot * dx
    # S_form = rho_s * dx
    # t_arr = []
    # E_arr = []
    # S_arr = []
    # if write_qois:
    #     qoi_path = out_path / "qois.csv"
    #     with qoi_path.open("w", encoding="utf-8") as f:
    #         f.write("time,energy,entropy\n")
    # def record_and_log():
    #     t_out = float(t)
    #     E_out = float(assemble(E_form))
    #     S_out = float(assemble(S_form))
    #     print(BLUE % f"Time (t) = {t_out:.6f}")
    #     print(GREEN % f"Energy  = {E_out:.8e}")
    #     print(GREEN % f"Entropy = {S_out:.8e}")
    #     t_arr.append(t_out)
    #     E_arr.append(E_out)
    #     S_arr.append(S_out)
    #     if write_qois:
    #         with (out_path / "qois.csv").open("a", encoding="utf-8") as f:
    #             f.write(f"{t_out},{E_out},{S_out}\n")
    # record_and_log()
    # if write_vtk:
    #     vtk = VTKFile(str(out_path / "u.pvd"))
    #     u_out.rename("Barycentric velocity (u)")
    #     vtk.write(rho_s_out, time=float(t))

    # # Time loop
    # for _ in range(Nt):
    #     stepper.advance()
    #     t.assign(float(t) + float(dt_c))
    #     if write_vtk: vtk.write(u_out, time=float(t))
    #     record_and_log()

### Legacy FET

In [None]:
def stefan_maxwell_legacyfet(
    Nspec:      int = 2,
    Nx:         int = 24,
    deg:        int = 1,
    vdeg:       int = 2,
    time_deg:   int = 1,
    Nt:         int = 24,
    dt:         float = 1e-9,
    Kval:       float = 1.0e-2,
    nu:         float = 1.0e-3,
    output_dir: str = "output/stefan_maxwell/",
    write_qois: bool = True,
    write_vtk:  bool = True,
):
    """
    Energy- and entropy-preserving Stefan-Maxwell scheme.
    - Unknowns (time-continuous): (rho_i), u, (rho s)
    - Auxiliary (time-discontinuous): (mu_i), p, theta, m
    Returns
    - Dictionary: {"time": [...], "energy": [...], "entropy": [...]}
    """
    # Other parameters
    V_i = [Constant(0.8), Constant(0.2)]
    rho_ic = 1 + 0.2*sin(4*pi*x)*cos(2*pi*y)
    theta_ic = 1.1


    # Ensure output directory exists
    out_path = Path(output_dir)
    out_path.mkdir(parents=True, exist_ok=True)


    # Convert parameters to UFL objects
    K_c = Constant(Kval)
    nu_c = Constant(nu)
    dt_c = Constant(dt)


    # Mesh and coordinates
    mesh = PeriodicUnitSquareMesh(Nx, Nx)
    x, y = SpatialCoordinate(mesh)


    # Spatial function spaces and functions
    S_ = FunctionSpace(mesh, "CG", deg)
    V_ = VectorFunctionSpace(mesh, "CG", vdeg)
    Z_ = MixedFunctionSpace(([S]*Nspec) + [V, S])

    z_ = Function(Z_, name="ICs/tracker")

    z_split_ = split(z_)
    rho_   = z_split_[0:Nspec]
    u_     = z_split_[Nspec]
    rho_s_ = z_split_[Nspec+1]

    z_out = z_.subfunctions
    rho_out   = z_out[0:Nspec]
    u_out     = z_out[Nspec]
    rho_s_out = z_out[Nspec+1]

    if Nspec != 2: raise ValueError("ICs currently only set up for 2 species")
    rho_out[0].interpolate(rho_ic)
    rho_out[1].interpolate(1.0/V_i[1]*(1-V_i[0]*rho_ic))
    rho_tot_ic = sum(rho_out)
    rho_s_out.interpolate(rho_tot_ic*ln(theta_ic) - sum([rho_out[i]*ln(rho_out[i]/rho_tot_ic) for i in range(Nspec)]))


    # Space–time function spaces and functions
    S = cheb_fet.FETFunctionSpace(mesh, "CG", deg, order=time_deg)
    V = cheb_fet.FETVectorFunctionSpace(mesh, "CG", vdeg, order=time_deg)
    R = cheb_fet.FETFunctionSpace(mesh, "R", 0, order=time_deg)
    Z = cheb_fet.FETMixedFunctionSpace(([S]*Nspec) + [V, S] + ([S]*Nspec) + [S, S, V, R])  # Mixed space: (rho_1...rho_N, u, rho_s, mu_1...mu_N, p, theta, m)

    # z = Function(Z, name="state")
    # HOTFIX: Make sure theta initial guess is positive
    z = project(
        as_vector([0 for _ in range((2*Nspec+3)*time_deg)] + [theta_ic] + [0 for _ in range(3*time_deg - 1)]),
        Z, name="state"
    )
    z_split = cheb_fet.FETsplit(z)
    rho_t   = z_split[0:Nspec]
    u_t     = z_split[Nspec]
    rho_s_t = z_split[Nspec+1]
    mu      = z_split[(Nspec+2):(2*Nspec+2)]
    p       = z_split[2*Nspec+2]
    theta   = z_split[2*Nspec+3]
    m       = z_split[2*Nspec+4]
    lam     = z_split[2*Nspec+5]

    rho   = [cheb_fet.integrate(rho_t[i], rho_[i], dt) for i in range(Nspec)]
    u     = cheb_fet.integrate(u_t, u_, dt)
    rho_s = cheb_fet.integrate(rho_s_t, rho_s_, dt)

    tests = cheb_fet.FETsplit(TestFunction(Z))
    psi   = tests[0:Nspec]
    v     = tests[Nspec]
    omega = tests[Nspec+1]
    zeta  = tests[(Nspec+2):(2*Nspec+2)]
    q     = tests[2*Nspec+2]
    gamma = tests[2*Nspec+3]
    w     = tests[2*Nspec+4]; sigma = tests[2*Nspec+5]


    # Helpers for rho
    rho_tot = [sum([rho[j][i] for j in range(Nspec)]) for i in range(time_deg)]
    # sqrt_rho  = sqrt(rho_tot)
    # rho_F     = rho1*ln(rho1/rho_tot) + rho2*ln(rho2/rho_tot)
    # rho_e     = rho_tot * exp((rho_s + rho_F)/rho_tot)
    # rho_e_tot = 0.5 * rho_tot * inner(u, u) + rho_e

    # def C_skw(rho_u, v_in, w_in):
    #     return 0.5 * (inner(dot(rho_u, nabla_grad(v_in)), w_in) - inner(dot(rho_u, nabla_grad(w_in)), v_in))


    # Residual
    F = 0

    # Mass for (for each species)
    M_ij = lambda rho__, rho_tot__, i, j: (0.1*rho__[i] if i == j else 0.0) - 0.1 * rho__[i] * rho__[j] / rho_tot__
    for i in range(Nspec):
        F = cheb_fet.residual(F, lambda rho_t_i__, psi_i__:      inner(rho_t_i__, psi_i__) * dx,         (rho_t[i], psi[i]))
        F = cheb_fet.residual(F, lambda rho_t_i__, u__, psi_i__: - inner(rho_t_i__ * u__, psi_i__) * dx, (rho_t[i], u, psi[i]))
        diff_flux = lambda rho__, rho_tot__, mu__, theta__: sum([M_ij(rho__, rho_tot__, i, j) * grad(mu__[j] / theta__) for j in range(Nspec)])  # a -> rho, b -> rho_tot, c -> mu, d -> theta
        F = cheb_fet.residual(F, lambda *args: inner(diff_flux(args[0:Nspec], args[Nspec], args[Nspec+1:2*Nspec+1], args[Nspec]), grad(args[Nspec+1])) * dx, (*rho, rho_tot, *mu, theta, psi[i]), poly=False)

    # Chemical potential
    rho_F = lambda rho__, rho_tot__: sum([rho_i__ * ln(rho_i__/rho_tot__) for rho_i__ in rho__])
    rho_e = lambda rho__, rho_tot__, rho_s__: rho_tot__ * exp((rho_s__ + rho_F(rho__, rho_tot__))/rho_tot__)
    for i in range(Nspec):
        F = cheb_fet.residual(F, lambda mu_i__, zeta_i__: inner(mu_i__, zeta_i__) * dx,         (mu[i], zeta[i]))
        F = cheb_fet.residual(F, lambda p__, zeta_i__:    - inner(V_i[i] * p__, zeta_i__) * dx, (p, zeta[i]))



    F = cheb_fet.residual(F, lambda a, b: inner(a, b) * dx, (rho1_t, psi1))
    F = cheb_fet.residual(F, lambda a, b, c: - inner(a * b, grad(c)) * dx, (rho1, u, psi1))
    # Diffusion (diagonal + off-diagonal)
    F = cheb_fet.residual(F, lambda a, b, c: inner(0.1 * a * grad(b/c), grad(c)) * dx, (rho1, mu1, theta, psi1))
    F = cheb_fet.residual(F, lambda a, b, c, d, e: inner(-0.1 * a * b / c * grad(d/e), grad(e)) * dx, (rho1, rho2, rho_tot, mu2, theta, psi1))

    # Mass for species 2
    F = cheb_fet.residual(F, lambda a, b: inner(a, b) * dx, (rho2_t, psi2))
    F = cheb_fet.residual(F, lambda a, b, c: - inner(a * b, grad(c)) * dx, (rho2, u, psi2))
    F = cheb_fet.residual(F, lambda a, b, c: inner(0.1 * a * grad(b/c), grad(c)) * dx, (rho2, mu2, theta, psi2))
    F = cheb_fet.residual(F, lambda a, b, c, d, e: inner(-0.1 * a * b / c * grad(d/e), grad(e)) * dx, (rho2, rho1, rho_tot, mu1, theta, psi2))

    # Chemical potential relations (mu_i - d rho_e / d rho_i - V_i[i] * p) = 0
    # Use quadrature to capture time dependence of rho_e dependence on (rho1, rho2, rho_s)
    F = cheb_fet.residual(
        F,
        lambda r1, r2, rs, mu_i, p_, zeta_i: (mu_i - diff((r1+r2) * exp((rs + r1*ln(r1/(r1+r2)) + r2*ln(r2/(r1+r2)))/(r1+r2)), variable(r1)) - V_i[0]*p_) * zeta_i * dx,
        (rho1, rho2, rho_s, mu1, p, zeta1),
        poly=False
    )
    F = cheb_fet.residual(
        F,
        lambda r1, r2, rs, mu_i, p_, zeta_i: (mu_i - diff((r1+r2) * exp((rs + r1*ln(r1/(r1+r2)) + r2*ln(r2/(r1+r2)))/(r1+r2)), variable(r2)) - V_i[1]*p_) * zeta_i * dx,
        (rho1, rho2, rho_s, mu2, p, zeta2),
        poly=False
    )

    # Momentum
    Du = sym(grad(u))
    rhou = rho_tot * u
    F = cheb_fet.residual(F, lambda a, b, c: inner(a * b, c) * dx, (sqrt_rho, m_t, v))
    F = cheb_fet.residual(F, lambda a, b, c: C_skw(a, b, c) * dx, (rhou, u, v), poly=False)
    F = cheb_fet.residual(F, lambda a, b: 2.0 * nu_c * inner(sym(grad(a)), sym(grad(b))) * dx, (u, v))
    F = cheb_fet.residual(F, lambda a, b: - inner(a, div(b)) * dx, (p, v))
    F = cheb_fet.residual(F, lambda a, b, c, d: inner(a * grad(b - V_i[0] * c), d) * dx, (rho1, mu1, p, v))
    F = cheb_fet.residual(F, lambda a, b, c, d: inner(a * grad(b - V_i[1] * c), d) * dx, (rho2, mu2, p, v))
    F = cheb_fet.residual(F, lambda a, b, c: inner(a * grad(b), c) * dx, (rho_s, theta, v))

    # Auxiliary momentum-like constraint: m = sqrt_rho * u
    F = cheb_fet.residual(F, lambda a, b, c: inner(a, b) * dx - inner(c * u, b) * dx, (m, w, sqrt_rho))

    # Pseudo-incompressibility and pressure nullspace
    F = cheb_fet.residual(F, lambda a: div(a) * q * dx, (u,))
    # Mobility contribution to divergence constraint
    F = cheb_fet.residual(F, lambda a, b, c, d: V_i[0] * inner(0.1*a * grad(b/d) - 0.1*a*c/(a+c) * grad(b/d), grad(q)) * dx, (rho1, mu1, rho2, theta))
    F = cheb_fet.residual(F, lambda a, b, c, d: V_i[1] * inner(0.1*a * grad(b/d) - 0.1*a*c/(a+c) * grad(b/d), grad(q)) * dx, (rho2, mu2, rho1, theta))
    F = cheb_fet.residual(F, lambda a, b: inner(a, b) * dx, (p, sigma))

    # Entropy equation
    inv_theta = 1.0/theta
    F = cheb_fet.residual(F, lambda a, b: inner(a, b) * dx, (rho_s_t, omega))
    F = cheb_fet.residual(F, lambda a, b, c: - inner(a * b, grad(c)) * dx, (rho_s, u, omega))
    F = cheb_fet.residual(F, lambda a, b, c: - 2.0 * nu_c * inner(sym(grad(a)), grad(a)) * (b) * c * dx, (u, inv_theta, omega), poly=False)
    F = cheb_fet.residual(F, lambda a, b, c: - K_c * inner(grad(a), grad(b*c)) * dx, (inv_theta, omega, inv_theta))
    # Cross-diffusion entropy production terms (diagonal + off-diagonal)
    F = cheb_fet.residual(F, lambda a, b, c, d: - inner(grad(a/b), grad((c * d) / b)) * 0.1 * a * dx, (mu1, theta, omega, mu1))
    F = cheb_fet.residual(F, lambda a, b, c, d, e: - inner(grad(d/b), grad((c * e) / b)) * (-0.1 * a * d/(a+d)) * dx, (rho1, theta, omega, rho2, mu2))
    F = cheb_fet.residual(F, lambda a, b, c, d: - inner(grad(a/b), grad((c * d) / b)) * 0.1 * a * dx, (mu2, theta, omega, mu2))
    F = cheb_fet.residual(F, lambda a, b, c, d, e: - inner(grad(d/b), grad((c * e) / b)) * (-0.1 * a * d/(a+d)) * dx, (rho2, theta, omega, rho1, mu1))

    # Temperature relation: theta = d rho_e / d rho_s
    F = cheb_fet.residual(
        F,
        lambda r1, r2, rs, th, gam: (th - diff((r1+r2) * exp((rs + r1*ln(r1/(r1+r2)) + r2*ln(r2/(r1+r2)))/(r1+r2)), variable(rs))) * gam * dx,
        (rho1, rho2, rho_s, theta, gamma),
        poly=False
    )

    # Solver parameters
    sp = {
        "snes_atol": 1e-12,
        "snes_rtol": 1e-12,
        "snes_converged_reason": None,
        "snes_monitor": None,
        "ksp_type": "preonly",
        "pc_type": "lu",
        "pc_factor_mat_solver_type": "mumps",
        "ksp_monitor_true_residual": None,
    }

    # Output setup
    out_path = Path(output_dir)
    out_path.mkdir(parents=True, exist_ok=True)
    E_txt = out_path / "energy.txt"
    S_txt = out_path / "entropy.txt"
    if write_vtk:
        pvd = VTKFile(str(out_path / "fields.pvd"))
        rho1_sub = Function(Sx, name="rho1")
        rho2_sub = Function(Sx, name="rho2")
        rho_s_sub = Function(Sx, name="rho_s")
        u_sub = Function(Vx, name="u")

    # Initial QoIs
    E_form = rho_e_tot * dx
    S_form = rho_s * dx

    def record_qois(tval):
        E = assemble(E_form)
        S = assemble(S_form)
        print(GREEN % f"Energy: {E}")
        print(GREEN % f"Entropy: {S}")
        with open(E_txt, "a") as f: f.write(str(float(E)) + "\n")
        with open(S_txt, "a") as f: f.write(str(float(S)) + "\n")

    record_qois(0.0)

    # Time loop
    t = Constant(0.0)
    for _ in range(Nt):
        print(BLUE % f"Solving t = {float(t) + float(dt_c)}")
        solve(F == 0, z, solver_parameters=sp)

        # Update trackers to end of slab
        rho1_0.assign(cheb_fet.FETeval((rho1_0, None), (z, 0), dt_c, dt_c))
        rho2_0.assign(cheb_fet.FETeval((rho2_0, None), (z, 1), dt_c, dt_c))
        u_0.assign(cheb_fet.FETeval((u_0, None), (z, 2), dt_c, dt_c))
        rho_s_0.assign(cheb_fet.FETeval((rho_s_0, None), (z, 3), dt_c, dt_c))
        m_0.assign(cheb_fet.FETeval((m_0, None), (z, 8), dt_c, dt_c))

        # Update integrated forms for QoIs with new trackers
        rho1 = cheb_fet.integrate(rho1_t, rho1_0, dt_c)
        rho2 = cheb_fet.integrate(rho2_t, rho2_0, dt_c)
        rho_s = cheb_fet.integrate(rho_s_t, rho_s_0, dt_c)
        u    = cheb_fet.integrate(u_t,    u_0,    dt_c)
        rho_tot   = rho1 + rho2
        rho_F     = rho1*ln(rho1/rho_tot) + rho2*ln(rho2/rho_tot)
        rho_e     = rho_tot * exp((rho_s + rho_F)/rho_tot)
        rho_e_tot = 0.5 * rho_tot * inner(u, u) + rho_e
        E_form = rho_e_tot * dx
        S_form = rho_s * dx

        if write_vtk:
            rho1_sub.assign(rho1_0)
            rho2_sub.assign(rho2_0)
            rho_s_sub.assign(rho_s_0)
            u_sub.assign(u_0)
            pvd.write(rho1_sub, rho2_sub, rho_s_sub, u_sub)

        record_qois(float(t) + float(dt_c))
        t.assign(float(t) + float(dt_c))

    return {"time": float(t), "energy": float(assemble(E_form)), "entropy": float(assemble(S_form))}