In [49]:

# ============  Sizing & indexing layer  ============

from dataclasses import dataclass

@dataclass
class ModelSize:
    Nn: int                 # negative solid CVs
    Np: int                 # positive solid CVs
    Ne_n: int = 1           # electrolyte CVs in negative region
    Ne_s: int = 1           # electrolyte CVs in separator
    Ne_p: int = 1           # electrolyte CVs in positive region

    @property
    def Ne_total(self): return self.Ne_n + self.Ne_s + self.Ne_p
    @property
    def Nx(self): return self.Nn + self.Np + self.Ne_total

def make_index_map(sz: ModelSize):
    # contiguous layout: [ c_n(1..Nn), c_p(1..Np), ce(1..Ne_total) ]
    cn = slice(0, sz.Nn)
    cp = slice(sz.Nn, sz.Nn + sz.Np)
    ce = slice(sz.Nn + sz.Np, sz.Nn + sz.Np + sz.Ne_total)

    # "surface" convention: separator-side solids
    # - negative electrode: separator-side is the "right" end of negative domain
    cn_surf = cn.stop - 1 if sz.Nn > 0 else None
    # - positive electrode: separator-side is the "left" end of positive domain
    cp_surf = cp.start if sz.Np > 0 else None

    # electrolyte left/right taps (paper uses ends of the chain)
    ce_left  = ce.start                       # first electrolyte node (in neg region)
    ce_right = ce.stop - 1                    # last electrolyte node (in pos region)

    IDX = dict(cn=cn, cp=cp, ce=ce,
               cn_surf=cn_surf, cp_surf=cp_surf,
               ce_left=ce_left, ce_right=ce_right)
    return IDX

def make_state_names(sz: ModelSize):
    names = []
    names += [f"cn{i+1}" for i in range(sz.Nn)]
    names += [f"cp{i+1}" for i in range(sz.Np)]
    # three-region electrolyte labels
    names += [f"ce_n{i+1}" for i in range(sz.Ne_n)]
    names += [f"ce_s{i+1}" for i in range(sz.Ne_s)]
    names += [f"ce_p{i+1}" for i in range(sz.Ne_p)]
    return names
@dataclass
class PhysParams:
    R: float = 8.314462618
    F: float = 96485.33212
    T: float = 298.15
    t_plus: float = 0.38
    kf: float = 1.0
    Re: float = 0.0
    Rf: float = 0.0
    csn_max: float = 3.1e4
    csp_max: float = 5.1e4

PP = PhysParams()  # make an instance you can still use as default


# ============  Solid blocks (generic Nn, Np)  ============

import numpy as np

def build_solid_A(N: int, D: float, R: float):
    """
    Stable 1D radial diffusion surrogate for single-particle solid with Neumann at r=0.
    For N=2 and N=4, reproduce your exact FVM stencils.
    For other N, fallback to a standard second-difference (scaled) with Neumann at both ends,
    then we'll feed flux via the B vector on the surface node.
    """
    s = D/(R**2)

    # Exact stencils you used
    if N == 2:
        A = np.array([[-8,  8],
                      [ 8, -8]], dtype=float)*s
        return A

    if N == 4:
        A = np.zeros((4,4), float)
        A[0,0], A[0,1] = -24,  24
        A[1,0], A[1,1], A[1,2] =  16, -40, 24
        A[2,1], A[2,2], A[2,3] =  16, -40, 24
        A[3,2], A[3,3] =  16, -16
        return A*s

    # Generic fallback (stable): simple Neumann-Neumann chain
    A = np.zeros((N,N), float)
    # Left Neumann
    A[0,0] = -1; A[0,1] = +1
    # Interior
    for i in range(1, N-1):
        A[i,i-1] = +1; A[i,i] = -2; A[i,i+1] = +1
    # Right Neumann (will also get flux via B)
    if N > 1:
        A[N-1,N-2] = +1; A[N-1,N-1] = -1
    return A * (8*s)  # scale to be comparable magnitude to your 2-node case

def build_solid_B(N: int, D: float, R: float, a_s: float, A_cs: float, L_region: float,
                  pp: PhysParams):
    """
    Input vector mapping pack current to surface flux for the solid.
    For N=2: matches your 4/(R*F*a*D*A*L) factor on the last node.
    For N=4: matches your 6/(...) legacy (if you use 4-node), but we’ll keep 4/(...) for consistency
    with your current 2-node form; you can flip to 6/(...) if needed.
    """
    b = np.zeros((N,1), float)
    # You used 4 for the 2-node stencil. Keep that here as the default robust choice.
    coef = 4.0/(R * pp.F * a_s * D * A_cs * L_region)
    b[-1,0] = coef
    return b


