In [1]:
import numpy as np
import scipy.interpolate
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
import skfem
from skfem import *
from skfem.helpers import dot, grad
import math
from numba import njit 

In [None]:
# =============================================================================
# --- Kir4.1 Channel Model (Mandge et al., 2019) ---
# =============================================================================

@njit
def m_inf_Kir41(Vm):
    """ Steady-state activation for Kir4.1 channel. From Mandge et al. (2019), Eq. 4. """
    return 1.0 / (1.0 + np.exp((Vm + 77.13) / 22.166))

@njit
def h_inf_Kir41(Vm):
    """ Steady-state inactivation for Kir4.1 channel. From Mandge et al. (2019), Eq. 5. """
    # Note the double negatives in the original paper's Vh term, which cancel out.
    # Vh = -71.96 mV, k = 49.667 mV
    return 1.0 / (1.0 + np.exp((-71.96 - Vm) / -49.667))

@njit
def state_vars_deriv_Kir41(Vm, m, h, tau_m=1.0, tau_h=10.0):
    """
    Calculates derivatives for Kir4.1 channel gating variables m and h.
    Based on Mandge et al. (2019), Eqs. 6-7.

    NOTE: The time constants tau_m and tau_h are not explicitly defined as
    voltage-dependent in the provided paper. We assume constant values here,
    which is a common practice. These may need to be tuned or sourced from the
    paper's citations [8, 19, 20] for full accuracy. Activation (m) is typically
    fast, while inactivation (h) can be slower.
    """
    m_inf = m_inf_Kir41(Vm)
    h_inf = h_inf_Kir41(Vm)
    
    dm_dt = (m_inf - m) / tau_m
    dh_dt = (h_inf - h) / tau_h
    
    return dm_dt, dh_dt

@njit
def calculate_IKir_Mandge(Vm, m, h, E_K, K_o, g_bar_Kir=0.0002, a=0.5):
    """
    Calculates Kir4.1 current based on the Mandge et al. (2019) HH model.
    Based on Eqs. 2-3.

    Args:
        Vm (float): Membrane potential (mV).
        m (float): Activation gating variable.
        h (float): Inactivation gating variable.
        E_K (float): Nernst potential for K+ (mV).
        K_o (float): Extracellular K+ concentration (mM).
        g_bar_Kir (float): Maximum conductance (mS/cm^2). This is a key parameter to tune.
        a (float): Partial inactivation parameter (dimensionless).

    Returns:
        float: Kir4.1 current density (mA/cm^2).
    """
    # Eq 3: Calculate the effective conductance
    gKir = g_bar_Kir * m * (a * h + (1.0 - a))
    
    # Eq 2: Calculate the final current, including sqrt(K_o) dependence
    # The sqrt(K_o) dependence is a known feature of Kir channels.
    # A normalization factor (e.g., /sqrt(K_o_base)) might be needed if g_bar_Kir
    # was fitted at a specific baseline K_o, but we follow the paper's formula here.
    current = gKir * (Vm - E_K) * np.sqrt(K_o)
    
    return current
@njit
def calculate_Kir_direct_V_A(Vm, E_K, K_o, Va1 = 0, Va2 = -15, Va3 = 7,g = 2.083e-8):#Va1 = -14.83, Va2 = 34, Va3 = 19.23
    ### values of Vm, EK, Va1, Va2, Va3 in mV, g in S, resulting in mA as an output
    ### accroding to Mandage et al., 2019 - at resting state currencies dominated by Kir4.1 currents
    ### therefore, g_total = 1/resistance, and resustance = 48.9 MOhm, 
    ### returns mA - conductance in S and other values in mV
    ### derivation from Ostby et al. 2009, S1 text, eq.3
    Ikir = g * (Vm - E_K - Va1) * (np.sqrt(K_o) / (1 + np.exp((Vm - E_K - Va2)/Va3)))
    return Ikir

# --- Delayed Rectifier K+ Channel (KDR) ---
# Kinetics from Mandge et al. (2018) Main Text, Eqs 19-20
@njit
def n_inf_KDR(Vm):
    # Eq 19
    return 1.0 / (1.0 + np.exp((-35.0 - Vm) / 15.4))
@njit
def tau_n_KDR(Vm):
    # Eq 19
    tau = 7.14 + 29.74 * np.exp(-2.0 * ((Vm + 20.0) / 17.08)**2) # ms
    return max(tau, 1e-9) # Ensure positive tau
@njit
def state_vars_deriv_KDR(Vm, n):
    """ Calculates derivative for KDR channel gating variable n. """
    n_inf = n_inf_KDR(Vm); tau_n = tau_n_KDR(Vm)
    dn_dt = (n_inf - n) / tau_n
    return dn_dt
@njit
def calculate_IKDR(Vm, n, EK, g_max_KDR =0.002688 ):#0.00072
    """ Calculates KDR current. """
    # Eq 20
    current = g_max_KDR * (n**4) * (Vm - EK) # mA/cm^2
    return current


@njit
def get_SKCa_V12(Ca_i):
    """ Interpolates V1/2 for SKCa based on intracellular Ca (mM). """
    return np.interp(Ca_i, sk3_cai_data_mM, sk3_vh_data_mV)
@njit
def get_SKCa_sf(Ca_i):
    """ Interpolates slope factor for SKCa based on intracellular Ca (mM). """
    return np.interp(Ca_i, sk3_cai_data_mM, sk3_sf_data_mV)
@njit
def o_SKCa(Ca_i, EC = 0.42e-3): # EC50 seems to be in mM based on sk3_cai_data_mM units
    """ Calculates Ca-dependent open probability factor for SKCa. """
    # Hill equation form: Ca^n / (Ca^n + EC50^n), here n=5.6
    Ca_i_n = Ca_i**5.6
    return Ca_i_n / (Ca_i_n + EC**5.6) if (Ca_i_n + EC**5.6) > 1e-12 else 0.0
@njit
def m_SKCa(Vm, sf, V12, E_K):
    """
    Calculates voltage-dependent gating variable 'm'.
    """
    # Standard Boltzmann: 1 / (1 + exp((V1/2 - Vm) / sf))
    # Used form: 1 / (1 + exp((Vm - (E_K + V12)) / sf))
    # This implies the gating depends on the driving force (Vm - E_K) relative to -V12?
    # Or perhaps V12 is defined relative to E_K? Recheck source equations.
    exponent_arg = (Vm - (E_K + V12)) / sf if abs(sf) > 1e-12 else np.inf
    # Avoid overflow if exponent_arg is very large negative
    if exponent_arg < -700: return 1.0
    # Avoid overflow if exponent_arg is very large positive
    if exponent_arg > 700: return 0.0
    m = 1 / (1 + np.exp(exponent_arg))
    return m
@njit
def calculate_I_SKCa(Vm, Ca_i, E_K, g = 0.0009):
    """ Calculates SKCa current. """
    # Ca_i assumed in mM based on interpolation data units
    o = o_SKCa(Ca_i) # Calcium dependence factor
    sf = get_SKCa_sf(Ca_i) # Interpolated slope factor
    V12 = get_SKCa_V12(Ca_i) # Interpolated V1/2
    # SUSPICIOUS: Uses the m_BKCa named function for voltage dependence.
    m = m_SKCa(Vm, sf, V12, E_K) # Voltage dependence factor
    current = g * o * m * (Vm - E_K) # mA/cm^2
    return current

# --- Passive Leak Current ---
# =============================================================================

@njit
def calculate_Ipas_SGC(Vm, Epas = -78.257, gpas =0.042): #0.042
    """
    Calculates the passive leak current.
    """
    return gpas * (Vm - Epas) # mA/cm^2
sk3_cai_data_mM = np.array([0.2e-3, 0.3e-3, 0.5e-3, 10e-3]) # [Ca]i in mM
sk3_vh_data_mV = np.array([24.0, 35.30399, 60.49381, 59.13068]) # V1/2 in mV
sk3_sf_data_mV = np.array([128.0, 38.80141, 48.77538, 44.82705]) # Slope factor in mV

# @njit
# def steady_state_pump(Vm, Na_o, Na_i, nernst, F, g_Na = 1e-4):
#     E_Na = nernst * np.log(Na_o / Na_i)
#     J = -((g_Na / F)*(Vm - E_Na))/3
#     return J


@njit
def steady_state_pump(Vm, Na_o, Na_i, K_o, nernst, F, g_Na =0.00168 , K_Na_i_Km = 10, K_K_o_Km = 1.5):#
    """
    Calculates the maximum pump flux (J_NaK_max) required to maintain a steady state
    at the given baseline ion concentrations and Vm. This is the correct implementation
    based on Østby et al. (2009) Supporting Text S1, Eq. S14.
    """
    # 1. Calculate the passive Na+ flux that the pump must counteract at baseline. (nernst in mV)
    E_Na = nernst * np.log(Na_o / Na_i)
    # The flux is (g/F)*(Vm-E_Na). This will be a negative value for inward flux.
    passive_Na_flux = (g_Na / F) * (Vm - E_Na) 

    # 2. Calculate the required baseline pump flux (outward) to balance the passive influx.
    # The pump moves 3 Na+ ions per cycle, so the pump cycle flux J_NaKATPase
    # must be -passive_Na_flux / 3 to balance the sodium.
    J_NaKATPase_baseline = -passive_Na_flux / 3.0

    # 3. Calculate the pump's saturation factors at baseline concentrations.
    fNa_val = (Na_i**1.5) / (Na_i**1.5 + K_Na_i_Km**1.5)
    fK_val = K_o / (K_o + K_K_o_Km)

    # 4. Calculate J_NaK_max by rearranging the pump equation:
    # J_NaKATPase_baseline = J_NaK_max * fNa_val * fK_val
    denominator = fNa_val * fK_val
    if denominator < 1e-12:
        return 0.0 # Avoid division by zero
        
    J_NaK_max = J_NaKATPase_baseline / denominator
    #J in mmol/cm2*s)
    return J_NaK_max

@njit
def KNa_pump_revised(Na_i, K_o, J_NaK_max, K_Na_i_Km = 10, K_K_o_Km = 1.5, F_const=96485.3321): # C/mol
    """
    Calculates the Na+/K+ pump current contribution based on the Ivar Ostby et al., 2009 article's logic.
    Returns Na+ and K+ current densities.

    Args:
        Na_i (float): Intracellular Na+ concentration (mM).
        K_o (float): Extracellular K+ concentration (mM).
        J_NaK_max (float): Maximum pump flux density (mMol * s^-1 * cm^-2).
                             This is JNaKATPase,max from the article.
                             (Note: The article does not specify this value directly,
                              it would be a parameter fitted or taken from other literature.
                              A typical range might be 1e-12 to 10e-12 mol*s^-1*cm^-2).
        K_Na_i_Km (float): Michaelis constant for intracellular Na+ (mM). (KNa,i in article)
        K_K_o_Km (float): Michaelis constant for extracellular K+ (mM). (KKo in article)
        F_const (float): Faraday's constant (C/mol).

    Returns:
        tuple: (Na+ current density (mA/cm^2), K+ current density (mA/cm^2)).
               Positive current is outward.
    """
    # fNa ([Na+]i) = ([Na+]i)^1.5 / (([Na+]i)^1.5 + KNa,i^1.5)
    # (Equation 10 from the article)

    fNa_val = (Na_i**1.5) / (Na_i**1.5 + K_Na_i_Km**1.5)
    #if J in mM/s*cm2, then J = J*1e-3
    # fK ([K+]o) = [K+]o / ([K+]o + KKo)
    # (Equation 10 from the article)
    fK_val = K_o / (K_o + K_K_o_Km)
    # JNaKATPase ([Na+]i,[K+]o) = JNaKATPase,max * fNa ([Na+]i) * fK ([K+]o)
    # (Equation 9 from the article)
    # Units: mmol * s^-1 * cm^-2
    J_NaK_ATPase = J_NaK_max * fNa_val * fK_val

    # Convert flux density to current density (A/cm^2)
    # The pump is electrogenic: 3 Na+ out, 2 K+ in. Net 1 positive charge out per cycle.
    # Current = Flux * Stoichiometry_for_ion * Charge_per_ion (which is 1*F for monovalent)

    # Na+ current density (3 Na+ out per cycle)
    # Positive current is outward.

    i_Na_pump_mA_cm2 = 3 * J_NaK_ATPase * F_const

    # K+ current density (2 K+ in per cycle)
    # Inward current is negative.
    i_K_pump_mA_cm2 = -2 * J_NaK_ATPase * F_const


    return i_Na_pump_mA_cm2, i_K_pump_mA_cm2

@njit
def KCC1_current(K_o, K_i,Cl_o, Cl_i, nernst, g= 7e-5, g_by_F = 7.2549888e-10):
    ### for g/F we have constant 7e-5/F = 7.25...e-10
    ### and the eq is like (g/F)*R*T/F *ln(conc_o/conc_i)
    #J = g_by_F *nernst * np.log((K_o/K_i)*(Cl_o/Cl_i))
    log_arg = max((K_o/K_i)*(Cl_o/Cl_i), 1e-9)
    J = g *nernst * np.log(log_arg)
    flux_K_KCC1 = J
    flux_Cl_KCC1 = J
    return flux_K_KCC1, flux_Cl_KCC1


@njit
def NKCC1_current(Na_o, Na_i, K_o, K_i, Cl_o, Cl_i, nernst, g = 2e-6, g_by_F = 2.0728539e-11):
    ### since the eq. 7 from Ostby et al. goes like g/F * R*T/F, we can take the g/F as constant, 2e-6 / F = 2.0728539e-11
    ### in the ostby et al. g is divided by F for whatever reason, resulting in J in A*M/(C*cm2). We dont do this, as we want A/cm2
   # J = g_by_F *nernst * np.log((Na_o/Na_i)*(K_o/K_i)*(Cl_o/Cl_i)**2)
    log_arg = max((Na_o/Na_i)*(K_o/K_i)*(Cl_o/Cl_i)**2, 1e-9)
    J = g * nernst * np.log(log_arg)
    flux_Na_NKCC1 = J  # 1 Na+ per cycle
    flux_K_NKCC1  = -J  # 1 K+ per cycle
    flux_Cl_NKCC1 =  2 * J # 2 Cl- per cycle
    return flux_Na_NKCC1, flux_K_NKCC1, flux_Cl_NKCC1
@njit
def single_SGC(Vm, V_ext, V_glial,A,C,
               K_i, K_o, Ca_i, Ca_o,Na_o, Na_i, Cl_o, Cl_i,
               J_pump_steady_state, n_KDR, nernst,dt,F,diff_mode,neur_glial_dist,R_sgc_cm,
              K_o_base, Na_o_base, Cl_o_base, Ca_o_base, betta_o,
                K_i_base, Na_i_base, Cl_i_base, Ca_i_base, betta_i,
               m_kir, h_kir,
               IC =0):

   # Na_i = max(1e-12, Na_i); Na_o = max(1e-12, Na_o)
    K_i = max(1e-12, K_i); K_o = max(1e-12, K_o)
    Ca_i = max(1e-12, Ca_i); Ca_o = max(1e-12, Ca_o)
    Cl_i = max(1e-12, Cl_i); Cl_o = max(1e-12, Cl_o)
    Na_i = max(1e-12, Na_i); Na_o = max(1e-12, Na_o)

    E_K = nernst * np.log(K_o / K_i)
    #E_Na = nernst * np.log(Na_o / Na_i)
    
    I_pas = calculate_Ipas_SGC(Vm)

    n_KDR = n_KDR + state_vars_deriv_KDR(Vm, n_KDR) * dt


    dm_kir_dt, dh_kir_dt = state_vars_deriv_Kir41(Vm, m_kir, h_kir)
    m_kir = m_kir + dm_kir_dt * dt
    h_kir = h_kir + dh_kir_dt * dt
    Na_pas = I_pas*0.03226
    K_pas = I_pas*0.6451
    Na_pump, K_pump = KNa_pump_revised(Na_i, K_o, J_pump_steady_state)
    Na_NKCC, K_NKCC, Cl_NKCC = NKCC1_current(Na_o, Na_i, K_o, K_i, Cl_o, Cl_i, nernst) 
    K_KCC, Cl_KCC = KCC1_current(K_o, K_i, Cl_o, Cl_i, nernst)
    #calculate_Kir_direct_V_A(Vm, E_K, K_o) / A
   # print(calculate_IKir_Mandge(Vm, m_kir, h_kir,E_K, K_o))
    I_K = -calculate_IKir_Mandge(Vm, m_kir, h_kir,E_K, K_o) + calculate_IKDR(Vm, n_KDR,E_K) + calculate_I_SKCa(Vm, Ca_i,E_K) 
    + K_pump + K_NKCC + K_KCC - K_pas# * 1e3
    I_Na = Na_pump + Na_NKCC + Na_pas
    I_Cl =(Cl_NKCC - Cl_KCC)
    I_total = I_K + I_Na + I_Cl + I_pas

    # Na_flux = (A * I_Na * 1e-6) / (1 * F) 
    # K_flux = (A * I_K * 1e-6) / (1 * F) ### coef 1e-6, since final measurement M/ms (1e-3 for mA to A, 1e-3 for s to ms) '
    
    # Cl_flux = (A * I_Cl * 1e-6) / (1 * F)

    
    I_total_mem =  I_total * A  # Total membrane current in mA
    # Assuming IC is provided in mA.
    dVm_dt = -(I_total_mem - IC) / (C * 1e3) # C(F)*1e3 = C(mF). mA/mF = mV/ms.
    new_Vm = Vm + dt * dVm_dt

    # K_o, Na_o, Cl_o = diff_mode(I_K, I_Na,-I_Cl,
    #                    K_o, Na_o, Cl_o,dt,neur_glial_dist,
    #                     K_o_base, Na_o_base, Cl_o_base,betta = betta_o 
    #                    )
    Ca_o -= betta_o * (Ca_o - Ca_o_base) * dt
    Na_o -= betta_o * (Na_o - Na_o_base) * dt
    K_i, Na_i, Cl_i = diff_mode(-I_K, -I_Na,I_Cl,
                       K_i, Na_i, Cl_i,dt,R_sgc_cm,
                        K_i_base, Na_i_base, Cl_i_base,betta = betta_i
                       )
    Ca_i -= betta_i * (Ca_i - Ca_i_base) * dt

    # Convert current densities (mA/cm^2) to total currents (mA)

    # Convert total currents (A = mC/ms) to flux (mol/ms)
    # Flux = I / (z * F) where I is in Amperes (C/s).
    # I (mA) = I * 1e-3 (A). Flux = (I * 1e-3) / (z * F) (mol/s)
    # Flux (mol/ms) = Flux (mol/s) * 1e-3 = (I * 1e-6) / (z * F)
    Na_flux = (I_Na * 1e-6*A) / (1 * F) # mol/ms (z=1)
    K_flux = (I_K * 1e-6*A) / (1 * F)   # mol/ms (z=1)
   # Ca_flux = (I_Ca_total * 1e-6) / (2 * F) # mol/ms (z=2).
    
    # Update concentrations using Euler method
    # Outward flux (positive I) decreases intracellular, increases extracellular.
    # Flux units: mol/ms
    # dC units: mM
    dNa_o = (Na_flux * 1000 * dt) / V_ext
    dNa_i = -(Na_flux * 1000 * dt) / V_glial
    dK_o = (K_flux * 1000 * dt) / V_ext
    dK_i = -(K_flux * 1000 * dt) / V_glial
    # dCa_o = (Ca_flux * 1000 * dt) / V_ext
    # dCa_i = -(Ca_flux * 1000 * dt) / V_glial
    # Update concentrations
    Na_o = Na_o + dNa_o
    Na_i = Na_i + dNa_i # Update Na_i (not tracked as state var, but needed for next step if passed back)
    K_o = K_o + dK_o
    K_i = K_i + dK_i   # Update K_i (not tracked as state var)
    # Ca_o = Ca_o + dCa_o
    # Ca_i = Ca_i + dCa_i

    # Ensure concentrations don't go below a minimum positive value
    Na_o = max(1e-12, Na_o); Na_i = max(1e-12, Na_i)
    K_o = max(1e-12, K_o);   K_i = max(1e-12, K_i)
    Ca_o = max(1e-12, Ca_o); Ca_i = max(1e-12, Ca_i)
    return new_Vm, K_i, K_o, Ca_i, Ca_o,Na_o, Na_i, Cl_o, Cl_i, n_KDR,m_kir, h_kir
