# Load Packages

In [60]:
import matplotlib.pyplot as plt
from numba import njit
from numba.core import types
from numba.typed import Dict
import numpy as np
from scipy.integrate import solve_ivp
from scipy.optimize import minimize

# Markov Model

In [61]:
@njit(fastmath = True)
def transition_matrix(Vm: float, weights: dict) -> np.ndarray[np.float64]:
    """Goal:
        This function returns the transition matrix of the Markov model for the
        Kv11 channel at a given voltage Vm.
        
        Adapted from the KV11.1 Markov model described in:
            Mazhari R, Greenstein JL, Winslow RL, Marbán E, Nuss HB. 
            Molecular interactions between two long-QT syndrome gene products, 
            HERG and KCNE2, rationalized by in vitro and in silico analysis. 
            Circ Res. 2001 Jul 6;89(1):33-8. doi: 10.1161/hh1301.093633. 
            PMID: 11440975.
    ---------------------------------------------------------------------------
    Input:
        Vm: Membrane voltage [mV]
    ---------------------------------------------------------------------------
    Output:
        Q: Transition matrix of the Markov model [s^-1]"""
    # Transition rates [ms^-1]
    # 5 States; 3 closed (C1, C2, C3), a transition state (I) and an open state (O)
    # C1 - C2 - C3 - O
    #            \  /
    #              I    
    # C1C2 = 0.0069 * np.exp(0.0272 * Vm) # Symbol in the paper --> a0
    # C2C1 = 0.0227 * np.exp(-0.0431 * Vm) # b0
    # C3O = 0.0218 * np.exp(0.0262 * Vm) # a1
    # OC3 = 0.0009 * np.exp(-0.0269 * Vm) # b1
    # OI = 0.0622 * np.exp(0.0120 * Vm) # ai
    # IO = 0.0059 * np.exp(-0.0443 * Vm) # bi
    # C2C3 = 0.0266 # kf
    # C3C2 = 0.1348 # kb
    # C3I = 1.29E-5 * np.exp(2.71E-6 * Vm) # ai3
    # IC3 = (OC3 * IO * C3I) / (C3O * OI) # Greek letter
    C1C2 = weights['C1C2_1'] * np.exp(weights['C1C2_2'] * Vm) # Symbol in the paper --> a0
    C2C1 = weights['C2C1_1'] * np.exp(weights['C2C1_2'] * Vm) # b0
    C3O = weights['C3O_1'] * np.exp(weights['C3O_2'] * Vm) # a1
    OC3 = weights['OC3_1'] * np.exp(weights['OC3_2'] * Vm) # b1
    OI = weights['OI_1'] * np.exp(weights['OI_2'] * Vm) # ai
    IO = weights['IO_1'] * np.exp(weights['IO_2'] * Vm) # bi
    C2C3 = weights['C2C3'] # kf
    C3C2 = weights['C3C2'] # kb
    C3I = weights['C3I_1'] * np.exp(weights['C3I_2'] * Vm) # ai3
    IC3 = (OC3 * IO * C3I) / (C3O * OI) # Greek letter

    # Transition matrix
    Q = np.zeros((5, 5), dtype = np.float64)
    Q[0, 0] = -C1C2 # C1
    Q[0, 1] = C2C1
    Q[1, 0] = C1C2 # C2
    Q[1, 1] = -(C2C1 + C2C3)
    Q[1, 2] = C3C2
    Q[2, 2] = -(C3C2 + C3O + C3I) # C3
    Q[2, 1] = C2C3
    Q[2, 3] = OC3
    Q[2, 4] = IC3
    Q[3, 3] = -(OC3 + OI) # O
    Q[3, 2] = C3O
    Q[3, 4] = IO
    Q[4, 4] = -(IO + IC3) # I
    Q[4, 2] = C3I 
    Q[4, 3] = OI
    
    return Q * 1000 # ms^-1 to s^-1


