diff --git a/pySDC/Sweeper.py b/pySDC/Sweeper.py index fdba6341f1..cd90e1873d 100644 --- a/pySDC/Sweeper.py +++ b/pySDC/Sweeper.py @@ -31,7 +31,8 @@ def __init__(self,params): defaults = dict() defaults['do_LU'] = False - + defaults['do_coll_update'] = True + for k,v in defaults.items(): setattr(self,k,v) @@ -43,6 +44,8 @@ def __init__(self,params): coll = params['collocation_class'](params['num_nodes'],0,1) assert isinstance(coll, CollBase) + if not coll.right_is_node: + assert self.params.do_coll_update, "For nodes where the right end point is not a node, do_coll_update has to be set to True" # This will be set as soon as the sweeper is instantiated at the level self.__level = None @@ -151,4 +154,4 @@ def update_nodes(self): """ Abstract interface to node update """ - return None \ No newline at end of file + return None diff --git a/pySDC/sweeper_classes/generic_LU.py b/pySDC/sweeper_classes/generic_LU.py index 00697a7223..9f1b91f27a 100644 --- a/pySDC/sweeper_classes/generic_LU.py +++ b/pySDC/sweeper_classes/generic_LU.py @@ -135,12 +135,17 @@ def compute_end_point(self): L = self.level P = L.prob - # start with u0 and add integral over the full interval (using coll.weights) - L.uend = P.dtype_u(L.u[0]) - for m in range(self.coll.num_nodes): - L.uend += L.dt*self.coll.weights[m]*L.f[m+1] - # add up tau correction of the full interval (last entry) - if L.tau is not None: - L.uend += L.tau[-1] - - return None \ No newline at end of file + # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy + if (self.coll.right_is_node and not self.params.do_coll_update): + # a copy is sufficient + L.uend = P.dtype_u(L.u[-1]) + else: + # start with u0 and add integral over the full interval (using coll.weights) + L.uend = P.dtype_u(L.u[0]) + for m in range(self.coll.num_nodes): + L.uend += L.dt*self.coll.weights[m]*L.f[m+1] + # add up tau correction of the full interval (last entry) + if L.tau is not None: + L.uend += L.tau[-1] + + return None diff --git a/pySDC/sweeper_classes/imex_1st_order.py b/pySDC/sweeper_classes/imex_1st_order.py index d35edfc270..70abaa320c 100644 --- a/pySDC/sweeper_classes/imex_1st_order.py +++ b/pySDC/sweeper_classes/imex_1st_order.py @@ -134,19 +134,24 @@ def compute_end_point(self): """ Compute u at the right point of the interval - The value uend computed here is a full evaluation of the Picard formulation (always!) + The value uend computed here is a full evaluation of the Picard formulation unless do_full_update==False """ # get current level and problem description L = self.level P = L.prob - # start with u0 and add integral over the full interval (using coll.weights) - L.uend = P.dtype_u(L.u[0]) - for m in range(self.coll.num_nodes): - L.uend += L.dt*self.coll.weights[m]*(L.f[m+1].impl + L.f[m+1].expl) - # add up tau correction of the full interval (last entry) - if L.tau is not None: - L.uend += L.tau[-1] + # check if Mth node is equal to right point and do_coll_update is false, perform a simple copy + if (self.coll.right_is_node and not self.params.do_coll_update): + # a copy is sufficient + L.uend = P.dtype_u(L.u[-1]) + else: + # start with u0 and add integral over the full interval (using coll.weights) + L.uend = P.dtype_u(L.u[0]) + for m in range(self.coll.num_nodes): + L.uend += L.dt*self.coll.weights[m]*(L.f[m+1].impl + L.f[m+1].expl) + # add up tau correction of the full interval (last entry) + if L.tau is not None: + L.uend += L.tau[-1] return None diff --git a/tests/test_imexsweeper.py b/tests/test_imexsweeper.py index 6f0e40166a..e619ab26d9 100644 --- a/tests/test_imexsweeper.py +++ b/tests/test_imexsweeper.py @@ -204,7 +204,7 @@ def test_manysweepsequalmatrix(self): # # Make sure that update function for K sweeps computed from K-sweep matrix gives same result as K sweeps in node-to-node form plus compute_end_point # - def test_maysweepupdate(self): + def test_manysweepupdate(self): step, level, problem, nnodes = self.setupLevelStepProblem() step.levels[0].sweep.predict() @@ -229,3 +229,48 @@ def test_maysweepupdate(self): # Multiply u0 by value of update function to get end value directly uend_matrix = update*self.pparams['u0'] assert abs(uend_matrix - uend_sweep)<1e-14, "Node-to-node sweep plus update yields different result than update function computed through K-sweep matrix" + + # + # Make sure that creating a sweeper object with a collocation object with right_is_node=False and do_coll_update=False throws an exception + # + def test_norightnode_collupdate_fails(self): + self.swparams['collocation_class'] = collclass.CollGaussLegendre + self.swparams['do_coll_update'] = False + # Has to throw an exception + with self.assertRaises(AssertionError): + step, level, problem, nnodes = self.setupLevelStepProblem() + + # + # Make sure the update with do_coll_update=False reproduces last stage + # + def test_update_nocollupdate_laststage(self): + self.swparams['do_coll_update'] = False + step, level, problem, nnodes = self.setupLevelStepProblem() + level.sweep.predict() + ulaststage = np.random.rand() + level.u[nnodes].values = ulaststage + level.sweep.compute_end_point() + uend = level.uend.values + assert abs(uend-ulaststage)<1e-14, "compute_end_point with do_coll_update=False did not reproduce last stage value" + + # + # Make sure that update with do_coll_update=False is identical to update formula with q=(0,...,0,1) + # + def test_updateformula_no_coll_update(self): + self.swparams['do_coll_update'] = False + step, level, problem, nnodes = self.setupLevelStepProblem() + level.sweep.predict() + u0full = np.array([ level.u[l].values.flatten() for l in range(1,nnodes+1) ]) + + # Perform update step in sweeper + level.sweep.update_nodes() + ustages = np.array([ level.u[l].values.flatten() for l in range(1,nnodes+1) ]) + + # Compute end value through provided function + level.sweep.compute_end_point() + uend_sweep = level.uend.values + # Compute end value from matrix formulation + q = np.zeros(nnodes) + q[nnodes-1] = 1.0 + uend_mat = q.dot(ustages) + assert np.linalg.norm(uend_sweep - uend_mat, np.infty)<1e-14, "For do_coll_update=False, update formula in sweeper gives different result than matrix update formula with q=(0,..,0,1)"