#@njit
def simulate(Vm, tmax, dt,T,stim_intervals, amplitude,
             K_i_base, K_o_base, Ca_i_base, Ca_o_base,Na_o_base, Na_i_base, Cl_o_base, Cl_i_base,
             Neur_glial_dist, R_sgc_cm,C,
             diff = 'simple_balance',betta_o = 0.1, betta_i = 1.3
            ):

    nt = int(tmax / dt) + 1 # Number of time steps

    
    R_gas = 8.314462 # J/(mol*K)
    F = 96485.332 # C/mol
    T_kelvin = T + 273.15 # Convert Celsius to Kelvin for GHK etc.
    nernst = 1e3 * R_gas * T_kelvin / F 


    Neur_glial_dist_cm = Neur_glial_dist
    area = 4 * np.pi * R_sgc_cm**2 # Surface area in cm^2

    # Re-calculating based on standard formulas and converting to Liters:
    R_total_cm = R_sgc_cm + Neur_glial_dist_cm
    V_sgc_cm3 = (4.0/3.0) * np.pi * (R_sgc_cm**3)
    V_total_cm3 = (4.0/3.0) * np.pi * (R_total_cm**3)
    V_ext_cm3 = V_total_cm3 - V_sgc_cm3
    V_sgc_L = V_sgc_cm3 * 1e-3 # Neuron volume in Liters
    V_ext_L = V_ext_cm3 * 1e-3   # Extracellular volume in Liters

    n_KDR = n_inf_KDR(Vm)

    J_pump_steady_state = steady_state_pump(Vm, Na_o_base,Na_i_base,K_o_base,K_o,nernst, F)# in mM/s*cm2
    #J_pump_steady_state = steady_state_pump(Vm, Na_o_base,Na_i_base,nernst, F)
    #J_pump_steady_state = 1e-12
    state_mat = np.zeros((6,nt))
    state_mat[0, 0] = Vm
    state_mat[1, 0] = K_o_base
    state_mat[2, 0] = Na_o_base
    state_mat[3, 0] = Ca_i_base
    state_mat[4, 0] = Ca_o_base
    state_mat[5,0] = Cl_o_base


    IC_inj = np.zeros(nt)
    for interval in stim_intervals:
        start_idx = int(interval[0] / dt)
        end_idx = int(interval[1] / dt)
        # Apply current within bounds of the simulation time
        actual_start_idx = max(0, start_idx)
        actual_end_idx = min(nt, end_idx + 1) # Apply up to and including end_idx
        if actual_start_idx < actual_end_idx:
            IC_inj[actual_start_idx:actual_end_idx] = amplitude

    if diff == 'simple_balance':
        diff_mode = simple_diff_balance
    else:
        diff_mode = dummy_diff

    K_i = K_i_base
    Na_i = Na_i_base
    Cl_i = Cl_i_base
    for i in tqdm(range(1,nt)):
        Vm = state_mat[0,i-1]
        K_o = state_mat[1,i-1]
        Na_o = state_mat[2,i-1]
        Ca_i = state_mat[3,i-1]
        Ca_o = state_mat[4,i-1]
        Cl_o = state_mat[5,i-1]
        new_Vm, K_i, K_o, Ca_i, Ca_o,Na_o, Na_i, Cl_o, Cl_i, n_KDR = single_SGC(Vm, V_ext_L, V_sgc_L, area, 
                                                            K_i, K_o, Ca_i, Ca_o,Na_o, Na_i, Cl_o, Cl_i, J_pump_steady_state, n_KDR, nernst,dt,F
                                                                                , diff_mode,Neur_glial_dist_cm,R_sgc_cm,
                                                                                              K_o_base, Na_o_base, Cl_o_base, betta_o,
                                                                                               K_i_base, Na_i_base, Cl_i_base, betta_i, IC = IC_inj[i])
        state_mat[0, i] = new_Vm
        state_mat[1, i] = K_o
        state_mat[2, i] = Na_o
        state_mat[3, i] = Ca_i
        state_mat[4, i] = Ca_o
        state_mat[5, i] = Cl_o
        
    return state_mat
sk3_cai_data_mM = np.array([0.2e-3, 0.3e-3, 0.5e-3, 10e-3]) # [Ca]i in mM
sk3_vh_data_mV = np.array([24.0, 35.30399, 60.49381, 59.13068]) # V1/2 in mV
sk3_sf_data_mV = np.array([128.0, 38.80141, 48.77538, 44.82705]) # Slope factor in mV

@njit
def simple_diff_balance(I_K, I_Na,I_Cl,
                       K_o, Na_o,  Cl_o,
                        dt,neur_glial_dist, 
                       K_base = 5, Na_base =150, Cl_base = 145, 
                    betta = 0.1, F =96485.332 ): 
    dK_dt = (I_K*1e-3 / (F*neur_glial_dist)) - betta * (K_o - K_base)
    dNa_dt = (I_Na*1e-3 / (F*neur_glial_dist)) - betta * (Na_o - Na_base)
   # dCa_dt = (I_Ca / (F*neur_glial_dist)) - betta * (Ca_o - Ca_base)
    dCl_dt = (I_Cl*1e-3 / (F*neur_glial_dist)) - betta * (Cl_o - Cl_base)
    return K_o + dK_dt*dt, Na_o + dNa_dt*dt,Cl_o + dCl_dt*dt
@njit
def dummy_diff(I_K, I_Na,I_Cl,
               K_o, Na_o, Cl_o,dt,neur_glial_dist,
               K_base = 5, Na_base =150, Cl_base = 145, 
               betta = None):
    return K_o, Na_o,  Cl_o

def create_equally_spaced_intervals(
    total_length: float, num_intervals: int, interval_len: float,
    start_time = False
) -> np.ndarray:
    """
    Creates an array of equally spaced intervals within a total length.

    The spacing includes gaps before the first interval, between intervals,
    and after the last interval, all of which are equal.

    Args:
        total_length: The total length available for intervals and gaps.
        num_intervals: The number of intervals to create.
        interval_len: The length of each individual interval.

    Returns:
        np.ndarray: A 2D array where each row is [start, end] of an interval.
                    Returns an empty array of shape (0, 2) if num_intervals is 0.

    Raises:
        ValueError: If num_intervals is negative, or if the sum of 
                    interval lengths exceeds total_length.
    """
    if num_intervals < 0:
        raise ValueError("Number of intervals cannot be negative.")
    if num_intervals == 0:
        return np.empty((0, 2)) # Return an empty array with 2 columns

    total_occupied_by_intervals = num_intervals * interval_len

    if total_occupied_by_intervals > total_length:
        raise ValueError(
            f"Total length of intervals ({total_occupied_by_intervals}) "
            f"cannot exceed total_length ({total_length})."
        )

    # There are (num_intervals + 1) gaps:
    # 1 before the first interval, (num_intervals - 1) between intervals, 1 after the last.
    num_gaps = num_intervals + 1
    
    total_gap_length = total_length - total_occupied_by_intervals
    
    # Each gap will have an equal size
    # This can be a float if the division is not exact
    gap_size = total_gap_length / num_gaps

    # Calculate start points
    # The first interval starts after one `gap_size`.
    if not start_time:
        
        first_interval_start = gap_size
    else:
        first_interval_start = start_time
    
    # Each subsequent interval starts after the previous interval_len and another gap_size.
    # So, the step from the start of one interval to the start of the next is (interval_len + gap_size).
    step_between_starts = interval_len + gap_size
    
    starts = first_interval_start + np.arange(num_intervals) * step_between_starts
    
    # Ends are simply starts + interval_len
    ends = starts + interval_len
    
    # Stack them column-wise to get [[start1, end1], [start2, end2], ...]
    intervals = np.column_stack((starts, ends))
    
    return intervals
@njit
def calculate_GHK(Vm_mV, P, C_in_M, C_out_M, z, T, F = 96485.332, R =  8.314462):
    """
    Calculates the GHK current density for a given ion.

    Args:
        Vm_mV (float): Membrane potential (mV)
        P (float): Permeability (cm/s)
        C_in_M (float): Intracellular concentration (Molar)
        C_out_M (float): Extracellular concentration (Molar)
        z (float): Valence of the ion
        T (float): Temperature in Kelvin
        F (float): Faraday constant (C/mol)
        R (float): Ideal gas constant (J/(mol*K))

    Returns:
        float: GHK current density (mA/cm^2). Convention: Positive for outward.
               Note: P (cm/s) * Conc (mol/L or mol/dm^3) * F * (...) gives A/cm^2
               Need to convert Conc from mol/L to mol/cm^3 (multiply by 1e-3)
               Final result multiplied by 1e3 to get mA/cm^2.
    """
    Vm_V = Vm_mV * 1e-3 # Convert mV to V
    C_in_M_conv = C_in_M * 1e-3 # Convert M (mol/L) to mol/cm^3
    C_out_M_conv = C_out_M * 1e-3 # Convert M (mol/L) to mol/cm^3
    FRT = F/(R*T)
    # Using FRT calculated based on T provided
    zFV_RT = z * Vm_V * FRT

    # Handle Vm near zero to avoid division by zero in the GHK formula
    if np.abs(zFV_RT) < 1e-6:
        # Use L'Hopital's rule approximation for exp(x) ≈ 1+x for small x
        # I ≈ P * z * F * (C_in - C_out)
        current = P * z * F * (C_in_M_conv - C_out_M_conv)
    else:
        exp_term = np.exp(-zFV_RT)
        denom = 1.0 - exp_term
        # Added check for denom ~ 0, though less likely now Vm=0 is handled
        if abs(denom) < 1e-12: return 0.0

        current = P * z * F * zFV_RT * (C_in_M_conv - C_out_M_conv * exp_term) / denom

    # The sign convention in the paper seems to be inward Ca current is negative.
    # The standard GHK formulation gives positive current for outward flux.
    # If C_in < C_out * exp(-zFV_RT), current is negative (inward flux of positive ion)
    return current*1e3 # Units: (cm/s) * (C/mol) * (mol/cm^3)*1e3 = mA/cm^2

# =============================================================================
# --- Sodium (Na+) Channels ---
# =============================================================================

# --- TTX-Sensitive Na+ Channel (Nav1.?) ---
# Kinetics from Mandge et al. (2018) Main Text, Eqs 5-6
@njit
def alpha_m_TTXS(Vm):
    # Eq 5 from Mandge paper
    Vm_adj = Vm
    if abs(Vm - 5.0) < 1e-6: Vm_adj += 1e-6 # Avoid potential division by zero if Vm == 5.0 (although exp denominator prevents it)
    return 15.5 / (1.0 + np.exp((Vm_adj - 5.0) / -12.08)) # ms^-1
@njit
def beta_m_TTXS(Vm):
    # Eq 5 from Mandge paper
    return 35.2 / (1.0 + np.exp((Vm + 72.7) / 16.7)) # ms^-1
@njit
def alpha_h_TTXS(Vm):
    # Eq 6 from Mandge paper
    return 0.24 * np.exp(-(Vm + 115.0) / 46.33) # ms^-1
@njit
def beta_h_TTXS(Vm):
    # Eq 6 from Mandge paper
    Vm_adj = Vm
    if abs(Vm - (-11.8)) < 1e-6: Vm_adj += 1e-6 # Avoid potential division by zero if Vm == -11.8 (although exp denominator prevents it)
    return 4.32 * np.exp((Vm_adj - (-11.8)) / -12.0) # ms^-1
@njit
def state_vars_deriv_TTXS(Vm, m, h):
    """ Calculates derivatives for TTXS Na+ channel gating variables m and h. """
    am = alpha_m_TTXS(Vm); bm = beta_m_TTXS(Vm)
    ah = alpha_h_TTXS(Vm); bh = beta_h_TTXS(Vm)
    dm_dt = am * (1.0 - m) - bm * m
    dh_dt = ah * (1.0 - h) - bh * h
    return dm_dt, dh_dt
@njit
def calculate_INa_TTXS(Vm, m, h, ENa, g_max_TTXS = 0.0001):
    """ Calculates TTXS Na+ current. """
    # Eq 7 from Mandge paper
    current = g_max_TTXS * (m**3) * h * (Vm - ENa)
    return current

# --- Nav1.8 Channel (TTX-Resistant) ---
# Kinetics from Mandge et al. (2018) Main Text, Eqs 10-11
@njit
def alpha_m_Nav18(Vm):
    # Eq 10 from Mandge paper
    denom = 1.0 + np.exp((Vm - 0.063) / 7.86)
    # Ensure denom is not effectively zero before division
    return 7.21 - (7.21 / denom if abs(denom) > 1e-12 else 0.0) # ms^-1
@njit
def beta_m_Nav18(Vm):
    # Eq 10 from Mandge paper
    denom = 1.0 + np.exp((Vm + 53.06) / 19.34)
    # Ensure denom is not effectively zero before division
    return 7.4 / denom if abs(denom) > 1e-12 else 0.0 # ms^-1
@njit
def alpha_h_Nav18(Vm):
    # Eq 11 from Mandge paper
    denom = 1.0 + np.exp((Vm + 68.5) / 10.01)
    # Ensure denom is not effectively zero before division
    return 0.003 + (1.63 / denom if abs(denom) > 1e-12 else 0.0) # ms^-1
@njit
def beta_h_Nav18(Vm):
     # Eq 11 from Mandge paper
    denom = 1.0 + np.exp((Vm - 11.44) / 13.12)
    # Original form: 0.81 - 0.81 / denom. Ensure denom != 0
    if abs(denom) < 1e-12: return 0.0 # Avoid division by zero; Check limit if Vm -> inf
    return 0.81 * (1.0 - 1.0 / denom) # ms^-1
@njit
def state_vars_deriv_Nav18(Vm, m, h):
    """ Calculates derivatives for Nav1.8 channel gating variables m and h. """
    # Use direct _inf from paper Eq 9
    m_inf = 1.0 / (1.0 + np.exp((-11.4 - Vm) / 8.5))
    h_inf = 1.0 / (1.0 + np.exp((Vm + 24.2) / 5.6))

    # Use alpha/beta from paper Eq 10 & 11 for taus
    am = alpha_m_Nav18(Vm) # Your existing alpha_m_Nav18 function
    bm = beta_m_Nav18(Vm) # Your existing beta_m_Nav18 function
    ah = alpha_h_Nav18(Vm) # Your existing alpha_h_Nav18 function
    bh = beta_h_Nav18(Vm) # Your existing beta_h_Nav18 function

    # Add a small epsilon to prevent division by zero if sum is exactly zero
    epsilon = 1e-12
    tau_m = 1.0 / (am + bm + epsilon)
    tau_h = 1.0 / (ah + bh + epsilon)

    # Ensure taus are not excessively small (can lead to instability with Euler)
    # A minimum tau (e.g., dt/10 or a small physiological limit) might be good.
    # For now, just ensure positive.
    tau_m = max(tau_m, 1e-12)
    tau_h = max(tau_h, 1e-12)

    dm_dt = (m_inf - m) / tau_m
    dh_dt = (h_inf - h) / tau_h
    return dm_dt, dh_dt
@njit
def calculate_INa_Nav18(Vm, m, h, ENa, g_max_Nav18 = 0.0087177):
    """ Calculates Nav1.8 Na+ current. """
    # Eq 12 from Mandge paper
    current = g_max_Nav18 * (m**3) * h * (Vm - ENa) # mA/cm^2 (assuming g_max unit is mS/cm^2)
    return current

# --- Nav1.9 Channel (TTX-Resistant, persistent) ---
# Kinetics from Baker (2005) via Mandge S2 Text Eq 1
@njit
def alpha_m9_Nav19(Vm):
    # Eq 1 (S2 Text)
    denom = (1.0 + np.exp((Vm - 11.01) / (-14.871)))
    # Ensure denom is not effectively zero before division
    return 1.548 / denom if abs(denom) > 1e-12 else 0.0 # ms^-1
@njit
def beta_m9_Nav19(Vm):
    # Eq 1 (S2 Text)
    denom = 1.0 + np.exp((Vm + 112.4) / 22.9)
    # Ensure denom is not effectively zero before division
    return 8.685 / denom if abs(denom) > 1e-12 else 0.0 # ms^-1
@njit
def alpha_h9_Nav19(Vm):
    # Eq 1 (S2 Text)
    denom = (1.0 + np.exp((Vm + 63.264) / 3.7193))
    # Ensure denom is not effectively zero before division
    return 0.2574 / denom if abs(denom) > 1e-12 else 0.0 # ms^-1
@njit
def beta_h9_Nav19(Vm):
    # Eq 1 (S2 Text)
    denom = (1.0 + np.exp((Vm + 0.27853) / (-9.0933)))
    # Ensure denom is not effectively zero before division
    return 0.53984 / denom if abs(denom) > 1e-12 else 0.0 # ms^-1
@njit
def state_vars_deriv_Nav19(Vm, m9, h9):
    """ Calculates derivatives for Nav1.9 channel gating variables m9 and h9. """
    am9 = alpha_m9_Nav19(Vm); bm9 = beta_m9_Nav19(Vm)
    ah9 = alpha_h9_Nav19(Vm); bh9 = beta_h9_Nav19(Vm)
    dm9_dt = am9 * (1.0 - m9) - bm9 * m9
    dh9_dt = ah9 * (1.0 - h9) - bh9 * h9
    return dm9_dt, dh9_dt
@njit
def calculate_INa_Nav19(Vm, m9, h9, ENa, g_max_Nav19 = 0.00001):
    """ Calculates Nav1.9 Na+ current. """
    # Eq 2 (S2 Text), Eq 1 from Baker (2005)
    current = g_max_Nav19 * m9 * h9 * (Vm - ENa) # mA/cm^2 (assuming g_max unit is mS/cm^2)
    return current

# --- Na+ Channels Combined Functions ---
@njit
def get_steady_state_Na_vars(Vm):
    """ Calculates steady-state values for all Na channel gating variables. """
    # TTXS
    am_ttxs = alpha_m_TTXS(Vm); bm_ttxs = beta_m_TTXS(Vm)
    ah_ttxs = alpha_h_TTXS(Vm); bh_ttxs = beta_h_TTXS(Vm)
    m_inf_ttxs = am_ttxs / (am_ttxs + bm_ttxs) if (am_ttxs + bm_ttxs) > 1e-12 else 0.0
    h_inf_ttxs = ah_ttxs / (ah_ttxs + bh_ttxs) if (ah_ttxs + bh_ttxs) > 1e-12 else 0.0
    # Nav1.8
    am_nav18 = alpha_m_Nav18(Vm); bm_nav18 = beta_m_Nav18(Vm)
    ah_nav18 = alpha_h_Nav18(Vm); bh_nav18 = beta_h_Nav18(Vm)
    # m_inf_nav18 = am_nav18 / (am_nav18 + bm_nav18) if (am_nav18 + bm_nav18) > 1e-12 else 0.0
    # h_inf_nav18 = ah_nav18 / (ah_nav18 + bh_nav18) if (ah_nav18 + bh_nav18) > 1e-12 else 0.0
    # Nav1.9
    m_inf_nav18 = 1 /(1+np.exp((-11.4-Vm)/8.5))
    h_inf_nav18 = 1 / (1+np.exp((Vm+24.2)/5.6))
    am9_nav19 = alpha_m9_Nav19(Vm); bm9_nav19 = beta_m9_Nav19(Vm)
    ah9_nav19 = alpha_h9_Nav19(Vm); bh9_nav19 = beta_h9_Nav19(Vm)
    m9_inf_nav19 = am9_nav19 / (am9_nav19 + bm9_nav19) if (am9_nav19 + bm9_nav19) > 1e-12 else 0.0
    h9_inf_nav19 = ah9_nav19 / (ah9_nav19 + bh9_nav19) if (ah9_nav19 + bh9_nav19) > 1e-12 else 0.0
    # Order: m_ttxs, h_ttxs, m_nav18, h_nav18, m9_nav19, h9_nav19
    return np.array([m_inf_ttxs, h_inf_ttxs, m_inf_nav18, h_inf_nav18, m9_inf_nav19, h9_inf_nav19])

@njit
def update_Na_vars(Vm, ch_vars, dt):
    """ Updates all Na+ channel gating variables over one time step dt. """
    mttx, httx, m8, h8, m9, h9 = ch_vars # Unpack state variables
    # Calculate derivatives
    dmttx_dt, dhttx_dt = state_vars_deriv_TTXS(Vm, mttx, httx)
    dm8_Nav18_dt, dh8_Nav18_dt = state_vars_deriv_Nav18(Vm, m8, h8)
    dm9_Nav19_dt, dh9_Nav19_dt = state_vars_deriv_Nav19(Vm, m9, h9)
    # Update using Euler method
    new_mttx = mttx + dt * dmttx_dt
    new_httx = httx + dt * dhttx_dt
    new_m8 = m8 + dt * dm8_Nav18_dt
    new_h8 = h8 + dt * dh8_Nav18_dt
    new_m9 = m9 + dt * dm9_Nav19_dt
    new_h9 = h9 + dt * dh9_Nav19_dt
    # Pack updated state variables
    ch_vars_updated = np.array([new_mttx, new_httx, new_m8, new_h8, new_m9, new_h9])
    return ch_vars_updated