# ============  Electrolyte block (generic)  ============

def build_Ae_paper3(De: float, eps: float, Ln: float, Ls: float, Lp: float):
    """Exactly the paper’s 3-CV chain: ce_n1 -- ce_s1 -- ce_p1 (stable signs)."""
    K = De/eps
    w_ns = K * 4.0 / (Ln + Ls)**2
    w_sp = K * 4.0 / (Lp + Ls)**2
    Ae = np.array([[w_ns,  -w_ns,    0.0 ],
                   [ -w_ns, (w_ns+w_sp), -w_sp],
                   [  0.0 ,  -w_sp,   w_sp]], dtype=float)
    return Ae

def build_Be_paper3(Ln: float, Ls: float, Lp: float, eps: float, A_cs: float,
                    pp: PhysParams):
    b = np.zeros((3,1), float)
    s_n = (1.0 - pp.t_plus) / (0.5 * pp.F * A_cs * (Ln + Ls) * eps)
    s_p = (1.0 - pp.t_plus) / (0.5 * pp.F * A_cs * (Lp + Ls) * eps)
    b[0,0] = s_n
    b[2,0] = s_p
    return b

def build_Ae_uniform(Ne_total: int, De: float, eps: float, Ltot: float):
    """
    Simple, stable uniform-grid electrolyte chain of arbitrary length.
    Not “the paper” but robust when you want >3 nodes.
    """
    K = De/eps
    h = Ltot / Ne_total
    w = K / (h*h)

    Ae = np.zeros((Ne_total, Ne_total), float)
    for i in range(Ne_total):
        if i > 0:
            Ae[i,i-1] = +w
        if i < Ne_total-1:
            Ae[i,i+1] = +w
        # diagonal
        deg = (i>0) + (i<Ne_total-1)
        Ae[i,i] = -deg*w
    return Ae

def build_Be_uniform(sz: ModelSize, Ln: float, Ls: float, Lp: float,
                     eps: float, A_cs: float, pp: PhysParams):
    b = np.zeros((sz.Ne_total,1), float)
    k0 = 0
    k1 = k0 + sz.Ne_n
    k2 = k1 + sz.Ne_s
    k3 = k2 + sz.Ne_p

    s_n = (1.0 - pp.t_plus) / (pp.F * A_cs * Ln * eps) if sz.Ne_n>0 else 0.0
    s_p = (1.0 - pp.t_plus) / (pp.F * A_cs * Lp * eps) if sz.Ne_p>0 else 0.0
    if sz.Ne_n>0: b[k0:k1,0] = s_n / sz.Ne_n
    if sz.Ne_p>0: b[k2:k3,0] = s_p / sz.Ne_p
    return b

# ============  Assembly (size-agnostic)  ============

from scipy.linalg import block_diag
import control as ct

def assemble_system(params, sz: ModelSize, pp: PhysParams = PP):
    Ln,Ls,Lp = params["L1"], params["L2"], params["L3"]
    A_ln,A_lp = params["A_ln"], params["A_lp"]
    a_s_n,a_s_p = params["a_s_n"], params["a_s_p"]
    Dn,Rn = params["Dn"], params["Rn"]
    Dp,Rp = params["Dp"], params["Rp"]
    De  = params.get("De", params.get("D_e", 7.23e-10))
    eps = params.get("eps", 0.3)
    A_e = params.get("A_e", 1.0)

    An = build_solid_A(sz.Nn, Dn, Rn)
    Ap = build_solid_A(sz.Np, Dp, Rp)
    Bn = build_solid_B(sz.Nn, Dn, Rn, a_s_n, A_ln, Ln, pp)
    Bp = build_solid_B(sz.Np, Dp, Rp, a_s_p, A_lp, Lp, pp)

    if (sz.Ne_n, sz.Ne_s, sz.Ne_p) == (1,1,1):
        Ae = build_Ae_paper3(De, eps, Ln, Ls, Lp)
        Be = build_Be_paper3(Ln, Ls, Lp, eps, A_e, pp)
    else:
        Ltot = Ln + Ls + Lp
        Ae = build_Ae_uniform(sz.Ne_total, De, eps, Ltot)
        Be = build_Be_uniform(sz, Ln, Ls, Lp, eps, A_e, pp)

    A = block_diag(An, Ap, Ae)
    B = np.vstack([Bn, Bp, Be])
    IDX = make_index_map(sz)
    STATE_NAMES = make_state_names(sz)
    C = np.eye(A.shape[0]); D = np.zeros((A.shape[0],1))
    S = ct.ss(A,B,C,D)
    return S, A, B, IDX, STATE_NAMES


