Skip to content

Commit

Permalink
update to new syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jan 24, 2018
1 parent e1b8d82 commit 5c443ea
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 17 deletions.
4 changes: 2 additions & 2 deletions REQUIRE
@@ -1,4 +1,4 @@
julia 0.5
DiffEqBase 1.5.1
julia 0.6
DiffEqBase 3.0.0
Compat 0.17.0
Reexport
1 change: 1 addition & 0 deletions deps/.gitignore
@@ -0,0 +1 @@
daskr.so
13 changes: 6 additions & 7 deletions src/common.jl
Expand Up @@ -120,14 +120,13 @@ function solve{uType,duType,tType,isinplace,LinearSolver}(

### Fix the more general function to DASKR allowed style
if !isinplace && (typeof(prob.u0)<:Vector{Float64} || typeof(prob.u0)<:Number)
f! = (t,u,du,out) -> (out[:] = prob.f(t,u,du); nothing)
f! = (out,du,u,p,t) -> (out[:] = prob.f(du,u,p,t); nothing)
elseif !isinplace && typeof(prob.u0)<:AbstractArray
f! = (t,u,du,out) -> (out[:] = vec(prob.f(t,reshape(u,sizeu),reshape(du,sizedu))); nothing)
f! = (out,du,u,p,t) -> (out[:] = vec(prob.f(reshape(du,sizedu),reshape(u,sizeu),p,t)); nothing)
elseif typeof(prob.u0)<:Vector{Float64}
f! = prob.f
else # Then it's an in-place function on an abstract array
f! = (t,u,du,out) -> (prob.f(t,reshape(u,sizeu),reshape(du,sizedu),out);
u = vec(u); du=vec(du); 0)
f! = (out,du,u,p,t) -> (prob.f(out,reshape(du,sizedu),reshape(u,sizeu),p,t); 0)
end

if prob.differential_vars == nothing
Expand Down Expand Up @@ -159,7 +158,7 @@ function solve{uType,duType,tType,isinplace,LinearSolver}(
lrw = Int32[N[1]^3 + 9 * N[1] + 60 + 3 * nrt[1]]
rwork = zeros(lrw[1])

liw = Int32[2*N[1] + 40]
liw = Int32[2*N[1] + 40]
iwork = zeros(Int32, liw[1])
iwork[1] = alg.jac_lower
iwork[2] = alg.jac_upper
Expand Down Expand Up @@ -205,11 +204,11 @@ function solve{uType,duType,tType,isinplace,LinearSolver}(

jroot = zeros(Int32, max(nrt[1], 1))
ipar = Int32[length(u0), nrt[1], length(u0)]
res = DASKR.res_c(f!)
res = DASKR.common_res_c(f!,prob.p)
rt = Int32[0]

if has_jac(f!)
jac = common_jac_c(f!)
jac = common_jac_c(f!,prob.p)
info[5] = 1 # Enables Jacobian
else
jac = Int32[0]
Expand Down
24 changes: 22 additions & 2 deletions src/core.jl
Expand Up @@ -18,6 +18,7 @@ function res_c(fun)
Ptr{Int32}, Ptr{Float64}, Ptr{Int32}))
end


"""
Return a C-style callback for the event-handling function `fun`. Suitable for use with `unsafe_solve`.
"""
Expand Down Expand Up @@ -57,18 +58,37 @@ function jac_c(fun)
Ptr{Float64}, Ptr{Int32}))
end

"""
Return a C-style callback for the residual function `fun`. Suitable for use with `unsafe_solve`.
"""
function common_res_c(fun,p)
newfun = function(t, y, yp, cj, delta, ires, rpar, ipar)
n = convert(Array{Int}, unsafe_wrap(Array, ipar, (3,)))
t = unsafe_wrap(Array, t, (1,))
y = unsafe_wrap(Array, y, (n[1],))
yp = unsafe_wrap(Array, yp, (n[1],))
delta = unsafe_wrap(Array, delta, (n[1],))
fun(delta,yp,y,p,first(t))
return nothing
end
cfunction(newfun, Void,
# T, Y, YPRIME, CJ, DELTA, IRES, RPAR, IPAR
(Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Ptr{Float64},
Ptr{Int32}, Ptr{Float64}, Ptr{Int32}))
end

"""
Return a C-style callback for the Jacobian function `fun`. Suitable for use with `unsafe_solve`. For a common interface passed function.
"""
function common_jac_c(fun)
function common_jac_c(fun,p)
newfun = function(t, y, yp, pd, cj, rpar, ipar)
n = convert(Array{Int}, unsafe_wrap(Array, ipar, (3,)))
_t = unsafe_wrap(Array, t, (1,))
_y = unsafe_wrap(Array, y, (n[1],))
_yp = unsafe_wrap(Array, yp, (n[1],))
_pd = unsafe_wrap(Array, pd, (n[3], n[1]))
_cj = unsafe_wrap(Array, cj, (1,))
fun(Val{:jac},first(_t), _y, _yp, first(_cj[1]), _pd)
fun(Val{:jac},_pd,_yp,_y,p,first(_cj[1]),first(_t))
return nothing
end
cfunction(newfun, Void,
Expand Down
12 changes: 6 additions & 6 deletions test/runtests.jl
Expand Up @@ -54,20 +54,20 @@ end


# Test the JuliaDiffEq common interface
function resrob(tres, y, yp, r)
function resrob(r,yp,y,p,tres)
r[1] = -0.04*y[1] + 1.0e4*y[2]*y[3]
r[2] = -r[1] - 3.0e7*y[2]*y[2] - yp[2]
r[1] -= yp[1]
r[3] = y[1] + y[2] + y[3] - 1.0
end

function testjac(t,u,du,res)
function testjac(res,du,u,p,t)
res[1] = du[1] - 1.5 * u[1] + 1.0 * u[1]*u[2]
res[2] = du[2] +3 * u[2] - u[1]*u[2]
end

jac_called = false
function testjac(::Type{Val{:jac}},t,u,du,gamma,J)
function testjac(::Type{Val{:jac}},J,du,u,p,gamma,t)
global jac_called
jac_called = true
J[1,1] = gamma - 1.5 + 1.0 * u[2]
Expand Down Expand Up @@ -114,7 +114,7 @@ let
@test maximum(sol[end]) < 2 #should be cyclic

# inconsistent initial conditions
function f!(t, u, du, res)
function f!(res,du,u,p,t)
res[1] = du[1]-1.01
return
end
Expand All @@ -125,12 +125,12 @@ let
sol = solve(dae_prob,daskr())

# Jacobian
function f2!(t, u, du, res)
function f2!(res,du,u,p,t)
res[1] = 1.01du[1]
return
end

function f2!(::Type{Val{:jac}},t,u,du,gamma,out)
function f2!(::Type{Val{:jac}},out,du,u,p,gamma,t)
global jac_called
jac_called = true
out[1] = 1.01
Expand Down

0 comments on commit 5c443ea

Please sign in to comment.