@njit
def update_Na_vars_rk4(Vm, ch_vars, dt):
    """
    Updates all Na+ channel gating variables over one time step dt using the RK4 method.
    This function is optimized for Numba.
    """

    # Numba works well with helper functions defined inside the main jitted function.
    # This calculates the vector of derivatives for a given vector of state variables.
    def _get_all_derivs(current_vars):
        mttx, httx, m8, h8, m9, h9 = current_vars # Unpack state variables

        # Calculate derivatives by calling the other Numba-jitted functions
        dmttx_dt, dhttx_dt = state_vars_deriv_TTXS(Vm, mttx, httx)
        dm8_dt, dh8_dt = state_vars_deriv_Nav18(Vm, m8, h8)
        dm9_dt, dh9_dt = state_vars_deriv_Nav19(Vm, m9, h9)

        # Return all derivatives as a single NumPy array
        return np.array([dmttx_dt, dhttx_dt, dm8_dt, dh8_dt, dm9_dt, dh9_dt])

    # RK4 Stage 1: Calculate k1 (slope at the beginning)
    k1 = _get_all_derivs(ch_vars)

    # RK4 Stage 2: Calculate k2 (slope at the midpoint, estimated using k1)
    k2 = _get_all_derivs(ch_vars + 0.5 * dt * k1)

    # RK4 Stage 3: Calculate k3 (slope at the midpoint, estimated using k2)
    k3 = _get_all_derivs(ch_vars + 0.5 * dt * k2)

    # RK4 Stage 4: Calculate k4 (slope at the end, estimated using k3)
    k4 = _get_all_derivs(ch_vars + dt * k3)

    # Final update: Combine the slopes using a weighted average
    ch_vars_updated = ch_vars + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)

    return ch_vars_updated
@njit
def Na_current(Vm, Na_o, Na_i, ch_vars, nernst):
    """ Calculates the total Na+ current from all channels. """
    mttx, httx, m8, h8, m9, h9 = ch_vars # Unpack state variables
    E_Na = nernst * np.log(Na_o / Na_i) # Nernst potential for Na+

    # Calculate individual Na+ currents
    I_TTXs = calculate_INa_TTXS(Vm, mttx, httx, E_Na)
    I_Nav18 = calculate_INa_Nav18(Vm, m8, h8, E_Na)
    I_Nav19 = calculate_INa_Nav19(Vm, m9, h9, E_Na)

    # Sum currents
    total_current = I_TTXs + I_Nav18 + I_Nav19 # mA/cm^2
    return total_current

# =============================================================================
# --- Potassium (K+) Channels ---
# =============================================================================

# --- Slow A-type K+ Channel (KA) ---
# Kinetics from Mandge et al. (2018) Main Text, Eqs 14-16, 18
@njit
def n_inf_KA(Vm):
    # Eq 14
    return 1.0 / (1.0 + np.exp((-40.8 - Vm) / 9.5))
@njit
def h_inf_KA(Vm): # Common steady state for h_fast and h_slow
    # Eq 15
    return 1.0 / (1.0 + np.exp((Vm + 74.2) / 9.6))
@njit
def tau_n_KA(Vm):
    # Eq 16
    tau = 1.2 + 2.56 * np.exp(-2.0 * ((Vm + 60.0) / 45.76)**2) # ms
    return max(tau, 1e-9) # Ensure positive tau
@njit
def tau_h_fast_KA(Vm):
    # Eq 16
    tau = 25.46 + 67.41 * np.exp(-2.0 * ((Vm + 50.0) / 21.95)**2) # ms
    return max(tau, 1e-9) # Ensure positive tau
@njit
def tau_h_slow_KA(Vm):
    # Eq 16
    tau = 200.0 + 587.4 * np.exp(-(Vm / 47.77)**2) # ms
    return max(tau, 1e-9) # Ensure positive tau
@njit
def state_vars_deriv_KA(Vm, n, h_fast, h_slow):
    """ Calculates derivatives for KA channel gating variables n, h_fast, h_slow. """
    n_inf = n_inf_KA(Vm); tau_n = tau_n_KA(Vm)
    h_inf = h_inf_KA(Vm); tau_hf = tau_h_fast_KA(Vm); tau_hs = tau_h_slow_KA(Vm)
    dn_dt = (n_inf - n) / tau_n
    dh_fast_dt = (h_inf - h_fast) / tau_hf
    dh_slow_dt = (h_inf - h_slow) / tau_hs
    return dn_dt, dh_fast_dt, dh_slow_dt
@njit
def calculate_IKA(Vm, n, h_fast, h_slow, EK, g_max_KA =0.00108):# 0.00136
    """ Calculates KA current. """
    # Eq 18
    current = g_max_KA * n * (0.3 * h_fast + 0.7 * h_slow) * (Vm - EK) # mA/cm^2
    return current

# --- Delayed Rectifier K+ Channel (KDR) ---
# Kinetics from Mandge et al. (2018) Main Text, Eqs 19-20
@njit
def n_inf_KDR(Vm):
    # Eq 19
    return 1.0 / (1.0 + np.exp((-35.0 - Vm) / 15.4))
@njit
def tau_n_KDR(Vm):
    # Eq 19
    tau = 7.14 + 29.74 * np.exp(-2.0 * ((Vm + 20.0) / 17.08)**2) # ms
    return max(tau, 1e-9) # Ensure positive tau
@njit
def state_vars_deriv_KDR(Vm, n):
    """ Calculates derivative for KDR channel gating variable n. """
    n_inf = n_inf_KDR(Vm); tau_n = tau_n_KDR(Vm)
    dn_dt = (n_inf - n) / tau_n
    return dn_dt
@njit
def calculate_IKDR(Vm, n, EK, g_max_KDR = 0.002688):
    """ Calculates KDR current. """
    # Eq 20
    current = g_max_KDR * (n**4) * (Vm - EK) # mA/cm^2
    return current

# --- KCNQ/M K+ Channel (IKM) ---
# Kinetics from Mandge S2 Text Eqs 4-5
@njit
def alpha_n_IKM(Vm): # S2 Eq 4
    return 0.00395 * np.exp((Vm + 30.0) / 40.0) # ms^-1
@njit
def beta_n_IKM(Vm): # S2 Eq 4
    return 0.00395 * np.exp(-(Vm + 30.0) / 20.0) # ms^-1
@njit
def n_inf_IKM(Vm): # S2 Eq 4 (derived from alpha/beta)
    alpha = alpha_n_IKM(Vm); beta = beta_n_IKM(Vm)
    sum_ab = alpha + beta
    return alpha / sum_ab if sum_ab > 1e-9 else 0.0 # Avoid division by zero
@njit
def tau_n_IKM(Vm): # S2 Eq 4
    alpha = alpha_n_IKM(Vm); beta = beta_n_IKM(Vm)
    sum_ab = alpha + beta
    # Avoid division by zero and return a very small tau if sum is near zero
    return 1.0 / sum_ab if sum_ab > 1e-9 else 1.0 / 1e-9 # ms
@njit
def state_vars_deriv_IKM(Vm, n):
    """ Calculates derivative for IKM channel gating variable n. """
    n_inf = n_inf_IKM(Vm); tau_n = tau_n_IKM(Vm)
    # Ensure tau is not too small, preventing large steps
    tau_n = max(1e-12, tau_n)
    dn_dt = (n_inf - n) / tau_n
    return dn_dt
@njit
def calculate_IKM(Vm, n, EK, g_max_IKM = 0.0001):
    """ Calculates IKM current. """
    # S2 Eq 5
    current = g_max_IKM * n * (Vm - EK) # mA/cm^2
    return current

# --- Na+-Activated K+ Channel (IKNa) ---
# Kinetics from Mandge S2 Text Eq 6
@njit
def w_inf_IKNa(Nai_mM): # S2 Eq 6 - Na-dependence factor 'w_inf'
    Nai_mM = max(1e-12, Nai_mM) # Ensure Nai_mM is positive for the calculation
    EC50 = 38.7 # mM
    n_hill = 3.5
    ratio_pow = (EC50 / Nai_mM)**n_hill
    denom = 1.0 + ratio_pow
    # Avoid potential issues if ratio_pow becomes huge for very small Nai or vice versa
    if denom <= 1e-12: return 0.0 # Should not happen if Nai_mM > 0
    if np.isinf(denom): return 1.0 # Occurs if ratio_pow is inf (Nai -> 0)
    return 1.0 / denom
@njit
def state_vars_deriv_IKNa(Vm, w, Nai_mM): # Note: Vm not used in this specific model's kinetics
    """ Calculates derivative for IKNa channel gating variable w. """
    w_inf = w_inf_IKNa(Nai_mM)
    tau_w = 1.0 # ms, constant from S2 Eq 6
    dw_dt = (w_inf - w) / tau_w
    return dw_dt
@njit
def calculate_IKNa(Vm, w, EK, g_max_IKNa = 1.2e-6):#0.00001
    """ Calculates IKNa current. """
    # S2 Eq 6
    current = g_max_IKNa * w * (Vm - EK) # mA/cm^2 (assuming mS/cm^2 for g_max)
    return current

# --- Large-Conductance Ca2+-Activated K+ Channel (BKca) ---
# Kinetics from Mandge et al. (2018) Main Text, Eqs 21-24
@njit
def pCa_BKca(Cai_M):
    """ Helper function to calculate pCa = -log10([Ca]i in Molar). """
    Cai_M = max(1e-12, Cai_M) # Prevent log10(0) or log10(negative)
    return np.log10(Cai_M)
@njit
def V12_BKca(pCa): # Eq 21 - V1/2 based on pCa
    return -43.4 * pCa - 203.0 # mV
@njit
def sf_BKca(pCa): # Eq 21 - Slope factor based on pCa
    exponent_term = -((pCa + 5.42) / 1.85)**2
    # Prevent overflow/underflow in exp
    if exponent_term > 700: return 1e-12 # Avoid returning zero for slope, use small positive
    if exponent_term < -700: return 33.88 # Exp term approaches 0
    slope = 33.88 * np.exp(exponent_term)
    return max(1e-12, slope) # Ensure positive slope factor, mV
@njit
def n_inf_BKca(Vm, Cai_M): # Eq 22 (implicit) & 21 - Steady-state activation
    pCa_val = pCa_BKca(Cai_M)
    V_half = V12_BKca(pCa_val)
    slope_factor = sf_BKca(pCa_val)
    return 1.0 / (1.0 + np.exp((V_half - Vm) / slope_factor))
@njit
def tau_n_BKca(Vm): # Eq 23 (Independent of Cai in this model) - Time constant
    # Check for potential large positive Vm causing large exp term -> potential overflow
    exp_arg = Vm / 42.91
    if exp_arg > 700: # Avoid overflow
        # For large Vm, tau becomes large and dominated by exp term
        # Approximation: tau ~ 5.55 * exp(Vm / 42.91)
        # Return a very large number instead of Inf or error
        return 1e10 # Or some other large value indicating slow dynamics
    tau_val = 5.55 * np.exp(exp_arg) + 0.75 - 0.12 * Vm # ms
    return max(tau_val, 1e-12) # Ensure tau is positive
@njit
def state_vars_deriv_BKca(Vm, n, Cai_M):
    """ Calculates derivative for BKCa channel gating variable n. """
    n_inf = n_inf_BKca(Vm, Cai_M); tau_n = tau_n_BKca(Vm)
    # Ensure tau is not zero or negative before division
    tau_n = max(1e-12, tau_n)
    dn_dt = (n_inf - n) / tau_n
    return dn_dt
@njit
def calculate_IBKca(Vm, n, EK, g_max_BKca = 0.0009):
    """ Calculates BKCa current. """
    # Eq 24
    current = g_max_BKca * n * (Vm - EK) # mA/cm^2
    return current

# --- Small-Conductance Ca2+-Activated K+ Channel (SKCa) ---
# Kinetics based on interpolation data (presumably from fits, e.g., SK3?)

@njit
def get_SKCa_V12(Ca_i):
    """ Interpolates V1/2 for SKCa based on intracellular Ca (mM). """
    return np.interp(Ca_i, sk3_cai_data_mM, sk3_vh_data_mV)
@njit
def get_SKCa_sf(Ca_i):
    """ Interpolates slope factor for SKCa based on intracellular Ca (mM). """
    return np.interp(Ca_i, sk3_cai_data_mM, sk3_sf_data_mV)
@njit
def o_SKCa(Ca_i, EC = 0.42e-3): # EC50 seems to be in mM based on sk3_cai_data_mM units
    """ Calculates Ca-dependent open probability factor for SKCa. """
    # Hill equation form: Ca^n / (Ca^n + EC50^n), here n=5.6
    Ca_i_n = Ca_i**5.6
    return Ca_i_n / (Ca_i_n + EC**5.6) if (Ca_i_n + EC**5.6) > 1e-12 else 0.0
@njit
def m_BKCa(Vm, sf, V12, E_K):
    """
    Calculates voltage-dependent gating variable 'm'.
    SUSPICIOUS: Name suggests BKCa, but used within calculate_I_SKCa.
                Keeping name as requested, but structure is unusual.
                This assumes a Boltzmann function based on V1/2 and sf.
                The E_K term addition to V12 is unusual for a standard Boltzmann.
    """
    # Standard Boltzmann: 1 / (1 + exp((V1/2 - Vm) / sf))
    # Used form: 1 / (1 + exp((Vm - (E_K + V12)) / sf))
    # This implies the gating depends on the driving force (Vm - E_K) relative to -V12?
    # Or perhaps V12 is defined relative to E_K? Recheck source equations.
    exponent_arg = (Vm - (E_K + V12)) / sf if abs(sf) > 1e-12 else np.inf
    # Avoid overflow if exponent_arg is very large negative
    if exponent_arg < -700: return 1.0
    # Avoid overflow if exponent_arg is very large positive
    if exponent_arg > 700: return 0.0
    m = 1 / (1 + np.exp(exponent_arg))
    return m
@njit
def calculate_I_SKCa(Vm, Ca_i, E_K, g = 0.0009):
    """ Calculates SKCa current. """
    # Ca_i assumed in mM based on interpolation data units
    o = o_SKCa(Ca_i) # Calcium dependence factor
    sf = get_SKCa_sf(Ca_i) # Interpolated slope factor
    V12 = get_SKCa_V12(Ca_i) # Interpolated V1/2
    # SUSPICIOUS: Uses the m_BKCa named function for voltage dependence.
    m = m_BKCa(Vm, sf, V12, E_K) # Voltage dependence factor
    current = g * o * m * (Vm - E_K) # mA/cm^2
    return current

# --- K+ Channels Combined Functions ---
@njit
def get_steady_state_K_vars(Vm, Na_i, Ca_i):
    """ Calculates steady-state values for all K channel gating variables. """
    Cai_M = Ca_i * 1e-3 # Convert Ca_i from mM (assumed input) to Molar for BKCa functions
    # KA
    n_inf_ka = n_inf_KA(Vm)
    h_inf_ka = h_inf_KA(Vm) # Same steady state for fast/slow h
    # KDR
    n_inf_kdr = n_inf_KDR(Vm)
    # KCNQ/IKM
    n_inf_kcnq = n_inf_IKM(Vm)
    # IKNa
    w_inf_ikna = w_inf_IKNa(Na_i) # Na_i assumed in mM
    # BKCa
    n_inf_bkca = n_inf_BKca(Vm, Cai_M) # Requires Ca_i in Molar
    # SKCa - Note: SKCa calculation uses instantaneous factors based on Ca_i and Vm,
    # so it doesn't have separate state variables in this structure.
    # Order: n_KA, hf_KA, hs_KA, n_KCNQ, n_KDR, w_IKNa, n_BKCa
    return np.array([n_inf_ka, h_inf_ka, h_inf_ka, n_inf_kcnq, n_inf_kdr, w_inf_ikna, n_inf_bkca])

@njit
def update_K_vars(Vm, Na_i, ch_vars, Ca_i, dt):
    """ Updates all K+ channel gating variables over one time step dt. """
    n_A, hf_A, hs_A, nKCNQ, nKDR, w_IKNa, nBKCa = ch_vars # Unpack state variables
    Cai_M = Ca_i * 1e-3 # Convert Ca_i (assumed mM) to Molar for relevant functions

    # Calculate derivatives
    dn_A_dt, dhf_A_dt, dhs_A_dt = state_vars_deriv_KA(Vm, n_A, hf_A, hs_A)
    dnKCNQ_dt = state_vars_deriv_IKM(Vm, nKCNQ)
    dnKDR_dt = state_vars_deriv_KDR(Vm, nKDR)
    dw_IKNa_dt = state_vars_deriv_IKNa(Vm, w_IKNa, Na_i) # Na_i assumed mM
    dnBKCa_dt = state_vars_deriv_BKca(Vm, nBKCa, Cai_M) # Requires Ca_i in Molar

    # Update using Euler method
    new_n_A = n_A + dt * dn_A_dt
    new_hf_A = hf_A + dt * dhf_A_dt
    new_hs_A = hs_A + dt * dhs_A_dt
    new_nKCNQ = nKCNQ + dt * dnKCNQ_dt
    new_nKDR = nKDR + dt * dnKDR_dt
    new_w_IKNa = w_IKNa + dt * dw_IKNa_dt
    new_nBKCa = nBKCa + dt * dnBKCa_dt

    # Pack updated state variables
    ch_vars_updated = np.array([new_n_A, new_hf_A, new_hs_A, new_nKCNQ, new_nKDR, new_w_IKNa, new_nBKCa])
    return ch_vars_updated
@njit
def update_K_vars_rk4(Vm, Na_i, ch_vars, Ca_i, dt):
    """ Updates all K+ channel gating variables over one time step dt using the RK4 method. """

    # Helper function to compute the derivative vector for a given state vector
    def _get_all_derivs(current_vars):
        # Unpack the state variables from the input vector
        n_A, hf_A, hs_A, nKCNQ, nKDR, w_IKNa, nBKCa = current_vars
        Cai_M = Ca_i * 1e-3 # Convert Ca_i (assumed mM) to Molar

        # Calculate all derivatives
        dn_A_dt, dhf_A_dt, dhs_A_dt = state_vars_deriv_KA(Vm, n_A, hf_A, hs_A)
        dnKCNQ_dt = state_vars_deriv_IKM(Vm, nKCNQ)
        dnKDR_dt = state_vars_deriv_KDR(Vm, nKDR)
        dw_IKNa_dt = state_vars_deriv_IKNa(Vm, w_IKNa, Na_i)
        dnBKCa_dt = state_vars_deriv_BKca(Vm, nBKCa, Cai_M)

        # Pack derivatives into a single numpy array
        derivs = np.array([dn_A_dt, dhf_A_dt, dhs_A_dt, dnKCNQ_dt, dnKDR_dt, dw_IKNa_dt, dnBKCa_dt])
        return derivs

    # RK4 Stage 1: Calculate k1
    k1 = _get_all_derivs(ch_vars)

    # RK4 Stage 2: Calculate k2
    k2 = _get_all_derivs(ch_vars + 0.5 * dt * k1)

    # RK4 Stage 3: Calculate k3
    k3 = _get_all_derivs(ch_vars + 0.5 * dt * k2)

    # RK4 Stage 4: Calculate k4
    k4 = _get_all_derivs(ch_vars + dt * k3)

    # Final update: Combine the k values to get the new state
    ch_vars_updated = ch_vars + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)

    return ch_vars_updated
@njit
def K_current(Vm, K_o, K_i, Ca_i, ch_vars, nernst):
    """ Calculates the total K+ current from all channels. """
    n_A, hf_A, hs_A, nKCNQ, nKDR, w_IKNa, nBKCa = ch_vars # Unpack state variables
    # Assumes Na_i, Ca_i are available globally or passed if needed by steady state funcs if called here
    # For current calculation, only Ca_i (mM) needed for SKCa. Na_i implicitly handled by w_IKNa state var.
    E_K = nernst * np.log(K_o / K_i) # Nernst potential for K+

    # Calculate individual K+ currents
    I_A_type = calculate_IKA(Vm, n_A, hf_A, hs_A, E_K)
    I_KCNQ = calculate_IKM(Vm, nKCNQ, E_K)
    I_KDR = calculate_IKDR(Vm, nKDR, E_K)
    I_K_Na = calculate_IKNa(Vm, w_IKNa, E_K)
    I_BKCa = calculate_IBKca(Vm, nBKCa, E_K)
    I_SKCa = calculate_I_SKCa(Vm, Ca_i, E_K) # Ca_i assumed in mM

    # Sum currents
    total_current = I_A_type + I_KCNQ + I_KDR + I_K_Na + I_BKCa + I_SKCa # mA/cm^2
    return total_current

