In [None]:
using Statistics

using ProgressMeter

using DifferentialEquations
using Printf
using Plots
gr(fmt="png", size=(800, 300))

include("../src/HHModel.jl")

In [None]:
# Persistent Sodium Plus Potassium Model

# stimulus setup
function dual_pulse(t, param)
    _time = t - param.start
    if 0 < _time < param.pulse_length
        param.pulse_one + param.baseline
    elseif param.pulse_length + param.pulse_interval < _time < 2*param.pulse_length + param.pulse_interval
        param.pulse_two + param.baseline
    else
        param.baseline
    end
end

# Biophysics setup

Na_m = HHModel.Kinetics(1, -20.0, 15.0, _type=:instantaneous)
Na = HHModel.SimpleIonChannel("Persistent Sodium", :sodium, 20.0,
    Na_m, HHModel.Kinetics());

# K_m = HHModel.Kinetics(1, -25.0, 5.0)
K_m = HHModel.Kinetics(1, -55.0, 5.0)
K = HHModel.SimpleIonChannel("ltk", :potassium, 9.0,
    K_m, HHModel.Kinetics());

il = HHModel.leakage(8);

_model = [Na, K, il];

In [None]:
# setup initial values and run the simulation
u0 = HHModel.setup_init(_model, -70.0)
# u0 = [0.9, Na.m.infty(0.9), Na.h.infty(0.9), K.m.infty(0.9), 0.0]
tspan = (0.0, 30.0)

# K.m.Vhalf = -30.5
# HHModel.update!(K) 
# _model = HHModel.simpleConductanceModel([Na, K, L], stimulus)

_reversal_potential = (sodium=60.0, potassium=-90.0, leak=-80.0) # ikh
# _reversal_potential = (sodium=60.0, potassium=-90.0, leak=-78.0) # ikl

_stimulus_parameter = (start=10, pulse_length=0.5, pulse_interval=1.5, pulse_one=110, pulse_two=150, baseline=0.0)

# create model cell simulation
_model_sim = HHModel.simpleConductanceModel(_model, dual_pulse, C=1.0)


_p = (E=_reversal_potential, stim=_stimulus_parameter)

prob = ODEProblem(_model_sim, u0, tspan, _p)
sol = solve(prob, Tsit5(), reltol=1e-8, abstol=1e-8);

# simple preview
a = plot(sol, vars=(1), legend=nothing, size=(1200, 300))
vline!([11.5])
# plot!(sol, vars=(length(u0)))

In [None]:
_trace_1_tspan = _stimulus_parameter.start:0.01:(_stimulus_parameter.start+1.5)
_trace_2_start = (_stimulus_parameter.start + _stimulus_parameter.pulse_length + _stimulus_parameter.pulse_interval)
_trace_2_tspan = _trace_2_start:0.01:(_trace_2_start+1.5)

_trace_1 = hcat(sol(_trace_1_tspan)...)[1,:];
_max_1 = maximum(_trace_1)

_trace_2 = hcat(sol(_trace_2_tspan)...)[1,:];
_max_2 = maximum(_trace_2)

_diff = _max_2 - _max_1

In [None]:
_n = 5
_interval_range = 1.0:0.1:2.0
_pulse_range = 100:10:200
_result = zeros(length(_interval_range), length(_pulse_strength), _n)
_progress_bar = Progress(length(_pulse_strength) * length(_interval_range) * _n)

for _n_idx = 1:_n
    for (_interval_idx, _interval_step) in enumerate(_interval_range)
        for (_pulse_idx, _pulse_step) in enumerate(_pulse_range)

            _stimulus_parameter = (start=10, pulse_length=0.5, pulse_interval=_interval_step, pulse_one=110, pulse_two=_pulse_step, baseline=0.0)

            _p = (E=_reversal_potential, stim=_stimulus_parameter)

            prob = ODEProblem(_model_sim, u0, tspan, _p)
            sol = solve(prob, Tsit5(), reltol=1e-8, abstol=1e-8);

            _trace_1_tspan = _stimulus_parameter.start:0.01:(_stimulus_parameter.start+1.5)
            _trace_2_start = (_stimulus_parameter.start + _stimulus_parameter.pulse_length + _stimulus_parameter.pulse_interval)
            _trace_2_tspan = _trace_2_start:0.01:(_trace_2_start+1.5)

            _trace_1 = hcat(sol(_trace_1_tspan)...)[1,:];
            _max_1 = maximum(_trace_1)

            _trace_2 = hcat(sol(_trace_2_tspan)...)[1,:];
            _max_2 = maximum(_trace_2)

            _diff = _max_2 - _max_1
            _result[_interval_idx, _pulse_idx, _n_idx] = _diff > 25 ? NaN : _diff
            next!(_progress_bar)
        end
    end
