Skip to content

Commit

Permalink
Some clean up on evaluators and solvers for easier output
Browse files Browse the repository at this point in the history
  • Loading branch information
kburns committed May 17, 2023
1 parent b224e81 commit 2f4bcbb
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 34 deletions.
12 changes: 5 additions & 7 deletions dedalus/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def setup_file(self, file):
file.create_group('scales')
file['scales'].create_dataset(name='constant', data=np.zeros(1), dtype=np.float64)
file['scales']['constant'].make_scale('constant')
for name in ['sim_time', 'timestep', 'world_time', 'wall_time']:
for name in ['sim_time', 'timestep', 'wall_time']:
file['scales'].create_dataset(name=name, shape=(0,), maxshape=(self.max_writes,), dtype=np.float64) # shape[0] = 0 to chunk across writes
file['scales'][name].make_scale(name)
for name in ['iteration', 'write_number']:
Expand All @@ -533,7 +533,7 @@ def setup_file(self, file):
dset.attrs['scales'] = scales
# Time scales
dset.dims[0].label = 't'
for sn in ['sim_time', 'world_time', 'wall_time', 'timestep', 'iteration', 'write_number']:
for sn in ['sim_time', 'wall_time', 'timestep', 'iteration', 'write_number']:
dset.dims[0].attach_scale(file['scales'][sn])
# Spatial scales
rank = len(op.tensorsig)
Expand Down Expand Up @@ -566,10 +566,8 @@ def create_task_dataset(self, file, task):
dset = file['tasks'].create_dataset(name=task['name'], shape=shape, maxshape=maxshape, dtype=task['dtype'])
return dset

