From 8678ec85c451b9b191f08cb0813864fcae6583e7 Mon Sep 17 00:00:00 2001 From: Gandhali Kogekar Date: Tue, 7 Jun 2022 16:53:33 -0400 Subject: [PATCH] Compatibility with all Sundials versions --- include/cantera/numerics/IDAIntegrator.h | 3 +- src/numerics/IDAIntegrator.cpp | 70 +++++++++++++++++++----- 2 files changed, 58 insertions(+), 15 deletions(-) diff --git a/include/cantera/numerics/IDAIntegrator.h b/include/cantera/numerics/IDAIntegrator.h index 7d3a18d0b6..6f1ceb1601 100644 --- a/include/cantera/numerics/IDAIntegrator.h +++ b/include/cantera/numerics/IDAIntegrator.h @@ -12,6 +12,7 @@ #include "cantera/numerics/Integrator.h" #include "cantera/base/ctexceptions.h" #include "sundials/sundials_nvector.h" +#include "cantera/numerics/SundialsContext.h" namespace Cantera { @@ -88,7 +89,7 @@ class IDAIntegrator : public Integrator void* m_ida_mem; //!< Pointer to the IDA memory for the problem void* m_linsol; //!< Sundials linear solver object void* m_linsol_matrix; //!< matrix used by Sundials - void * m_ctx; //!< contex object used by Sundials + SundialsContext m_sundials_ctx; //!< SUNContext object for Sundials>=6.0 FuncEval* m_func; double m_t0; double m_time; //!< The current integrator time diff --git a/src/numerics/IDAIntegrator.cpp b/src/numerics/IDAIntegrator.cpp index 021a6ad9d6..d8c715a290 100644 --- a/src/numerics/IDAIntegrator.cpp +++ b/src/numerics/IDAIntegrator.cpp @@ -9,6 +9,19 @@ #include "cantera/numerics/sundials_headers.h" using namespace std; + +namespace { + +N_Vector newNVector(size_t N, Cantera::SundialsContext& context) +{ +#if CT_SUNDIALS_VERSION >= 60 + return N_VNew_Serial(static_cast(N), context.get()); +#else + return N_VNew_Serial(static_cast(N)); +#endif +} + +} // end anonymous namespace namespace Cantera { @@ -115,7 +128,7 @@ void IDAIntegrator::setTolerances(double reltol, size_t n, double* abstol) if (m_abstol) { N_VDestroy_Serial(m_abstol); } - m_abstol = N_VNew_Serial(static_cast(n), (SUNContext) m_ctx); + m_abstol = newNVector(static_cast(n), m_sundials_ctx); } for (size_t i=0; i(m_neq), (SUNContext) m_ctx); // allocate solution vector + m_y = newNVector(static_cast(m_neq), m_sundials_ctx); N_VConst(0.0, m_y); if (m_ydot) { N_VDestroy_Serial(m_ydot); // free derivative vector if already allocated } - m_ydot = N_VNew_Serial(m_neq, (SUNContext) m_ctx); + m_ydot = newNVector(m_neq, m_sundials_ctx); N_VConst(0.0, m_ydot); // check abs tolerance array size @@ -250,7 +263,7 @@ void IDAIntegrator::initialize(double t0, FuncEval& func) if (m_constraints) { N_VDestroy_Serial(m_constraints); } - m_constraints = N_VNew_Serial(static_cast(m_neq), (SUNContext) m_ctx); + m_constraints = newNVector(static_cast(m_neq), m_sundials_ctx); // set the constraints func.getConstraints(NV_DATA_S(m_constraints)); @@ -263,7 +276,11 @@ void IDAIntegrator::initialize(double t0, FuncEval& func) } //! Create the IDA solver - m_ida_mem = IDACreate((SUNContext) m_ctx); + #if CT_SUNDIALS_VERSION >= 60 + m_ida_mem = IDACreate(m_sundials_ctx.get()); + #else + m_ida_mem = IDACreate(); + #endif if (!m_ida_mem) { throw CanteraError("IDAIntegrator::initialize", "IDACreate failed."); @@ -350,11 +367,24 @@ void IDAIntegrator::applyOptions() #if CT_SUNDIALS_VERSION >= 30 SUNLinSolFree((SUNLinearSolver) m_linsol); SUNMatDestroy((SUNMatrix) m_linsol_matrix); - m_linsol_matrix = SUNDenseMatrix(N, N, (SUNContext) m_ctx); + #if CT_SUNDIALS_VERSION >= 60 + m_linsol_matrix = SUNDenseMatrix(N, N, m_sundials_ctx.get()); + #else + m_linsol_matrix = SUNDenseMatrix(N, N); + #endif + #if CT_SUNDIALS_VERSION >= 60 + m_linsol_matrix = SUNDenseMatrix(N, N, m_sundials_ctx.get()); + #else + m_linsol_matrix = SUNDenseMatrix(N, N); + #endif #if CT_SUNDIALS_USE_LAPACK m_linsol = SUNLapackDense(m_y, (SUNMatrix) m_linsol_matrix); #else - m_linsol = SUNLinSol_Dense(m_y, (SUNMatrix) m_linsol_matrix, (SUNContext) m_ctx); + #if CT_SUNDIALS_VERSION >= 60 + m_linsol = SUNLinSol_Dense(m_y, (SUNMatrix) m_linsol_matrix, m_sundials_ctx.get()); + #else + m_linsol = SUNLinSol_Dense(m_y, (SUNMatrix) m_linsol_matrix); + #endif #endif IDASetLinearSolver(m_ida_mem, (SUNLinearSolver) m_linsol, (SUNMatrix) m_linsol_matrix); @@ -370,7 +400,11 @@ void IDAIntegrator::applyOptions() "Cannot use a diagonal matrix with IDA."); } else if (m_type == GMRES) { #if CT_SUNDIALS_VERSION >= 30 - m_linsol = SUNLinSol_SPGMR(m_y, PREC_NONE, 0, (SUNContext) m_ctx); + #if CT_SUNDIALS_VERSION >= 60 + m_linsol = SUNLinSol_SPGMR(m_y, PREC_NONE, 0, m_sundials_ctx.get()); + #else + m_linsol = SUNLinSol_SPGMR(m_y, PREC_NONE, 0); + #endif IDASpilsSetLinearSolver(m_ida_mem, (SUNLinearSolver) m_linsol); #else IDASpgmr(m_ida_mem, 0); @@ -382,11 +416,19 @@ void IDAIntegrator::applyOptions() #if CT_SUNDIALS_VERSION >= 30 SUNLinSolFree((SUNLinearSolver) m_linsol); SUNMatDestroy((SUNMatrix) m_linsol_matrix); - m_linsol_matrix = SUNBandMatrix(N, nu, nl, (SUNContext) m_ctx); + #if CT_SUNDIALS_VERSION >= 60 + m_linsol_matrix = SUNBandMatrix(N, nu, nl, m_sundials_ctx.get()); + #else + m_linsol_matrix = SUNBandMatrix(N, nu, nl); + #endif #if CT_SUNDIALS_USE_LAPACK m_linsol = SUNLapackBand(m_y, (SUNMatrix) m_linsol_matrix); #else - m_linsol = SUNLinSol_Band(m_y, (SUNMatrix) m_linsol_matrix, (SUNContext) m_ctx); + #if CT_SUNDIALS_VERSION >= 60 + m_linsol = SUNLinSol_Band(m_y, (SUNMatrix) m_linsol_matrix, m_sundials_ctx.get()); + #else + m_linsol = SUNLinSol_Band(m_y, (SUNMatrix) m_linsol_matrix); + #endif #endif IDASetLinearSolver(m_ida_mem, (SUNLinearSolver) m_linsol, (SUNMatrix) m_linsol_matrix); @@ -442,13 +484,13 @@ void IDAIntegrator::sensInit(double t0, FuncEval& func) m_np = func.nparams(); m_sens_ok = false; - N_Vector y = N_VNew_Serial(static_cast(func.neq()), (SUNContext) m_ctx); + N_Vector y = newNVector(static_cast(func.neq()), m_sundials_ctx); m_yS = N_VCloneVectorArray_Serial(static_cast(m_np), y); for (size_t n = 0; n < m_np; n++) { N_VConst(0.0, m_yS[n]); } N_VDestroy_Serial(y); - N_Vector ydot = N_VNew_Serial(static_cast(func.neq()), (SUNContext) m_ctx); + N_Vector ydot = newNVector(static_cast(func.neq()), m_sundials_ctx); m_ySdot = N_VCloneVectorArray_Serial(static_cast(m_np), ydot); for (size_t n = 0; n < m_np; n++) { N_VConst(0.0, m_ySdot[n]); @@ -545,8 +587,8 @@ double IDAIntegrator::sensitivity(size_t k, size_t p) string IDAIntegrator::getErrorInfo(int N) { - N_Vector errs = N_VNew_Serial(static_cast(m_neq), (SUNContext) m_ctx); - N_Vector errw = N_VNew_Serial(static_cast(m_neq), (SUNContext) m_ctx); + N_Vector errs = newNVector(static_cast(m_neq), m_sundials_ctx); + N_Vector errw = newNVector(static_cast(m_neq), m_sundials_ctx); IDAGetErrWeights(m_ida_mem, errw); IDAGetEstLocalErrors(m_ida_mem, errs);