Skip to content

Commit

Permalink
fix lazy restep
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jul 6, 2018
1 parent 19f6e54 commit d3bf319
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 23 deletions.
4 changes: 2 additions & 2 deletions src/dense/low_order_rk_addsteps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -639,9 +639,9 @@ Called to add the extra k9, k10, k11 steps for the Order 5 interpolation when ne
end
f(k7,tmp,p,t+dt)
@tight_loop_macros for i in uidx
@inbounds u[i] = uprev[i]+dt*(a81*k1[i]+a83*k3[i]+a84*k4[i]+a85*k5[i]+a86*k6[i]+a87*k7[i])
@inbounds tmp[i] = uprev[i]+dt*(a81*k1[i]+a83*k3[i]+a84*k4[i]+a85*k5[i]+a86*k6[i]+a87*k7[i])
end
f(k8,u,p,t+dt)
f(k8,tmp,p,t+dt)
copyat_or_push!(k,1,k1)
copyat_or_push!(k,2,k2)
copyat_or_push!(k,3,k3)
Expand Down
4 changes: 2 additions & 2 deletions src/integrators/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ function reeval_internals_due_to_modification!(integrator)
if integrator.opts.calck
resize!(integrator.k,integrator.kshortsize) # Reset k for next step!
alg = unwrap_alg(integrator, false)
if typeof(alg) == BS5 || typeof(alg) == Vern6 || typeof(alg) == Vern7 ||
typeof(alg) == Vern8 || typeof(alg) == Vern9
if typeof(alg) <: BS5 || typeof(alg) <: Vern6 || typeof(alg) <: Vern7 ||
typeof(alg) <: Vern8 || typeof(alg) <: Vern9
ode_addsteps!(integrator,integrator.f,true,false,!alg.lazy)
else
ode_addsteps!(integrator,integrator.f,true,false)
Expand Down
4 changes: 2 additions & 2 deletions src/perform_step/low_order_rk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ end
integrator.u = u

alg = unwrap_alg(integrator, false)
if !alg.lazy && (integrator.opts.adaptive = false || integrator.EEst <= 1.0)
if !alg.lazy && (integrator.opts.adaptive == false || integrator.EEst <= 1.0)
@unpack c6,c7,c8,a91,a92,a93,a94,a95,a96,a97,a98,a101,a102,a103,a104,a105,a106,a107,a108,a109,a111,a112,a113,a114,a115,a116,a117,a118,a119,a1110 = cache
k = integrator.k
k[9] = f(uprev+dt*(a91*k[1]+a92*k[2]+a93*k[3]+a94*k[4]+a95*k[5]+a96*k[6]+a97*k[7]+a98*k[8]),p,t+c6*dt)
Expand Down Expand Up @@ -500,7 +500,7 @@ end
end

alg = unwrap_alg(integrator, false)
if !alg.lazy && (integrator.opts.adaptive = false || integrator.EEst <= 1.0)
if !alg.lazy && (integrator.opts.adaptive == false || integrator.EEst <= 1.0)
k = integrator.k
@unpack c6,c7,c8,a91,a92,a93,a94,a95,a96,a97,a98,a101,a102,a103,a104,a105,a106,a107,a108,a109,a111,a112,a113,a114,a115,a116,a117,a118,a119,a1110 = cache.tab
@tight_loop_macros for i in uidx
Expand Down
22 changes: 11 additions & 11 deletions src/perform_step/verner_rk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ end
integrator.k[9]=k9

alg = unwrap_alg(integrator, false)
if !alg.lazy && (integrator.opts.adaptive = false || integrator.EEst <= 1.0)
if !alg.lazy && (integrator.opts.adaptive == false || integrator.EEst <= 1.0)
k = integrator.k
@unpack c10,a1001,a1004,a1005,a1006,a1007,a1008,a1009,c11,a1101,a1102,a1103,a1104,a1105,a1106,a1107,a1108,a1109,a1110,c12,a1201,a1202,a1203,a1204,a1205,a1206,a1207,a1208,a1209,a1210,a1211 = cache
k[10] = f(uprev+dt*(a1001*k[1]+a1004*k[4]+a1005*k[5]+a1006*k[6]+a1007*k[7]+a1008*k[8]+a1009*k[9]),p,t+c10*dt)
Expand All @@ -64,15 +64,15 @@ end
end