def steady_state(Vm: float, weights: dict) -> np.ndarray[np.float64]:
    """Goal:
        This function returns the steady state probabilities of the Markov model
        for the Kv11 channel at a given voltage Vm.
    ---------------------------------------------------------------------------
    Input:
        Vm: Membrane voltage [mV]
    ---------------------------------------------------------------------------
    Output:
        steady_state: Steady state probabilities of the Markov model
    """
    # Create transition matrix and adapt the last row in such a way that all
    # have weight 1
    Q = transition_matrix(Vm, weights)
    Q[-1, :] = 1
    # Solve the linear system --> last equation ensures that sum of probabilities is 1
    b = np.array([0, 0, 0, 0, 1])
    steady_state = np.linalg.solve(Q, b)
    return steady_state

@njit(fastmath = True)
def kv11(t: np.float64, y: np.ndarray[np.float64], Vm: float, weights: dict) -> np.ndarray[np.float64]:
    """Goal:
        This function returns the derivative of the state probabilities of the
        Kv11 channel at a given voltage Vm.
    ---------------------------------------------------------------------------
    Input:
        t: Time [s]
        y: State probabilities of the Markov model
        Vm: Membrane voltage [mV]
    ---------------------------------------------------------------------------
    Output:
        dydt: Derivative of the state probabilities of the Markov model
    """
    Q = transition_matrix(Vm, weights)
    return Q @ y

# Function to simulate Experiments

In [62]:
def get_peak_tail_currents(
        Vmclamp: float = -80.0, 
        Vmpres: np.ndarray[np.float64] = np.arange(-80, 61, 10),
        dpres: np.ndarray[np.float64] = np.array([3.5]),
        Vmtests: np.ndarray[np.float64] = np.array([-50]), 
        dtest: float = 0.5,
        dt = 0.0001,
        Ek: float = -82.0,
        dp: float = None,
        Vmp: float = None,
        weights: dict = None
        ) -> np.ndarray[float]:
    """Goal:
        This function returns the (peak) tail currents of the Kv11 channel for a
        given voltage clamp protocol. The protocol consists of a pre-pulse at
        Vmpre [mV] for 'dpre' seconds, followed by a test pulse at Vmtest [mV] for
        'dtest' seconds. The membrane voltage is clamped at Vmclamp [mV] before.
    ---------------------------------------------------------------------------
    Input:
        Vmclamp: Membrane voltage [mV]
        Vmpres: Voltage(s) of the pre-pulse [mV]
        dpres: Duration(s) of the pre-pulse [s]
        Vmtests: Voltage(s) of the test pulse [mV]
        dtest: Duration of the test pulse [s]
        dt: Time step [s]
        Ek: Potassium reversal potential [mV]
        dp: Duration of pulse in between pre-pulse and the test-pulse [s]
        Vmp: Voltage of pulse in between pre-pulse and the test-pulse [mV]
    ---------------------------------------------------------------------------
    Output:
        tail_currents: Tail currents of the Kv11 channel for the given voltage clamp
            protocol (vmpres, dpres, vmtests, dtest/dt).
        peak_tail_currents: Peak tail currents of the Kv11 channel for the given
            voltage clamp protocol (vmpres, dpres, vmtests).
    """
    # Set time [s] for the pulse and test pulse
    if dp != None:
        t_pulse_span = (0, dp)
        t_pulse = np.arange(t_pulse_span[0], t_pulse_span[1], dt)
    t_test_span = (0, dtest)
    t_test = np.arange(t_test_span[0], t_test_span[1], dt)


    # Get y0
    y0 = steady_state(Vmclamp, weights)

    # Initialize arrays for tail currents and peak tail currents
    pre_tail_currents = np.empty(
        (Vmpres.shape[0], dpres.shape[0], int(dpres/dt))
        )
    tail_currents = np.empty(
        (Vmpres.shape[0], dpres.shape[0], Vmtests.shape[0], int(dtest/dt))
        )
    peak_tail_currents = np.empty(
        (Vmpres.shape[0], dpres.shape[0], Vmtests.shape[0])
        )

    # Perform the simulation
    for iVmpre, Vmpre in enumerate(Vmpres):
        for idpre, dpre in enumerate(dpres):
            for iVmtest, Vmtest in enumerate(Vmtests):
                # Solve ODEs
                sol = solve_ivp(
                    kv11, (0, dpre), y0.copy(), args=(Vmpre, weights), 
                    t_eval = np.arange(0, dpre, dt), method='RK45'
                    )
                
                # Append the pre-pulse tail currents
                pre_tail_current = sol.y[3, :] * (Vmpre - Ek)
                pre_tail_currents[iVmpre, idpre, :] = pre_tail_current

                if dp != None:
                    sol = solve_ivp(
                        kv11, t_pulse_span, sol.y[:, -1], args=(Vmp, weights), 
                        t_eval = t_pulse, method='RK45'
                        )

                sol = solve_ivp(
                    kv11, t_test_span, sol.y[:, -1], args=(Vmtest, weights), 
                    t_eval = t_test, method='RK45'
                    )
                # Append the tail currents 
                tail_current = sol.y[3, :] * (Vmtest - Ek)
                tail_currents[iVmpre, idpre, iVmtest, :] = tail_current

                argabsmax = np.argmax(abs(tail_current))
                peak_tail_currents[iVmpre, idpre, iVmtest] = tail_current[argabsmax]

    return pre_tail_currents, tail_currents, peak_tail_currents

