In [None]:
using Plots; gr()
using DifferentialEquations
using BenchmarkTools

In [None]:
# [Van-Pottelbergh 2018]
# T. Van Pottelbergh, G. Drion, and R. Sepulchre.
# Robust modulation of integrate-and-fire models.
# Neural Computation, 30(4):987–1011, Apr. 2018.

# Model definition

In [None]:
# parameters
C     = 1.
taus  = 10.
tauus = 100.

Vmax = 50.

gf  = -1. # inverse sign compared with [Van-Pottelbergh 2018]
gs  = 0.5
gus = 0.015

V0   = -40.
Vs0  = -38.4
Vus0 = -50.

Vr    = -40.
Vsr   = -35.
DVusr =   3.

# input
I0 = 5.
Iapp(t) = I0 ;

In [None]:
function f!(dx,x,p,t)
    V    = x[1]
    Vs   = x[2]
    Vus  = x[3]
    Iion  = gf*(V - V0)^2 + gs*(Vs - Vs0)^2 + gus*(Vus - Vus0)^2 
    # inverse sign in front of gf
    dx[1] = 1/C * ( Iapp(t) - Iion )
    dx[2] = 1/taus  * (V - Vs )
    dx[3] = 1/tauus * (V - Vus)
end

function spike(x)  # spikes when spike(x) goes from negative to positive
    (x[1] - Vmax)
end

function reset!(x) # reset function
    x[1] = Vr
    x[2] = Vsr
    x[3] = x[3] + DVusr
end

In [None]:
x0    = [-40.0;-40.0;-40.0]
tspan = (0.0,500.0)

# Julia Solver (DifferentialEquations.jl)

In [None]:
# event when event_f(u,t) == 0
function condition(x,t,integrator) # 
    spike(x)
end
# when condition == 0 and upcrossing (from negative to positive) 
function affect!(integrator)      
    reset!(integrator.u)
end

cb   = ContinuousCallback(condition,affect!,nothing)
prob = ODEProblem(f!,x0,tspan,callback=cb)

sol  = solve(prob,dense=false);  # dense=false, avoids nonlinear interpolations between time steps when plotting (no impact on computation)

In [None]:
p1 = plot(sol,label=["V" "Vs" "Vus"])
pu = plot(sol.t,Iapp.(sol.t),label="Iapp")
plot(p1, pu, layout = (2,1))

# Homemade Euler integration

In [None]:
function solve_homemade_euler(dt)
    t = []
    x = []
    dx = zeros(length(x0))
    x  = push!(x,x0)
    t  = push!(t,tspan[1])
    while t[end] < tspan[2]

        # flow
        f!(dx,x[end],[],t[end])
        x = push!(x,x[end] + dt*dx)
        t = push!(t,t[end] + dt)
        
        # jump
        if spike(x[end]) > 0
            x = push!(x,x[end])
            t = push!(t,t[end])            
            reset!(x[end])
        end
        
    end
    return t, x
end

In [None]:
function solve_homemade_euler2(dt)
    
    t = collect(tspan[1]:dt:tspan[2])
    n = length(t)
    x = fill(zeros(size(x0)), n, 1)
    
    x[1] = x0
    
    dx = zeros(length(x0))
    for i in 1:n-1

        # flow
        f!(dx,x[i],[],t[i])
        x[i+1] = x[i] + dt*dx
        
        # jump
        if spike(x[i+1]) > 0
            reset!(x[i+1])                                
        end
        
    end
    return t, x
end

In [None]:
dt   = 0.01

t, x = solve_homemade_euler(dt);

p1 = plot(t,hcat(x...)',label=["V" "Vs" "Vus"])
pu = plot(t,Iapp.(t),label="Iapp")
plot(p1, pu, layout = (2,1))

In [None]:
t, x = solve_homemade_euler2(dt);

p1 = plot(t,hcat(x...)',label=["V" "Vs" "Vus"])
pu = plot(t,Iapp.(t),label="Iapp")
plot(p1, pu, layout = (2,1))

# Benchmarking

The benchmarking takes about 10 minutes...

In [None]:
#BenchmarkTools.DEFAULT_PARAMETERS.samples = 20
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 10*5 # seconds

tim = []
mem = []
all = []
lab = [];

In [None]:
# benchmak Julia solver
let b
    b = @benchmark solve(prob,dense=false)
    push!( lab, " auto ")
    push!( tim, mean(b).time) # log of times in sec
    push!( mem, mean(b).memory)
    push!( all, mean(b).allocs)
end

dt_vec = [0.01,0.001,0.0001]

for dt in dt_vec
    
    # benchmak Euler1
    b = @benchmark solve_homemade_euler($dt)
    push!( lab, " HM1 $dt ")
    push!( tim, mean(b).time)
    push!( mem, mean(b).memory)
    push!( all, mean(b).allocs)
    println(string("dt = ", dt,", ", length(b.times)," samples"))
    
    # benchmak Euler2
    b = @benchmark solve_homemade_euler2($dt)
    push!( lab, " HM2 $dt ")
    push!( tim, mean(b).time) 
    push!( mem, mean(b).memory)
    push!( all, mean(b).allocs)
    println(string("dt = ", dt,", ", length(b.times)," samples"))
    
end

In [None]:
scatter(log10.(tim/1e9),xticks=(collect(1:length(lab)),lab))
ylabel!("log_10 time (s)")

In [None]:
scatter(log10.(mem/1e3),xticks=(collect(1:length(lab)),lab))
ylabel!("log_10 memory (KiB)")

In [None]:
scatter(log10.(all),xticks=(collect(1:length(lab)),lab))
ylabel!("log_10 alloc (-)")