end


_correct_result = zeros(size(_result)[1:2])
for _i_idx = 1:size(_result, 1)
    for _j_idx = 1:size(_result, 2)
        _a = _result[_i_idx, _j_idx, :]
        _a = _a[not.(isnan.(_a))]
        _avg = mean(_a)
        _std = std(_a)
        _a = _std < 3 ? _a : _a[(_avg - _std .< _a) .& (_a .< _avg + _std)]
        _correct_result[_i_idx, _j_idx] = mean(_a)
    end
end

heatmap(_pulse_range ./ 110, _interval_range, _correct_result, #fill=true,
    xlabel="relative pulse strength", ylabel="pulse interval (ms)", colorbar_title="spike strength difference (mV)",
    clim=(-70, 0)
)

In [None]:
_correct_result[9, 1] = (_correct_result[8, 1] + _correct_result[10, 1]) / 2
plot(_correct_result[:, 1])

In [None]:
heatmap(_pulse_range ./ 110, _interval_range, _correct_result, #fill=true,
    xlabel="relative pulse strength", ylabel="pulse interval (ms)", colorbar_title="spike strength difference (mV)",
    clim=(-70, 0)
)
# savefig("ltk_10.svg")

In [None]:
# animation
_n = 5
_interval_range = 1.0:0.1:2.0
_pulse_range = 100:10:200
_g_ltk_range = 8:0.05:11
_progress_bar = Progress(length(_g_ltk_range) * length(_pulse_strength) * length(_interval_range) * _n)

anim = @animate for _g_ltk = _g_ltk_range
    K.g = _g_ltk
    _result = zeros(length(_interval_range), length(_pulse_strength), _n)

    for _n_idx = 1:_n
        for (_interval_idx, _interval_step) in enumerate(_interval_range)
            for (_pulse_idx, _pulse_step) in enumerate(_pulse_range)

                _stimulus_parameter = (start=10, pulse_length=0.5, pulse_interval=_interval_step, pulse_one=110, pulse_two=_pulse_step, baseline=0.0)

                _p = (E=_reversal_potential, stim=_stimulus_parameter)

                prob = ODEProblem(_model_sim, u0, tspan, _p)
                sol = solve(prob, Tsit5(), reltol=1e-8, abstol=1e-8);

                _trace_1_tspan = _stimulus_parameter.start:0.01:(_stimulus_parameter.start+1.5)
                _trace_2_start = (_stimulus_parameter.start + _stimulus_parameter.pulse_length + _stimulus_parameter.pulse_interval)
                _trace_2_tspan = _trace_2_start:0.01:(_trace_2_start+1.5)

                _trace_1 = hcat(sol(_trace_1_tspan)...)[1,:];
                _max_1 = maximum(_trace_1)

                _trace_2 = hcat(sol(_trace_2_tspan)...)[1,:];
                _max_2 = maximum(_trace_2)

                _diff = _max_2 - _max_1
                _result[_interval_idx, _pulse_idx, _n_idx] = _diff > 25 ? NaN : _diff
                next!(_progress_bar)
            end
        end
    end


    _correct_result = zeros(size(_result)[1:2])
    for _i_idx = 1:size(_result, 1)
        for _j_idx = 1:size(_result, 2)
            _a = _result[_i_idx, _j_idx, :]
            _a = _a[not.(isnan.(_a))]
            _avg = mean(_a)
            _std = std(_a)
            _a = _std < 3 ? _a : _a[(_avg - _std .< _a) .& (_a .< _avg + _std)]
            _correct_result[_i_idx, _j_idx] = mean(_a)
        end
    end

    heatmap(_pulse_range ./ 110, _interval_range, _correct_result, #fill=true,
        xlabel="relative pulse strength", ylabel="pulse interval (ms)", colorbar_title="spike strength difference (mV)",
        clim=(-70, 0), title=@sprintf("g_ltk = %.2f", _g_ltk)
    ) 
end

In [None]:
gif(anim, "test.gif", fps=12)