Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ end

# Shampine's Low-order Rosenbrocks

mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabType,
mutable struct RosenbrockCache{uType, rateType, tabType, uNoUnitsType, JType, WType, TabType,
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <:
RosenbrockMutableCache
u::uType
Expand All @@ -21,6 +21,8 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT
du::rateType
du1::rateType
du2::rateType
dtC::Matrix{tabType}
dtd::Vector{tabType}
ks::Vector{rateType}
fsalfirst::rateType
fsallast::rateType
Expand Down Expand Up @@ -761,6 +763,9 @@ function alg_cache(
du1 = zero(rate_prototype)
du2 = zero(rate_prototype)

dtC = similar(tab.C)
dtd = similar(tab.d)

# Initialize other variables
fsalfirst = zero(rate_prototype)
fsallast = zero(rate_prototype)
Expand Down Expand Up @@ -795,7 +800,7 @@ function alg_cache(

# Return the cache struct with vectors
RosenbrockCache(
u, uprev, dense, du, du1, du2, ks, fsalfirst, fsallast,
u, uprev, dense, du, du1, du2, dtC, dtd, ks, fsalfirst, fsallast,
dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
linsolve, jac_config, grad_config, reltol, alg,
alg.step_limiter!, alg.stage_limiter!, size(tab.H, 1))
Expand Down
6 changes: 3 additions & 3 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,7 @@ end

@muladd function perform_step!(integrator, cache::RosenbrockCache, repeat_step = false)
(; t, dt, uprev, u, f, p) = integrator
(; du, du1, du2, dT, J, W, uf, tf, ks, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter!) = cache
(; du, du1, du2, dT, dtC, dtd, J, W, uf, tf, ks, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter!) = cache
(; A, C, gamma, c, d, H) = cache.tab

# Assignments
Expand All @@ -1327,8 +1327,8 @@ end
mass_matrix = integrator.f.mass_matrix

# Precalculations
dtC = C .* inv(dt)
dtd = dt .* d
@. dtC = C * inv(dt)
@. dtd = dt * d
dtgamma = dt * gamma

f(cache.fsalfirst, uprev, p, t)
Expand Down
Loading