Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pySDC/Sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def __init__(self,params):

defaults = dict()
defaults['do_LU'] = False

defaults['do_coll_update'] = True

for k,v in defaults.items():
setattr(self,k,v)

Expand All @@ -43,6 +44,8 @@ def __init__(self,params):

coll = params['collocation_class'](params['num_nodes'],0,1)
assert isinstance(coll, CollBase)
if not coll.right_is_node:
assert self.params.do_coll_update, "For nodes where the right end point is not a node, do_coll_update has to be set to True"

# This will be set as soon as the sweeper is instantiated at the level
self.__level = None
Expand Down Expand Up @@ -151,4 +154,4 @@ def update_nodes(self):
"""
Abstract interface to node update
"""
return None
return None
23 changes: 14 additions & 9 deletions pySDC/sweeper_classes/generic_LU.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,17 @@ def compute_end_point(self):
L = self.level
P = L.prob

# start with u0 and add integral over the full interval (using coll.weights)
L.uend = P.dtype_u(L.u[0])
for m in range(self.coll.num_nodes):
L.uend += L.dt*self.coll.weights[m]*L.f[m+1]
# add up tau correction of the full interval (last entry)
if L.tau is not None:
L.uend += L.tau[-1]

return None
# check if Mth node is equal to right point and do_coll_update is false, perform a simple copy
if (self.coll.right_is_node and not self.params.do_coll_update):
# a copy is sufficient
L.uend = P.dtype_u(L.u[-1])
else:
# start with u0 and add integral over the full interval (using coll.weights)
L.uend = P.dtype_u(L.u[0])
for m in range(self.coll.num_nodes):
L.uend += L.dt*self.coll.weights[m]*L.f[m+1]
# add up tau correction of the full interval (last entry)
if L.tau is not None:
L.uend += L.tau[-1]

return None
21 changes: 13 additions & 8 deletions pySDC/sweeper_classes/imex_1st_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,24 @@ def compute_end_point(self):
"""
Compute u at the right point of the interval

The value uend computed here is a full evaluation of the Picard formulation (always!)
The value uend computed here is a full evaluation of the Picard formulation unless do_full_update==False
"""

# get current level and problem description
L = self.level
P = L.prob

# start with u0 and add integral over the full interval (using coll.weights)
L.uend = P.dtype_u(L.u[0])
for m in range(self.coll.num_nodes):
L.uend += L.dt*self.coll.weights[m]*(L.f[m+1].impl + L.f[m+1].expl)
# add up tau correction of the full interval (last entry)
if L.tau is not None:
L.uend += L.tau[-1]
# check if Mth node is equal to right point and do_coll_update is false, perform a simple copy
if (self.coll.right_is_node and not self.params.do_coll_update):
# a copy is sufficient
L.uend = P.dtype_u(L.u[-1])
else:
# start with u0 and add integral over the full interval (using coll.weights)
L.uend = P.dtype_u(L.u[0])
for m in range(self.coll.num_nodes):
L.uend += L.dt*self.coll.weights[m]*(L.f[m+1].impl + L.f[m+1].expl)
# add up tau correction of the full interval (last entry)
if L.tau is not None:
L.uend += L.tau[-1]

return None
47 changes: 46 additions & 1 deletion tests/test_imexsweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_manysweepsequalmatrix(self):
#
# Make sure that update function for K sweeps computed from K-sweep matrix gives same result as K sweeps in node-to-node form plus compute_end_point
#
def test_maysweepupdate(self):
def test_manysweepupdate(self):

step, level, problem, nnodes = self.setupLevelStepProblem()
step.levels[0].sweep.predict()
Expand All @@ -229,3 +229,48 @@ def test_maysweepupdate(self):
# Multiply u0 by value of update function to get end value directly
uend_matrix = update*self.pparams['u0']
assert abs(uend_matrix - uend_sweep)<1e-14, "Node-to-node sweep plus update yields different result than update function computed through K-sweep matrix"

#
# Make sure that creating a sweeper object with a collocation object with right_is_node=False and do_coll_update=False throws an exception
#
def test_norightnode_collupdate_fails(self):
self.swparams['collocation_class'] = collclass.CollGaussLegendre
self.swparams['do_coll_update'] = False
# Has to throw an exception
with self.assertRaises(AssertionError):
step, level, problem, nnodes = self.setupLevelStepProblem()

#
# Make sure the update with do_coll_update=False reproduces last stage
#
def test_update_nocollupdate_laststage(self):
self.swparams['do_coll_update'] = False
step, level, problem, nnodes = self.setupLevelStepProblem()
level.sweep.predict()
ulaststage = np.random.rand()
level.u[nnodes].values = ulaststage
level.sweep.compute_end_point()
uend = level.uend.values
assert abs(uend-ulaststage)<1e-14, "compute_end_point with do_coll_update=False did not reproduce last stage value"

#
# Make sure that update with do_coll_update=False is identical to update formula with q=(0,...,0,1)
#
def test_updateformula_no_coll_update(self):
self.swparams['do_coll_update'] = False
step, level, problem, nnodes = self.setupLevelStepProblem()
level.sweep.predict()
u0full = np.array([ level.u[l].values.flatten() for l in range(1,nnodes+1) ])

# Perform update step in sweeper
level.sweep.update_nodes()
ustages = np.array([ level.u[l].values.flatten() for l in range(1,nnodes+1) ])

# Compute end value through provided function
level.sweep.compute_end_point()
uend_sweep = level.uend.values
# Compute end value from matrix formulation
q = np.zeros(nnodes)
q[nnodes-1] = 1.0
uend_mat = q.dot(ustages)
assert np.linalg.norm(uend_sweep - uend_mat, np.infty)<1e-14, "For do_coll_update=False, update formula in sweeper gives different result than matrix update formula with q=(0,..,0,1)"