# =============================================================================
# --- Calcium (Ca2+) Channels ---
# =============================================================================
# Note: All Ca channel currents use GHK equation. Input Ca concentrations (Ca_i, Ca_o)
#       are assumed to be in mM and are converted to Molar within the GHK function.
#       The CDI factors (hca_inf_...) expect input Ca_i in Molar.

# --- L-type Ca2+ Channel (CaV1) ---
# Kinetics from Mandge et al. (2018) Main Text, Fig 7, Eqs 27-30
@njit
def m_inf_L(Vm): # Eq 27
    return 1.0 / (1.0 + np.exp((8.46 - Vm) / 4.26))
@njit
def h_inf_L(Vm): # Eq 27
    return 1.0 / (1.0 + np.exp((Vm + 42.52) / 7.48))
@njit
def hca_inf_L(Cai_M): # Eq 27 - Calcium-dependent inactivation factor 'hca'
    # Expects Cai_M in Molar
    # SUSPICIOUS: Denominator form 1 + (Cai/K)^n usually represents binding/activation.
    #             For inactivation, often K^n / (K^n + Cai^n) or similar. Recheck source Eq 27.
    #             Assuming the formula as written is intended.
    Ca_i_norm = Cai_M / 0.001 # K_d = 1 uM = 0.001 mM = 1e-6 M
    # The original code used Cai_M / 0.001, implying Cai_M was expected in mM?
    # Let's assume Cai_M input is Molar as per GHK usage. K_d is 1e-6 M.
    Ca_i_norm_M = Cai_M / 1e-6 # Normalise by K_d = 1 uM
    hca_denom = 1.0 + (Ca_i_norm_M)**4
    # Check for potential division by zero/small number, although should be >= 1
    return 1.0 / hca_denom if hca_denom > 1e-12 else 0.0
@njit
def tau_m_L(Vm): # Eq 28
    exponent_term = -2.0 * ((Vm + 10.0) / 16.02)**2
    # Prevent underflow if exponent_term is very negative
    if exponent_term < -700: return 2.11 # Exp term approaches 0
    tau = 2.11 + 3.86 * np.exp(exponent_term) # ms
    return max(tau, 1e-12) # Ensure positive tau
@njit
def tau_h_L(Vm): # Eq 28
    exponent_term = -2.0 * ((Vm) / 39.75)**2
    # Prevent underflow if exponent_term is very negative
    if exponent_term < -700: return 825.80 # Exp term approaches 0
    tau = 825.80 + 637.91 * np.exp(exponent_term) # ms
    return max(tau, 1e-12) # Ensure positive tau
@njit
def state_vars_deriv_L(Vm, m, h):
    """ Calculates derivatives for L-type Ca channel gating variables m and h. """
    m_inf = m_inf_L(Vm); tau_m = tau_m_L(Vm)
    h_inf = h_inf_L(Vm); tau_h = tau_h_L(Vm)
    dm_dt = (m_inf - m) / tau_m
    dh_dt = (h_inf - h) / tau_h
    return dm_dt, dh_dt
@njit
def calculate_ICa_L(Vm, m, h, Cai_M, Cao_M, T, Pmax_L = 2.75e-5): # Eq 30
    """ Calculates L-type Ca current using GHK. Cai_M, Cao_M expected in Molar. """
    hca = hca_inf_L(Cai_M) # CDI is instantaneous, requires Cai in Molar
    Perm = Pmax_L * m * h * hca # Effective permeability (cm/s)
    current = calculate_GHK(Vm, Perm, C_in_M = Cai_M, C_out_M = Cao_M, z = 2, T = T) # mA/cm^2
    return current

# --- N-type Ca2+ Channel (CaV2.2) ---
# Kinetics from Mandge et al. (2018) Main Text, Fig 8, Eqs 31-34

@njit
def interp_tau_h_N(Vm): # Based on Fig 8B data
    """ Interpolates tau_h for N-type Ca channel based on Vm. """
    tau = np.interp(Vm, N_tau_h_Vm, N_tau_h_ms)
    return max(tau, 1e-9) # Ensure positive tau
@njit
def m_inf_N(Vm): # Eq 31
    return 1.0 / (1.0 + np.exp((-6.5 - Vm) / 6.5))
@njit
def h_inf_N(Vm): # Eq 31
    return 1.0 / (1.0 + np.exp((Vm + 70.0) / 12.5))
@njit
def hca_inf_N(Cai_M): # Eq 31 - Calcium-dependent inactivation factor 'hca' (same form as L-type)
    # Expects Cai_M in Molar
    # Same SUSPICIOUS form as hca_inf_L. Recheck source Eq 31.
    Ca_i_norm_M = Cai_M / 1e-6 # Normalise by K_d = 1 uM
    denom = 1.0 + (Ca_i_norm_M)**4
    return 1.0 / denom if denom > 1e-12 else 0.0
@njit
def tau_m_N(Vm): # Eq 32
    exponent_term = -2.0 * ((Vm + 20.0) / 15.0)**2
    # Prevent underflow
    if exponent_term < -700: return 0.8 # Exp term approaches 0
    tau = 0.8 + 5.38 * np.exp(exponent_term) # ms
    return max(tau, 1e-12) # Ensure positive tau
@njit
def state_vars_deriv_N(Vm, m, h):
    """ Calculates derivatives for N-type Ca channel gating variables m and h. """
    m_inf = m_inf_N(Vm); tau_m = tau_m_N(Vm)
    h_inf = h_inf_N(Vm); tau_h = interp_tau_h_N(Vm) # Uses lookup table for tau_h
    dm_dt = (m_inf - m) / tau_m
    dh_dt = (h_inf - h) / tau_h
    return dm_dt, dh_dt
@njit
def calculate_ICa_N(Vm, m, h, Cai_M, Cao_M, T, Pmax_N = 2.8e-5, a_N = 0.7326): # Eq 33
    """ Calculates N-type Ca current using GHK. Cai_M, Cao_M expected in Molar. """
    hca = hca_inf_N(Cai_M) # CDI is instantaneous, requires Cai in Molar
    # Incorporate incomplete inactivation factor 'a_N'
    Perm = Pmax_N * m * (a_N * h + (1.0 - a_N)) * hca # Effective permeability (cm/s)
    current = calculate_GHK(Vm, Perm, C_in_M = Cai_M, C_out_M = Cao_M, z = 2, T = T) # mA/cm^2
    return current

# --- P/Q-type Ca2+ Channel (CaV2.1) ---
# Kinetics from Mandge S2 Text, Fig D, Eqs 7-8
# Note: This model lacks explicit voltage-dependent inactivation (h) and CDI (hca) in provided funcs. Check S2 source.
@njit
def m_inf_PQ(Vm): # Eq 7 (S2 Text)
    return 1.0 / (1.0 + np.exp((-5.1 - Vm) / 3.1))
@njit
def tau_m_PQ(Vm): # Eq 7 (S2 Text)
    exponent_term = -2.0 * ((Vm + 9.73) / 18.14)**2
    # Prevent underflow
    if exponent_term < -700: return 0.35 # Exp term approaches 0
    tau = 0.35 + 5.51 * np.exp(exponent_term) # ms
    return max(tau, 1e-12) # Ensure positive tau
@njit
def state_vars_deriv_PQ(Vm, m):
    """ Calculates derivative for P/Q-type Ca channel gating variable m. """
    m_inf = m_inf_PQ(Vm); tau_m = tau_m_PQ(Vm)
    dm_dt = (m_inf - m) / tau_m
    return dm_dt
@njit
def calculate_ICa_PQ(Vm, m, Cai_M, Cao_M, T, P_max = 8e-6): # Eq 8 (S2 Text - implied)
    """ Calculates P/Q-type Ca current using GHK. Cai_M, Cao_M expected in Molar. """
    # Assumes no voltage inactivation (h) or CDI (hca) based on provided functions.
    # Check S2 source if inactivation is present.
    Perm = P_max * m # Effective permeability (cm/s)
    current = calculate_GHK(Vm, Perm, C_in_M = Cai_M, C_out_M = Cao_M, z = 2, T = T) # mA/cm^2
    return current

# --- R-type Ca2+ Channel (CaV2.3) ---
# Kinetics from Mandge S2 Text, Fig E, Eqs 9-11
@njit
def interp_tau_m_R(Vm):
    """ Interpolates tau_m for R-type Ca channel based on Vm. """
    tau = np.interp(Vm, R_tau_m_Vm, R_tau_m_ms)
    return max(tau, 1e-12) # Ensure positive tau
@njit
def interp_tau_h_fast_R(Vm):
    """ Interpolates tau_h_fast for R-type Ca channel based on Vm. """
    tau = np.interp(Vm, R_tau_hf_Vm, R_tau_hf_ms)
    return max(tau, 1e-12) # Ensure positive tau
@njit
def interp_tau_h_slow_R(Vm):
    """ Interpolates tau_h_slow for R-type Ca channel based on Vm. """
    tau = np.interp(Vm, R_tau_hs_Vm, R_tau_hs_ms)
    return max(tau, 1e-12) # Ensure positive tau
@njit
def m_inf_R(Vm): # Eq 9 (S2 Text)
    return 1.0 / (1.0 + np.exp((-5.0 - Vm) / 5.0))
@njit
def h_inf_R(Vm): # Eq 9 (S2 Text) - Common steady state for h_fast, h_slow
    return 1.0 / (1.0 + np.exp((Vm + 51.0) / 12.0))
@njit
def state_vars_deriv_R(Vm, m, h_fast, h_slow):
    """ Calculates derivatives for R-type Ca channel gating variables m, h_fast, h_slow. """
    m_inf = m_inf_R(Vm); tau_m = interp_tau_m_R(Vm)
    h_inf = h_inf_R(Vm); tau_hf = interp_tau_h_fast_R(Vm); tau_hs = interp_tau_h_slow_R(Vm)
    dm_dt = (m_inf - m) / tau_m
    dh_fast_dt = (h_inf - h_fast) / tau_hf
    dh_slow_dt = (h_inf - h_slow) / tau_hs
    return dm_dt, dh_fast_dt, dh_slow_dt
@njit
def calculate_ICa_R(Vm, m, h_fast, h_slow, Cai_M, Cao_M, T, Pmax_R = 1e-8): # Eq 11 (S2 Text)
    """ Calculates R-type Ca current using GHK. Cai_M, Cao_M expected in Molar. """
    # Assumes no CDI based on provided functions. Check S2 source.
    Perm = Pmax_R * m * (0.4 * h_fast + 0.6 * h_slow) # Effective permeability (cm/s)
    current = calculate_GHK(Vm, Perm, C_in_M = Cai_M, C_out_M = Cao_M, z = 2, T = T) # mA/cm^2
    return current

# --- T-type Ca2+ Channel (CaV3) ---
# Kinetics from Mandge S2 Text, Fig F, Eqs 12-13
@njit
def interp_tau_m_T(Vm):
    """ Interpolates tau_m for T-type Ca channel based on Vm. """
    tau = np.interp(Vm, T_tau_m_Vm, T_tau_m_ms)
    return max(tau, 1e-12) # Ensure positive tau
@njit
def interp_tau_h_T(Vm):
    """ Interpolates tau_h for T-type Ca channel based on Vm. """
    tau = np.interp(Vm, T_tau_h_Vm, T_tau_h_ms)
    return max(tau, 1e-12) # Ensure positive tau
@njit
def m_inf_T(Vm): # Eq 12 (S2 Text)
    return 1.0 / (1.0 + np.exp((-55.29 - Vm) / 6.38))
@njit
def h_inf_T(Vm): # Eq 12 (S2 Text)
    return 1.0 / (1.0 + np.exp((Vm + 76.59) / 4.46))
@njit
def state_vars_deriv_T(Vm, m, h):
    """ Calculates derivatives for T-type Ca channel gating variables m and h. """
    m_inf = m_inf_T(Vm); tau_m = interp_tau_m_T(Vm)
    h_inf = h_inf_T(Vm); tau_h = interp_tau_h_T(Vm)
    dm_dt = (m_inf - m) / tau_m
    dh_dt = (h_inf - h) / tau_h
    return dm_dt, dh_dt
@njit
def calculate_ICa_T(Vm, m, h, Cai_M, Cao_M, T, Pmax_T = 1e-8): # Eq 13 (S2 Text)
    """ Calculates T-type Ca current using GHK. Cai_M, Cao_M expected in Molar. """
    # Assumes no CDI based on provided functions. Check S2 source.
    Perm = Pmax_T * m * h # Effective permeability (cm/s)
    current = calculate_GHK(Vm, Perm, C_in_M = Cai_M, C_out_M = Cao_M, z = 2, T = T) # mA/cm^2
    return current

# --- Ca2+ Channels Combined Functions ---
@njit
def get_steady_state_Ca_vars(Vm):
    """ Calculates steady-state values for Ca channel gating variables (excluding CDI). """
    # L-type
    m_inf_l = m_inf_L(Vm); h_inf_l = h_inf_L(Vm)
    # N-type
    m_inf_n = m_inf_N(Vm); h_inf_n = h_inf_N(Vm)
    # P/Q-type
    m_inf_pq = m_inf_PQ(Vm) # No h variable in this model version
    # R-type
    m_inf_r = m_inf_R(Vm); h_inf_r = h_inf_R(Vm) # Common steady state for fast/slow h
    # T-type
    m_inf_t = m_inf_T(Vm); h_inf_t = h_inf_T(Vm)
    # Order: m_L, h_L, m_N, h_N, m_PQ, m_R, h_fast_R, h_slow_R, m_T, h_T
    return np.array([
        m_inf_l, h_inf_l, m_inf_n, h_inf_n, m_inf_pq,
        m_inf_r, h_inf_r, h_inf_r, # Initialize h_fast/slow to common steady state
        m_inf_t, h_inf_t
    ])

@njit
def update_Ca_vars(Vm, Ca_i, ch_vars, dt):
    """ Updates all Ca+ channel gating variables over one time step dt. """
    # Unpack state variables based on get_steady_state_Ca_vars order
    m_L, h_L, m_N, h_N, m_PQ, m_R, hf_R, hs_R, m_T, h_T, = ch_vars
    # Note: Ca_i (assumed mM) is passed but CDI factors (hca) are calculated instantaneously
    # within the current calculation functions, not updated here as state variables.

    # Calculate derivatives
    dm_L_dt, dh_L_dt = state_vars_deriv_L(Vm, m_L, h_L)
    dm_N_dt, dh_N_dt = state_vars_deriv_N(Vm, m_N, h_N)
    dm_PQ_dt = state_vars_deriv_PQ(Vm, m_PQ) # Only m derivative
    dm_R_dt, dhf_R_dt, dhs_R_dt = state_vars_deriv_R(Vm, m_R, hf_R, hs_R)
    dm_T_dt, dh_T_dt = state_vars_deriv_T(Vm, m_T, h_T)

    # Update using Euler method
    new_m_L = m_L + dt * dm_L_dt; new_h_L = h_L + dt * dh_L_dt
    new_m_N = m_N + dt * dm_N_dt; new_h_N = h_N + dt * dh_N_dt
    new_m_PQ = m_PQ + dt * dm_PQ_dt
    new_m_R = m_R + dt * dm_R_dt; new_hf_R = hf_R + dt * dhf_R_dt; new_hs_R = hs_R + dt * dhs_R_dt
    new_m_T = m_T + dt * dm_T_dt; new_h_T = h_T + dt * dh_T_dt

    
    # Pack updated state variables
    ch_vars_updated = np.array([
        new_m_L, new_h_L, new_m_N, new_h_N, new_m_PQ,
        new_m_R, new_hf_R, new_hs_R,
        new_m_T, new_h_T,
        
    ])
    return ch_vars_updated
@njit
def update_Ca_vars_rk4(Vm, Ca_i, ch_vars, dt):
    """
    Updates all Ca+ channel gating variables over one time step dt using the RK4 method.
    This function is optimized for Numba.
    """

    # Helper function to compute the derivative vector for a given state vector.
    # Numba can efficiently compile and use such nested helper functions.
    def _get_all_derivs(current_vars):
        # Unpack the 10 state variables from the input vector
        m_L, h_L, m_N, h_N, m_PQ, m_R, hf_R, hs_R, m_T, h_T = current_vars

        # Calculate all derivatives by calling the respective (jitted) functions
        dm_L_dt, dh_L_dt = state_vars_deriv_L(Vm, m_L, h_L)
        dm_N_dt, dh_N_dt = state_vars_deriv_N(Vm, m_N, h_N)
        dm_PQ_dt = state_vars_deriv_PQ(Vm, m_PQ)
        dm_R_dt, dhf_R_dt, dhs_R_dt = state_vars_deriv_R(Vm, m_R, hf_R, hs_R)
        dm_T_dt, dh_T_dt = state_vars_deriv_T(Vm, m_T, h_T)

        # Pack all 10 derivatives into a single NumPy array in the correct order
        return np.array([
            dm_L_dt, dh_L_dt,
            dm_N_dt, dh_N_dt,
            dm_PQ_dt,
            dm_R_dt, dhf_R_dt, dhs_R_dt,
            dm_T_dt, dh_T_dt
        ])

    # RK4 Stage 1: Calculate k1 (slope at the beginning of the interval)
    k1 = _get_all_derivs(ch_vars)

    # RK4 Stage 2: Calculate k2 (slope at the midpoint, estimated using k1)
    k2 = _get_all_derivs(ch_vars + 0.5 * dt * k1)

    # RK4 Stage 3: Calculate k3 (slope at the midpoint, estimated using k2)
    k3 = _get_all_derivs(ch_vars + 0.5 * dt * k2)

    # RK4 Stage 4: Calculate k4 (slope at the end of the interval, estimated using k3)
    k4 = _get_all_derivs(ch_vars + dt * k3)

    # Final update: Combine the four slopes using a weighted average for a more accurate result
    ch_vars_updated = ch_vars + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)

    return ch_vars_updated

@njit
def Ca_current(Vm, Ca_o, Ca_i, ch_vars, T):
    """ Calculates the total Ca2+ current from all channels. """
    # Assumes Ca_o, Ca_i are in mM. Conversion to Molar happens in GHK/CDI funcs.
    Cao_M = Ca_o * 1e-3
    Cai_M = Ca_i * 1e-3
    # Unpack state variables based on get_steady_state_Ca_vars order
    m_L, h_L, m_N, h_N, m_PQ, m_R, hf_R, hs_R, m_T, h_T = ch_vars

    # Calculate individual Ca2+ currents
    I_Ltype_Ca = calculate_ICa_L(Vm, m_L, h_L, Cai_M, Cao_M, T)
    I_Ntype_Ca = calculate_ICa_N(Vm, m_N, h_N, Cai_M, Cao_M, T)
    I_PQtype_Ca = calculate_ICa_PQ(Vm, m_PQ, Cai_M, Cao_M, T)
    I_Rtype_Ca = calculate_ICa_R(Vm, m_R, hf_R, hs_R, Cai_M, Cao_M, T)
    I_Ttype_Ca = calculate_ICa_T(Vm, m_T, h_T, Cai_M, Cao_M, T)

    # Sum currents
    total_current = I_Ltype_Ca + I_Ntype_Ca + I_PQtype_Ca + I_Rtype_Ca + I_Ttype_Ca # mA/cm^2
    return total_current

# =============================================================================
# --- Hyperpolarization-Activated Cation Channel (HCN / Ih) ---
# =============================================================================
# Kinetics based on Mandge S2 Text Eq 14 and related text/figures for time constants

@njit
def calculate_HCN_inf(Vm):
    """ Calculates steady-state activation for HCN channel. """
    # Eq 14 in S2 Text
    return 1 / (1 + np.exp((Vm + 87.2) / 9.7))
@njit
def get_HCN_taus(Vm):
    """ Calculates time constants (fast and slow) for HCN channel. """
    # Based on description related to update_HCN_vars
    if Vm < -70:
        tmf = 250 + 12 * np.exp((Vm + 240) / 50)
        tms = 2500 + 100 * np.exp((Vm + 240) / 50)
    else:
        tmf = 140 + 50 * np.exp((Vm + 25) / -20)
        tms = 300 + 542 * np.exp((Vm + 25) / (-20))
    # Ensure taus are positive
    tmf = max(1e-9, tmf)
    tms = max(1e-9, tms)
    return tmf, tms # ms