def process(self, **kw):
def process(self, iteration, wall_time=0, sim_time=0, timestep=0):
"""Save task outputs to HDF5 file."""
# HACK: fix world time and timestep inputs from solvers.py/timestepper.py
kw['world_time'] = 0
# Update write counts
self.total_write_num += 1
self.file_write_num += 1
Expand All @@ -580,7 +578,7 @@ def process(self, **kw):
self.file_write_num = 1
# Write file metadata
file = self.get_file()
self.write_file_metadata(file, write_number=self.total_write_num, **kw)
self.write_file_metadata(file, write_number=self.total_write_num, iteration=iteration, wall_time=wall_time, sim_time=sim_time, timestep=timestep)
# Write tasks
for task in self.tasks:
# Transform and process data
Expand All @@ -596,7 +594,7 @@ def write_file_metadata(self, file, **kw):
# Update file metadata
file.attrs['writes'] = self.file_write_num
# Update time scales
for name in ['sim_time', 'world_time', 'wall_time', 'timestep', 'iteration', 'write_number']:
for name in ['sim_time', 'wall_time', 'timestep', 'iteration', 'write_number']:
dset = file['scales'][name]
dset.resize(self.file_write_num, axis=0)
dset[self.file_write_num-1] = kw[name]
Expand Down
54 changes: 30 additions & 24 deletions dedalus/core/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def __init__(self, problem, ncc_cutoff=1e-6, max_ncc_terms=None, entry_cutoff=1e
self.subsystems = subsystems.build_subsystems(self)
self.subproblems = subsystems.build_subproblems(self, self.subsystems)
self.subproblems_by_group = {sp.group: sp for sp in self.subproblems}
# Build evaluator
namespace = {}
self.evaluator = Evaluator(self.dist, namespace)


class EigenvalueSolver(SolverBase):
Expand Down Expand Up @@ -317,8 +320,6 @@ def __init__(self, problem, **kw):
self.subproblem_matsolvers = {}
self.iteration = 0
# Create RHS handler
namespace = {}
self.evaluator = Evaluator(self.dist, namespace)
F_handler = self.evaluator.add_system_handler(iter=1, group='F')
for eq in problem.eqs:
F_handler.add_task(eq['F'])
Expand Down Expand Up @@ -425,8 +426,6 @@ def __init__(self, problem, **kw):
self.perturbations = problem.perturbations
self.iteration = 0
# Create RHS handler
namespace = {}
self.evaluator = Evaluator(self.dist, namespace)
F_handler = self.evaluator.add_system_handler(iter=1, group='F')
for eq in problem.eqs:
F_handler.add_task(eq['H'])
Expand Down Expand Up @@ -511,17 +510,18 @@ class InitialValueSolver(SolverBase):
def __init__(self, problem, timestepper, enforce_real_cadence=100, warmup_iterations=10, **kw):
logger.debug('Beginning IVP instantiation')
super().__init__(problem, **kw)
self.enforce_real_cadence = enforce_real_cadence
self._wall_time_array = np.zeros(1, dtype=float)
self.init_time = self.get_wall_time()
if np.isrealobj(self.dtype.type()):
self.enforce_real_cadence = enforce_real_cadence
else:
self.enforce_real_cadence = None
self._bcast_array = np.zeros(1, dtype=float)
self.init_time = self.world_time
# Build LHS matrices
subsystems.build_subproblem_matrices(self, self.subproblems, ['M', 'L'])
# Compute total modes
local_modes = sum(ss.subproblem.pre_right.shape[1] for ss in self.subsystems)
self.total_modes = self.dist.comm.allreduce(local_modes, op=MPI.SUM)
# Create RHS handler
namespace = {}
self.evaluator = Evaluator(self.dist, namespace)
F_handler = self.evaluator.add_system_handler(iter=1, group='F')
for eq in problem.eqs:
F_handler.add_task(eq['F'])
Expand All @@ -548,19 +548,28 @@ def sim_time(self, t):
self._sim_time = t
self.problem.time['g'] = t

def get_wall_time(self):
self._wall_time_array[0] = time.time()
comm = self.dist.comm_cart
comm.Allreduce(MPI.IN_PLACE, self._wall_time_array, op=MPI.MAX)
return self._wall_time_array[0]
@property
def world_time(self):
if self.dist.comm.size == 1:
return time.time()
else:
# Broadcast time from root process
self._bcast_array[0] = time.time()
self.dist.comm_cart.Bcast(self._bcast_array, root=0)
return self._bcast_array[0]

@property
def wall_time(self):
"""Seconds ellapsed since instantiation."""
return self.world_time - self.init_time

@property
def proceed(self):
"""Check that current time and iteration pass stop conditions."""
if self.sim_time >= self.stop_sim_time:
logger.info('Simulation stop time reached.')
return False
elif (self.get_wall_time() - self.init_time) >= self.stop_wall_time:
elif self.wall_time >= self.stop_wall_time:
logger.info('Wall stop time reached.')
return False
elif self.iteration >= self.stop_iteration:
Expand Down Expand Up @@ -619,19 +628,18 @@ def step(self, dt):
if not np.isfinite(dt):
raise ValueError("Invalid timestep")
# Enforce Hermitian symmetry for real variables
if np.isrealobj(self.dtype.type()):
if self.enforce_real_cadence:
# Enforce for as many iterations as timestepper uses internally
if self.iteration % self.enforce_real_cadence < self.timestepper.steps:
self.enforce_hermitian_symmetry(self.state)
# Record times
wall_time = self.get_wall_time()
wall_time = self.wall_time
if self.iteration == self.initial_iteration:
self.start_time = wall_time
if self.iteration == self.initial_iteration + self.warmup_iterations:
self.warmup_time = wall_time
# Advance using timestepper
wall_elapsed = wall_time - self.init_time
self.timestepper.step(dt, wall_elapsed)
self.timestepper.step(dt, wall_time)
# Update iteration
self.iteration += 1
self.dt = dt
Expand Down Expand Up @@ -676,16 +684,14 @@ def evaluate_handlers(self, handlers=None, dt=0):
"""Evaluate specified list of handlers (all by default)."""
if handlers is None:
handlers = self.evaluator.handlers
wall_elapsed = self.get_wall_time() - self.init_time
self.evaluator.evaluate_handlers(handlers, iteration=self.iteration, wall_time=wall_elapsed, sim_time=self.sim_time, timestep=dt)
self.evaluator.evaluate_handlers(handlers, iteration=self.iteration, wall_time=self.wall_time, sim_time=self.sim_time, timestep=dt)

def log_stats(self, format=".4g"):
"""Log timing statistics with specified string formatting (optional)."""
log_time = self.get_wall_time()
log_time = self.wall_time
logger.info(f"Final iteration: {self.iteration}")
logger.info(f"Final sim time: {self.sim_time}")
setup_time = self.start_time - self.init_time
logger.info(f"Setup time (init - iter 0): {setup_time:{format}} sec")
logger.info(f"Setup time (init - iter 0): {self.start_time:{format}} sec")
if self.iteration >= self.initial_iteration + self.warmup_iterations:
warmup_time = self.warmup_time - self.start_time
run_time = log_time - self.warmup_time
Expand Down
4 changes: 2 additions & 2 deletions dedalus/tests/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_cartesian_output(dtype, dealias, output_scales, output_layout, parallel
output = solver.evaluator.add_file_handler(tempdir, iter=1, parallel=parallel)
for task in tasks:
output.add_task(task, layout=output_layout, name=str(task), scales=output_scales)
solver.evaluator.evaluate_handlers([output], sim_time=0, wall_time=0, world_time=0, timestep=0, iteration=0)
solver.evaluate_handlers([output])
# Check solution
errors = []
with h5py.File(f'{tempdir}/{tempdir}_s1.h5', mode='r') as file:
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_spherical_output(Nphi, Ntheta, Nr, k, dealias, dtype, basis, output_sca
output = solver.evaluator.add_file_handler(tempdir, iter=1, parallel=parallel)
for task in tasks:
output.add_task(task, layout='g', name=str(task), scales=output_scales)
solver.evaluator.evaluate_handlers([output], sim_time=0, wall_time=0, world_time=0, timestep=0, iteration=0)
solver.evaluate_handlers([output])
# Check solution
errors = []
with h5py.File(f'{tempdir}/{tempdir}_s1.h5', mode='r') as file:
Expand Down
2 changes: 1 addition & 1 deletion dedalus/tools/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def merge_setup(joint_file, proc_paths, virtual=False):
for i, proc_dim in enumerate(proc_dset.dims):
joint_dset.dims[i].label = proc_dim.label
if joint_dset.dims[i].label == 't':
for scalename in ['sim_time', 'world_time', 'wall_time', 'timestep', 'iteration', 'write_number']:
for scalename in ['sim_time', 'wall_time', 'timestep', 'iteration', 'write_number']:
scale = joint_file['scales'][scalename]
joint_dset.dims.create_scale(scale, scalename)
joint_dset.dims[i].attach_scale(scale)
Expand Down

0 comments on commit 2f4bcbb

Please sign in to comment.