Skip to content

Commit

Permalink
more progress with prototyping SUNAdjointSolver and its interface to …
Browse files Browse the repository at this point in the history
…SUNStepper
  • Loading branch information
balos1 committed May 10, 2024
1 parent bba03b5 commit 9753833
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 14 deletions.
5 changes: 5 additions & 0 deletions include/sundials/sundials_stepper.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

#include <sundials/sundials_core.h>

#include "sundials/sundials_export.h"

#ifdef __cplusplus
extern "C" {
#endif
Expand Down Expand Up @@ -94,6 +96,9 @@ SUNErrCode SUNStepper_TryStep(SUNStepper stepper, sunrealtype t0,
sunrealtype tout, N_Vector y, sunrealtype* tret,
int* stop_reason);

SUNDIALS_EXPORT
SUNErrCode SUNStepper_Reset(SUNStepper stepper, sunrealtype tR, N_Vector yR);

#ifdef __cplusplus
}
#endif
Expand Down
9 changes: 9 additions & 0 deletions src/arkode/arkode.c
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,7 @@ int arkInit(ARKodeMem ark_mem, sunrealtype t0, N_Vector y0, int init_type)
/* Initial, old, and next step sizes */
ark_mem->h0u = ZERO;
ark_mem->hold = ZERO;
ark_mem->h = ZERO;
ark_mem->next_h = ZERO;

