Skip to content
Merged
39 changes: 35 additions & 4 deletions pySDC/core/Controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import sys
import numpy as np

from pySDC.core import Hooks as hookclass
from pySDC.core.BaseTransfer import base_transfer
from pySDC.helpers.pysdc_helper import FrozenClass
from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
from pySDC.implementations.hooks.default_hook import DefaultHooks


# short helper class to add params as attributes
Expand Down Expand Up @@ -41,10 +41,15 @@ def __init__(self, controller_params, description):
"""

# check if we have a hook on this list. If not, use default class.
controller_params['hook_class'] = controller_params.get('hook_class', hookclass.hooks)
self.__hooks = controller_params['hook_class']()
self.__hooks = []
hook_classes = [DefaultHooks]
user_hooks = controller_params.get('hook_class', [])
hook_classes += user_hooks if type(user_hooks) == list else [user_hooks]
[self.add_hook(hook) for hook in hook_classes]
controller_params['hook_class'] = controller_params.get('hook_class', hook_classes)

self.hooks.pre_setup(step=None, level_number=None)
for hook in self.hooks:
hook.pre_setup(step=None, level_number=None)

self.params = _Pars(controller_params)

Expand Down Expand Up @@ -101,6 +106,20 @@ def __setup_custom_logger(level=None, log_to_file=None, fname=None):
else:
pass

def add_hook(self, hook):
"""
Add a hook to the controller which will be called in addition to all other hooks whenever something happens.
The hook is only added if a hook of the same class is not already present.

Args:
hook (pySDC.Hook): A hook class that is derived from the core hook class

Returns:
None
"""
if hook not in [type(me) for me in self.hooks]:
self.__hooks += [hook()]

def welcome_message(self):
out = (
"Welcome to the one and only, really very astonishing and 87.3% bug free"
Expand Down Expand Up @@ -308,3 +327,15 @@ def get_convergence_controllers_as_table(self, description):
out += f'\n{user_added}|{i:3} | {C.params.control_order:5} | {type(C).__name__}'

return out

def return_stats(self):
"""
Return the merged stats from all hooks

Returns:
dict: Merged stats from all hooks
"""
stats = {}
for hook in self.hooks:
stats = {**stats, **hook.return_stats()}
return stats
199 changes: 6 additions & 193 deletions pySDC/core/Hooks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import time
from collections import namedtuple


Expand All @@ -8,23 +7,13 @@ class hooks(object):
"""
Hook class to contain the functions called during the controller runs (e.g. for calling user-routines)

When deriving a custom hook from this class make sure to always call the parent method using e.g.
`super().post_step(step, level_number)`. Otherwise bugs may arise when using `filer_recomputed` from the stats
helper for post processing.