function initialize!(integrator, cache::Vern6Cache)
integrator.alg.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12)
alg = unwrap_alg(integrator, false)
alg.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12)
integrator.fsalfirst = cache.k1 ; integrator.fsallast = cache.k9
@unpack k = integrator
resize!(k, integrator.kshortsize)
k[1]=cache.k1; k[2]=cache.k2; k[3]=cache.k3;
k[4]=cache.k4; k[5]=cache.k5; k[6]=cache.k6;
k[7]=cache.k7; k[8]=cache.k8; k[9]=cache.k9 # Set the pointers

alg = unwrap_alg(integrator, false)
if !alg.lazy
k[10] = similar(cache.k1)
k[11] = similar(cache.k1)
Expand Down Expand Up @@ -168,7 +168,7 @@ end
end

alg = unwrap_alg(integrator, false)
if !alg.lazy && (integrator.opts.adaptive = false || integrator.EEst <= 1.0)
if !alg.lazy && (integrator.opts.adaptive == false || integrator.EEst <= 1.0)
k = integrator.k
@unpack c10,a1001,a1004,a1005,a1006,a1007,a1008,a1009,c11,a1101,a1102,a1103,a1104,a1105,a1106,a1107,a1108,a1109,a1110,c12,a1201,a1202,a1203,a1204,a1205,a1206,a1207,a1208,a1209,a1210,a1211 = cache.tab
@unpack tmp = cache
Expand Down Expand Up @@ -233,7 +233,7 @@ end
integrator.u = u

alg = unwrap_alg(integrator, false)
if !alg.lazy && (integrator.opts.adaptive = false || integrator.EEst <= 1.0)
if !alg.lazy && (integrator.opts.adaptive == false || integrator.EEst <= 1.0)
k = integrator.k
@unpack c11,a1101,a1104,a1105,a1106,a1107,a1108,a1109,c12,a1201,a1204,a1205,a1206,a1207,a1208,a1209,a1211,c13,a1301,a1304,a1305,a1306,a1307,a1308,a1309,a1311,a1312,c14,a1401,a1404,a1405,a1406,a1407,a1408,a1409,a1411,a1412,a1413,c15,a1501,a1504,a1505,a1506,a1507,a1508,a1509,a1511,a1512,a1513,c16,a1601,a1604,a1605,a1606,a1607,a1608,a1609,a1611,a1612,a1613 = cache
k[11] = f(uprev+dt*(a1101*k[1]+a1104*k[4]+a1105*k[5]+a1106*k[6]+a1107*k[7]+a1108*k[8]+a1109*k[9]),p,t+c11*dt)
Expand Down Expand Up @@ -362,7 +362,7 @@ end
end

alg = unwrap_alg(integrator, false)
if !alg.lazy && (integrator.opts.adaptive = false || integrator.EEst <= 1.0)
if !alg.lazy && (integrator.opts.adaptive == false || integrator.EEst <= 1.0)
k = integrator.k
@unpack tmp = cache
@unpack c11,a1101,a1104,a1105,a1106,a1107,a1108,a1109,c12,a1201,a1204,a1205,a1206,a1207,a1208,a1209,a1211,c13,a1301,a1304,a1305,a1306,a1307,a1308,a1309,a1311,a1312,c14,a1401,a1404,a1405,a1406,a1407,a1408,a1409,a1411,a1412,a1413,c15,a1501,a1504,a1505,a1506,a1507,a1508,a1509,a1511,a1512,a1513,c16,a1601,a1604,a1605,a1606,a1607,a1608,a1609,a1611,a1612,a1613 = cache.tab
Expand Down Expand Up @@ -445,7 +445,7 @@ end
integrator.u = u

alg = unwrap_alg(integrator, false)
if !alg.lazy && (integrator.opts.adaptive = false || integrator.EEst <= 1.0)
if !alg.lazy && (integrator.opts.adaptive == false || integrator.EEst <= 1.0)
k = integrator.k
@unpack c14,a1401,a1406,a1407,a1408,a1409,a1410,a1411,a1412,c15,a1501,a1506,a1507,a1508,a1509,a1510,a1511,a1512,a1514,c16,a1601,a1606,a1607,a1608,a1609,a1610,a1611,a1612,a1614,a1615,c17,a1701,a1706,a1707,a1708,a1709,a1710,a1711,a1712,a1714,a1715,a1716,c18,a1801,a1806,a1807,a1808,a1809,a1810,a1811,a1812,a1814,a1815,a1816,a1817,c19,a1901,a1906,a1907,a1908,a1909,a1910,a1911,a1912,a1914,a1915,a1916,a1917,c20,a2001,a2006,a2007,a2008,a2009,a2010,a2011,a2012,a2014,a2015,a2016,a2017,c21,a2101,a2106,a2107,a2108,a2109,a2110,a2111,a2112,a2114,a2115,a2116,a2117 = cache
k[14] = f(uprev+dt*(a1401*k[1]+a1406*k[6]+a1407*k[7]+a1408*k[8]+a1409*k[9]+a1410*k[10]+a1411*k[11]+a1412*k[12]),p,t+c14*dt)
Expand Down Expand Up @@ -591,7 +591,7 @@ end
end