# Function to train the Model

In [70]:

def optimize(params):
    # ----- Parameters --> see function for more details
    Vmclamp = -80.0 # [mV]
    Vmpre = np.arange(-90, 41, 10)  # [mV]
    dpre = np.array([4.0]) # [s]
    Vmtest = np.array([-100]) # [mV] 
    dtest = 4 # [s]
    dt = 0.0001 # [s]
    Ek = -82.0 # [mV]

    weights = Dict.empty(
        key_type=types.string,   # Define key type
        value_type=types.float64 # Define value type
    )

    # weights["C1C2_1"] = 0.0069
    # weights["C1C2_2"] = 0.0272
    # weights["C2C1_1"] = 0.0227
    # weights["C2C1_2"] = -0.0431
    # weights["C3O_1"] = 0.0218
    # weights["C3O_2"] = 0.0262
    # weights["OC3_1"] = 0.0009
    # weights["OC3_2"] = -0.0269
    # weights["OI_1"] = 0.0622
    # weights["OI_2"] = 0.0120
    # weights["IO_1"] = 0.0059
    # weights["IO_2"] = -0.0443
    # weights["C2C3"] = 0.0266
    # weights["C3C2"] = 0.1348
    # weights["C3I_1"] = 1.29E-5
    # weights["C3I_2"] = 2.71E-6
    weights["C1C2_1"] = params[0]
    weights["C1C2_2"] = params[1]
    weights["C2C1_1"] = params[2]
    weights["C2C1_2"] = params[3]
    weights["C3O_1"] = params[4]
    weights["C3O_2"] = params[5]
    weights["OC3_1"] = params[6]
    weights["OC3_2"] = params[7]
    weights["OI_1"] = params[8]
    weights["OI_2"] = params[9]
    weights["IO_1"] = params[10]
    weights["IO_2"] = params[11]
    weights["C2C3"] = params[12]
    weights["C3C2"] = params[13]
    weights["C3I_1"] = params[14]
    weights["C3I_2"] = params[15]

    # Get peak tail currents
    pretail_currents, tail_currents, peak_tail_currents = get_peak_tail_currents(
        Vmclamp = Vmclamp, Vmpres = Vmpre, dpres = dpre, Vmtests = Vmtest, 
        dtest = dtest, dt = dt, Ek = Ek, weights = weights
        )
    

    # Pretail currents
    ref_pretails = np.array(
        [
            [
                -0.03454138580784605, -0.03423230818229328, -0.02780914168986559,
                0.0012080699069789702, 0.21678108465850432, 0.5636477644692519,
                0.8305041312330013, 0.9292316832283776, 0.9110572912771588,
                0.8355136990420688, 0.7025067712056803, 0.5797715754225252,
                0.4508658095820337, 0.3465670157169347
            ], # 21 oC	
            [
                -0.046835458263800867, -0.036301716243728244, -0.03191030368576353,
                0.021713879886473997, 0.24551745922502644, 0.6436298369210247,
                1.0335116500299955, 1.2942539279268659, 1.3991238093843528,
                1.3850505794290275, 1.2643565511120776, 1.088315662913979,
                0.8835384001493587, 0.6890140423744917
            ], # 30 oC
            [
                -0.11053693401286058, -0.08577049836324324, -0.04832239383312231,
                0.0631248992658211, 0.38598891785093015, 0.9097021602965616, 
                1.513747089129601, 2.0712770037070176, 2.48082521439012,
                2.655697525134914, 2.614928110532488, 2.3944478099694875,
                1.9984916272529718, 1.6701921313406958
            ] # 35 oC
        ]
        ) 
    ref_pretail = ref_pretails[2, :]


    ref_norm = ref_pretail / np.max(ref_pretail)
    pretail_currents_norm = pretail_currents[:, 0, -1] / np.max(pretail_currents[:, 0, -1])
    error1 = np.sum((ref_norm - pretail_currents_norm)**2)

    # Charge of test tail currents
    ref_tails = np.array(
        [   
            [
                0.015903307888041063, 0.006361323155216647, 0.019720101781171007,
                0.08778625954198493, 0.3473282442748098, 0.6692111959287533,
                0.8695928753180664, 0.9681933842239188, 1, 1.0025445292620867,
                0.9968193384223919, 0.9834605597964379, 0.9739185750636135
            ], # 21 oC
            [
                0, 0.013358778625953693, 0.02735368956743023,
                0.08396946564885521, 0.21501272264631055, 0.47582697201017854,
                0.701017811704835, 0.8638676844783716, 0.9535623409669212,
                1.0050890585241732, 1.0095419847328246, 0.9955470737913488,
                0.9688295165394403
            ], # 30 oC
            [
                0.011365557126790549, 0.017608983692143365, 0.03020161946472677,
                0.08533440706466, 0.19633978554932674, 0.37464723176954373,
                0.5764458517787127, 0.7509437724745911, 0.9000439560637423,
                0.9672397910185134, 0.9855471652165066, 0.9702037306161748,
                1.0018430937326483
            ], # 35 oC


            ]
        ) 
    ref_tail = ref_tails[2, :]

    argabsmax = np.argmax(abs(peak_tail_currents[:-1, 0, 0]))
    charge_norm = peak_tail_currents[:-1, 0, 0] / peak_tail_currents[argabsmax, 0, 0]
    error2 = np.sum((ref_tail - charge_norm)**2)

    
    # plt.plot(Vmtest[1:], ref_norm, label = 'Ref')
    # plt.plot(Vmtest[1:], tail_currents_norm, label = 'Tail currents')
    # plt.legend()
    # plt.show()
    
    print(error1 + error2)
    return error1 + error2

