Skip to content

Commit

Permalink
update sunstepper evolve functions to pass tret and return reason
Browse files Browse the repository at this point in the history
  • Loading branch information
balos1 committed May 9, 2024
1 parent 0d3e9b0 commit f97e71b
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 27 deletions.
25 changes: 21 additions & 4 deletions include/sundials/sundials_stepper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@ extern "C" {
typedef _SUNDIALS_STRUCT_ SUNStepper_s* SUNStepper;

typedef int (*SUNStepperAdvanceFn)(SUNStepper stepper, sunrealtype t0,
sunrealtype tout, N_Vector y);
sunrealtype tout, N_Vector y,
sunrealtype* tret, int* stop_reason);

typedef int (*SUNStepperOneStepFn)(SUNStepper stepper, sunrealtype t0,
sunrealtype tout, N_Vector y);
sunrealtype tout, N_Vector y,
sunrealtype* tret, int* stop_reason);

typedef int (*SUNStepperTryStepFn)(SUNStepper stepper, sunrealtype t0,
sunrealtype tout, N_Vector y, int* ark_flag);
sunrealtype tout, N_Vector y,
sunrealtype* tret, int* stop_reason);

typedef int (*SUNStepperFullRhsFn)(SUNStepper stepper, sunrealtype t,
N_Vector y, N_Vector f, int mode);
Expand All @@ -39,7 +42,7 @@ SUNDIALS_EXPORT
SUNErrCode SUNStepper_Create(SUNContext sunctx, SUNStepper* stepper);

SUNDIALS_EXPORT
SUNErrCode SUNStepper_Free(SUNStepper* stepper);
SUNErrCode SUNStepper_Destroy(SUNStepper* stepper);

SUNDIALS_EXPORT
SUNErrCode SUNStepper_SetContent(SUNStepper stepper, void* content);
Expand Down Expand Up @@ -70,6 +73,20 @@ SUNErrCode SUNStepper_GetForcingData(SUNStepper stepper, sunrealtype* tshift,
sunrealtype* tscale, N_Vector** forcing,
int* nforcing);

SUNDIALS_EXPORT
SUNErrCode SUNStepper_Advance(SUNStepper stepper, sunrealtype t0,
sunrealtype tout, N_Vector y, sunrealtype* tret,
int* stop_reason);

SUNDIALS_EXPORT
SUNErrCode SUNStepper_Step(SUNStepper stepper, sunrealtype t0, sunrealtype tout,
N_Vector y, sunrealtype* tret, int* stop_reason);

SUNDIALS_EXPORT
SUNErrCode SUNStepper_TryStep(SUNStepper stepper, sunrealtype t0,
sunrealtype tout, N_Vector y, sunrealtype* tret,
int* stop_reason);

#ifdef __cplusplus
}
#endif
Expand Down
32 changes: 16 additions & 16 deletions src/arkode/arkode_arkstep.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "arkode_arkstep_impl.h"
#include "arkode_impl.h"
#include "arkode_interp_impl.h"
#include "sundials/sundials_types.h"

#define FIXED_LIN_TOL

Expand Down Expand Up @@ -3270,10 +3271,10 @@ int ARKStepCreateSUNStepper(void* inner_arkode_mem, SUNStepper* stepper)
----------------------------------------------------------------------------*/