alg = unwrap_alg(integrator, false)
if !alg.lazy && (integrator.opts.adaptive = false || integrator.EEst <= 1.0)
if !alg.lazy && (integrator.opts.adaptive == false || integrator.EEst <= 1.0)
k = integrator.k
@unpack c14,a1401,a1406,a1407,a1408,a1409,a1410,a1411,a1412,c15,a1501,a1506,a1507,a1508,a1509,a1510,a1511,a1512,a1514,c16,a1601,a1606,a1607,a1608,a1609,a1610,a1611,a1612,a1614,a1615,c17,a1701,a1706,a1707,a1708,a1709,a1710,a1711,a1712,a1714,a1715,a1716,c18,a1801,a1806,a1807,a1808,a1809,a1810,a1811,a1812,a1814,a1815,a1816,a1817,c19,a1901,a1906,a1907,a1908,a1909,a1910,a1911,a1912,a1914,a1915,a1916,a1917,c20,a2001,a2006,a2007,a2008,a2009,a2010,a2011,a2012,a2014,a2015,a2016,a2017,c21,a2101,a2106,a2107,a2108,a2109,a2110,a2111,a2112,a2114,a2115,a2116,a2117 = cache.tab
@unpack tmp = cache
Expand Down Expand Up @@ -686,7 +686,7 @@ end
integrator.u = u

alg = unwrap_alg(integrator, false)
if !alg.lazy && (integrator.opts.adaptive = false || integrator.EEst <= 1.0)
if !alg.lazy && (integrator.opts.adaptive == false || integrator.EEst <= 1.0)
k = integrator.k
@unpack c17,a1701,a1708,a1709,a1710,a1711,a1712,a1713,a1714,a1715,c18,a1801,a1808,a1809,a1810,a1811,a1812,a1813,a1814,a1815,a1817,c19,a1901,a1908,a1909,a1910,a1911,a1912,a1913,a1914,a1915,a1917,a1918,c20,a2001,a2008,a2009,a2010,a2011,a2012,a2013,a2014,a2015,a2017,a2018,a2019,c21,a2101,a2108,a2109,a2110,a2111,a2112,a2113,a2114,a2115,a2117,a2118,a2119,a2120,c22,a2201,a2208,a2209,a2210,a2211,a2212,a2213,a2214,a2215,a2217,a2218,a2219,a2220,a2221,c23,a2301,a2308,a2309,a2310,a2311,a2312,a2313,a2314,a2315,a2317,a2318,a2319,a2320,a2321,c24,a2401,a2408,a2409,a2410,a2411,a2412,a2413,a2414,a2415,a2417,a2418,a2419,a2420,a2421,c25,a2501,a2508,a2509,a2510,a2511,a2512,a2513,a2514,a2515,a2517,a2518,a2519,a2520,a2521,c26,a2601,a2608,a2609,a2610,a2611,a2612,a2613,a2614,a2615,a2617,a2618,a2619,a2620,a2621 = cache
k[17] = f(uprev+dt*(a1701*k[1]+a1708*k[8]+a1709*k[9]+a1710*k[10]+a1711*k[11]+a1712*k[12]+a1713*k[13]+a1714*k[14]+a1715*k[15]),p,t+c17*dt)
Expand All @@ -706,7 +706,7 @@ function initialize!(integrator, cache::Vern9Cache)
@unpack k1,k2,k3,k4,k5,k6,k7,k8,k9,k10,k11,k12,k13,k14,k15,k16 = cache
@unpack k = integrator
alg = unwrap_alg(integrator, false)
integrator.alg.lazy ? (integrator.kshortsize = 16) : (integrator.kshortsize = 26)
alg.lazy ? (integrator.kshortsize = 16) : (integrator.kshortsize = 26)
resize!(k, integrator.kshortsize)
k[1]=k1;k[2]=k2;k[3]=k3;k[4]=k4;k[5]=k5;k[6]=k6;k[7]=k7;k[8]=k8;k[9]=k9;k[10]=k10;k[11]=k11;k[12]=k12;k[13]=k13;k[14]=k14;k[15]=k15;k[16]=k16 # Setup pointers

