diff --git a/pySDC/core/ConvergenceController.py b/pySDC/core/ConvergenceController.py index a112c1f9e6..bcd2fce0ee 100644 --- a/pySDC/core/ConvergenceController.py +++ b/pySDC/core/ConvergenceController.py @@ -42,6 +42,7 @@ def __init__(self, controller, params, description, **kwargs): params (dict): The params passed for this specific convergence controller description (dict): The description object used to instantiate the controller """ + self.controller = controller self.params = Pars(self.setup(controller, params, description)) params_ok, msg = self.check_parameters(controller, params, description) assert params_ok, f'{type(self).__name__} -- {msg}' @@ -425,94 +426,43 @@ def Recv(self, comm, source, buffer, **kwargs): return data - def reset_variable(self, controller, name, MPI=False, place=None, where=None, init=None): - """ - Utility function for resetting variables. This function will call the `add_variable` function with all the same - arguments, but with `allow_overwrite = True`. - - Args: - controller (pySDC.Controller): The controller - name (str): The name of the variable - MPI (bool): Whether to use MPI controller - place (object): The object you want to reset the variable of - where (list): List of strings containing a path to where you want to reset the variable - init: Initial value of the variable - - Returns: - None - """ - self.add_variable(controller, name, MPI, place, where, init, allow_overwrite=True) + def add_status_variable_to_step(self, key, value=None): + if type(self.controller).__name__ == 'controller_MPI': + steps = [self.controller.S] + else: + steps = self.controller.MS - def add_variable(self, controller, name, MPI=False, place=None, where=None, init=None, allow_overwrite=False): - """ - Add a variable to a frozen class. + steps[0].status.add_attr(key) - This function goes through the path to the destination of the variable recursively and adds it to all instances - that are possible in the path. For example, giving `where = ["MS", "levels", "status"]` will result in adding a - variable to the status object of all levels of all steps of the controller. + if value is not None: + self.set_step_status_variable(key, value) - Part of the functionality of the frozen class is to separate initialization and setting of variables. By - enforcing this, you can make sure not to overwrite already existing variables. Since this function is called - outside of the `__init__` function of the status objects, this can otherwise lead to bugs that are hard to find. - For this reason, you need to specifically set `allow_overwrite = True` if you want to forgo the check if the - variable already exists. This can be useful when resetting variables between steps, but make sure to set it to - `allow_overwrite = False` the first time you add a variable. + def set_step_status_variable(self, key, value): + if type(self.controller).__name__ == 'controller_MPI': + steps = [self.controller.S] + else: + steps = self.controller.MS - Args: - controller (pySDC.Controller): The controller - name (str): The name of the variable - MPI (bool): Whether to use MPI controller - place (object): The object you want to add the variable to - where (list): List of strings containing a path to where you want to add the variable - init: Initial value of the variable - allow_overwrite (bool): Allow overwriting the variables if they already exist or raise an exception + for S in steps: + S.status.__dict__[key] = value - Returns: - None - """ - where = ["S" if MPI else "MS", "levels", "status"] if where is None else where - place = controller if place is None else place + def add_status_variable_to_level(self, key, value=None): + if type(self.controller).__name__ == 'controller_MPI': + steps = [self.controller.S] + else: + steps = self.controller.MS - # check if we have arrived at the end of the path to the variable - if len(where) == 0: - variable_exitsts = name in place.__dict__.keys() - # check if the variable already exists and raise an error in case we are about to introduce a bug - if not allow_overwrite and variable_exitsts: - raise ValueError(f"Key \"{name}\" already exists in {place}! Please rename the variable in {self}") - # if we allow overwriting, but the variable does not exist already, we are violating the intended purpose - # of this function, so we also raise an error if someone should be so mad as to attempt this - elif allow_overwrite and not variable_exitsts: - raise ValueError(f"Key \"{name}\" is supposed to be overwritten in {place}, but it does not exist!") + steps[0].levels[0].status.add_attr(key) - # actually add or overwrite the variable - place.__dict__[name] = init + if value is not None: + self.set_level_status_variable(key, value) - # follow the path to the final destination recursively + def set_level_status_variable(self, key, value): + if type(self.controller).__name__ == 'controller_MPI': + steps = [self.controller.S] else: - # get all possible new places to continue the path - new_places = place.__dict__[where[0]] - - # continue all possible paths - if type(new_places) == list: - # loop through all possibilities - for new_place in new_places: - self.add_variable( - controller, - name, - MPI=MPI, - place=new_place, - where=where[1:], - init=init, - allow_overwrite=allow_overwrite, - ) - else: - # go to the only possible possibility - self.add_variable( - controller, - name, - MPI=MPI, - place=new_places, - where=where[1:], - init=init, - allow_overwrite=allow_overwrite, - ) + steps = self.controller.MS + + for S in steps: + for L in S.levels: + L.status.__dict__[key] = value diff --git a/pySDC/core/Level.py b/pySDC/core/Level.py index fe12c693f2..38217f85a8 100644 --- a/pySDC/core/Level.py +++ b/pySDC/core/Level.py @@ -21,8 +21,7 @@ def __init__(self, params): class _Status(FrozenClass): """ This class carries the status of the level. All variables that the core SDC / PFASST functionality depend on are - initialized here, while the convergence controllers are allowed to add more variables in a controlled fashion - later on using the `add_variable` function. + initialized here. """ def __init__(self): diff --git a/pySDC/core/Step.py b/pySDC/core/Step.py index cd45f81aa3..f983c6fa96 100644 --- a/pySDC/core/Step.py +++ b/pySDC/core/Step.py @@ -20,8 +20,7 @@ def __init__(self, params): class _Status(FrozenClass): """ This class carries the status of the step. All variables that the core SDC / PFASST functionality depend on are - initialized here, while the convergence controllers are allowed to add more variables in a controlled fashion - later on using the `add_variable` function. + initialized here. """ def __init__(self): diff --git a/pySDC/helpers/pysdc_helper.py b/pySDC/helpers/pysdc_helper.py index babf7ffd0b..a6161e062f 100644 --- a/pySDC/helpers/pysdc_helper.py +++ b/pySDC/helpers/pysdc_helper.py @@ -6,6 +6,8 @@ class FrozenClass(object): __isfrozen: Flag to freeze a class """ + attrs = [] + __isfrozen = False def __setattr__(self, key, value): @@ -18,10 +20,33 @@ def __setattr__(self, key, value): """ # check if attribute exists and if class is frozen - if self.__isfrozen and not hasattr(self, key): - raise TypeError("%r is a frozen class" % self) + if self.__isfrozen and not (key in self.attrs or hasattr(self, key)): + raise TypeError(f'{type(self).__name__!r} is a frozen class, cannot add attribute {key!r}') + object.__setattr__(self, key, value) + def __getattr__(self, key): + """ + This is needed in case the variables have not been initialized after adding. + """ + if key in self.attrs: + return None + else: + super().__getattr__(key) + + @classmethod + def add_attr(cls, key, raise_error_if_exists=False): + """ + Add a key to the allowed attributes of this class. + + Args: + key (str): The key to add + raise_error_if_exists (bool): Raise an error if the attribute already exists in the class + """ + if key in cls.attrs and raise_error_if_exists: + raise TypeError(f'Attribute {key!r} already exists in {cls.__name__}!') + cls.attrs += [key] + def _freeze(self): """ Function to freeze the class @@ -40,3 +65,9 @@ def get(self, key, default=None): __dict__.get(key, default) """ return self.__dict__.get(key, default) + + def __dir__(self): + """ + My hope is that some editors can use this for dynamic autocompletion. + """ + return super().__dir__() + self.attrs diff --git a/pySDC/implementations/convergence_controller_classes/basic_restarting.py b/pySDC/implementations/convergence_controller_classes/basic_restarting.py index 501f49cb1b..89678385c3 100644 --- a/pySDC/implementations/convergence_controller_classes/basic_restarting.py +++ b/pySDC/implementations/convergence_controller_classes/basic_restarting.py @@ -76,36 +76,26 @@ def setup(self, controller, params, description, **kwargs): return {**defaults, **super().setup(controller, params, description, **kwargs)} - def setup_status_variables(self, controller, **kwargs): + def setup_status_variables(self, *args, **kwargs): """ Add status variables for whether to restart now and how many times the step has been restarted in a row to the Steps - Args: - controller (pySDC.Controller): The controller - reset (bool): Whether the function is called for the first time or to reset - Returns: None """ - where = ["S" if 'comm' in kwargs.keys() else "MS", "status"] - self.add_variable(controller, name='restart', where=where, init=False) - self.add_variable(controller, name='restarts_in_a_row', where=where, init=0) + self.add_status_variable_to_step('restart', False) + self.add_status_variable_to_step('restarts_in_a_row', 0) - def reset_status_variables(self, controller, reset=False, **kwargs): + def reset_status_variables(self, *args, **kwargs): """ Add status variables for whether to restart now and how many times the step has been restarted in a row to the Steps - Args: - controller (pySDC.Controller): The controller - reset (bool): Whether the function is called for the first time or to reset - Returns: None """ - where = ["S" if 'comm' in kwargs.keys() else "MS", "status"] - self.reset_variable(controller, name='restart', where=where, init=False) + self.set_step_status_variable('restart', False) def dependencies(self, controller, description, **kwargs): """ diff --git a/pySDC/implementations/convergence_controller_classes/estimate_contraction_factor.py b/pySDC/implementations/convergence_controller_classes/estimate_contraction_factor.py index f78c9c9791..7a782d6ed6 100644 --- a/pySDC/implementations/convergence_controller_classes/estimate_contraction_factor.py +++ b/pySDC/implementations/convergence_controller_classes/estimate_contraction_factor.py @@ -39,41 +39,17 @@ def dependencies(self, controller, description, **kwargs): description=description, ) - def setup_status_variables(self, controller, **kwargs): + def setup_status_variables(self, *args, **kwargs): """ Add the embedded error, contraction factor and iterations to convergence variable to the status of the levels. - Args: - controller (pySDC.Controller): The controller - - Returns: - None - """ - if 'comm' in kwargs.keys(): - steps = [controller.S] - else: - if 'active_slots' in kwargs.keys(): - steps = [controller.MS[i] for i in kwargs['active_slots']] - else: - steps = controller.MS - where = ["levels", "status"] - for S in steps: - self.add_variable(S, name='error_embedded_estimate_last_iter', where=where, init=None) - self.add_variable(S, name='contraction_factor', where=where, init=None) - if self.params.e_tol is not None: - self.add_variable(S, name='iter_to_convergence', where=where, init=None) - - def reset_status_variables(self, controller, **kwargs): - """ - Reinitialize new status variables for the levels. - - Args: - controller (pySDC.controller): The controller - Returns: None """ - self.setup_status_variables(controller, **kwargs) + self.add_status_variable_to_level('error_embedded_estimate_last_iter') + self.add_status_variable_to_level('contraction_factor') + if self.params.e_tol is not None: + self.add_status_variable_to_level('iter_to_convergence') def post_iteration_processing(self, controller, S, **kwargs): """ diff --git a/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py b/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py index a4b51b4d5d..8aec1f79c3 100644 --- a/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py +++ b/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py @@ -114,19 +114,7 @@ def setup_status_variables(self, controller, **kwargs): Args: controller (pySDC.Controller): The controller """ - if 'comm' in kwargs.keys(): - steps = [controller.S] - else: - if 'active_slots' in kwargs.keys(): - steps = [controller.MS[i] for i in kwargs['active_slots']] - else: - steps = controller.MS - where = ["levels", "status"] - for S in steps: - self.add_variable(S, name='error_embedded_estimate', where=where, init=None) - - def reset_status_variables(self, controller, **kwargs): - self.setup_status_variables(controller, **kwargs) + self.add_status_variable_to_level('error_embedded_estimate') def post_iteration_processing(self, controller, S, **kwargs): """ @@ -350,7 +338,7 @@ def post_iteration_processing(self, controller, step, **kwargs): max([np.finfo(float).eps, abs(self.status.u[-1] - self.status.u[-2])]), ) - def setup_status_variables(self, controller, **kwargs): + def setup_status_variables(self, *args, **kwargs): """ Add the embedded error variable to the levels and add a status variable for previous steps. @@ -361,16 +349,4 @@ def setup_status_variables(self, controller, **kwargs): self.status.u = [] # the solutions of converged collocation problems self.status.iter = [] # the iteration in which the solution converged - if 'comm' in kwargs.keys(): - steps = [controller.S] - else: - if 'active_slots' in kwargs.keys(): - steps = [controller.MS[i] for i in kwargs['active_slots']] - else: - steps = controller.MS - where = ["levels", "status"] - for S in steps: - self.add_variable(S, name='error_embedded_estimate_collocation', where=where, init=None) - - def reset_status_variables(self, controller, **kwargs): - self.setup_status_variables(controller, **kwargs) + self.add_status_variable_to_level('error_embedded_estimate_collocation') diff --git a/pySDC/implementations/convergence_controller_classes/estimate_extrapolation_error.py b/pySDC/implementations/convergence_controller_classes/estimate_extrapolation_error.py index 0052f3b420..9118fcedf5 100644 --- a/pySDC/implementations/convergence_controller_classes/estimate_extrapolation_error.py +++ b/pySDC/implementations/convergence_controller_classes/estimate_extrapolation_error.py @@ -84,29 +84,7 @@ def setup_status_variables(self, controller, **kwargs): self.coeff.u = [None] * self.params.n self.coeff.f = [0.0] * self.params.n - self.reset_status_variables(controller, **kwargs) - return None - - def reset_status_variables(self, controller, **kwargs): - """ - Add variable for extrapolated error - - Args: - controller (pySDC.Controller): The controller - - Returns: - None - """ - if 'comm' in kwargs.keys(): - steps = [controller.S] - else: - if 'active_slots' in kwargs.keys(): - steps = [controller.MS[i] for i in kwargs['active_slots']] - else: - steps = controller.MS - where = ["levels", "status"] - for S in steps: - self.add_variable(S, name='error_extrapolation_estimate', where=where, init=None) + self.add_status_variable_to_level('error_extrapolation_estimate') def check_parameters(self, controller, params, description, **kwargs): """ diff --git a/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py b/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py index 7eb9989d96..3f8bcb2c9a 100644 --- a/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py +++ b/pySDC/implementations/convergence_controller_classes/estimate_polynomial_error.py @@ -1,8 +1,7 @@ import numpy as np from pySDC.core.Lagrange import LagrangeApproximation -from pySDC.core.ConvergenceController import ConvergenceController, Status -from pySDC.core.Collocation import CollBase +from pySDC.core.ConvergenceController import ConvergenceController class EstimatePolynomialError(ConvergenceController): @@ -62,28 +61,15 @@ def setup(self, controller, params, description, **kwargs): return defaults - def reset_status_variables(self, controller, **kwargs): + def reset_status_variables(self, *args, **kwargs): """ Add variable for embedded error - Args: - controller (pySDC.Controller): The controller - Returns: None """ - if 'comm' in kwargs.keys(): - steps = [controller.S] - else: - if 'active_slots' in kwargs.keys(): - steps = [controller.MS[i] for i in kwargs['active_slots']] - else: - steps = controller.MS - - where = ["levels", "status"] - for S in steps: - self.add_variable(S, name='error_embedded_estimate', where=where, init=None) - self.add_variable(S, name='order_embedded_estimate', where=where, init=None) + self.add_status_variable_to_level('error_embedded_estimate') + self.add_status_variable_to_level('order_embedded_estimate') def matmul(self, A, b): """