Skip to content

Conversation

@ChrisRackauckas
Copy link
Member

@ChrisRackauckas ChrisRackauckas commented Aug 21, 2020

This is exciting because it finally works! The tests show that you can change keyword arguments around and inference is still fine, thanks to @JeffBezanson! This requires Julia v1.5 in order to pass because it needs constant prop through the keyword arguments, since for example save_idxs can change the values that are saved from u to u[1] and thus change the type of the output. This should make a lot of downstream packages infer better (I'm looking at you Pumas)!

While this is exciting, there is a trade-off that is required to make this PR work out. Notice that I had to drop the automated progress bar handling from @devmotion @tkf @c42f. Something about it is making everything not infer, even with constant propagation. Could we figure out what that is? I'd hope to keep everything like a greedy programmer. To "fake it before we make it", I tried to force inference via Core.Compiler.return_type, but that only does positional arguments. So I tried stuff like:

  function testf(_prob,args,kwargs)
    __solve(_prob,args...;kwargs)
  end
  T = Core.Compiler.return_type(testf,Tuple{typeof(_prob),map(typeof, args)...,map(typeof, kwargs)...})

But I am not sure how to typeof over the kwargs, and even if I did, I'd have to reconstruct the kwargs so this would need to become a generated function. I am assuming that would be trying too hard, and so does anyone have any ideas on how to force inference through that choice?

Fixes SciML/DifferentialEquations.jl#603

@ChrisRackauckas
Copy link
Member Author

  function f()
    __solve(_prob,args...;kwargs...)
  end

  T = Core.Compiler.return_type(f,Tuple{})

  if hasfield(typeof(_prob),:f) && hasfield(typeof(_prob.f),:f) && typeof(_prob.f.f) <: EvalFunc
    Base.invokelatest(__solve,_prob,args...; kwargs...)::T
  else
    __solve(_prob,args...;kwargs...)::T
  end

infers Any for some reason, even though it works without the ::T...?

@ChrisRackauckas ChrisRackauckas merged commit 38d71b2 into master Aug 21, 2020
@ChrisRackauckas ChrisRackauckas deleted the inference branch August 21, 2020 17:02
@mchitre
Copy link
Contributor

mchitre commented Aug 23, 2020

Hmmm ... I'm on 6.45.0, so this PR should be in. Wondering why I'm still not being able to infer through the solver. See anything I might be doing wrong?

function traceray1(model, r0, z0, θ, rmax, ds, q0)
  a = altimetry(model.env)
  b = bathymetry(model.env)
  c = z -> soundspeed(ssp(model.env), 0.0, 0.0, z)
  ∂c = z -> ForwardDiff.derivative(c, z)
  ∂²c = z -> ForwardDiff.derivative(∂c, z)
  cᵥ = c(z0)
  u0 = [r0, z0, cos(θ)/cᵥ, sin(θ)/cᵥ, 0.0, 1/cᵥ, q0]
  prob = ODEProblem(rayeqns!, u0, (0.0, model.rugocity * (rmax-r0)/cos(θ)), (c, ∂c, ∂²c))
  cb = VectorContinuousCallback(
    (out, u, s, i) -> checkray!(out, u, s, i, a, b, rmax),
    (i, ndx) -> terminate!(i), 4; rootfind=true)
  soln = solve(prob, model.solver; abstol=model.solvertol, save_everystep=false, callback=cb)
  s2 = soln[end]
  soln.t[end], s2[1], s2[2], atan(s2[4], s2[3]), s2[5], s2[7], soln.u, soln.t
end

and:

julia> @code_warntype UnderwaterAcoustics.traceray1(pm, 0.0, -5.0, 0.0, 100.0, 1.0, 0.0)
Variables
  #self#::Core.Compiler.Const(UnderwaterAcoustics.traceray1, false)
  r0::Float64
  z0::Float64
  θ::Float64
  rmax::Float64
  ds::Float64
  q0::Float64
  a::FlatSurface
  b::ConstantDepth{Float64}
  c::UnderwaterAcoustics.var"#203#208"{RaySolver}
  ∂c::UnderwaterAcoustics.var"#204#209"{UnderwaterAcoustics.var"#203#208"{RaySolver}}
  ∂²c::UnderwaterAcoustics.var"#205#210"{UnderwaterAcoustics.var"#204#209"{UnderwaterAcoustics.var"#203#208"{RaySolver}}}
  cᵥ::Float64
  u0::Array{Float64,1}
  prob::Any
  cb::DiffEqBase.VectorContinuousCallback{UnderwaterAcoustics.var"#206#211"{Float64,FlatSurface,ConstantDepth{Float64}},UnderwaterAcoustics.var"#207#212",UnderwaterAcoustics.var"#207#212",typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64}
  soln::Any
  s2::Any

and for reference, the 2 callbacks, which are correctly type inferred, if I pass in all the value manually:

function rayeqns!(du, u, params, s)
  # implementation based on COA (3.161-164, 3.58-63)
  r, z, ξ, ζ, t, p, q = u
  c, ∂c, ∂²c = params
  cᵥ = c(z)
  cᵥ² = cᵥ * cᵥ
  c̄nn = ∂²c(z) * ξ * ξ
  du[1] = cᵥ * ξ
  du[2] = cᵥ * ζ
  du[3] = 0
  du[4] = -∂c(z) / cᵥ²
  du[5] = 1 / cᵥ
  du[6] = -c̄nn * q
  du[7] = cᵥ * p
end

function checkray!(out, u, s, integrator, a::Altimetry, b::Bathymetry, rmax)
  out[1] = altitude(a, u[1], 0.0) - u[2]
  out[2] = u[2] + depth(b, u[1], 0.0)
  out[3] = u[1] - rmax
  out[4] = u[3]
end

@devmotion
Copy link
Member

It seems prob::Any is not inferred, which then probably leads to soln::Any. So it might be a problem with type inference of the problem rather than the solver.

@mchitre
Copy link
Contributor

mchitre commented Aug 23, 2020

Yes, fair.

But any thoughts on what might prevent the problem from inferring?

@devmotion
Copy link
Member

Maybe it helps to specify prob = ODEProblem{true}(....) to indicate that you use an inplace function. I assume that it might be problematic/impossible to infer it otherwise.

@mchitre
Copy link
Contributor

mchitre commented Aug 23, 2020

Bingo! That fixed it!! Thanks so much!!!

@devmotion
Copy link
Member

Great that it works 😃

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

type unstable when returning EnsembleSolution object

4 participants