diff --git a/.github/workflows/ci_pipeline.yml b/.github/workflows/ci_pipeline.yml index 6515e8d272..b45cf0b402 100644 --- a/.github/workflows/ci_pipeline.yml +++ b/.github/workflows/ci_pipeline.yml @@ -35,23 +35,23 @@ jobs: run: | flakeheaven lint --benchmark pySDC - mirror_to_gitlab: +# mirror_to_gitlab: - runs-on: ubuntu-latest +# runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v1 +# steps: +# - name: Checkout +# uses: actions/checkout@v1 - - name: Mirror - uses: jakob-fritz/github2lab_action@main - env: - MODE: 'mirror' # Either 'mirror', 'get_status', or 'both' - GITLAB_TOKEN: ${{ secrets.GITLAB_SECRET_H }} - FORCE_PUSH: "true" - GITLAB_HOSTNAME: "codebase.helmholtz.cloud" - GITLAB_PROJECT_ID: "3525" - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} +# - name: Mirror +# uses: jakob-fritz/github2lab_action@main +# env: +# MODE: 'mirror' # Either 'mirror', 'get_status', or 'both' +# GITLAB_TOKEN: ${{ secrets.GITLAB_SECRET_H }} +# FORCE_PUSH: "true" +# GITLAB_HOSTNAME: "codebase.helmholtz.cloud" +# GITLAB_PROJECT_ID: "3525" +# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} user_cpu_tests_linux: runs-on: ubuntu-latest @@ -121,31 +121,31 @@ jobs: pytest --continue-on-collection-errors -v --durations=0 pySDC/tests -m ${{ matrix.env }} - wait_for_gitlab: - runs-on: ubuntu-latest +# wait_for_gitlab: +# runs-on: ubuntu-latest - needs: - - mirror_to_gitlab +# needs: +# - mirror_to_gitlab - steps: - - name: Wait - uses: jakob-fritz/github2lab_action@main - env: - MODE: 'get_status' # Either 'mirror', 'get_status', or 'both' - GITLAB_TOKEN: ${{ secrets.GITLAB_SECRET_H }} - FORCE_PUSH: "true" - GITLAB_HOSTNAME: "codebase.helmholtz.cloud" - GITLAB_PROJECT_ID: "3525" - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} +# steps: +# - name: Wait +# uses: jakob-fritz/github2lab_action@main +# env: +# MODE: 'get_status' # Either 'mirror', 'get_status', or 'both' +# GITLAB_TOKEN: ${{ secrets.GITLAB_SECRET_H }} +# FORCE_PUSH: "true" +# GITLAB_HOSTNAME: "codebase.helmholtz.cloud" +# GITLAB_PROJECT_ID: "3525" +# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -# - name: Get and prepare artifacts -# run: | -# pipeline_id=$(curl --header "PRIVATE-TOKEN: ${{ secrets.GITLAB_SECRET_H }}" --silent "https://gitlab.hzdr.de/api/v4/projects/3525/repository/commits/${{ github.head_ref || github.ref_name }}" | jq '.last_pipeline.id') -# job_id=$(curl --header "PRIVATE-TOKEN: ${{ secrets.GITLAB_SECRET_H }}" --silent "https://gitlab.hzdr.de/api/v4/projects/3525/pipelines/$pipeline_id/jobs" | jq '.[] | select( .name == "bundle" ) | select( .status == "success" ) | .id') -# curl --output artifacts.zip "https://gitlab.hzdr.de/api/v4/projects/3525/jobs/$job_id/artifacts" -# rm -rf data -# unzip artifacts.zip -# ls -ratl +# # - name: Get and prepare artifacts +# # run: | +# # pipeline_id=$(curl --header "PRIVATE-TOKEN: ${{ secrets.GITLAB_SECRET_H }}" --silent "https://gitlab.hzdr.de/api/v4/projects/3525/repository/commits/${{ github.head_ref || github.ref_name }}" | jq '.last_pipeline.id') +# # job_id=$(curl --header "PRIVATE-TOKEN: ${{ secrets.GITLAB_SECRET_H }}" --silent "https://gitlab.hzdr.de/api/v4/projects/3525/pipelines/$pipeline_id/jobs" | jq '.[] | select( .name == "bundle" ) | select( .status == "success" ) | .id') +# # curl --output artifacts.zip "https://gitlab.hzdr.de/api/v4/projects/3525/jobs/$job_id/artifacts" +# # rm -rf data +# # unzip artifacts.zip +# # ls -ratl post-processing: @@ -156,7 +156,7 @@ jobs: needs: - lint - user_cpu_tests_linux - - wait_for_gitlab +# - wait_for_gitlab defaults: run: @@ -188,7 +188,10 @@ jobs: run: | pip install genbadge[all] genbadge coverage -i coverage.xml -o htmlcov/coverage-badge.svg - + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 + # - name: Generate benchmark report # uses: pancetta/github-action-benchmark@v1 # if: ${{ (!contains(github.event.head_commit.message, '[CI-no-benchmarks]')) && (github.event_name == 'push') }} diff --git a/README.rst b/README.rst index 41ce82469e..e420177c73 100644 --- a/README.rst +++ b/README.rst @@ -1,5 +1,7 @@ |badge-ga| |badge-ossf| +|badge-cc| +|zenodo| Welcome to pySDC! ================= @@ -79,7 +81,9 @@ This project also received funding from the `German Federal Ministry of Educatio The project also received help from the `Helmholtz Platform for Research Software Engineering - Preparatory Study (HiRSE_PS) `_. -.. |badge-ga| image:: https://github.com/Parallel-in-Time/pySDC/actions/workflows/ci_pipeline.yml/badge.svg +.. |badge-ga| image:: https://github.com/Parallel-in-Time/pySDC/actions/workflows/ci_pipeline.yml/badge.svg?branch=master :target: https://github.com/Parallel-in-Time/pySDC/actions/workflows/ci_pipeline.yml .. |badge-ossf| image:: https://bestpractices.coreinfrastructure.org/projects/6909/badge - :target: https://bestpractices.coreinfrastructure.org/projects/6909 \ No newline at end of file + :target: https://bestpractices.coreinfrastructure.org/projects/6909 +.. |badge-cc| image:: https://codecov.io/gh/Parallel-in-Time/pySDC/branch/master/graph/badge.svg?token=hpP18dmtgS + :target: https://codecov.io/gh/Parallel-in-Time/pySDC diff --git a/pySDC/core/Controller.py b/pySDC/core/Controller.py index c5257d2316..e4d213e475 100644 --- a/pySDC/core/Controller.py +++ b/pySDC/core/Controller.py @@ -3,10 +3,10 @@ import sys import numpy as np -from pySDC.core import Hooks as hookclass from pySDC.core.BaseTransfer import base_transfer from pySDC.helpers.pysdc_helper import FrozenClass from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence +from pySDC.implementations.hooks.default_hook import DefaultHooks # short helper class to add params as attributes @@ -41,10 +41,15 @@ def __init__(self, controller_params, description): """ # check if we have a hook on this list. If not, use default class. - controller_params['hook_class'] = controller_params.get('hook_class', hookclass.hooks) - self.__hooks = controller_params['hook_class']() + self.__hooks = [] + hook_classes = [DefaultHooks] + user_hooks = controller_params.get('hook_class', []) + hook_classes += user_hooks if type(user_hooks) == list else [user_hooks] + [self.add_hook(hook) for hook in hook_classes] + controller_params['hook_class'] = controller_params.get('hook_class', hook_classes) - self.hooks.pre_setup(step=None, level_number=None) + for hook in self.hooks: + hook.pre_setup(step=None, level_number=None) self.params = _Pars(controller_params) @@ -101,6 +106,20 @@ def __setup_custom_logger(level=None, log_to_file=None, fname=None): else: pass + def add_hook(self, hook): + """ + Add a hook to the controller which will be called in addition to all other hooks whenever something happens. + The hook is only added if a hook of the same class is not already present. + + Args: + hook (pySDC.Hook): A hook class that is derived from the core hook class + + Returns: + None + """ + if hook not in [type(me) for me in self.hooks]: + self.__hooks += [hook()] + def welcome_message(self): out = ( "Welcome to the one and only, really very astonishing and 87.3% bug free" @@ -308,3 +327,15 @@ def get_convergence_controllers_as_table(self, description): out += f'\n{user_added}|{i:3} | {C.params.control_order:5} | {type(C).__name__}' return out + + def return_stats(self): + """ + Return the merged stats from all hooks + + Returns: + dict: Merged stats from all hooks + """ + stats = {} + for hook in self.hooks: + stats = {**stats, **hook.return_stats()} + return stats diff --git a/pySDC/core/Hooks.py b/pySDC/core/Hooks.py index 5268278529..dfff1045af 100644 --- a/pySDC/core/Hooks.py +++ b/pySDC/core/Hooks.py @@ -1,5 +1,4 @@ import logging -import time from collections import namedtuple @@ -8,23 +7,13 @@ class hooks(object): """ Hook class to contain the functions called during the controller runs (e.g. for calling user-routines) + When deriving a custom hook from this class make sure to always call the parent method using e.g. + `super().post_step(step, level_number)`. Otherwise bugs may arise when using `filer_recomputed` from the stats + helper for post processing. + Attributes: - __t0_setup (float): private variable to get starting time of setup - __t0_run (float): private variable to get starting time of the run - __t0_predict (float): private variable to get starting time of the predictor - __t0_step (float): private variable to get starting time of the step - __t0_iteration (float): private variable to get starting time of the iteration - __t0_sweep (float): private variable to get starting time of the sweep - __t0_comm (list): private variable to get starting time of the communication - __t1_run (float): private variable to get end time of the run - __t1_predict (float): private variable to get end time of the predictor - __t1_step (float): private variable to get end time of the step - __t1_iteration (float): private variable to get end time of the iteration - __t1_sweep (float): private variable to get end time of the sweep - __t1_setup (float): private variable to get end time of setup - __t1_comm (list): private variable to hold timing of the communication (!) - __num_restarts (int): number of restarts of the current step logger: logger instance for output + __num_restarts (int): number of restarts of the current step __stats (dict): dictionary for gathering the statistics of a run __entry (namedtuple): statistics entry containing all information to identify the value """ @@ -33,20 +22,6 @@ def __init__(self): """ Initialization routine """ - self.__t0_setup = None - self.__t0_run = None - self.__t0_predict = None - self.__t0_step = None - self.__t0_iteration = None - self.__t0_sweep = None - self.__t0_comm = [] - self.__t1_run = None - self.__t1_predict = None - self.__t1_step = None - self.__t1_iteration = None - self.__t1_sweep = None - self.__t1_setup = None - self.__t1_comm = [] self.__num_restarts = 0 self.logger = logging.getLogger('hooks') @@ -130,7 +105,6 @@ def pre_setup(self, step, level_number): level_number (int): the current level number """ self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0 - self.__t0_setup = time.perf_counter() def pre_run(self, step, level_number): """ @@ -141,7 +115,6 @@ def pre_run(self, step, level_number): level_number (int): the current level number """ self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0 - self.__t0_run = time.perf_counter() def pre_predict(self, step, level_number): """ @@ -151,7 +124,7 @@ def pre_predict(self, step, level_number): step (pySDC.Step.step): the current step level_number (int): the current level number """ - self.__t0_predict = time.perf_counter() + self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0 def pre_step(self, step, level_number): """ @@ -162,7 +135,6 @@ def pre_step(self, step, level_number): level_number (int): the current level number """ self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0 - self.__t0_step = time.perf_counter() def pre_iteration(self, step, level_number): """ @@ -173,7 +145,6 @@ def pre_iteration(self, step, level_number): level_number (int): the current level number """ self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0 - self.__t0_iteration = time.perf_counter() def pre_sweep(self, step, level_number): """ @@ -184,7 +155,6 @@ def pre_sweep(self, step, level_number): level_number (int): the current level number """ self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0 - self.__t0_sweep = time.perf_counter() def pre_comm(self, step, level_number): """ @@ -195,16 +165,6 @@ def pre_comm(self, step, level_number): level_number (int): the current level number """ self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0 - if len(self.__t0_comm) >= level_number + 1: - self.__t0_comm[level_number] = time.perf_counter() - else: - while len(self.__t0_comm) < level_number: - self.__t0_comm.append(None) - self.__t0_comm.append(time.perf_counter()) - while len(self.__t1_comm) <= level_number: - self.__t1_comm.append(0.0) - assert len(self.__t0_comm) == level_number + 1 - assert len(self.__t1_comm) == level_number + 1 def post_comm(self, step, level_number, add_to_stats=False): """ @@ -216,22 +176,6 @@ def post_comm(self, step, level_number, add_to_stats=False): add_to_stats (bool): set if result should go to stats object """ self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0 - assert len(self.__t1_comm) >= level_number + 1 - self.__t1_comm[level_number] += time.perf_counter() - self.__t0_comm[level_number] - - if add_to_stats: - L = step.levels[level_number] - - self.add_to_stats( - process=step.status.slot, - time=L.time, - level=L.level_index, - iter=step.status.iter, - sweep=L.status.sweep, - type='timing_comm', - value=self.__t1_comm[level_number], - ) - self.__t1_comm[level_number] = 0.0 def post_sweep(self, step, level_number): """ @@ -242,39 +186,6 @@ def post_sweep(self, step, level_number): level_number (int): the current level number """ self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0 - self.__t1_sweep = time.perf_counter() - - L = step.levels[level_number] - - self.logger.info( - 'Process %2i on time %8.6f at stage %15s: Level: %s -- Iteration: %2i -- Sweep: %2i -- ' 'residual: %12.8e', - step.status.slot, - L.time, - step.status.stage, - L.level_index, - step.status.iter, - L.status.sweep, - L.status.residual, - ) - - self.add_to_stats( - process=step.status.slot, - time=L.time, - level=L.level_index, - iter=step.status.iter, - sweep=L.status.sweep, - type='residual_post_sweep', - value=L.status.residual, - ) - self.add_to_stats( - process=step.status.slot, - time=L.time, - level=L.level_index, - iter=step.status.iter, - sweep=L.status.sweep, - type='timing_sweep', - value=self.__t1_sweep - self.__t0_sweep, - ) def post_iteration(self, step, level_number): """ @@ -286,29 +197,6 @@ def post_iteration(self, step, level_number): """ self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0 - self.__t1_iteration = time.perf_counter() - - L = step.levels[level_number] - - self.add_to_stats( - process=step.status.slot, - time=L.time, - level=-1, - iter=step.status.iter, - sweep=L.status.sweep, - type='residual_post_iteration', - value=L.status.residual, - ) - self.add_to_stats( - process=step.status.slot, - time=L.time, - level=L.level_index, - iter=step.status.iter, - sweep=L.status.sweep, - type='timing_iteration', - value=self.__t1_iteration - self.__t0_iteration, - ) - def post_step(self, step, level_number): """ Default routine called after each step or block @@ -319,44 +207,6 @@ def post_step(self, step, level_number): """ self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0 - self.__t1_step = time.perf_counter() - - L = step.levels[level_number] - - self.add_to_stats( - process=step.status.slot, - time=L.time, - level=L.level_index, - iter=step.status.iter, - sweep=L.status.sweep, - type='timing_step', - value=self.__t1_step - self.__t0_step, - ) - self.add_to_stats( - process=step.status.slot, - time=L.time, - level=-1, - iter=step.status.iter, - sweep=L.status.sweep, - type='niter', - value=step.status.iter, - ) - self.add_to_stats( - process=step.status.slot, - time=L.time, - level=L.level_index, - iter=-1, - sweep=L.status.sweep, - type='residual_post_step', - value=L.status.residual, - ) - - # record the recomputed quantities at weird positions to make sure there is only one value for each step - for t in [L.time, L.time + L.dt]: - self.add_to_stats( - process=-1, time=t, level=-1, iter=-1, sweep=-1, type='_recomputed', value=step.status.get('restart') - ) - def post_predict(self, step, level_number): """ Default routine called after each predictor @@ -366,19 +216,6 @@ def post_predict(self, step, level_number): level_number (int): the current level number """ self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0 - self.__t1_predict = time.perf_counter() - - L = step.levels[level_number] - - self.add_to_stats( - process=step.status.slot, - time=L.time, - level=L.level_index, - iter=step.status.iter, - sweep=L.status.sweep, - type='timing_predictor', - value=self.__t1_predict - self.__t0_predict, - ) def post_run(self, step, level_number): """ @@ -389,19 +226,6 @@ def post_run(self, step, level_number): level_number (int): the current level number """ self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0 - self.__t1_run = time.perf_counter() - - L = step.levels[level_number] - - self.add_to_stats( - process=step.status.slot, - time=L.time, - level=L.level_index, - iter=step.status.iter, - sweep=L.status.sweep, - type='timing_run', - value=self.__t1_run - self.__t0_run, - ) def post_setup(self, step, level_number): """ @@ -412,14 +236,3 @@ def post_setup(self, step, level_number): level_number (int): the current level number """ self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0 - self.__t1_setup = time.perf_counter() - - self.add_to_stats( - process=-1, - time=-1, - level=-1, - iter=-1, - sweep=-1, - type='timing_setup', - value=self.__t1_setup - self.__t0_setup, - ) diff --git a/pySDC/implementations/controller_classes/controller_MPI.py b/pySDC/implementations/controller_classes/controller_MPI.py index 9383e87608..7dc6bcd781 100644 --- a/pySDC/implementations/controller_classes/controller_MPI.py +++ b/pySDC/implementations/controller_classes/controller_MPI.py @@ -83,7 +83,8 @@ def run(self, u0, t0, Tend): """ # reset stats to prevent double entries from old runs - self.hooks.reset_stats() + for hook in self.hooks: + hook.reset_stats() # find active processes and put into new communicator rank = self.comm.Get_rank() @@ -111,10 +112,12 @@ def run(self, u0, t0, Tend): uend = u0 # call post-setup hook - self.hooks.post_setup(step=None, level_number=None) + for hook in self.hooks: + hook.post_setup(step=None, level_number=None) # call pre-run hook - self.hooks.pre_run(step=self.S, level_number=0) + for hook in self.hooks: + hook.pre_run(step=self.S, level_number=0) comm_active.Barrier() @@ -162,11 +165,12 @@ def run(self, u0, t0, Tend): self.restart_block(num_procs, time, uend, comm=comm_active) # call post-run hook - self.hooks.post_run(step=self.S, level_number=0) + for hook in self.hooks: + hook.post_run(step=self.S, level_number=0) comm_active.Free() - return uend, self.hooks.return_stats() + return uend, self.return_stats() def restart_block(self, size, time, u0, comm): """ @@ -243,7 +247,8 @@ def send_full(self, comm=None, blocking=False, level=None, add_to_stats=False): level: the level number add_to_stats: a flag to end recording data in the hooks (defaults to False) """ - self.hooks.pre_comm(step=self.S, level_number=level) + for hook in self.hooks: + hook.pre_comm(step=self.S, level_number=level) if not blocking: self.wait_with_interrupt(request=self.req_send[level]) @@ -272,7 +277,8 @@ def send_full(self, comm=None, blocking=False, level=None, add_to_stats=False): if self.S.status.force_done: return None - self.hooks.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats) + for hook in self.hooks: + hook.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats) def recv_full(self, comm, level=None, add_to_stats=False): """ @@ -284,7 +290,8 @@ def recv_full(self, comm, level=None, add_to_stats=False): add_to_stats: a flag to end recording data in the hooks (defaults to False) """ - self.hooks.pre_comm(step=self.S, level_number=level) + for hook in self.hooks: + hook.pre_comm(step=self.S, level_number=level) if not self.S.status.first and not self.S.status.prev_done: self.logger.debug( 'recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s' @@ -299,7 +306,8 @@ def recv_full(self, comm, level=None, add_to_stats=False): ) self.recv(target=self.S.levels[level], source=self.S.prev, tag=level * 100 + self.S.status.iter, comm=comm) - self.hooks.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats) + for hook in self.hooks: + hook.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats) def wait_with_interrupt(self, request): """ @@ -334,7 +342,8 @@ def check_iteration_estimate(self, comm): diff_new = max(diff_new, abs(L.uold[m] - L.u[m])) # Send forward diff - self.hooks.pre_comm(step=self.S, level_number=0) + for hook in self.hooks: + hook.pre_comm(step=self.S, level_number=0) self.wait_with_interrupt(request=self.req_diff) if self.S.status.force_done: @@ -360,7 +369,8 @@ def check_iteration_estimate(self, comm): tmp = np.array(diff_new, dtype=float) self.req_diff = comm.Issend((tmp, MPI.DOUBLE), dest=self.S.next, tag=999) - self.hooks.post_comm(step=self.S, level_number=0) + for hook in self.hooks: + hook.post_comm(step=self.S, level_number=0) # Store values from first iteration if self.S.status.iter == 1: @@ -382,14 +392,18 @@ def check_iteration_estimate(self, comm): if np.ceil(Kest_glob) <= self.S.status.iter: if self.S.status.last: self.logger.debug(f'{self.S.status.slot} is done, broadcasting..') - self.hooks.pre_comm(step=self.S, level_number=0) + for hook in self.hooks: + hook.pre_comm(step=self.S, level_number=0) comm.Ibcast((np.array([1]), MPI.INT), root=self.S.status.slot).Wait() - self.hooks.post_comm(step=self.S, level_number=0, add_to_stats=True) + for hook in self.hooks: + hook.post_comm(step=self.S, level_number=0, add_to_stats=True) self.logger.debug(f'{self.S.status.slot} is done, broadcasting done') self.S.status.done = True else: - self.hooks.pre_comm(step=self.S, level_number=0) - self.hooks.post_comm(step=self.S, level_number=0, add_to_stats=True) + for hook in self.hooks: + hook.pre_comm(step=self.S, level_number=0) + for hook in self.hooks: + hook.post_comm(step=self.S, level_number=0, add_to_stats=True) def pfasst(self, comm, num_procs): """ @@ -420,7 +434,8 @@ def pfasst(self, comm, num_procs): self.logger.debug(f'Rewinding {self.S.status.slot} after {stage}..') self.S.levels[0].u[1:] = self.S.levels[0].uold[1:] - self.hooks.post_iteration(step=self.S, level_number=0) + for hook in self.hooks: + hook.post_iteration(step=self.S, level_number=0) for req in self.req_send: if req is not None and req != MPI.REQUEST_NULL: @@ -431,7 +446,8 @@ def pfasst(self, comm, num_procs): self.req_diff.Cancel() self.S.status.stage = 'DONE' - self.hooks.post_step(step=self.S, level_number=0) + for hook in self.hooks: + hook.post_step(step=self.S, level_number=0) else: # Start cycling, if not interrupted @@ -453,7 +469,8 @@ def spread(self, comm, num_procs): """ # first stage: spread values - self.hooks.pre_step(step=self.S, level_number=0) + for hook in self.hooks: + hook.pre_step(step=self.S, level_number=0) # call predictor from sweeper self.S.levels[0].sweep.predict() @@ -476,7 +493,8 @@ def predict(self, comm, num_procs): Predictor phase """ - self.hooks.pre_predict(step=self.S, level_number=0) + for hook in self.hooks: + hook.pre_predict(step=self.S, level_number=0) if self.params.predict_type is None: pass @@ -568,7 +586,8 @@ def predict(self, comm, num_procs): else: raise ControllerError('Wrong predictor type, got %s' % self.params.predict_type) - self.hooks.post_predict(step=self.S, level_number=0) + for hook in self.hooks: + hook.post_predict(step=self.S, level_number=0) # update stage self.S.status.stage = 'IT_CHECK' @@ -598,7 +617,8 @@ def it_check(self, comm, num_procs): return None if self.S.status.iter > 0: - self.hooks.post_iteration(step=self.S, level_number=0) + for hook in self.hooks: + hook.post_iteration(step=self.S, level_number=0) # decide if the step is done, needs to be restarted and other things convergence related for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: @@ -611,7 +631,8 @@ def it_check(self, comm, num_procs): # increment iteration count here (and only here) self.S.status.iter += 1 - self.hooks.pre_iteration(step=self.S, level_number=0) + for hook in self.hooks: + hook.pre_iteration(step=self.S, level_number=0) for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: C.pre_iteration_processing(self, self.S, comm=comm) @@ -648,7 +669,8 @@ def it_check(self, comm, num_procs): if self.req_diff is not None: self.req_diff.Cancel() - self.hooks.post_step(step=self.S, level_number=0) + for hook in self.hooks: + hook.post_step(step=self.S, level_number=0) self.S.status.stage = 'DONE' def it_fine(self, comm, num_procs): @@ -675,10 +697,12 @@ def it_fine(self, comm, num_procs): if self.S.status.force_done: return None - self.hooks.pre_sweep(step=self.S, level_number=0) + for hook in self.hooks: + hook.pre_sweep(step=self.S, level_number=0) self.S.levels[0].sweep.update_nodes() self.S.levels[0].sweep.compute_residual() - self.hooks.post_sweep(step=self.S, level_number=0) + for hook in self.hooks: + hook.post_sweep(step=self.S, level_number=0) # update stage self.S.status.stage = 'IT_CHECK' @@ -705,10 +729,12 @@ def it_down(self, comm, num_procs): if self.S.status.force_done: return None - self.hooks.pre_sweep(step=self.S, level_number=l) + for hook in self.hooks: + hook.pre_sweep(step=self.S, level_number=l) self.S.levels[l].sweep.update_nodes() self.S.levels[l].sweep.compute_residual() - self.hooks.post_sweep(step=self.S, level_number=l) + for hook in self.hooks: + hook.post_sweep(step=self.S, level_number=l) # transfer further down the hierarchy self.S.transfer(source=self.S.levels[l], target=self.S.levels[l + 1]) @@ -727,14 +753,16 @@ def it_coarse(self, comm, num_procs): return None # do the sweep - self.hooks.pre_sweep(step=self.S, level_number=len(self.S.levels) - 1) + for hook in self.hooks: + hook.pre_sweep(step=self.S, level_number=len(self.S.levels) - 1) assert self.S.levels[-1].params.nsweeps == 1, ( 'ERROR: this controller can only work with one sweep on the coarse level, got %s' % self.S.levels[-1].params.nsweeps ) self.S.levels[-1].sweep.update_nodes() self.S.levels[-1].sweep.compute_residual() - self.hooks.post_sweep(step=self.S, level_number=len(self.S.levels) - 1) + for hook in self.hooks: + hook.post_sweep(step=self.S, level_number=len(self.S.levels) - 1) self.S.levels[-1].sweep.compute_end_point() # send to next step @@ -774,10 +802,12 @@ def it_up(self, comm, num_procs): if self.S.status.force_done: return None - self.hooks.pre_sweep(step=self.S, level_number=l - 1) + for hook in self.hooks: + hook.pre_sweep(step=self.S, level_number=l - 1) self.S.levels[l - 1].sweep.update_nodes() self.S.levels[l - 1].sweep.compute_residual() - self.hooks.post_sweep(step=self.S, level_number=l - 1) + for hook in self.hooks: + hook.post_sweep(step=self.S, level_number=l - 1) # update stage self.S.status.stage = 'IT_FINE' diff --git a/pySDC/implementations/controller_classes/controller_nonMPI.py b/pySDC/implementations/controller_classes/controller_nonMPI.py index 2144717298..4cbd86260e 100644 --- a/pySDC/implementations/controller_classes/controller_nonMPI.py +++ b/pySDC/implementations/controller_classes/controller_nonMPI.py @@ -100,7 +100,8 @@ def run(self, u0, t0, Tend): # some initializations and reset of statistics uend = None num_procs = len(self.MS) - self.hooks.reset_stats() + for hook in self.hooks: + hook.reset_stats() # initial ordering of the steps: 0,1,...,Np-1 slots = list(range(num_procs)) @@ -120,11 +121,13 @@ def run(self, u0, t0, Tend): # initialize block of steps with u0 self.restart_block(active_slots, time, u0) - self.hooks.post_setup(step=None, level_number=None) + for hook in self.hooks: + hook.post_setup(step=None, level_number=None) # call pre-run hook for S in self.MS: - self.hooks.pre_run(step=S, level_number=0) + for hook in self.hooks: + hook.pre_run(step=S, level_number=0) # main loop: as long as at least one step is still active (time < Tend), do something while any(active): @@ -168,9 +171,10 @@ def run(self, u0, t0, Tend): # call post-run hook for S in self.MS: - self.hooks.post_run(step=S, level_number=0) + for hook in self.hooks: + hook.post_run(step=S, level_number=0) - return uend, self.hooks.return_stats() + return uend, self.return_stats() def restart_block(self, active_slots, time, u0): """ @@ -241,13 +245,16 @@ def send(source, tag): source.sweep.compute_end_point() source.tag = cp.deepcopy(tag) - self.hooks.pre_comm(step=S, level_number=level) + for hook in self.hooks: + hook.pre_comm(step=S, level_number=level) if not S.status.last: self.logger.debug( 'Process %2i provides data on level %2i with tag %s' % (S.status.slot, level, S.status.iter) ) send(S.levels[level], tag=(level, S.status.iter, S.status.slot)) - self.hooks.post_comm(step=S, level_number=level, add_to_stats=add_to_stats) + + for hook in self.hooks: + hook.post_comm(step=S, level_number=level, add_to_stats=add_to_stats) def recv_full(self, S, level=None, add_to_stats=False): """ @@ -276,14 +283,16 @@ def recv(target, source, tag=None): # re-evaluate f on left interval boundary target.f[0] = target.prob.eval_f(target.u[0], target.time) - self.hooks.pre_comm(step=S, level_number=level) + for hook in self.hooks: + hook.pre_comm(step=S, level_number=level) if not S.status.prev_done and not S.status.first: self.logger.debug( 'Process %2i receives from %2i on level %2i with tag %s' % (S.status.slot, S.prev.status.slot, level, S.status.iter) ) recv(S.levels[level], S.prev.levels[level], tag=(level, S.status.iter, S.prev.status.slot)) - self.hooks.post_comm(step=S, level_number=level, add_to_stats=add_to_stats) + for hook in self.hooks: + hook.post_comm(step=S, level_number=level, add_to_stats=add_to_stats) def pfasst(self, local_MS_active): """ @@ -333,7 +342,8 @@ def spread(self, local_MS_running): for S in local_MS_running: # first stage: spread values - self.hooks.pre_step(step=S, level_number=0) + for hook in self.hooks: + hook.pre_step(step=S, level_number=0) # call predictor from sweeper S.levels[0].sweep.predict() @@ -356,7 +366,8 @@ def predict(self, local_MS_running): """ for S in local_MS_running: - self.hooks.pre_predict(step=S, level_number=0) + for hook in self.hooks: + hook.pre_predict(step=S, level_number=0) if self.params.predict_type is None: pass @@ -464,7 +475,8 @@ def predict(self, local_MS_running): raise ControllerError('Wrong predictor type, got %s' % self.params.predict_type) for S in local_MS_running: - self.hooks.post_predict(step=S, level_number=0) + for hook in self.hooks: + hook.post_predict(step=S, level_number=0) for S in local_MS_running: # update stage @@ -490,7 +502,8 @@ def it_check(self, local_MS_running): for S in local_MS_running: if S.status.iter > 0: - self.hooks.post_iteration(step=S, level_number=0) + for hook in self.hooks: + hook.post_iteration(step=S, level_number=0) # decide if the step is done, needs to be restarted and other things convergence related for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: @@ -499,20 +512,25 @@ def it_check(self, local_MS_running): for S in local_MS_running: if not S.status.first: - self.hooks.pre_comm(step=S, level_number=0) + for hook in self.hooks: + hook.pre_comm(step=S, level_number=0) S.status.prev_done = S.prev.status.done # "communicate" - self.hooks.post_comm(step=S, level_number=0, add_to_stats=True) + for hook in self.hooks: + hook.post_comm(step=S, level_number=0, add_to_stats=True) S.status.done = S.status.done and S.status.prev_done if self.params.all_to_done: - self.hooks.pre_comm(step=S, level_number=0) + for hook in self.hooks: + hook.pre_comm(step=S, level_number=0) S.status.done = all([T.status.done for T in local_MS_running]) - self.hooks.post_comm(step=S, level_number=0, add_to_stats=True) + for hook in self.hooks: + hook.post_comm(step=S, level_number=0, add_to_stats=True) if not S.status.done: # increment iteration count here (and only here) S.status.iter += 1 - self.hooks.pre_iteration(step=S, level_number=0) + for hook in self.hooks: + hook.pre_iteration(step=S, level_number=0) for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: C.pre_iteration_processing(self, S) @@ -525,7 +543,8 @@ def it_check(self, local_MS_running): S.status.stage = 'IT_COARSE' # serial MSSDC (Gauss-like) else: S.levels[0].sweep.compute_end_point() - self.hooks.post_step(step=S, level_number=0) + for hook in self.hooks: + hook.post_step(step=S, level_number=0) S.status.stage = 'DONE' for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]: @@ -555,10 +574,12 @@ def it_fine(self, local_MS_running): for S in local_MS_running: # standard sweep workflow: update nodes, compute residual, log progress - self.hooks.pre_sweep(step=S, level_number=0) + for hook in self.hooks: + hook.pre_sweep(step=S, level_number=0) S.levels[0].sweep.update_nodes() S.levels[0].sweep.compute_residual() - self.hooks.post_sweep(step=S, level_number=0) + for hook in self.hooks: + hook.post_sweep(step=S, level_number=0) for S in local_MS_running: # update stage @@ -589,10 +610,12 @@ def it_down(self, local_MS_running): self.recv_full(S, level=l) for S in local_MS_running: - self.hooks.pre_sweep(step=S, level_number=l) + for hook in self.hooks: + hook.pre_sweep(step=S, level_number=l) S.levels[l].sweep.update_nodes() S.levels[l].sweep.compute_residual() - self.hooks.post_sweep(step=S, level_number=l) + for hook in self.hooks: + hook.post_sweep(step=S, level_number=l) for S in local_MS_running: # transfer further down the hierarchy @@ -616,10 +639,12 @@ def it_coarse(self, local_MS_running): self.recv_full(S, level=len(S.levels) - 1) # do the sweep - self.hooks.pre_sweep(step=S, level_number=len(S.levels) - 1) + for hook in self.hooks: + hook.pre_sweep(step=S, level_number=len(S.levels) - 1) S.levels[-1].sweep.update_nodes() S.levels[-1].sweep.compute_residual() - self.hooks.post_sweep(step=S, level_number=len(S.levels) - 1) + for hook in self.hooks: + hook.post_sweep(step=S, level_number=len(S.levels) - 1) # send to succ step self.send_full(S, level=len(S.levels) - 1, add_to_stats=True) @@ -657,10 +682,12 @@ def it_up(self, local_MS_running): self.recv_full(S, level=l - 1, add_to_stats=(k == self.nsweeps[l - 1] - 1)) for S in local_MS_running: - self.hooks.pre_sweep(step=S, level_number=l - 1) + for hook in self.hooks: + hook.pre_sweep(step=S, level_number=l - 1) S.levels[l - 1].sweep.update_nodes() S.levels[l - 1].sweep.compute_residual() - self.hooks.post_sweep(step=S, level_number=l - 1) + for hook in self.hooks: + hook.post_sweep(step=S, level_number=l - 1) for S in local_MS_running: # update stage diff --git a/pySDC/implementations/convergence_controller_classes/adaptivity.py b/pySDC/implementations/convergence_controller_classes/adaptivity.py index 06fc27f073..c39bf5fae2 100644 --- a/pySDC/implementations/convergence_controller_classes/adaptivity.py +++ b/pySDC/implementations/convergence_controller_classes/adaptivity.py @@ -7,6 +7,7 @@ BasicRestartingNonMPI, ) from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI +from pySDC.implementations.hooks.log_step_size import LogStepSize class AdaptivityBase(ConvergenceController): @@ -35,6 +36,7 @@ def setup(self, controller, params, description, **kwargs): "control_order": -50, "beta": 0.9, } + controller.add_hook(LogStepSize) return {**defaults, **params} def dependencies(self, controller, description, **kwargs): diff --git a/pySDC/implementations/convergence_controller_classes/check_convergence.py b/pySDC/implementations/convergence_controller_classes/check_convergence.py index f969779d07..e6a1907a9a 100644 --- a/pySDC/implementations/convergence_controller_classes/check_convergence.py +++ b/pySDC/implementations/convergence_controller_classes/check_convergence.py @@ -87,13 +87,16 @@ def communicate_convergence(self, controller, S, comm): if controller.params.all_to_done: from mpi4py.MPI import LAND - controller.hooks.pre_comm(step=S, level_number=0) + for hook in controller.hooks: + hook.pre_comm(step=S, level_number=0) S.status.done = comm.allreduce(sendobj=S.status.done, op=LAND) - controller.hooks.post_comm(step=S, level_number=0, add_to_stats=True) + for hook in controller.hooks: + hook.post_comm(step=S, level_number=0, add_to_stats=True) else: - controller.hooks.pre_comm(step=S, level_number=0) + for hook in controller.hooks: + hook.pre_comm(step=S, level_number=0) # check if an open request of the status send is pending controller.wait_with_interrupt(request=controller.req_status) @@ -109,4 +112,5 @@ def communicate_convergence(self, controller, S, comm): if not S.status.last: self.send(comm, dest=S.status.slot + 1, data=S.status.done) - controller.hooks.post_comm(step=S, level_number=0, add_to_stats=True) + for hook in controller.hooks: + hook.post_comm(step=S, level_number=0, add_to_stats=True) diff --git a/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py b/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py index 1cf5e3f41b..f33f272e29 100644 --- a/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py +++ b/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py @@ -2,6 +2,7 @@ from pySDC.core.ConvergenceController import ConvergenceController, Pars from pySDC.implementations.convergence_controller_classes.store_uold import StoreUOld +from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate from pySDC.implementations.sweeper_classes.Runge_Kutta import RungeKutta @@ -16,7 +17,7 @@ class EstimateEmbeddedError(ConvergenceController): def __init__(self, controller, params, description, **kwargs): """ - Initalization routine. Add the buffers for communication. + Initialisation routine. Add the buffers for communication. Args: controller (pySDC.Controller): The controller @@ -25,6 +26,7 @@ def __init__(self, controller, params, description, **kwargs): """ super(EstimateEmbeddedError, self).__init__(controller, params, description, **kwargs) self.buffers = Pars({'e_em_last': 0.0}) + controller.add_hook(LogEmbeddedErrorEstimate) @classmethod def get_implementation(cls, flavor): @@ -42,7 +44,7 @@ def get_implementation(cls, flavor): elif flavor == 'nonMPI': return EstimateEmbeddedErrorNonMPI else: - raise NotImplementedError(f'Flavor {flavor} of EmstimateEmbeddedError is not implemented!') + raise NotImplementedError(f'Flavor {flavor} of EstimateEmbeddedError is not implemented!') def setup(self, controller, params, description, **kwargs): """ @@ -123,7 +125,7 @@ def reset_status_variables(self, controller, **kwargs): class EstimateEmbeddedErrorNonMPI(EstimateEmbeddedError): def reset_buffers_nonMPI(self, controller, **kwargs): """ - Reset buffers for immitated communication. + Reset buffers for imitated communication. Args: controller (pySDC.controller): The controller diff --git a/pySDC/implementations/convergence_controller_classes/estimate_extrapolation_error.py b/pySDC/implementations/convergence_controller_classes/estimate_extrapolation_error.py index 698df89627..887761d982 100644 --- a/pySDC/implementations/convergence_controller_classes/estimate_extrapolation_error.py +++ b/pySDC/implementations/convergence_controller_classes/estimate_extrapolation_error.py @@ -4,6 +4,7 @@ from pySDC.core.ConvergenceController import ConvergenceController, Status from pySDC.core.Errors import DataError from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh +from pySDC.implementations.hooks.log_extrapolated_error_estimate import LogExtrapolationErrorEstimate class EstimateExtrapolationErrorBase(ConvergenceController): @@ -27,6 +28,7 @@ def __init__(self, controller, params, description, **kwargs): self.prev = Status(["t", "u", "f", "dt"]) # store solutions etc. of previous steps here self.coeff = Status(["u", "f", "prefactor"]) # store coefficients for extrapolation here super(EstimateExtrapolationErrorBase, self).__init__(controller, params, description) + controller.add_hook(LogExtrapolationErrorEstimate) def setup(self, controller, params, description, **kwargs): """ @@ -252,7 +254,7 @@ class EstimateExtrapolationErrorNonMPI(EstimateExtrapolationErrorBase): def setup(self, controller, params, description, **kwargs): """ - Add a no parameter 'no_storage' which decides whether the standart or the no-memory-overhead version is run, + Add a no parameter 'no_storage' which decides whether the standard or the no-memory-overhead version is run, where only values are used for extrapolation which are in memory of other processes Args: diff --git a/pySDC/implementations/hooks/default_hook.py b/pySDC/implementations/hooks/default_hook.py new file mode 100644 index 0000000000..dcdd236421 --- /dev/null +++ b/pySDC/implementations/hooks/default_hook.py @@ -0,0 +1,343 @@ +import time +from pySDC.core.Hooks import hooks + + +class DefaultHooks(hooks): + """ + Hook class to contain the functions called during the controller runs (e.g. for calling user-routines) + + Attributes: + __t0_setup (float): private variable to get starting time of setup + __t0_run (float): private variable to get starting time of the run + __t0_predict (float): private variable to get starting time of the predictor + __t0_step (float): private variable to get starting time of the step + __t0_iteration (float): private variable to get starting time of the iteration + __t0_sweep (float): private variable to get starting time of the sweep + __t0_comm (list): private variable to get starting time of the communication + __t1_run (float): private variable to get end time of the run + __t1_predict (float): private variable to get end time of the predictor + __t1_step (float): private variable to get end time of the step + __t1_iteration (float): private variable to get end time of the iteration + __t1_sweep (float): private variable to get end time of the sweep + __t1_setup (float): private variable to get end time of setup + __t1_comm (list): private variable to hold timing of the communication (!) + """ + + def __init__(self): + super().__init__() + self.__t0_setup = None + self.__t0_run = None + self.__t0_predict = None + self.__t0_step = None + self.__t0_iteration = None + self.__t0_sweep = None + self.__t0_comm = [] + self.__t1_run = None + self.__t1_predict = None + self.__t1_step = None + self.__t1_iteration = None + self.__t1_sweep = None + self.__t1_setup = None + self.__t1_comm = [] + + def pre_setup(self, step, level_number): + """ + Default routine called before setup starts + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + """ + super().pre_setup(step, level_number) + self.__t0_setup = time.perf_counter() + + def pre_run(self, step, level_number): + """ + Default routine called before time-loop starts + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + """ + super().pre_run(step, level_number) + self.__t0_run = time.perf_counter() + + def pre_predict(self, step, level_number): + """ + Default routine called before predictor starts + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + """ + super().pre_predict(step, level_number) + self.__t0_predict = time.perf_counter() + + def pre_step(self, step, level_number): + """ + Hook called before each step + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + """ + super().pre_step(step, level_number) + self.__t0_step = time.perf_counter() + + def pre_iteration(self, step, level_number): + """ + Default routine called before iteration starts + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + """ + super().pre_iteration(step, level_number) + self.__t0_iteration = time.perf_counter() + + def pre_sweep(self, step, level_number): + """ + Default routine called before sweep starts + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + """ + super().pre_sweep(step, level_number) + self.__t0_sweep = time.perf_counter() + + def pre_comm(self, step, level_number): + """ + Default routine called before communication starts + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + """ + super().pre_comm(step, level_number) + if len(self.__t0_comm) >= level_number + 1: + self.__t0_comm[level_number] = time.perf_counter() + else: + while len(self.__t0_comm) < level_number: + self.__t0_comm.append(None) + self.__t0_comm.append(time.perf_counter()) + while len(self.__t1_comm) <= level_number: + self.__t1_comm.append(0.0) + assert len(self.__t0_comm) == level_number + 1 + assert len(self.__t1_comm) == level_number + 1 + + def post_comm(self, step, level_number, add_to_stats=False): + """ + Default routine called after each communication + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + add_to_stats (bool): set if result should go to stats object + """ + super().post_comm(step, level_number) + assert len(self.__t1_comm) >= level_number + 1 + self.__t1_comm[level_number] += time.perf_counter() - self.__t0_comm[level_number] + + if add_to_stats: + L = step.levels[level_number] + + self.add_to_stats( + process=step.status.slot, + time=L.time, + level=L.level_index, + iter=step.status.iter, + sweep=L.status.sweep, + type='timing_comm', + value=self.__t1_comm[level_number], + ) + self.__t1_comm[level_number] = 0.0 + + def post_sweep(self, step, level_number): + """ + Default routine called after each sweep + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + """ + super().post_sweep(step, level_number) + self.__t1_sweep = time.perf_counter() + + L = step.levels[level_number] + + self.logger.info( + 'Process %2i on time %8.6f at stage %15s: Level: %s -- Iteration: %2i -- Sweep: %2i -- ' 'residual: %12.8e', + step.status.slot, + L.time, + step.status.stage, + L.level_index, + step.status.iter, + L.status.sweep, + L.status.residual, + ) + + self.add_to_stats( + process=step.status.slot, + time=L.time, + level=L.level_index, + iter=step.status.iter, + sweep=L.status.sweep, + type='residual_post_sweep', + value=L.status.residual, + ) + self.add_to_stats( + process=step.status.slot, + time=L.time, + level=L.level_index, + iter=step.status.iter, + sweep=L.status.sweep, + type='timing_sweep', + value=self.__t1_sweep - self.__t0_sweep, + ) + + def post_iteration(self, step, level_number): + """ + Default routine called after each iteration + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + """ + super().post_iteration(step, level_number) + self.__t1_iteration = time.perf_counter() + + L = step.levels[level_number] + + self.add_to_stats( + process=step.status.slot, + time=L.time, + level=-1, + iter=step.status.iter, + sweep=L.status.sweep, + type='residual_post_iteration', + value=L.status.residual, + ) + self.add_to_stats( + process=step.status.slot, + time=L.time, + level=L.level_index, + iter=step.status.iter, + sweep=L.status.sweep, + type='timing_iteration', + value=self.__t1_iteration - self.__t0_iteration, + ) + + def post_step(self, step, level_number): + """ + Default routine called after each step or block + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + """ + super().post_step(step, level_number) + self.__t1_step = time.perf_counter() + + L = step.levels[level_number] + + self.add_to_stats( + process=step.status.slot, + time=L.time, + level=L.level_index, + iter=step.status.iter, + sweep=L.status.sweep, + type='timing_step', + value=self.__t1_step - self.__t0_step, + ) + self.add_to_stats( + process=step.status.slot, + time=L.time, + level=-1, + iter=step.status.iter, + sweep=L.status.sweep, + type='niter', + value=step.status.iter, + ) + self.add_to_stats( + process=step.status.slot, + time=L.time, + level=L.level_index, + iter=-1, + sweep=L.status.sweep, + type='residual_post_step', + value=L.status.residual, + ) + + # record the recomputed quantities at weird positions to make sure there is only one value for each step + for t in [L.time, L.time + L.dt]: + self.add_to_stats( + process=-1, time=t, level=-1, iter=-1, sweep=-1, type='_recomputed', value=step.status.get('restart') + ) + + def post_predict(self, step, level_number): + """ + Default routine called after each predictor + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + """ + super().post_predict(step, level_number) + self.__t1_predict = time.perf_counter() + + L = step.levels[level_number] + + self.add_to_stats( + process=step.status.slot, + time=L.time, + level=L.level_index, + iter=step.status.iter, + sweep=L.status.sweep, + type='timing_predictor', + value=self.__t1_predict - self.__t0_predict, + ) + + def post_run(self, step, level_number): + """ + Default routine called after each run + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + """ + super().post_run(step, level_number) + self.__t1_run = time.perf_counter() + + L = step.levels[level_number] + + self.add_to_stats( + process=step.status.slot, + time=L.time, + level=L.level_index, + iter=step.status.iter, + sweep=L.status.sweep, + type='timing_run', + value=self.__t1_run - self.__t0_run, + ) + + def post_setup(self, step, level_number): + """ + Default routine called after setup + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + """ + super().post_setup(step, level_number) + self.__t1_setup = time.perf_counter() + + self.add_to_stats( + process=-1, + time=-1, + level=-1, + iter=-1, + sweep=-1, + type='timing_setup', + value=self.__t1_setup - self.__t0_setup, + ) diff --git a/pySDC/implementations/hooks/log_embedded_error_estimate.py b/pySDC/implementations/hooks/log_embedded_error_estimate.py new file mode 100644 index 0000000000..60fa0c6ea7 --- /dev/null +++ b/pySDC/implementations/hooks/log_embedded_error_estimate.py @@ -0,0 +1,32 @@ +from pySDC.core.Hooks import hooks + + +class LogEmbeddedErrorEstimate(hooks): + """ + Store the embedded error estimate at the end of each step as "error_embedded_estimate". + """ + + def post_step(self, step, level_number): + """ + Record embedded error estimate + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + + Returns: + None + """ + super().post_step(step, level_number) + + L = step.levels[level_number] + + self.add_to_stats( + process=step.status.slot, + time=L.time + L.dt, + level=L.level_index, + iter=step.status.iter, + sweep=L.status.sweep, + type='error_embedded_estimate', + value=L.status.get('error_embedded_estimate'), + ) diff --git a/pySDC/implementations/hooks/log_extrapolated_error_estimate.py b/pySDC/implementations/hooks/log_extrapolated_error_estimate.py new file mode 100644 index 0000000000..1530db9e18 --- /dev/null +++ b/pySDC/implementations/hooks/log_extrapolated_error_estimate.py @@ -0,0 +1,33 @@ +from pySDC.core.Hooks import hooks + + +class LogExtrapolationErrorEstimate(hooks): + """ + Store the extrapolated error estimate at the end of each step as "error_extrapolation_estimate". + """ + + def post_step(self, step, level_number): + """ + Record extrapolated error estimate + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + + Returns: + None + """ + super().post_step(step, level_number) + + # some abbreviations + L = step.levels[level_number] + + self.add_to_stats( + process=step.status.slot, + time=L.time + L.dt, + level=L.level_index, + iter=step.status.iter, + sweep=L.status.sweep, + type='error_extrapolation_estimate', + value=L.status.get('error_extrapolation_estimate'), + ) diff --git a/pySDC/implementations/hooks/log_solution.py b/pySDC/implementations/hooks/log_solution.py new file mode 100644 index 0000000000..9d2d72d2e6 --- /dev/null +++ b/pySDC/implementations/hooks/log_solution.py @@ -0,0 +1,33 @@ +from pySDC.core.Hooks import hooks + + +class LogSolution(hooks): + """ + Store the solution at the end of each step as "u". + """ + + def post_step(self, step, level_number): + """ + Record solution at the end of the step + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + + Returns: + None + """ + super().post_step(step, level_number) + + L = step.levels[level_number] + L.sweep.compute_end_point() + + self.add_to_stats( + process=step.status.slot, + time=L.time + L.dt, + level=L.level_index, + iter=step.status.iter, + sweep=L.status.sweep, + type='u', + value=L.uend, + ) diff --git a/pySDC/implementations/hooks/log_step_size.py b/pySDC/implementations/hooks/log_step_size.py new file mode 100644 index 0000000000..62dada9ab0 --- /dev/null +++ b/pySDC/implementations/hooks/log_step_size.py @@ -0,0 +1,32 @@ +from pySDC.core.Hooks import hooks + + +class LogStepSize(hooks): + """ + Store the step size at the end of each step as "dt". + """ + + def post_step(self, step, level_number): + """ + Record step size + + Args: + step (pySDC.Step.step): the current step + level_number (int): the current level number + + Returns: + None + """ + super().post_step(step, level_number) + + L = step.levels[level_number] + + self.add_to_stats( + process=step.status.slot, + time=L.time, + level=L.level_index, + iter=step.status.iter, + sweep=L.status.sweep, + type='dt', + value=L.dt, + ) diff --git a/pySDC/implementations/problem_classes/Battery.py b/pySDC/implementations/problem_classes/Battery.py index 673c22233e..56b3713367 100644 --- a/pySDC/implementations/problem_classes/Battery.py +++ b/pySDC/implementations/problem_classes/Battery.py @@ -8,28 +8,32 @@ class battery(ptype): """ Example implementing the battery drain model as in the description in the PinTSimE project + Attributes: A: system matrix, representing the 2 ODEs + t_switch: time point of the switch + nswitches: number of switches """ def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh): """ Initialization routine + Args: problem_params (dict): custom parameters for the example dtype_u: mesh data type for solution dtype_f: mesh data type for RHS """ - problem_params['nvars'] = 2 - # these parameters will be used later, so assert their existence - essential_keys = ['Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref', 'set_switch', 't_switch'] + essential_keys = ['ncondensators', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref'] for key in essential_keys: if key not in problem_params: msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys())) raise ParameterError(msg) + problem_params['nvars'] = problem_params['ncondensators'] + 1 + # invoke super init, passing number of dofs, dtype_u and dtype_f super(battery, self).__init__( init=(problem_params['nvars'], None, np.dtype('float64')), @@ -39,13 +43,17 @@ def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh): ) self.A = np.zeros((2, 2)) + self.t_switch = None + self.nswitches = 0 def eval_f(self, u, t): """ Routine to evaluate the RHS + Args: u (dtype_u): current values t (float): current time + Returns: dtype_f: the RHS """ @@ -53,17 +61,10 @@ def eval_f(self, u, t): f = self.dtype_f(self.init, val=0.0) f.impl[:] = self.A.dot(u) - if u[1] <= self.params.V_ref or self.params.set_switch: - # switching need to happen on exact time point - if self.params.set_switch: - if t >= self.params.t_switch: - f.expl[0] = self.params.Vs / self.params.L - - else: - f.expl[0] = 0 + t_switch = np.inf if self.t_switch is None else self.t_switch - else: - f.expl[0] = self.params.Vs / self.params.L + if u[1] <= self.params.V_ref or t >= t_switch: + f.expl[0] = self.params.Vs / self.params.L else: f.expl[0] = 0 @@ -73,27 +74,22 @@ def eval_f(self, u, t): def solve_system(self, rhs, factor, u0, t): """ Simple linear solver for (I-factor*A)u = rhs + Args: rhs (dtype_f): right-hand side for the linear system factor (float): abbrev. for the local stepsize (or any other factor required) u0 (dtype_u): initial guess for the iterative solver t (float): current time (e.g. for time-dependent BCs) + Returns: dtype_u: solution as mesh """ self.A = np.zeros((2, 2)) - if rhs[1] <= self.params.V_ref or self.params.set_switch: - # switching need to happen on exact time point - if self.params.set_switch: - if t >= self.params.t_switch: - self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L + t_switch = np.inf if self.t_switch is None else self.t_switch - else: - self.A[1, 1] = -1 / (self.params.C * self.params.R) - - else: - self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L + if rhs[1] <= self.params.V_ref or t >= t_switch: + self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L else: self.A[1, 1] = -1 / (self.params.C * self.params.R) @@ -105,8 +101,10 @@ def solve_system(self, rhs, factor, u0, t): def u_exact(self, t): """ Routine to compute the exact solution at time t + Args: t (float): current time + Returns: dtype_u: exact solution """ @@ -119,53 +117,60 @@ def u_exact(self, t): return me + def get_switching_info(self, u, t): + """ + Provides information about a discrete event for one subinterval. -class battery_implicit(ptype): - """ - Example implementing the battery drain model as in the description in the PinTSimE project - Attributes: - A: system matrix, representing the 2 ODEs - """ + Args: + u (dtype_u): current values + t (float): current time - def __init__(self, problem_params, dtype_u=mesh, dtype_f=mesh): + Returns: + switch_detected (bool): Indicates if a switch is found or not + m_guess (np.int): Index of collocation node inside one subinterval of where the discrete event was found + vC_switch (list): Contains function values of switching condition (for interpolation) """ - Initialization routine - Args: - problem_params (dict): custom parameters for the example - dtype_u: mesh data type for solution - dtype_f: mesh data type for RHS + + switch_detected = False + m_guess = -100 + + for m in range(len(u)): + if u[m][1] - self.params.V_ref <= 0: + switch_detected = True + m_guess = m - 1 + break + + vC_switch = [u[m][1] - self.params.V_ref for m in range(1, len(u))] if switch_detected else [] + + return switch_detected, m_guess, vC_switch + + def count_switches(self): + """ + Counts the number of switches. This function is called when a switch is found inside the range of tolerance + (in switch_estimator.py) """ - problem_params['nvars'] = 2 + self.nswitches += 1 - # these parameters will be used later, so assert their existence - essential_keys = [ - 'newton_maxiter', - 'newton_tol', - 'Vs', - 'Rs', - 'C', - 'R', - 'L', - 'alpha', - 'V_ref', - 'set_switch', - 't_switch', - ] + +class battery_implicit(battery): + def __init__(self, problem_params, dtype_u=mesh, dtype_f=mesh): + + essential_keys = ['newton_maxiter', 'newton_tol', 'ncondensators', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref'] for key in essential_keys: if key not in problem_params: msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys())) raise ParameterError(msg) + problem_params['nvars'] = problem_params['ncondensators'] + 1 + # invoke super init, passing number of dofs, dtype_u and dtype_f super(battery_implicit, self).__init__( - init=(problem_params['nvars'], None, np.dtype('float64')), + problem_params, dtype_u=dtype_u, dtype_f=dtype_f, - params=problem_params, ) - self.A = np.zeros((2, 2)) self.newton_itercount = 0 self.lin_itercount = 0 self.newton_ncalls = 0 @@ -174,9 +179,11 @@ def __init__(self, problem_params, dtype_u=mesh, dtype_f=mesh): def eval_f(self, u, t): """ Routine to evaluate the RHS + Args: u (dtype_u): current values t (float): current time + Returns: dtype_f: the RHS """ @@ -184,37 +191,29 @@ def eval_f(self, u, t): f = self.dtype_f(self.init, val=0.0) non_f = np.zeros(2) - if u[1] <= self.params.V_ref or self.params.set_switch: - # switching need to happen on exact time point - if self.params.set_switch: - if t >= self.params.t_switch: - self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L - non_f[0] = self.params.Vs / self.params.L + t_switch = np.inf if self.t_switch is None else self.t_switch - else: - self.A[1, 1] = -1 / (self.params.C * self.params.R) - non_f[0] = 0 - - else: - self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L - non_f[0] = self.params.Vs / self.params.L + if u[1] <= self.params.V_ref or t >= t_switch: + self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L + non_f[0] = self.params.Vs else: self.A[1, 1] = -1 / (self.params.C * self.params.R) non_f[0] = 0 f[:] = self.A.dot(u) + non_f - return f def solve_system(self, rhs, factor, u0, t): """ Simple Newton solver + Args: rhs (dtype_f): right-hand side for the linear system factor (float): abbrev. for the local stepsize (or any other factor required) u0 (dtype_u): initial guess for the iterative solver t (float): current time (e.g. for time-dependent BCs) + Returns: dtype_u: solution as mesh """ @@ -223,20 +222,11 @@ def solve_system(self, rhs, factor, u0, t): non_f = np.zeros(2) self.A = np.zeros((2, 2)) - if rhs[1] <= self.params.V_ref or self.params.set_switch: - # switching need to happen on exact time point - if self.params.set_switch: - if t >= self.params.t_switch: - self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L - non_f[0] = self.params.Vs / self.params.L + t_switch = np.inf if self.t_switch is None else self.t_switch - else: - self.A[1, 1] = -1 / (self.params.C * self.params.R) - non_f[0] = 0 - - else: - self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L - non_f[0] = self.params.Vs / self.params.L + if rhs[1] <= self.params.V_ref or t >= t_switch: + self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L + non_f[0] = self.params.Vs else: self.A[1, 1] = -1 / (self.params.C * self.params.R) @@ -280,11 +270,131 @@ def solve_system(self, rhs, factor, u0, t): return me + +class battery_n_condensators(ptype): + """ + Example implementing the battery drain model with N capacitors, where N is an arbitrary integer greater than 0. + Attributes: + nswitches: number of switches + """ + + def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh): + """ + Initialization routine + + Args: + problem_params (dict): custom parameters for the example + dtype_u: mesh data type for solution + dtype_f: mesh data type for RHS + """ + + # these parameters will be used later, so assert their existence + essential_keys = ['ncondensators', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref'] + for key in essential_keys: + if key not in problem_params: + msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys())) + raise ParameterError(msg) + + n = problem_params['ncondensators'] + problem_params['nvars'] = n + 1 + + # invoke super init, passing number of dofs, dtype_u and dtype_f + super(battery_n_condensators, self).__init__( + init=(problem_params['nvars'], None, np.dtype('float64')), + dtype_u=dtype_u, + dtype_f=dtype_f, + params=problem_params, + ) + + v = np.zeros(n + 1) + v[0] = 1 + + self.A = np.zeros((n + 1, n + 1)) + self.switch_A = {k: np.diag(-1 / (self.params.C[k] * self.params.R) * np.roll(v, k + 1)) for k in range(n)} + self.switch_A.update({n: np.diag(-(self.params.Rs + self.params.R) / self.params.L * v)}) + self.switch_f = {k: np.zeros(n + 1) for k in range(n)} + self.switch_f.update({n: self.params.Vs / self.params.L * v}) + self.t_switch = None + self.nswitches = 0 + + def eval_f(self, u, t): + """ + Routine to evaluate the RHS. No Switch Estimator is used: For N = 3 there are N + 1 = 4 different states of the battery: + 1. u[1] > V_ref[0] and u[2] > V_ref[1] and u[3] > V_ref[2] -> C1 supplies energy + 2. u[1] <= V_ref[0] and u[2] > V_ref[1] and u[3] > V_ref[2] -> C2 supplies energy + 3. u[1] <= V_ref[0] and u[2] <= V_ref[1] and u[3] > V_ref[2] -> C3 supplies energy + 4. u[1] <= V_ref[0] and u[2] <= V_ref[1] and u[3] <= V_ref[2] -> Vs supplies energy + max_index is initialized to -1. List "switch" contains a True if u[k] <= V_ref[k-1] is satisfied. + - Is no True there (i.e. max_index = -1), we are in the first case. + - max_index = k >= 0 means we are in the (k+1)-th case. + So, the actual RHS has key max_index-1 in the dictionary self.switch_f. + In case of using the Switch Estimator, we count the number of switches which illustrates in which case of voltage source we are. + + Args: + u (dtype_u): current values + t (float): current time + + Returns: + dtype_f: the RHS + """ + + f = self.dtype_f(self.init, val=0.0) + f.impl[:] = self.A.dot(u) + + if self.t_switch is not None: + f.expl[:] = self.switch_f[self.nswitches] + + else: + # proof all switching conditions and find largest index where it drops below V_ref + switch = [True if u[k] <= self.params.V_ref[k - 1] else False for k in range(1, len(u))] + max_index = max([k if switch[k] == True else -1 for k in range(len(switch))]) + + if max_index == -1: + f.expl[:] = self.switch_f[0] + + else: + f.expl[:] = self.switch_f[max_index + 1] + + return f + + def solve_system(self, rhs, factor, u0, t): + """ + Simple linear solver for (I-factor*A)u = rhs + + Args: + rhs (dtype_f): right-hand side for the linear system + factor (float): abbrev. for the local stepsize (or any other factor required) + u0 (dtype_u): initial guess for the iterative solver + t (float): current time (e.g. for time-dependent BCs) + + Returns: + dtype_u: solution as mesh + """ + + if self.t_switch is not None: + self.A = self.switch_A[self.nswitches] + + else: + # proof all switching conditions and find largest index where it drops below V_ref + switch = [True if rhs[k] <= self.params.V_ref[k - 1] else False for k in range(1, len(rhs))] + max_index = max([k if switch[k] == True else -1 for k in range(len(switch))]) + if max_index == -1: + self.A = self.switch_A[0] + + else: + self.A = self.switch_A[max_index + 1] + + me = self.dtype_u(self.init) + me[:] = np.linalg.solve(np.eye(self.params.nvars) - factor * self.A, rhs) + return me + def u_exact(self, t): """ Routine to compute the exact solution at time t + Args: t (float): current time + Returns: dtype_u: exact solution """ @@ -293,6 +403,49 @@ def u_exact(self, t): me = self.dtype_u(self.init) me[0] = 0.0 # cL - me[1] = self.params.alpha * self.params.V_ref # vC - + me[1:] = self.params.alpha * self.params.V_ref # vC's return me + + def get_switching_info(self, u, t): + """ + Provides information about a discrete event for one subinterval. + + Args: + u (dtype_u): current values + t (float): current time + + Returns: + switch_detected (bool): Indicates if a switch is found or not + m_guess (np.int): Index of collocation node inside one subinterval of where the discrete event was found + vC_switch (list): Contains function values of switching condition (for interpolation) + """ + + switch_detected = False + m_guess = -100 + break_flag = False + + for m in range(len(u)): + for k in range(1, self.params.nvars): + if u[m][k] - self.params.V_ref[k - 1] <= 0: + switch_detected = True + m_guess = m - 1 + k_detected = k + break_flag = True + break + + if break_flag: + break + + vC_switch = ( + [u[m][k_detected] - self.params.V_ref[k_detected - 1] for m in range(1, len(u))] if switch_detected else [] + ) + + return switch_detected, m_guess, vC_switch + + def count_switches(self): + """ + Counts the number of switches. This function is called when a switch is found inside the range of tolerance + (in switch_estimator.py) + """ + + self.nswitches += 1 diff --git a/pySDC/implementations/problem_classes/Battery_2Condensators.py b/pySDC/implementations/problem_classes/Battery_2Condensators.py deleted file mode 100644 index 1c5eb55030..0000000000 --- a/pySDC/implementations/problem_classes/Battery_2Condensators.py +++ /dev/null @@ -1,168 +0,0 @@ -import numpy as np - -from pySDC.core.Errors import ParameterError -from pySDC.core.Problem import ptype -from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh - - -class battery_2condensators(ptype): - """ - Example implementing the battery drain model using two capacitors as in the description in the PinTSimE - project - Attributes: - A: system matrix, representing the 3 ODEs - """ - - def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh): - """ - Initialization routine - Args: - problem_params (dict): custom parameters for the example - dtype_u: mesh data type for solution - dtype_f: mesh data type for RHS - """ - - problem_params['nvars'] = 3 - - # these parameters will be used later, so assert their existence - essential_keys = ['Vs', 'Rs', 'C1', 'C2', 'R', 'L', 'alpha', 'V_ref', 'set_switch', 't_switch'] - - for key in essential_keys: - if key not in problem_params: - msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys())) - raise ParameterError(msg) - - # invoke super init, passing number of dofs, dtype_u and dtype_f - super(battery_2condensators, self).__init__( - init=(problem_params['nvars'], None, np.dtype('float64')), - dtype_u=dtype_u, - dtype_f=dtype_f, - params=problem_params, - ) - - self.A = np.zeros((3, 3)) - - def eval_f(self, u, t): - """ - Routine to evaluate the RHS - Args: - u (dtype_u): current values - t (float): current time - Returns: - dtype_f: the RHS - """ - - f = self.dtype_f(self.init, val=0.0) - f.impl[:] = self.A.dot(u) - - # switch to C2 - if ( - u[1] <= self.params.V_ref[0] - and u[2] > self.params.V_ref[1] - or self.params.set_switch[0] - and not self.params.set_switch[1] - ): - if self.params.set_switch[0]: - if t >= self.params.t_switch[0]: - f.expl[0] = 0 - - else: - f.expl[0] = 0 - - else: - f.expl[0] = 0 - - # switch to Vs - elif u[2] <= self.params.V_ref[1] or (self.params.set_switch[0] and self.params.set_switch[1]): - # switch to Vs - if self.params.set_switch[1]: - if t >= self.params.t_switch[1]: - f.expl[0] = self.params.Vs / self.params.L - - else: - f.expl[0] = 0 - - else: - f.expl[0] = self.params.Vs / self.params.L - - elif ( - u[1] > self.params.V_ref[0] - and u[2] > self.params.V_ref[1] - or not self.params.set_switch[0] - and not self.params.set_switch[1] - ): - # C1 supplies energy - f.expl[0] = 0 - - return f - - def solve_system(self, rhs, factor, u0, t): - """ - Simple linear solver for (I-factor*A)u = rhs - Args: - rhs (dtype_f): right-hand side for the linear system - factor (float): abbrev. for the local stepsize (or any other factor required) - u0 (dtype_u): initial guess for the iterative solver - t (float): current time (e.g. for time-dependent BCs) - Returns: - dtype_u: solution as mesh - """ - self.A = np.zeros((3, 3)) - - # switch to C2 - if ( - rhs[1] <= self.params.V_ref[0] - and rhs[2] > self.params.V_ref[1] - or self.params.set_switch[0] - and not self.params.set_switch[1] - ): - if self.params.set_switch[0]: - if t >= self.params.t_switch[0]: - self.A[2, 2] = -1 / (self.params.C2 * self.params.R) - - else: - self.A[1, 1] = -1 / (self.params.C1 * self.params.R) - else: - self.A[2, 2] = -1 / (self.params.C2 * self.params.R) - - # switch to Vs - elif rhs[2] <= self.params.V_ref[1] or (self.params.set_switch[0] and self.params.set_switch[1]): - if self.params.set_switch[1]: - if t >= self.params.t_switch[1]: - self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L - - else: - self.A[2, 2] = -1 / (self.params.C2 * self.params.R) - - else: - self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L - - elif ( - rhs[1] > self.params.V_ref[0] - and rhs[2] > self.params.V_ref[1] - or not self.params.set_switch[0] - and not self.params.set_switch[1] - ): - # C1 supplies energy - self.A[1, 1] = -1 / (self.params.C1 * self.params.R) - - me = self.dtype_u(self.init) - me[:] = np.linalg.solve(np.eye(self.params.nvars) - factor * self.A, rhs) - return me - - def u_exact(self, t): - """ - Routine to compute the exact solution at time t - Args: - t (float): current time - Returns: - dtype_u: exact solution - """ - assert t == 0, 'ERROR: u_exact only valid for t=0' - - me = self.dtype_u(self.init) - - me[0] = 0.0 # cL - me[1] = self.params.alpha * self.params.V_ref[0] # vC1 - me[2] = self.params.alpha * self.params.V_ref[1] # vC2 - return me diff --git a/pySDC/projects/PinTSimE/battery_2condensators_model.py b/pySDC/projects/PinTSimE/battery_2condensators_model.py index c41cb82712..8739a96d42 100644 --- a/pySDC/projects/PinTSimE/battery_2condensators_model.py +++ b/pySDC/projects/PinTSimE/battery_2condensators_model.py @@ -4,10 +4,10 @@ from pySDC.helpers.stats_helper import get_sorted from pySDC.core.Collocation import CollBase as Collocation -from pySDC.implementations.problem_classes.Battery_2Condensators import battery_2condensators +from pySDC.implementations.problem_classes.Battery import battery_n_condensators from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI -from pySDC.implementations.transfer_classes.TransferMesh import mesh_to_mesh +from pySDC.projects.PinTSimE.battery_model import get_recomputed from pySDC.projects.PinTSimE.piline_model import setup_mpl import pySDC.helpers.plot_helper as plt_helper from pySDC.core.Hooks import hooks @@ -52,55 +52,60 @@ def post_step(self, step, level_number): type='voltage C2', value=L.uend[2], ) - self.increment_stats( + self.add_to_stats( process=step.status.slot, time=L.time, level=L.level_index, iter=0, sweep=L.status.sweep, type='restart', - value=1, - initialize=0, + value=int(step.status.get('restart')), ) def main(use_switch_estimator=True): """ A simple test program to do SDC/PFASST runs for the battery drain model using 2 condensators + + Args: + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + + Returns: + description (dict): contains all information for a controller run """ # initialize level parameters level_params = dict() - level_params['restol'] = 1e-13 + level_params['restol'] = -1 level_params['dt'] = 1e-2 + assert level_params['dt'] == 1e-2, 'Error! Do not use the time step dt != 1e-2!' + # initialize sweeper parameters sweeper_params = dict() sweeper_params['quad_type'] = 'LOBATTO' sweeper_params['num_nodes'] = 5 - sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part - sweeper_params['initial_guess'] = 'zero' + # sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part + sweeper_params['initial_guess'] = 'spread' # initialize problem parameters problem_params = dict() + problem_params['ncondensators'] = 2 problem_params['Vs'] = 5.0 problem_params['Rs'] = 0.5 - problem_params['C1'] = 1.0 - problem_params['C2'] = 1.0 + problem_params['C'] = np.array([1.0, 1.0]) problem_params['R'] = 1.0 problem_params['L'] = 1.0 problem_params['alpha'] = 5.0 problem_params['V_ref'] = np.array([1.0, 1.0]) # [V_ref1, V_ref2] - problem_params['set_switch'] = np.array([False, False], dtype=bool) - problem_params['t_switch'] = np.zeros(np.shape(problem_params['V_ref'])[0]) # initialize step parameters step_params = dict() - step_params['maxiter'] = 20 + step_params['maxiter'] = 4 # initialize controller parameters controller_params = dict() - controller_params['logger_level'] = 20 + controller_params['logger_level'] = 30 controller_params['hook_class'] = log_data # convergence controllers @@ -111,18 +116,17 @@ def main(use_switch_estimator=True): # fill description dictionary for easy step instantiation description = dict() - description['problem_class'] = battery_2condensators # pass problem class + description['problem_class'] = battery_n_condensators # pass problem class description['problem_params'] = problem_params # pass problem parameters description['sweeper_class'] = imex_1st_order # pass sweeper description['sweeper_params'] = sweeper_params # pass sweeper parameters description['level_params'] = level_params # pass level parameters description['step_params'] = step_params - description['space_transfer_class'] = mesh_to_mesh # pass spatial transfer class if use_switch_estimator: description['convergence_controllers'] = convergence_controllers - proof_assertions_description(description, problem_params) + proof_assertions_description(description, use_switch_estimator) # set time parameters t0 = 0.0 @@ -144,39 +148,24 @@ def main(use_switch_estimator=True): dill.dump(stats, f) f.close() - # filter statistics by number of iterations - iter_counts = get_sorted(stats, type='niter', sortby='time') - - # compute and print statistics - min_iter = 20 - max_iter = 0 - - f = open('battery_2condensators_out.txt', 'w') - niters = np.array([item[1] for item in iter_counts]) - out = ' Mean number of iterations: %4.2f' % np.mean(niters) - f.write(out + '\n') - print(out) - for item in iter_counts: - out = 'Number of iterations for time %4.2f: %1i' % item - f.write(out + '\n') - # print(out) - min_iter = min(min_iter, item[1]) - max_iter = max(max_iter, item[1]) - - restarts = np.array(get_sorted(stats, type='restart', recomputed=False))[:, 1] - print("Restarts for dt: ", level_params['dt'], " -- ", np.sum(restarts)) - - assert np.mean(niters) <= 10, "Mean number of iterations is too high, got %s" % np.mean(niters) - f.close() + recomputed = False + + check_solution(stats, use_switch_estimator) - plot_voltages(description, use_switch_estimator) + plot_voltages(description, recomputed, use_switch_estimator) - return np.mean(niters) + return description -def plot_voltages(description, use_switch_estimator, cwd='./'): +def plot_voltages(description, recomputed, use_switch_estimator, cwd='./'): """ Routine to plot the numerical solution of the model + + Args: + description(dict): contains all information for a controller run + recomputed (bool): flag if the values after a restart are used or before + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + cwd: current working directory """ f = open(cwd + 'data/battery_2condensators.dat', 'rb') @@ -184,9 +173,9 @@ def plot_voltages(description, use_switch_estimator, cwd='./'): f.close() # convert filtered statistics to list of iterations count, sorted by process - cL = get_sorted(stats, type='current L', sortby='time') - vC1 = get_sorted(stats, type='voltage C1', sortby='time') - vC2 = get_sorted(stats, type='voltage C2', sortby='time') + cL = get_sorted(stats, type='current L', recomputed=recomputed, sortby='time') + vC1 = get_sorted(stats, type='voltage C1', recomputed=recomputed, sortby='time') + vC2 = get_sorted(stats, type='voltage C2', recomputed=recomputed, sortby='time') times = [v[0] for v in cL] @@ -197,11 +186,13 @@ def plot_voltages(description, use_switch_estimator, cwd='./'): ax.plot(times, [v[1] for v in vC2], label='$v_{C_2}$') if use_switch_estimator: - t_switch_plot = np.zeros(np.shape(description['problem_params']['t_switch'])[0]) - for i in range(np.shape(description['problem_params']['t_switch'])[0]): - t_switch_plot[i] = description['problem_params']['t_switch'][i] + switches = get_recomputed(stats, type='switch', sortby='time') + if recomputed is not None: + assert len(switches) >= 2, f"Expected at least 2 switches, got {len(switches)}!" + t_switches = [v[1] for v in switches] - ax.axvline(x=t_switch_plot[i], linestyle='--', color='k', label='Switch {}'.format(i + 1)) + for i in range(len(t_switches)): + ax.axvline(x=t_switches[i], linestyle='--', color='k', label='Switch {}'.format(i + 1)) ax.legend(frameon=False, fontsize=12, loc='upper right') @@ -212,30 +203,108 @@ def plot_voltages(description, use_switch_estimator, cwd='./'): plt_helper.plt.close(fig) -def proof_assertions_description(description, problem_params): +def check_solution(stats, use_switch_estimator): """ - Function to proof the assertions (function to get cleaner code) + Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen. + + Args: + stats (dict): Raw statistics from a controller run + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + """ + + data = get_data_dict(stats, use_switch_estimator) + + if use_switch_estimator: + msg = 'Error when using the switch estimator for battery_2condensators:' + expected = { + 'cL': 1.2065280755094876, + 'vC1': 1.0094825899806945, + 'vC2': 1.0050052828742688, + 'switch1': 1.6094379124373626, + 'switch2': 3.209437912457051, + 'restarts': 2.0, + 'sum_niters': 1568, + } + + got = { + 'cL': data['cL'][-1], + 'vC1': data['vC1'][-1], + 'vC2': data['vC2'][-1], + 'switch1': data['switch1'], + 'switch2': data['switch2'], + 'restarts': data['restarts'], + 'sum_niters': data['sum_niters'], + } + + for key in expected.keys(): + assert np.isclose( + expected[key], got[key], rtol=1e-4 + ), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}' + + +def get_data_dict(stats, use_switch_estimator, recomputed=False): """ + Converts the statistics in a useful data dictionary so that it can be easily checked in the check_solution function. + Based on @brownbaerchen's get_data function. + + Args: + stats (dict): Raw statistics from a controller run + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + recomputed (bool): flag if the values after a restart are used or before - assert problem_params['alpha'] > problem_params['V_ref'][0], 'Please set "alpha" greater than "V_ref1"' - assert problem_params['alpha'] > problem_params['V_ref'][1], 'Please set "alpha" greater than "V_ref2"' + Return: + data (dict): contains all information as the statistics dict + """ - assert problem_params['V_ref'][0] > 0, 'Please set "V_ref1" greater than 0' - assert problem_params['V_ref'][1] > 0, 'Please set "V_ref2" greater than 0' + data = dict() + data['cL'] = np.array(get_sorted(stats, type='current L', recomputed=recomputed, sortby='time'))[:, 1] + data['vC1'] = np.array(get_sorted(stats, type='voltage C1', recomputed=recomputed, sortby='time'))[:, 1] + data['vC2'] = np.array(get_sorted(stats, type='voltage C2', recomputed=recomputed, sortby='time'))[:, 1] + data['switch1'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[0, 1] + data['switch2'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[-1, 1] + data['restarts'] = np.sum(np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))[:, 1]) + data['sum_niters'] = np.sum(np.array(get_sorted(stats, type='niter', recomputed=None, sortby='time'))[:, 1]) - assert type(problem_params['V_ref']) == np.ndarray, '"V_ref" needs to be an array (of type float)' - assert not problem_params['set_switch'][0], 'First entry of "set_switch" needs to be False' - assert not problem_params['set_switch'][1], 'Second entry of "set_switch" needs to be False' + return data - assert not type(problem_params['t_switch']) == float, '"t_switch" has to be an array with entry zero' - assert problem_params['t_switch'][0] == 0, 'First entry of "t_switch" needs to be zero' - assert problem_params['t_switch'][1] == 0, 'Second entry of "t_switch" needs to be zero' +def proof_assertions_description(description, use_switch_estimator): + """ + Function to proof the assertions (function to get cleaner code) + + Args: + description(dict): contains all information for a controller run + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + """ + + assert ( + description['problem_params']['alpha'] > description['problem_params']['V_ref'][0] + ), 'Please set "alpha" greater than "V_ref1"' + assert ( + description['problem_params']['alpha'] > description['problem_params']['V_ref'][1] + ), 'Please set "alpha" greater than "V_ref2"' + + if description['problem_params']['ncondensators'] > 1: + assert ( + type(description['problem_params']['V_ref']) == np.ndarray + ), '"V_ref" needs to be an array (of type float)' + assert ( + description['problem_params']['ncondensators'] == np.shape(description['problem_params']['V_ref'])[0] + ), 'Number of reference values needs to be equal to number of condensators' + assert ( + description['problem_params']['ncondensators'] == np.shape(description['problem_params']['C'])[0] + ), 'Number of capacitance values needs to be equal to number of condensators' + + assert description['problem_params']['V_ref'][0] > 0, 'Please set "V_ref1" greater than 0' + assert description['problem_params']['V_ref'][1] > 0, 'Please set "V_ref2" greater than 0' assert 'errtol' not in description['step_params'].keys(), 'No exact solution known to compute error' assert 'alpha' in description['problem_params'].keys(), 'Please supply "alpha" in the problem parameters' assert 'V_ref' in description['problem_params'].keys(), 'Please supply "V_ref" in the problem parameters' + if use_switch_estimator: + assert description['level_params']['restol'] == -1, "Please set restol to -1 or omit it" + if __name__ == "__main__": main() diff --git a/pySDC/projects/PinTSimE/battery_model.py b/pySDC/projects/PinTSimE/battery_model.py index a127687eb5..a5336bc035 100644 --- a/pySDC/projects/PinTSimE/battery_model.py +++ b/pySDC/projects/PinTSimE/battery_model.py @@ -2,7 +2,7 @@ import dill from pathlib import Path -from pySDC.helpers.stats_helper import get_sorted +from pySDC.helpers.stats_helper import sort_stats, filter_stats, get_sorted from pySDC.core.Collocation import CollBase as Collocation from pySDC.implementations.problem_classes.Battery import battery, battery_implicit from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order @@ -62,11 +62,31 @@ def post_step(self, step, level_number): type='dt', value=L.dt, ) + self.add_to_stats( + process=step.status.slot, + time=L.time + L.dt, + level=L.level_index, + iter=0, + sweep=L.status.sweep, + type='e_embedded', + value=L.status.get('error_embedded_estimate'), + ) def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity): """ A simple test program to do SDC/PFASST runs for the battery drain model + + Args: + dt (float): time step for computation + problem (problem_class.__name__): problem class that wants to be simulated + sweeper (sweeper_class.__name__): sweeper class for solving the problem class numerically + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + use_adaptivity (bool): flag if the adaptivity wants to be used or not + + Returns: + stats (dict): Raw statistics from a controller run + description (dict): contains all information for a controller run """ # initialize level parameters @@ -79,12 +99,13 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity): sweeper_params['quad_type'] = 'LOBATTO' sweeper_params['num_nodes'] = 5 # sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part - sweeper_params['initial_guess'] = 'zero' + sweeper_params['initial_guess'] = 'spread' # initialize problem parameters problem_params = dict() problem_params['newton_maxiter'] = 200 problem_params['newton_tol'] = 1e-08 + problem_params['ncondensators'] = 1 # number of condensators problem_params['Vs'] = 5.0 problem_params['Rs'] = 0.5 problem_params['C'] = 1.0 @@ -92,8 +113,6 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity): problem_params['L'] = 1.0 problem_params['alpha'] = 1.2 problem_params['V_ref'] = 1.0 - problem_params['set_switch'] = np.array([False], dtype=bool) - problem_params['t_switch'] = np.zeros(1) # initialize step parameters step_params = dict() @@ -101,7 +120,7 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity): # initialize controller parameters controller_params = dict() - controller_params['logger_level'] = 20 + controller_params['logger_level'] = 30 controller_params['hook_class'] = log_data controller_params['mssdc_jac'] = False @@ -113,7 +132,7 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity): if use_adaptivity: adaptivity_params = dict() - adaptivity_params['e_tol'] = 1e-12 + adaptivity_params['e_tol'] = 1e-7 convergence_controllers.update({Adaptivity: adaptivity_params}) # fill description dictionary for easy step instantiation @@ -128,12 +147,19 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity): if use_switch_estimator or use_adaptivity: description['convergence_controllers'] = convergence_controllers - proof_assertions_description(description, problem_params) - # set time parameters t0 = 0.0 Tend = 0.3 + proof_assertions_description(description, use_adaptivity, use_switch_estimator) + + assert dt < Tend, "Time step is too large for the time domain!" + + assert ( + Tend == 0.3 and description['problem_params']['V_ref'] == 1.0 and description['problem_params']['alpha'] == 1.2 + ), "Error! Do not use other parameters for V_ref != 1.0, alpha != 1.2, Tend != 0.3 due to hardcoded reference!" + assert description['level_params']['dt'] == 1e-2, "Error! Do not use another time step dt!= 1e-2!" + # instantiate controller controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description) @@ -144,35 +170,13 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity): # call main function to get things done... uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) - # filter statistics by number of iterations - iter_counts = get_sorted(stats, type='niter', recomputed=False, sortby='time') - - # compute and print statistics - min_iter = 20 - max_iter = 0 - Path("data").mkdir(parents=True, exist_ok=True) fname = 'data/battery_{}_USE{}_USA{}.dat'.format(sweeper.__name__, use_switch_estimator, use_adaptivity) f = open(fname, 'wb') dill.dump(stats, f) f.close() - f = open('data/battery_out.txt', 'w') - niters = np.array([item[1] for item in iter_counts]) - out = ' Mean number of iterations: %4.2f' % np.mean(niters) - f.write(out + '\n') - print(out) - for item in iter_counts: - out = 'Number of iterations for time %4.2f: %1i' % item - f.write(out + '\n') - print(out) - min_iter = min(min_iter, item[1]) - max_iter = max(max_iter, item[1]) - - assert np.mean(niters) <= 5, "Mean number of iterations is too high, got %s" % np.mean(niters) - f.close() - - return description + return stats, description def run(): @@ -184,13 +188,14 @@ def run(): dt = 1e-2 problem_classes = [battery, battery_implicit] sweeper_classes = [imex_1st_order, generic_implicit] + recomputed = False use_switch_estimator = [True] use_adaptivity = [True] for problem, sweeper in zip(problem_classes, sweeper_classes): for use_SE in use_switch_estimator: for use_A in use_adaptivity: - description = main( + stats, description = main( dt=dt, problem=problem, sweeper=sweeper, @@ -198,12 +203,23 @@ def run(): use_adaptivity=use_A, ) - plot_voltages(description, problem.__name__, sweeper.__name__, use_SE, use_A) + check_solution(stats, problem.__name__, use_adaptivity, use_switch_estimator) + + plot_voltages(description, problem.__name__, sweeper.__name__, recomputed, use_SE, use_A) -def plot_voltages(description, problem, sweeper, use_switch_estimator, use_adaptivity, cwd='./'): +def plot_voltages(description, problem, sweeper, recomputed, use_switch_estimator, use_adaptivity, cwd='./'): """ Routine to plot the numerical solution of the model + + Args: + description(dict): contains all information for a controller run + problem (problem_class.__name__): problem class that wants to be simulated + sweeper (sweeper_class.__name__): sweeper class for solving the problem class numerically + recomputed (bool): flag if the values after a restart are used or before + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + use_adaptivity (bool): flag if adaptivity wants to be used or not + cwd: current working directory """ f = open(cwd + 'data/battery_{}_USE{}_USA{}.dat'.format(sweeper, use_switch_estimator, use_adaptivity), 'rb') @@ -223,12 +239,15 @@ def plot_voltages(description, problem, sweeper, use_switch_estimator, use_adapt ax.plot(times, [v[1] for v in vC], label=r'$v_C$') if use_switch_estimator: - val_switch = get_sorted(stats, type='switch1', sortby='time') - t_switch = [v[1] for v in val_switch] + switches = get_recomputed(stats, type='switch', sortby='time') + + assert len(switches) >= 1, 'No switches found!' + t_switch = [v[1] for v in switches] ax.axvline(x=t_switch[-1], linestyle='--', linewidth=0.8, color='r', label='Switch') if use_adaptivity: dt = np.array(get_sorted(stats, type='dt', recomputed=False)) + dt_ax = ax.twinx() dt_ax.plot(dt[:, 0], dt[:, 1], linestyle='-', linewidth=0.8, color='k', label=r'$\Delta t$') dt_ax.set_ylabel(r'$\Delta t$', fontsize=8) @@ -245,23 +264,146 @@ def plot_voltages(description, problem, sweeper, use_switch_estimator, use_adapt plt_helper.plt.close(fig) -def proof_assertions_description(description, problem_params): +def check_solution(stats, problem, use_adaptivity, use_switch_estimator): """ - Function to proof the assertions (function to get cleaner code) + Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen. + + Args: + stats (dict): Raw statistics from a controller run + problem (problem_class.__name__): the problem_class that is numerically solved + use_adaptivity (bool): flag if adaptivity wants to be used or not + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + """ + + data = get_data_dict(stats, use_adaptivity, use_switch_estimator) + + if use_switch_estimator and use_adaptivity: + if problem == 'battery': + msg = 'Error when using switch estimator and adaptivity for battery:' + expected = { + 'cL': 0.5474500710994862, + 'vC': 1.0019332967173764, + 'dt': 0.011761752270047832, + 'e_em': 8.001793672107738e-10, + 'switches': 0.18232155791181945, + 'restarts': 3.0, + 'sum_niters': 44, + } + elif problem == 'battery_implicit': + msg = 'Error when using switch estimator and adaptivity for battery_implicit:' + expected = { + 'cL': 0.5424577937840791, + 'vC': 1.0001051105894005, + 'dt': 0.01, + 'e_em': 2.220446049250313e-16, + 'switches': 0.1822923488448394, + 'restarts': 6.0, + 'sum_niters': 60, + } + + got = { + 'cL': data['cL'][-1], + 'vC': data['vC'][-1], + 'dt': data['dt'][-1], + 'e_em': data['e_em'][-1], + 'switches': data['switches'][-1], + 'restarts': data['restarts'], + 'sum_niters': data['sum_niters'], + } + + for key in expected.keys(): + assert np.isclose( + expected[key], got[key], rtol=1e-4 + ), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}' + + +def get_data_dict(stats, use_adaptivity=True, use_switch_estimator=True, recomputed=False): + """ + Converts the statistics in a useful data dictionary so that it can be easily checked in the check_solution function. + Based on @brownbaerchen's get_data function. + + Args: + stats (dict): Raw statistics from a controller run + use_adaptivity (bool): flag if adaptivity wants to be used or not + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + recomputed (bool): flag if the values after a restart are used or before + + Return: + data (dict): contains all information as the statistics dict + """ + + data = dict() + + data['cL'] = np.array(get_sorted(stats, type='current L', recomputed=recomputed, sortby='time'))[:, 1] + data['vC'] = np.array(get_sorted(stats, type='voltage C', recomputed=recomputed, sortby='time'))[:, 1] + if use_adaptivity: + data['dt'] = np.array(get_sorted(stats, type='dt', recomputed=recomputed, sortby='time'))[:, 1] + data['e_em'] = np.array( + get_sorted(stats, type='error_embedded_estimate', recomputed=recomputed, sortby='time') + )[:, 1] + if use_switch_estimator: + data['switches'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[:, 1] + if use_adaptivity or use_switch_estimator: + data['restarts'] = np.sum(np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))[:, 1]) + data['sum_niters'] = np.sum(np.array(get_sorted(stats, type='niter', recomputed=None, sortby='time'))[:, 1]) + + return data + + +def get_recomputed(stats, type, sortby): """ + Function that filters statistics after a recomputation. It stores all value of a type before restart. If there are multiple values + with same time point, it only stores the elements with unique times. - assert problem_params['alpha'] > problem_params['V_ref'], 'Please set "alpha" greater than "V_ref"' - assert problem_params['V_ref'] > 0, 'Please set "V_ref" greater than 0' - assert type(problem_params['V_ref']) == float, '"V_ref" needs to be of type float' + Args: + stats (dict): Raw statistics from a controller run + type (str): the type the be filtered + sortby (str): string to specify which key to use for sorting - assert type(problem_params['set_switch'][0]) == np.bool_, '"set_switch" has to be an bool array' - assert type(problem_params['t_switch']) == np.ndarray, '"t_switch" has to be an array' - assert problem_params['t_switch'][0] == 0, '"t_switch" is only allowed to have entry zero' + Returns: + sorted_list (list): list of filtered statistics + """ + + sorted_nested_list = [] + times_unique = np.unique([me[0] for me in get_sorted(stats, type=type)]) + filtered_list = [ + filter_stats( + stats, + time=t_unique, + num_restarts=max([me.num_restarts for me in filter_stats(stats, type=type, time=t_unique).keys()]), + type=type, + ) + for t_unique in times_unique + ] + for item in filtered_list: + sorted_nested_list.append(sort_stats(item, sortby=sortby)) + sorted_list = [item for sub_item in sorted_nested_list for item in sub_item] + return sorted_list + + +def proof_assertions_description(description, use_adaptivity, use_switch_estimator): + """ + Function to proof the assertions (function to get cleaner code) + + Args: + description(dict): contains all information for a controller run + use_adaptivity (bool): flag if adaptivity wants to be used or not + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + """ + + assert ( + description['problem_params']['alpha'] > description['problem_params']['V_ref'] + ), 'Please set "alpha" greater than "V_ref"' + assert description['problem_params']['V_ref'] > 0, 'Please set "V_ref" greater than 0' + assert type(description['problem_params']['V_ref']) == float, '"V_ref" needs to be of type float' assert 'errtol' not in description['step_params'].keys(), 'No exact solution known to compute error' assert 'alpha' in description['problem_params'].keys(), 'Please supply "alpha" in the problem parameters' assert 'V_ref' in description['problem_params'].keys(), 'Please supply "V_ref" in the problem parameters' + if use_switch_estimator or use_adaptivity: + assert description['level_params']['restol'] == -1, "For adaptivity, please set restol to -1 or omit it" + if __name__ == "__main__": run() diff --git a/pySDC/projects/PinTSimE/estimation_check.py b/pySDC/projects/PinTSimE/estimation_check.py index e565bd11d3..aad507a9bd 100644 --- a/pySDC/projects/PinTSimE/estimation_check.py +++ b/pySDC/projects/PinTSimE/estimation_check.py @@ -9,17 +9,29 @@ from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI from pySDC.projects.PinTSimE.piline_model import setup_mpl -from pySDC.projects.PinTSimE.battery_model import log_data, proof_assertions_description +from pySDC.projects.PinTSimE.battery_model import get_recomputed, get_data_dict, log_data, proof_assertions_description import pySDC.helpers.plot_helper as plt_helper +from pySDC.core.Hooks import hooks from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity -from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import EstimateEmbeddedErrorNonMPI def run(dt, problem, sweeper, use_switch_estimator, use_adaptivity, V_ref): """ A simple test program to do SDC/PFASST runs for the battery drain model + + Args: + dt (np.float): (initial) time step + problem (problem_class): the considered problem class (here: battery or battery_implicit) + sweeper (sweeper_class): the used sweeper class to solve (here: imex_1st_order or generic_implicit) + use_switch_estimator (bool): Switch estimator should be used or not + use_adaptivity (bool): Adaptivity should be used or not + V_ref (np.float): reference value for the switch + + Returns: + description (dict): contains all the information for the controller + stats (dict): includes the statistics of the solve """ # initialize level parameters @@ -32,12 +44,13 @@ def run(dt, problem, sweeper, use_switch_estimator, use_adaptivity, V_ref): sweeper_params['quad_type'] = 'LOBATTO' sweeper_params['num_nodes'] = 5 # sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part - sweeper_params['initial_guess'] = 'zero' + sweeper_params['initial_guess'] = 'spread' # initialize problem parameters problem_params = dict() problem_params['newton_maxiter'] = 200 problem_params['newton_tol'] = 1e-08 + problem_params['ncondensators'] = 1 problem_params['Vs'] = 5.0 problem_params['Rs'] = 0.5 problem_params['C'] = 1.0 @@ -45,8 +58,6 @@ def run(dt, problem, sweeper, use_switch_estimator, use_adaptivity, V_ref): problem_params['L'] = 1.0 problem_params['alpha'] = 1.2 problem_params['V_ref'] = V_ref - problem_params['set_switch'] = np.array([False], dtype=bool) - problem_params['t_switch'] = np.zeros(1) # initialize step parameters step_params = dict() @@ -54,7 +65,7 @@ def run(dt, problem, sweeper, use_switch_estimator, use_adaptivity, V_ref): # initialize controller parameters controller_params = dict() - controller_params['logger_level'] = 20 + controller_params['logger_level'] = 30 controller_params['hook_class'] = log_data controller_params['mssdc_jac'] = False @@ -82,12 +93,19 @@ def run(dt, problem, sweeper, use_switch_estimator, use_adaptivity, V_ref): if use_switch_estimator or use_adaptivity: description['convergence_controllers'] = convergence_controllers - proof_assertions_description(description, problem_params) + proof_assertions_description(description, use_adaptivity, use_switch_estimator) # set time parameters t0 = 0.0 Tend = 0.3 + assert dt < Tend, "Time step is too large for the time domain!" + + assert ( + Tend == 0.3 and description['problem_params']['V_ref'] == 1.0 and description['problem_params']['alpha'] == 1.2 + ), "Error! Do not use other parameters for V_ref != 1.0, alpha != 1.2, Tend != 0.3 due to hardcoded reference!" + assert dt == 4e-2 or dt == 4e-3, "Error! Do not use another time step dt!= 1e-2!" + # instantiate controller controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description) @@ -98,37 +116,24 @@ def run(dt, problem, sweeper, use_switch_estimator, use_adaptivity, V_ref): # call main function to get things done... uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) - Path("data").mkdir(parents=True, exist_ok=True) - fname = 'data/battery.dat' - f = open(fname, 'wb') - dill.dump(stats, f) - f.close() - - # filter statistics by number of iterations - iter_counts = get_sorted(stats, type='niter', recomputed=False, sortby='time') - - # compute and print statistics - f = open('data/battery_out.txt', 'w') - niters = np.array([item[1] for item in iter_counts]) - - assert np.mean(niters) <= 11, "Mean number of iterations is too high, got %s" % np.mean(niters) - f.close() - return description, stats def check(cwd='./'): """ Routine to check the differences between using a switch estimator or not + + Args: + cwd: current working directory """ V_ref = 1.0 - dt_list = [1e-2, 1e-3] + dt_list = [4e-2, 4e-3] use_switch_estimator = [True, False] use_adaptivity = [True, False] - restarts_true = [] - restarts_false_adapt = [] - restarts_true_adapt = [] + restarts_SE = [] + restarts_adapt = [] + restarts_SE_adapt = [] problem_classes = [battery, battery_implicit] sweeper_classes = [imex_1st_order, generic_implicit] @@ -146,6 +151,14 @@ def check(cwd='./'): V_ref=V_ref, ) + if use_A or use_SE: + check_solution(stats, dt_item, problem.__name__, use_A, use_SE) + + if use_SE: + assert ( + len(get_recomputed(stats, type='switch', sortby='time')) >= 1 + ), 'No switches found for dt={}!'.format(dt_item) + fname = 'data/battery_dt{}_USE{}_USA{}_{}.dat'.format(dt_item, use_SE, use_A, sweeper.__name__) f = open(fname, 'wb') dill.dump(stats, f) @@ -153,24 +166,23 @@ def check(cwd='./'): if use_SE or use_A: restarts_sorted = np.array(get_sorted(stats, type='restart', recomputed=None))[:, 1] - print('Restarts for dt={}: {}'.format(dt_item, np.sum(restarts_sorted))) if use_SE and not use_A: - restarts_true.append(np.sum(restarts_sorted)) + restarts_SE.append(np.sum(restarts_sorted)) elif not use_SE and use_A: - restarts_false_adapt.append(np.sum(restarts_sorted)) + restarts_adapt.append(np.sum(restarts_sorted)) elif use_SE and use_A: - restarts_true_adapt.append(np.sum(restarts_sorted)) + restarts_SE_adapt.append(np.sum(restarts_sorted)) accuracy_check(dt_list, problem.__name__, sweeper.__name__, V_ref) differences_around_switch( dt_list, problem.__name__, - restarts_true, - restarts_false_adapt, - restarts_true_adapt, + restarts_SE, + restarts_adapt, + restarts_SE_adapt, sweeper.__name__, V_ref, ) @@ -179,14 +191,21 @@ def check(cwd='./'): iterations_over_time(dt_list, description['step_params']['maxiter'], problem.__name__, sweeper.__name__) - restarts_true = [] - restarts_false_adapt = [] - restarts_true_adapt = [] + restarts_SE = [] + restarts_adapt = [] + restarts_SE_adapt = [] def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'): """ Routine to check accuracy for different step sizes in case of using adaptivity + + Args: + dt_list (list): list of considered (initial) step sizes + problem (problem.__name__): Problem class used to consider (the class name) + sweeper (sweeper.__name__): Sweeper used to solve (the class name) + V_ref (np.float): reference value for the switch + cwd: current working directory """ if len(dt_list) > 1: @@ -202,45 +221,49 @@ def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'): count_ax = 0 for dt_item in dt_list: f3 = open(cwd + 'data/battery_dt{}_USETrue_USATrue_{}.dat'.format(dt_item, sweeper), 'rb') - stats_TT = dill.load(f3) + stats_SE_adapt = dill.load(f3) f3.close() f4 = open(cwd + 'data/battery_dt{}_USEFalse_USATrue_{}.dat'.format(dt_item, sweeper), 'rb') - stats_FT = dill.load(f4) + stats_adapt = dill.load(f4) f4.close() - val_switch_TT = get_sorted(stats_TT, type='switch1', sortby='time') - t_switch_adapt = [v[1] for v in val_switch_TT] - t_switch_adapt = t_switch_adapt[-1] + switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time') + t_switch_SE_adapt = [v[1] for v in switches_SE_adapt] + t_switch_SE_adapt = t_switch_SE_adapt[-1] - dt_TT_val = get_sorted(stats_TT, type='dt', recomputed=False) - dt_FT_val = get_sorted(stats_FT, type='dt', recomputed=False) + dt_SE_adapt_val = get_sorted(stats_SE_adapt, type='dt', recomputed=False) + dt_adapt_val = get_sorted(stats_adapt, type='dt', recomputed=False) - e_emb_TT_val = get_sorted(stats_TT, type='e_embedded', recomputed=False) - e_emb_FT_val = get_sorted(stats_FT, type='e_embedded', recomputed=False) + e_emb_SE_adapt_val = get_sorted(stats_SE_adapt, type='e_embedded', recomputed=False) + e_emb_adapt_val = get_sorted(stats_adapt, type='e_embedded', recomputed=False) - times_TT = [v[0] for v in e_emb_TT_val] - times_FT = [v[0] for v in e_emb_FT_val] + times_SE_adapt = [v[0] for v in e_emb_SE_adapt_val] + times_adapt = [v[0] for v in e_emb_adapt_val] - e_emb_TT = [v[1] for v in e_emb_TT_val] - e_emb_FT = [v[1] for v in e_emb_FT_val] + e_emb_SE_adapt = [v[1] for v in e_emb_SE_adapt_val] + e_emb_adapt = [v[1] for v in e_emb_adapt_val] if len(dt_list) > 1: - ax_acc[count_ax].set_title(r'$\Delta t$={}'.format(dt_item)) + ax_acc[count_ax].set_title(r'$\Delta t_\mathrm{initial}$=%s' % dt_item) dt1 = ax_acc[count_ax].plot( - [v[0] for v in dt_TT_val], [v[1] for v in dt_TT_val], 'ko-', label=r'SE+A - $\Delta t$' + [v[0] for v in dt_SE_adapt_val], + [v[1] for v in dt_SE_adapt_val], + 'ko-', + label=r'SE+A - $\Delta t_\mathrm{adapt}$', ) dt2 = ax_acc[count_ax].plot( - [v[0] for v in dt_FT_val], [v[1] for v in dt_FT_val], 'g-', label=r'A - $\Delta t$' + [v[0] for v in dt_adapt_val], [v[1] for v in dt_adapt_val], 'g-', label=r'A - $\Delta t_\mathrm{adapt}$' ) - ax_acc[count_ax].axvline(x=t_switch_adapt, linestyle='--', linewidth=0.5, color='r', label='Switch') + ax_acc[count_ax].axvline(x=t_switch_SE_adapt, linestyle='--', linewidth=0.5, color='r', label='Switch') + ax_acc[count_ax].tick_params(axis='both', which='major', labelsize=6) ax_acc[count_ax].set_xlabel('Time', fontsize=6) if count_ax == 0: - ax_acc[count_ax].set_ylabel(r'$\Delta t_{adapted}$', fontsize=6) + ax_acc[count_ax].set_ylabel(r'$\Delta t_\mathrm{adapt}$', fontsize=6) e_ax = ax_acc[count_ax].twinx() - e_plt1 = e_ax.plot(times_TT, e_emb_TT, 'k--', label=r'SE+A - $\epsilon_{emb}$') - e_plt2 = e_ax.plot(times_FT, e_emb_FT, 'g--', label=r'A - $\epsilon_{emb}$') + e_plt1 = e_ax.plot(times_SE_adapt, e_emb_SE_adapt, 'k--', label=r'SE+A - $\epsilon_{emb}$') + e_plt2 = e_ax.plot(times_adapt, e_emb_adapt, 'g--', label=r'A - $\epsilon_{emb}$') e_ax.set_yscale('log', base=10) e_ax.set_ylim(1e-16, 1e-7) e_ax.tick_params(labelsize=6) @@ -248,26 +271,37 @@ def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'): lines = dt1 + e_plt1 + dt2 + e_plt2 labels = [l.get_label() for l in lines] - ax_acc[count_ax].legend(lines, labels, frameon=False, fontsize=6, loc='upper left') + ax_acc[count_ax].legend(lines, labels, frameon=False, fontsize=6, loc='upper right') else: - ax_acc.set_title(r'$\Delta t$={}'.format(dt_item)) - dt1 = ax_acc.plot([v[0] for v in dt_TT_val], [v[1] for v in dt_TT_val], 'ko-', label=r'SE+A - $\Delta t$') - dt2 = ax_acc.plot([v[0] for v in dt_FT_val], [v[1] for v in dt_FT_val], 'go-', label=r'A - $\Delta t$') - ax_acc.axvline(x=t_switch_adapt, linestyle='--', linewidth=0.5, color='r', label='Switch') + ax_acc.set_title(r'$\Delta t_\mathrm{initial}$=%s' % dt_item) + dt1 = ax_acc.plot( + [v[0] for v in dt_SE_adapt_val], + [v[1] for v in dt_SE_adapt_val], + 'ko-', + label=r'SE+A - $\Delta t_\mathrm{adapt}$', + ) + dt2 = ax_acc.plot( + [v[0] for v in dt_adapt_val], + [v[1] for v in dt_adapt_val], + 'go-', + label=r'A - $\Delta t_\mathrm{adapt}$', + ) + ax_acc.axvline(x=t_switch_SE_adapt, linestyle='--', linewidth=0.5, color='r', label='Switch') + ax_acc.tick_params(axis='both', which='major', labelsize=6) ax_acc.set_xlabel('Time', fontsize=6) - ax_acc.set_ylabel(r'$Delta t_{adapted}$', fontsize=6) + ax_acc.set_ylabel(r'$Delta t_\mathrm{adapt}$', fontsize=6) e_ax = ax_acc.twinx() - e_plt1 = e_ax.plot(times_TT, e_emb_TT, 'k--', label=r'SE+A - $\epsilon_{emb}$') - e_plt2 = e_ax.plot(times_FT, e_emb_FT, 'g--', label=r'A - $\epsilon_{emb}$') + e_plt1 = e_ax.plot(times_SE_adapt, e_emb_SE_adapt, 'k--', label=r'SE+A - $\epsilon_{emb}$') + e_plt2 = e_ax.plot(times_adapt, e_emb_adapt, 'g--', label=r'A - $\epsilon_{emb}$') e_ax.set_yscale('log', base=10) e_ax.tick_params(labelsize=6) lines = dt1 + e_plt1 + dt2 + e_plt2 labels = [l.get_label() for l in lines] - ax_acc.legend(lines, labels, frameon=False, fontsize=6, loc='upper left') + ax_acc.legend(lines, labels, frameon=False, fontsize=6, loc='upper right') count_ax += 1 @@ -275,11 +309,221 @@ def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'): plt_helper.plt.close(fig_acc) +def check_solution(stats, dt, problem, use_adaptivity, use_switch_estimator): + """ + Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen. + + Args: + stats (dict): Raw statistics from a controller run + dt (float): initial time step + problem (problem_class.__name__): the problem_class that is numerically solved + use_switch_estimator (bool): + use_adaptivity (bool): + """ + + data = get_data_dict(stats, use_adaptivity, use_switch_estimator) + + if problem == 'battery': + if use_switch_estimator and use_adaptivity: + msg = f'Error when using switch estimator and adaptivity for battery for dt={dt:.1e}:' + if dt == 4e-2: + expected = { + 'cL': 0.5525783945667581, + 'vC': 1.00001743462299, + 'dt': 0.03550610373897258, + 'e_em': 6.21240694442804e-08, + 'switches': 0.18231603298272345, + 'restarts': 4.0, + 'sum_niters': 56, + } + elif dt == 4e-3: + expected = { + 'cL': 0.5395601429161445, + 'vC': 1.0000413761942089, + 'dt': 0.028281271825675414, + 'e_em': 2.5628611677319668e-08, + 'switches': 0.18230920573953438, + 'restarts': 3.0, + 'sum_niters': 48, + } + + got = { + 'cL': data['cL'][-1], + 'vC': data['vC'][-1], + 'dt': data['dt'][-1], + 'e_em': data['e_em'][-1], + 'switches': data['switches'][-1], + 'restarts': data['restarts'], + 'sum_niters': data['sum_niters'], + } + elif use_switch_estimator and not use_adaptivity: + msg = f'Error when using switch estimator for battery for dt={dt:.1e}:' + if dt == 4e-2: + expected = { + 'cL': 0.6139093327509394, + 'vC': 1.0010140038721593, + 'switches': 0.1824302065533169, + 'restarts': 1.0, + 'sum_niters': 48, + } + elif dt == 4e-3: + expected = { + 'cL': 0.5429509935448258, + 'vC': 1.0001158309787614, + 'switches': 0.18232183080236553, + 'restarts': 1.0, + 'sum_niters': 392, + } + + got = { + 'cL': data['cL'][-1], + 'vC': data['vC'][-1], + 'switches': data['switches'][-1], + 'restarts': data['restarts'], + 'sum_niters': data['sum_niters'], + } + + elif not use_switch_estimator and use_adaptivity: + msg = f'Error when using adaptivity for battery for dt={dt:.1e}:' + if dt == 4e-2: + expected = { + 'cL': 0.5966289599915113, + 'vC': 0.9923148791604984, + 'dt': 0.03564958366355817, + 'e_em': 6.210964231812e-08, + 'restarts': 1.0, + 'sum_niters': 36, + } + elif dt == 4e-3: + expected = { + 'cL': 0.5431613774808756, + 'vC': 0.9934307674636834, + 'dt': 0.022880524075396924, + 'e_em': 1.1130212751453428e-08, + 'restarts': 3.0, + 'sum_niters': 52, + } + + got = { + 'cL': data['cL'][-1], + 'vC': data['vC'][-1], + 'dt': data['dt'][-1], + 'e_em': data['e_em'][-1], + 'restarts': data['restarts'], + 'sum_niters': data['sum_niters'], + } + + elif problem == 'battery_implicit': + if use_switch_estimator and use_adaptivity: + msg = f'Error when using switch estimator and adaptivity for battery_implicit for dt={dt:.1e}:' + if dt == 4e-2: + expected = { + 'cL': 0.6717104472882885, + 'vC': 1.0071670698947914, + 'dt': 0.035896059229296486, + 'e_em': 6.208836400567463e-08, + 'switches': 0.18232158833761175, + 'restarts': 3.0, + 'sum_niters': 36, + } + elif dt == 4e-3: + expected = { + 'cL': 0.5396216192241711, + 'vC': 1.0000561014463172, + 'dt': 0.009904645972832471, + 'e_em': 2.220446049250313e-16, + 'switches': 0.18230549652342606, + 'restarts': 4.0, + 'sum_niters': 44, + } + + got = { + 'cL': data['cL'][-1], + 'vC': data['vC'][-1], + 'dt': data['dt'][-1], + 'e_em': data['e_em'][-1], + 'switches': data['switches'][-1], + 'restarts': data['restarts'], + 'sum_niters': data['sum_niters'], + } + elif use_switch_estimator and not use_adaptivity: + msg = f'Error when using switch estimator for battery_implicit for dt={dt:.1e}:' + if dt == 4e-2: + expected = { + 'cL': 0.613909968362315, + 'vC': 1.0010140112484431, + 'switches': 0.18243023230469263, + 'restarts': 1.0, + 'sum_niters': 48, + } + elif dt == 4e-3: + expected = { + 'cL': 0.5429616576526073, + 'vC': 1.0001158454740509, + 'switches': 0.1823218812753008, + 'restarts': 1.0, + 'sum_niters': 392, + } + + got = { + 'cL': data['cL'][-1], + 'vC': data['vC'][-1], + 'switches': data['switches'][-1], + 'restarts': data['restarts'], + 'sum_niters': data['sum_niters'], + } + + elif not use_switch_estimator and use_adaptivity: + msg = f'Error when using adaptivity for battery_implicit for dt={dt:.1e}:' + if dt == 4e-2: + expected = { + 'cL': 0.5556563012729733, + 'vC': 0.9930947318467772, + 'dt': 0.035507110551631804, + 'e_em': 6.2098696185231e-08, + 'restarts': 6.0, + 'sum_niters': 64, + } + elif dt == 4e-3: + expected = { + 'cL': 0.5401117929618637, + 'vC': 0.9933888475391347, + 'dt': 0.03176025170463925, + 'e_em': 4.0386798239033794e-08, + 'restarts': 8.0, + 'sum_niters': 80, + } + + got = { + 'cL': data['cL'][-1], + 'vC': data['vC'][-1], + 'dt': data['dt'][-1], + 'e_em': data['e_em'][-1], + 'restarts': data['restarts'], + 'sum_niters': data['sum_niters'], + } + + for key in expected.keys(): + assert np.isclose( + expected[key], got[key], rtol=1e-4 + ), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}' + + def differences_around_switch( - dt_list, problem, restarts_true, restarts_false_adapt, restarts_true_adapt, sweeper, V_ref, cwd='./' + dt_list, problem, restarts_SE, restarts_adapt, restarts_SE_adapt, sweeper, V_ref, cwd='./' ): """ Routine to plot the differences before, at, and after the switch. Produces the diffs_estimation_.png file + + Args: + dt_list (list): list of considered (initial) step sizes + problem (problem.__name__): Problem class used to consider (the class name) + restarts_SE (list): Restarts for the solve only using the switch estimator + restarts_adapt (list): Restarts for the solve of only using adaptivity + restarts_SE_adapt (list): Restarts for the solve of using both, switch estimator and adaptivity + sweeper (sweeper.__name__): Sweeper used to solve (the class name) + V_ref (np.float): reference value for the switch + cwd: current working directory """ diffs_true_at = [] @@ -294,59 +538,59 @@ def differences_around_switch( diffs_false_after_adapt = [] for dt_item in dt_list: f1 = open(cwd + 'data/battery_dt{}_USETrue_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb') - stats_TF = dill.load(f1) + stats_SE = dill.load(f1) f1.close() f2 = open(cwd + 'data/battery_dt{}_USEFalse_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb') - stats_FF = dill.load(f2) + stats = dill.load(f2) f2.close() f3 = open(cwd + 'data/battery_dt{}_USETrue_USATrue_{}.dat'.format(dt_item, sweeper), 'rb') - stats_TT = dill.load(f3) + stats_SE_adapt = dill.load(f3) f3.close() f4 = open(cwd + 'data/battery_dt{}_USEFalse_USATrue_{}.dat'.format(dt_item, sweeper), 'rb') - stats_FT = dill.load(f4) + stats_adapt = dill.load(f4) f4.close() - val_switch_TF = get_sorted(stats_TF, type='switch1', sortby='time') - t_switch = [v[1] for v in val_switch_TF] + switches_SE = get_recomputed(stats_SE, type='switch', sortby='time') + t_switch = [v[1] for v in switches_SE] t_switch = t_switch[-1] # battery has only one single switch - val_switch_TT = get_sorted(stats_TT, type='switch1', sortby='time') - t_switch_adapt = [v[1] for v in val_switch_TT] - t_switch_adapt = t_switch_adapt[-1] + switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time') + t_switch_SE_adapt = [v[1] for v in switches_SE_adapt] + t_switch_SE_adapt = t_switch_SE_adapt[-1] - vC_TF = get_sorted(stats_TF, type='voltage C', recomputed=False, sortby='time') - vC_FT = get_sorted(stats_FT, type='voltage C', recomputed=False, sortby='time') - vC_TT = get_sorted(stats_TT, type='voltage C', recomputed=False, sortby='time') - vC_FF = get_sorted(stats_FF, type='voltage C', sortby='time') + vC_SE = get_sorted(stats_SE, type='voltage C', recomputed=False, sortby='time') + vC_adapt = get_sorted(stats_adapt, type='voltage C', recomputed=False, sortby='time') + vC_SE_adapt = get_sorted(stats_SE_adapt, type='voltage C', recomputed=False, sortby='time') + vC = get_sorted(stats, type='voltage C', sortby='time') - diff_TF, diff_FF = [v[1] - V_ref for v in vC_TF], [v[1] - V_ref for v in vC_FF] - times_TF, times_FF = [v[0] for v in vC_TF], [v[0] for v in vC_FF] + diff_SE, diff = [v[1] - V_ref for v in vC_SE], [v[1] - V_ref for v in vC] + times_SE, times = [v[0] for v in vC_SE], [v[0] for v in vC] - diff_FT, diff_TT = [v[1] - V_ref for v in vC_FT], [v[1] - V_ref for v in vC_TT] - times_FT, times_TT = [v[0] for v in vC_FT], [v[0] for v in vC_TT] + diff_adapt, diff_SE_adapt = [v[1] - V_ref for v in vC_adapt], [v[1] - V_ref for v in vC_SE_adapt] + times_adapt, times_SE_adapt = [v[0] for v in vC_adapt], [v[0] for v in vC_SE_adapt] - for m in range(len(times_TF)): - if np.round(times_TF[m], 15) == np.round(t_switch, 15): - diffs_true_at.append(diff_TF[m]) + for m in range(len(times_SE)): + if np.round(times_SE[m], 15) == np.round(t_switch, 15): + diffs_true_at.append(diff_SE[m]) - for m in range(1, len(times_FF)): - if times_FF[m - 1] <= t_switch <= times_FF[m]: - diffs_false_before.append(diff_FF[m - 1]) - diffs_false_after.append(diff_FF[m]) + for m in range(1, len(times)): + if times[m - 1] <= t_switch <= times[m]: + diffs_false_before.append(diff[m - 1]) + diffs_false_after.append(diff[m]) - for m in range(len(times_TT)): - if np.round(times_TT[m], 13) == np.round(t_switch_adapt, 13): - diffs_true_at_adapt.append(diff_TT[m]) - diffs_true_before_adapt.append(diff_TT[m - 1]) - diffs_true_after_adapt.append(diff_TT[m + 1]) + for m in range(len(times_SE_adapt)): + if np.round(times_SE_adapt[m], 13) == np.round(t_switch_SE_adapt, 13): + diffs_true_at_adapt.append(diff_SE_adapt[m]) + diffs_true_before_adapt.append(diff_SE_adapt[m - 1]) + diffs_true_after_adapt.append(diff_SE_adapt[m + 1]) - for m in range(len(times_FT)): - if times_FT[m - 1] <= t_switch <= times_FT[m]: - diffs_false_before_adapt.append(diff_FT[m - 1]) - diffs_false_after_adapt.append(diff_FT[m]) + for m in range(len(times_adapt)): + if times_adapt[m - 1] <= t_switch <= times_adapt[m]: + diffs_false_before_adapt.append(diff_adapt[m - 1]) + diffs_false_after_adapt.append(diff_adapt[m]) setup_mpl() fig_around, ax_around = plt_helper.plt.subplots(1, 3, figsize=(9, 3), sharex='col', sharey='row') @@ -356,14 +600,15 @@ def differences_around_switch( pos13 = ax_around[0].plot(dt_list, diffs_true_at, 'ko--', label='at switch') ax_around[0].set_xticks(dt_list) ax_around[0].set_xticklabels(dt_list) + ax_around[0].tick_params(axis='both', which='major', labelsize=6) ax_around[0].set_xscale('log', base=10) ax_around[0].set_yscale('symlog', linthresh=1e-8) ax_around[0].set_ylim(-1, 1) - ax_around[0].set_xlabel(r'$\Delta t$', fontsize=6) + ax_around[0].set_xlabel(r'$\Delta t_\mathrm{initial}$', fontsize=6) ax_around[0].set_ylabel(r'$v_{C}-V_{ref}$', fontsize=6) restart_ax0 = ax_around[0].twinx() - restarts_plt0 = restart_ax0.plot(dt_list, restarts_true, 'cs--', label='Restarts') + restarts_plt0 = restart_ax0.plot(dt_list, restarts_SE, 'cs--', label='Restarts') restart_ax0.tick_params(labelsize=6) lines = pos11 + pos12 + pos13 + restarts_plt0 @@ -375,13 +620,14 @@ def differences_around_switch( pos22 = ax_around[1].plot(dt_list, diffs_false_after_adapt, 'bd--', label='after switch') ax_around[1].set_xticks(dt_list) ax_around[1].set_xticklabels(dt_list) + ax_around[1].tick_params(axis='both', which='major', labelsize=6) ax_around[1].set_xscale('log', base=10) ax_around[1].set_yscale('symlog', linthresh=1e-8) ax_around[1].set_ylim(-1, 1) - ax_around[1].set_xlabel(r'$\Delta t$', fontsize=6) + ax_around[1].set_xlabel(r'$\Delta t_\mathrm{initial}$', fontsize=6) restart_ax1 = ax_around[1].twinx() - restarts_plt1 = restart_ax1.plot(dt_list, restarts_false_adapt, 'cs--', label='Restarts') + restarts_plt1 = restart_ax1.plot(dt_list, restarts_adapt, 'cs--', label='Restarts') restart_ax1.tick_params(labelsize=6) lines = pos21 + pos22 + restarts_plt1 @@ -394,90 +640,101 @@ def differences_around_switch( pos33 = ax_around[2].plot(dt_list, diffs_true_at_adapt, 'ko--', label='at switch') ax_around[2].set_xticks(dt_list) ax_around[2].set_xticklabels(dt_list) + ax_around[2].tick_params(axis='both', which='major', labelsize=6) ax_around[2].set_xscale('log', base=10) ax_around[2].set_yscale('symlog', linthresh=1e-8) ax_around[2].set_ylim(-1, 1) - ax_around[2].set_xlabel(r'$\Delta t$', fontsize=6) + ax_around[2].set_xlabel(r'$\Delta t_\mathrm{initial}$', fontsize=6) restart_ax2 = ax_around[2].twinx() - restarts_plt2 = restart_ax2.plot(dt_list, restarts_true_adapt, 'cs--', label='Restarts') + restarts_plt2 = restart_ax2.plot(dt_list, restarts_SE_adapt, 'cs--', label='Restarts') restart_ax2.tick_params(labelsize=6) lines = pos31 + pos32 + pos33 + restarts_plt2 labels = [l.get_label() for l in lines] ax_around[2].legend(frameon=False, fontsize=6, loc='lower right') - fig_around.savefig('data/diffs_estimation_{}.png'.format(sweeper), dpi=300, bbox_inches='tight') + fig_around.savefig('data/diffs_around_switch_{}.png'.format(sweeper), dpi=300, bbox_inches='tight') plt_helper.plt.close(fig_around) def differences_over_time(dt_list, problem, sweeper, V_ref, cwd='./'): """ Routine to plot the differences in time using the switch estimator or not. Produces the difference_estimation_.png file + + Args: + dt_list (list): list of considered (initial) step sizes + problem (problem.__name__): Problem class used to consider (the class name) + sweeper (sweeper.__name__): Sweeper used to solve (the class name) + V_ref (np.float): reference value for the switch + cwd: current working directory """ if len(dt_list) > 1: setup_mpl() fig_diffs, ax_diffs = plt_helper.plt.subplots( - 2, len(dt_list), figsize=(3 * len(dt_list), 4), sharex='col', sharey='row' + 2, len(dt_list), figsize=(4 * len(dt_list), 6), sharex='col', sharey='row' ) else: setup_mpl() - fig_diffs, ax_diffs = plt_helper.plt.subplots(2, 1, figsize=(3, 3)) + fig_diffs, ax_diffs = plt_helper.plt.subplots(2, 1, figsize=(4, 6)) count_ax = 0 for dt_item in dt_list: f1 = open(cwd + 'data/battery_dt{}_USETrue_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb') - stats_TF = dill.load(f1) + stats_SE = dill.load(f1) f1.close() f2 = open(cwd + 'data/battery_dt{}_USEFalse_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb') - stats_FF = dill.load(f2) + stats = dill.load(f2) f2.close() f3 = open(cwd + 'data/battery_dt{}_USETrue_USATrue_{}.dat'.format(dt_item, sweeper), 'rb') - stats_TT = dill.load(f3) + stats_SE_adapt = dill.load(f3) f3.close() f4 = open(cwd + 'data/battery_dt{}_USEFalse_USATrue_{}.dat'.format(dt_item, sweeper), 'rb') - stats_FT = dill.load(f4) + stats_adapt = dill.load(f4) f4.close() - val_switch_TF = get_sorted(stats_TF, type='switch1', sortby='time') - t_switch_TF = [v[1] for v in val_switch_TF] - t_switch_TF = t_switch_TF[-1] # battery has only one single switch + switches_SE = get_recomputed(stats_SE, type='switch', sortby='time') + t_switch_SE = [v[1] for v in switches_SE] + t_switch_SE = t_switch_SE[-1] # battery has only one single switch - val_switch_TT = get_sorted(stats_TT, type='switch1', sortby='time') - t_switch_adapt = [v[1] for v in val_switch_TT] - t_switch_adapt = t_switch_adapt[-1] + switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time') + t_switch_SE_adapt = [v[1] for v in switches_SE_adapt] + t_switch_SE_adapt = t_switch_SE_adapt[-1] - dt_FT = np.array(get_sorted(stats_FT, type='dt', recomputed=False, sortby='time')) - dt_TT = np.array(get_sorted(stats_TT, type='dt', recomputed=False, sortby='time')) + dt_adapt = np.array(get_sorted(stats_adapt, type='dt', recomputed=False, sortby='time')) + dt_SE_adapt = np.array(get_sorted(stats_SE_adapt, type='dt', recomputed=False, sortby='time')) - restart_FT = np.array(get_sorted(stats_FT, type='restart', recomputed=None, sortby='time')) - restart_TT = np.array(get_sorted(stats_TT, type='restart', recomputed=None, sortby='time')) + restart_adapt = np.array(get_sorted(stats_adapt, type='restart', recomputed=None, sortby='time')) + restart_SE_adapt = np.array(get_sorted(stats_SE_adapt, type='restart', recomputed=None, sortby='time')) - vC_TF = get_sorted(stats_TF, type='voltage C', recomputed=False, sortby='time') - vC_FT = get_sorted(stats_FT, type='voltage C', recomputed=False, sortby='time') - vC_TT = get_sorted(stats_TT, type='voltage C', recomputed=False, sortby='time') - vC_FF = get_sorted(stats_FF, type='voltage C', sortby='time') + vC_SE = get_sorted(stats_SE, type='voltage C', recomputed=False, sortby='time') + vC_adapt = get_sorted(stats_adapt, type='voltage C', recomputed=False, sortby='time') + vC_SE_adapt = get_sorted(stats_SE_adapt, type='voltage C', recomputed=False, sortby='time') + vC = get_sorted(stats, type='voltage C', sortby='time') - diff_TF, diff_FF = [v[1] - V_ref for v in vC_TF], [v[1] - V_ref for v in vC_FF] - times_TF, times_FF = [v[0] for v in vC_TF], [v[0] for v in vC_FF] + diff_SE, diff = [v[1] - V_ref for v in vC_SE], [v[1] - V_ref for v in vC] + times_SE, times = [v[0] for v in vC_SE], [v[0] for v in vC] - diff_FT, diff_TT = [v[1] - V_ref for v in vC_FT], [v[1] - V_ref for v in vC_TT] - times_FT, times_TT = [v[0] for v in vC_FT], [v[0] for v in vC_TT] + diff_adapt, diff_SE_adapt = [v[1] - V_ref for v in vC_adapt], [v[1] - V_ref for v in vC_SE_adapt] + times_adapt, times_SE_adapt = [v[0] for v in vC_adapt], [v[0] for v in vC_SE_adapt] if len(dt_list) > 1: - ax_diffs[0, count_ax].set_title(r'$\Delta t$={}'.format(dt_item)) - ax_diffs[0, count_ax].plot(times_TF, diff_TF, label='SE=True, A=False', color='#ff7f0e') - ax_diffs[0, count_ax].plot(times_FF, diff_FF, label='SE=False, A=False', color='#1f77b4') - ax_diffs[0, count_ax].plot(times_FT, diff_FT, label='SE=False, A=True', color='red', linestyle='--') - ax_diffs[0, count_ax].plot(times_TT, diff_TT, label='SE=True, A=True', color='limegreen', linestyle='-.') - ax_diffs[0, count_ax].axvline(x=t_switch_TF, linestyle='--', linewidth=0.5, color='k', label='Switch') + ax_diffs[0, count_ax].set_title(r'$\Delta t$=%s' % dt_item) + ax_diffs[0, count_ax].plot(times_SE, diff_SE, label='SE=True, A=False', color='#ff7f0e') + ax_diffs[0, count_ax].plot(times, diff, label='SE=False, A=False', color='#1f77b4') + ax_diffs[0, count_ax].plot(times_adapt, diff_adapt, label='SE=False, A=True', color='red', linestyle='--') + ax_diffs[0, count_ax].plot( + times_SE_adapt, diff_SE_adapt, label='SE=True, A=True', color='limegreen', linestyle='-.' + ) + ax_diffs[0, count_ax].axvline(x=t_switch_SE, linestyle='--', linewidth=0.5, color='k', label='Switch') ax_diffs[0, count_ax].legend(frameon=False, fontsize=6, loc='lower left') ax_diffs[0, count_ax].set_yscale('symlog', linthresh=1e-5) + ax_diffs[0, count_ax].tick_params(axis='both', which='major', labelsize=6) if count_ax == 0: ax_diffs[0, count_ax].set_ylabel('Difference $v_{C}-V_{ref}$', fontsize=6) @@ -488,110 +745,128 @@ def differences_over_time(dt_list, problem, sweeper, V_ref, cwd='./'): ax_diffs[0, count_ax].legend(frameon=False, fontsize=6, loc='upper right') ax_diffs[1, count_ax].plot( - dt_FT[:, 0], dt_FT[:, 1], label=r'$\Delta t$ - SE=F, A=T', color='red', linestyle='--' + dt_adapt[:, 0], dt_adapt[:, 1], label=r'$\Delta t$ - SE=F, A=T', color='red', linestyle='--' ) ax_diffs[1, count_ax].plot([None], [None], label='Restart - SE=F, A=T', color='grey', linestyle='-.') - for i in range(len(restart_FT)): - if restart_FT[i, 1] > 0: - ax_diffs[1, count_ax].axvline(restart_FT[i, 0], color='grey', linestyle='-.') + for i in range(len(restart_adapt)): + if restart_adapt[i, 1] > 0: + ax_diffs[1, count_ax].axvline(restart_adapt[i, 0], color='grey', linestyle='-.') ax_diffs[1, count_ax].plot( - dt_TT[:, 0], dt_TT[:, 1], label=r'$ \Delta t$ - SE=T, A=T', color='limegreen', linestyle='-.' + dt_SE_adapt[:, 0], + dt_SE_adapt[:, 1], + label=r'$ \Delta t$ - SE=T, A=T', + color='limegreen', + linestyle='-.', ) ax_diffs[1, count_ax].plot([None], [None], label='Restart - SE=T, A=T', color='black', linestyle='-.') - for i in range(len(restart_TT)): - if restart_TT[i, 1] > 0: - ax_diffs[1, count_ax].axvline(restart_TT[i, 0], color='black', linestyle='-.') + for i in range(len(restart_SE_adapt)): + if restart_SE_adapt[i, 1] > 0: + ax_diffs[1, count_ax].axvline(restart_SE_adapt[i, 0], color='black', linestyle='-.') ax_diffs[1, count_ax].set_xlabel('Time', fontsize=6) + ax_diffs[1, count_ax].tick_params(axis='both', which='major', labelsize=6) if count_ax == 0: - ax_diffs[1, count_ax].set_ylabel(r'$\Delta t_{adapted}$', fontsize=6) + ax_diffs[1, count_ax].set_ylabel(r'$\Delta t_\mathrm{adapted}$', fontsize=6) - ax_diffs[1, count_ax].legend(frameon=True, fontsize=6, loc='upper left') + ax_diffs[1, count_ax].set_yscale('log', base=10) + ax_diffs[1, count_ax].legend(frameon=True, fontsize=6, loc='lower left') else: - ax_diffs[0].set_title(r'$\Delta t$={}'.format(dt_item)) - ax_diffs[0].plot(times_TF, diff_TF, label='SE=True', color='#ff7f0e') - ax_diffs[0].plot(times_FF, diff_FF, label='SE=False', color='#1f77b4') - ax_diffs[0].plot(times_FT, diff_FT, label='SE=False, A=True', color='red', linestyle='--') - ax_diffs[0].plot(times_TT, diff_TT, label='SE=True, A=True', color='limegreen', linestyle='-.') - ax_diffs[0].axvline(x=t_switch_TF, linestyle='--', linewidth=0.5, color='k', label='Switch') - ax_diffs[0].legend(frameon=False, fontsize=6, loc='lower left') + ax_diffs[0].set_title(r'$\Delta t$=%s' % dt_item) + ax_diffs[0].plot(times_SE, diff_SE, label='SE=True', color='#ff7f0e') + ax_diffs[0].plot(times, diff, label='SE=False', color='#1f77b4') + ax_diffs[0].plot(times_adapt, diff_adapt, label='SE=False, A=True', color='red', linestyle='--') + ax_diffs[0].plot(times_SE_adapt, diff_SE_adapt, label='SE=True, A=True', color='limegreen', linestyle='-.') + ax_diffs[0].axvline(x=t_switch_SE, linestyle='--', linewidth=0.5, color='k', label='Switch') + ax_diffs[0].tick_params(axis='both', which='major', labelsize=6) ax_diffs[0].set_yscale('symlog', linthresh=1e-5) ax_diffs[0].set_ylabel('Difference $v_{C}-V_{ref}$', fontsize=6) ax_diffs[0].legend(frameon=False, fontsize=6, loc='center right') - ax_diffs[1].plot(dt_FT[:, 0], dt_FT[:, 1], label='SE=False, A=True', color='red', linestyle='--') - ax_diffs[1].plot(dt_TT[:, 0], dt_TT[:, 1], label='SE=True, A=True', color='limegreen', linestyle='-.') + ax_diffs[1].plot(dt_adapt[:, 0], dt_adapt[:, 1], label='SE=False, A=True', color='red', linestyle='--') + ax_diffs[1].plot( + dt_SE_adapt[:, 0], dt_SE_adapt[:, 1], label='SE=True, A=True', color='limegreen', linestyle='-.' + ) + ax_diffs[1].tick_params(axis='both', which='major', labelsize=6) ax_diffs[1].set_xlabel('Time', fontsize=6) - ax_diffs[1].set_ylabel(r'$\Delta t_{adapted}$', fontsize=6) + ax_diffs[1].set_ylabel(r'$\Delta t_\mathrm{adapted}$', fontsize=6) + ax_diffs[1].set_yscale('log', base=10) ax_diffs[1].legend(frameon=False, fontsize=6, loc='upper right') count_ax += 1 plt_helper.plt.tight_layout() - fig_diffs.savefig('data/difference_estimation_{}.png'.format(sweeper), dpi=300, bbox_inches='tight') + fig_diffs.savefig('data/diffs_over_time_{}.png'.format(sweeper), dpi=300, bbox_inches='tight') plt_helper.plt.close(fig_diffs) def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'): """ Routine to plot the number of iterations over time using switch estimator or not. Produces the iters_.png file + + Args: + dt_list (list): list of considered (initial) step sizes + maxiter (np.int): maximum number of iterations + problem (problem.__name__): Problem class used to consider (the class name) + sweeper (sweeper.__name__): Sweeper used to solve (the class name) + cwd: current working directory """ - iters_time_TF = [] - iters_time_FF = [] - iters_time_TT = [] - iters_time_FT = [] - times_TF = [] - times_FF = [] - times_TT = [] - times_FT = [] - t_switches_TF = [] - t_switches_adapt = [] + iters_time_SE = [] + iters_time = [] + iters_time_SE_adapt = [] + iters_time_adapt = [] + times_SE = [] + times = [] + times_SE_adapt = [] + times_adapt = [] + t_switches_SE = [] + t_switches_SE_adapt = [] for dt_item in dt_list: f1 = open(cwd + 'data/battery_dt{}_USETrue_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb') - stats_TF = dill.load(f1) + stats_SE = dill.load(f1) f1.close() f2 = open(cwd + 'data/battery_dt{}_USEFalse_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb') - stats_FF = dill.load(f2) + stats = dill.load(f2) f2.close() f3 = open(cwd + 'data/battery_dt{}_USETrue_USATrue_{}.dat'.format(dt_item, sweeper), 'rb') - stats_TT = dill.load(f3) + stats_SE_adapt = dill.load(f3) f3.close() f4 = open(cwd + 'data/battery_dt{}_USEFalse_USATrue_{}.dat'.format(dt_item, sweeper), 'rb') - stats_FT = dill.load(f4) + stats_adapt = dill.load(f4) f4.close() - iter_counts_TF_val = get_sorted(stats_TF, type='niter', recomputed=False, sortby='time') - iter_counts_TT_val = get_sorted(stats_TT, type='niter', recomputed=False, sortby='time') - iter_counts_FT_val = get_sorted(stats_FT, type='niter', recomputed=False, sortby='time') - iter_counts_FF_val = get_sorted(stats_FF, type='niter', recomputed=False, sortby='time') + # consider iterations before restarts to see what happens + iter_counts_SE_val = get_sorted(stats_SE, type='niter', sortby='time') + iter_counts_SE_adapt_val = get_sorted(stats_SE_adapt, type='niter', sortby='time') + iter_counts_adapt_val = get_sorted(stats_adapt, type='niter', sortby='time') + iter_counts_val = get_sorted(stats, type='niter', sortby='time') - iters_time_TF.append([v[1] for v in iter_counts_TF_val]) - iters_time_TT.append([v[1] for v in iter_counts_TT_val]) - iters_time_FT.append([v[1] for v in iter_counts_FT_val]) - iters_time_FF.append([v[1] for v in iter_counts_FF_val]) + iters_time_SE.append([v[1] for v in iter_counts_SE_val]) + iters_time_SE_adapt.append([v[1] for v in iter_counts_SE_adapt_val]) + iters_time_adapt.append([v[1] for v in iter_counts_adapt_val]) + iters_time.append([v[1] for v in iter_counts_val]) - times_TF.append([v[0] for v in iter_counts_TF_val]) - times_TT.append([v[0] for v in iter_counts_TT_val]) - times_FT.append([v[0] for v in iter_counts_FT_val]) - times_FF.append([v[0] for v in iter_counts_FF_val]) + times_SE.append([v[0] for v in iter_counts_SE_val]) + times_SE_adapt.append([v[0] for v in iter_counts_SE_adapt_val]) + times_adapt.append([v[0] for v in iter_counts_adapt_val]) + times.append([v[0] for v in iter_counts_val]) - val_switch_TF = get_sorted(stats_TF, type='switch1', sortby='time') - t_switch_TF = [v[1] for v in val_switch_TF] - t_switches_TF.append(t_switch_TF[-1]) + switches_SE = get_recomputed(stats_SE, type='switch', sortby='time') + t_switch_SE = [v[1] for v in switches_SE] + t_switches_SE.append(t_switch_SE[-1]) - val_switch_TT = get_sorted(stats_TT, type='switch1', sortby='time') - t_switch_adapt = [v[1] for v in val_switch_TT] - t_switches_adapt.append(t_switch_adapt[-1]) + switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time') + t_switch_SE_adapt = [v[1] for v in switches_SE_adapt] + t_switches_SE_adapt.append(t_switch_SE_adapt[-1]) if len(dt_list) > 1: setup_mpl() @@ -599,18 +874,15 @@ def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'): nrows=1, ncols=len(dt_list), figsize=(2 * len(dt_list) - 1, 3), sharex='col', sharey='row' ) for col in range(len(dt_list)): - ax_iter_all[col].plot(times_FF[col], iters_time_FF[col], label='SE=F, A=F') - ax_iter_all[col].plot(times_TF[col], iters_time_TF[col], label='SE=T, A=F') - ax_iter_all[col].plot(times_TT[col], iters_time_TT[col], '--', label='SE=T, A=T') - ax_iter_all[col].plot(times_FT[col], iters_time_FT[col], '--', label='SE=F, A=T') - ax_iter_all[col].axvline(x=t_switches_TF[col], linestyle='--', linewidth=0.5, color='k', label='Switch') - if t_switches_adapt[col] != t_switches_TF[col]: - ax_iter_all[col].axvline( - x=t_switches_adapt[col], linestyle='--', linewidth=0.5, color='k', label='Switch' - ) - ax_iter_all[col].set_title('dt={}'.format(dt_list[col])) + ax_iter_all[col].plot(times[col], iters_time[col], label='SE=F, A=F') + ax_iter_all[col].plot(times_SE[col], iters_time_SE[col], label='SE=T, A=F') + ax_iter_all[col].plot(times_SE_adapt[col], iters_time_SE_adapt[col], '--', label='SE=T, A=T') + ax_iter_all[col].plot(times_adapt[col], iters_time_adapt[col], '--', label='SE=F, A=T') + ax_iter_all[col].axvline(x=t_switches_SE[col], linestyle='--', linewidth=0.5, color='k', label='Switch') + ax_iter_all[col].set_title(r'$\Delta t_\mathrm{initial}$=%s' % dt_list[col]) ax_iter_all[col].set_ylim(0, maxiter + 2) ax_iter_all[col].set_xlabel('Time', fontsize=6) + ax_iter_all[col].tick_params(axis='both', which='major', labelsize=6) if col == 0: ax_iter_all[col].set_ylabel('Number iterations', fontsize=6) @@ -620,16 +892,15 @@ def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'): setup_mpl() fig_iter_all, ax_iter_all = plt_helper.plt.subplots(nrows=1, ncols=1, figsize=(3, 3)) - ax_iter_all.plot(times_FF[0], iters_time_FF[0], label='SE=False') - ax_iter_all.plot(times_TF[0], iters_time_TF[0], label='SE=True') - ax_iter_all.plot(times_TT[0], iters_time_TT[0], '--', label='SE=T, A=T') - ax_iter_all.plot(times_FT[0], iters_time_FT[0], '--', label='SE=F, A=T') - ax_iter_all.axvline(x=t_switches_TF[0], linestyle='--', linewidth=0.5, color='k', label='Switch') - if t_switches_adapt[0] != t_switches_TF[0]: - ax_iter_all.axvline(x=t_switches_adapt[0], linestyle='--', linewidth=0.5, color='k', label='Switch') - ax_iter_all.set_title('dt={}'.format(dt_list[0])) + ax_iter_all.plot(times[0], iters_time[0], label='SE=False') + ax_iter_all.plot(times_SE[0], iters_time_SE[0], label='SE=True') + ax_iter_all.plot(times_SE_adapt[0], iters_time_SE_adapt[0], '--', label='SE=T, A=T') + ax_iter_all.plot(times_adapt[0], iters_time_adapt[0], '--', label='SE=F, A=T') + ax_iter_all.axvline(x=t_switches_SE[0], linestyle='--', linewidth=0.5, color='k', label='Switch') + ax_iter_all.set_title(r'$\Delta t_\mathrm{initial}$=%s' % dt_list[0]) ax_iter_all.set_ylim(0, maxiter + 2) ax_iter_all.set_xlabel('Time', fontsize=6) + ax_iter_all.tick_params(axis='both', which='major', labelsize=6) ax_iter_all.set_ylabel('Number iterations', fontsize=6) ax_iter_all.legend(frameon=False, fontsize=6, loc='upper right') diff --git a/pySDC/projects/PinTSimE/estimation_check_extended.py b/pySDC/projects/PinTSimE/estimation_check_extended.py index f075fe2de4..27d80234e8 100644 --- a/pySDC/projects/PinTSimE/estimation_check_extended.py +++ b/pySDC/projects/PinTSimE/estimation_check_extended.py @@ -4,10 +4,11 @@ from pySDC.helpers.stats_helper import get_sorted from pySDC.core.Collocation import CollBase as Collocation -from pySDC.implementations.problem_classes.Battery_2Condensators import battery_2condensators +from pySDC.implementations.problem_classes.Battery import battery_n_condensators from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI -from pySDC.implementations.transfer_classes.TransferMesh import mesh_to_mesh +from pySDC.projects.PinTSimE.battery_model import get_recomputed +from pySDC.projects.PinTSimE.battery_2condensators_model import get_data_dict from pySDC.projects.PinTSimE.piline_model import setup_mpl from pySDC.projects.PinTSimE.battery_2condensators_model import log_data, proof_assertions_description import pySDC.helpers.plot_helper as plt_helper @@ -16,39 +17,52 @@ def run(dt, use_switch_estimator=True): + """ + A simple test program to do SDC/PFASST runs for the battery drain model using 2 condensators + + Args: + dt (float): time step that wants to be used for the computation + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + + Returns: + stats (dict): all statistics from a controller run + description (dict): contains all information for a controller run + """ # initialize level parameters level_params = dict() - level_params['restol'] = 1e-13 + level_params['restol'] = -1 level_params['dt'] = dt + assert ( + dt == 4e-1 or dt == 4e-2 or dt == 4e-3 + ), "Error! Do not use other time steps dt != 4e-1 or dt != 4e-2 or dt != 4e-3 due to hardcoded references!" + # initialize sweeper parameters sweeper_params = dict() sweeper_params['quad_type'] = 'LOBATTO' sweeper_params['num_nodes'] = 5 - sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part - sweeper_params['initial_guess'] = 'zero' + # sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part + sweeper_params['initial_guess'] = 'spread' # initialize problem parameters problem_params = dict() + problem_params['ncondensators'] = 2 problem_params['Vs'] = 5.0 problem_params['Rs'] = 0.5 - problem_params['C1'] = 1.0 - problem_params['C2'] = 1.0 + problem_params['C'] = np.array([1.0, 1.0]) problem_params['R'] = 1.0 problem_params['L'] = 1.0 problem_params['alpha'] = 5.0 problem_params['V_ref'] = np.array([1.0, 1.0]) # [V_ref1, V_ref2] - problem_params['set_switch'] = np.array([False, False], dtype=bool) - problem_params['t_switch'] = np.zeros(np.shape(problem_params['V_ref'])[0]) # initialize step parameters step_params = dict() - step_params['maxiter'] = 20 + step_params['maxiter'] = 4 # initialize controller parameters controller_params = dict() - controller_params['logger_level'] = 20 + controller_params['logger_level'] = 30 controller_params['hook_class'] = log_data # convergence controllers @@ -59,13 +73,12 @@ def run(dt, use_switch_estimator=True): # fill description dictionary for easy step instantiation description = dict() - description['problem_class'] = battery_2condensators # pass problem class + description['problem_class'] = battery_n_condensators # pass problem class description['problem_params'] = problem_params # pass problem parameters description['sweeper_class'] = imex_1st_order # pass sweeper description['sweeper_params'] = sweeper_params # pass sweeper parameters description['level_params'] = level_params # pass level parameters description['step_params'] = step_params - description['space_transfer_class'] = mesh_to_mesh # pass spatial transfer class if use_switch_estimator: description['convergence_controllers'] = convergence_controllers @@ -86,40 +99,15 @@ def run(dt, use_switch_estimator=True): # call main function to get things done... uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend) - Path("data").mkdir(parents=True, exist_ok=True) - fname = 'data/battery_2condensators.dat' - f = open(fname, 'wb') - dill.dump(stats, f) - f.close() - - # filter statistics by number of iterations - iter_counts = get_sorted(stats, type='niter', sortby='time') - - # compute and print statistics - min_iter = 20 - max_iter = 0 - - f = open('data/battery_2condensators_out.txt', 'w') - niters = np.array([item[1] for item in iter_counts]) - out = ' Mean number of iterations: %4.2f' % np.mean(niters) - f.write(out + '\n') - print(out) - for item in iter_counts: - out = 'Number of iterations for time %4.2f: %1i' % item - f.write(out + '\n') - # print(out) - min_iter = min(min_iter, item[1]) - max_iter = max(max_iter, item[1]) - - assert np.mean(niters) <= 12, "Mean number of iterations is too high, got %s" % np.mean(niters) - f.close() - return stats, description def check(cwd='./'): """ Routine to check the differences between using a switch estimator or not + + Args: + cwd: current working directory """ dt_list = [4e-1, 4e-2, 4e-3] @@ -127,15 +115,21 @@ def check(cwd='./'): restarts_all = [] restarts_dict = dict() for dt_item in dt_list: - for item in use_switch_estimator: - stats, description = run(dt=dt_item, use_switch_estimator=item) + for use_SE in use_switch_estimator: + stats, description = run(dt=dt_item, use_switch_estimator=use_SE) + + if use_SE: + switches = get_recomputed(stats, type='switch', sortby='time') + assert len(switches) >= 2, f"Expected at least 2 switches for dt: {dt_item}, got {len(switches)}!" + + check_solution(stats, dt_item, use_SE) - fname = 'data/battery_2condensators_dt{}_USE{}.dat'.format(dt_item, item) + fname = 'data/battery_2condensators_dt{}_USE{}.dat'.format(dt_item, use_SE) f = open(fname, 'wb') dill.dump(stats, f) f.close() - if item: + if use_SE: restarts_dict[dt_item] = np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time')) restarts = restarts_dict[dt_item][:, 1] restarts_all.append(np.sum(restarts)) @@ -161,15 +155,10 @@ def check(cwd='./'): stats_false = dill.load(f2) f2.close() - val_switch1 = get_sorted(stats_true, type='switch1', sortby='time') - val_switch2 = get_sorted(stats_true, type='switch2', sortby='time') - t_switch1 = [v[1] for v in val_switch1] - t_switch2 = [v[1] for v in val_switch2] + switches = get_recomputed(stats_true, type='switch', sortby='time') + t_switch = [v[1] for v in switches] - t_switch1 = t_switch1[-1] - t_switch2 = t_switch2[-1] - - val_switch_all.append([t_switch1, t_switch2]) + val_switch_all.append([t_switch[0], t_switch[1]]) vC1_true = get_sorted(stats_true, type='voltage C1', recomputed=False, sortby='time') vC2_true = get_sorted(stats_true, type='voltage C2', recomputed=False, sortby='time') @@ -187,37 +176,37 @@ def check(cwd='./'): times_false2 = [v[0] for v in vC2_false] for m in range(len(times_true1)): - if np.round(times_true1[m], 15) == np.round(t_switch1, 15): + if np.round(times_true1[m], 15) == np.round(t_switch[0], 15): diff_true_all1.append(diff_true1[m]) for m in range(len(times_true2)): - if np.round(times_true2[m], 15) == np.round(t_switch2, 15): + if np.round(times_true2[m], 15) == np.round(t_switch[1], 15): diff_true_all2.append(diff_true2[m]) for m in range(1, len(times_false1)): - if times_false1[m - 1] < t_switch1 < times_false1[m]: + if times_false1[m - 1] < t_switch[0] < times_false1[m]: diff_false_all_before1.append(diff_false1[m - 1]) diff_false_all_after1.append(diff_false1[m]) for m in range(1, len(times_false2)): - if times_false2[m - 1] < t_switch2 < times_false2[m]: + if times_false2[m - 1] < t_switch[1] < times_false2[m]: diff_false_all_before2.append(diff_false2[m - 1]) diff_false_all_after2.append(diff_false2[m]) restarts_dt = restarts_dict[dt_item] for i in range(len(restarts_dt[:, 0])): - if round(restarts_dt[i, 0], 13) == round(t_switch1, 13): - restarts_dt_switch1.append(np.sum(restarts_dt[0:i, 1])) + if round(restarts_dt[i, 0], 13) == round(t_switch[0], 13): + restarts_dt_switch1.append(np.sum(restarts_dt[0 : i - 1, 1])) - if round(restarts_dt[i, 0], 13) == round(t_switch2, 13): - restarts_dt_switch2.append(np.sum(restarts_dt[i - 1 :, 1])) + if round(restarts_dt[i, 0], 13) == round(t_switch[1], 13): + restarts_dt_switch2.append(np.sum(restarts_dt[i - 2 :, 1])) setup_mpl() fig1, ax1 = plt_helper.plt.subplots(1, 1, figsize=(4.5, 3)) ax1.set_title('Time evolution of $v_{C_{1}}-V_{ref1}$') ax1.plot(times_true1, diff_true1, label='SE=True', color='#ff7f0e') ax1.plot(times_false1, diff_false1, label='SE=False', color='#1f77b4') - ax1.axvline(x=t_switch1, linestyle='--', color='k', label='Switch1') + ax1.axvline(x=t_switch[0], linestyle='--', color='k', label='Switch1') ax1.legend(frameon=False, fontsize=10, loc='lower left') ax1.set_yscale('symlog', linthresh=1e-5) ax1.set_xlabel('Time') @@ -230,7 +219,7 @@ def check(cwd='./'): ax2.set_title('Time evolution of $v_{C_{2}}-V_{ref2}$') ax2.plot(times_true2, diff_true2, label='SE=True', color='#ff7f0e') ax2.plot(times_false2, diff_false2, label='SE=False', color='#1f77b4') - ax2.axvline(x=t_switch2, linestyle='--', color='k', label='Switch2') + ax2.axvline(x=t_switch[1], linestyle='--', color='k', label='Switch2') ax2.legend(frameon=False, fontsize=10, loc='lower left') ax2.set_yscale('symlog', linthresh=1e-5) ax2.set_xlabel('Time') @@ -287,5 +276,66 @@ def check(cwd='./'): plt_helper.plt.close(fig2) +def check_solution(stats, dt, use_switch_estimator): + """ + Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen. + + Args: + stats (dict): Raw statistics from a controller run + dt (float): initial time step + use_switch_estimator (bool): flag if the switch estimator wants to be used or not + """ + + data = get_data_dict(stats, use_switch_estimator) + + if use_switch_estimator: + msg = f'Error when using the switch estimator for battery_2condensators for dt={dt:.1e}:' + if dt == 4e-1: + expected = { + 'cL': 1.1842780233981391, + 'vC1': 1.0094891393319418, + 'vC2': 1.00103823232433, + 'switch1': 1.6075867934844466, + 'switch2': 3.209437912436633, + 'restarts': 2.0, + 'sum_niters': 2000, + } + elif dt == 4e-2: + expected = { + 'cL': 1.180493652021971, + 'vC1': 1.0094825917376264, + 'vC2': 1.0007713468084405, + 'switch1': 1.6094074085553605, + 'switch2': 3.209437912440314, + 'restarts': 2.0, + 'sum_niters': 2364, + } + elif dt == 4e-3: + expected = { + 'cL': 1.1537529501025199, + 'vC1': 1.001438946726028, + 'vC2': 1.0004331625246141, + 'switch1': 1.6093728710270467, + 'switch2': 3.217437912434171, + 'restarts': 2.0, + 'sum_niters': 8920, + } + + got = { + 'cL': data['cL'][-1], + 'vC1': data['vC1'][-1], + 'vC2': data['vC2'][-1], + 'switch1': data['switch1'], + 'switch2': data['switch2'], + 'restarts': data['restarts'], + 'sum_niters': data['sum_niters'], + } + + for key in expected.keys(): + assert np.isclose( + expected[key], got[key], rtol=1e-4 + ), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}' + + if __name__ == "__main__": check() diff --git a/pySDC/projects/PinTSimE/switch_estimator.py b/pySDC/projects/PinTSimE/switch_estimator.py index e43fbb49b3..f31a290ea3 100644 --- a/pySDC/projects/PinTSimE/switch_estimator.py +++ b/pySDC/projects/PinTSimE/switch_estimator.py @@ -2,7 +2,7 @@ import scipy as sp from pySDC.core.Collocation import CollBase -from pySDC.core.ConvergenceController import ConvergenceController +from pySDC.core.ConvergenceController import ConvergenceController, Status class SwitchEstimator(ConvergenceController): @@ -31,13 +31,34 @@ def setup(self, controller, params, description): num_nodes=description['sweeper_params']['num_nodes'], quad_type=description['sweeper_params']['quad_type'], ) - self.coll_nodes_local = coll.nodes - self.switch_detected = False - self.switch_detected_step = False - self.t_switch = None - self.count_switches = 0 - self.dt_initial = description['level_params']['dt'] - return {'control_order': 100, **params} + + defaults = { + 'control_order': 100, + 'tol': description['level_params']['dt'], + 'coll_nodes': coll.nodes, + 'dt_initial': description['level_params']['dt'], + } + return {**defaults, **params} + + def setup_status_variables(self, controller, **kwargs): + """ + Adds switching specific variables to status variables. + + Args: + controller (pySDC.Controller): The controller + """ + + self.status = Status(['t_switch', 'switch_detected', 'switch_detected_step']) + + def reset_status_variables(self, controller, **kwargs): + """ + Resets status variables. + + Args: + controller (pySDC.Controller): The controller + """ + + self.setup_status_variables(controller, **kwargs) def get_new_step_size(self, controller, S): """ @@ -51,76 +72,62 @@ def get_new_step_size(self, controller, S): None """ - self.switch_detected = False # reset between steps - L = S.levels[0] - if not type(L.prob.params.V_ref) == int and not type(L.prob.params.V_ref) == float: - # if V_ref is not a scalar, but an (np.)array - V_ref = np.zeros(np.shape(L.prob.params.V_ref)[0], dtype=float) - for m in range(np.shape(L.prob.params.V_ref)[0]): - V_ref[m] = L.prob.params.V_ref[m] - else: - V_ref = np.array([L.prob.params.V_ref], dtype=float) + if S.status.iter == S.params.maxiter: - if S.status.iter > 0 and self.count_switches < np.shape(V_ref)[0]: - for m in range(len(L.u)): - if L.u[m][self.count_switches + 1] - V_ref[self.count_switches] <= 0: - self.switch_detected = True - m_guess = m - 1 - break + self.status.switch_detected, m_guess, vC_switch = L.prob.get_switching_info(L.u, L.time) - if self.switch_detected: - t_interp = [L.time + L.dt * self.coll_nodes_local[m] for m in range(len(self.coll_nodes_local))] - - vC_switch = [] - for m in range(1, len(L.u)): - vC_switch.append(L.u[m][self.count_switches + 1] - V_ref[self.count_switches]) + if self.status.switch_detected: + t_interp = [L.time + L.dt * self.params.coll_nodes[m] for m in range(len(self.params.coll_nodes))] # only find root if vc_switch[0], vC_switch[-1] have opposite signs (intermediate value theorem) if vC_switch[0] * vC_switch[-1] < 0: - self.t_switch = self.get_switch(t_interp, vC_switch, m_guess) + self.status.t_switch = self.get_switch(t_interp, vC_switch, m_guess) - # if the switch is not find, we need to do ... ? - if L.time < self.t_switch < L.time + L.dt: - r = 1 - tol = self.dt_initial / r + if L.time < self.status.t_switch < L.time + L.dt: - if not np.isclose(self.t_switch - L.time, L.dt, atol=tol): - dt_search = self.t_switch - L.time + dt_switch = self.status.t_switch - L.time + if not np.isclose(self.status.t_switch - L.time, L.dt, atol=self.params.tol): + self.log( + f"Located Switch at time {self.status.t_switch:.6f} is outside the range of tol={self.params.tol:.4e}", + S, + ) else: - print('Switch located at time: {}'.format(self.t_switch)) - dt_search = self.t_switch - L.time - L.prob.params.set_switch[self.count_switches] = self.switch_detected - L.prob.params.t_switch[self.count_switches] = self.t_switch - controller.hooks.add_to_stats( + self.log( + f"Switch located at time {self.status.t_switch:.6f} inside tol={self.params.tol:.4e}", S + ) + + L.prob.t_switch = self.status.t_switch + controller.hooks[0].add_to_stats( process=S.status.slot, time=L.time, level=L.level_index, iter=0, sweep=L.status.sweep, - type='switch{}'.format(self.count_switches + 1), - value=self.t_switch, + type='switch', + value=self.status.t_switch, ) - self.switch_detected_step = True + L.prob.count_switches() + self.status.switch_detected_step = True dt_planned = L.status.dt_new if L.status.dt_new is not None else L.params.dt # when a switch is found, time step to match with switch should be preferred - if self.switch_detected: - L.status.dt_new = dt_search + if self.status.switch_detected: + L.status.dt_new = dt_switch else: - L.status.dt_new = min([dt_planned, dt_search]) + L.status.dt_new = min([dt_planned, dt_switch]) else: - self.switch_detected = False + self.status.switch_detected = False else: - self.switch_detected = False + self.status.switch_detected = False def determine_restart(self, controller, S): """ @@ -134,8 +141,7 @@ def determine_restart(self, controller, S): None """ - if self.switch_detected: - print("Restart") + if self.status.switch_detected: S.status.restart = True S.status.force_done = True @@ -156,13 +162,9 @@ def post_step_processing(self, controller, S): L = S.levels[0] - if self.switch_detected_step: - if L.prob.params.set_switch[self.count_switches] and L.time + L.dt >= self.t_switch: - self.count_switches += 1 - self.t_switch = None - self.switch_detected_step = False - - L.status.dt_new = self.dt_initial + if self.status.switch_detected_step: + if L.time + L.dt >= self.params.t_switch: + L.status.dt_new = L.status.dt_new if L.status.dt_new is not None else L.params.dt super(SwitchEstimator, self).post_step_processing(controller, S) diff --git a/pySDC/projects/Resilience/accuracy_check.py b/pySDC/projects/Resilience/accuracy_check.py index 4c51d696c1..0bbc375ed9 100644 --- a/pySDC/projects/Resilience/accuracy_check.py +++ b/pySDC/projects/Resilience/accuracy_check.py @@ -34,24 +34,6 @@ def post_step(self, step, level_number): L.sweep.compute_end_point() - self.add_to_stats( - process=step.status.slot, - time=L.time + L.dt, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='e_embedded', - value=L.status.error_embedded_estimate, - ) - self.add_to_stats( - process=step.status.slot, - time=L.time + L.dt, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='e_extrapolated', - value=L.status.get('error_extrapolation_estimate'), - ) self.add_to_stats( process=step.status.slot, time=L.time, @@ -105,8 +87,8 @@ def get_results_from_stats(stats, var, val, hook_class=log_errors): } if hook_class == log_errors: - e_extrapolated = np.array(get_sorted(stats, type='e_extrapolated'))[:, 1] - e_embedded = np.array(get_sorted(stats, type='e_embedded'))[:, 1] + e_extrapolated = np.array(get_sorted(stats, type='error_extrapolation_estimate'))[:, 1] + e_embedded = np.array(get_sorted(stats, type='error_embedded_estimate'))[:, 1] e_loc = np.array(get_sorted(stats, type='e_loc'))[:, 1] if len(e_extrapolated[e_extrapolated != [None]]) > 0: diff --git a/pySDC/projects/Resilience/advection.py b/pySDC/projects/Resilience/advection.py index 6515a0931e..1591d340c2 100644 --- a/pySDC/projects/Resilience/advection.py +++ b/pySDC/projects/Resilience/advection.py @@ -6,7 +6,7 @@ from pySDC.core.Hooks import hooks from pySDC.helpers.stats_helper import get_sorted import numpy as np -from pySDC.projects.Resilience.hook import log_error_estimates +from pySDC.projects.Resilience.hook import log_data def plot_embedded(stats, ax): @@ -21,17 +21,8 @@ def plot_embedded(stats, ax): ax.legend(frameon=False) -class log_data(hooks): - def pre_run(self, step, level_number): - """ - Record los conditiones initiales - """ - super(log_data, self).pre_run(step, level_number) - L = step.levels[level_number] - self.add_to_stats(process=0, time=0, level=0, iter=0, sweep=0, type='u0', value=L.u[0]) - +class log_every_iteration(hooks): def post_iteration(self, step, level_number): - super(log_data, self).post_iteration(step, level_number) if step.status.iter == step.params.maxiter - 1: L = step.levels[level_number] L.sweep.compute_end_point() @@ -45,58 +36,12 @@ def post_iteration(self, step, level_number): value=L.uold[-1], ) - def post_step(self, step, level_number): - - super(log_data, self).post_step(step, level_number) - - # some abbreviations - L = step.levels[level_number] - - L.sweep.compute_end_point() - - self.add_to_stats( - process=step.status.slot, - time=L.time + L.dt, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='u', - value=L.uend, - ) - self.add_to_stats( - process=step.status.slot, - time=L.time, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='dt', - value=L.dt, - ) - self.add_to_stats( - process=step.status.slot, - time=L.time + L.dt, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='e_embedded', - value=L.status.get('error_embedded_estimate'), - ) - self.add_to_stats( - process=step.status.slot, - time=L.time + L.dt, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='e_extrapolated', - value=L.status.get('error_extrapolation_estimate'), - ) - def run_advection( custom_description=None, num_procs=1, Tend=2e-1, - hook_class=log_error_estimates, + hook_class=log_data, fault_stuff=None, custom_controller_params=None, custom_problem_params=None, diff --git a/pySDC/projects/Resilience/heat.py b/pySDC/projects/Resilience/heat.py index e4d3e55bb4..770002ffab 100644 --- a/pySDC/projects/Resilience/heat.py +++ b/pySDC/projects/Resilience/heat.py @@ -5,7 +5,7 @@ from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI from pySDC.core.Hooks import hooks from pySDC.helpers.stats_helper import get_sorted -from pySDC.projects.Resilience.hook import log_error_estimates +from pySDC.projects.Resilience.hook import log_data import numpy as np @@ -13,7 +13,7 @@ def run_heat( custom_description=None, num_procs=1, Tend=2e-1, - hook_class=log_error_estimates, + hook_class=log_data, fault_stuff=None, custom_controller_params=None, custom_problem_params=None, diff --git a/pySDC/projects/Resilience/hook.py b/pySDC/projects/Resilience/hook.py index 0ad5bdfb6f..a980e0223c 100644 --- a/pySDC/projects/Resilience/hook.py +++ b/pySDC/projects/Resilience/hook.py @@ -1,7 +1,14 @@ from pySDC.core.Hooks import hooks +from pySDC.implementations.hooks.log_solution import LogSolution +from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate +from pySDC.implementations.hooks.log_extrapolated_error_estimate import LogExtrapolationErrorEstimate +from pySDC.implementations.hooks.log_step_size import LogStepSize -class log_error_estimates(hooks): +hook_collection = [LogSolution, LogEmbeddedErrorEstimate, LogExtrapolationErrorEstimate, LogStepSize] + + +class log_data(hooks): """ Record data required for analysis of problems in the resilience project """ @@ -10,7 +17,8 @@ def pre_run(self, step, level_number): """ Record los conditiones initiales """ - super(log_error_estimates, self).pre_run(step, level_number) + super().pre_run(step, level_number) + L = step.levels[level_number] self.add_to_stats(process=0, time=0, level=0, iter=0, sweep=0, type='u0', value=L.u[0]) @@ -18,49 +26,10 @@ def post_step(self, step, level_number): """ Record final solutions as well as step size and error estimates """ - super(log_error_estimates, self).post_step(step, level_number) + super().post_step(step, level_number) - # some abbreviations L = step.levels[level_number] - L.sweep.compute_end_point() - - self.add_to_stats( - process=step.status.slot, - time=L.time + L.dt, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='u', - value=L.uend, - ) - self.add_to_stats( - process=step.status.slot, - time=L.time, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='dt', - value=L.dt, - ) - self.add_to_stats( - process=step.status.slot, - time=L.time + L.dt, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='e_embedded', - value=L.status.__dict__.get('error_embedded_estimate', None), - ) - self.add_to_stats( - process=step.status.slot, - time=L.time + L.dt, - level=L.level_index, - iter=0, - sweep=L.status.sweep, - type='e_extrapolated', - value=L.status.__dict__.get('error_extrapolation_estimate', None), - ) self.add_to_stats( process=step.status.slot, time=L.time, diff --git a/pySDC/projects/Resilience/piline.py b/pySDC/projects/Resilience/piline.py index 0d301cab59..9ac4af3fbe 100644 --- a/pySDC/projects/Resilience/piline.py +++ b/pySDC/projects/Resilience/piline.py @@ -7,14 +7,14 @@ from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity from pySDC.implementations.convergence_controller_classes.hotrod import HotRod -from pySDC.projects.Resilience.hook import log_error_estimates +from pySDC.projects.Resilience.hook import log_data, hook_collection def run_piline( custom_description=None, num_procs=1, Tend=20.0, - hook_class=log_error_estimates, + hook_class=log_data, fault_stuff=None, custom_controller_params=None, custom_problem_params=None, @@ -68,7 +68,7 @@ def run_piline( # initialize controller parameters controller_params = dict() controller_params['logger_level'] = 30 - controller_params['hook_class'] = hook_class + controller_params['hook_class'] = hook_collection + [hook_class] controller_params['mssdc_jac'] = False if custom_controller_params is not None: @@ -131,8 +131,8 @@ def get_data(stats, recomputed=False): 't': np.array([me[0] for me in get_sorted(stats, type='u', recomputed=recomputed)]), 'dt': np.array([me[1] for me in get_sorted(stats, type='dt', recomputed=recomputed)]), 't_dt': np.array([me[0] for me in get_sorted(stats, type='dt', recomputed=recomputed)]), - 'e_em': np.array(get_sorted(stats, type='e_embedded', recomputed=recomputed))[:, 1], - 'e_ex': np.array(get_sorted(stats, type='e_extrapolated', recomputed=recomputed))[:, 1], + 'e_em': np.array(get_sorted(stats, type='error_embedded_estimate', recomputed=recomputed))[:, 1], + 'e_ex': np.array(get_sorted(stats, type='error_extrapolation_estimate', recomputed=recomputed))[:, 1], 'restarts': np.array(get_sorted(stats, type='restart', recomputed=None))[:, 1], 't_restarts': np.array(get_sorted(stats, type='restart', recomputed=None))[:, 0], 'sweeps': np.array(get_sorted(stats, type='sweeps', recomputed=None))[:, 1], diff --git a/pySDC/projects/Resilience/vdp.py b/pySDC/projects/Resilience/vdp.py index 2a712d39ce..de7de087c1 100644 --- a/pySDC/projects/Resilience/vdp.py +++ b/pySDC/projects/Resilience/vdp.py @@ -8,7 +8,7 @@ from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity from pySDC.core.Errors import ProblemError -from pySDC.projects.Resilience.hook import log_error_estimates +from pySDC.projects.Resilience.hook import log_data, hook_collection def plot_step_sizes(stats, ax): @@ -28,7 +28,7 @@ def plot_step_sizes(stats, ax): p = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')]) t = np.array([me[0] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')]) - e_em = np.array(get_sorted(stats, type='e_embedded', recomputed=False, sortby='time'))[:, 1] + e_em = np.array(get_sorted(stats, type='error_embedded_estimate', recomputed=False, sortby='time'))[:, 1] dt = np.array(get_sorted(stats, type='dt', recomputed=False, sortby='time')) restart = np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time')) @@ -85,7 +85,7 @@ def run_vdp( custom_description=None, num_procs=1, Tend=10.0, - hook_class=log_error_estimates, + hook_class=log_data, fault_stuff=None, custom_controller_params=None, custom_problem_params=None, @@ -138,7 +138,7 @@ def run_vdp( # initialize controller parameters controller_params = dict() controller_params['logger_level'] = 30 - controller_params['hook_class'] = hook_class + controller_params['hook_class'] = hook_collection + [hook_class] controller_params['mssdc_jac'] = False if custom_controller_params is not None: @@ -212,7 +212,7 @@ def fetch_test_data(stats, comm=None, use_MPI=False): Returns: dict: Key values to perform tests on """ - types = ['e_embedded', 'restart', 'dt', 'sweeps', 'residual_post_step'] + types = ['error_embedded_estimate', 'restart', 'dt', 'sweeps', 'residual_post_step'] data = {} for type in types: if type not in get_list_of_types(stats): diff --git a/pySDC/projects/matrixPFASST/controller_matrix_nonMPI.py b/pySDC/projects/matrixPFASST/controller_matrix_nonMPI.py index afdcfed8a0..081ca420d4 100644 --- a/pySDC/projects/matrixPFASST/controller_matrix_nonMPI.py +++ b/pySDC/projects/matrixPFASST/controller_matrix_nonMPI.py @@ -133,7 +133,8 @@ def run(self, u0, t0, Tend): # some initializations and reset of statistics uend = None num_procs = len(self.MS) - self.hooks.reset_stats() + for hook in self.hooks: + hook.reset_stats() assert ( (Tend - t0) / self.dt @@ -152,7 +153,8 @@ def run(self, u0, t0, Tend): # call pre-run hook for S in self.MS: - self.hooks.pre_run(step=S, level_number=0) + for hook in self.hooks: + hook.pre_run(step=S, level_number=0) nblocks = int((Tend - t0) / self.dt / num_procs) @@ -169,9 +171,10 @@ def run(self, u0, t0, Tend): # call post-run hook for S in self.MS: - self.hooks.post_run(step=S, level_number=0) + for hook in self.hooks: + hook.post_run(step=S, level_number=0) - return uend, self.hooks.return_stats() + return uend, self.return_stats() def build_propagation_matrix(self, niter): """ @@ -302,7 +305,8 @@ def pfasst(self, MS): MS = self.update_data(MS=MS, u=self.u, res=self.res, niter=niter, level=0, stage='PRE_STEP') for S in MS: - self.hooks.pre_step(step=S, level_number=0) + for hook in self.hooks: + hook.pre_step(step=S, level_number=0) while np.linalg.norm(self.res, np.inf) > self.tol and niter < self.maxiter: @@ -310,14 +314,16 @@ def pfasst(self, MS): MS = self.update_data(MS=MS, u=self.u, res=self.res, niter=niter, level=0, stage='PRE_ITERATION') for S in MS: - self.hooks.pre_iteration(step=S, level_number=0) + for hook in self.hooks: + hook.pre_iteration(step=S, level_number=0) if self.nlevels > 1: for _ in range(MS[0].levels[1].params.nsweeps): MS = self.update_data(MS=MS, u=self.u, res=self.res, niter=niter, level=1, stage='PRE_COARSE_SWEEP') for S in MS: - self.hooks.pre_sweep(step=S, level_number=1) + for hook in self.hooks: + hook.pre_sweep(step=S, level_number=1) self.u += self.Tcf.dot(np.linalg.solve(self.Pc, self.Tfc.dot(self.res))) self.res = self.u0 - self.C.dot(self.u) @@ -326,27 +332,32 @@ def pfasst(self, MS): MS=MS, u=self.u, res=self.res, niter=niter, level=1, stage='POST_COARSE_SWEEP' ) for S in MS: - self.hooks.post_sweep(step=S, level_number=1) + for hook in self.hooks: + hook.post_sweep(step=S, level_number=1) for _ in range(MS[0].levels[0].params.nsweeps): MS = self.update_data(MS=MS, u=self.u, res=self.res, niter=niter, level=0, stage='PRE_FINE_SWEEP') for S in MS: - self.hooks.pre_sweep(step=S, level_number=0) + for hook in self.hooks: + hook.pre_sweep(step=S, level_number=0) self.u += np.linalg.solve(self.P, self.res) self.res = self.u0 - self.C.dot(self.u) MS = self.update_data(MS=MS, u=self.u, res=self.res, niter=niter, level=0, stage='POST_FINE_SWEEP') for S in MS: - self.hooks.post_sweep(step=S, level_number=0) + for hook in self.hooks: + hook.post_sweep(step=S, level_number=0) MS = self.update_data(MS=MS, u=self.u, res=self.res, niter=niter, level=0, stage='POST_ITERATION') for S in MS: - self.hooks.post_iteration(step=S, level_number=0) + for hook in self.hooks: + hook.post_iteration(step=S, level_number=0) MS = self.update_data(MS=MS, u=self.u, res=self.res, niter=niter, level=0, stage='POST_STEP') for S in MS: - self.hooks.post_step(step=S, level_number=0) + for hook in self.hooks: + hook.post_step(step=S, level_number=0) return MS diff --git a/pyproject.toml b/pyproject.toml index ede34d02c1..f8bbeb3e87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ pyflakes = [ '-E203', '-E741', '-E402', '-W504', '-W605', '-F401' ] #flake8-black = ["+*"] -flake8-bugbear = ["+*", '-B023'] +flake8-bugbear = ["+*", '-B023', '-B028'] flake8-comprehensions = ["+*", '-C408', '-C417'] [tool.black]