Expand Down Expand Up @@ -851,7 +851,7 @@ end
end

alg = unwrap_alg(integrator, false)
if !alg.lazy && (integrator.opts.adaptive = false || integrator.EEst <= 1.0)
if !alg.lazy && (integrator.opts.adaptive == false || integrator.EEst <= 1.0)
k = integrator.k
@unpack tmp = cache
@unpack c17,a1701,a1708,a1709,a1710,a1711,a1712,a1713,a1714,a1715,c18,a1801,a1808,a1809,a1810,a1811,a1812,a1813,a1814,a1815,a1817,c19,a1901,a1908,a1909,a1910,a1911,a1912,a1913,a1914,a1915,a1917,a1918,c20,a2001,a2008,a2009,a2010,a2011,a2012,a2013,a2014,a2015,a2017,a2018,a2019,c21,a2101,a2108,a2109,a2110,a2111,a2112,a2113,a2114,a2115,a2117,a2118,a2119,a2120,c22,a2201,a2208,a2209,a2210,a2211,a2212,a2213,a2214,a2215,a2217,a2218,a2219,a2220,a2221,c23,a2301,a2308,a2309,a2310,a2311,a2312,a2313,a2314,a2315,a2317,a2318,a2319,a2320,a2321,c24,a2401,a2408,a2409,a2410,a2411,a2412,a2413,a2414,a2415,a2417,a2418,a2419,a2420,a2421,c25,a2501,a2508,a2509,a2510,a2511,a2512,a2513,a2514,a2515,a2517,a2518,a2519,a2520,a2521,c26,a2601,a2608,a2609,a2610,a2611,a2612,a2613,a2614,a2615,a2617,a2618,a2619,a2620,a2621 = cache.tab
Expand Down
2 changes: 1 addition & 1 deletion test/composite_algorithm_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using OrdinaryDiffEq, DiffEqProblemLibrary, Base.Test
choice_function(integrator) = (Int(integrator.t<0.5) + 1)
alg_double = CompositeAlgorithm((Tsit5(),Tsit5()),choice_function)
alg_double2 = CompositeAlgorithm((Vern6(),Vern6()),choice_function)
alg_switch = CompositeAlgorithm((Vern7(),Tsit5()),choice_function)
alg_switch = CompositeAlgorithm((Tsit5(),Vern7()),choice_function)

@time sol1 = solve(prob_ode_linear,alg_double)
@time sol2 = solve(prob_ode_linear,Tsit5())
Expand Down
10 changes: 5 additions & 5 deletions test/ode/ode_add_steps_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@ lazy_alg = [BS5, Vern6, Vern7, Vern8, Vern9]

nonstandard_interp_algs = union(algs,bad_algs,lazy_alg)

passed = fill(false, 2length(algs))
passed = fill(false, 2length(nonstandard_interp_algs))

cur_itr = 0
for inplace in [false,true], alg in algs
prob = ODEProblem{inplace}(test_ode, [0.], tspan, [1.])
sol = solve(prob, alg(); callback=cb,dt=0.0013)
pass = all(isapprox(sol(t)[1], test_solution(t); atol=0.05) for t in testtimes)
cur_itr += 1
@test pass
passed[cur_itr] = pass
end

Expand All @@ -42,12 +41,13 @@ for inplace in [false,true], alg in lazy_alg
sol = solve(prob, alg(); callback=cb,dt=0.0013)
fail = all(isapprox(sol(t)[1], test_solution(t); atol=0.05) for t in testtimes)

prob = ODEProblem{inplace}(test_ode, [0.], tspan, [1.])
sol = solve(prob, alg(lazy=false); callback=cb,dt=0.0013)
pass = all(isapprox(sol(t)[1], test_solution(t); atol=0.05) for t in testtimes)

cur_itr += 1
@test pass && !fail
passed[cur_itr] = pass
passed[cur_itr] = pass && !fail
end

any(.!(passed)) && warn("The following algorithms failed the continuous callback test: $(union(algs,algs)[.!(passed)])")
any(.!(passed)) && warn("The following algorithms failed the continuous callback test: $(vcat(algs,algs,lazy_alg,lazy_alg)[.!(passed)])")
@test all(passed)

0 comments on commit d3bf319

Please sign in to comment.