# ============  Helpers that adapt to size  ============

def make_x0(sz: ModelSize, pp: PhysParams = PP, theta_n0=0.2, theta_p0=0.9, ce0=1000.0):
    x0 = np.zeros(sz.Nx, float)
    IDX = make_index_map(sz)
    x0[IDX["cn"]] = theta_n0 * pp.csn_max
    x0[IDX["cp"]] = theta_p0 * pp.csp_max
    x0[IDX["ce"]] = ce0
    return x0

def safe_log_ratio(a, b, eps=1e-9):
    return np.log(np.maximum(a, eps) / np.maximum(b, eps))

# placeholders — keep or replace with your fits
def ocp_p(theta: float) -> float: return 4.0 - 0.1*np.tanh(8*(theta-0.5))
def ocp_n(theta: float) -> float: return 0.1 + 0.8*np.tanh(8*(theta-0.5))
def eta_block(u: float, x: np.ndarray, kind="linear", Rct=0.0, alpha=0.5, I0=10.0):
    if kind=="none": return 0.0
    if kind=="linear": return Rct*u
    if kind=="bv": return (2*PP.R*PP.T/(alpha*PP.F))*np.arcsinh(u/(2*I0))
    raise ValueError("eta kind unknown")

def output_voltage(x: np.ndarray, u: float, IDX, pp: PhysParams = PP,
                   eta_kind="linear", Rct=0.0, I0=10.0):
    cn_surf = x[IDX["cn_surf"]]
    cp_surf = x[IDX["cp_surf"]]
    ce_left  = x[IDX["ce_left"]]
    ce_right = x[IDX["ce_right"]]

    theta_p = np.clip(cp_surf/pp.csp_max, 1e-6, 1-1e-6)
    theta_n = np.clip(cn_surf/pp.csn_max, 1e-6, 1-1e-6)

    U = ocp_p(theta_p) - ocp_n(theta_n)
    log_term = -(2*pp.R*pp.T*(1-pp.t_plus)*pp.kf/pp.F) * safe_log_ratio(ce_left, ce_right)
    eta = eta_block(u, x, kind=eta_kind, Rct=Rct, I0=I0)
    return float(U + log_term + eta + (pp.Re + pp.Rf)*u)


# ============  Convenience / sanity checks  ============

# ---------- Matrix viewer (dimension-agnostic) ----------
def show_matrices(A_=None, B_=None, IDX_=None, STATE_NAMES_=None):
    """
    Pretty-print An,Ap,Ae and Bn,Bp,Be plus global A,B using the current
    A,B,IDX,STATE_NAMES by default. You can also call with explicit args:
      show_matrices(A, B, IDX, STATE_NAMES)
    """
    import numpy as np, pandas as pd
    try:
        from IPython.display import display
    except Exception:
        display = print  # fallback

    # Allow zero-arg call by reading globals
    if A_ is None: A_ = globals().get('A')
    if B_ is None: B_ = globals().get('B')
    if IDX_ is None: IDX_ = globals().get('IDX')
    if STATE_NAMES_ is None: STATE_NAMES_ = globals().get('STATE_NAMES')

    if any(z is None for z in (A_, B_, IDX_, STATE_NAMES_)):
        raise ValueError("Provide A,B,IDX,STATE_NAMES or define them in globals before calling show_matrices().")

    # Build index arrays from slices
    def idx_arr(slc): 
        return np.arange(slc.start, slc.stop) if isinstance(slc, slice) else np.atleast_1d(slc)

    cn_idx = idx_arr(IDX_["cn"])
    cp_idx = idx_arr(IDX_["cp"])
    ce_idx = idx_arr(IDX_["ce"])

    # Names directly from your STATE_NAMES (no assumptions about ce_ prefix)
    cn_names = [STATE_NAMES_[i] for i in cn_idx]
    cp_names = [STATE_NAMES_[i] for i in cp_idx]
    ce_names = [STATE_NAMES_[i] for i in ce_idx]

    # Extract sub-blocks
    An = A_[np.ix_(cn_idx, cn_idx)]
    Ap = A_[np.ix_(cp_idx, cp_idx)]
    Ae = A_[np.ix_(ce_idx, ce_idx)]

    Bn = B_[cn_idx, :]
    Bp = B_[cp_idx, :]
    Be = B_[ce_idx, :]

    # Helper to frame/print
    def df(M, rnames, cnames):
        return pd.DataFrame(M, index=rnames, columns=cnames)

    # Show blocks
    display(df(An, cn_names, cn_names)); display(df(Bn, cn_names, ['u']))
    display(df(Ap, cp_names, cp_names)); display(df(Bp, cp_names, ['u']))
    display(df(Ae, ce_names, ce_names)); display(df(Be, ce_names, ['u']))

    # Show global
    display(df(A_, STATE_NAMES_, STATE_NAMES_))
    display(df(B_, STATE_NAMES_, ['u']))
    print("A shape:", A_.shape, "  B shape:", B_.shape)