Attributes:
__t0_setup (float): private variable to get starting time of setup
__t0_run (float): private variable to get starting time of the run
__t0_predict (float): private variable to get starting time of the predictor
__t0_step (float): private variable to get starting time of the step
__t0_iteration (float): private variable to get starting time of the iteration
__t0_sweep (float): private variable to get starting time of the sweep
__t0_comm (list): private variable to get starting time of the communication
__t1_run (float): private variable to get end time of the run
__t1_predict (float): private variable to get end time of the predictor
__t1_step (float): private variable to get end time of the step
__t1_iteration (float): private variable to get end time of the iteration
__t1_sweep (float): private variable to get end time of the sweep
__t1_setup (float): private variable to get end time of setup
__t1_comm (list): private variable to hold timing of the communication (!)
__num_restarts (int): number of restarts of the current step
logger: logger instance for output
__num_restarts (int): number of restarts of the current step
__stats (dict): dictionary for gathering the statistics of a run
__entry (namedtuple): statistics entry containing all information to identify the value
"""
Expand All @@ -33,20 +22,6 @@ def __init__(self):
"""
Initialization routine
"""
self.__t0_setup = None
self.__t0_run = None
self.__t0_predict = None
self.__t0_step = None
self.__t0_iteration = None
self.__t0_sweep = None
self.__t0_comm = []
self.__t1_run = None
self.__t1_predict = None
self.__t1_step = None
self.__t1_iteration = None
self.__t1_sweep = None
self.__t1_setup = None
self.__t1_comm = []
self.__num_restarts = 0

self.logger = logging.getLogger('hooks')
Expand Down Expand Up @@ -130,7 +105,6 @@ def pre_setup(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
self.__t0_setup = time.perf_counter()

def pre_run(self, step, level_number):
"""
Expand All @@ -141,7 +115,6 @@ def pre_run(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
self.__t0_run = time.perf_counter()

def pre_predict(self, step, level_number):
"""
Expand All @@ -151,7 +124,7 @@ def pre_predict(self, step, level_number):
step (pySDC.Step.step): the current step
level_number (int): the current level number
"""
self.__t0_predict = time.perf_counter()
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0

def pre_step(self, step, level_number):
"""
Expand All @@ -162,7 +135,6 @@ def pre_step(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
self.__t0_step = time.perf_counter()

def pre_iteration(self, step, level_number):
"""
Expand All @@ -173,7 +145,6 @@ def pre_iteration(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
self.__t0_iteration = time.perf_counter()

def pre_sweep(self, step, level_number):
"""
Expand All @@ -184,7 +155,6 @@ def pre_sweep(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
self.__t0_sweep = time.perf_counter()

def pre_comm(self, step, level_number):
"""
Expand All @@ -195,16 +165,6 @@ def pre_comm(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
if len(self.__t0_comm) >= level_number + 1:
self.__t0_comm[level_number] = time.perf_counter()
else:
while len(self.__t0_comm) < level_number:
self.__t0_comm.append(None)
self.__t0_comm.append(time.perf_counter())
while len(self.__t1_comm) <= level_number:
self.__t1_comm.append(0.0)
assert len(self.__t0_comm) == level_number + 1
assert len(self.__t1_comm) == level_number + 1

def post_comm(self, step, level_number, add_to_stats=False):
"""
Expand All @@ -216,22 +176,6 @@ def post_comm(self, step, level_number, add_to_stats=False):
add_to_stats (bool): set if result should go to stats object
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
assert len(self.__t1_comm) >= level_number + 1
self.__t1_comm[level_number] += time.perf_counter() - self.__t0_comm[level_number]

if add_to_stats:
L = step.levels[level_number]

self.add_to_stats(
process=step.status.slot,
time=L.time,
level=L.level_index,
iter=step.status.iter,
sweep=L.status.sweep,
type='timing_comm',
value=self.__t1_comm[level_number],
)
self.__t1_comm[level_number] = 0.0

def post_sweep(self, step, level_number):
"""
Expand All @@ -242,39 +186,6 @@ def post_sweep(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
self.__t1_sweep = time.perf_counter()

L = step.levels[level_number]

self.logger.info(
'Process %2i on time %8.6f at stage %15s: Level: %s -- Iteration: %2i -- Sweep: %2i -- ' 'residual: %12.8e',
step.status.slot,
L.time,
step.status.stage,
L.level_index,
step.status.iter,
L.status.sweep,
L.status.residual,
)

self.add_to_stats(
process=step.status.slot,
time=L.time,
level=L.level_index,
iter=step.status.iter,
sweep=L.status.sweep,
type='residual_post_sweep',
value=L.status.residual,
)
self.add_to_stats(
process=step.status.slot,
time=L.time,
level=L.level_index,
iter=step.status.iter,
sweep=L.status.sweep,
type='timing_sweep',
value=self.__t1_sweep - self.__t0_sweep,
)

def post_iteration(self, step, level_number):
"""
Expand All @@ -286,29 +197,6 @@ def post_iteration(self, step, level_number):
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0

self.__t1_iteration = time.perf_counter()

L = step.levels[level_number]

self.add_to_stats(
process=step.status.slot,
time=L.time,
level=-1,
iter=step.status.iter,
sweep=L.status.sweep,
type='residual_post_iteration',
value=L.status.residual,
)
self.add_to_stats(
process=step.status.slot,
time=L.time,
level=L.level_index,
iter=step.status.iter,
sweep=L.status.sweep,
type='timing_iteration',
value=self.__t1_iteration - self.__t0_iteration,
)

def post_step(self, step, level_number):
"""
Default routine called after each step or block
Expand All @@ -319,44 +207,6 @@ def post_step(self, step, level_number):
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0

self.__t1_step = time.perf_counter()

L = step.levels[level_number]

self.add_to_stats(
process=step.status.slot,
time=L.time,
level=L.level_index,
iter=step.status.iter,
sweep=L.status.sweep,
type='timing_step',
value=self.__t1_step - self.__t0_step,
)
self.add_to_stats(
process=step.status.slot,
time=L.time,
level=-1,
iter=step.status.iter,
sweep=L.status.sweep,
type='niter',
value=step.status.iter,
)
self.add_to_stats(
process=step.status.slot,
time=L.time,
level=L.level_index,
iter=-1,
sweep=L.status.sweep,
type='residual_post_step',
value=L.status.residual,
)

# record the recomputed quantities at weird positions to make sure there is only one value for each step
for t in [L.time, L.time + L.dt]:
self.add_to_stats(
process=-1, time=t, level=-1, iter=-1, sweep=-1, type='_recomputed', value=step.status.get('restart')
)

def post_predict(self, step, level_number):
"""
Default routine called after each predictor
Expand All @@ -366,19 +216,6 @@ def post_predict(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
self.__t1_predict = time.perf_counter()

L = step.levels[level_number]

self.add_to_stats(
process=step.status.slot,
time=L.time,
level=L.level_index,
iter=step.status.iter,
sweep=L.status.sweep,
type='timing_predictor',
value=self.__t1_predict - self.__t0_predict,
)

def post_run(self, step, level_number):
"""
Expand All @@ -389,19 +226,6 @@ def post_run(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
self.__t1_run = time.perf_counter()

L = step.levels[level_number]

self.add_to_stats(
process=step.status.slot,
time=L.time,
level=L.level_index,
iter=step.status.iter,
sweep=L.status.sweep,
type='timing_run',
value=self.__t1_run - self.__t0_run,
)

def post_setup(self, step, level_number):
"""
Expand All @@ -412,14 +236,3 @@ def post_setup(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
self.__t1_setup = time.perf_counter()

self.add_to_stats(
process=-1,
time=-1,
level=-1,
iter=-1,
sweep=-1,
type='timing_setup',
value=self.__t1_setup - self.__t0_setup,
)
Loading