int arkStep_SUNStepperAdvance(SUNStepper stepper, sunrealtype t0,
sunrealtype tout, N_Vector y)
sunrealtype tout, N_Vector y, sunrealtype* tret,
int* stop_reason)
{
void* arkode_mem; /* arkode memory */
sunrealtype tret; /* return time */
sunrealtype tshift, tscale; /* time normalization values */
N_Vector* forcing; /* forcing vectors */
int nforcing; /* number of forcing vectors */
Expand All @@ -3297,8 +3298,8 @@ int arkStep_SUNStepperAdvance(SUNStepper stepper, sunrealtype t0,
if (retval != ARK_SUCCESS) { return (retval); }

/* evolve inner ODE */
retval = ARKStepEvolve(arkode_mem, tout, y, &tret, ARK_NORMAL);
if (retval < 0) { return (retval); }
*stop_reason = ARKStepEvolve(arkode_mem, tout, y, tret, ARK_NORMAL);
if (*stop_reason < 0) { return (*stop_reason); }

/* disable inner forcing */
retval = arkStep_SetInnerForcing(arkode_mem, ZERO, ONE, NULL, 0);
Expand All @@ -3308,10 +3309,10 @@ int arkStep_SUNStepperAdvance(SUNStepper stepper, sunrealtype t0,
}

int arkStep_SUNStepperOneStep(SUNStepper stepper, sunrealtype t0,
sunrealtype tout, N_Vector y)
sunrealtype tout, N_Vector y, sunrealtype* tret,
int* stop_reason)
{
void* arkode_mem; /* arkode memory */
sunrealtype tret; /* return time */
sunrealtype tshift, tscale; /* time normalization values */
N_Vector* forcing; /* forcing vectors */
int nforcing; /* number of forcing vectors */
Expand All @@ -3335,8 +3336,8 @@ int arkStep_SUNStepperOneStep(SUNStepper stepper, sunrealtype t0,
if (retval != ARK_SUCCESS) { return (retval); }

/* evolve inner ODE */
retval = ARKStepEvolve(arkode_mem, tout, y, &tret, ARK_ONE_STEP);
if (retval < 0) { return (retval); }
*stop_reason = ARKStepEvolve(arkode_mem, tout, y, tret, ARK_ONE_STEP);
if (*stop_reason < 0) { return (*stop_reason); }

/* disable inner forcing */
retval = arkStep_SetInnerForcing(arkode_mem, ZERO, ONE, NULL, 0);
Expand All @@ -3346,10 +3347,10 @@ int arkStep_SUNStepperOneStep(SUNStepper stepper, sunrealtype t0,
}

int arkStep_SUNStepperTryStep(SUNStepper stepper, sunrealtype t0,
sunrealtype tout, N_Vector y, int* ark_flag)
sunrealtype tout, N_Vector y, sunrealtype* tret,
int* stop_reason)
{
void* arkode_mem; /* arkode memory */
sunrealtype tret; /* return time */
sunrealtype tshift, tscale; /* time normalization values */
N_Vector* forcing; /* forcing vectors */
int nforcing; /* number of forcing vectors */
Expand All @@ -3373,7 +3374,7 @@ int arkStep_SUNStepperTryStep(SUNStepper stepper, sunrealtype t0,
if (retval != ARK_SUCCESS) { return (retval); }

/* try to evolve inner ODE */
retval = arkStep_TryStep(arkode_mem, t0, tout, y, ark_flag);
retval = arkStep_TryStep(arkode_mem, t0, tout, y, tret, stop_reason);
if (retval != ARK_SUCCESS) { return (retval); }

/* disable inner forcing */
Expand Down Expand Up @@ -3864,11 +3865,10 @@ int arkStep_GetOrder(ARKodeMem ark_mem)
* ---------------------------------------------------------------------------*/

int arkStep_TryStep(void* arkode_mem, sunrealtype tstart, sunrealtype tstop,
N_Vector y, int* ark_flag)
N_Vector y, sunrealtype* tret, int* ark_flag)
{
int flag; /* generic return flag */
int tmp_flag; /* evolve return flag */
sunrealtype tret; /* return time */
int flag; /* generic return flag */
int tmp_flag; /* evolve return flag */

/* Check inputs */
if (arkode_mem == NULL) { return ARK_MEM_NULL; }
Expand All @@ -3887,7 +3887,7 @@ int arkStep_TryStep(void* arkode_mem, sunrealtype tstart, sunrealtype tstop,
if (flag != ARK_SUCCESS) { return flag; }

/* Take step, check flag below */
tmp_flag = ARKStepEvolve(arkode_mem, tstop, y, &tret, ARK_ONE_STEP);
tmp_flag = ARKStepEvolve(arkode_mem, tstop, y, tret, ARK_ONE_STEP);

/* Re-enable temporal error test check */
flag = arkSetForcePass(arkode_mem, SUNFALSE);
Expand Down
11 changes: 7 additions & 4 deletions src/arkode/arkode_arkstep_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,14 @@ int arkStep_NlsConvTest(SUNNonlinearSolver NLS, N_Vector y, N_Vector del,

/* private functions for interfacing with SUNStepper */
int arkStep_SUNStepperAdvance(SUNStepper stepper, sunrealtype t0,
sunrealtype tout, N_Vector y);
sunrealtype tout, N_Vector y, sunrealtype* tret,
int* stop_reason);
int arkStep_SUNStepperOneStep(SUNStepper stepper, sunrealtype t0,
sunrealtype tout, N_Vector y);
sunrealtype tout, N_Vector y, sunrealtype* tret,
int* stop_reason);
int arkStep_SUNStepperTryStep(SUNStepper stepper, sunrealtype t0,
sunrealtype tout, N_Vector y, int* ark_flag);
sunrealtype tout, N_Vector y, sunrealtype* tret,
int* stop_reason);
int arkStep_SUNStepperFullRhs(SUNStepper stepper, sunrealtype t, N_Vector y,
N_Vector f, int mode);
int arkStep_SUNStepperReset(SUNStepper stepper, sunrealtype tR, N_Vector yR);
Expand All @@ -243,7 +246,7 @@ int arkStep_GetOrder(ARKodeMem ark_mem);

/* private utility functions */
int arkStep_TryStep(void* arkode_mem, sunrealtype tstart, sunrealtype tstop,
N_Vector y, int* ark_flag);
N_Vector y, sunrealtype* tret, int* ark_flag);

/*===============================================================
Reusable ARKStep Error Messages
Expand Down
5 changes: 4 additions & 1 deletion src/arkode/arkode_mristep.c
Original file line number Diff line number Diff line change
Expand Up @@ -2941,7 +2941,10 @@ int mriStepInnerStepper_EvolveSUNStepper(MRIStepInnerStepper stepper,
N_Vector y)
{
SUNStepper sunstepper = (SUNStepper)stepper->content;
stepper->last_flag = sunstepper->ops->advance(sunstepper, t0, tout, y);
sunrealtype tret;
int stop_reason;
stepper->last_flag = sunstepper->ops->advance(sunstepper, t0, tout, y, &tret,
&stop_reason);
return stepper->last_flag;
}

Expand Down
4 changes: 3 additions & 1 deletion src/arkode/xbraid/arkode_xbraid.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "arkode_arkstep_impl.h"
#include "arkode_xbraid_impl.h"
#include "sundials/sundials_math.h"
#include "sundials/sundials_types.h"

#define ONE SUN_RCONST(1.0)

Expand Down Expand Up @@ -438,5 +439,6 @@ int ARKBraid_Access(braid_App app, braid_Vector u, braid_AccessStatus astatus)
int ARKBraid_TakeStep(void* arkode_mem, sunrealtype tstart, sunrealtype tstop,
N_Vector y, int* ark_flag)
{
return arkStep_TryStep(arkode_mem, tstart, tstop, y, ark_flag);
sunrealtype tret;
return arkStep_TryStep(arkode_mem, tstart, tstop, y, &tret, ark_flag);
}
37 changes: 36 additions & 1 deletion src/sundials/sundials_stepper.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ SUNErrCode SUNStepper_Create(SUNContext sunctx, SUNStepper* stepper)
return SUN_SUCCESS;
}

SUNErrCode SUNStepper_Free(SUNStepper* stepper_ptr)
SUNErrCode SUNStepper_Destroy(SUNStepper* stepper_ptr)
{
SUNFunctionBegin((*stepper_ptr)->sunctx);

Expand All @@ -52,6 +52,41 @@ SUNErrCode SUNStepper_Free(SUNStepper* stepper_ptr)
return SUN_SUCCESS;
}

SUNErrCode SUNStepper_Advance(SUNStepper stepper, sunrealtype t0,
sunrealtype tout, N_Vector y, sunrealtype* tret,
int* stop_reason)
{
SUNFunctionBegin(stepper->sunctx);
if (stepper->ops->advance)
{
return stepper->ops->advance(stepper, t0, tout, y, tret, stop_reason);
}
return SUN_ERR_NOT_IMPLEMENTED;
}

SUNErrCode SUNStepper_Step(SUNStepper stepper, sunrealtype t0, sunrealtype tout,
N_Vector y, sunrealtype* tret, int* stop_reason)
{
SUNFunctionBegin(stepper->sunctx);
if (stepper->ops->onestep)
{
return stepper->ops->onestep(stepper, t0, tout, y, tret, stop_reason);
}
return SUN_ERR_NOT_IMPLEMENTED;
}

SUNErrCode SUNStepper_TryStep(SUNStepper stepper, sunrealtype t0,
sunrealtype tout, N_Vector y, sunrealtype* tret,
int* stop_reason)
{
SUNFunctionBegin(stepper->sunctx);
if (stepper->ops->trystep)
{
return stepper->ops->trystep(stepper, t0, tout, y, tret, stop_reason);
}
return SUN_ERR_NOT_IMPLEMENTED;
}

SUNErrCode SUNStepper_SetContent(SUNStepper stepper, void* content)
{
SUNFunctionBegin(stepper->sunctx);
Expand Down

0 comments on commit f97e71b

Please sign in to comment.