def simulate(S, A, B, I_of_t, t, x0, IDX, eta_kind="linear", Rct=0.0, I0=10.0, pp: PhysParams = PP):
    U = I_of_t(t)
    tout, y_states, x_states = ct.forced_response(S, T=t, U=U, X0=x0, return_x=True)
    V = np.array([output_voltage(x_states[:,i], U[i], IDX, pp=pp, eta_kind=eta_kind, Rct=Rct, I0=I0)
                  for i in range(x_states.shape[1])])
    return tout, x_states, V

# pick your model size
sz = ModelSize(Nn=2, Np=2, Ne_n=1, Ne_s=1, Ne_p=1)  # 7-state paper model

params = dict(
    Dn=1e-14, Rn=5e-6, Dp=1e-14, Rp=5e-6,
    a_s_n=1.0e6, a_s_p=1.0e6, A_ln=1.0, A_lp=1.0,
    L1=25e-6, L2=20e-6, L3=25e-6,
    De=7.23e-10, eps=0.30, A_e=1.0,
)

S, A, B, IDX, STATE_NAMES = assemble_system(params, sz, pp=PP)
x0 = make_x0(sz, pp=PP, theta_n0=0.2, theta_p0=0.9, ce0=1000.0)
t = np.linspace(0, 2000, 1001)
I_of_t = lambda tt: np.ones_like(tt)*3.2
tout, X, V = simulate(S, A, B, I_of_t, t, x0, IDX, eta_kind="linear", Rct=0.002, pp=PP)
print("A shape:", A.shape, "B shape:", B.shape)

evals = np.linalg.eigvals(A)
print("max Re(λ(A)) =", np.max(np.real(evals)))  # should be <= ~0


STATE_NAMES = ['cn1','cn2','cp1','cp2','ce1','ce2','ce3']


show_matrices()


A shape: (7, 7) B shape: (7, 1)
max Re(λ(A)) = 14.281481481481489


  xout[:, i] = (Ad @ xout[:, i-1]
  xout[:, i] = (Ad @ xout[:, i-1]
  yout = C @ xout + D @ U
  return np.log(np.maximum(a, eps) / np.maximum(b, eps))


Unnamed: 0,cn1,cn2
cn1,-0.0032,0.0032
cn2,0.0032,-0.0032


Unnamed: 0,u
cn1,0.0
cn2,33165660000000.0


Unnamed: 0,cp1,cp2
cp1,-0.0032,0.0032
cp2,0.0032,-0.0032


Unnamed: 0,u
cp1,0.0
cp2,33165660000000.0


Unnamed: 0,ce1,ce2,ce3
ce1,4.760494,-4.760494,0.0
ce2,-4.760494,9.520988,-4.760494
ce3,0.0,-4.760494,4.760494


Unnamed: 0,u
ce1,0.951977
ce2,0.0
ce3,0.951977


Unnamed: 0,cn1,cn2,cp1,cp2,ce1,ce2,ce3
cn1,-0.0032,0.0032,0.0,0.0,0.0,0.0,0.0
cn2,0.0032,-0.0032,0.0,0.0,0.0,0.0,0.0
cp1,0.0,0.0,-0.0032,0.0032,0.0,0.0,0.0
cp2,0.0,0.0,0.0032,-0.0032,0.0,0.0,0.0
ce1,0.0,0.0,0.0,0.0,4.760494,-4.760494,0.0
ce2,0.0,0.0,0.0,0.0,-4.760494,9.520988,-4.760494
ce3,0.0,0.0,0.0,0.0,0.0,-4.760494,4.760494


Unnamed: 0,u
cn1,0.0
cn2,33165660000000.0
cp1,0.0
cp2,33165660000000.0
ce1,0.9519774
ce2,0.0
ce3,0.9519774


A shape: (7, 7)   B shape: (7, 1)


In [None]:
# sz = ModelSize(Nn=4, Np=4, Ne_n=2, Ne_s=2, Ne_p=2)  # 14 total
# # same params; assemble_system picks "uniform" electrolyte automatically
# S, A, B, IDX, STATE_NAMES = assemble_system(params, sz)


In [None]:
# sz = ModelSize(Nn=4, Np=4, Ne_n=3, Ne_s=3, Ne_p=3)  # 17 total
# S, A, B, IDX, STATE_NAMES = assemble_system(params, sz)