params0 = [
    0.0069, 0.0272, 0.0227, -0.0431, 0.0218, 0.0262, 0.0009, -0.0269,
    0.0622, 0.0120, 0.0059, -0.0443, 0.0266, 0.1348, 1.29E-5, 2.71E-6
]

minimize(optimize, params0, method = "Nelder-Mead")


# -80
# 0.5 s +80
# 0.25 +80:-120:10
# +80

1.5992474286679275
1.595396318098035
1.6204821987363707
1.6239841994794975
1.6330463306732623
1.5957425133481307
1.6200882095380131
1.6176713177179722
1.6141808180124655
1.5980720014980072
1.6221825757483028
1.6109720794074827
1.6839001149792994
1.5964032718967898
1.6237053017717882
1.5992076064701255
1.5992474401897598
1.5491974337586703
1.4941581578102423
1.5672025748085028
1.5791022565761152
1.576939907044395
1.5527337500389375
1.5543713633877396
1.5475291548566288
1.5429102639281806
1.5323045207237038
1.5349836829890708
1.5276511376324005
1.5185881903176777
1.5084061987444262
1.4979723095632713
1.488179460577477
1.4356591302536077
1.4684273603411646
1.4527667245856741
1.4531702632624395
1.442482521085028
1.4333106419914836
1.3651192934306156
1.4232442845910895
1.4083562900755102
1.3965203346500248
1.3842410909138059
1.3872730110348166
1.3625199940319361
1.287671215077165
1.3442139166938505
1.337832721262935
1.3339398497012738
1.3168870706513736
1.3116663679205032
1.287992586613138


KeyboardInterrupt: 