@njit
def get_steady_state_HCN_vars(Vm):
    """ Calculates steady-state values for HCN channel gating variables (mf, ms). """
    m_inf = calculate_HCN_inf(Vm) # Common steady state for fast and slow components
    # Order: mf, ms
    return np.array([m_inf, m_inf])
@njit
def update_HCN_vars(Vm, ch_vars, dt):
    """ Updates HCN channel gating variables (mf, ms) over one time step dt. """
    mf, ms = ch_vars # Unpack state variables
    m_inf = calculate_HCN_inf(Vm)
    tmf, tms = get_HCN_taus(Vm)

    # Calculate derivatives
    dmf_dt = (m_inf - mf) / tmf
    dms_dt = (m_inf - ms) / tms

    # Update using Euler method
    new_mf = mf + dmf_dt * dt
    new_ms = ms + dms_dt * dt

    # Pack updated state variables
    ch_vars_updated = np.array([new_mf, new_ms])
    return ch_vars_updated
@njit
def update_HCN_vars_rk4(Vm, ch_vars, dt):
    """
    Updates HCN channel gating variables (mf, ms) over one time step dt using the RK4 method.
    """
    # These values depend only on Vm and are constant for the duration of the time step.
    # Calculating them once here prevents redundant computations inside the RK4 stages.
    m_inf = calculate_HCN_inf(Vm)
    tmf, tms = get_HCN_taus(Vm)

    # Helper function to compute the derivatives for a given state vector [mf, ms].
    # It uses m_inf, tmf, and tms from the parent function's scope.
    def _get_derivs(current_vars):
        mf, ms = current_vars
        dmf_dt = (m_inf - mf) / tmf
        dms_dt = (m_inf - ms) / tms
        return np.array([dmf_dt, dms_dt])

    # RK4 Stage 1: Calculate k1 (slope at the beginning of the interval)
    k1 = _get_derivs(ch_vars)

    # RK4 Stage 2: Calculate k2 (slope at the midpoint, estimated using k1)
    k2 = _get_derivs(ch_vars + 0.5 * dt * k1)

    # RK4 Stage 3: Calculate k3 (slope at the midpoint, estimated using k2)
    k3 = _get_derivs(ch_vars + 0.5 * dt * k2)

    # RK4 Stage 4: Calculate k4 (slope at the end of the interval, estimated using k3)
    k4 = _get_derivs(ch_vars + dt * k3)

    # Final update: Combine the slopes using a weighted average for higher accuracy
    ch_vars_updated = ch_vars + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)

    return ch_vars_updated
@njit
def HCN_current(Vm, ch_vars, Eh = -30,gf = 1.24e-4, gs = 6.7615e-5):# gf = 1.352e-5
    """ Calculates the HCN current (Ih). """
    # Units of g mS/cm^2 for current in mA/cm^2
    mf, ms = ch_vars # Unpack state variables
    I_HCN = (gf * mf + gs * ms) * (Vm - Eh) # mA/cm^2
    return I_HCN

# =============================================================================
# --- Chloride (Cl-) Channels ---
# =============================================================================

# --- Calcium-Activated Chloride Channel (CaCC / Ano) ---
# Kinetics based on custom implementation (check source for equations)
@njit
def n_inf_CaCC(Vm, Ca_i):
    """ Calculates steady-state activation for CaCC based on Vm and Ca_i (mM). """
    # Assumes Ca_i is in mM based on typical EC50 values
    # Check source for Vm-dependent HC and EC50 equations
    if Vm < -100:
        HC = 1.012 # Hill Coefficient
        EC50 = 5.798 # EC50 in mM
    else:
        # Avoid potential large negative exp for large positive Vm
        exp_arg_hc = -Vm / 81.02
        exp_term_hc = np.exp(exp_arg_hc) if exp_arg_hc > -700 else 0.0
        HC = -0.3126 * exp_term_hc + 2.086

        exp_arg_ec50 = -Vm / 38.307
        exp_term_ec50 = np.exp(exp_arg_ec50) if exp_arg_ec50 > -700 else 0.0
        EC50 = 0.39175 * exp_term_ec50 + 0.468

    # Ensure EC50 and Ca_i are positive before power/division
    EC50 = max(1e-12, EC50)
    Ca_i = max(1e-12, Ca_i)
    ratio_pow = (EC50 / Ca_i)**HC
    denom = 1 + ratio_pow
    # Handle potential issues with large HC or ratio
    if denom <= 1e-12 or np.isinf(denom) or np.isnan(denom):
        # If Ca_i >> EC50, ratio -> 0, n_inf -> 1
        # If EC50 >> Ca_i, ratio -> inf, n_inf -> 0 (for positive HC)
        # Limit based on ratio
        return 1.0 if (EC50 / Ca_i) < 1e-6 else 0.0 # Approximate limits
    return 1 / denom
@njit
def state_vars_deriv_CaCC(Vm, n, Ca_i):
    """ Calculates derivative for CaCC gating variable n. Assumes tau = 1 ms. """
    # Assumes Ca_i is in mM
    n_inf = n_inf_CaCC(Vm, Ca_i)
    # SUSPICIOUS: Derivative is just (n_inf - n). This implies tau = 1 ms. Verify if intended.
    dndt = (n_inf - n) / 1.0 # Assuming tau = 1 ms
    return dndt
@njit
def get_steady_state_Cl_vars(Vm, Ca_i):
    """ Calculates steady-state value for CaCC gating variable n. """
    # Assumes Ca_i is in mM
    n_inf_CaCC_val = n_inf_CaCC(Vm, Ca_i)
    # Order: n_CaCC
    return np.array([n_inf_CaCC_val])
@njit
def update_Cl_vars(Vm, ch_vars, Ca_i, dt):
    """ Updates CaCC gating variable n over one time step dt. """
    # Assumes Ca_i is in mM
    n = ch_vars[0] # Unpack state variable (only one for Cl channels here)
    dn_dt = state_vars_deriv_CaCC(Vm, n, Ca_i) # Calculate derivative
    new_n = n + dn_dt * dt # Update using Euler method
    # Pack updated state variables
    ch_vars_updated = np.array([new_n])
    return ch_vars_updated

@njit
def update_Cl_vars_rk4(Vm, ch_vars, Ca_i, dt):
    """
    Updates CaCC gating variable n over one time step dt using the RK4 method.
    This function is optimized for Numba.
    """

    # Helper function to compute the derivative for a given state.
    # It takes a 1-element array, unpacks it, calculates the derivative,
    # and returns the result as a 1-element array.
    def _get_deriv(current_vars):
        n = current_vars[0] # Unpack the single state variable
        dn_dt = state_vars_deriv_CaCC(Vm, n, Ca_i)
        return np.array([dn_dt])

    # RK4 Stage 1: Calculate k1 (slope at the beginning)
    k1 = _get_deriv(ch_vars)

    # RK4 Stage 2: Calculate k2 (slope at the midpoint, estimated using k1)
    k2 = _get_deriv(ch_vars + 0.5 * dt * k1)

    # RK4 Stage 3: Calculate k3 (slope at the midpoint, estimated using k2)
    k3 = _get_deriv(ch_vars + 0.5 * dt * k2)

    # RK4 Stage 4: Calculate k4 (slope at the end, estimated using k3)
    k4 = _get_deriv(ch_vars + dt * k3)

    # Final update: Combine the slopes using a weighted average
    ch_vars_updated = ch_vars + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)

    return ch_vars_updated
@njit
def calculate_ICaCC(Vm, n, Ca_i, E_Cl, g = 1e-6): # g seems very small, verify units (mS/cm^2?)
    """ Calculates CaCC current. """
    # Note: Ca_i is passed but not used directly in this specific current formula,
    #       its effect is captured in the state variable 'n'.
    # Units of g likely mS/cm^2 for current in mA/cm^2
    current = g * n * (Vm - E_Cl) # mA/cm^2
    return current
@njit
def Cl_current(Vm, Ca_i, ch_vars, E_Cl = -32.7):
    """ Calculates the total Cl- current (only CaCC in this model). """
    # Assumes Ca_i is in mM
    n = ch_vars[0] # Unpack state variable
    I_CaCC = calculate_ICaCC(Vm, n, Ca_i, E_Cl)
    total_current = I_CaCC # mA/cm^2
    return total_current

# =============================================================================
# --- Passive Leak Current ---
# =============================================================================

@njit
def calculate_Ipas(Vm, Epas = -43, gpas =0.0001):# 0.0001
    """
    Calculates the passive leak current.
    """
    return gpas * (Vm - Epas) # mA/cm^2

# =============================================================================
# --- Ion Pumps and Exchangers ---
# =============================================================================

@njit
def KNa_pump(Na_i, C, area, imaxh = 1.62, imaxl = 0.99, a = 0.9475):
    """
    Calculates the Na+/K+ pump current contribution.
    Returns fluxes scaled for Na+ and K+.
    Args:
        Na_i (float): Intracellular Na+ concentration (mM).
        C (float): Membrane capacitance (Farads).
        area (float): Membrane area (cm^2).
        imaxh, imaxl: Pump parameters (pA/pF).
        a (float): scaling constant, unitless.
    Returns:
        tuple: (Na+ current density, K+ current density) due to pump (mA/cm^2).
    """
    #Cm = C * 1e12 / area # Convert Farads to micF, then divide by area (micF/cm^2) 
    Cm = C*1e3/area
    ## error - Cm in microF/cm2? not in pF
    # Ensure Na_i is positive for division
    Na_i = max(1e-12, Na_i)
    # Pump current calculation - Units depend heavily on units of a, imaxh, imaxl.
    
    Ipump_h = a*(imaxh / (1 + (6.7 / Na_i)**3))
    Ipump_l = (imaxl / (1 + (67.6 / Na_i)**3))
    Ipump = Cm*(Ipump_h + Ipump_l)
    # Return Na+ and K+ components based on 3:2 stoichiometry
    # Positive current is outward flux. Pump moves 3 Na+ out, 2 K+ in.
    I_Na_pump = 3 * Ipump   # Outward Na+ current
    I_K_pump = -2 * Ipump  # Inward K+ current
    return I_Na_pump, I_K_pump

@njit
def NCX(Vm, Na_i, Na_o, Ca_i, Ca_o, nernst, K_Na = 87.5, K_Ca = 1.38, Imax = 1.1e-5):
    """
    Calculates the Na+/Ca2+ exchanger current contribution.
    Returns fluxes scaled for Na+ and Ca2+.
    SUSPICIOUS: Units need verification. K_Na, K_Ca are binding constants (mM).
                Imax unit? (e.g., A/F, A/cm^2?). Assuming result is current density (mA/cm^2).
                Formula from S2 Eq 31-33 needs verification. Denominator term (1+0.1*KB) seems unusual.
    Args:
        Vm (float): Membrane potential (mV).
        Na_i, Na_o, Ca_i, Ca_o (float): Ion concentrations (mM).
        nernst (float): RT/F in mV.
        K_Na, K_Ca (float): Binding constants (mM).
        Imax (float): Maximum exchange rate (units?).
    Returns:
        tuple: (Na+ current density, Ca2+ current density) due to NCX (mA/cm^2).
    """
    # Ensure concentrations are positive for calculations
    Na_i = max(1e-12, Na_i); Na_o = max(1e-12, Na_o)
    Ca_i = max(1e-12, Ca_i); Ca_o = max(1e-12, Ca_o)

    # Voltage-dependent factors based on S2 Eq 31-33 (check source)
    Kqa = np.exp(0.35 * Vm / nernst)
    KB = np.exp(-0.65 * Vm / nernst)

    # Numerator: Driving force term
    numerator = (Ca_o * Kqa * Na_i**3 - KB * Ca_i * Na_o**3)
    # Denominator: Saturation terms
    denom_Na = (K_Na**3 + Na_o**3)
    denom_Ca = (K_Ca + Ca_o)
    # SUSPICIOUS: Denominator factor (1 + 0.1 * KB) - recheck source S2 Eq 33.
    denom_Vdep = (1 + 0.1 * KB)
    denominator = denom_Na * denom_Ca * denom_Vdep

    # Calculate total exchange current density (assuming Imax gives density in mA/cm^2)
    I_NCX_total = Imax * numerator / denominator if abs(denominator) > 1e-12 else 0.0

    # Return Na+ and Ca2+ components based on 3:1 stoichiometry
    # Forward mode (Ca out, Na in): numerator > 0, I_NCX_total > 0. Needs 3 Na+ in (negative current), 1 Ca2+ out (positive current).
    # Reverse mode (Ca in, Na out): numerator < 0, I_NCX_total < 0. Needs 3 Na+ out (positive current), 1 Ca2+ in (negative current).
    # Stoichiometry: 3 Na+ per 1 Ca2+. Charge movement: 3*Na(+1) vs 1*Ca(+2) -> net charge +1 per cycle in forward mode.
    # If I_NCX_total represents the net charge movement:
    # I_Na = -3 * I_NCX_total  (If forward mode needs Na IN) ?? This seems wrong sign.
    # Let's assume I_NCX is defined such that positive means outward current.
    # Forward mode (Ca out): Positive I_Ca. Needs 3 Na+ in (negative I_Na). I_NCX_total > 0.
    # Reverse mode (Ca in): Negative I_Ca. Needs 3 Na+ out (positive I_Na). I_NCX_total < 0.
    # If I_NCX_total is the Ca2+ component (charge 2): I_Ca_NCX = I_NCX_total.
    # Then the Na+ component (charge 3) would be: I_Na_NCX = -1.5 * I_NCX_total? No, based on ions, not charge.
    # Let I_NCX be rate of Ca2+ efflux in current units.
    # I_Ca_NCX = I_NCX_total
    # I_Na_NCX = -3 * I_NCX_total (if I_NCX_total is proportional to Ca flux) - Need to verify convention.
    # The original code returns (3*I_NCX, -2*I_NCX). This implies I_NCX is scaled differently.
    # Let's assume the original code's return value stoichiometry is correct relative to its I_NCX calculation.
    I_Na_NCX = 3 * I_NCX_total
    I_Ca_NCX = -2 * I_NCX_total # This implies 3 Na+ out for 2 Ca2+ in? Stoichiometry seems wrong (usually 3Na/1Ca). Recheck S2 Eq 31-33 source.
                               # Keeping original return structure as requested.
    return I_Na_NCX, I_Ca_NCX

@njit
def pmca_explicit_euler_step(
    pump_bound_current, # Current density of Ca-bound pump (mol/cm^2)
    Cai,                # Current intracellular Ca concentration (mM)
    Cao,                # Current extracellular Ca concentration (mM)
    dt,                 # Time step (s) - MUST BE SMALL ENOUGH FOR STABILITY
    # --- Parameters ---
    K1 = 3.74e7,        # /mM-s
    K2 = 2.5e5,         # /s
    K3 = 500.0,         # /s
    K4 = 5.0,           # /mM-s
    pump0 = 4.232e-13,  # Total pump density (mol/cm^2)
    F = 96485.33,       # Faraday's constant (C/mol)
    z = 2               # Valence of Ca2+
):
    """
    Calculates one time step update for PMCA pump states using Forward Euler
    and the instantaneous current density based on the state *before* the update.

    Args:
        pump_bound_current (float): Density of bound pump at the start of dt (mol/cm^2).
        Cai (float): Intracellular Ca2+ concentration at the start of dt (mM).
        Cao (float): Extracellular Ca2+ concentration at the start of dt (mM).
        dt (float): Time step duration (s).
        K1, K2, K3, K4 (float, optional): Rate constants. Defaults from paper.
        pump0 (float, optional): Total pump density. Defaults from paper.
        F (float, optional): Faraday's constant. Defaults to standard value.
        z (int, optional): Ion valence. Defaults to 2.

    Returns:
        tuple: Contains:
            - pump_bound_new (float): Updated density of bound pump (mol/cm^2).
            - ipmca_mA_per_cm2 (float): Instantaneous current density (mA/cm^2)
                                         calculated based on the state *before* update.
                                         Positive current = outward movement of charge.
    """
    # --- Calculate Instantaneous Current (based on state BEFORE update) ---

    # Calculate free pump density at the start of dt
    dt = dt*1e-3
    pump_free_current = pump0 - pump_bound_current
    pump_bound_current_clamped = pump0 - pump_free_current # Recalculate bound based on clamped free

    # Calculate instantaneous flux density JPMCA (mol/cm^2/s)
    JPMCA = K3 * pump_bound_current_clamped - K4 * pump_free_current * Cao

    # Calculate instantaneous current density IPMCA (A/cm^2)
    IPMCA_A_per_cm2 = z * F * JPMCA

    # Convert to mA/cm^2
    ipmca_mA_per_cm2 = IPMCA_A_per_cm2 * 1000.0

    # --- Update Pump State using Forward Euler ---

    # Calculate the rate of change d(pump_bound)/dt
    rate_increase = K1 * Cai * pump_free_current + K4 * Cao * pump_free_current
    rate_decrease = K2 * pump_bound_current_clamped + K3 * pump_bound_current_clamped
    dpump_bound_dt = rate_increase - rate_decrease

    # Calculate the change over the time step dt
    delta_pump_bound = dpump_bound_dt * dt

    # Calculate the new bound pump density
    pump_bound_new_raw = pump_bound_current_clamped + delta_pump_bound

    # --- Enforce Bounds (Prevent numerical issues) ---
    # Clamp the new bound state between 0 and pump0
    pump_bound_new = max(0.0, min(pump0, pump_bound_new_raw))

    return pump_bound_new, ipmca_mA_per_cm2


@njit
def pmca_implicit_euler_step(
    pump_bound_current, # Current density of Ca-bound pump (mol/cm^2)
    Cai,                # Current intracellular Ca concentration (mM)
    Cao,                # Current extracellular Ca concentration (mM)
    dt,                 # Time step (ms)
    # --- Parameters ---
    K1 = 3.74e7,        # /mM-s
    K2 = 2.5e5,         # /s
    K3 = 500.0,         # /s
    K4 = 5.0,           # /mM-s
    pump0 = 4.232e-13,  # Total pump density (mol/cm^2)
    F = 96485.33,       # Faraday's constant (C/mol)
    z = 2               # Valence of Ca2+
):
    """
    Calculates one time step update for PMCA pump states using the stable
    Backward Euler method.
    """
    dt_s = dt * 1e-3 # Convert dt from ms to seconds
    
    # --- Calculate current based on state BEFORE update ---
    pump_free_current = pump0 - pump_bound_current
    JPMCA = K3 * pump_bound_current - K4 * pump_free_current * Cao
    IPMCA_A_per_cm2 = z * F * JPMCA
    ipmca_mA_per_cm2 = IPMCA_A_per_cm2 * 1000.0

    # --- Update Pump State using Backward Euler ---
    # For y' = A - B*y, the update is y_new = (y_old + dt*A) / (1 + dt*B)
    A = (K1 * Cai + K4 * Cao) * pump0
    B = (K1 * Cai + K4 * Cao) + K2 + K3
    
    numerator = pump_bound_current + dt_s * A
    denominator = 1.0 + dt_s * B
    
    pump_bound_new = numerator / denominator

    # Enforce bounds as a safeguard
    pump_bound_new = max(0.0, min(pump0, pump_bound_new))
    
    return pump_bound_new, ipmca_mA_per_cm2

@njit
def pmca_rk4_step(
    pump_bound_current, # Current density of Ca-bound pump (mol/cm^2)
    Cai,                # Current intracellular Ca concentration (mM)
    Cao,                # Current extracellular Ca concentration (mM)
    dt,                 # Time step (ms)
    # --- Parameters ---
    K1 = 3.74e7,        # /mM-s
    K2 = 2.5e5,         # /s
    K3 = 500.0,         # /s
    K4 = 5.0,           # /mM-s
    pump0 = 4.232e-13,  # Total pump density (mol/cm^2)
    F = 96485.33,       # Faraday's constant (C/mol)
    z = 2               # Valence of Ca2+
):
    """
    Calculates one time step update for PMCA pump states using the
    4th-order Runge-Kutta (RK4) method.
    """
    dt_s = dt * 1e-3 # Convert dt from ms to seconds

    # --- Calculate current based on state BEFORE update ---
    pump_free_current = pump0 - pump_bound_current
    JPMCA = K3 * pump_bound_current - K4 * pump_free_current * Cao
    IPMCA_A_per_cm2 = z * F * JPMCA
    ipmca_mA_per_cm2 = IPMCA_A_per_cm2 * 1000.0

    # --- Define the derivative function d(pump_bound)/dt ---
    def derivative(pb):
        """
        Calculates the rate of change of the bound pump concentration.
        This is the function f(y) in the ODE y' = f(y).
        """
        A = (K1 * Cai + K4 * Cao) * pump0
        B = (K1 * Cai + K4 * Cao) + K2 + K3
        return A - B * pb

    # --- Update Pump State using RK4 ---
    k1 = dt_s * derivative(pump_bound_current)
    k2 = dt_s * derivative(pump_bound_current + 0.5 * k1)
    k3 = dt_s * derivative(pump_bound_current + 0.5 * k2)
    k4 = dt_s * derivative(pump_bound_current + k3)

    pump_bound_new = pump_bound_current + (k1 + 2*k2 + 2*k3 + k4) / 6.0

    # Enforce bounds as a safeguard
    pump_bound_new = max(0.0, min(pump0, pump_bound_new))

    return pump_bound_new, ipmca_mA_per_cm2

