Skip to content

Commit

Permalink
Fix SKenCarp
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Aug 19, 2018
1 parent c8147d4 commit ed13a28
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions src/perform_step/kencarp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ end
@unpack ea21,ea31,ea32,ea41,ea42,ea43,eb1,eb2,eb3,eb4 = cache.tab
@unpack ebtilde1,ebtilde2,ebtilde3,ebtilde4 = cache.tab
@unpack nb021,nb043 = cache.tab
nlsolve! = cache.nlsolve; nlcache = nlsolve!.cache
alg = unwrap_alg(integrator, true)
nlsolve! = cache.nlsolve; nlcache = nlsolve!.cache

# Some aliases

Expand All @@ -165,9 +165,11 @@ end
end

# precalculations

γdt = γ*dt

nlsolve! isa NLNewton && calc_W!(integrator, cache, γdt, repeat_step)
calc_W!(integrator, cache, γdt, repeat_step)
new_W = true

if !repeat_step && !integrator.last_stepfail
f(z₁, integrator.uprev, p, integrator.t)
Expand All @@ -176,9 +178,6 @@ end

##### Step 2

# initial step of Newton iteration
nlcache.c = 2*γdt

# TODO: Add a cache so this isn't overwritten near the end, so it can not repeat on fail
g(g1,uprev,p,t)

Expand All @@ -196,6 +195,8 @@ end
@. z₂ = z₁
end
nlcache.z = z₂
nlcache.c = 2γ


if typeof(integrator.f) <: SplitSDEFunction
# This assumes the implicit part is cheaper than the explicit part
Expand All @@ -210,7 +211,6 @@ end

################################## Solve Step 3

# initial step of Newton iteration
nlcache.c = c3

if typeof(integrator.f) <: SplitSDEFunction
Expand Down Expand Up @@ -328,7 +328,6 @@ end
@. dz += btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄
end
end
if alg.smooth_est # From Shampine
if has_invW(f)
mul!(vec(tmp),W,vec(dz))
Expand All @@ -338,13 +337,12 @@ end
else
tmp .= dz
end
=#

@. E₁ = z₁ + z₂ + z₃ + z₄

@tight_loop_macros for (i,atol,rtol,δ) in zip(eachindex(u),Iterators.cycle(integrator.opts.abstol),
Iterators.cycle(integrator.opts.reltol),Iterators.cycle(integrator.opts.delta))
Iterators.cycle(integrator.opts.reltol),Iterators.cycle(integrator.opts.delta))
@inbounds tmp[i] =*E₁[i]+E₂[i])/(atol + max(integrator.opts.internalnorm(uprev[i]),integrator.opts.internalnorm(u[i]))*rtol)
end
integrator.EEst = integrator.opts.internalnorm(tmp)
Expand Down

0 comments on commit ed13a28

Please sign in to comment.