/* Tolerance scale factor */
Expand Down Expand Up @@ -2159,6 +2160,14 @@ int arkInitialSetup(ARKodeMem ark_mem, sunrealtype tout)
/* Test input tstop for legality (correct direction of integration) */
if (ark_mem->tstopset)
{
#if SUNDIALS_LOGGING_LEVEL >= SUNDIALS_LOGGING_DEBUG
SUNLogger_QueueMsg(ARK_LOGGER, SUN_LOGLEVEL_DEBUG,
"ARKODE::arkInitialSetup", "test-tstop",
"h = %" RSYM ", tcur = %" RSYM ", tout = %" RSYM
", tstop = %" RSYM,
ark_mem->h, ark_mem->tcur, tout, ark_mem->tstop);
#endif

htmp = (ark_mem->h == ZERO) ? tout - ark_mem->tcur : ark_mem->h;
if ((ark_mem->tstop - ark_mem->tcur) * htmp <= ZERO)
{
Expand Down
17 changes: 9 additions & 8 deletions src/arkode/arkode_arkstep.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "arkode_arkstep_impl.h"
#include "arkode_impl.h"
#include "arkode_interp_impl.h"
#include "arkode_types_impl.h"
#include "sundials/sundials_types.h"

#define FIXED_LIN_TOL
Expand Down Expand Up @@ -3056,11 +3057,11 @@ int arkStep_ComputeSolutions_MassFixed(ARKodeMem ark_mem, sunrealtype* dsmPtr)
int ARKStepCreateSUNStepper(void* inner_arkode_mem, SUNStepper* stepper)
{
int retval;
ARKodeMem ark_mem;
ARKodeMem ark_mem = (ARKodeMem)inner_arkode_mem;
ARKodeARKStepMem step_mem;

retval = arkStep_AccessStepMem(inner_arkode_mem, "ARKStepCreateSUNStepper",
&ark_mem, &step_mem);
&step_mem);
if (retval)
{
arkProcessError(NULL, ARK_ILL_INPUT, __LINE__, __func__, __FILE__,
Expand Down Expand Up @@ -3117,11 +3118,11 @@ int arkStep_SUNStepperAdvance(SUNStepper stepper, sunrealtype t0,
if (retval != ARK_SUCCESS) { return (retval); }

/* set the stop time */
retval = ARKStepSetStopTime(arkode_mem, tout);
retval = ARKodeSetStopTime(arkode_mem, tout);
if (retval != ARK_SUCCESS) { return (retval); }

/* evolve inner ODE */
*stop_reason = ARKStepEvolve(arkode_mem, tout, y, tret, ARK_NORMAL);
*stop_reason = ARKodeEvolve(arkode_mem, tout, y, tret, ARK_NORMAL);
if (*stop_reason < 0) { return (*stop_reason); }

/* disable inner forcing */
Expand Down Expand Up @@ -3155,11 +3156,11 @@ int arkStep_SUNStepperOneStep(SUNStepper stepper, sunrealtype t0,
if (retval != ARK_SUCCESS) { return (retval); }

/* set the stop time */
retval = ARKStepSetStopTime(arkode_mem, tout);
retval = ARKodeSetStopTime(arkode_mem, tout);
if (retval != ARK_SUCCESS) { return (retval); }

/* evolve inner ODE */
*stop_reason = ARKStepEvolve(arkode_mem, tout, y, tret, ARK_ONE_STEP);
*stop_reason = ARKodeEvolve(arkode_mem, tout, y, tret, ARK_ONE_STEP);
if (*stop_reason < 0) { return (*stop_reason); }

/* disable inner forcing */
Expand Down Expand Up @@ -3193,7 +3194,7 @@ int arkStep_SUNStepperTryStep(SUNStepper stepper, sunrealtype t0,
if (retval != ARK_SUCCESS) { return (retval); }

/* set the stop time */
retval = ARKStepSetStopTime(arkode_mem, tout);
retval = ARKodeSetStopTime(arkode_mem, tout);
if (retval != ARK_SUCCESS) { return (retval); }

/* try to evolve inner ODE */
Expand Down Expand Up @@ -3243,7 +3244,7 @@ int arkStep_SUNStepperReset(SUNStepper stepper, sunrealtype tR, N_Vector yR)
retval = SUNStepper_GetContent(stepper, &arkode_mem);
if (retval != ARK_SUCCESS) { return (retval); }

return (ARKStepReset(arkode_mem, tR, yR));
return (ARKodeReset(arkode_mem, tR, yR));
}

/*---------------------------------------------------------------
Expand Down
7 changes: 7 additions & 0 deletions src/sundials/sundials_stepper.c
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ SUNErrCode SUNStepper_TryStep(SUNStepper stepper, sunrealtype t0,
return SUN_ERR_NOT_IMPLEMENTED;
}

SUNErrCode SUNStepper_Reset(SUNStepper stepper, sunrealtype tR, N_Vector yR)
{
SUNFunctionBegin(stepper->sunctx);
if (stepper->ops->advance) { return stepper->ops->reset(stepper, tR, yR); }
return SUN_ERR_NOT_IMPLEMENTED;
}

SUNErrCode SUNStepper_SetContent(SUNStepper stepper, void* content)
{
SUNFunctionBegin(stepper->sunctx);
Expand Down
51 changes: 45 additions & 6 deletions test/unit_tests/arkode/C_serial/ark_test_sunadjoint.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
#include "sunadjoint/sunadjoint_checkpointscheme.h"
#include "sunadjoint/sunadjoint_solver.h"
#include "sundials/sundials_nvector.h"
#include "sundials/sundials_stepper.h"
#include "sundials/sundials_types.h"
#include "sunmatrix/sunmatrix_dense.h"

int lotka_volterra(sunrealtype t, N_Vector uvec, N_Vector udotvec, void* user_data)
{
Expand All @@ -36,22 +38,52 @@ int lotka_volterra(sunrealtype t, N_Vector uvec, N_Vector udotvec, void* user_da
return 0;
}

int jacobian(sunrealtype t, N_Vector uvec, N_Vector udotvec, SUNMatrix Jac,
void* user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3)
{
sunrealtype* p = (sunrealtype*)user_data;
sunrealtype* u = N_VGetArrayPointer(uvec);
sunrealtype* udot = N_VGetArrayPointer(udotvec);
sunrealtype* J = SUNDenseMatrix_Data(Jac);

J[0] = p[0] - p[1] * u[1];
J[1] = p[2] * u[1];
J[2] = -p[1] * u[0];
J[3] = p[2] * u[0] - p[3];

return 0;
}

int lotka_volterra_adjoint(sunrealtype t, N_Vector lvec, N_Vector ldotvec,
void* user_data)
{
sunrealtype* p = (sunrealtype*)user_data;
sunrealtype* lambda = N_VGetArrayPointer(lvec);
sunrealtype* ldot = N_VGetArrayPointer(ldotvec);

// TODO(CJB): implement
// jacobian(t, lvec, ldotvec, J, user_data, NULL, NULL, NULL);
// ldot[0] = lvec[0] *

return 0;
}

int forward_solution(SUNContext sunctx, void* arkode_mem,
SUNAdjointCheckpointScheme checkpoint_scheme,
sunrealtype t0, sunrealtype tf, N_Vector u)
{
sunrealtype params[4] = {1.5, 1.0, 3.0, 1.0};
ARKStepSetUserData(arkode_mem, (void*)params);
ARKodeSetUserData(arkode_mem, (void*)params);

ARKStepSStolerances(arkode_mem, 1e-4, 1e-10);
ARKodeSStolerances(arkode_mem, 1e-4, 1e-10);

sunrealtype t = t0;
while (t < tf)
{
int flag = ARKStepEvolve(arkode_mem, tf, u, &t, ARK_NORMAL);
int flag = ARKodeEvolve(arkode_mem, tf, u, &t, ARK_NORMAL);
if (flag < 0)
{
fprintf(stderr, ">>> ERROR: ARKStepEvolve returned %d\n", flag);
fprintf(stderr, ">>> ERROR: ARKodeEvolve returned %d\n", flag);
return -1;
}
}
Expand All @@ -72,6 +104,13 @@ int adjoint_solution(SUNContext sunctx, void* arkode_mem,
sunindextype num_params = 4;
N_Vector sf = N_VNew_Serial(neq + num_params, sunctx);

// TODO(CJB): Load sf with the sensitivity terminal conditions
N_VConst(0.0, sf);

// TODO(CJB): this block of code needs to be less complicated, should wrap it up in something like ARKStepCreateAdjointSolver()
// lotka_volterra_adjoint is J*lambda - user will provide J, we internally will create RHS that is J*lambda
ARKodeResize(arkode_mem, sf, 1.0, tf, NULL, NULL);
ARKStepReInit(arkode_mem, lotka_volterra_adjoint, NULL, tf, sf);
SUNStepper stepper = NULL;
ARKStepCreateSUNStepper(arkode_mem, &stepper);
SUNAdjointSolver adj_solver = NULL;
Expand Down Expand Up @@ -116,7 +155,7 @@ int main(int argc, char* argv[])
// Enable checkpointing during the forward solution
SUNAdjointCheckpointScheme checkpoint_scheme = NULL;
// SUNAdjointCheckpointScheme_NewEmpty(sunctx, &checkpoint_scheme);
// ARKStepSetCheckpointScheme(arkode_mem, checkpoint_scheme);
// ARKodeSetCheckpointScheme(arkode_mem, checkpoint_scheme);

//
// Compute the forward solution
Expand All @@ -136,7 +175,7 @@ int main(int argc, char* argv[])
//

N_VDestroy(u);
ARKStepFree(&arkode_mem);
ARKodeFree(&arkode_mem);

return 0;
}

0 comments on commit 9753833

Please sign in to comment.