@njit
def calculate_concentric_shell_volumes(outer_radius,ER_part = 0.12, Mt_part = 0.08, num_shells = 12):
    """
    Calculates the volumes of 12 concentric shells with equal radius increments, 
    given the radius of the outermost shell.

    Args:
        outer_radius: The radius of the outermost shell (float).

    Returns:
        A list of 12 floats, representing the volumes of the 12 shells,
        starting from the innermost shell and going outwards.
    """
    
    shell_width = outer_radius / num_shells  # Calculate the width of each shell

    volumes = []
    for i in range(1, num_shells + 1):
        # Calculate the outer and inner radii for the current shell
        outer_radius_shell = i * shell_width
        inner_radius_shell = (i - 1) * shell_width

        # Calculate the volume of the shell
        volume = (4/3) * np.pi * (outer_radius_shell**3 - inner_radius_shell**3)
        volumes.append(volume)
    volumes = np.asarray(volumes)
    volumes_ER = volumes*ER_part
    volumes_Mt = volumes*Mt_part
    volumes_cyt = volumes - (volumes_ER + volumes_Mt)
    return volumes_cyt*1e-3, volumes_ER*1e-3, volumes_Mt*1e-3
@njit
def shell_flux(Ca_cyt_mat, Ca_ER_mat, Ca_Mt_mat,Na_i,
               dt,
               V_shell, V_ER, V_Mt,IP3R_vars,
               K_B_ER = 0.5, B_tot_ER = 10,K_B_Mt = 1e-5,B_tot_Mt= 0.065, buffer_cyt = 0.00269541779,##precompiled 1/(1+betta+betta_dye)
               IP3_0 = 1.6e-4,):#model may be expanded to include IP3 production, but in the original paper there is no mechanism for it

    for count, i in enumerate(Ca_cyt_mat):
        ER_current,Mt_current = shell_Ca_current(Ca_cyt_mat[count], Ca_ER_mat[count], Ca_Mt_mat[count],IP3_0, IP3R_vars[count],Na_i) # current in mM/ms
        dCa_cyt = buffer_cyt*(-((ER_current + Mt_current)*dt) / V_shell[count])
        ER_buffer = (K_B_ER * B_tot_ER) / (K_B_ER + Ca_ER_mat[count])**2
        dCa_ER = (1/(1+ER_buffer)) * (ER_current*dt) / V_ER[count]
        Mt_buffer = (K_B_Mt * B_tot_Mt) / (K_B_Mt + Ca_Mt_mat[count])**2
        dCa_Mt = (1/(1+Mt_buffer)) * (Mt_current*dt) / V_Mt[count]
        Ca_cyt_mat[count] += dCa_cyt
        Ca_ER_mat[count] += dCa_ER
        Ca_Mt_mat[count] += dCa_Mt
        IP3R_vars[count] = update_IP3R_h(Ca_cyt_mat[count], IP3R_vars[count], dt)
    ###update h
    return Ca_cyt_mat, Ca_ER_mat, Ca_Mt_mat, IP3R_vars

@njit
def shell_Ca_current(Ca_cyt, Ca_ER, Ca_Mt,IP3R,IP3R_vars,Na_i):
    ER_current = SERCA_current(Ca_cyt) - RYR_current(Ca_cyt, Ca_ER) - ER_leak(Ca_cyt,Ca_ER)- IP3R_current(IP3R,Ca_cyt, Ca_ER, h = IP3R_vars)
    Mt_current = MCU_current(Ca_cyt) - MNCX_current(Na_i, Ca_Mt) # 
     #+Mt_current
    return ER_current, Mt_current
@njit
def SERCA_current(Ca_cyt, Vserca = 3.75e-6, Kpsr = 27e-5):
    J =  Vserca * (Ca_cyt**2 /(Ca_cyt**2 + Kpsr**2)) #in mM/ms
    return J

@njit
def RYR_current(Ca_cyt, Ca_ER, VCICR = 5e-7, KCICR = 1.98e-3,KT = 0.0006):
    if Ca_cyt <=KT:
        return 0
    else:
        J = (Ca_ER - Ca_cyt) * VCICR * (Ca_cyt/(Ca_cyt+KCICR))
        return J

@njit
def ER_leak(Ca_cyt, Ca_ER, L_ER =  7.7180688e-7):
    J = L_ER * (1- (Ca_cyt/Ca_ER))
    return J

@njit
def IP3R_current(IP3, Ca_cyt,Ca_ER,h,Jmax = 3.5e-6, KIP3 = 8e-4, Kactip3 = 3e-4):
    J = Jmax * (1-(Ca_cyt/Ca_ER)) * ((IP3/(IP3+KIP3))*(Ca_cyt/(Ca_cyt+Kactip3))*h)**3
    return J #mM/ms
@njit
def update_IP3R_h(Ca_cyt, h,dt,Konip3 = 2.7, Kinhip3 = 2e-4):
    dh = Konip3*(Kinhip3 - (Ca_cyt + Kinhip3)*h)
    return h+dh*dt

@njit
def IP3R_h_inf(Ca_cyt, Kinhip3 = 2e-4):
    return Kinhip3/(Ca_cyt+Kinhip3)
@njit
def MCU_current(Ca_cyt, V_MCU = 1.4468e-6, K_MCU_pow = 3.97820387e-8):## K_MCU = 6.06e-4, this is just K t opower of 2.3, so not to calculate this evey time
    Ca_cyt = Ca_cyt**2.3
    J = V_MCU *(Ca_cyt / (Ca_cyt + K_MCU_pow))
    return J
@njit
def MNCX_current(Na_i, Ca_Mt, V_MNCX = 6e-5, K_Na_pow = 512, K_Ca = 0.0035):#same - K_Na is 8, here it is **3
    Na_i = Na_i**3
    J = V_MNCX * (Na_i / (Na_i + K_Na_pow)) * (Ca_Mt / (Ca_Mt + K_Ca))
    return J
@njit
def update_trpm8_m(m_current, Cai_mM, dt,
                   mmin = 0, mmax = 200, Kca = 0.0005, taum = 80000):
    if Cai_mM < 0:
        Cai_mM = 0
    minf = mmin + (mmax - mmin) * Cai_mM / (Kca + Cai_mM)
    dm = (minf - m_current) / taum
    m_new = m_current + dm * dt
    return m_new

@njit
def update_trpm8_m_rk4(m_current, Cai_mM, dt,
                       mmin=0, mmax=200, Kca=0.0005, taum=80000):
    """
    Updates the TRPM8 gating variable m over one time step dt using the RK4 method.
    """
    if Cai_mM < 0:
        Cai_mM = 0

    # This value depends on Cai_mM, which is constant for the duration of the time step.
    # We calculate it once to avoid redundant computations inside the RK4 stages.
    minf = mmin + (mmax - mmin) * Cai_mM / (Kca + Cai_mM)

    # Helper function to compute the derivative for a given state 'm'.
    # It uses 'minf' and 'taum' from the parent function's scope.
    def _get_deriv(m):
        return (minf - m) / taum

    # RK4 Stage 1: Calculate k1 (slope at the beginning)
    k1 = _get_deriv(m_current)

    # RK4 Stage 2: Calculate k2 (slope at the midpoint, estimated using k1)
    k2 = _get_deriv(m_current + 0.5 * dt * k1)

    # RK4 Stage 3: Calculate k3 (slope at the midpoint, estimated using k2)
    k3 = _get_deriv(m_current + 0.5 * dt * k2)

    # RK4 Stage 4: Calculate k4 (slope at the end, estimated using k3)
    k4 = _get_deriv(m_current + dt * k3)

    # Final update: Combine the slopes using a weighted average
    m_new = m_current + (dt / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)

    return m_new
@njit
def calculate_trpm8_current(Vm, m, Cai_mM, T,
                            gbar = 1e-7, dE = 9e3, C = 67, z = 0.65, em8 = 0 , p_ca = 0.01,F = 96485,R = 8.314):
    celsius = T - 273.15

    vh_temp_numerator = C * R * celsius - dE
    vh_temp_denominator = z * F
    vh_temp = 1000.0 * vh_temp_numerator / vh_temp_denominator # Result in mV

    vhalf_full = vh_temp + m # mV'
    exponent_numerator = -z * F * (Vm - vhalf_full) # C/mol * mV
    exponent_denominator =  R * celsius * 1000       # 1000 * J/mol = 1000 * C*V/mol
    exponent = exponent_numerator / exponent_denominator # Dimensionless
    # Prevent overflow in exp()
    am = 1.0 / (1.0 + np.exp(exponent))
    # Calculate driving force
    driving_force = Vm - em8 # mV
    # Calculate current components
    # Units: gbar (S/cm^2) * am (dim.less) * driving_force (mV = 1e-3 V)
    #      = (A/V/cm^2) * (1e-3 V) = 1e-3 A/cm^2 = mA/cm^2
    g_activated = gbar * am
    I_Ca_TRPM = p_ca * g_activated * driving_force       # mA/cm^2
    I_TRPM = (1.0 - p_ca) * g_activated * driving_force  # mA/cm^2

    return I_TRPM, I_Ca_TRPM


@njit
def get_steady_state_TRPM(Cai_mM,
                    mmin=0.0, mmax=200.0, Kca=0.0005):
    # Calculate steady-state value minf based on cai_mM
    # Formula from PROCEDURE rate(ca (mM)) in trpm8.mod file
    # minf = mmin + (mmax - mmin) * (ca) / (Kca + ca)
    minf = mmin + (mmax - mmin) * Cai_mM / (Kca + Cai_mM)
    return minf
# =============================================================================
# --- Simulation Step ---
# =============================================================================
@njit
def calculate_Epas( E_Na, E_K, E_Cl, 
                   g_Na = 0.01, g_K = 0.2, g_Cl = 0.1):
    """
    Calculates the passive reversal potential (E_pas).

    Args:
        g_Na (float): Conductance for sodium ions (in Siemens).
        g_K (float): Conductance for potassium ions (in Siemens).
        g_Cl (float): Conductance for chloride ions (in Siemens).
        E_Na (float): Nernst potential for sodium (in mV).
        E_K (float): Nernst potential for potassium (in mV).
        E_Cl (float): Nernst potential for chloride (in mV).

    Returns:
        float: The passive reversal potential (E_pas) in mV.
    """
    g_total = g_Na + g_K + g_Cl
    
    # Avoid division by zero if all conductances are zero
    if g_total == 0:
        return 0.0

    # Calculate the weighted average of the reversal potentials
    E_pas = (g_Na * E_Na + g_K * E_K + g_Cl * E_Cl) / g_total
    
    return E_pas

@njit
def single_neuron(Vm, Na_i, Na_o, K_i, K_o, Ca_o, Ca_i, # Cl_o, Cl_i, # Chloride concentrations not used as state vars here
                  area, dt, C, # C in Farads
                  Na_ch_vars, K_ch_vars, Ca_ch_vars, HCN_ch_vars, Cl_ch_vars,pump_bound_current,
                  nernst, # RT/F in mV
                  V_ext, # Extracellular volume (liters? cm^3?) - units needed for flux calc
                  V_neur, # Neuronal volume (liters? cm^3?) - units needed for flux calc
                  IC, # Injected current (Amperes?) - units needed for dVm/dt
                  T, # Temperature in Kelvin for GHK,
                  K_base,
                  TRPM_m,enforce_stability = False,
                  ion_specificity = False,
                  dynamic_ions = True, F = 96485.332 # Faraday constant C/mol
                 ):
    """
    Performs a single time step update for the neuron model.
    """
    # --- Calculate currents from pumps and exchangers (Result assumed mA/cm^2) ---

    # Ensure concentrations are positive
    Na_i = max(1e-12, Na_i); Na_o = max(1e-12, Na_o)
    K_i = max(1e-12, K_i); K_o = max(1e-12, K_o)
    Ca_i = max(1e-12, Ca_i); Ca_o = max(1e-12, Ca_o)
    I_Na_pump_dens, I_K_pump_dens = KNa_pump(Na_i, C, area)
    # I_Na_pump_dens = 0
    # I_K_pump_dens = 0
    I_Na_NCX_dens, I_Ca_NCX_dens = NCX(Vm, Na_i, Na_o, Ca_i, Ca_o, nernst)
    #pump_bound_current, I_Ca_PMCA_dens = pmca_explicit_euler_step(pump_bound_current, Ca_i, Ca_o, dt)
    #pump_bound_current, I_Ca_PMCA_dens = pmca_implicit_euler_step(pump_bound_current, Ca_i, Ca_o, dt)
    pump_bound_current, I_Ca_PMCA_dens = pmca_rk4_step(pump_bound_current, Ca_i, Ca_o, dt)
    #I_Ca_PMCA_dens = calculate_IPMCA(Ca_i,Ca_o)

    # --- Calculate total ionic currents through channels (Result mA/cm^2) ---
    # I_Na_ch_dens = Na_current(Vm, Na_o, Na_i, Na_ch_vars, nernst)
    # I_K_ch_dens = K_current(Vm, K_o, K_i, Ca_i, K_ch_vars, nernst) # Ca_i assumed mM
    I_Ca_ch_dens = Ca_current(Vm, Ca_o, Ca_i, Ca_ch_vars, T) # Ca_i, Ca_o assumed mM
    # I_Cl_ch_dens = Cl_current(Vm, Ca_i, Cl_ch_vars) # Ca_i assumed mM

    # --- Calculate other currents (Result mA/cm^2) ---
    # I_HCN_dens = HCN_current(Vm, HCN_ch_vars)
    # I_pas_dens = calculate_Ipas(Vm)

    E_K = nernst * np.log(K_o / K_i)
    E_Na = nernst * np.log(Na_o / Na_i)
    
    # --- Calculate total ionic currents through channels ---
    I_Na_ch_dens = Na_current(Vm, Na_o, Na_i, Na_ch_vars, nernst) 
    I_K_ch_dens = K_current(Vm, K_o, K_i, Ca_i, K_ch_vars, nernst)
    I_Cl_ch_dens = Cl_current(Vm, Ca_i, Cl_ch_vars)
    I_HCN_dens = HCN_current(Vm, HCN_ch_vars)

    
    E_pas = calculate_Epas(E_Na, E_K, E_Cl = 2.713723)
    I_pas_dens = calculate_Ipas(Vm, ) #Epas = E_pas
    I_TRPM, I_Ca_TRPM = calculate_trpm8_current(Vm,TRPM_m,Ca_i, T)

    # I_Na_ch_dens = I_Na_ch_dens_uA / 1000.0   # Convert to mA/cm^2
    # I_K_ch_dens = I_K_ch_dens_uA / 1000.0     # Convert to mA/cm^2
    # I_Cl_ch_dens = I_Cl_ch_dens_uA / 1000.0   # Convert to mA/cm^2
    # I_HCN_dens = I_HCN_dens_uA / 1000.0       # Convert to mA/cm^2
    # I_pas_dens = I_pas_dens_uA / 1000.0       # Convert to mA/cm^2
    if ion_specificity:
        E_K = nernst * np.log(K_o / K_i)
        E_Na = nernst * np.log(Na_o/Na_i)
        wK = np.abs(Vm - E_K)/(1 +  np.abs(Vm - E_K) +  np.abs(Vm - E_Na))
        wNa = np.abs(Vm - E_Na)/(1 +  np.abs(Vm - E_K) +  np.abs(Vm - E_Na))
        I_K_ch_dens += (I_pas_dens + I_TRPM + I_HCN_dens)*wK
        I_Na_ch_dens += (I_pas_dens + I_TRPM + I_HCN_dens)*wNa
        I_pas_dens = 0
        I_TRPM = 0
        I_HCN = 0
                

    
    # --- Sum current densities for ion flux calculation ---
    # Total current density for each ion type
    I_Na_total_dens = I_Na_ch_dens + I_Na_pump_dens + I_Na_NCX_dens
    I_K_total_dens = I_K_ch_dens + I_K_pump_dens
    I_Ca_total_dens = I_Ca_ch_dens + I_Ca_PMCA_dens + I_Ca_NCX_dens + I_Ca_TRPM
    # Note: Cl current handled separately or assumed balanced if not tracked.

    # --- Update ion concentrations (if dynamic_ions is True) ---
    if dynamic_ions:
        # Convert current densities (mA/cm^2) to total currents (mA)
        I_Na_total = I_Na_total_dens * area # mA
        I_K_total = I_K_total_dens * area   # mA
        I_Ca_total = I_Ca_total_dens * area # mA

        # Convert total currents (A = mC/ms) to flux (mol/ms)
        # Flux = I / (z * F) where I is in Amperes (C/s).
        # I (mA) = I * 1e-3 (A). Flux = (I * 1e-3) / (z * F) (mol/s)
        # Flux (mol/ms) = Flux (mol/s) * 1e-3 = (I * 1e-6) / (z * F)
        Na_flux = (I_Na_total * 1e-6) / (1 * F) # mol/ms (z=1)
        K_flux = (I_K_total * 1e-6) / (1 * F)   # mol/ms (z=1)
        Ca_flux = (I_Ca_total * 1e-6) / (2 * F) # mol/ms (z=2).
        
        # Update concentrations using Euler method
        # Outward flux (positive I) decreases intracellular, increases extracellular.
        # Flux units: mol/ms
        # dC units: mM
        dNa_o = (Na_flux * 1000 * dt) / V_ext
        dNa_i = -(Na_flux * 1000 * dt) / V_neur
        dK_o = (K_flux * 1000 * dt) / V_ext
        dK_i = -(K_flux * 1000 * dt) / V_neur
        dCa_o = (Ca_flux * 1000 * dt) / V_ext
        dCa_i = -(Ca_flux * 1000 * dt) / V_neur

        # Update concentrations
        Na_o = Na_o + dNa_o
        Na_i = Na_i + dNa_i # Update Na_i (not tracked as state var, but needed for next step if passed back)
        K_o = K_o + dK_o
        K_i = K_i + dK_i   # Update K_i (not tracked as state var)
        Ca_o = Ca_o + dCa_o
        Ca_i = Ca_i + dCa_i

        # Ensure concentrations don't go below a minimum positive value
        Na_o = max(1e-9, Na_o); Na_i = max(1e-9, Na_i)
        K_o = max(1e-9, K_o);   K_i = max(1e-9, K_i)
        Ca_o = max(1e-9, Ca_o); Ca_i = max(1e-9, Ca_i)
    # --- Calculate total membrane current density ---
    I_mem_total_dens = I_Na_total_dens + I_K_total_dens+ I_pas_dens + I_Cl_ch_dens + I_HCN_dens  + I_Ca_total_dens + I_TRPM
    # --- Update membrane potential ---
    # dVm/dt = -I_total / C_total.
    # C_total = C (Farads)
    # dVm/dt (mV/ms) = - (I_mem_total_dens * area - IC) / (C * 1000) # (mA / F = V/s = mV/ms)
    I_total_mem = I_mem_total_dens * area # Total membrane current in mA
    # Ensure injected current IC has the correct sign convention (positive inward?)
    # If IC is inward current, it depolarizes, so it should subtract from outward membrane current.
    # Assuming IC is provided in mA.
    dVm_dt = -(I_total_mem - IC) / (C * 1e3) # C(F)*1e3 = C(mF). mA/mF = mV/ms.
    new_Vm = Vm + dt * dVm_dt

    # --- Update gating variables ---
    Na_ch_vars = update_Na_vars_rk4(Vm, Na_ch_vars, dt)
    K_ch_vars = update_K_vars_rk4(Vm, Na_i, K_ch_vars, Ca_i, dt) # Use updated Na_i, Ca_i? Or old? Using old Vm, current Na_i, Ca_i.
    HCN_ch_vars = update_HCN_vars_rk4(Vm, HCN_ch_vars, dt)
    Ca_ch_vars = update_Ca_vars_rk4(Vm, Ca_i, Ca_ch_vars, dt) # Using old Vm, current Ca_i.
    Cl_ch_vars = update_Cl_vars_rk4(Vm, Cl_ch_vars, Ca_i, dt) # Using old Vm, current Ca_i.
    new_m_TRPM = update_trpm8_m_rk4(TRPM_m,Ca_i,dt)
    # --- Package new state ---
    # Create the state vector for the next time step
    new_state_mat = np.zeros((34,))
    new_state_mat[0] = new_Vm
    new_state_mat[1] = K_o   # Updated K_o
    new_state_mat[2] = Na_o   # Updated Na_o
    new_state_mat[3] = Ca_i   # Updated Ca_i
    new_state_mat[4] = Ca_o   # Updated Ca_o
    new_state_mat[5:11] = Na_ch_vars # Updated Na gating vars
    new_state_mat[11:18] = K_ch_vars  # Updated K gating vars
    new_state_mat[18:20] = HCN_ch_vars # Updated HCN gating vars
    new_state_mat[20:30] = Ca_ch_vars # Updated Ca gating vars
    new_state_mat[30] = Cl_ch_vars[0] # Updated Cl gating var (unpack from array)
    new_state_mat[31] = new_m_TRPM
    new_state_mat[32] = Na_i
    new_state_mat[33] = K_i


    if abs(Vm)>2e4 and not enforce_stability:
        print('Numerical instability ecnountered in single_neuron:')
        print('K_o = ', np.round(K_o, 4))
        print('I_mem_total_dens = ', np.round(I_mem_total_dens, 4))
        print('dNa_o, dK_o, dCa_o = ', np.round(dNa_o, 4), np.round(dK_o, 4), np.round(dCa_o, 4))
        print('dNa_i, dK_i, dCa_i = ', np.round(dNa_i, 4), np.round(dK_i, 4), np.round(dCa_i, 4))
        print('Vm = ', np.round(Vm, 4))
        raise ValueError('Numerical Instability error')
    if enforce_stability:
        Vm = min(Vm, 2e2)
        Vm = max(Vm, -2e2)
    return new_state_mat, pump_bound_current #, Na_i, K_i # Optionally return updated internal concentrations





sk3_cai_data_mM = np.array([0.2e-3, 0.3e-3, 0.5e-3, 10e-3]) # [Ca]i in mM
sk3_vh_data_mV = np.array([24.0, 35.30399, 60.49381, 59.13068]) # V1/2 in mV
sk3_sf_data_mV = np.array([128.0, 38.80141, 48.77538, 44.82705]) # Slope factor in mV

# CaN section
# Corresponds to //CaN in tautables.txt
# vecv2_CaN.indgen(-30, 20, 10) -> generates: -30, -20, -10, 0, 10, 20 (6 elements)
# vechtau_CaN.append(60.00, 50.00, 44.91, 21.14, 22.26, 26.9)

N_tau_h_Vm = np.array([-30., -20., -10., 0., 10., 20.]) # mV
N_tau_h_ms = np.array([60.00, 50.00, 44.91, 21.14, 22.26, 26.9]) # ms

# CaR section
# Corresponds to //CaR in tautables.txtf
# vecv_CaR.indgen(-20, 20, 10) -> generates: -20, -10, 0, 10, 20 (5 elements)
# This vecv_CaR applies to vecmtau_CaR, vechtau_CaR (slow h), and vechtau2_CaR (fast h)

# vecmtau_CaR.append(0.745,2.087864,1.044556,1.5,0.796224)
R_tau_m_Vm = np.array([-20., -10., 0., 10., 20.]) # mV
R_tau_m_ms = np.array([0.745, 2.087864, 1.044556, 1.5, 0.796224]) # ms

# vechtau2_CaR.append(20.85107,17.19,12.63647,17.82774,16.75605) (corresponds to R_tau_hf)
R_tau_hf_Vm = np.array([-20., -10., 0., 10., 20.]) # mV
R_tau_hf_ms = np.array([20.85107, 17.19, 12.63647, 17.82774, 16.75605]) # ms

# vechtau_CaR.append(828.96717,662.18,430.74661,513.46502,473.16638) (corresponds to R_tau_hs)
R_tau_hs_Vm = np.array([-20., -10., 0., 10., 20.]) # mV
R_tau_hs_ms = np.array([828.96717, 662.18, 430.74661, 513.46502, 473.16638]) # ms
      
# --- T-type Ca2+ Channel (CaV3) ---
# Kinetics from Mandge S2 Text, Fig F, Eqs 12-13
# Data for tau interpolations
T_tau_m_Vm = np.array([
    -65., -60., -55., -50., -45., -40., -35., -30., -25., -20.,
    -15., -10.,  -5.,   0.,   5.,  10.
]) # mV
T_tau_m_ms = np.array([
    6.73, 6.73, 4.37, 2.9,  2.27, 1.87, 1.47, 1.27, 1.2,  1.1,
    1.2,  1.2,  1.1,  1.1,  1.1,  1.1
])      # ms (Approx from Fig Fb)
T_tau_h_Vm = np.array([
    -70.18405, -60.     , -49.93865, -40.     , -29.93865, -19.8773 ,
    -9.81595 ,   0.1227 ,  10.18405,  20.2454 ,  30.18405,  40.2454
]) # mV
T_tau_h_ms = np.array([
    167.54378,  61.65778,  44.39809,  31.05863,  25.83774,  21.80703,
    17.62502,  14.5562 ,  15.25575,  12.60002,  12.92211,  13.54313
])   # ms (Approx from Fig Fb)
def diffusion_1D(u_prev, dt, D_coef, mesh=None, A=None, M=None,):

    if A is None:
        # Initialize mesh and finite element basis
        mesh = MeshLine(np.linspace(0,11,12))
        element = ElementLineP1()
        basis = Basis(mesh, element)
        #u_prev = np.zeros(basis.N)
        # Define stiffness form using skfem's BilinearForm decorator
        @BilinearForm
        def stiffness_1D(u, v, w):
            return u.grad[0] * v.grad[0]

        # Mass matrix (time derivative term)
        @BilinearForm
        def mass(u, v, w):
            return u * v

        # Assemble stiffness and mass matrices using the laplace model for stiffness
        K = stiffness_1D.assemble(basis)
        M = mass.assemble(basis)

        # Formulate the system matrix
        A_mat = M + D_coef * dt * K

        # # Apply boundary conditions (Dirichlet at endpoints, assuming zero-flux)
        # boundary_dofs = basis.get_dofs().all() #gets boundary DOFS
        # A, F = enforce(A_mat, M @ u_prev, D=boundary_dofs)
        A = A_mat
        F = M @ u_prev  # Right-hand side remains unchanged

        # Solve for the next time step
        u_next = solve(A, F)

    else:
        # Use pre-assembled matrices to solve the diffusion equation

        # Solve for the next time step using the pre-assembled matrices
        u_next = solve(A, M @ u_prev)

    return u_next, A, M, mesh, 

def find_nearest_node_1D(mesh, point):
    """
    Finds the index of the node in a 1D mesh that is closest to a given point.

    Args:
        mesh (skfem.MeshLine): The 1D mesh.
        point (float): The coordinate of the point.

    Returns:
        int: The index of the nearest node.
    """
    return np.argmin(np.abs(mesh.p[0, :] - point))

def diffusion_exc(neur_loc, size, dt, D_coef, K_base, K_o, mesh=None, u_prev=None, A=None, M=None, neur_ind=None, refined=6):
    """
    Simulates 1D diffusion of extracellular potassium concentration using the finite element method.

    Args:
        neur_loc (np.ndarray): (n_neurons,) array of neuron locations.
        size (float): Size of the 1D domain.
        dt (float): Time step (ms).
        D_coef (float): Diffusion coefficient.
        K_base (float): Baseline extracellular potassium concentration (mM).
        K_o (np.ndarray): (n_neurons,) array of current extracellular potassium concentrations at the neuron locations (mM).
        mesh (skfem.MeshLine, optional): Pre-existing mesh. If None, a new mesh is created. Defaults to None.
        u_prev (np.ndarray, optional): Previous solution vector. If None, initialized to zeros. Defaults to None.
        A (scipy.sparse.csc_matrix, optional): Pre-assembled system matrix. If None, it is assembled. Defaults to None.
        M (scipy.sparse.csc_matrix, optional): Pre-assembled mass matrix. If None, it is assembled. Defaults to None.
        neur_ind (np.ndarray, optional): (n_neurons, 2) array of neuron indices and K_o values. If None, indices are calculated. Defaults to None.
        refined (int, optional): Mesh refinement level. Defaults to 6.

    Returns:
        tuple: A tuple containing:
            - u_next (np.ndarray): Updated solution vector (potassium concentration at each node).
            - K_o_updated (np.ndarray): Updated extracellular potassium concentrations at the neuron locations (including the base concentration).
            - A (scipy.sparse.csc_matrix): The system matrix.
            - M (scipy.sparse.csc_matrix): The mass matrix.
            - mesh (skfem.MeshLine): The mesh used for the simulation.
            - neur_ind (np.ndarray): Array of neuron indices and K_o values.

    """

    K_o_source = K_o - K_base  # Potassium concentration relative to baseline (source term)
    # Check if matrices are pre-assembled, if not, perform initial assembly
    if A is None:
        # Initialize mesh and finite element basis
        mesh = MeshLine().refined(refined)
        mesh.p[0, :] = (mesh.p[0, :] - 0.5) * size  # Center and scale the mesh

        element = ElementLineP1()
        basis = Basis(mesh, element)

        # Define stiffness form using skfem's BilinearForm decorator
        @BilinearForm
        def stiffness_1D(u, v, w):
            return u.grad[0] * v.grad[0]

        # Mass matrix (time derivative term)
        @BilinearForm
        def mass(u, v, w):
            return u * v

        # Assemble stiffness and mass matrices using the laplace model for stiffness
        K = stiffness_1D.assemble(basis)
        M = mass.assemble(basis)

        # Find nearest nodes to neuron locations
        neur_ind = np.array([find_nearest_node_1D(mesh, x) for x in neur_loc])
        neur_ind = np.column_stack((neur_ind, K_o_source))  # Store index and source K_o

        u_prev = np.zeros(basis.N)
        for idx, val in neur_ind:
            u_prev[int(idx)] = val  # Initialize u_prev with source concentrations

        # Formulate the system matrix
        A_mat = M + D_coef * dt * K

        # # Apply boundary conditions (Dirichlet at endpoints, assuming zero-flux)
        #boundary_dofs = basis.get_dofs()

        # Enforce a value of 0 for the relative concentration at the boundaries.
        # This modifies both the system matrix A and the right-hand side vector F.
       # A, F = enforce(A_mat, M @ u_prev, D=boundary_dofs)
        # boundary_dofs = basis.get_dofs().all() #gets boundary DOFS
        # A, F = enforce(A_mat, M @ u_prev, D=boundary_dofs)
        A = A_mat
        F = M @ u_prev  # Right-hand side remains unchanged

        # Solve for the next time step
        u_next = solve(A, F)

        # Update source concentrations using the solution at neuron locations
        for count, (idx, _) in enumerate(neur_ind):
            K_o[count] = u_next[int(idx)]  # K_o now holds the updated *relative* concentrations

    else:
        # 2. Start with the diffused state from the end of the previous time step
        u_current = u_prev

        # 3. Overwrite the values at the neuron locations with the new source terms.
        #    This correctly "adds" the neuron's contribution for the current step.
        for count, (idx, _) in enumerate(neur_ind):
            u_current[int(idx)] = K_o_source[count]
        
        # 4. Now, solve for the next time step using the correctly updated state
        u_next = solve(A, M @ u_current)

        # A_temp = M + D_coef * dt * K  # Reassemble or use stored K matrix
        # boundary_dofs = basis.get_dofs()
        # A, F = enforce(A_temp, M @ u_current, D=boundary_dofs)
        # u_next = solve(A, F)
        

        # 5. Update the K_o array that will be returned
        for count, (idx, _) in enumerate(neur_ind):
            K_o[count] = u_next[int(idx)]

    K_o_updated = K_o + K_base  # Add back the baseline concentration
    return u_next, K_o_updated, A, M, mesh, neur_ind

def simulate_modules_diff(sgc_params,
                     neur_params,
                     common_params,
                     num_modules,  # MODIFIED: New parameter to specify the number of modules
                     size,  # NEW: Size of computational domain
                     neur_positions,  # NEW: Positions of neurons
                     refined=6,  # NEW: Mesh refinement level
                     indices_to_stimulate = None,
                     glia=True,
                     D_coef=1e-6, 
                    diff = False,
                    alpha = None,
                    enforce_stability = False,
                    noise = 1e-5
                     ):
    """
    Simulates a specified number of neuron-SGC modules with K+ diffusion between modules.

    Each module is treated as a separate simulation space. All state variables
    (e.g., membrane potentials, ion concentrations, gating variables) are expanded
    to have a new dimension corresponding to the module index. K+ diffusion is
    calculated between modules in a 1D domain.

    Args:
        sgc_params (tuple): Parameters for the SGC model.
        neur_params (tuple): Parameters for the neuron model.
        common_params (tuple): Common simulation parameters (tmax, dt, etc.).
        num_modules (int): The number of independent modules to simulate.
        size (float): Size of the 1D computational domain.
        neur_positions (np.ndarray): Positions of neurons in the 1D domain.
        refined (int): Mesh refinement level for diffusion calculation.
        indices_to_stimulate: Indices of modules to stimulate.
        glia (bool): Flag to include SGC simulation.
        D_coef (float): Diffusion coefficient for K+.

    Returns:
        tuple: A tuple containing the simulation results for all modules:
               (state_mat_SGC, state_mat_neuron_Vm, state_mat_extracellular,
                Ca_cyt_mat, IC_inj)
               where the state matrices now have an additional first dimension
               of size `num_modules`.
    """

    # =============================================================================
    # --- Common ---
    # =============================================================================
    tmax, dt, T_celsius, stim_intervals, amplitude = common_params
    nt = int(tmax / dt) + 1  # Number of time steps
    R_gas = 8.314462  # J/(mol*K)
    F = 96485.332  # C/mol
    T_kelvin = T_celsius + 273.15  # Convert Celsius to Kelvin for GHK etc.
    nernst = 1e3 * R_gas * T_kelvin / F

    # =============================================================================
    # --- SGC Parameters & Initialization ---
    # =============================================================================
    (Vm_SGC_init, K_o_init, K_i_SGC_init, Na_o_init, Na_i_SGC_init, Ca_i_SGC_init, Ca_o_init, Cl_o_init, Cl_i_SGC_init,
     Neur_glial_dist_cm, R_sgc_cm, C_SGC, diff_method_str, betta_o, betta_i) = sgc_params

    # MODIFIED: SGC internal state variables are now arrays of size `num_modules`.
    # Each element in the array represents the state for one module.
    K_i_SGC = np.full(num_modules, K_i_SGC_init)
    Na_i_SGC = np.full(num_modules, Na_i_SGC_init)
    Ca_i_SGC = np.full(num_modules, Ca_i_SGC_init)
    Cl_i_SGC = np.full(num_modules, Cl_i_SGC_init)
    Vm_SGC = np.full(num_modules, Vm_SGC_init)
    # SGC Base internal concentrations (constants, do not need to be arrays)
    K_i_base_SGC, Na_i_base_SGC, Ca_i_base_SGC, Cl_i_base_SGC = K_i_SGC_init, Na_i_SGC_init, Ca_i_SGC_init, Cl_i_SGC_init
    Ca_o_base = Ca_o_init

    # Geometric and electrical properties are the same for all modules
    area_SGC = 4 * np.pi * R_sgc_cm**2
    V_sgc_cm3 = (4.0/3.0) * np.pi * (R_sgc_cm**3)
    V_sgc_L = V_sgc_cm3 * 1e-3

    # MODIFIED: SGC gating variables are now arrays for each module.
    n_KDR_SGC = np.full(num_modules, n_inf_KDR(Vm_SGC_init))
    J_pump_steady_state = steady_state_pump(Vm_SGC_init, Na_o_init, Na_i_SGC_init, K_o_init,nernst, F) # Same initial state for all
    m_kir = m_inf_Kir41(Vm_SGC)*np.ones((num_modules,))
    h_kir = h_inf_Kir41(Vm_SGC)*np.ones((num_modules,))
    #J_pump_steady_state = steady_state_pump(Vm_SGC_init, Na_o_init,Na_i_SGC_init,nernst, F)
    # MODIFIED: SGC state matrix is now a 3D tensor: (num_modules, num_vars, time_steps)
    state_mat_SGC = np.zeros((num_modules, 2, nt))
    state_mat_SGC[:, 0, 0] = Vm_SGC_init
    state_mat_SGC[:, 1, 0] = Ca_i_SGC_init

    if diff_method_str == 'simple_balance':
        diff_mode = simple_diff_balance
    else:
        diff_mode = dummy_diff

    # =============================================================================
    # --- NEURON Parameters & Initialization ---
    # =============================================================================
    (Vm_neur_init, Na_i_neur_init, K_i_neur_init, Ca_i_neur_init, Ca_ER_init, Ca_Mt_init,
     R_neur_cm, C_neur, dynamic_ions, skip_interval_diffusion, pump_initial, ion_specificity) = neur_params

    # MODIFIED: Neuron internal state variables are now arrays for each module.
    Na_i_neur = np.full(num_modules, Na_i_neur_init)
    K_i_neur = np.full(num_modules, K_i_neur_init)

    # Geometric properties are the same for all modules
    area_neur = 4 * np.pi * R_neur_cm**2
    V_neur_cm3 = (4.0/3.0) * np.pi * (R_neur_cm**3)
    V_neur_L = V_neur_cm3 * 1e-3
    R_outer_shell_cm = R_neur_cm + Neur_glial_dist_cm
    V_total_shell_cm3 = (4.0/3.0) * np.pi * (R_outer_shell_cm**3)
    V_ext_cm3 = V_total_shell_cm3 - V_neur_cm3
    V_ext_L = V_ext_cm3 * 1e-3
    V_shells_L, V_ER_L, V_Mt_L = calculate_concentric_shell_volumes(R_neur_cm)

    # MODIFIED: Neuron gating variables are now 2D arrays: (num_modules, num_gating_vars)
    # We calculate the steady state once and tile it for all modules.
    Na_ch_vars = np.tile(get_steady_state_Na_vars(Vm_neur_init), (num_modules, 1))
    K_ch_vars = np.tile(get_steady_state_K_vars(Vm_neur_init, Na_i_neur_init, Ca_i_neur_init), (num_modules, 1))
    HCN_ch_vars = np.tile(get_steady_state_HCN_vars(Vm_neur_init), (num_modules, 1))
    Ca_ch_vars = np.tile(get_steady_state_Ca_vars(Vm_neur_init), (num_modules, 1))
    Cl_ch_vars = np.tile(get_steady_state_Cl_vars(Vm_neur_init, Ca_i_neur_init), (num_modules, 1))
    TRPM_m = np.full(num_modules, get_steady_state_TRPM(Ca_i_neur_init))
    pump_bound_current = np.full(num_modules, pump_initial)

    # MODIFIED: Neuron state matrix is now a 3D tensor: (num_modules, 1, time_steps)
    state_mat_neuron_Vm = np.zeros((num_modules, 1, nt))
    state_mat_neuron_Vm[:, 0, 0] = Vm_neur_init
    J_pump_steady_state_neuron = steady_state_pump(Vm_neur_init, Na_o_init, Na_i_neur_init, K_o_init, nernst, F)
    # MODIFIED: Neuron Calcium stores are now 2D arrays: (num_modules, num_shells)
    Ca_cyt_mat = np.full((num_modules, 12), Ca_i_neur_init)
    Ca_ER_mat = np.full((num_modules, 12), Ca_ER_init)
    Ca_Mt_mat = np.full((num_modules, 12), Ca_Mt_init)
    IP3R_vars = np.full((num_modules, 12), IP3R_h_inf(Ca_i_neur_init))

    # =============================================================================
    # --- Common Extracellular State Initialization ---
    # =============================================================================
    # MODIFIED: Extracellular concentrations are now arrays, one for each module's space.
    K_o = np.full(num_modules, K_o_init)
    Na_o = np.full(num_modules, Na_o_init)
    Ca_o = np.full(num_modules, Ca_o_init)
    Cl_o = np.full(num_modules, Cl_o_init)

    # MODIFIED: Extracellular state matrix is now a 3D tensor: (num_modules, num_ions, time_steps)
    state_mat_extracellular = np.zeros((num_modules, 4, nt))
    state_mat_extracellular[:, 0, 0] = K_o_init
    state_mat_extracellular[:, 1, 0] = Na_o_init
    state_mat_extracellular[:, 2, 0] = Ca_o_init
    state_mat_extracellular[:, 3, 0] = Cl_o_init

    # Base extracellular concentrations (constants, do not need to be arrays)
    K_o_base, Na_o_base, Cl_o_base = K_o_init, Na_o_init, Cl_o_init

    # --- Setup Injected Current for Neuron ---
    # This is the same for all modules, so it remains a 1D array.
    IC_inj = np.zeros((num_modules, nt))
    if indices_to_stimulate: # Proceeds only if the list is not empty
        # Create the basic 1D stimulus pattern.
        stim_pattern = np.zeros(nt)
        for interval in stim_intervals:
            start_idx = int(interval[0] / dt)
            end_idx = int(interval[1] / dt)
            # Use robust slicing to prevent out-of-bounds errors
            actual_start_idx = max(0, start_idx)
            actual_end_idx = min(nt, end_idx + 1)
            if actual_start_idx < actual_end_idx:
                stim_pattern[actual_start_idx:actual_end_idx] = amplitude
        # Use NumPy's advanced indexing to apply the pattern to the selected rows.
        # This is highly efficient and assigns the 1D pattern to each specified row.
        IC_inj[np.array(indices_to_stimulate), :] = stim_pattern
    if alpha:
        for i in range(IC_inj.shape[0]):
            IC_inj[i] = double_expest(IC_inj[i], alpha)
    # MODIFIED: Diffusion setup needs to be handled per module. We use lists to store state.
    mesh_list = [None] * num_modules
    A_list = [None] * num_modules
    M_list = [None] * num_modules
    u_next = None
    neur_ind = None
    
    # NEW: Initialize K+ diffusion variables
    K_diff_mesh = None
    K_diff_A = None
    K_diff_M = None
    K_diff_u_prev = None
    K_diff_neur_ind = None
    K_mat = None
    basis = None
    if noise:
        noise = (1/(noise))
        noise = (np.random.random(nt)- 0.5)/noise
    else:
        noise = np.zeros((nt,))
    # =============================================================================
    # --- MAIN SIMULATION LOOP ---
    # =============================================================================
    for i in tqdm(range(1, nt), desc="Simulating Modules"):
        # MODIFIED: An outer loop iterates through each module for the current time step.
        # This ensures each module's simulation is performed separately.
        for j in range(num_modules):
            # --- Load state for the current module (j) from the previous time step (i-1) ---
            K_o_current = state_mat_extracellular[j, 0, i-1]
            Na_o_current = state_mat_extracellular[j, 1, i-1]
            Ca_o_current = state_mat_extracellular[j, 2, i-1]
            Cl_o_current = state_mat_extracellular[j, 3, i-1]

            if glia:
                # =============================================================================
                # --- SGC (Module j) ---
                # =============================================================================
                Vm_SGC_prev = state_mat_SGC[j, 0, i-1]
                Ca_i_SGC_prev = state_mat_SGC[j, 1, i-1]

                new_Vm_SGC, K_i_SGC_updated, K_o_after_SGC, Ca_i_SGC_updated, Ca_o_after_SGC, Na_o_after_SGC, Na_i_SGC_updated, Cl_o_after_SGC, Cl_i_SGC_updated, n_KDR_SGC_updated,m_kir_updated, h_kir_updated = single_SGC(
                    Vm_SGC_prev, V_ext_L, V_sgc_L, area_SGC, C_SGC,
                    K_i_SGC[j], K_o_current, Ca_i_SGC_prev, Ca_o_current, Na_o_current, Na_i_SGC[j], Cl_o_current, Cl_i_SGC[j], J_pump_steady_state,
                    n_KDR_SGC[j], nernst, dt, F, diff_mode, Neur_glial_dist_cm, R_sgc_cm,
                    K_o_base, Na_o_base, Cl_o_base, Ca_o_base, betta_o, K_i_base_SGC, Na_i_base_SGC, Cl_i_base_SGC, Ca_i_base_SGC, betta_i,m_kir = m_kir[j], h_kir = h_kir[j]
                )

                # Store updated SGC state for module j at time i
                state_mat_SGC[j, 0, i] = new_Vm_SGC
                state_mat_SGC[j, 1, i] = Ca_i_SGC_updated
                m_kir[j] = m_kir_updated
                h_kir[j] = h_kir_updated

                # Update the state variables for the next iteration of module j
                K_i_SGC[j] = K_i_SGC_updated
                Na_i_SGC[j] = Na_i_SGC_updated
                Cl_i_SGC[j] = Cl_i_SGC_updated
                n_KDR_SGC[j] = n_KDR_SGC_updated
                Ca_i_SGC[j] = Ca_i_SGC_updated

                # These concentrations are now the input for the neuron in this module
                K_o_for_neuron = K_o_after_SGC
                # print('sgc Kdelta')
                # print(K_o_current - K_o_for_neuron)
                Na_o_for_neuron = Na_o_after_SGC
                Ca_o_for_neuron = Ca_o_after_SGC
                Cl_o_for_neuron = Cl_o_after_SGC
            else: # If glia are not simulated
                state_mat_SGC[j, 0, i] = state_mat_SGC[j, 0, i-1]
                state_mat_SGC[j, 1, i] = state_mat_SGC[j, 1, i-1]
                K_o_for_neuron, Na_o_for_neuron, Ca_o_for_neuron, Cl_o_for_neuron = K_o_current, Na_o_current, Ca_o_current, Cl_o_current

            # =============================================================================
            # --- NEURON (Module j) ---
            # =============================================================================
            Vm_neur_prev = state_mat_neuron_Vm[j, 0, i-1] + noise[i]

            # Update calcium shells for module j
            Ca_cyt_mat[j], Ca_ER_mat[j], Ca_Mt_mat[j], IP3R_vars[j] = shell_flux(
                Ca_cyt_mat=Ca_cyt_mat[j], Ca_ER_mat=Ca_ER_mat[j], Ca_Mt_mat=Ca_Mt_mat[j], Na_i=Na_i_neur[j],
                dt=dt, V_shell=V_shells_L, V_ER=V_ER_L, V_Mt=V_Mt_L, IP3R_vars=IP3R_vars[j],
            )
            if i % skip_interval_diffusion == 0 or i ==0:
                # Perform diffusion for module j, using its own stateful mesh/matrices
                Ca_cyt_mat[j], A_list[j], M_list[j], mesh_list[j] = diffusion_1D(
                    Ca_cyt_mat[j], skip_interval_diffusion * dt, 6e-5, mesh_list[j], A_list[j], M_list[j]
                )

            Ca_i_neur_current_for_flux = Ca_cyt_mat[j, 0]

            # Run neuron simulation for module j
            neuron_full_new_state, pump_bound_current_updated = single_neuron(
                Vm=Vm_neur_prev, Na_i=Na_i_neur[j], Na_o=Na_o_for_neuron, K_i=K_i_neur[j], K_o=K_o_for_neuron,
                Ca_o=Ca_o_for_neuron, Ca_i=Ca_i_neur_current_for_flux, area=area_neur, dt=dt, C=C_neur,
                Na_ch_vars=Na_ch_vars[j], K_ch_vars=K_ch_vars[j], Ca_ch_vars=Ca_ch_vars[j],
                HCN_ch_vars=HCN_ch_vars[j], Cl_ch_vars=Cl_ch_vars[j], pump_bound_current=pump_bound_current[j],
                nernst=nernst, V_ext=V_ext_L, V_neur=V_neur_L, IC=IC_inj[j, i], T=T_kelvin,
                dynamic_ions=dynamic_ions, F=F,K_base = K_i_neur_init,TRPM_m=TRPM_m[j], ion_specificity=ion_specificity,enforce_stability = enforce_stability,
            )#J_NaK_max=J_pump_steady_state_neuron

            # Unpack and store results for module j
            state_mat_neuron_Vm[j, 0, i] = neuron_full_new_state[0]
            K_o_final = neuron_full_new_state[1]
            Na_o_final = neuron_full_new_state[2]
            Ca_i_neur_after_fluxes = neuron_full_new_state[3]
            Ca_o_final = neuron_full_new_state[4]
            # print('neur Kdelta')
            # print(K_o_for_neuron - K_o_final)
            Ca_cyt_mat[j, 0] = Ca_i_neur_after_fluxes
            pump_bound_current[j] = pump_bound_current_updated

            # Update gating variables and internal ion concentrations for module j
            Na_ch_vars[j, :] = neuron_full_new_state[5:11]
            K_ch_vars[j, :] = neuron_full_new_state[11:18]
            HCN_ch_vars[j, :] = neuron_full_new_state[18:20]
            Ca_ch_vars[j, :] = neuron_full_new_state[20:30]
            Cl_ch_vars[j, 0] = neuron_full_new_state[30]
            TRPM_m[j] = neuron_full_new_state[31]
            Na_i_neur[j] = neuron_full_new_state[32]
            K_i_neur[j] = neuron_full_new_state[33]

            # =============================================================================
            # --- Store temporary extracellular concentrations before K+ diffusion ---
            # =============================================================================
            K_o[j] = K_o_final
            Na_o[j] = Na_o_final
            Ca_o[j] = Ca_o_final
            Cl_o[j] = Cl_o_for_neuron  # Assuming neuron doesn't change Cl_o

        # NEW: Apply K+ diffusion between modules after all modules have been processed
        if diff:
            if i % skip_interval_diffusion == 0 or i ==0:
                # K_diff_u_prev, K_o_diffused, K_diff_A, K_diff_M, K_diff_mesh, K_diff_neur_ind = diffusion_exc(
                #     neur_loc=neur_positions,
                #     size=size,
                #     dt=dt*skip_interval_diffusion,
                #     D_coef=D_coef,
                #     K_base=K_o_base,
                #     K_o=K_o.copy(),  # Use the updated K_o from all modules
                #     mesh=K_diff_mesh,
                #     u_prev=K_diff_u_prev,
                #     A=K_diff_A,
                #     M=K_diff_M,
                #     neur_ind=K_diff_neur_ind,
                #     refined=refined
                # )

                K_diff_u_prev, K_o_diffused, K_diff_A, K_diff_M, K_diff_mesh, K_diff_neur_ind,basis, K_mat = diff_dirichet_cond(
                    neur_loc=neur_positions,
                    size=size,
                    dt=dt*skip_interval_diffusion,
                    D_coef=D_coef,
                    K_base=K_o_base,
                    K_o=K_o.copy(),  # Use the updated K_o from all modules
                    mesh=K_diff_mesh,
                    u_prev=K_diff_u_prev,
                    A=K_diff_A,
                    M=K_diff_M,
                    neur_ind=K_diff_neur_ind,
                    refined=refined,
                    basis = basis,
                    K = K_mat
                )
        
                # Update K_o with diffused values
                K_o = K_o_diffused

        # =============================================================================
        # --- Store final extracellular concentrations for all modules at step i ---
        # =============================================================================
        for j in range(num_modules):
            state_mat_extracellular[j, 0, i] = K_o[j]  # Updated with diffusion
            state_mat_extracellular[j, 1, i] = Na_o[j]
            state_mat_extracellular[j, 2, i] = Ca_o[j]
            state_mat_extracellular[j, 3, i] = Cl_o[j]

    # Return the multi-module result matrices
    return (state_mat_SGC, state_mat_neuron_Vm, state_mat_extracellular,
            Ca_cyt_mat, IC_inj)
def create_pulse_train_intervals(
    num_pulses: int,
    start_time_ms: float,
    frequency_hz: float,
    pulse_duration_ms: float
) -> np.ndarray:
    """
    Generates a stimulus train of rectangular pulses based on frequency.

    This function creates an array of [start, end] times for a series of
    pulses delivered at a specific frequency. This is ideal for recreating
    experimental protocols like the one described in the paper.

    Args:
        num_pulses (int): The total number of pulses in the train (e.g., 500).
        start_time_ms (float): The time in milliseconds when the first pulse begins.
        frequency_hz (float): The frequency of the pulses in Hz (e.g., 20).
        pulse_duration_ms (float): The duration of each individual pulse in ms (e.g., 1).

    Returns:
        np.ndarray: A 2D array of shape (num_pulses, 2) where each row is
                    the [start_time, end_time] for a pulse in milliseconds.
                    Returns an empty array of shape (0, 2) if num_pulses is 0.

    Raises:
        ValueError: If parameters are invalid (e.g., negative pulses,
                    non-positive frequency, or pulse duration exceeds the period).
    """
    # --- Input Validation ---
    if num_pulses < 0:
        raise ValueError("Number of pulses cannot be negative.")
    if num_pulses == 0:
        return np.empty((0, 2))

    if frequency_hz <= 0:
        raise ValueError("Frequency must be positive.")
    
    if pulse_duration_ms <= 0:
        raise ValueError("Pulse duration must be positive.")

    # Calculate the time (in ms) between the start of one pulse and the start of the next.
    # This is the period.
    period_ms = 1000.0 / frequency_hz

    if pulse_duration_ms > period_ms:
        raise ValueError(
            f"Pulse duration ({pulse_duration_ms} ms) cannot exceed the "
            f"pulse period ({period_ms:.2f} ms) derived from the frequency."
        )

    # --- Generate Intervals ---
    
    # Create an array of offsets for each pulse start time: [0, 1, 2, ..., num_pulses-1]
    pulse_indices = np.arange(num_pulses)
    
    # Calculate the absolute start time for each pulse using the formula:
    # start_i = t_start + i * period
    starts = start_time_ms + pulse_indices * period_ms
    
    # The end time for each pulse is simply its start time plus its duration
    ends = starts + pulse_duration_ms
    
    # Stack the starts and ends arrays into a 2-column matrix
    intervals = np.column_stack((starts, ends))
    
    return intervals
def diff_dirichet_cond(neur_loc, size, dt, D_coef, K_base, K_o, 
                       mesh=None, u_prev=None, A=None, M=None, neur_ind=None, refined=6, basis = None, K = None):

    K_o_source = K_o - K_base  # Potassium concentration relative to baseline (source term)
    # Check if matrices are pre-assembled, if not, perform initial assembly
    if A is None:
        # Initialize mesh and finite element basis
        mesh = MeshLine().refined(refined)
        mesh.p[0, :] = (mesh.p[0, :] - 0.5) * size  # Center and scale the mesh

        element = ElementLineP1()
        basis = Basis(mesh, element)

        # Define stiffness form using skfem's BilinearForm decorator
        @BilinearForm
        def stiffness_1D(u, v, w):
            return u.grad[0] * v.grad[0]

        # Mass matrix (time derivative term)
        @BilinearForm
        def mass(u, v, w):
            return u * v

        # Assemble stiffness and mass matrices using the laplace model for stiffness
        K = stiffness_1D.assemble(basis)
        M = mass.assemble(basis)

        # Find nearest nodes to neuron locations
        neur_ind = np.array([find_nearest_node_1D(mesh, x) for x in neur_loc])
        neur_ind = np.column_stack((neur_ind, K_o_source))  # Store index and source K_o

        u_prev = np.zeros(basis.N)
        for idx, val in neur_ind:
            u_prev[int(idx)] = val  # Initialize u_prev with source concentrations

        # Formulate the system matrix
        A_mat = M + D_coef * dt * K

        # # Apply boundary conditions (Dirichlet at endpoints, assuming zero-flux)
        #boundary_dofs = basis.get_dofs()

        # Enforce a value of 0 for the relative concentration at the boundaries.
        # This modifies both the system matrix A and the right-hand side vector F.
       # A, F = enforce(A_mat, M @ u_prev, D=boundary_dofs)
        # boundary_dofs = basis.get_dofs().all() #gets boundary DOFS
        # A, F = enforce(A_mat, M @ u_prev, D=boundary_dofs)
        A = A_mat
        F = M @ u_prev  # Right-hand side remains unchanged

        # Solve for the next time step
        u_next = solve(A, F)

        # Update source concentrations using the solution at neuron locations
        for count, (idx, _) in enumerate(neur_ind):
            K_o[count] = u_next[int(idx)]  # K_o now holds the updated *relative* concentrations

    else:
        # 2. Start with the diffused state from the end of the previous time step
        u_current = u_prev

        # 3. Overwrite the values at the neuron locations with the new source terms.
        #    This correctly "adds" the neuron's contribution for the current step.
        for count, (idx, _) in enumerate(neur_ind):
            u_current[int(idx)] = K_o_source[count]
        
        # 4. Now, solve for the next time step using the correctly updated state

        A_temp = M + D_coef * dt * K  # Reassemble or use stored K matrix
        boundary_dofs = basis.get_dofs()
        A, F = enforce(A_temp, M @ u_current, D=boundary_dofs)
        u_next = solve(A, F)
        

        # 5. Update the K_o array that will be returned
        for count, (idx, _) in enumerate(neur_ind):
            K_o[count] = u_next[int(idx)]
    K_o_updated = K_o + K_base  # Add back the baseline concentration
    return u_next, K_o_updated, A, M, mesh, neur_ind,  basis, K
    

In [None]:
@njit
def expo_estimation(meas, alpha):
    path_estimate_exp = np.copy(meas)
    for i in range(1, len(path_estimate_exp)):
        path_estimate_exp[i] = path_estimate_exp[i - 1] + alpha * (meas[i] - path_estimate_exp[i - 1])
    return path_estimate_exp
@njit
def double_expest(meas, alpha):
    return np.flip(expo_estimation(np.flip(expo_estimation(meas, alpha)),alpha))

In [None]:
if __name__ == __main__:
    Vm_sgc = -84.7 #Membrane potential of the SGC
    T = 37 #Temperature
    tmax = 500 #Simulation legnth
    dt =5e-3
    stim_intervals = create_pulse_train_intervals(100, 10, 20, 10) #creation of pusletrain intervals
    stim_amplitude = 5e-7 #amplitude of stimulation (in A/cm2)
    K_o = 5.00455 #External K concentrations
    K_i = 131.0 #internal K concentrations for SGC
    Na_o = 150.0 #External Na concentrations
    Na_i = 10.0 #Internal Na concentrations for SGC
    Ca_i = 5e-5 #Internal Ca concentrations for SGC
    Ca_o = 2.0 #External Ca concentrations
    Cl_o = 145.0 #External Cl concentrations
    Cl_i = 40.0 #Internal Cl concentrations for SGC
    R_sgc_cm = 0.0003 #radius of glial sheath
    C = 13.5e-12 #capacity of SGC membrane, in F
    diff = 'simple_balance' #way of calculating how SGC affects the concentrations
    betta_i = 1.3 #betta value from Madage et al. 2019
    betta_o = 0.01
    SGC_params = [Vm_sgc,K_o,K_i,Na_o,Na_i,Ca_i,Ca_o,Cl_o,Cl_i,Neur_glial_dist,R_sgc_cm,C,diff,betta_o, betta_i]
    common_params = [tmax, dt,T,stim_intervals, stim_amplitude]
    Vm_neur_init = -53.73926 #Membrane potential of the neuron
    K_i_neur_init = 140.0 #K concentration inside neuron
    Na_i_neur_init = 10.0 #Na concentration inside neuron
    Ca_i_neur_init = 1.36e-4 #Ca concentration inside neuron
    Neur_glial_dist = 2e-6 #ditance between glial envelop and neuron, in cm
    R_neur_cm = 0.002 #Neuron radius, in cm
    
    Ca_ER_init = 0.4 #Ca concentration in ER of neuron
    Ca_Mt_init = 2e-4 #Ca concentration in Mitochondrias of neuron
    pump_initital = 8.437e-15 #initial pump state value
    C_neur = 28e-12 #capacity of neuron membrane
    dynamic_ions = True #whether to update ion concentrations dynamicaly
    skip_interval_diffusion = 8 #diffusion will be calculated on every i-th step. A way to speed up calculations
    ion_specificity = True #whether passive currents are considered a sum of Na, K, Ca, Cl or as just abstract current, as in basic HH
    neur_params = [Vm_neur_init, Na_i_neur_init, K_i_neur_init, Ca_i_neur_init, Ca_ER_init, Ca_Mt_init,R_neur_cm, C_neur, dynamic_ions, skip_interval_diffusion,pump_initital,ion_specificity]
    D_K =9e-9 #diffusion coefficient for K in extracellular space
    size = 0.1 #size of simulated diffusion space
    refined = 5 #times diffusion space is refined
    neur_loc1D = np.array(([[0]])) #location of neuron in said space
    diff_params = [D_K, size, refined, skip_interval_diffusion]
    state_mat_SGC, state_mat_neur,exc_st,ca, IC = simulate_modules_diff(num_modules = 1,indices_to_stimulate=[0],size = size,neur_positions = neur_loc1D, refined=refined, D_coef = D_K,
                                                               sgc_params = SGC_params,neur_params = neur_params,common_params = common_params,
                                                                    glia = True,diff=True,alpha =dt/2,enforce_stability=False,noise = 1e-5)
    stimulated = decimate(exc_st[0][0],250)
    np.save('concentrations.npy',stimulated)