diff --git a/.github/workflows/ci_pipeline.yml b/.github/workflows/ci_pipeline.yml
index 6515e8d272..b45cf0b402 100644
--- a/.github/workflows/ci_pipeline.yml
+++ b/.github/workflows/ci_pipeline.yml
@@ -35,23 +35,23 @@ jobs:
run: |
flakeheaven lint --benchmark pySDC
- mirror_to_gitlab:
+# mirror_to_gitlab:
- runs-on: ubuntu-latest
+# runs-on: ubuntu-latest
- steps:
- - name: Checkout
- uses: actions/checkout@v1
+# steps:
+# - name: Checkout
+# uses: actions/checkout@v1
- - name: Mirror
- uses: jakob-fritz/github2lab_action@main
- env:
- MODE: 'mirror' # Either 'mirror', 'get_status', or 'both'
- GITLAB_TOKEN: ${{ secrets.GITLAB_SECRET_H }}
- FORCE_PUSH: "true"
- GITLAB_HOSTNAME: "codebase.helmholtz.cloud"
- GITLAB_PROJECT_ID: "3525"
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+# - name: Mirror
+# uses: jakob-fritz/github2lab_action@main
+# env:
+# MODE: 'mirror' # Either 'mirror', 'get_status', or 'both'
+# GITLAB_TOKEN: ${{ secrets.GITLAB_SECRET_H }}
+# FORCE_PUSH: "true"
+# GITLAB_HOSTNAME: "codebase.helmholtz.cloud"
+# GITLAB_PROJECT_ID: "3525"
+# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
user_cpu_tests_linux:
runs-on: ubuntu-latest
@@ -121,31 +121,31 @@ jobs:
pytest --continue-on-collection-errors -v --durations=0 pySDC/tests -m ${{ matrix.env }}
- wait_for_gitlab:
- runs-on: ubuntu-latest
+# wait_for_gitlab:
+# runs-on: ubuntu-latest
- needs:
- - mirror_to_gitlab
+# needs:
+# - mirror_to_gitlab
- steps:
- - name: Wait
- uses: jakob-fritz/github2lab_action@main
- env:
- MODE: 'get_status' # Either 'mirror', 'get_status', or 'both'
- GITLAB_TOKEN: ${{ secrets.GITLAB_SECRET_H }}
- FORCE_PUSH: "true"
- GITLAB_HOSTNAME: "codebase.helmholtz.cloud"
- GITLAB_PROJECT_ID: "3525"
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+# steps:
+# - name: Wait
+# uses: jakob-fritz/github2lab_action@main
+# env:
+# MODE: 'get_status' # Either 'mirror', 'get_status', or 'both'
+# GITLAB_TOKEN: ${{ secrets.GITLAB_SECRET_H }}
+# FORCE_PUSH: "true"
+# GITLAB_HOSTNAME: "codebase.helmholtz.cloud"
+# GITLAB_PROJECT_ID: "3525"
+# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
-# - name: Get and prepare artifacts
-# run: |
-# pipeline_id=$(curl --header "PRIVATE-TOKEN: ${{ secrets.GITLAB_SECRET_H }}" --silent "https://gitlab.hzdr.de/api/v4/projects/3525/repository/commits/${{ github.head_ref || github.ref_name }}" | jq '.last_pipeline.id')
-# job_id=$(curl --header "PRIVATE-TOKEN: ${{ secrets.GITLAB_SECRET_H }}" --silent "https://gitlab.hzdr.de/api/v4/projects/3525/pipelines/$pipeline_id/jobs" | jq '.[] | select( .name == "bundle" ) | select( .status == "success" ) | .id')
-# curl --output artifacts.zip "https://gitlab.hzdr.de/api/v4/projects/3525/jobs/$job_id/artifacts"
-# rm -rf data
-# unzip artifacts.zip
-# ls -ratl
+# # - name: Get and prepare artifacts
+# # run: |
+# # pipeline_id=$(curl --header "PRIVATE-TOKEN: ${{ secrets.GITLAB_SECRET_H }}" --silent "https://gitlab.hzdr.de/api/v4/projects/3525/repository/commits/${{ github.head_ref || github.ref_name }}" | jq '.last_pipeline.id')
+# # job_id=$(curl --header "PRIVATE-TOKEN: ${{ secrets.GITLAB_SECRET_H }}" --silent "https://gitlab.hzdr.de/api/v4/projects/3525/pipelines/$pipeline_id/jobs" | jq '.[] | select( .name == "bundle" ) | select( .status == "success" ) | .id')
+# # curl --output artifacts.zip "https://gitlab.hzdr.de/api/v4/projects/3525/jobs/$job_id/artifacts"
+# # rm -rf data
+# # unzip artifacts.zip
+# # ls -ratl
post-processing:
@@ -156,7 +156,7 @@ jobs:
needs:
- lint
- user_cpu_tests_linux
- - wait_for_gitlab
+# - wait_for_gitlab
defaults:
run:
@@ -188,7 +188,10 @@ jobs:
run: |
pip install genbadge[all]
genbadge coverage -i coverage.xml -o htmlcov/coverage-badge.svg
-
+
+ - name: Upload coverage reports to Codecov
+ uses: codecov/codecov-action@v3
+
# - name: Generate benchmark report
# uses: pancetta/github-action-benchmark@v1
# if: ${{ (!contains(github.event.head_commit.message, '[CI-no-benchmarks]')) && (github.event_name == 'push') }}
diff --git a/README.rst b/README.rst
index 41ce82469e..e420177c73 100644
--- a/README.rst
+++ b/README.rst
@@ -1,5 +1,7 @@
|badge-ga|
|badge-ossf|
+|badge-cc|
+|zenodo|
Welcome to pySDC!
=================
@@ -79,7 +81,9 @@ This project also received funding from the `German Federal Ministry of Educatio
The project also received help from the `Helmholtz Platform for Research Software Engineering - Preparatory Study (HiRSE_PS) `_.
-.. |badge-ga| image:: https://github.com/Parallel-in-Time/pySDC/actions/workflows/ci_pipeline.yml/badge.svg
+.. |badge-ga| image:: https://github.com/Parallel-in-Time/pySDC/actions/workflows/ci_pipeline.yml/badge.svg?branch=master
:target: https://github.com/Parallel-in-Time/pySDC/actions/workflows/ci_pipeline.yml
.. |badge-ossf| image:: https://bestpractices.coreinfrastructure.org/projects/6909/badge
- :target: https://bestpractices.coreinfrastructure.org/projects/6909
\ No newline at end of file
+ :target: https://bestpractices.coreinfrastructure.org/projects/6909
+.. |badge-cc| image:: https://codecov.io/gh/Parallel-in-Time/pySDC/branch/master/graph/badge.svg?token=hpP18dmtgS
+ :target: https://codecov.io/gh/Parallel-in-Time/pySDC
diff --git a/pySDC/core/Controller.py b/pySDC/core/Controller.py
index c5257d2316..e4d213e475 100644
--- a/pySDC/core/Controller.py
+++ b/pySDC/core/Controller.py
@@ -3,10 +3,10 @@
import sys
import numpy as np
-from pySDC.core import Hooks as hookclass
from pySDC.core.BaseTransfer import base_transfer
from pySDC.helpers.pysdc_helper import FrozenClass
from pySDC.implementations.convergence_controller_classes.check_convergence import CheckConvergence
+from pySDC.implementations.hooks.default_hook import DefaultHooks
# short helper class to add params as attributes
@@ -41,10 +41,15 @@ def __init__(self, controller_params, description):
"""
# check if we have a hook on this list. If not, use default class.
- controller_params['hook_class'] = controller_params.get('hook_class', hookclass.hooks)
- self.__hooks = controller_params['hook_class']()
+ self.__hooks = []
+ hook_classes = [DefaultHooks]
+ user_hooks = controller_params.get('hook_class', [])
+ hook_classes += user_hooks if type(user_hooks) == list else [user_hooks]
+ [self.add_hook(hook) for hook in hook_classes]
+ controller_params['hook_class'] = controller_params.get('hook_class', hook_classes)
- self.hooks.pre_setup(step=None, level_number=None)
+ for hook in self.hooks:
+ hook.pre_setup(step=None, level_number=None)
self.params = _Pars(controller_params)
@@ -101,6 +106,20 @@ def __setup_custom_logger(level=None, log_to_file=None, fname=None):
else:
pass
+ def add_hook(self, hook):
+ """
+ Add a hook to the controller which will be called in addition to all other hooks whenever something happens.
+ The hook is only added if a hook of the same class is not already present.
+
+ Args:
+ hook (pySDC.Hook): A hook class that is derived from the core hook class
+
+ Returns:
+ None
+ """
+ if hook not in [type(me) for me in self.hooks]:
+ self.__hooks += [hook()]
+
def welcome_message(self):
out = (
"Welcome to the one and only, really very astonishing and 87.3% bug free"
@@ -308,3 +327,15 @@ def get_convergence_controllers_as_table(self, description):
out += f'\n{user_added}|{i:3} | {C.params.control_order:5} | {type(C).__name__}'
return out
+
+ def return_stats(self):
+ """
+ Return the merged stats from all hooks
+
+ Returns:
+ dict: Merged stats from all hooks
+ """
+ stats = {}
+ for hook in self.hooks:
+ stats = {**stats, **hook.return_stats()}
+ return stats
diff --git a/pySDC/core/Hooks.py b/pySDC/core/Hooks.py
index 5268278529..dfff1045af 100644
--- a/pySDC/core/Hooks.py
+++ b/pySDC/core/Hooks.py
@@ -1,5 +1,4 @@
import logging
-import time
from collections import namedtuple
@@ -8,23 +7,13 @@ class hooks(object):
"""
Hook class to contain the functions called during the controller runs (e.g. for calling user-routines)
+ When deriving a custom hook from this class make sure to always call the parent method using e.g.
+ `super().post_step(step, level_number)`. Otherwise bugs may arise when using `filer_recomputed` from the stats
+ helper for post processing.
+
Attributes:
- __t0_setup (float): private variable to get starting time of setup
- __t0_run (float): private variable to get starting time of the run
- __t0_predict (float): private variable to get starting time of the predictor
- __t0_step (float): private variable to get starting time of the step
- __t0_iteration (float): private variable to get starting time of the iteration
- __t0_sweep (float): private variable to get starting time of the sweep
- __t0_comm (list): private variable to get starting time of the communication
- __t1_run (float): private variable to get end time of the run
- __t1_predict (float): private variable to get end time of the predictor
- __t1_step (float): private variable to get end time of the step
- __t1_iteration (float): private variable to get end time of the iteration
- __t1_sweep (float): private variable to get end time of the sweep
- __t1_setup (float): private variable to get end time of setup
- __t1_comm (list): private variable to hold timing of the communication (!)
- __num_restarts (int): number of restarts of the current step
logger: logger instance for output
+ __num_restarts (int): number of restarts of the current step
__stats (dict): dictionary for gathering the statistics of a run
__entry (namedtuple): statistics entry containing all information to identify the value
"""
@@ -33,20 +22,6 @@ def __init__(self):
"""
Initialization routine
"""
- self.__t0_setup = None
- self.__t0_run = None
- self.__t0_predict = None
- self.__t0_step = None
- self.__t0_iteration = None
- self.__t0_sweep = None
- self.__t0_comm = []
- self.__t1_run = None
- self.__t1_predict = None
- self.__t1_step = None
- self.__t1_iteration = None
- self.__t1_sweep = None
- self.__t1_setup = None
- self.__t1_comm = []
self.__num_restarts = 0
self.logger = logging.getLogger('hooks')
@@ -130,7 +105,6 @@ def pre_setup(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
- self.__t0_setup = time.perf_counter()
def pre_run(self, step, level_number):
"""
@@ -141,7 +115,6 @@ def pre_run(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
- self.__t0_run = time.perf_counter()
def pre_predict(self, step, level_number):
"""
@@ -151,7 +124,7 @@ def pre_predict(self, step, level_number):
step (pySDC.Step.step): the current step
level_number (int): the current level number
"""
- self.__t0_predict = time.perf_counter()
+ self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
def pre_step(self, step, level_number):
"""
@@ -162,7 +135,6 @@ def pre_step(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
- self.__t0_step = time.perf_counter()
def pre_iteration(self, step, level_number):
"""
@@ -173,7 +145,6 @@ def pre_iteration(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
- self.__t0_iteration = time.perf_counter()
def pre_sweep(self, step, level_number):
"""
@@ -184,7 +155,6 @@ def pre_sweep(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
- self.__t0_sweep = time.perf_counter()
def pre_comm(self, step, level_number):
"""
@@ -195,16 +165,6 @@ def pre_comm(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
- if len(self.__t0_comm) >= level_number + 1:
- self.__t0_comm[level_number] = time.perf_counter()
- else:
- while len(self.__t0_comm) < level_number:
- self.__t0_comm.append(None)
- self.__t0_comm.append(time.perf_counter())
- while len(self.__t1_comm) <= level_number:
- self.__t1_comm.append(0.0)
- assert len(self.__t0_comm) == level_number + 1
- assert len(self.__t1_comm) == level_number + 1
def post_comm(self, step, level_number, add_to_stats=False):
"""
@@ -216,22 +176,6 @@ def post_comm(self, step, level_number, add_to_stats=False):
add_to_stats (bool): set if result should go to stats object
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
- assert len(self.__t1_comm) >= level_number + 1
- self.__t1_comm[level_number] += time.perf_counter() - self.__t0_comm[level_number]
-
- if add_to_stats:
- L = step.levels[level_number]
-
- self.add_to_stats(
- process=step.status.slot,
- time=L.time,
- level=L.level_index,
- iter=step.status.iter,
- sweep=L.status.sweep,
- type='timing_comm',
- value=self.__t1_comm[level_number],
- )
- self.__t1_comm[level_number] = 0.0
def post_sweep(self, step, level_number):
"""
@@ -242,39 +186,6 @@ def post_sweep(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
- self.__t1_sweep = time.perf_counter()
-
- L = step.levels[level_number]
-
- self.logger.info(
- 'Process %2i on time %8.6f at stage %15s: Level: %s -- Iteration: %2i -- Sweep: %2i -- ' 'residual: %12.8e',
- step.status.slot,
- L.time,
- step.status.stage,
- L.level_index,
- step.status.iter,
- L.status.sweep,
- L.status.residual,
- )
-
- self.add_to_stats(
- process=step.status.slot,
- time=L.time,
- level=L.level_index,
- iter=step.status.iter,
- sweep=L.status.sweep,
- type='residual_post_sweep',
- value=L.status.residual,
- )
- self.add_to_stats(
- process=step.status.slot,
- time=L.time,
- level=L.level_index,
- iter=step.status.iter,
- sweep=L.status.sweep,
- type='timing_sweep',
- value=self.__t1_sweep - self.__t0_sweep,
- )
def post_iteration(self, step, level_number):
"""
@@ -286,29 +197,6 @@ def post_iteration(self, step, level_number):
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
- self.__t1_iteration = time.perf_counter()
-
- L = step.levels[level_number]
-
- self.add_to_stats(
- process=step.status.slot,
- time=L.time,
- level=-1,
- iter=step.status.iter,
- sweep=L.status.sweep,
- type='residual_post_iteration',
- value=L.status.residual,
- )
- self.add_to_stats(
- process=step.status.slot,
- time=L.time,
- level=L.level_index,
- iter=step.status.iter,
- sweep=L.status.sweep,
- type='timing_iteration',
- value=self.__t1_iteration - self.__t0_iteration,
- )
-
def post_step(self, step, level_number):
"""
Default routine called after each step or block
@@ -319,44 +207,6 @@ def post_step(self, step, level_number):
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
- self.__t1_step = time.perf_counter()
-
- L = step.levels[level_number]
-
- self.add_to_stats(
- process=step.status.slot,
- time=L.time,
- level=L.level_index,
- iter=step.status.iter,
- sweep=L.status.sweep,
- type='timing_step',
- value=self.__t1_step - self.__t0_step,
- )
- self.add_to_stats(
- process=step.status.slot,
- time=L.time,
- level=-1,
- iter=step.status.iter,
- sweep=L.status.sweep,
- type='niter',
- value=step.status.iter,
- )
- self.add_to_stats(
- process=step.status.slot,
- time=L.time,
- level=L.level_index,
- iter=-1,
- sweep=L.status.sweep,
- type='residual_post_step',
- value=L.status.residual,
- )
-
- # record the recomputed quantities at weird positions to make sure there is only one value for each step
- for t in [L.time, L.time + L.dt]:
- self.add_to_stats(
- process=-1, time=t, level=-1, iter=-1, sweep=-1, type='_recomputed', value=step.status.get('restart')
- )
-
def post_predict(self, step, level_number):
"""
Default routine called after each predictor
@@ -366,19 +216,6 @@ def post_predict(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
- self.__t1_predict = time.perf_counter()
-
- L = step.levels[level_number]
-
- self.add_to_stats(
- process=step.status.slot,
- time=L.time,
- level=L.level_index,
- iter=step.status.iter,
- sweep=L.status.sweep,
- type='timing_predictor',
- value=self.__t1_predict - self.__t0_predict,
- )
def post_run(self, step, level_number):
"""
@@ -389,19 +226,6 @@ def post_run(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
- self.__t1_run = time.perf_counter()
-
- L = step.levels[level_number]
-
- self.add_to_stats(
- process=step.status.slot,
- time=L.time,
- level=L.level_index,
- iter=step.status.iter,
- sweep=L.status.sweep,
- type='timing_run',
- value=self.__t1_run - self.__t0_run,
- )
def post_setup(self, step, level_number):
"""
@@ -412,14 +236,3 @@ def post_setup(self, step, level_number):
level_number (int): the current level number
"""
self.__num_restarts = step.status.get('restarts_in_a_row') if step is not None else 0
- self.__t1_setup = time.perf_counter()
-
- self.add_to_stats(
- process=-1,
- time=-1,
- level=-1,
- iter=-1,
- sweep=-1,
- type='timing_setup',
- value=self.__t1_setup - self.__t0_setup,
- )
diff --git a/pySDC/implementations/controller_classes/controller_MPI.py b/pySDC/implementations/controller_classes/controller_MPI.py
index 9383e87608..7dc6bcd781 100644
--- a/pySDC/implementations/controller_classes/controller_MPI.py
+++ b/pySDC/implementations/controller_classes/controller_MPI.py
@@ -83,7 +83,8 @@ def run(self, u0, t0, Tend):
"""
# reset stats to prevent double entries from old runs
- self.hooks.reset_stats()
+ for hook in self.hooks:
+ hook.reset_stats()
# find active processes and put into new communicator
rank = self.comm.Get_rank()
@@ -111,10 +112,12 @@ def run(self, u0, t0, Tend):
uend = u0
# call post-setup hook
- self.hooks.post_setup(step=None, level_number=None)
+ for hook in self.hooks:
+ hook.post_setup(step=None, level_number=None)
# call pre-run hook
- self.hooks.pre_run(step=self.S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_run(step=self.S, level_number=0)
comm_active.Barrier()
@@ -162,11 +165,12 @@ def run(self, u0, t0, Tend):
self.restart_block(num_procs, time, uend, comm=comm_active)
# call post-run hook
- self.hooks.post_run(step=self.S, level_number=0)
+ for hook in self.hooks:
+ hook.post_run(step=self.S, level_number=0)
comm_active.Free()
- return uend, self.hooks.return_stats()
+ return uend, self.return_stats()
def restart_block(self, size, time, u0, comm):
"""
@@ -243,7 +247,8 @@ def send_full(self, comm=None, blocking=False, level=None, add_to_stats=False):
level: the level number
add_to_stats: a flag to end recording data in the hooks (defaults to False)
"""
- self.hooks.pre_comm(step=self.S, level_number=level)
+ for hook in self.hooks:
+ hook.pre_comm(step=self.S, level_number=level)
if not blocking:
self.wait_with_interrupt(request=self.req_send[level])
@@ -272,7 +277,8 @@ def send_full(self, comm=None, blocking=False, level=None, add_to_stats=False):
if self.S.status.force_done:
return None
- self.hooks.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats)
+ for hook in self.hooks:
+ hook.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats)
def recv_full(self, comm, level=None, add_to_stats=False):
"""
@@ -284,7 +290,8 @@ def recv_full(self, comm, level=None, add_to_stats=False):
add_to_stats: a flag to end recording data in the hooks (defaults to False)
"""
- self.hooks.pre_comm(step=self.S, level_number=level)
+ for hook in self.hooks:
+ hook.pre_comm(step=self.S, level_number=level)
if not self.S.status.first and not self.S.status.prev_done:
self.logger.debug(
'recv data: process %s, stage %s, time %s, source %s, tag %s, iter %s'
@@ -299,7 +306,8 @@ def recv_full(self, comm, level=None, add_to_stats=False):
)
self.recv(target=self.S.levels[level], source=self.S.prev, tag=level * 100 + self.S.status.iter, comm=comm)
- self.hooks.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats)
+ for hook in self.hooks:
+ hook.post_comm(step=self.S, level_number=level, add_to_stats=add_to_stats)
def wait_with_interrupt(self, request):
"""
@@ -334,7 +342,8 @@ def check_iteration_estimate(self, comm):
diff_new = max(diff_new, abs(L.uold[m] - L.u[m]))
# Send forward diff
- self.hooks.pre_comm(step=self.S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_comm(step=self.S, level_number=0)
self.wait_with_interrupt(request=self.req_diff)
if self.S.status.force_done:
@@ -360,7 +369,8 @@ def check_iteration_estimate(self, comm):
tmp = np.array(diff_new, dtype=float)
self.req_diff = comm.Issend((tmp, MPI.DOUBLE), dest=self.S.next, tag=999)
- self.hooks.post_comm(step=self.S, level_number=0)
+ for hook in self.hooks:
+ hook.post_comm(step=self.S, level_number=0)
# Store values from first iteration
if self.S.status.iter == 1:
@@ -382,14 +392,18 @@ def check_iteration_estimate(self, comm):
if np.ceil(Kest_glob) <= self.S.status.iter:
if self.S.status.last:
self.logger.debug(f'{self.S.status.slot} is done, broadcasting..')
- self.hooks.pre_comm(step=self.S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_comm(step=self.S, level_number=0)
comm.Ibcast((np.array([1]), MPI.INT), root=self.S.status.slot).Wait()
- self.hooks.post_comm(step=self.S, level_number=0, add_to_stats=True)
+ for hook in self.hooks:
+ hook.post_comm(step=self.S, level_number=0, add_to_stats=True)
self.logger.debug(f'{self.S.status.slot} is done, broadcasting done')
self.S.status.done = True
else:
- self.hooks.pre_comm(step=self.S, level_number=0)
- self.hooks.post_comm(step=self.S, level_number=0, add_to_stats=True)
+ for hook in self.hooks:
+ hook.pre_comm(step=self.S, level_number=0)
+ for hook in self.hooks:
+ hook.post_comm(step=self.S, level_number=0, add_to_stats=True)
def pfasst(self, comm, num_procs):
"""
@@ -420,7 +434,8 @@ def pfasst(self, comm, num_procs):
self.logger.debug(f'Rewinding {self.S.status.slot} after {stage}..')
self.S.levels[0].u[1:] = self.S.levels[0].uold[1:]
- self.hooks.post_iteration(step=self.S, level_number=0)
+ for hook in self.hooks:
+ hook.post_iteration(step=self.S, level_number=0)
for req in self.req_send:
if req is not None and req != MPI.REQUEST_NULL:
@@ -431,7 +446,8 @@ def pfasst(self, comm, num_procs):
self.req_diff.Cancel()
self.S.status.stage = 'DONE'
- self.hooks.post_step(step=self.S, level_number=0)
+ for hook in self.hooks:
+ hook.post_step(step=self.S, level_number=0)
else:
# Start cycling, if not interrupted
@@ -453,7 +469,8 @@ def spread(self, comm, num_procs):
"""
# first stage: spread values
- self.hooks.pre_step(step=self.S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_step(step=self.S, level_number=0)
# call predictor from sweeper
self.S.levels[0].sweep.predict()
@@ -476,7 +493,8 @@ def predict(self, comm, num_procs):
Predictor phase
"""
- self.hooks.pre_predict(step=self.S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_predict(step=self.S, level_number=0)
if self.params.predict_type is None:
pass
@@ -568,7 +586,8 @@ def predict(self, comm, num_procs):
else:
raise ControllerError('Wrong predictor type, got %s' % self.params.predict_type)
- self.hooks.post_predict(step=self.S, level_number=0)
+ for hook in self.hooks:
+ hook.post_predict(step=self.S, level_number=0)
# update stage
self.S.status.stage = 'IT_CHECK'
@@ -598,7 +617,8 @@ def it_check(self, comm, num_procs):
return None
if self.S.status.iter > 0:
- self.hooks.post_iteration(step=self.S, level_number=0)
+ for hook in self.hooks:
+ hook.post_iteration(step=self.S, level_number=0)
# decide if the step is done, needs to be restarted and other things convergence related
for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
@@ -611,7 +631,8 @@ def it_check(self, comm, num_procs):
# increment iteration count here (and only here)
self.S.status.iter += 1
- self.hooks.pre_iteration(step=self.S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_iteration(step=self.S, level_number=0)
for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
C.pre_iteration_processing(self, self.S, comm=comm)
@@ -648,7 +669,8 @@ def it_check(self, comm, num_procs):
if self.req_diff is not None:
self.req_diff.Cancel()
- self.hooks.post_step(step=self.S, level_number=0)
+ for hook in self.hooks:
+ hook.post_step(step=self.S, level_number=0)
self.S.status.stage = 'DONE'
def it_fine(self, comm, num_procs):
@@ -675,10 +697,12 @@ def it_fine(self, comm, num_procs):
if self.S.status.force_done:
return None
- self.hooks.pre_sweep(step=self.S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_sweep(step=self.S, level_number=0)
self.S.levels[0].sweep.update_nodes()
self.S.levels[0].sweep.compute_residual()
- self.hooks.post_sweep(step=self.S, level_number=0)
+ for hook in self.hooks:
+ hook.post_sweep(step=self.S, level_number=0)
# update stage
self.S.status.stage = 'IT_CHECK'
@@ -705,10 +729,12 @@ def it_down(self, comm, num_procs):
if self.S.status.force_done:
return None
- self.hooks.pre_sweep(step=self.S, level_number=l)
+ for hook in self.hooks:
+ hook.pre_sweep(step=self.S, level_number=l)
self.S.levels[l].sweep.update_nodes()
self.S.levels[l].sweep.compute_residual()
- self.hooks.post_sweep(step=self.S, level_number=l)
+ for hook in self.hooks:
+ hook.post_sweep(step=self.S, level_number=l)
# transfer further down the hierarchy
self.S.transfer(source=self.S.levels[l], target=self.S.levels[l + 1])
@@ -727,14 +753,16 @@ def it_coarse(self, comm, num_procs):
return None
# do the sweep
- self.hooks.pre_sweep(step=self.S, level_number=len(self.S.levels) - 1)
+ for hook in self.hooks:
+ hook.pre_sweep(step=self.S, level_number=len(self.S.levels) - 1)
assert self.S.levels[-1].params.nsweeps == 1, (
'ERROR: this controller can only work with one sweep on the coarse level, got %s'
% self.S.levels[-1].params.nsweeps
)
self.S.levels[-1].sweep.update_nodes()
self.S.levels[-1].sweep.compute_residual()
- self.hooks.post_sweep(step=self.S, level_number=len(self.S.levels) - 1)
+ for hook in self.hooks:
+ hook.post_sweep(step=self.S, level_number=len(self.S.levels) - 1)
self.S.levels[-1].sweep.compute_end_point()
# send to next step
@@ -774,10 +802,12 @@ def it_up(self, comm, num_procs):
if self.S.status.force_done:
return None
- self.hooks.pre_sweep(step=self.S, level_number=l - 1)
+ for hook in self.hooks:
+ hook.pre_sweep(step=self.S, level_number=l - 1)
self.S.levels[l - 1].sweep.update_nodes()
self.S.levels[l - 1].sweep.compute_residual()
- self.hooks.post_sweep(step=self.S, level_number=l - 1)
+ for hook in self.hooks:
+ hook.post_sweep(step=self.S, level_number=l - 1)
# update stage
self.S.status.stage = 'IT_FINE'
diff --git a/pySDC/implementations/controller_classes/controller_nonMPI.py b/pySDC/implementations/controller_classes/controller_nonMPI.py
index 2144717298..4cbd86260e 100644
--- a/pySDC/implementations/controller_classes/controller_nonMPI.py
+++ b/pySDC/implementations/controller_classes/controller_nonMPI.py
@@ -100,7 +100,8 @@ def run(self, u0, t0, Tend):
# some initializations and reset of statistics
uend = None
num_procs = len(self.MS)
- self.hooks.reset_stats()
+ for hook in self.hooks:
+ hook.reset_stats()
# initial ordering of the steps: 0,1,...,Np-1
slots = list(range(num_procs))
@@ -120,11 +121,13 @@ def run(self, u0, t0, Tend):
# initialize block of steps with u0
self.restart_block(active_slots, time, u0)
- self.hooks.post_setup(step=None, level_number=None)
+ for hook in self.hooks:
+ hook.post_setup(step=None, level_number=None)
# call pre-run hook
for S in self.MS:
- self.hooks.pre_run(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_run(step=S, level_number=0)
# main loop: as long as at least one step is still active (time < Tend), do something
while any(active):
@@ -168,9 +171,10 @@ def run(self, u0, t0, Tend):
# call post-run hook
for S in self.MS:
- self.hooks.post_run(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.post_run(step=S, level_number=0)
- return uend, self.hooks.return_stats()
+ return uend, self.return_stats()
def restart_block(self, active_slots, time, u0):
"""
@@ -241,13 +245,16 @@ def send(source, tag):
source.sweep.compute_end_point()
source.tag = cp.deepcopy(tag)
- self.hooks.pre_comm(step=S, level_number=level)
+ for hook in self.hooks:
+ hook.pre_comm(step=S, level_number=level)
if not S.status.last:
self.logger.debug(
'Process %2i provides data on level %2i with tag %s' % (S.status.slot, level, S.status.iter)
)
send(S.levels[level], tag=(level, S.status.iter, S.status.slot))
- self.hooks.post_comm(step=S, level_number=level, add_to_stats=add_to_stats)
+
+ for hook in self.hooks:
+ hook.post_comm(step=S, level_number=level, add_to_stats=add_to_stats)
def recv_full(self, S, level=None, add_to_stats=False):
"""
@@ -276,14 +283,16 @@ def recv(target, source, tag=None):
# re-evaluate f on left interval boundary
target.f[0] = target.prob.eval_f(target.u[0], target.time)
- self.hooks.pre_comm(step=S, level_number=level)
+ for hook in self.hooks:
+ hook.pre_comm(step=S, level_number=level)
if not S.status.prev_done and not S.status.first:
self.logger.debug(
'Process %2i receives from %2i on level %2i with tag %s'
% (S.status.slot, S.prev.status.slot, level, S.status.iter)
)
recv(S.levels[level], S.prev.levels[level], tag=(level, S.status.iter, S.prev.status.slot))
- self.hooks.post_comm(step=S, level_number=level, add_to_stats=add_to_stats)
+ for hook in self.hooks:
+ hook.post_comm(step=S, level_number=level, add_to_stats=add_to_stats)
def pfasst(self, local_MS_active):
"""
@@ -333,7 +342,8 @@ def spread(self, local_MS_running):
for S in local_MS_running:
# first stage: spread values
- self.hooks.pre_step(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_step(step=S, level_number=0)
# call predictor from sweeper
S.levels[0].sweep.predict()
@@ -356,7 +366,8 @@ def predict(self, local_MS_running):
"""
for S in local_MS_running:
- self.hooks.pre_predict(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_predict(step=S, level_number=0)
if self.params.predict_type is None:
pass
@@ -464,7 +475,8 @@ def predict(self, local_MS_running):
raise ControllerError('Wrong predictor type, got %s' % self.params.predict_type)
for S in local_MS_running:
- self.hooks.post_predict(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.post_predict(step=S, level_number=0)
for S in local_MS_running:
# update stage
@@ -490,7 +502,8 @@ def it_check(self, local_MS_running):
for S in local_MS_running:
if S.status.iter > 0:
- self.hooks.post_iteration(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.post_iteration(step=S, level_number=0)
# decide if the step is done, needs to be restarted and other things convergence related
for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
@@ -499,20 +512,25 @@ def it_check(self, local_MS_running):
for S in local_MS_running:
if not S.status.first:
- self.hooks.pre_comm(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_comm(step=S, level_number=0)
S.status.prev_done = S.prev.status.done # "communicate"
- self.hooks.post_comm(step=S, level_number=0, add_to_stats=True)
+ for hook in self.hooks:
+ hook.post_comm(step=S, level_number=0, add_to_stats=True)
S.status.done = S.status.done and S.status.prev_done
if self.params.all_to_done:
- self.hooks.pre_comm(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_comm(step=S, level_number=0)
S.status.done = all([T.status.done for T in local_MS_running])
- self.hooks.post_comm(step=S, level_number=0, add_to_stats=True)
+ for hook in self.hooks:
+ hook.post_comm(step=S, level_number=0, add_to_stats=True)
if not S.status.done:
# increment iteration count here (and only here)
S.status.iter += 1
- self.hooks.pre_iteration(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_iteration(step=S, level_number=0)
for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
C.pre_iteration_processing(self, S)
@@ -525,7 +543,8 @@ def it_check(self, local_MS_running):
S.status.stage = 'IT_COARSE' # serial MSSDC (Gauss-like)
else:
S.levels[0].sweep.compute_end_point()
- self.hooks.post_step(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.post_step(step=S, level_number=0)
S.status.stage = 'DONE'
for C in [self.convergence_controllers[i] for i in self.convergence_controller_order]:
@@ -555,10 +574,12 @@ def it_fine(self, local_MS_running):
for S in local_MS_running:
# standard sweep workflow: update nodes, compute residual, log progress
- self.hooks.pre_sweep(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_sweep(step=S, level_number=0)
S.levels[0].sweep.update_nodes()
S.levels[0].sweep.compute_residual()
- self.hooks.post_sweep(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.post_sweep(step=S, level_number=0)
for S in local_MS_running:
# update stage
@@ -589,10 +610,12 @@ def it_down(self, local_MS_running):
self.recv_full(S, level=l)
for S in local_MS_running:
- self.hooks.pre_sweep(step=S, level_number=l)
+ for hook in self.hooks:
+ hook.pre_sweep(step=S, level_number=l)
S.levels[l].sweep.update_nodes()
S.levels[l].sweep.compute_residual()
- self.hooks.post_sweep(step=S, level_number=l)
+ for hook in self.hooks:
+ hook.post_sweep(step=S, level_number=l)
for S in local_MS_running:
# transfer further down the hierarchy
@@ -616,10 +639,12 @@ def it_coarse(self, local_MS_running):
self.recv_full(S, level=len(S.levels) - 1)
# do the sweep
- self.hooks.pre_sweep(step=S, level_number=len(S.levels) - 1)
+ for hook in self.hooks:
+ hook.pre_sweep(step=S, level_number=len(S.levels) - 1)
S.levels[-1].sweep.update_nodes()
S.levels[-1].sweep.compute_residual()
- self.hooks.post_sweep(step=S, level_number=len(S.levels) - 1)
+ for hook in self.hooks:
+ hook.post_sweep(step=S, level_number=len(S.levels) - 1)
# send to succ step
self.send_full(S, level=len(S.levels) - 1, add_to_stats=True)
@@ -657,10 +682,12 @@ def it_up(self, local_MS_running):
self.recv_full(S, level=l - 1, add_to_stats=(k == self.nsweeps[l - 1] - 1))
for S in local_MS_running:
- self.hooks.pre_sweep(step=S, level_number=l - 1)
+ for hook in self.hooks:
+ hook.pre_sweep(step=S, level_number=l - 1)
S.levels[l - 1].sweep.update_nodes()
S.levels[l - 1].sweep.compute_residual()
- self.hooks.post_sweep(step=S, level_number=l - 1)
+ for hook in self.hooks:
+ hook.post_sweep(step=S, level_number=l - 1)
for S in local_MS_running:
# update stage
diff --git a/pySDC/implementations/convergence_controller_classes/adaptivity.py b/pySDC/implementations/convergence_controller_classes/adaptivity.py
index 06fc27f073..c39bf5fae2 100644
--- a/pySDC/implementations/convergence_controller_classes/adaptivity.py
+++ b/pySDC/implementations/convergence_controller_classes/adaptivity.py
@@ -7,6 +7,7 @@
BasicRestartingNonMPI,
)
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
+from pySDC.implementations.hooks.log_step_size import LogStepSize
class AdaptivityBase(ConvergenceController):
@@ -35,6 +36,7 @@ def setup(self, controller, params, description, **kwargs):
"control_order": -50,
"beta": 0.9,
}
+ controller.add_hook(LogStepSize)
return {**defaults, **params}
def dependencies(self, controller, description, **kwargs):
diff --git a/pySDC/implementations/convergence_controller_classes/check_convergence.py b/pySDC/implementations/convergence_controller_classes/check_convergence.py
index f969779d07..e6a1907a9a 100644
--- a/pySDC/implementations/convergence_controller_classes/check_convergence.py
+++ b/pySDC/implementations/convergence_controller_classes/check_convergence.py
@@ -87,13 +87,16 @@ def communicate_convergence(self, controller, S, comm):
if controller.params.all_to_done:
from mpi4py.MPI import LAND
- controller.hooks.pre_comm(step=S, level_number=0)
+ for hook in controller.hooks:
+ hook.pre_comm(step=S, level_number=0)
S.status.done = comm.allreduce(sendobj=S.status.done, op=LAND)
- controller.hooks.post_comm(step=S, level_number=0, add_to_stats=True)
+ for hook in controller.hooks:
+ hook.post_comm(step=S, level_number=0, add_to_stats=True)
else:
- controller.hooks.pre_comm(step=S, level_number=0)
+ for hook in controller.hooks:
+ hook.pre_comm(step=S, level_number=0)
# check if an open request of the status send is pending
controller.wait_with_interrupt(request=controller.req_status)
@@ -109,4 +112,5 @@ def communicate_convergence(self, controller, S, comm):
if not S.status.last:
self.send(comm, dest=S.status.slot + 1, data=S.status.done)
- controller.hooks.post_comm(step=S, level_number=0, add_to_stats=True)
+ for hook in controller.hooks:
+ hook.post_comm(step=S, level_number=0, add_to_stats=True)
diff --git a/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py b/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py
index 1cf5e3f41b..f33f272e29 100644
--- a/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py
+++ b/pySDC/implementations/convergence_controller_classes/estimate_embedded_error.py
@@ -2,6 +2,7 @@
from pySDC.core.ConvergenceController import ConvergenceController, Pars
from pySDC.implementations.convergence_controller_classes.store_uold import StoreUOld
+from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate
from pySDC.implementations.sweeper_classes.Runge_Kutta import RungeKutta
@@ -16,7 +17,7 @@ class EstimateEmbeddedError(ConvergenceController):
def __init__(self, controller, params, description, **kwargs):
"""
- Initalization routine. Add the buffers for communication.
+ Initialisation routine. Add the buffers for communication.
Args:
controller (pySDC.Controller): The controller
@@ -25,6 +26,7 @@ def __init__(self, controller, params, description, **kwargs):
"""
super(EstimateEmbeddedError, self).__init__(controller, params, description, **kwargs)
self.buffers = Pars({'e_em_last': 0.0})
+ controller.add_hook(LogEmbeddedErrorEstimate)
@classmethod
def get_implementation(cls, flavor):
@@ -42,7 +44,7 @@ def get_implementation(cls, flavor):
elif flavor == 'nonMPI':
return EstimateEmbeddedErrorNonMPI
else:
- raise NotImplementedError(f'Flavor {flavor} of EmstimateEmbeddedError is not implemented!')
+ raise NotImplementedError(f'Flavor {flavor} of EstimateEmbeddedError is not implemented!')
def setup(self, controller, params, description, **kwargs):
"""
@@ -123,7 +125,7 @@ def reset_status_variables(self, controller, **kwargs):
class EstimateEmbeddedErrorNonMPI(EstimateEmbeddedError):
def reset_buffers_nonMPI(self, controller, **kwargs):
"""
- Reset buffers for immitated communication.
+ Reset buffers for imitated communication.
Args:
controller (pySDC.controller): The controller
diff --git a/pySDC/implementations/convergence_controller_classes/estimate_extrapolation_error.py b/pySDC/implementations/convergence_controller_classes/estimate_extrapolation_error.py
index 698df89627..887761d982 100644
--- a/pySDC/implementations/convergence_controller_classes/estimate_extrapolation_error.py
+++ b/pySDC/implementations/convergence_controller_classes/estimate_extrapolation_error.py
@@ -4,6 +4,7 @@
from pySDC.core.ConvergenceController import ConvergenceController, Status
from pySDC.core.Errors import DataError
from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
+from pySDC.implementations.hooks.log_extrapolated_error_estimate import LogExtrapolationErrorEstimate
class EstimateExtrapolationErrorBase(ConvergenceController):
@@ -27,6 +28,7 @@ def __init__(self, controller, params, description, **kwargs):
self.prev = Status(["t", "u", "f", "dt"]) # store solutions etc. of previous steps here
self.coeff = Status(["u", "f", "prefactor"]) # store coefficients for extrapolation here
super(EstimateExtrapolationErrorBase, self).__init__(controller, params, description)
+ controller.add_hook(LogExtrapolationErrorEstimate)
def setup(self, controller, params, description, **kwargs):
"""
@@ -252,7 +254,7 @@ class EstimateExtrapolationErrorNonMPI(EstimateExtrapolationErrorBase):
def setup(self, controller, params, description, **kwargs):
"""
- Add a no parameter 'no_storage' which decides whether the standart or the no-memory-overhead version is run,
+ Add a no parameter 'no_storage' which decides whether the standard or the no-memory-overhead version is run,
where only values are used for extrapolation which are in memory of other processes
Args:
diff --git a/pySDC/implementations/hooks/default_hook.py b/pySDC/implementations/hooks/default_hook.py
new file mode 100644
index 0000000000..dcdd236421
--- /dev/null
+++ b/pySDC/implementations/hooks/default_hook.py
@@ -0,0 +1,343 @@
+import time
+from pySDC.core.Hooks import hooks
+
+
+class DefaultHooks(hooks):
+ """
+ Hook class to contain the functions called during the controller runs (e.g. for calling user-routines)
+
+ Attributes:
+ __t0_setup (float): private variable to get starting time of setup
+ __t0_run (float): private variable to get starting time of the run
+ __t0_predict (float): private variable to get starting time of the predictor
+ __t0_step (float): private variable to get starting time of the step
+ __t0_iteration (float): private variable to get starting time of the iteration
+ __t0_sweep (float): private variable to get starting time of the sweep
+ __t0_comm (list): private variable to get starting time of the communication
+ __t1_run (float): private variable to get end time of the run
+ __t1_predict (float): private variable to get end time of the predictor
+ __t1_step (float): private variable to get end time of the step
+ __t1_iteration (float): private variable to get end time of the iteration
+ __t1_sweep (float): private variable to get end time of the sweep
+ __t1_setup (float): private variable to get end time of setup
+ __t1_comm (list): private variable to hold timing of the communication (!)
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.__t0_setup = None
+ self.__t0_run = None
+ self.__t0_predict = None
+ self.__t0_step = None
+ self.__t0_iteration = None
+ self.__t0_sweep = None
+ self.__t0_comm = []
+ self.__t1_run = None
+ self.__t1_predict = None
+ self.__t1_step = None
+ self.__t1_iteration = None
+ self.__t1_sweep = None
+ self.__t1_setup = None
+ self.__t1_comm = []
+
+ def pre_setup(self, step, level_number):
+ """
+ Default routine called before setup starts
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+ """
+ super().pre_setup(step, level_number)
+ self.__t0_setup = time.perf_counter()
+
+ def pre_run(self, step, level_number):
+ """
+ Default routine called before time-loop starts
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+ """
+ super().pre_run(step, level_number)
+ self.__t0_run = time.perf_counter()
+
+ def pre_predict(self, step, level_number):
+ """
+ Default routine called before predictor starts
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+ """
+ super().pre_predict(step, level_number)
+ self.__t0_predict = time.perf_counter()
+
+ def pre_step(self, step, level_number):
+ """
+ Hook called before each step
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+ """
+ super().pre_step(step, level_number)
+ self.__t0_step = time.perf_counter()
+
+ def pre_iteration(self, step, level_number):
+ """
+ Default routine called before iteration starts
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+ """
+ super().pre_iteration(step, level_number)
+ self.__t0_iteration = time.perf_counter()
+
+ def pre_sweep(self, step, level_number):
+ """
+ Default routine called before sweep starts
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+ """
+ super().pre_sweep(step, level_number)
+ self.__t0_sweep = time.perf_counter()
+
+ def pre_comm(self, step, level_number):
+ """
+ Default routine called before communication starts
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+ """
+ super().pre_comm(step, level_number)
+ if len(self.__t0_comm) >= level_number + 1:
+ self.__t0_comm[level_number] = time.perf_counter()
+ else:
+ while len(self.__t0_comm) < level_number:
+ self.__t0_comm.append(None)
+ self.__t0_comm.append(time.perf_counter())
+ while len(self.__t1_comm) <= level_number:
+ self.__t1_comm.append(0.0)
+ assert len(self.__t0_comm) == level_number + 1
+ assert len(self.__t1_comm) == level_number + 1
+
+ def post_comm(self, step, level_number, add_to_stats=False):
+ """
+ Default routine called after each communication
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+ add_to_stats (bool): set if result should go to stats object
+ """
+ super().post_comm(step, level_number)
+ assert len(self.__t1_comm) >= level_number + 1
+ self.__t1_comm[level_number] += time.perf_counter() - self.__t0_comm[level_number]
+
+ if add_to_stats:
+ L = step.levels[level_number]
+
+ self.add_to_stats(
+ process=step.status.slot,
+ time=L.time,
+ level=L.level_index,
+ iter=step.status.iter,
+ sweep=L.status.sweep,
+ type='timing_comm',
+ value=self.__t1_comm[level_number],
+ )
+ self.__t1_comm[level_number] = 0.0
+
+ def post_sweep(self, step, level_number):
+ """
+ Default routine called after each sweep
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+ """
+ super().post_sweep(step, level_number)
+ self.__t1_sweep = time.perf_counter()
+
+ L = step.levels[level_number]
+
+ self.logger.info(
+ 'Process %2i on time %8.6f at stage %15s: Level: %s -- Iteration: %2i -- Sweep: %2i -- ' 'residual: %12.8e',
+ step.status.slot,
+ L.time,
+ step.status.stage,
+ L.level_index,
+ step.status.iter,
+ L.status.sweep,
+ L.status.residual,
+ )
+
+ self.add_to_stats(
+ process=step.status.slot,
+ time=L.time,
+ level=L.level_index,
+ iter=step.status.iter,
+ sweep=L.status.sweep,
+ type='residual_post_sweep',
+ value=L.status.residual,
+ )
+ self.add_to_stats(
+ process=step.status.slot,
+ time=L.time,
+ level=L.level_index,
+ iter=step.status.iter,
+ sweep=L.status.sweep,
+ type='timing_sweep',
+ value=self.__t1_sweep - self.__t0_sweep,
+ )
+
+ def post_iteration(self, step, level_number):
+ """
+ Default routine called after each iteration
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+ """
+ super().post_iteration(step, level_number)
+ self.__t1_iteration = time.perf_counter()
+
+ L = step.levels[level_number]
+
+ self.add_to_stats(
+ process=step.status.slot,
+ time=L.time,
+ level=-1,
+ iter=step.status.iter,
+ sweep=L.status.sweep,
+ type='residual_post_iteration',
+ value=L.status.residual,
+ )
+ self.add_to_stats(
+ process=step.status.slot,
+ time=L.time,
+ level=L.level_index,
+ iter=step.status.iter,
+ sweep=L.status.sweep,
+ type='timing_iteration',
+ value=self.__t1_iteration - self.__t0_iteration,
+ )
+
+ def post_step(self, step, level_number):
+ """
+ Default routine called after each step or block
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+ """
+ super().post_step(step, level_number)
+ self.__t1_step = time.perf_counter()
+
+ L = step.levels[level_number]
+
+ self.add_to_stats(
+ process=step.status.slot,
+ time=L.time,
+ level=L.level_index,
+ iter=step.status.iter,
+ sweep=L.status.sweep,
+ type='timing_step',
+ value=self.__t1_step - self.__t0_step,
+ )
+ self.add_to_stats(
+ process=step.status.slot,
+ time=L.time,
+ level=-1,
+ iter=step.status.iter,
+ sweep=L.status.sweep,
+ type='niter',
+ value=step.status.iter,
+ )
+ self.add_to_stats(
+ process=step.status.slot,
+ time=L.time,
+ level=L.level_index,
+ iter=-1,
+ sweep=L.status.sweep,
+ type='residual_post_step',
+ value=L.status.residual,
+ )
+
+ # record the recomputed quantities at weird positions to make sure there is only one value for each step
+ for t in [L.time, L.time + L.dt]:
+ self.add_to_stats(
+ process=-1, time=t, level=-1, iter=-1, sweep=-1, type='_recomputed', value=step.status.get('restart')
+ )
+
+ def post_predict(self, step, level_number):
+ """
+ Default routine called after each predictor
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+ """
+ super().post_predict(step, level_number)
+ self.__t1_predict = time.perf_counter()
+
+ L = step.levels[level_number]
+
+ self.add_to_stats(
+ process=step.status.slot,
+ time=L.time,
+ level=L.level_index,
+ iter=step.status.iter,
+ sweep=L.status.sweep,
+ type='timing_predictor',
+ value=self.__t1_predict - self.__t0_predict,
+ )
+
+ def post_run(self, step, level_number):
+ """
+ Default routine called after each run
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+ """
+ super().post_run(step, level_number)
+ self.__t1_run = time.perf_counter()
+
+ L = step.levels[level_number]
+
+ self.add_to_stats(
+ process=step.status.slot,
+ time=L.time,
+ level=L.level_index,
+ iter=step.status.iter,
+ sweep=L.status.sweep,
+ type='timing_run',
+ value=self.__t1_run - self.__t0_run,
+ )
+
+ def post_setup(self, step, level_number):
+ """
+ Default routine called after setup
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+ """
+ super().post_setup(step, level_number)
+ self.__t1_setup = time.perf_counter()
+
+ self.add_to_stats(
+ process=-1,
+ time=-1,
+ level=-1,
+ iter=-1,
+ sweep=-1,
+ type='timing_setup',
+ value=self.__t1_setup - self.__t0_setup,
+ )
diff --git a/pySDC/implementations/hooks/log_embedded_error_estimate.py b/pySDC/implementations/hooks/log_embedded_error_estimate.py
new file mode 100644
index 0000000000..60fa0c6ea7
--- /dev/null
+++ b/pySDC/implementations/hooks/log_embedded_error_estimate.py
@@ -0,0 +1,32 @@
+from pySDC.core.Hooks import hooks
+
+
+class LogEmbeddedErrorEstimate(hooks):
+ """
+ Store the embedded error estimate at the end of each step as "error_embedded_estimate".
+ """
+
+ def post_step(self, step, level_number):
+ """
+ Record embedded error estimate
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+
+ Returns:
+ None
+ """
+ super().post_step(step, level_number)
+
+ L = step.levels[level_number]
+
+ self.add_to_stats(
+ process=step.status.slot,
+ time=L.time + L.dt,
+ level=L.level_index,
+ iter=step.status.iter,
+ sweep=L.status.sweep,
+ type='error_embedded_estimate',
+ value=L.status.get('error_embedded_estimate'),
+ )
diff --git a/pySDC/implementations/hooks/log_extrapolated_error_estimate.py b/pySDC/implementations/hooks/log_extrapolated_error_estimate.py
new file mode 100644
index 0000000000..1530db9e18
--- /dev/null
+++ b/pySDC/implementations/hooks/log_extrapolated_error_estimate.py
@@ -0,0 +1,33 @@
+from pySDC.core.Hooks import hooks
+
+
+class LogExtrapolationErrorEstimate(hooks):
+ """
+ Store the extrapolated error estimate at the end of each step as "error_extrapolation_estimate".
+ """
+
+ def post_step(self, step, level_number):
+ """
+ Record extrapolated error estimate
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+
+ Returns:
+ None
+ """
+ super().post_step(step, level_number)
+
+ # some abbreviations
+ L = step.levels[level_number]
+
+ self.add_to_stats(
+ process=step.status.slot,
+ time=L.time + L.dt,
+ level=L.level_index,
+ iter=step.status.iter,
+ sweep=L.status.sweep,
+ type='error_extrapolation_estimate',
+ value=L.status.get('error_extrapolation_estimate'),
+ )
diff --git a/pySDC/implementations/hooks/log_solution.py b/pySDC/implementations/hooks/log_solution.py
new file mode 100644
index 0000000000..9d2d72d2e6
--- /dev/null
+++ b/pySDC/implementations/hooks/log_solution.py
@@ -0,0 +1,33 @@
+from pySDC.core.Hooks import hooks
+
+
+class LogSolution(hooks):
+ """
+ Store the solution at the end of each step as "u".
+ """
+
+ def post_step(self, step, level_number):
+ """
+ Record solution at the end of the step
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+
+ Returns:
+ None
+ """
+ super().post_step(step, level_number)
+
+ L = step.levels[level_number]
+ L.sweep.compute_end_point()
+
+ self.add_to_stats(
+ process=step.status.slot,
+ time=L.time + L.dt,
+ level=L.level_index,
+ iter=step.status.iter,
+ sweep=L.status.sweep,
+ type='u',
+ value=L.uend,
+ )
diff --git a/pySDC/implementations/hooks/log_step_size.py b/pySDC/implementations/hooks/log_step_size.py
new file mode 100644
index 0000000000..62dada9ab0
--- /dev/null
+++ b/pySDC/implementations/hooks/log_step_size.py
@@ -0,0 +1,32 @@
+from pySDC.core.Hooks import hooks
+
+
+class LogStepSize(hooks):
+ """
+ Store the step size at the end of each step as "dt".
+ """
+
+ def post_step(self, step, level_number):
+ """
+ Record step size
+
+ Args:
+ step (pySDC.Step.step): the current step
+ level_number (int): the current level number
+
+ Returns:
+ None
+ """
+ super().post_step(step, level_number)
+
+ L = step.levels[level_number]
+
+ self.add_to_stats(
+ process=step.status.slot,
+ time=L.time,
+ level=L.level_index,
+ iter=step.status.iter,
+ sweep=L.status.sweep,
+ type='dt',
+ value=L.dt,
+ )
diff --git a/pySDC/implementations/problem_classes/Battery.py b/pySDC/implementations/problem_classes/Battery.py
index 673c22233e..56b3713367 100644
--- a/pySDC/implementations/problem_classes/Battery.py
+++ b/pySDC/implementations/problem_classes/Battery.py
@@ -8,28 +8,32 @@
class battery(ptype):
"""
Example implementing the battery drain model as in the description in the PinTSimE project
+
Attributes:
A: system matrix, representing the 2 ODEs
+ t_switch: time point of the switch
+ nswitches: number of switches
"""
def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh):
"""
Initialization routine
+
Args:
problem_params (dict): custom parameters for the example
dtype_u: mesh data type for solution
dtype_f: mesh data type for RHS
"""
- problem_params['nvars'] = 2
-
# these parameters will be used later, so assert their existence
- essential_keys = ['Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref', 'set_switch', 't_switch']
+ essential_keys = ['ncondensators', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref']
for key in essential_keys:
if key not in problem_params:
msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys()))
raise ParameterError(msg)
+ problem_params['nvars'] = problem_params['ncondensators'] + 1
+
# invoke super init, passing number of dofs, dtype_u and dtype_f
super(battery, self).__init__(
init=(problem_params['nvars'], None, np.dtype('float64')),
@@ -39,13 +43,17 @@ def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh):
)
self.A = np.zeros((2, 2))
+ self.t_switch = None
+ self.nswitches = 0
def eval_f(self, u, t):
"""
Routine to evaluate the RHS
+
Args:
u (dtype_u): current values
t (float): current time
+
Returns:
dtype_f: the RHS
"""
@@ -53,17 +61,10 @@ def eval_f(self, u, t):
f = self.dtype_f(self.init, val=0.0)
f.impl[:] = self.A.dot(u)
- if u[1] <= self.params.V_ref or self.params.set_switch:
- # switching need to happen on exact time point
- if self.params.set_switch:
- if t >= self.params.t_switch:
- f.expl[0] = self.params.Vs / self.params.L
-
- else:
- f.expl[0] = 0
+ t_switch = np.inf if self.t_switch is None else self.t_switch
- else:
- f.expl[0] = self.params.Vs / self.params.L
+ if u[1] <= self.params.V_ref or t >= t_switch:
+ f.expl[0] = self.params.Vs / self.params.L
else:
f.expl[0] = 0
@@ -73,27 +74,22 @@ def eval_f(self, u, t):
def solve_system(self, rhs, factor, u0, t):
"""
Simple linear solver for (I-factor*A)u = rhs
+
Args:
rhs (dtype_f): right-hand side for the linear system
factor (float): abbrev. for the local stepsize (or any other factor required)
u0 (dtype_u): initial guess for the iterative solver
t (float): current time (e.g. for time-dependent BCs)
+
Returns:
dtype_u: solution as mesh
"""
self.A = np.zeros((2, 2))
- if rhs[1] <= self.params.V_ref or self.params.set_switch:
- # switching need to happen on exact time point
- if self.params.set_switch:
- if t >= self.params.t_switch:
- self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
+ t_switch = np.inf if self.t_switch is None else self.t_switch
- else:
- self.A[1, 1] = -1 / (self.params.C * self.params.R)
-
- else:
- self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
+ if rhs[1] <= self.params.V_ref or t >= t_switch:
+ self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
else:
self.A[1, 1] = -1 / (self.params.C * self.params.R)
@@ -105,8 +101,10 @@ def solve_system(self, rhs, factor, u0, t):
def u_exact(self, t):
"""
Routine to compute the exact solution at time t
+
Args:
t (float): current time
+
Returns:
dtype_u: exact solution
"""
@@ -119,53 +117,60 @@ def u_exact(self, t):
return me
+ def get_switching_info(self, u, t):
+ """
+ Provides information about a discrete event for one subinterval.
-class battery_implicit(ptype):
- """
- Example implementing the battery drain model as in the description in the PinTSimE project
- Attributes:
- A: system matrix, representing the 2 ODEs
- """
+ Args:
+ u (dtype_u): current values
+ t (float): current time
- def __init__(self, problem_params, dtype_u=mesh, dtype_f=mesh):
+ Returns:
+ switch_detected (bool): Indicates if a switch is found or not
+ m_guess (np.int): Index of collocation node inside one subinterval of where the discrete event was found
+ vC_switch (list): Contains function values of switching condition (for interpolation)
"""
- Initialization routine
- Args:
- problem_params (dict): custom parameters for the example
- dtype_u: mesh data type for solution
- dtype_f: mesh data type for RHS
+
+ switch_detected = False
+ m_guess = -100
+
+ for m in range(len(u)):
+ if u[m][1] - self.params.V_ref <= 0:
+ switch_detected = True
+ m_guess = m - 1
+ break
+
+ vC_switch = [u[m][1] - self.params.V_ref for m in range(1, len(u))] if switch_detected else []
+
+ return switch_detected, m_guess, vC_switch
+
+ def count_switches(self):
+ """
+ Counts the number of switches. This function is called when a switch is found inside the range of tolerance
+ (in switch_estimator.py)
"""
- problem_params['nvars'] = 2
+ self.nswitches += 1
- # these parameters will be used later, so assert their existence
- essential_keys = [
- 'newton_maxiter',
- 'newton_tol',
- 'Vs',
- 'Rs',
- 'C',
- 'R',
- 'L',
- 'alpha',
- 'V_ref',
- 'set_switch',
- 't_switch',
- ]
+
+class battery_implicit(battery):
+ def __init__(self, problem_params, dtype_u=mesh, dtype_f=mesh):
+
+ essential_keys = ['newton_maxiter', 'newton_tol', 'ncondensators', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref']
for key in essential_keys:
if key not in problem_params:
msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys()))
raise ParameterError(msg)
+ problem_params['nvars'] = problem_params['ncondensators'] + 1
+
# invoke super init, passing number of dofs, dtype_u and dtype_f
super(battery_implicit, self).__init__(
- init=(problem_params['nvars'], None, np.dtype('float64')),
+ problem_params,
dtype_u=dtype_u,
dtype_f=dtype_f,
- params=problem_params,
)
- self.A = np.zeros((2, 2))
self.newton_itercount = 0
self.lin_itercount = 0
self.newton_ncalls = 0
@@ -174,9 +179,11 @@ def __init__(self, problem_params, dtype_u=mesh, dtype_f=mesh):
def eval_f(self, u, t):
"""
Routine to evaluate the RHS
+
Args:
u (dtype_u): current values
t (float): current time
+
Returns:
dtype_f: the RHS
"""
@@ -184,37 +191,29 @@ def eval_f(self, u, t):
f = self.dtype_f(self.init, val=0.0)
non_f = np.zeros(2)
- if u[1] <= self.params.V_ref or self.params.set_switch:
- # switching need to happen on exact time point
- if self.params.set_switch:
- if t >= self.params.t_switch:
- self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
- non_f[0] = self.params.Vs / self.params.L
+ t_switch = np.inf if self.t_switch is None else self.t_switch
- else:
- self.A[1, 1] = -1 / (self.params.C * self.params.R)
- non_f[0] = 0
-
- else:
- self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
- non_f[0] = self.params.Vs / self.params.L
+ if u[1] <= self.params.V_ref or t >= t_switch:
+ self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
+ non_f[0] = self.params.Vs
else:
self.A[1, 1] = -1 / (self.params.C * self.params.R)
non_f[0] = 0
f[:] = self.A.dot(u) + non_f
-
return f
def solve_system(self, rhs, factor, u0, t):
"""
Simple Newton solver
+
Args:
rhs (dtype_f): right-hand side for the linear system
factor (float): abbrev. for the local stepsize (or any other factor required)
u0 (dtype_u): initial guess for the iterative solver
t (float): current time (e.g. for time-dependent BCs)
+
Returns:
dtype_u: solution as mesh
"""
@@ -223,20 +222,11 @@ def solve_system(self, rhs, factor, u0, t):
non_f = np.zeros(2)
self.A = np.zeros((2, 2))
- if rhs[1] <= self.params.V_ref or self.params.set_switch:
- # switching need to happen on exact time point
- if self.params.set_switch:
- if t >= self.params.t_switch:
- self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
- non_f[0] = self.params.Vs / self.params.L
+ t_switch = np.inf if self.t_switch is None else self.t_switch
- else:
- self.A[1, 1] = -1 / (self.params.C * self.params.R)
- non_f[0] = 0
-
- else:
- self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
- non_f[0] = self.params.Vs / self.params.L
+ if rhs[1] <= self.params.V_ref or t >= t_switch:
+ self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
+ non_f[0] = self.params.Vs
else:
self.A[1, 1] = -1 / (self.params.C * self.params.R)
@@ -280,11 +270,131 @@ def solve_system(self, rhs, factor, u0, t):
return me
+
+class battery_n_condensators(ptype):
+ """
+ Example implementing the battery drain model with N capacitors, where N is an arbitrary integer greater than 0.
+ Attributes:
+ nswitches: number of switches
+ """
+
+ def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh):
+ """
+ Initialization routine
+
+ Args:
+ problem_params (dict): custom parameters for the example
+ dtype_u: mesh data type for solution
+ dtype_f: mesh data type for RHS
+ """
+
+ # these parameters will be used later, so assert their existence
+ essential_keys = ['ncondensators', 'Vs', 'Rs', 'C', 'R', 'L', 'alpha', 'V_ref']
+ for key in essential_keys:
+ if key not in problem_params:
+ msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys()))
+ raise ParameterError(msg)
+
+ n = problem_params['ncondensators']
+ problem_params['nvars'] = n + 1
+
+ # invoke super init, passing number of dofs, dtype_u and dtype_f
+ super(battery_n_condensators, self).__init__(
+ init=(problem_params['nvars'], None, np.dtype('float64')),
+ dtype_u=dtype_u,
+ dtype_f=dtype_f,
+ params=problem_params,
+ )
+
+ v = np.zeros(n + 1)
+ v[0] = 1
+
+ self.A = np.zeros((n + 1, n + 1))
+ self.switch_A = {k: np.diag(-1 / (self.params.C[k] * self.params.R) * np.roll(v, k + 1)) for k in range(n)}
+ self.switch_A.update({n: np.diag(-(self.params.Rs + self.params.R) / self.params.L * v)})
+ self.switch_f = {k: np.zeros(n + 1) for k in range(n)}
+ self.switch_f.update({n: self.params.Vs / self.params.L * v})
+ self.t_switch = None
+ self.nswitches = 0
+
+ def eval_f(self, u, t):
+ """
+ Routine to evaluate the RHS. No Switch Estimator is used: For N = 3 there are N + 1 = 4 different states of the battery:
+ 1. u[1] > V_ref[0] and u[2] > V_ref[1] and u[3] > V_ref[2] -> C1 supplies energy
+ 2. u[1] <= V_ref[0] and u[2] > V_ref[1] and u[3] > V_ref[2] -> C2 supplies energy
+ 3. u[1] <= V_ref[0] and u[2] <= V_ref[1] and u[3] > V_ref[2] -> C3 supplies energy
+ 4. u[1] <= V_ref[0] and u[2] <= V_ref[1] and u[3] <= V_ref[2] -> Vs supplies energy
+ max_index is initialized to -1. List "switch" contains a True if u[k] <= V_ref[k-1] is satisfied.
+ - Is no True there (i.e. max_index = -1), we are in the first case.
+ - max_index = k >= 0 means we are in the (k+1)-th case.
+ So, the actual RHS has key max_index-1 in the dictionary self.switch_f.
+ In case of using the Switch Estimator, we count the number of switches which illustrates in which case of voltage source we are.
+
+ Args:
+ u (dtype_u): current values
+ t (float): current time
+
+ Returns:
+ dtype_f: the RHS
+ """
+
+ f = self.dtype_f(self.init, val=0.0)
+ f.impl[:] = self.A.dot(u)
+
+ if self.t_switch is not None:
+ f.expl[:] = self.switch_f[self.nswitches]
+
+ else:
+ # proof all switching conditions and find largest index where it drops below V_ref
+ switch = [True if u[k] <= self.params.V_ref[k - 1] else False for k in range(1, len(u))]
+ max_index = max([k if switch[k] == True else -1 for k in range(len(switch))])
+
+ if max_index == -1:
+ f.expl[:] = self.switch_f[0]
+
+ else:
+ f.expl[:] = self.switch_f[max_index + 1]
+
+ return f
+
+ def solve_system(self, rhs, factor, u0, t):
+ """
+ Simple linear solver for (I-factor*A)u = rhs
+
+ Args:
+ rhs (dtype_f): right-hand side for the linear system
+ factor (float): abbrev. for the local stepsize (or any other factor required)
+ u0 (dtype_u): initial guess for the iterative solver
+ t (float): current time (e.g. for time-dependent BCs)
+
+ Returns:
+ dtype_u: solution as mesh
+ """
+
+ if self.t_switch is not None:
+ self.A = self.switch_A[self.nswitches]
+
+ else:
+ # proof all switching conditions and find largest index where it drops below V_ref
+ switch = [True if rhs[k] <= self.params.V_ref[k - 1] else False for k in range(1, len(rhs))]
+ max_index = max([k if switch[k] == True else -1 for k in range(len(switch))])
+ if max_index == -1:
+ self.A = self.switch_A[0]
+
+ else:
+ self.A = self.switch_A[max_index + 1]
+
+ me = self.dtype_u(self.init)
+ me[:] = np.linalg.solve(np.eye(self.params.nvars) - factor * self.A, rhs)
+ return me
+
def u_exact(self, t):
"""
Routine to compute the exact solution at time t
+
Args:
t (float): current time
+
Returns:
dtype_u: exact solution
"""
@@ -293,6 +403,49 @@ def u_exact(self, t):
me = self.dtype_u(self.init)
me[0] = 0.0 # cL
- me[1] = self.params.alpha * self.params.V_ref # vC
-
+ me[1:] = self.params.alpha * self.params.V_ref # vC's
return me
+
+ def get_switching_info(self, u, t):
+ """
+ Provides information about a discrete event for one subinterval.
+
+ Args:
+ u (dtype_u): current values
+ t (float): current time
+
+ Returns:
+ switch_detected (bool): Indicates if a switch is found or not
+ m_guess (np.int): Index of collocation node inside one subinterval of where the discrete event was found
+ vC_switch (list): Contains function values of switching condition (for interpolation)
+ """
+
+ switch_detected = False
+ m_guess = -100
+ break_flag = False
+
+ for m in range(len(u)):
+ for k in range(1, self.params.nvars):
+ if u[m][k] - self.params.V_ref[k - 1] <= 0:
+ switch_detected = True
+ m_guess = m - 1
+ k_detected = k
+ break_flag = True
+ break
+
+ if break_flag:
+ break
+
+ vC_switch = (
+ [u[m][k_detected] - self.params.V_ref[k_detected - 1] for m in range(1, len(u))] if switch_detected else []
+ )
+
+ return switch_detected, m_guess, vC_switch
+
+ def count_switches(self):
+ """
+ Counts the number of switches. This function is called when a switch is found inside the range of tolerance
+ (in switch_estimator.py)
+ """
+
+ self.nswitches += 1
diff --git a/pySDC/implementations/problem_classes/Battery_2Condensators.py b/pySDC/implementations/problem_classes/Battery_2Condensators.py
deleted file mode 100644
index 1c5eb55030..0000000000
--- a/pySDC/implementations/problem_classes/Battery_2Condensators.py
+++ /dev/null
@@ -1,168 +0,0 @@
-import numpy as np
-
-from pySDC.core.Errors import ParameterError
-from pySDC.core.Problem import ptype
-from pySDC.implementations.datatype_classes.mesh import mesh, imex_mesh
-
-
-class battery_2condensators(ptype):
- """
- Example implementing the battery drain model using two capacitors as in the description in the PinTSimE
- project
- Attributes:
- A: system matrix, representing the 3 ODEs
- """
-
- def __init__(self, problem_params, dtype_u=mesh, dtype_f=imex_mesh):
- """
- Initialization routine
- Args:
- problem_params (dict): custom parameters for the example
- dtype_u: mesh data type for solution
- dtype_f: mesh data type for RHS
- """
-
- problem_params['nvars'] = 3
-
- # these parameters will be used later, so assert their existence
- essential_keys = ['Vs', 'Rs', 'C1', 'C2', 'R', 'L', 'alpha', 'V_ref', 'set_switch', 't_switch']
-
- for key in essential_keys:
- if key not in problem_params:
- msg = 'need %s to instantiate problem, only got %s' % (key, str(problem_params.keys()))
- raise ParameterError(msg)
-
- # invoke super init, passing number of dofs, dtype_u and dtype_f
- super(battery_2condensators, self).__init__(
- init=(problem_params['nvars'], None, np.dtype('float64')),
- dtype_u=dtype_u,
- dtype_f=dtype_f,
- params=problem_params,
- )
-
- self.A = np.zeros((3, 3))
-
- def eval_f(self, u, t):
- """
- Routine to evaluate the RHS
- Args:
- u (dtype_u): current values
- t (float): current time
- Returns:
- dtype_f: the RHS
- """
-
- f = self.dtype_f(self.init, val=0.0)
- f.impl[:] = self.A.dot(u)
-
- # switch to C2
- if (
- u[1] <= self.params.V_ref[0]
- and u[2] > self.params.V_ref[1]
- or self.params.set_switch[0]
- and not self.params.set_switch[1]
- ):
- if self.params.set_switch[0]:
- if t >= self.params.t_switch[0]:
- f.expl[0] = 0
-
- else:
- f.expl[0] = 0
-
- else:
- f.expl[0] = 0
-
- # switch to Vs
- elif u[2] <= self.params.V_ref[1] or (self.params.set_switch[0] and self.params.set_switch[1]):
- # switch to Vs
- if self.params.set_switch[1]:
- if t >= self.params.t_switch[1]:
- f.expl[0] = self.params.Vs / self.params.L
-
- else:
- f.expl[0] = 0
-
- else:
- f.expl[0] = self.params.Vs / self.params.L
-
- elif (
- u[1] > self.params.V_ref[0]
- and u[2] > self.params.V_ref[1]
- or not self.params.set_switch[0]
- and not self.params.set_switch[1]
- ):
- # C1 supplies energy
- f.expl[0] = 0
-
- return f
-
- def solve_system(self, rhs, factor, u0, t):
- """
- Simple linear solver for (I-factor*A)u = rhs
- Args:
- rhs (dtype_f): right-hand side for the linear system
- factor (float): abbrev. for the local stepsize (or any other factor required)
- u0 (dtype_u): initial guess for the iterative solver
- t (float): current time (e.g. for time-dependent BCs)
- Returns:
- dtype_u: solution as mesh
- """
- self.A = np.zeros((3, 3))
-
- # switch to C2
- if (
- rhs[1] <= self.params.V_ref[0]
- and rhs[2] > self.params.V_ref[1]
- or self.params.set_switch[0]
- and not self.params.set_switch[1]
- ):
- if self.params.set_switch[0]:
- if t >= self.params.t_switch[0]:
- self.A[2, 2] = -1 / (self.params.C2 * self.params.R)
-
- else:
- self.A[1, 1] = -1 / (self.params.C1 * self.params.R)
- else:
- self.A[2, 2] = -1 / (self.params.C2 * self.params.R)
-
- # switch to Vs
- elif rhs[2] <= self.params.V_ref[1] or (self.params.set_switch[0] and self.params.set_switch[1]):
- if self.params.set_switch[1]:
- if t >= self.params.t_switch[1]:
- self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
-
- else:
- self.A[2, 2] = -1 / (self.params.C2 * self.params.R)
-
- else:
- self.A[0, 0] = -(self.params.Rs + self.params.R) / self.params.L
-
- elif (
- rhs[1] > self.params.V_ref[0]
- and rhs[2] > self.params.V_ref[1]
- or not self.params.set_switch[0]
- and not self.params.set_switch[1]
- ):
- # C1 supplies energy
- self.A[1, 1] = -1 / (self.params.C1 * self.params.R)
-
- me = self.dtype_u(self.init)
- me[:] = np.linalg.solve(np.eye(self.params.nvars) - factor * self.A, rhs)
- return me
-
- def u_exact(self, t):
- """
- Routine to compute the exact solution at time t
- Args:
- t (float): current time
- Returns:
- dtype_u: exact solution
- """
- assert t == 0, 'ERROR: u_exact only valid for t=0'
-
- me = self.dtype_u(self.init)
-
- me[0] = 0.0 # cL
- me[1] = self.params.alpha * self.params.V_ref[0] # vC1
- me[2] = self.params.alpha * self.params.V_ref[1] # vC2
- return me
diff --git a/pySDC/projects/PinTSimE/battery_2condensators_model.py b/pySDC/projects/PinTSimE/battery_2condensators_model.py
index c41cb82712..8739a96d42 100644
--- a/pySDC/projects/PinTSimE/battery_2condensators_model.py
+++ b/pySDC/projects/PinTSimE/battery_2condensators_model.py
@@ -4,10 +4,10 @@
from pySDC.helpers.stats_helper import get_sorted
from pySDC.core.Collocation import CollBase as Collocation
-from pySDC.implementations.problem_classes.Battery_2Condensators import battery_2condensators
+from pySDC.implementations.problem_classes.Battery import battery_n_condensators
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
-from pySDC.implementations.transfer_classes.TransferMesh import mesh_to_mesh
+from pySDC.projects.PinTSimE.battery_model import get_recomputed
from pySDC.projects.PinTSimE.piline_model import setup_mpl
import pySDC.helpers.plot_helper as plt_helper
from pySDC.core.Hooks import hooks
@@ -52,55 +52,60 @@ def post_step(self, step, level_number):
type='voltage C2',
value=L.uend[2],
)
- self.increment_stats(
+ self.add_to_stats(
process=step.status.slot,
time=L.time,
level=L.level_index,
iter=0,
sweep=L.status.sweep,
type='restart',
- value=1,
- initialize=0,
+ value=int(step.status.get('restart')),
)
def main(use_switch_estimator=True):
"""
A simple test program to do SDC/PFASST runs for the battery drain model using 2 condensators
+
+ Args:
+ use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+
+ Returns:
+ description (dict): contains all information for a controller run
"""
# initialize level parameters
level_params = dict()
- level_params['restol'] = 1e-13
+ level_params['restol'] = -1
level_params['dt'] = 1e-2
+ assert level_params['dt'] == 1e-2, 'Error! Do not use the time step dt != 1e-2!'
+
# initialize sweeper parameters
sweeper_params = dict()
sweeper_params['quad_type'] = 'LOBATTO'
sweeper_params['num_nodes'] = 5
- sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part
- sweeper_params['initial_guess'] = 'zero'
+ # sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part
+ sweeper_params['initial_guess'] = 'spread'
# initialize problem parameters
problem_params = dict()
+ problem_params['ncondensators'] = 2
problem_params['Vs'] = 5.0
problem_params['Rs'] = 0.5
- problem_params['C1'] = 1.0
- problem_params['C2'] = 1.0
+ problem_params['C'] = np.array([1.0, 1.0])
problem_params['R'] = 1.0
problem_params['L'] = 1.0
problem_params['alpha'] = 5.0
problem_params['V_ref'] = np.array([1.0, 1.0]) # [V_ref1, V_ref2]
- problem_params['set_switch'] = np.array([False, False], dtype=bool)
- problem_params['t_switch'] = np.zeros(np.shape(problem_params['V_ref'])[0])
# initialize step parameters
step_params = dict()
- step_params['maxiter'] = 20
+ step_params['maxiter'] = 4
# initialize controller parameters
controller_params = dict()
- controller_params['logger_level'] = 20
+ controller_params['logger_level'] = 30
controller_params['hook_class'] = log_data
# convergence controllers
@@ -111,18 +116,17 @@ def main(use_switch_estimator=True):
# fill description dictionary for easy step instantiation
description = dict()
- description['problem_class'] = battery_2condensators # pass problem class
+ description['problem_class'] = battery_n_condensators # pass problem class
description['problem_params'] = problem_params # pass problem parameters
description['sweeper_class'] = imex_1st_order # pass sweeper
description['sweeper_params'] = sweeper_params # pass sweeper parameters
description['level_params'] = level_params # pass level parameters
description['step_params'] = step_params
- description['space_transfer_class'] = mesh_to_mesh # pass spatial transfer class
if use_switch_estimator:
description['convergence_controllers'] = convergence_controllers
- proof_assertions_description(description, problem_params)
+ proof_assertions_description(description, use_switch_estimator)
# set time parameters
t0 = 0.0
@@ -144,39 +148,24 @@ def main(use_switch_estimator=True):
dill.dump(stats, f)
f.close()
- # filter statistics by number of iterations
- iter_counts = get_sorted(stats, type='niter', sortby='time')
-
- # compute and print statistics
- min_iter = 20
- max_iter = 0
-
- f = open('battery_2condensators_out.txt', 'w')
- niters = np.array([item[1] for item in iter_counts])
- out = ' Mean number of iterations: %4.2f' % np.mean(niters)
- f.write(out + '\n')
- print(out)
- for item in iter_counts:
- out = 'Number of iterations for time %4.2f: %1i' % item
- f.write(out + '\n')
- # print(out)
- min_iter = min(min_iter, item[1])
- max_iter = max(max_iter, item[1])
-
- restarts = np.array(get_sorted(stats, type='restart', recomputed=False))[:, 1]
- print("Restarts for dt: ", level_params['dt'], " -- ", np.sum(restarts))
-
- assert np.mean(niters) <= 10, "Mean number of iterations is too high, got %s" % np.mean(niters)
- f.close()
+ recomputed = False
+
+ check_solution(stats, use_switch_estimator)
- plot_voltages(description, use_switch_estimator)
+ plot_voltages(description, recomputed, use_switch_estimator)
- return np.mean(niters)
+ return description
-def plot_voltages(description, use_switch_estimator, cwd='./'):
+def plot_voltages(description, recomputed, use_switch_estimator, cwd='./'):
"""
Routine to plot the numerical solution of the model
+
+ Args:
+ description(dict): contains all information for a controller run
+ recomputed (bool): flag if the values after a restart are used or before
+ use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+ cwd: current working directory
"""
f = open(cwd + 'data/battery_2condensators.dat', 'rb')
@@ -184,9 +173,9 @@ def plot_voltages(description, use_switch_estimator, cwd='./'):
f.close()
# convert filtered statistics to list of iterations count, sorted by process
- cL = get_sorted(stats, type='current L', sortby='time')
- vC1 = get_sorted(stats, type='voltage C1', sortby='time')
- vC2 = get_sorted(stats, type='voltage C2', sortby='time')
+ cL = get_sorted(stats, type='current L', recomputed=recomputed, sortby='time')
+ vC1 = get_sorted(stats, type='voltage C1', recomputed=recomputed, sortby='time')
+ vC2 = get_sorted(stats, type='voltage C2', recomputed=recomputed, sortby='time')
times = [v[0] for v in cL]
@@ -197,11 +186,13 @@ def plot_voltages(description, use_switch_estimator, cwd='./'):
ax.plot(times, [v[1] for v in vC2], label='$v_{C_2}$')
if use_switch_estimator:
- t_switch_plot = np.zeros(np.shape(description['problem_params']['t_switch'])[0])
- for i in range(np.shape(description['problem_params']['t_switch'])[0]):
- t_switch_plot[i] = description['problem_params']['t_switch'][i]
+ switches = get_recomputed(stats, type='switch', sortby='time')
+ if recomputed is not None:
+ assert len(switches) >= 2, f"Expected at least 2 switches, got {len(switches)}!"
+ t_switches = [v[1] for v in switches]
- ax.axvline(x=t_switch_plot[i], linestyle='--', color='k', label='Switch {}'.format(i + 1))
+ for i in range(len(t_switches)):
+ ax.axvline(x=t_switches[i], linestyle='--', color='k', label='Switch {}'.format(i + 1))
ax.legend(frameon=False, fontsize=12, loc='upper right')
@@ -212,30 +203,108 @@ def plot_voltages(description, use_switch_estimator, cwd='./'):
plt_helper.plt.close(fig)
-def proof_assertions_description(description, problem_params):
+def check_solution(stats, use_switch_estimator):
"""
- Function to proof the assertions (function to get cleaner code)
+ Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen.
+
+ Args:
+ stats (dict): Raw statistics from a controller run
+ use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+ """
+
+ data = get_data_dict(stats, use_switch_estimator)
+
+ if use_switch_estimator:
+ msg = 'Error when using the switch estimator for battery_2condensators:'
+ expected = {
+ 'cL': 1.2065280755094876,
+ 'vC1': 1.0094825899806945,
+ 'vC2': 1.0050052828742688,
+ 'switch1': 1.6094379124373626,
+ 'switch2': 3.209437912457051,
+ 'restarts': 2.0,
+ 'sum_niters': 1568,
+ }
+
+ got = {
+ 'cL': data['cL'][-1],
+ 'vC1': data['vC1'][-1],
+ 'vC2': data['vC2'][-1],
+ 'switch1': data['switch1'],
+ 'switch2': data['switch2'],
+ 'restarts': data['restarts'],
+ 'sum_niters': data['sum_niters'],
+ }
+
+ for key in expected.keys():
+ assert np.isclose(
+ expected[key], got[key], rtol=1e-4
+ ), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}'
+
+
+def get_data_dict(stats, use_switch_estimator, recomputed=False):
"""
+ Converts the statistics in a useful data dictionary so that it can be easily checked in the check_solution function.
+ Based on @brownbaerchen's get_data function.
+
+ Args:
+ stats (dict): Raw statistics from a controller run
+ use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+ recomputed (bool): flag if the values after a restart are used or before
- assert problem_params['alpha'] > problem_params['V_ref'][0], 'Please set "alpha" greater than "V_ref1"'
- assert problem_params['alpha'] > problem_params['V_ref'][1], 'Please set "alpha" greater than "V_ref2"'
+ Return:
+ data (dict): contains all information as the statistics dict
+ """
- assert problem_params['V_ref'][0] > 0, 'Please set "V_ref1" greater than 0'
- assert problem_params['V_ref'][1] > 0, 'Please set "V_ref2" greater than 0'
+ data = dict()
+ data['cL'] = np.array(get_sorted(stats, type='current L', recomputed=recomputed, sortby='time'))[:, 1]
+ data['vC1'] = np.array(get_sorted(stats, type='voltage C1', recomputed=recomputed, sortby='time'))[:, 1]
+ data['vC2'] = np.array(get_sorted(stats, type='voltage C2', recomputed=recomputed, sortby='time'))[:, 1]
+ data['switch1'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[0, 1]
+ data['switch2'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[-1, 1]
+ data['restarts'] = np.sum(np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))[:, 1])
+ data['sum_niters'] = np.sum(np.array(get_sorted(stats, type='niter', recomputed=None, sortby='time'))[:, 1])
- assert type(problem_params['V_ref']) == np.ndarray, '"V_ref" needs to be an array (of type float)'
- assert not problem_params['set_switch'][0], 'First entry of "set_switch" needs to be False'
- assert not problem_params['set_switch'][1], 'Second entry of "set_switch" needs to be False'
+ return data
- assert not type(problem_params['t_switch']) == float, '"t_switch" has to be an array with entry zero'
- assert problem_params['t_switch'][0] == 0, 'First entry of "t_switch" needs to be zero'
- assert problem_params['t_switch'][1] == 0, 'Second entry of "t_switch" needs to be zero'
+def proof_assertions_description(description, use_switch_estimator):
+ """
+ Function to proof the assertions (function to get cleaner code)
+
+ Args:
+ description(dict): contains all information for a controller run
+ use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+ """
+
+ assert (
+ description['problem_params']['alpha'] > description['problem_params']['V_ref'][0]
+ ), 'Please set "alpha" greater than "V_ref1"'
+ assert (
+ description['problem_params']['alpha'] > description['problem_params']['V_ref'][1]
+ ), 'Please set "alpha" greater than "V_ref2"'
+
+ if description['problem_params']['ncondensators'] > 1:
+ assert (
+ type(description['problem_params']['V_ref']) == np.ndarray
+ ), '"V_ref" needs to be an array (of type float)'
+ assert (
+ description['problem_params']['ncondensators'] == np.shape(description['problem_params']['V_ref'])[0]
+ ), 'Number of reference values needs to be equal to number of condensators'
+ assert (
+ description['problem_params']['ncondensators'] == np.shape(description['problem_params']['C'])[0]
+ ), 'Number of capacitance values needs to be equal to number of condensators'
+
+ assert description['problem_params']['V_ref'][0] > 0, 'Please set "V_ref1" greater than 0'
+ assert description['problem_params']['V_ref'][1] > 0, 'Please set "V_ref2" greater than 0'
assert 'errtol' not in description['step_params'].keys(), 'No exact solution known to compute error'
assert 'alpha' in description['problem_params'].keys(), 'Please supply "alpha" in the problem parameters'
assert 'V_ref' in description['problem_params'].keys(), 'Please supply "V_ref" in the problem parameters'
+ if use_switch_estimator:
+ assert description['level_params']['restol'] == -1, "Please set restol to -1 or omit it"
+
if __name__ == "__main__":
main()
diff --git a/pySDC/projects/PinTSimE/battery_model.py b/pySDC/projects/PinTSimE/battery_model.py
index a127687eb5..a5336bc035 100644
--- a/pySDC/projects/PinTSimE/battery_model.py
+++ b/pySDC/projects/PinTSimE/battery_model.py
@@ -2,7 +2,7 @@
import dill
from pathlib import Path
-from pySDC.helpers.stats_helper import get_sorted
+from pySDC.helpers.stats_helper import sort_stats, filter_stats, get_sorted
from pySDC.core.Collocation import CollBase as Collocation
from pySDC.implementations.problem_classes.Battery import battery, battery_implicit
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
@@ -62,11 +62,31 @@ def post_step(self, step, level_number):
type='dt',
value=L.dt,
)
+ self.add_to_stats(
+ process=step.status.slot,
+ time=L.time + L.dt,
+ level=L.level_index,
+ iter=0,
+ sweep=L.status.sweep,
+ type='e_embedded',
+ value=L.status.get('error_embedded_estimate'),
+ )
def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity):
"""
A simple test program to do SDC/PFASST runs for the battery drain model
+
+ Args:
+ dt (float): time step for computation
+ problem (problem_class.__name__): problem class that wants to be simulated
+ sweeper (sweeper_class.__name__): sweeper class for solving the problem class numerically
+ use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+ use_adaptivity (bool): flag if the adaptivity wants to be used or not
+
+ Returns:
+ stats (dict): Raw statistics from a controller run
+ description (dict): contains all information for a controller run
"""
# initialize level parameters
@@ -79,12 +99,13 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity):
sweeper_params['quad_type'] = 'LOBATTO'
sweeper_params['num_nodes'] = 5
# sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part
- sweeper_params['initial_guess'] = 'zero'
+ sweeper_params['initial_guess'] = 'spread'
# initialize problem parameters
problem_params = dict()
problem_params['newton_maxiter'] = 200
problem_params['newton_tol'] = 1e-08
+ problem_params['ncondensators'] = 1 # number of condensators
problem_params['Vs'] = 5.0
problem_params['Rs'] = 0.5
problem_params['C'] = 1.0
@@ -92,8 +113,6 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity):
problem_params['L'] = 1.0
problem_params['alpha'] = 1.2
problem_params['V_ref'] = 1.0
- problem_params['set_switch'] = np.array([False], dtype=bool)
- problem_params['t_switch'] = np.zeros(1)
# initialize step parameters
step_params = dict()
@@ -101,7 +120,7 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity):
# initialize controller parameters
controller_params = dict()
- controller_params['logger_level'] = 20
+ controller_params['logger_level'] = 30
controller_params['hook_class'] = log_data
controller_params['mssdc_jac'] = False
@@ -113,7 +132,7 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity):
if use_adaptivity:
adaptivity_params = dict()
- adaptivity_params['e_tol'] = 1e-12
+ adaptivity_params['e_tol'] = 1e-7
convergence_controllers.update({Adaptivity: adaptivity_params})
# fill description dictionary for easy step instantiation
@@ -128,12 +147,19 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity):
if use_switch_estimator or use_adaptivity:
description['convergence_controllers'] = convergence_controllers
- proof_assertions_description(description, problem_params)
-
# set time parameters
t0 = 0.0
Tend = 0.3
+ proof_assertions_description(description, use_adaptivity, use_switch_estimator)
+
+ assert dt < Tend, "Time step is too large for the time domain!"
+
+ assert (
+ Tend == 0.3 and description['problem_params']['V_ref'] == 1.0 and description['problem_params']['alpha'] == 1.2
+ ), "Error! Do not use other parameters for V_ref != 1.0, alpha != 1.2, Tend != 0.3 due to hardcoded reference!"
+ assert description['level_params']['dt'] == 1e-2, "Error! Do not use another time step dt!= 1e-2!"
+
# instantiate controller
controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
@@ -144,35 +170,13 @@ def main(dt, problem, sweeper, use_switch_estimator, use_adaptivity):
# call main function to get things done...
uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
- # filter statistics by number of iterations
- iter_counts = get_sorted(stats, type='niter', recomputed=False, sortby='time')
-
- # compute and print statistics
- min_iter = 20
- max_iter = 0
-
Path("data").mkdir(parents=True, exist_ok=True)
fname = 'data/battery_{}_USE{}_USA{}.dat'.format(sweeper.__name__, use_switch_estimator, use_adaptivity)
f = open(fname, 'wb')
dill.dump(stats, f)
f.close()
- f = open('data/battery_out.txt', 'w')
- niters = np.array([item[1] for item in iter_counts])
- out = ' Mean number of iterations: %4.2f' % np.mean(niters)
- f.write(out + '\n')
- print(out)
- for item in iter_counts:
- out = 'Number of iterations for time %4.2f: %1i' % item
- f.write(out + '\n')
- print(out)
- min_iter = min(min_iter, item[1])
- max_iter = max(max_iter, item[1])
-
- assert np.mean(niters) <= 5, "Mean number of iterations is too high, got %s" % np.mean(niters)
- f.close()
-
- return description
+ return stats, description
def run():
@@ -184,13 +188,14 @@ def run():
dt = 1e-2
problem_classes = [battery, battery_implicit]
sweeper_classes = [imex_1st_order, generic_implicit]
+ recomputed = False
use_switch_estimator = [True]
use_adaptivity = [True]
for problem, sweeper in zip(problem_classes, sweeper_classes):
for use_SE in use_switch_estimator:
for use_A in use_adaptivity:
- description = main(
+ stats, description = main(
dt=dt,
problem=problem,
sweeper=sweeper,
@@ -198,12 +203,23 @@ def run():
use_adaptivity=use_A,
)
- plot_voltages(description, problem.__name__, sweeper.__name__, use_SE, use_A)
+ check_solution(stats, problem.__name__, use_adaptivity, use_switch_estimator)
+
+ plot_voltages(description, problem.__name__, sweeper.__name__, recomputed, use_SE, use_A)
-def plot_voltages(description, problem, sweeper, use_switch_estimator, use_adaptivity, cwd='./'):
+def plot_voltages(description, problem, sweeper, recomputed, use_switch_estimator, use_adaptivity, cwd='./'):
"""
Routine to plot the numerical solution of the model
+
+ Args:
+ description(dict): contains all information for a controller run
+ problem (problem_class.__name__): problem class that wants to be simulated
+ sweeper (sweeper_class.__name__): sweeper class for solving the problem class numerically
+ recomputed (bool): flag if the values after a restart are used or before
+ use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+ use_adaptivity (bool): flag if adaptivity wants to be used or not
+ cwd: current working directory
"""
f = open(cwd + 'data/battery_{}_USE{}_USA{}.dat'.format(sweeper, use_switch_estimator, use_adaptivity), 'rb')
@@ -223,12 +239,15 @@ def plot_voltages(description, problem, sweeper, use_switch_estimator, use_adapt
ax.plot(times, [v[1] for v in vC], label=r'$v_C$')
if use_switch_estimator:
- val_switch = get_sorted(stats, type='switch1', sortby='time')
- t_switch = [v[1] for v in val_switch]
+ switches = get_recomputed(stats, type='switch', sortby='time')
+
+ assert len(switches) >= 1, 'No switches found!'
+ t_switch = [v[1] for v in switches]
ax.axvline(x=t_switch[-1], linestyle='--', linewidth=0.8, color='r', label='Switch')
if use_adaptivity:
dt = np.array(get_sorted(stats, type='dt', recomputed=False))
+
dt_ax = ax.twinx()
dt_ax.plot(dt[:, 0], dt[:, 1], linestyle='-', linewidth=0.8, color='k', label=r'$\Delta t$')
dt_ax.set_ylabel(r'$\Delta t$', fontsize=8)
@@ -245,23 +264,146 @@ def plot_voltages(description, problem, sweeper, use_switch_estimator, use_adapt
plt_helper.plt.close(fig)
-def proof_assertions_description(description, problem_params):
+def check_solution(stats, problem, use_adaptivity, use_switch_estimator):
"""
- Function to proof the assertions (function to get cleaner code)
+ Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen.
+
+ Args:
+ stats (dict): Raw statistics from a controller run
+ problem (problem_class.__name__): the problem_class that is numerically solved
+ use_adaptivity (bool): flag if adaptivity wants to be used or not
+ use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+ """
+
+ data = get_data_dict(stats, use_adaptivity, use_switch_estimator)
+
+ if use_switch_estimator and use_adaptivity:
+ if problem == 'battery':
+ msg = 'Error when using switch estimator and adaptivity for battery:'
+ expected = {
+ 'cL': 0.5474500710994862,
+ 'vC': 1.0019332967173764,
+ 'dt': 0.011761752270047832,
+ 'e_em': 8.001793672107738e-10,
+ 'switches': 0.18232155791181945,
+ 'restarts': 3.0,
+ 'sum_niters': 44,
+ }
+ elif problem == 'battery_implicit':
+ msg = 'Error when using switch estimator and adaptivity for battery_implicit:'
+ expected = {
+ 'cL': 0.5424577937840791,
+ 'vC': 1.0001051105894005,
+ 'dt': 0.01,
+ 'e_em': 2.220446049250313e-16,
+ 'switches': 0.1822923488448394,
+ 'restarts': 6.0,
+ 'sum_niters': 60,
+ }
+
+ got = {
+ 'cL': data['cL'][-1],
+ 'vC': data['vC'][-1],
+ 'dt': data['dt'][-1],
+ 'e_em': data['e_em'][-1],
+ 'switches': data['switches'][-1],
+ 'restarts': data['restarts'],
+ 'sum_niters': data['sum_niters'],
+ }
+
+ for key in expected.keys():
+ assert np.isclose(
+ expected[key], got[key], rtol=1e-4
+ ), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}'
+
+
+def get_data_dict(stats, use_adaptivity=True, use_switch_estimator=True, recomputed=False):
+ """
+ Converts the statistics in a useful data dictionary so that it can be easily checked in the check_solution function.
+ Based on @brownbaerchen's get_data function.
+
+ Args:
+ stats (dict): Raw statistics from a controller run
+ use_adaptivity (bool): flag if adaptivity wants to be used or not
+ use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+ recomputed (bool): flag if the values after a restart are used or before
+
+ Return:
+ data (dict): contains all information as the statistics dict
+ """
+
+ data = dict()
+
+ data['cL'] = np.array(get_sorted(stats, type='current L', recomputed=recomputed, sortby='time'))[:, 1]
+ data['vC'] = np.array(get_sorted(stats, type='voltage C', recomputed=recomputed, sortby='time'))[:, 1]
+ if use_adaptivity:
+ data['dt'] = np.array(get_sorted(stats, type='dt', recomputed=recomputed, sortby='time'))[:, 1]
+ data['e_em'] = np.array(
+ get_sorted(stats, type='error_embedded_estimate', recomputed=recomputed, sortby='time')
+ )[:, 1]
+ if use_switch_estimator:
+ data['switches'] = np.array(get_recomputed(stats, type='switch', sortby='time'))[:, 1]
+ if use_adaptivity or use_switch_estimator:
+ data['restarts'] = np.sum(np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))[:, 1])
+ data['sum_niters'] = np.sum(np.array(get_sorted(stats, type='niter', recomputed=None, sortby='time'))[:, 1])
+
+ return data
+
+
+def get_recomputed(stats, type, sortby):
"""
+ Function that filters statistics after a recomputation. It stores all value of a type before restart. If there are multiple values
+ with same time point, it only stores the elements with unique times.
- assert problem_params['alpha'] > problem_params['V_ref'], 'Please set "alpha" greater than "V_ref"'
- assert problem_params['V_ref'] > 0, 'Please set "V_ref" greater than 0'
- assert type(problem_params['V_ref']) == float, '"V_ref" needs to be of type float'
+ Args:
+ stats (dict): Raw statistics from a controller run
+ type (str): the type the be filtered
+ sortby (str): string to specify which key to use for sorting
- assert type(problem_params['set_switch'][0]) == np.bool_, '"set_switch" has to be an bool array'
- assert type(problem_params['t_switch']) == np.ndarray, '"t_switch" has to be an array'
- assert problem_params['t_switch'][0] == 0, '"t_switch" is only allowed to have entry zero'
+ Returns:
+ sorted_list (list): list of filtered statistics
+ """
+
+ sorted_nested_list = []
+ times_unique = np.unique([me[0] for me in get_sorted(stats, type=type)])
+ filtered_list = [
+ filter_stats(
+ stats,
+ time=t_unique,
+ num_restarts=max([me.num_restarts for me in filter_stats(stats, type=type, time=t_unique).keys()]),
+ type=type,
+ )
+ for t_unique in times_unique
+ ]
+ for item in filtered_list:
+ sorted_nested_list.append(sort_stats(item, sortby=sortby))
+ sorted_list = [item for sub_item in sorted_nested_list for item in sub_item]
+ return sorted_list
+
+
+def proof_assertions_description(description, use_adaptivity, use_switch_estimator):
+ """
+ Function to proof the assertions (function to get cleaner code)
+
+ Args:
+ description(dict): contains all information for a controller run
+ use_adaptivity (bool): flag if adaptivity wants to be used or not
+ use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+ """
+
+ assert (
+ description['problem_params']['alpha'] > description['problem_params']['V_ref']
+ ), 'Please set "alpha" greater than "V_ref"'
+ assert description['problem_params']['V_ref'] > 0, 'Please set "V_ref" greater than 0'
+ assert type(description['problem_params']['V_ref']) == float, '"V_ref" needs to be of type float'
assert 'errtol' not in description['step_params'].keys(), 'No exact solution known to compute error'
assert 'alpha' in description['problem_params'].keys(), 'Please supply "alpha" in the problem parameters'
assert 'V_ref' in description['problem_params'].keys(), 'Please supply "V_ref" in the problem parameters'
+ if use_switch_estimator or use_adaptivity:
+ assert description['level_params']['restol'] == -1, "For adaptivity, please set restol to -1 or omit it"
+
if __name__ == "__main__":
run()
diff --git a/pySDC/projects/PinTSimE/estimation_check.py b/pySDC/projects/PinTSimE/estimation_check.py
index e565bd11d3..aad507a9bd 100644
--- a/pySDC/projects/PinTSimE/estimation_check.py
+++ b/pySDC/projects/PinTSimE/estimation_check.py
@@ -9,17 +9,29 @@
from pySDC.implementations.sweeper_classes.generic_implicit import generic_implicit
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
from pySDC.projects.PinTSimE.piline_model import setup_mpl
-from pySDC.projects.PinTSimE.battery_model import log_data, proof_assertions_description
+from pySDC.projects.PinTSimE.battery_model import get_recomputed, get_data_dict, log_data, proof_assertions_description
import pySDC.helpers.plot_helper as plt_helper
+from pySDC.core.Hooks import hooks
from pySDC.projects.PinTSimE.switch_estimator import SwitchEstimator
from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
-from pySDC.implementations.convergence_controller_classes.estimate_embedded_error import EstimateEmbeddedErrorNonMPI
def run(dt, problem, sweeper, use_switch_estimator, use_adaptivity, V_ref):
"""
A simple test program to do SDC/PFASST runs for the battery drain model
+
+ Args:
+ dt (np.float): (initial) time step
+ problem (problem_class): the considered problem class (here: battery or battery_implicit)
+ sweeper (sweeper_class): the used sweeper class to solve (here: imex_1st_order or generic_implicit)
+ use_switch_estimator (bool): Switch estimator should be used or not
+ use_adaptivity (bool): Adaptivity should be used or not
+ V_ref (np.float): reference value for the switch
+
+ Returns:
+ description (dict): contains all the information for the controller
+ stats (dict): includes the statistics of the solve
"""
# initialize level parameters
@@ -32,12 +44,13 @@ def run(dt, problem, sweeper, use_switch_estimator, use_adaptivity, V_ref):
sweeper_params['quad_type'] = 'LOBATTO'
sweeper_params['num_nodes'] = 5
# sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part
- sweeper_params['initial_guess'] = 'zero'
+ sweeper_params['initial_guess'] = 'spread'
# initialize problem parameters
problem_params = dict()
problem_params['newton_maxiter'] = 200
problem_params['newton_tol'] = 1e-08
+ problem_params['ncondensators'] = 1
problem_params['Vs'] = 5.0
problem_params['Rs'] = 0.5
problem_params['C'] = 1.0
@@ -45,8 +58,6 @@ def run(dt, problem, sweeper, use_switch_estimator, use_adaptivity, V_ref):
problem_params['L'] = 1.0
problem_params['alpha'] = 1.2
problem_params['V_ref'] = V_ref
- problem_params['set_switch'] = np.array([False], dtype=bool)
- problem_params['t_switch'] = np.zeros(1)
# initialize step parameters
step_params = dict()
@@ -54,7 +65,7 @@ def run(dt, problem, sweeper, use_switch_estimator, use_adaptivity, V_ref):
# initialize controller parameters
controller_params = dict()
- controller_params['logger_level'] = 20
+ controller_params['logger_level'] = 30
controller_params['hook_class'] = log_data
controller_params['mssdc_jac'] = False
@@ -82,12 +93,19 @@ def run(dt, problem, sweeper, use_switch_estimator, use_adaptivity, V_ref):
if use_switch_estimator or use_adaptivity:
description['convergence_controllers'] = convergence_controllers
- proof_assertions_description(description, problem_params)
+ proof_assertions_description(description, use_adaptivity, use_switch_estimator)
# set time parameters
t0 = 0.0
Tend = 0.3
+ assert dt < Tend, "Time step is too large for the time domain!"
+
+ assert (
+ Tend == 0.3 and description['problem_params']['V_ref'] == 1.0 and description['problem_params']['alpha'] == 1.2
+ ), "Error! Do not use other parameters for V_ref != 1.0, alpha != 1.2, Tend != 0.3 due to hardcoded reference!"
+ assert dt == 4e-2 or dt == 4e-3, "Error! Do not use another time step dt!= 1e-2!"
+
# instantiate controller
controller = controller_nonMPI(num_procs=1, controller_params=controller_params, description=description)
@@ -98,37 +116,24 @@ def run(dt, problem, sweeper, use_switch_estimator, use_adaptivity, V_ref):
# call main function to get things done...
uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
- Path("data").mkdir(parents=True, exist_ok=True)
- fname = 'data/battery.dat'
- f = open(fname, 'wb')
- dill.dump(stats, f)
- f.close()
-
- # filter statistics by number of iterations
- iter_counts = get_sorted(stats, type='niter', recomputed=False, sortby='time')
-
- # compute and print statistics
- f = open('data/battery_out.txt', 'w')
- niters = np.array([item[1] for item in iter_counts])
-
- assert np.mean(niters) <= 11, "Mean number of iterations is too high, got %s" % np.mean(niters)
- f.close()
-
return description, stats
def check(cwd='./'):
"""
Routine to check the differences between using a switch estimator or not
+
+ Args:
+ cwd: current working directory
"""
V_ref = 1.0
- dt_list = [1e-2, 1e-3]
+ dt_list = [4e-2, 4e-3]
use_switch_estimator = [True, False]
use_adaptivity = [True, False]
- restarts_true = []
- restarts_false_adapt = []
- restarts_true_adapt = []
+ restarts_SE = []
+ restarts_adapt = []
+ restarts_SE_adapt = []
problem_classes = [battery, battery_implicit]
sweeper_classes = [imex_1st_order, generic_implicit]
@@ -146,6 +151,14 @@ def check(cwd='./'):
V_ref=V_ref,
)
+ if use_A or use_SE:
+ check_solution(stats, dt_item, problem.__name__, use_A, use_SE)
+
+ if use_SE:
+ assert (
+ len(get_recomputed(stats, type='switch', sortby='time')) >= 1
+ ), 'No switches found for dt={}!'.format(dt_item)
+
fname = 'data/battery_dt{}_USE{}_USA{}_{}.dat'.format(dt_item, use_SE, use_A, sweeper.__name__)
f = open(fname, 'wb')
dill.dump(stats, f)
@@ -153,24 +166,23 @@ def check(cwd='./'):
if use_SE or use_A:
restarts_sorted = np.array(get_sorted(stats, type='restart', recomputed=None))[:, 1]
- print('Restarts for dt={}: {}'.format(dt_item, np.sum(restarts_sorted)))
if use_SE and not use_A:
- restarts_true.append(np.sum(restarts_sorted))
+ restarts_SE.append(np.sum(restarts_sorted))
elif not use_SE and use_A:
- restarts_false_adapt.append(np.sum(restarts_sorted))
+ restarts_adapt.append(np.sum(restarts_sorted))
elif use_SE and use_A:
- restarts_true_adapt.append(np.sum(restarts_sorted))
+ restarts_SE_adapt.append(np.sum(restarts_sorted))
accuracy_check(dt_list, problem.__name__, sweeper.__name__, V_ref)
differences_around_switch(
dt_list,
problem.__name__,
- restarts_true,
- restarts_false_adapt,
- restarts_true_adapt,
+ restarts_SE,
+ restarts_adapt,
+ restarts_SE_adapt,
sweeper.__name__,
V_ref,
)
@@ -179,14 +191,21 @@ def check(cwd='./'):
iterations_over_time(dt_list, description['step_params']['maxiter'], problem.__name__, sweeper.__name__)
- restarts_true = []
- restarts_false_adapt = []
- restarts_true_adapt = []
+ restarts_SE = []
+ restarts_adapt = []
+ restarts_SE_adapt = []
def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'):
"""
Routine to check accuracy for different step sizes in case of using adaptivity
+
+ Args:
+ dt_list (list): list of considered (initial) step sizes
+ problem (problem.__name__): Problem class used to consider (the class name)
+ sweeper (sweeper.__name__): Sweeper used to solve (the class name)
+ V_ref (np.float): reference value for the switch
+ cwd: current working directory
"""
if len(dt_list) > 1:
@@ -202,45 +221,49 @@ def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'):
count_ax = 0
for dt_item in dt_list:
f3 = open(cwd + 'data/battery_dt{}_USETrue_USATrue_{}.dat'.format(dt_item, sweeper), 'rb')
- stats_TT = dill.load(f3)
+ stats_SE_adapt = dill.load(f3)
f3.close()
f4 = open(cwd + 'data/battery_dt{}_USEFalse_USATrue_{}.dat'.format(dt_item, sweeper), 'rb')
- stats_FT = dill.load(f4)
+ stats_adapt = dill.load(f4)
f4.close()
- val_switch_TT = get_sorted(stats_TT, type='switch1', sortby='time')
- t_switch_adapt = [v[1] for v in val_switch_TT]
- t_switch_adapt = t_switch_adapt[-1]
+ switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time')
+ t_switch_SE_adapt = [v[1] for v in switches_SE_adapt]
+ t_switch_SE_adapt = t_switch_SE_adapt[-1]
- dt_TT_val = get_sorted(stats_TT, type='dt', recomputed=False)
- dt_FT_val = get_sorted(stats_FT, type='dt', recomputed=False)
+ dt_SE_adapt_val = get_sorted(stats_SE_adapt, type='dt', recomputed=False)
+ dt_adapt_val = get_sorted(stats_adapt, type='dt', recomputed=False)
- e_emb_TT_val = get_sorted(stats_TT, type='e_embedded', recomputed=False)
- e_emb_FT_val = get_sorted(stats_FT, type='e_embedded', recomputed=False)
+ e_emb_SE_adapt_val = get_sorted(stats_SE_adapt, type='e_embedded', recomputed=False)
+ e_emb_adapt_val = get_sorted(stats_adapt, type='e_embedded', recomputed=False)
- times_TT = [v[0] for v in e_emb_TT_val]
- times_FT = [v[0] for v in e_emb_FT_val]
+ times_SE_adapt = [v[0] for v in e_emb_SE_adapt_val]
+ times_adapt = [v[0] for v in e_emb_adapt_val]
- e_emb_TT = [v[1] for v in e_emb_TT_val]
- e_emb_FT = [v[1] for v in e_emb_FT_val]
+ e_emb_SE_adapt = [v[1] for v in e_emb_SE_adapt_val]
+ e_emb_adapt = [v[1] for v in e_emb_adapt_val]
if len(dt_list) > 1:
- ax_acc[count_ax].set_title(r'$\Delta t$={}'.format(dt_item))
+ ax_acc[count_ax].set_title(r'$\Delta t_\mathrm{initial}$=%s' % dt_item)
dt1 = ax_acc[count_ax].plot(
- [v[0] for v in dt_TT_val], [v[1] for v in dt_TT_val], 'ko-', label=r'SE+A - $\Delta t$'
+ [v[0] for v in dt_SE_adapt_val],
+ [v[1] for v in dt_SE_adapt_val],
+ 'ko-',
+ label=r'SE+A - $\Delta t_\mathrm{adapt}$',
)
dt2 = ax_acc[count_ax].plot(
- [v[0] for v in dt_FT_val], [v[1] for v in dt_FT_val], 'g-', label=r'A - $\Delta t$'
+ [v[0] for v in dt_adapt_val], [v[1] for v in dt_adapt_val], 'g-', label=r'A - $\Delta t_\mathrm{adapt}$'
)
- ax_acc[count_ax].axvline(x=t_switch_adapt, linestyle='--', linewidth=0.5, color='r', label='Switch')
+ ax_acc[count_ax].axvline(x=t_switch_SE_adapt, linestyle='--', linewidth=0.5, color='r', label='Switch')
+ ax_acc[count_ax].tick_params(axis='both', which='major', labelsize=6)
ax_acc[count_ax].set_xlabel('Time', fontsize=6)
if count_ax == 0:
- ax_acc[count_ax].set_ylabel(r'$\Delta t_{adapted}$', fontsize=6)
+ ax_acc[count_ax].set_ylabel(r'$\Delta t_\mathrm{adapt}$', fontsize=6)
e_ax = ax_acc[count_ax].twinx()
- e_plt1 = e_ax.plot(times_TT, e_emb_TT, 'k--', label=r'SE+A - $\epsilon_{emb}$')
- e_plt2 = e_ax.plot(times_FT, e_emb_FT, 'g--', label=r'A - $\epsilon_{emb}$')
+ e_plt1 = e_ax.plot(times_SE_adapt, e_emb_SE_adapt, 'k--', label=r'SE+A - $\epsilon_{emb}$')
+ e_plt2 = e_ax.plot(times_adapt, e_emb_adapt, 'g--', label=r'A - $\epsilon_{emb}$')
e_ax.set_yscale('log', base=10)
e_ax.set_ylim(1e-16, 1e-7)
e_ax.tick_params(labelsize=6)
@@ -248,26 +271,37 @@ def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'):
lines = dt1 + e_plt1 + dt2 + e_plt2
labels = [l.get_label() for l in lines]
- ax_acc[count_ax].legend(lines, labels, frameon=False, fontsize=6, loc='upper left')
+ ax_acc[count_ax].legend(lines, labels, frameon=False, fontsize=6, loc='upper right')
else:
- ax_acc.set_title(r'$\Delta t$={}'.format(dt_item))
- dt1 = ax_acc.plot([v[0] for v in dt_TT_val], [v[1] for v in dt_TT_val], 'ko-', label=r'SE+A - $\Delta t$')
- dt2 = ax_acc.plot([v[0] for v in dt_FT_val], [v[1] for v in dt_FT_val], 'go-', label=r'A - $\Delta t$')
- ax_acc.axvline(x=t_switch_adapt, linestyle='--', linewidth=0.5, color='r', label='Switch')
+ ax_acc.set_title(r'$\Delta t_\mathrm{initial}$=%s' % dt_item)
+ dt1 = ax_acc.plot(
+ [v[0] for v in dt_SE_adapt_val],
+ [v[1] for v in dt_SE_adapt_val],
+ 'ko-',
+ label=r'SE+A - $\Delta t_\mathrm{adapt}$',
+ )
+ dt2 = ax_acc.plot(
+ [v[0] for v in dt_adapt_val],
+ [v[1] for v in dt_adapt_val],
+ 'go-',
+ label=r'A - $\Delta t_\mathrm{adapt}$',
+ )
+ ax_acc.axvline(x=t_switch_SE_adapt, linestyle='--', linewidth=0.5, color='r', label='Switch')
+ ax_acc.tick_params(axis='both', which='major', labelsize=6)
ax_acc.set_xlabel('Time', fontsize=6)
- ax_acc.set_ylabel(r'$Delta t_{adapted}$', fontsize=6)
+ ax_acc.set_ylabel(r'$Delta t_\mathrm{adapt}$', fontsize=6)
e_ax = ax_acc.twinx()
- e_plt1 = e_ax.plot(times_TT, e_emb_TT, 'k--', label=r'SE+A - $\epsilon_{emb}$')
- e_plt2 = e_ax.plot(times_FT, e_emb_FT, 'g--', label=r'A - $\epsilon_{emb}$')
+ e_plt1 = e_ax.plot(times_SE_adapt, e_emb_SE_adapt, 'k--', label=r'SE+A - $\epsilon_{emb}$')
+ e_plt2 = e_ax.plot(times_adapt, e_emb_adapt, 'g--', label=r'A - $\epsilon_{emb}$')
e_ax.set_yscale('log', base=10)
e_ax.tick_params(labelsize=6)
lines = dt1 + e_plt1 + dt2 + e_plt2
labels = [l.get_label() for l in lines]
- ax_acc.legend(lines, labels, frameon=False, fontsize=6, loc='upper left')
+ ax_acc.legend(lines, labels, frameon=False, fontsize=6, loc='upper right')
count_ax += 1
@@ -275,11 +309,221 @@ def accuracy_check(dt_list, problem, sweeper, V_ref, cwd='./'):
plt_helper.plt.close(fig_acc)
+def check_solution(stats, dt, problem, use_adaptivity, use_switch_estimator):
+ """
+ Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen.
+
+ Args:
+ stats (dict): Raw statistics from a controller run
+ dt (float): initial time step
+ problem (problem_class.__name__): the problem_class that is numerically solved
+ use_switch_estimator (bool):
+ use_adaptivity (bool):
+ """
+
+ data = get_data_dict(stats, use_adaptivity, use_switch_estimator)
+
+ if problem == 'battery':
+ if use_switch_estimator and use_adaptivity:
+ msg = f'Error when using switch estimator and adaptivity for battery for dt={dt:.1e}:'
+ if dt == 4e-2:
+ expected = {
+ 'cL': 0.5525783945667581,
+ 'vC': 1.00001743462299,
+ 'dt': 0.03550610373897258,
+ 'e_em': 6.21240694442804e-08,
+ 'switches': 0.18231603298272345,
+ 'restarts': 4.0,
+ 'sum_niters': 56,
+ }
+ elif dt == 4e-3:
+ expected = {
+ 'cL': 0.5395601429161445,
+ 'vC': 1.0000413761942089,
+ 'dt': 0.028281271825675414,
+ 'e_em': 2.5628611677319668e-08,
+ 'switches': 0.18230920573953438,
+ 'restarts': 3.0,
+ 'sum_niters': 48,
+ }
+
+ got = {
+ 'cL': data['cL'][-1],
+ 'vC': data['vC'][-1],
+ 'dt': data['dt'][-1],
+ 'e_em': data['e_em'][-1],
+ 'switches': data['switches'][-1],
+ 'restarts': data['restarts'],
+ 'sum_niters': data['sum_niters'],
+ }
+ elif use_switch_estimator and not use_adaptivity:
+ msg = f'Error when using switch estimator for battery for dt={dt:.1e}:'
+ if dt == 4e-2:
+ expected = {
+ 'cL': 0.6139093327509394,
+ 'vC': 1.0010140038721593,
+ 'switches': 0.1824302065533169,
+ 'restarts': 1.0,
+ 'sum_niters': 48,
+ }
+ elif dt == 4e-3:
+ expected = {
+ 'cL': 0.5429509935448258,
+ 'vC': 1.0001158309787614,
+ 'switches': 0.18232183080236553,
+ 'restarts': 1.0,
+ 'sum_niters': 392,
+ }
+
+ got = {
+ 'cL': data['cL'][-1],
+ 'vC': data['vC'][-1],
+ 'switches': data['switches'][-1],
+ 'restarts': data['restarts'],
+ 'sum_niters': data['sum_niters'],
+ }
+
+ elif not use_switch_estimator and use_adaptivity:
+ msg = f'Error when using adaptivity for battery for dt={dt:.1e}:'
+ if dt == 4e-2:
+ expected = {
+ 'cL': 0.5966289599915113,
+ 'vC': 0.9923148791604984,
+ 'dt': 0.03564958366355817,
+ 'e_em': 6.210964231812e-08,
+ 'restarts': 1.0,
+ 'sum_niters': 36,
+ }
+ elif dt == 4e-3:
+ expected = {
+ 'cL': 0.5431613774808756,
+ 'vC': 0.9934307674636834,
+ 'dt': 0.022880524075396924,
+ 'e_em': 1.1130212751453428e-08,
+ 'restarts': 3.0,
+ 'sum_niters': 52,
+ }
+
+ got = {
+ 'cL': data['cL'][-1],
+ 'vC': data['vC'][-1],
+ 'dt': data['dt'][-1],
+ 'e_em': data['e_em'][-1],
+ 'restarts': data['restarts'],
+ 'sum_niters': data['sum_niters'],
+ }
+
+ elif problem == 'battery_implicit':
+ if use_switch_estimator and use_adaptivity:
+ msg = f'Error when using switch estimator and adaptivity for battery_implicit for dt={dt:.1e}:'
+ if dt == 4e-2:
+ expected = {
+ 'cL': 0.6717104472882885,
+ 'vC': 1.0071670698947914,
+ 'dt': 0.035896059229296486,
+ 'e_em': 6.208836400567463e-08,
+ 'switches': 0.18232158833761175,
+ 'restarts': 3.0,
+ 'sum_niters': 36,
+ }
+ elif dt == 4e-3:
+ expected = {
+ 'cL': 0.5396216192241711,
+ 'vC': 1.0000561014463172,
+ 'dt': 0.009904645972832471,
+ 'e_em': 2.220446049250313e-16,
+ 'switches': 0.18230549652342606,
+ 'restarts': 4.0,
+ 'sum_niters': 44,
+ }
+
+ got = {
+ 'cL': data['cL'][-1],
+ 'vC': data['vC'][-1],
+ 'dt': data['dt'][-1],
+ 'e_em': data['e_em'][-1],
+ 'switches': data['switches'][-1],
+ 'restarts': data['restarts'],
+ 'sum_niters': data['sum_niters'],
+ }
+ elif use_switch_estimator and not use_adaptivity:
+ msg = f'Error when using switch estimator for battery_implicit for dt={dt:.1e}:'
+ if dt == 4e-2:
+ expected = {
+ 'cL': 0.613909968362315,
+ 'vC': 1.0010140112484431,
+ 'switches': 0.18243023230469263,
+ 'restarts': 1.0,
+ 'sum_niters': 48,
+ }
+ elif dt == 4e-3:
+ expected = {
+ 'cL': 0.5429616576526073,
+ 'vC': 1.0001158454740509,
+ 'switches': 0.1823218812753008,
+ 'restarts': 1.0,
+ 'sum_niters': 392,
+ }
+
+ got = {
+ 'cL': data['cL'][-1],
+ 'vC': data['vC'][-1],
+ 'switches': data['switches'][-1],
+ 'restarts': data['restarts'],
+ 'sum_niters': data['sum_niters'],
+ }
+
+ elif not use_switch_estimator and use_adaptivity:
+ msg = f'Error when using adaptivity for battery_implicit for dt={dt:.1e}:'
+ if dt == 4e-2:
+ expected = {
+ 'cL': 0.5556563012729733,
+ 'vC': 0.9930947318467772,
+ 'dt': 0.035507110551631804,
+ 'e_em': 6.2098696185231e-08,
+ 'restarts': 6.0,
+ 'sum_niters': 64,
+ }
+ elif dt == 4e-3:
+ expected = {
+ 'cL': 0.5401117929618637,
+ 'vC': 0.9933888475391347,
+ 'dt': 0.03176025170463925,
+ 'e_em': 4.0386798239033794e-08,
+ 'restarts': 8.0,
+ 'sum_niters': 80,
+ }
+
+ got = {
+ 'cL': data['cL'][-1],
+ 'vC': data['vC'][-1],
+ 'dt': data['dt'][-1],
+ 'e_em': data['e_em'][-1],
+ 'restarts': data['restarts'],
+ 'sum_niters': data['sum_niters'],
+ }
+
+ for key in expected.keys():
+ assert np.isclose(
+ expected[key], got[key], rtol=1e-4
+ ), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}'
+
+
def differences_around_switch(
- dt_list, problem, restarts_true, restarts_false_adapt, restarts_true_adapt, sweeper, V_ref, cwd='./'
+ dt_list, problem, restarts_SE, restarts_adapt, restarts_SE_adapt, sweeper, V_ref, cwd='./'
):
"""
Routine to plot the differences before, at, and after the switch. Produces the diffs_estimation_.png file
+
+ Args:
+ dt_list (list): list of considered (initial) step sizes
+ problem (problem.__name__): Problem class used to consider (the class name)
+ restarts_SE (list): Restarts for the solve only using the switch estimator
+ restarts_adapt (list): Restarts for the solve of only using adaptivity
+ restarts_SE_adapt (list): Restarts for the solve of using both, switch estimator and adaptivity
+ sweeper (sweeper.__name__): Sweeper used to solve (the class name)
+ V_ref (np.float): reference value for the switch
+ cwd: current working directory
"""
diffs_true_at = []
@@ -294,59 +538,59 @@ def differences_around_switch(
diffs_false_after_adapt = []
for dt_item in dt_list:
f1 = open(cwd + 'data/battery_dt{}_USETrue_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb')
- stats_TF = dill.load(f1)
+ stats_SE = dill.load(f1)
f1.close()
f2 = open(cwd + 'data/battery_dt{}_USEFalse_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb')
- stats_FF = dill.load(f2)
+ stats = dill.load(f2)
f2.close()
f3 = open(cwd + 'data/battery_dt{}_USETrue_USATrue_{}.dat'.format(dt_item, sweeper), 'rb')
- stats_TT = dill.load(f3)
+ stats_SE_adapt = dill.load(f3)
f3.close()
f4 = open(cwd + 'data/battery_dt{}_USEFalse_USATrue_{}.dat'.format(dt_item, sweeper), 'rb')
- stats_FT = dill.load(f4)
+ stats_adapt = dill.load(f4)
f4.close()
- val_switch_TF = get_sorted(stats_TF, type='switch1', sortby='time')
- t_switch = [v[1] for v in val_switch_TF]
+ switches_SE = get_recomputed(stats_SE, type='switch', sortby='time')
+ t_switch = [v[1] for v in switches_SE]
t_switch = t_switch[-1] # battery has only one single switch
- val_switch_TT = get_sorted(stats_TT, type='switch1', sortby='time')
- t_switch_adapt = [v[1] for v in val_switch_TT]
- t_switch_adapt = t_switch_adapt[-1]
+ switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time')
+ t_switch_SE_adapt = [v[1] for v in switches_SE_adapt]
+ t_switch_SE_adapt = t_switch_SE_adapt[-1]
- vC_TF = get_sorted(stats_TF, type='voltage C', recomputed=False, sortby='time')
- vC_FT = get_sorted(stats_FT, type='voltage C', recomputed=False, sortby='time')
- vC_TT = get_sorted(stats_TT, type='voltage C', recomputed=False, sortby='time')
- vC_FF = get_sorted(stats_FF, type='voltage C', sortby='time')
+ vC_SE = get_sorted(stats_SE, type='voltage C', recomputed=False, sortby='time')
+ vC_adapt = get_sorted(stats_adapt, type='voltage C', recomputed=False, sortby='time')
+ vC_SE_adapt = get_sorted(stats_SE_adapt, type='voltage C', recomputed=False, sortby='time')
+ vC = get_sorted(stats, type='voltage C', sortby='time')
- diff_TF, diff_FF = [v[1] - V_ref for v in vC_TF], [v[1] - V_ref for v in vC_FF]
- times_TF, times_FF = [v[0] for v in vC_TF], [v[0] for v in vC_FF]
+ diff_SE, diff = [v[1] - V_ref for v in vC_SE], [v[1] - V_ref for v in vC]
+ times_SE, times = [v[0] for v in vC_SE], [v[0] for v in vC]
- diff_FT, diff_TT = [v[1] - V_ref for v in vC_FT], [v[1] - V_ref for v in vC_TT]
- times_FT, times_TT = [v[0] for v in vC_FT], [v[0] for v in vC_TT]
+ diff_adapt, diff_SE_adapt = [v[1] - V_ref for v in vC_adapt], [v[1] - V_ref for v in vC_SE_adapt]
+ times_adapt, times_SE_adapt = [v[0] for v in vC_adapt], [v[0] for v in vC_SE_adapt]
- for m in range(len(times_TF)):
- if np.round(times_TF[m], 15) == np.round(t_switch, 15):
- diffs_true_at.append(diff_TF[m])
+ for m in range(len(times_SE)):
+ if np.round(times_SE[m], 15) == np.round(t_switch, 15):
+ diffs_true_at.append(diff_SE[m])
- for m in range(1, len(times_FF)):
- if times_FF[m - 1] <= t_switch <= times_FF[m]:
- diffs_false_before.append(diff_FF[m - 1])
- diffs_false_after.append(diff_FF[m])
+ for m in range(1, len(times)):
+ if times[m - 1] <= t_switch <= times[m]:
+ diffs_false_before.append(diff[m - 1])
+ diffs_false_after.append(diff[m])
- for m in range(len(times_TT)):
- if np.round(times_TT[m], 13) == np.round(t_switch_adapt, 13):
- diffs_true_at_adapt.append(diff_TT[m])
- diffs_true_before_adapt.append(diff_TT[m - 1])
- diffs_true_after_adapt.append(diff_TT[m + 1])
+ for m in range(len(times_SE_adapt)):
+ if np.round(times_SE_adapt[m], 13) == np.round(t_switch_SE_adapt, 13):
+ diffs_true_at_adapt.append(diff_SE_adapt[m])
+ diffs_true_before_adapt.append(diff_SE_adapt[m - 1])
+ diffs_true_after_adapt.append(diff_SE_adapt[m + 1])
- for m in range(len(times_FT)):
- if times_FT[m - 1] <= t_switch <= times_FT[m]:
- diffs_false_before_adapt.append(diff_FT[m - 1])
- diffs_false_after_adapt.append(diff_FT[m])
+ for m in range(len(times_adapt)):
+ if times_adapt[m - 1] <= t_switch <= times_adapt[m]:
+ diffs_false_before_adapt.append(diff_adapt[m - 1])
+ diffs_false_after_adapt.append(diff_adapt[m])
setup_mpl()
fig_around, ax_around = plt_helper.plt.subplots(1, 3, figsize=(9, 3), sharex='col', sharey='row')
@@ -356,14 +600,15 @@ def differences_around_switch(
pos13 = ax_around[0].plot(dt_list, diffs_true_at, 'ko--', label='at switch')
ax_around[0].set_xticks(dt_list)
ax_around[0].set_xticklabels(dt_list)
+ ax_around[0].tick_params(axis='both', which='major', labelsize=6)
ax_around[0].set_xscale('log', base=10)
ax_around[0].set_yscale('symlog', linthresh=1e-8)
ax_around[0].set_ylim(-1, 1)
- ax_around[0].set_xlabel(r'$\Delta t$', fontsize=6)
+ ax_around[0].set_xlabel(r'$\Delta t_\mathrm{initial}$', fontsize=6)
ax_around[0].set_ylabel(r'$v_{C}-V_{ref}$', fontsize=6)
restart_ax0 = ax_around[0].twinx()
- restarts_plt0 = restart_ax0.plot(dt_list, restarts_true, 'cs--', label='Restarts')
+ restarts_plt0 = restart_ax0.plot(dt_list, restarts_SE, 'cs--', label='Restarts')
restart_ax0.tick_params(labelsize=6)
lines = pos11 + pos12 + pos13 + restarts_plt0
@@ -375,13 +620,14 @@ def differences_around_switch(
pos22 = ax_around[1].plot(dt_list, diffs_false_after_adapt, 'bd--', label='after switch')
ax_around[1].set_xticks(dt_list)
ax_around[1].set_xticklabels(dt_list)
+ ax_around[1].tick_params(axis='both', which='major', labelsize=6)
ax_around[1].set_xscale('log', base=10)
ax_around[1].set_yscale('symlog', linthresh=1e-8)
ax_around[1].set_ylim(-1, 1)
- ax_around[1].set_xlabel(r'$\Delta t$', fontsize=6)
+ ax_around[1].set_xlabel(r'$\Delta t_\mathrm{initial}$', fontsize=6)
restart_ax1 = ax_around[1].twinx()
- restarts_plt1 = restart_ax1.plot(dt_list, restarts_false_adapt, 'cs--', label='Restarts')
+ restarts_plt1 = restart_ax1.plot(dt_list, restarts_adapt, 'cs--', label='Restarts')
restart_ax1.tick_params(labelsize=6)
lines = pos21 + pos22 + restarts_plt1
@@ -394,90 +640,101 @@ def differences_around_switch(
pos33 = ax_around[2].plot(dt_list, diffs_true_at_adapt, 'ko--', label='at switch')
ax_around[2].set_xticks(dt_list)
ax_around[2].set_xticklabels(dt_list)
+ ax_around[2].tick_params(axis='both', which='major', labelsize=6)
ax_around[2].set_xscale('log', base=10)
ax_around[2].set_yscale('symlog', linthresh=1e-8)
ax_around[2].set_ylim(-1, 1)
- ax_around[2].set_xlabel(r'$\Delta t$', fontsize=6)
+ ax_around[2].set_xlabel(r'$\Delta t_\mathrm{initial}$', fontsize=6)
restart_ax2 = ax_around[2].twinx()
- restarts_plt2 = restart_ax2.plot(dt_list, restarts_true_adapt, 'cs--', label='Restarts')
+ restarts_plt2 = restart_ax2.plot(dt_list, restarts_SE_adapt, 'cs--', label='Restarts')
restart_ax2.tick_params(labelsize=6)
lines = pos31 + pos32 + pos33 + restarts_plt2
labels = [l.get_label() for l in lines]
ax_around[2].legend(frameon=False, fontsize=6, loc='lower right')
- fig_around.savefig('data/diffs_estimation_{}.png'.format(sweeper), dpi=300, bbox_inches='tight')
+ fig_around.savefig('data/diffs_around_switch_{}.png'.format(sweeper), dpi=300, bbox_inches='tight')
plt_helper.plt.close(fig_around)
def differences_over_time(dt_list, problem, sweeper, V_ref, cwd='./'):
"""
Routine to plot the differences in time using the switch estimator or not. Produces the difference_estimation_.png file
+
+ Args:
+ dt_list (list): list of considered (initial) step sizes
+ problem (problem.__name__): Problem class used to consider (the class name)
+ sweeper (sweeper.__name__): Sweeper used to solve (the class name)
+ V_ref (np.float): reference value for the switch
+ cwd: current working directory
"""
if len(dt_list) > 1:
setup_mpl()
fig_diffs, ax_diffs = plt_helper.plt.subplots(
- 2, len(dt_list), figsize=(3 * len(dt_list), 4), sharex='col', sharey='row'
+ 2, len(dt_list), figsize=(4 * len(dt_list), 6), sharex='col', sharey='row'
)
else:
setup_mpl()
- fig_diffs, ax_diffs = plt_helper.plt.subplots(2, 1, figsize=(3, 3))
+ fig_diffs, ax_diffs = plt_helper.plt.subplots(2, 1, figsize=(4, 6))
count_ax = 0
for dt_item in dt_list:
f1 = open(cwd + 'data/battery_dt{}_USETrue_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb')
- stats_TF = dill.load(f1)
+ stats_SE = dill.load(f1)
f1.close()
f2 = open(cwd + 'data/battery_dt{}_USEFalse_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb')
- stats_FF = dill.load(f2)
+ stats = dill.load(f2)
f2.close()
f3 = open(cwd + 'data/battery_dt{}_USETrue_USATrue_{}.dat'.format(dt_item, sweeper), 'rb')
- stats_TT = dill.load(f3)
+ stats_SE_adapt = dill.load(f3)
f3.close()
f4 = open(cwd + 'data/battery_dt{}_USEFalse_USATrue_{}.dat'.format(dt_item, sweeper), 'rb')
- stats_FT = dill.load(f4)
+ stats_adapt = dill.load(f4)
f4.close()
- val_switch_TF = get_sorted(stats_TF, type='switch1', sortby='time')
- t_switch_TF = [v[1] for v in val_switch_TF]
- t_switch_TF = t_switch_TF[-1] # battery has only one single switch
+ switches_SE = get_recomputed(stats_SE, type='switch', sortby='time')
+ t_switch_SE = [v[1] for v in switches_SE]
+ t_switch_SE = t_switch_SE[-1] # battery has only one single switch
- val_switch_TT = get_sorted(stats_TT, type='switch1', sortby='time')
- t_switch_adapt = [v[1] for v in val_switch_TT]
- t_switch_adapt = t_switch_adapt[-1]
+ switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time')
+ t_switch_SE_adapt = [v[1] for v in switches_SE_adapt]
+ t_switch_SE_adapt = t_switch_SE_adapt[-1]
- dt_FT = np.array(get_sorted(stats_FT, type='dt', recomputed=False, sortby='time'))
- dt_TT = np.array(get_sorted(stats_TT, type='dt', recomputed=False, sortby='time'))
+ dt_adapt = np.array(get_sorted(stats_adapt, type='dt', recomputed=False, sortby='time'))
+ dt_SE_adapt = np.array(get_sorted(stats_SE_adapt, type='dt', recomputed=False, sortby='time'))
- restart_FT = np.array(get_sorted(stats_FT, type='restart', recomputed=None, sortby='time'))
- restart_TT = np.array(get_sorted(stats_TT, type='restart', recomputed=None, sortby='time'))
+ restart_adapt = np.array(get_sorted(stats_adapt, type='restart', recomputed=None, sortby='time'))
+ restart_SE_adapt = np.array(get_sorted(stats_SE_adapt, type='restart', recomputed=None, sortby='time'))
- vC_TF = get_sorted(stats_TF, type='voltage C', recomputed=False, sortby='time')
- vC_FT = get_sorted(stats_FT, type='voltage C', recomputed=False, sortby='time')
- vC_TT = get_sorted(stats_TT, type='voltage C', recomputed=False, sortby='time')
- vC_FF = get_sorted(stats_FF, type='voltage C', sortby='time')
+ vC_SE = get_sorted(stats_SE, type='voltage C', recomputed=False, sortby='time')
+ vC_adapt = get_sorted(stats_adapt, type='voltage C', recomputed=False, sortby='time')
+ vC_SE_adapt = get_sorted(stats_SE_adapt, type='voltage C', recomputed=False, sortby='time')
+ vC = get_sorted(stats, type='voltage C', sortby='time')
- diff_TF, diff_FF = [v[1] - V_ref for v in vC_TF], [v[1] - V_ref for v in vC_FF]
- times_TF, times_FF = [v[0] for v in vC_TF], [v[0] for v in vC_FF]
+ diff_SE, diff = [v[1] - V_ref for v in vC_SE], [v[1] - V_ref for v in vC]
+ times_SE, times = [v[0] for v in vC_SE], [v[0] for v in vC]
- diff_FT, diff_TT = [v[1] - V_ref for v in vC_FT], [v[1] - V_ref for v in vC_TT]
- times_FT, times_TT = [v[0] for v in vC_FT], [v[0] for v in vC_TT]
+ diff_adapt, diff_SE_adapt = [v[1] - V_ref for v in vC_adapt], [v[1] - V_ref for v in vC_SE_adapt]
+ times_adapt, times_SE_adapt = [v[0] for v in vC_adapt], [v[0] for v in vC_SE_adapt]
if len(dt_list) > 1:
- ax_diffs[0, count_ax].set_title(r'$\Delta t$={}'.format(dt_item))
- ax_diffs[0, count_ax].plot(times_TF, diff_TF, label='SE=True, A=False', color='#ff7f0e')
- ax_diffs[0, count_ax].plot(times_FF, diff_FF, label='SE=False, A=False', color='#1f77b4')
- ax_diffs[0, count_ax].plot(times_FT, diff_FT, label='SE=False, A=True', color='red', linestyle='--')
- ax_diffs[0, count_ax].plot(times_TT, diff_TT, label='SE=True, A=True', color='limegreen', linestyle='-.')
- ax_diffs[0, count_ax].axvline(x=t_switch_TF, linestyle='--', linewidth=0.5, color='k', label='Switch')
+ ax_diffs[0, count_ax].set_title(r'$\Delta t$=%s' % dt_item)
+ ax_diffs[0, count_ax].plot(times_SE, diff_SE, label='SE=True, A=False', color='#ff7f0e')
+ ax_diffs[0, count_ax].plot(times, diff, label='SE=False, A=False', color='#1f77b4')
+ ax_diffs[0, count_ax].plot(times_adapt, diff_adapt, label='SE=False, A=True', color='red', linestyle='--')
+ ax_diffs[0, count_ax].plot(
+ times_SE_adapt, diff_SE_adapt, label='SE=True, A=True', color='limegreen', linestyle='-.'
+ )
+ ax_diffs[0, count_ax].axvline(x=t_switch_SE, linestyle='--', linewidth=0.5, color='k', label='Switch')
ax_diffs[0, count_ax].legend(frameon=False, fontsize=6, loc='lower left')
ax_diffs[0, count_ax].set_yscale('symlog', linthresh=1e-5)
+ ax_diffs[0, count_ax].tick_params(axis='both', which='major', labelsize=6)
if count_ax == 0:
ax_diffs[0, count_ax].set_ylabel('Difference $v_{C}-V_{ref}$', fontsize=6)
@@ -488,110 +745,128 @@ def differences_over_time(dt_list, problem, sweeper, V_ref, cwd='./'):
ax_diffs[0, count_ax].legend(frameon=False, fontsize=6, loc='upper right')
ax_diffs[1, count_ax].plot(
- dt_FT[:, 0], dt_FT[:, 1], label=r'$\Delta t$ - SE=F, A=T', color='red', linestyle='--'
+ dt_adapt[:, 0], dt_adapt[:, 1], label=r'$\Delta t$ - SE=F, A=T', color='red', linestyle='--'
)
ax_diffs[1, count_ax].plot([None], [None], label='Restart - SE=F, A=T', color='grey', linestyle='-.')
- for i in range(len(restart_FT)):
- if restart_FT[i, 1] > 0:
- ax_diffs[1, count_ax].axvline(restart_FT[i, 0], color='grey', linestyle='-.')
+ for i in range(len(restart_adapt)):
+ if restart_adapt[i, 1] > 0:
+ ax_diffs[1, count_ax].axvline(restart_adapt[i, 0], color='grey', linestyle='-.')
ax_diffs[1, count_ax].plot(
- dt_TT[:, 0], dt_TT[:, 1], label=r'$ \Delta t$ - SE=T, A=T', color='limegreen', linestyle='-.'
+ dt_SE_adapt[:, 0],
+ dt_SE_adapt[:, 1],
+ label=r'$ \Delta t$ - SE=T, A=T',
+ color='limegreen',
+ linestyle='-.',
)
ax_diffs[1, count_ax].plot([None], [None], label='Restart - SE=T, A=T', color='black', linestyle='-.')
- for i in range(len(restart_TT)):
- if restart_TT[i, 1] > 0:
- ax_diffs[1, count_ax].axvline(restart_TT[i, 0], color='black', linestyle='-.')
+ for i in range(len(restart_SE_adapt)):
+ if restart_SE_adapt[i, 1] > 0:
+ ax_diffs[1, count_ax].axvline(restart_SE_adapt[i, 0], color='black', linestyle='-.')
ax_diffs[1, count_ax].set_xlabel('Time', fontsize=6)
+ ax_diffs[1, count_ax].tick_params(axis='both', which='major', labelsize=6)
if count_ax == 0:
- ax_diffs[1, count_ax].set_ylabel(r'$\Delta t_{adapted}$', fontsize=6)
+ ax_diffs[1, count_ax].set_ylabel(r'$\Delta t_\mathrm{adapted}$', fontsize=6)
- ax_diffs[1, count_ax].legend(frameon=True, fontsize=6, loc='upper left')
+ ax_diffs[1, count_ax].set_yscale('log', base=10)
+ ax_diffs[1, count_ax].legend(frameon=True, fontsize=6, loc='lower left')
else:
- ax_diffs[0].set_title(r'$\Delta t$={}'.format(dt_item))
- ax_diffs[0].plot(times_TF, diff_TF, label='SE=True', color='#ff7f0e')
- ax_diffs[0].plot(times_FF, diff_FF, label='SE=False', color='#1f77b4')
- ax_diffs[0].plot(times_FT, diff_FT, label='SE=False, A=True', color='red', linestyle='--')
- ax_diffs[0].plot(times_TT, diff_TT, label='SE=True, A=True', color='limegreen', linestyle='-.')
- ax_diffs[0].axvline(x=t_switch_TF, linestyle='--', linewidth=0.5, color='k', label='Switch')
- ax_diffs[0].legend(frameon=False, fontsize=6, loc='lower left')
+ ax_diffs[0].set_title(r'$\Delta t$=%s' % dt_item)
+ ax_diffs[0].plot(times_SE, diff_SE, label='SE=True', color='#ff7f0e')
+ ax_diffs[0].plot(times, diff, label='SE=False', color='#1f77b4')
+ ax_diffs[0].plot(times_adapt, diff_adapt, label='SE=False, A=True', color='red', linestyle='--')
+ ax_diffs[0].plot(times_SE_adapt, diff_SE_adapt, label='SE=True, A=True', color='limegreen', linestyle='-.')
+ ax_diffs[0].axvline(x=t_switch_SE, linestyle='--', linewidth=0.5, color='k', label='Switch')
+ ax_diffs[0].tick_params(axis='both', which='major', labelsize=6)
ax_diffs[0].set_yscale('symlog', linthresh=1e-5)
ax_diffs[0].set_ylabel('Difference $v_{C}-V_{ref}$', fontsize=6)
ax_diffs[0].legend(frameon=False, fontsize=6, loc='center right')
- ax_diffs[1].plot(dt_FT[:, 0], dt_FT[:, 1], label='SE=False, A=True', color='red', linestyle='--')
- ax_diffs[1].plot(dt_TT[:, 0], dt_TT[:, 1], label='SE=True, A=True', color='limegreen', linestyle='-.')
+ ax_diffs[1].plot(dt_adapt[:, 0], dt_adapt[:, 1], label='SE=False, A=True', color='red', linestyle='--')
+ ax_diffs[1].plot(
+ dt_SE_adapt[:, 0], dt_SE_adapt[:, 1], label='SE=True, A=True', color='limegreen', linestyle='-.'
+ )
+ ax_diffs[1].tick_params(axis='both', which='major', labelsize=6)
ax_diffs[1].set_xlabel('Time', fontsize=6)
- ax_diffs[1].set_ylabel(r'$\Delta t_{adapted}$', fontsize=6)
+ ax_diffs[1].set_ylabel(r'$\Delta t_\mathrm{adapted}$', fontsize=6)
+ ax_diffs[1].set_yscale('log', base=10)
ax_diffs[1].legend(frameon=False, fontsize=6, loc='upper right')
count_ax += 1
plt_helper.plt.tight_layout()
- fig_diffs.savefig('data/difference_estimation_{}.png'.format(sweeper), dpi=300, bbox_inches='tight')
+ fig_diffs.savefig('data/diffs_over_time_{}.png'.format(sweeper), dpi=300, bbox_inches='tight')
plt_helper.plt.close(fig_diffs)
def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'):
"""
Routine to plot the number of iterations over time using switch estimator or not. Produces the iters_.png file
+
+ Args:
+ dt_list (list): list of considered (initial) step sizes
+ maxiter (np.int): maximum number of iterations
+ problem (problem.__name__): Problem class used to consider (the class name)
+ sweeper (sweeper.__name__): Sweeper used to solve (the class name)
+ cwd: current working directory
"""
- iters_time_TF = []
- iters_time_FF = []
- iters_time_TT = []
- iters_time_FT = []
- times_TF = []
- times_FF = []
- times_TT = []
- times_FT = []
- t_switches_TF = []
- t_switches_adapt = []
+ iters_time_SE = []
+ iters_time = []
+ iters_time_SE_adapt = []
+ iters_time_adapt = []
+ times_SE = []
+ times = []
+ times_SE_adapt = []
+ times_adapt = []
+ t_switches_SE = []
+ t_switches_SE_adapt = []
for dt_item in dt_list:
f1 = open(cwd + 'data/battery_dt{}_USETrue_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb')
- stats_TF = dill.load(f1)
+ stats_SE = dill.load(f1)
f1.close()
f2 = open(cwd + 'data/battery_dt{}_USEFalse_USAFalse_{}.dat'.format(dt_item, sweeper), 'rb')
- stats_FF = dill.load(f2)
+ stats = dill.load(f2)
f2.close()
f3 = open(cwd + 'data/battery_dt{}_USETrue_USATrue_{}.dat'.format(dt_item, sweeper), 'rb')
- stats_TT = dill.load(f3)
+ stats_SE_adapt = dill.load(f3)
f3.close()
f4 = open(cwd + 'data/battery_dt{}_USEFalse_USATrue_{}.dat'.format(dt_item, sweeper), 'rb')
- stats_FT = dill.load(f4)
+ stats_adapt = dill.load(f4)
f4.close()
- iter_counts_TF_val = get_sorted(stats_TF, type='niter', recomputed=False, sortby='time')
- iter_counts_TT_val = get_sorted(stats_TT, type='niter', recomputed=False, sortby='time')
- iter_counts_FT_val = get_sorted(stats_FT, type='niter', recomputed=False, sortby='time')
- iter_counts_FF_val = get_sorted(stats_FF, type='niter', recomputed=False, sortby='time')
+ # consider iterations before restarts to see what happens
+ iter_counts_SE_val = get_sorted(stats_SE, type='niter', sortby='time')
+ iter_counts_SE_adapt_val = get_sorted(stats_SE_adapt, type='niter', sortby='time')
+ iter_counts_adapt_val = get_sorted(stats_adapt, type='niter', sortby='time')
+ iter_counts_val = get_sorted(stats, type='niter', sortby='time')
- iters_time_TF.append([v[1] for v in iter_counts_TF_val])
- iters_time_TT.append([v[1] for v in iter_counts_TT_val])
- iters_time_FT.append([v[1] for v in iter_counts_FT_val])
- iters_time_FF.append([v[1] for v in iter_counts_FF_val])
+ iters_time_SE.append([v[1] for v in iter_counts_SE_val])
+ iters_time_SE_adapt.append([v[1] for v in iter_counts_SE_adapt_val])
+ iters_time_adapt.append([v[1] for v in iter_counts_adapt_val])
+ iters_time.append([v[1] for v in iter_counts_val])
- times_TF.append([v[0] for v in iter_counts_TF_val])
- times_TT.append([v[0] for v in iter_counts_TT_val])
- times_FT.append([v[0] for v in iter_counts_FT_val])
- times_FF.append([v[0] for v in iter_counts_FF_val])
+ times_SE.append([v[0] for v in iter_counts_SE_val])
+ times_SE_adapt.append([v[0] for v in iter_counts_SE_adapt_val])
+ times_adapt.append([v[0] for v in iter_counts_adapt_val])
+ times.append([v[0] for v in iter_counts_val])
- val_switch_TF = get_sorted(stats_TF, type='switch1', sortby='time')
- t_switch_TF = [v[1] for v in val_switch_TF]
- t_switches_TF.append(t_switch_TF[-1])
+ switches_SE = get_recomputed(stats_SE, type='switch', sortby='time')
+ t_switch_SE = [v[1] for v in switches_SE]
+ t_switches_SE.append(t_switch_SE[-1])
- val_switch_TT = get_sorted(stats_TT, type='switch1', sortby='time')
- t_switch_adapt = [v[1] for v in val_switch_TT]
- t_switches_adapt.append(t_switch_adapt[-1])
+ switches_SE_adapt = get_recomputed(stats_SE_adapt, type='switch', sortby='time')
+ t_switch_SE_adapt = [v[1] for v in switches_SE_adapt]
+ t_switches_SE_adapt.append(t_switch_SE_adapt[-1])
if len(dt_list) > 1:
setup_mpl()
@@ -599,18 +874,15 @@ def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'):
nrows=1, ncols=len(dt_list), figsize=(2 * len(dt_list) - 1, 3), sharex='col', sharey='row'
)
for col in range(len(dt_list)):
- ax_iter_all[col].plot(times_FF[col], iters_time_FF[col], label='SE=F, A=F')
- ax_iter_all[col].plot(times_TF[col], iters_time_TF[col], label='SE=T, A=F')
- ax_iter_all[col].plot(times_TT[col], iters_time_TT[col], '--', label='SE=T, A=T')
- ax_iter_all[col].plot(times_FT[col], iters_time_FT[col], '--', label='SE=F, A=T')
- ax_iter_all[col].axvline(x=t_switches_TF[col], linestyle='--', linewidth=0.5, color='k', label='Switch')
- if t_switches_adapt[col] != t_switches_TF[col]:
- ax_iter_all[col].axvline(
- x=t_switches_adapt[col], linestyle='--', linewidth=0.5, color='k', label='Switch'
- )
- ax_iter_all[col].set_title('dt={}'.format(dt_list[col]))
+ ax_iter_all[col].plot(times[col], iters_time[col], label='SE=F, A=F')
+ ax_iter_all[col].plot(times_SE[col], iters_time_SE[col], label='SE=T, A=F')
+ ax_iter_all[col].plot(times_SE_adapt[col], iters_time_SE_adapt[col], '--', label='SE=T, A=T')
+ ax_iter_all[col].plot(times_adapt[col], iters_time_adapt[col], '--', label='SE=F, A=T')
+ ax_iter_all[col].axvline(x=t_switches_SE[col], linestyle='--', linewidth=0.5, color='k', label='Switch')
+ ax_iter_all[col].set_title(r'$\Delta t_\mathrm{initial}$=%s' % dt_list[col])
ax_iter_all[col].set_ylim(0, maxiter + 2)
ax_iter_all[col].set_xlabel('Time', fontsize=6)
+ ax_iter_all[col].tick_params(axis='both', which='major', labelsize=6)
if col == 0:
ax_iter_all[col].set_ylabel('Number iterations', fontsize=6)
@@ -620,16 +892,15 @@ def iterations_over_time(dt_list, maxiter, problem, sweeper, cwd='./'):
setup_mpl()
fig_iter_all, ax_iter_all = plt_helper.plt.subplots(nrows=1, ncols=1, figsize=(3, 3))
- ax_iter_all.plot(times_FF[0], iters_time_FF[0], label='SE=False')
- ax_iter_all.plot(times_TF[0], iters_time_TF[0], label='SE=True')
- ax_iter_all.plot(times_TT[0], iters_time_TT[0], '--', label='SE=T, A=T')
- ax_iter_all.plot(times_FT[0], iters_time_FT[0], '--', label='SE=F, A=T')
- ax_iter_all.axvline(x=t_switches_TF[0], linestyle='--', linewidth=0.5, color='k', label='Switch')
- if t_switches_adapt[0] != t_switches_TF[0]:
- ax_iter_all.axvline(x=t_switches_adapt[0], linestyle='--', linewidth=0.5, color='k', label='Switch')
- ax_iter_all.set_title('dt={}'.format(dt_list[0]))
+ ax_iter_all.plot(times[0], iters_time[0], label='SE=False')
+ ax_iter_all.plot(times_SE[0], iters_time_SE[0], label='SE=True')
+ ax_iter_all.plot(times_SE_adapt[0], iters_time_SE_adapt[0], '--', label='SE=T, A=T')
+ ax_iter_all.plot(times_adapt[0], iters_time_adapt[0], '--', label='SE=F, A=T')
+ ax_iter_all.axvline(x=t_switches_SE[0], linestyle='--', linewidth=0.5, color='k', label='Switch')
+ ax_iter_all.set_title(r'$\Delta t_\mathrm{initial}$=%s' % dt_list[0])
ax_iter_all.set_ylim(0, maxiter + 2)
ax_iter_all.set_xlabel('Time', fontsize=6)
+ ax_iter_all.tick_params(axis='both', which='major', labelsize=6)
ax_iter_all.set_ylabel('Number iterations', fontsize=6)
ax_iter_all.legend(frameon=False, fontsize=6, loc='upper right')
diff --git a/pySDC/projects/PinTSimE/estimation_check_extended.py b/pySDC/projects/PinTSimE/estimation_check_extended.py
index f075fe2de4..27d80234e8 100644
--- a/pySDC/projects/PinTSimE/estimation_check_extended.py
+++ b/pySDC/projects/PinTSimE/estimation_check_extended.py
@@ -4,10 +4,11 @@
from pySDC.helpers.stats_helper import get_sorted
from pySDC.core.Collocation import CollBase as Collocation
-from pySDC.implementations.problem_classes.Battery_2Condensators import battery_2condensators
+from pySDC.implementations.problem_classes.Battery import battery_n_condensators
from pySDC.implementations.sweeper_classes.imex_1st_order import imex_1st_order
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
-from pySDC.implementations.transfer_classes.TransferMesh import mesh_to_mesh
+from pySDC.projects.PinTSimE.battery_model import get_recomputed
+from pySDC.projects.PinTSimE.battery_2condensators_model import get_data_dict
from pySDC.projects.PinTSimE.piline_model import setup_mpl
from pySDC.projects.PinTSimE.battery_2condensators_model import log_data, proof_assertions_description
import pySDC.helpers.plot_helper as plt_helper
@@ -16,39 +17,52 @@
def run(dt, use_switch_estimator=True):
+ """
+ A simple test program to do SDC/PFASST runs for the battery drain model using 2 condensators
+
+ Args:
+ dt (float): time step that wants to be used for the computation
+ use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+
+ Returns:
+ stats (dict): all statistics from a controller run
+ description (dict): contains all information for a controller run
+ """
# initialize level parameters
level_params = dict()
- level_params['restol'] = 1e-13
+ level_params['restol'] = -1
level_params['dt'] = dt
+ assert (
+ dt == 4e-1 or dt == 4e-2 or dt == 4e-3
+ ), "Error! Do not use other time steps dt != 4e-1 or dt != 4e-2 or dt != 4e-3 due to hardcoded references!"
+
# initialize sweeper parameters
sweeper_params = dict()
sweeper_params['quad_type'] = 'LOBATTO'
sweeper_params['num_nodes'] = 5
- sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part
- sweeper_params['initial_guess'] = 'zero'
+ # sweeper_params['QI'] = 'LU' # For the IMEX sweeper, the LU-trick can be activated for the implicit part
+ sweeper_params['initial_guess'] = 'spread'
# initialize problem parameters
problem_params = dict()
+ problem_params['ncondensators'] = 2
problem_params['Vs'] = 5.0
problem_params['Rs'] = 0.5
- problem_params['C1'] = 1.0
- problem_params['C2'] = 1.0
+ problem_params['C'] = np.array([1.0, 1.0])
problem_params['R'] = 1.0
problem_params['L'] = 1.0
problem_params['alpha'] = 5.0
problem_params['V_ref'] = np.array([1.0, 1.0]) # [V_ref1, V_ref2]
- problem_params['set_switch'] = np.array([False, False], dtype=bool)
- problem_params['t_switch'] = np.zeros(np.shape(problem_params['V_ref'])[0])
# initialize step parameters
step_params = dict()
- step_params['maxiter'] = 20
+ step_params['maxiter'] = 4
# initialize controller parameters
controller_params = dict()
- controller_params['logger_level'] = 20
+ controller_params['logger_level'] = 30
controller_params['hook_class'] = log_data
# convergence controllers
@@ -59,13 +73,12 @@ def run(dt, use_switch_estimator=True):
# fill description dictionary for easy step instantiation
description = dict()
- description['problem_class'] = battery_2condensators # pass problem class
+ description['problem_class'] = battery_n_condensators # pass problem class
description['problem_params'] = problem_params # pass problem parameters
description['sweeper_class'] = imex_1st_order # pass sweeper
description['sweeper_params'] = sweeper_params # pass sweeper parameters
description['level_params'] = level_params # pass level parameters
description['step_params'] = step_params
- description['space_transfer_class'] = mesh_to_mesh # pass spatial transfer class
if use_switch_estimator:
description['convergence_controllers'] = convergence_controllers
@@ -86,40 +99,15 @@ def run(dt, use_switch_estimator=True):
# call main function to get things done...
uend, stats = controller.run(u0=uinit, t0=t0, Tend=Tend)
- Path("data").mkdir(parents=True, exist_ok=True)
- fname = 'data/battery_2condensators.dat'
- f = open(fname, 'wb')
- dill.dump(stats, f)
- f.close()
-
- # filter statistics by number of iterations
- iter_counts = get_sorted(stats, type='niter', sortby='time')
-
- # compute and print statistics
- min_iter = 20
- max_iter = 0
-
- f = open('data/battery_2condensators_out.txt', 'w')
- niters = np.array([item[1] for item in iter_counts])
- out = ' Mean number of iterations: %4.2f' % np.mean(niters)
- f.write(out + '\n')
- print(out)
- for item in iter_counts:
- out = 'Number of iterations for time %4.2f: %1i' % item
- f.write(out + '\n')
- # print(out)
- min_iter = min(min_iter, item[1])
- max_iter = max(max_iter, item[1])
-
- assert np.mean(niters) <= 12, "Mean number of iterations is too high, got %s" % np.mean(niters)
- f.close()
-
return stats, description
def check(cwd='./'):
"""
Routine to check the differences between using a switch estimator or not
+
+ Args:
+ cwd: current working directory
"""
dt_list = [4e-1, 4e-2, 4e-3]
@@ -127,15 +115,21 @@ def check(cwd='./'):
restarts_all = []
restarts_dict = dict()
for dt_item in dt_list:
- for item in use_switch_estimator:
- stats, description = run(dt=dt_item, use_switch_estimator=item)
+ for use_SE in use_switch_estimator:
+ stats, description = run(dt=dt_item, use_switch_estimator=use_SE)
+
+ if use_SE:
+ switches = get_recomputed(stats, type='switch', sortby='time')
+ assert len(switches) >= 2, f"Expected at least 2 switches for dt: {dt_item}, got {len(switches)}!"
+
+ check_solution(stats, dt_item, use_SE)
- fname = 'data/battery_2condensators_dt{}_USE{}.dat'.format(dt_item, item)
+ fname = 'data/battery_2condensators_dt{}_USE{}.dat'.format(dt_item, use_SE)
f = open(fname, 'wb')
dill.dump(stats, f)
f.close()
- if item:
+ if use_SE:
restarts_dict[dt_item] = np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))
restarts = restarts_dict[dt_item][:, 1]
restarts_all.append(np.sum(restarts))
@@ -161,15 +155,10 @@ def check(cwd='./'):
stats_false = dill.load(f2)
f2.close()
- val_switch1 = get_sorted(stats_true, type='switch1', sortby='time')
- val_switch2 = get_sorted(stats_true, type='switch2', sortby='time')
- t_switch1 = [v[1] for v in val_switch1]
- t_switch2 = [v[1] for v in val_switch2]
+ switches = get_recomputed(stats_true, type='switch', sortby='time')
+ t_switch = [v[1] for v in switches]
- t_switch1 = t_switch1[-1]
- t_switch2 = t_switch2[-1]
-
- val_switch_all.append([t_switch1, t_switch2])
+ val_switch_all.append([t_switch[0], t_switch[1]])
vC1_true = get_sorted(stats_true, type='voltage C1', recomputed=False, sortby='time')
vC2_true = get_sorted(stats_true, type='voltage C2', recomputed=False, sortby='time')
@@ -187,37 +176,37 @@ def check(cwd='./'):
times_false2 = [v[0] for v in vC2_false]
for m in range(len(times_true1)):
- if np.round(times_true1[m], 15) == np.round(t_switch1, 15):
+ if np.round(times_true1[m], 15) == np.round(t_switch[0], 15):
diff_true_all1.append(diff_true1[m])
for m in range(len(times_true2)):
- if np.round(times_true2[m], 15) == np.round(t_switch2, 15):
+ if np.round(times_true2[m], 15) == np.round(t_switch[1], 15):
diff_true_all2.append(diff_true2[m])
for m in range(1, len(times_false1)):
- if times_false1[m - 1] < t_switch1 < times_false1[m]:
+ if times_false1[m - 1] < t_switch[0] < times_false1[m]:
diff_false_all_before1.append(diff_false1[m - 1])
diff_false_all_after1.append(diff_false1[m])
for m in range(1, len(times_false2)):
- if times_false2[m - 1] < t_switch2 < times_false2[m]:
+ if times_false2[m - 1] < t_switch[1] < times_false2[m]:
diff_false_all_before2.append(diff_false2[m - 1])
diff_false_all_after2.append(diff_false2[m])
restarts_dt = restarts_dict[dt_item]
for i in range(len(restarts_dt[:, 0])):
- if round(restarts_dt[i, 0], 13) == round(t_switch1, 13):
- restarts_dt_switch1.append(np.sum(restarts_dt[0:i, 1]))
+ if round(restarts_dt[i, 0], 13) == round(t_switch[0], 13):
+ restarts_dt_switch1.append(np.sum(restarts_dt[0 : i - 1, 1]))
- if round(restarts_dt[i, 0], 13) == round(t_switch2, 13):
- restarts_dt_switch2.append(np.sum(restarts_dt[i - 1 :, 1]))
+ if round(restarts_dt[i, 0], 13) == round(t_switch[1], 13):
+ restarts_dt_switch2.append(np.sum(restarts_dt[i - 2 :, 1]))
setup_mpl()
fig1, ax1 = plt_helper.plt.subplots(1, 1, figsize=(4.5, 3))
ax1.set_title('Time evolution of $v_{C_{1}}-V_{ref1}$')
ax1.plot(times_true1, diff_true1, label='SE=True', color='#ff7f0e')
ax1.plot(times_false1, diff_false1, label='SE=False', color='#1f77b4')
- ax1.axvline(x=t_switch1, linestyle='--', color='k', label='Switch1')
+ ax1.axvline(x=t_switch[0], linestyle='--', color='k', label='Switch1')
ax1.legend(frameon=False, fontsize=10, loc='lower left')
ax1.set_yscale('symlog', linthresh=1e-5)
ax1.set_xlabel('Time')
@@ -230,7 +219,7 @@ def check(cwd='./'):
ax2.set_title('Time evolution of $v_{C_{2}}-V_{ref2}$')
ax2.plot(times_true2, diff_true2, label='SE=True', color='#ff7f0e')
ax2.plot(times_false2, diff_false2, label='SE=False', color='#1f77b4')
- ax2.axvline(x=t_switch2, linestyle='--', color='k', label='Switch2')
+ ax2.axvline(x=t_switch[1], linestyle='--', color='k', label='Switch2')
ax2.legend(frameon=False, fontsize=10, loc='lower left')
ax2.set_yscale('symlog', linthresh=1e-5)
ax2.set_xlabel('Time')
@@ -287,5 +276,66 @@ def check(cwd='./'):
plt_helper.plt.close(fig2)
+def check_solution(stats, dt, use_switch_estimator):
+ """
+ Function that checks the solution based on a hardcoded reference solution. Based on check_solution function from @brownbaerchen.
+
+ Args:
+ stats (dict): Raw statistics from a controller run
+ dt (float): initial time step
+ use_switch_estimator (bool): flag if the switch estimator wants to be used or not
+ """
+
+ data = get_data_dict(stats, use_switch_estimator)
+
+ if use_switch_estimator:
+ msg = f'Error when using the switch estimator for battery_2condensators for dt={dt:.1e}:'
+ if dt == 4e-1:
+ expected = {
+ 'cL': 1.1842780233981391,
+ 'vC1': 1.0094891393319418,
+ 'vC2': 1.00103823232433,
+ 'switch1': 1.6075867934844466,
+ 'switch2': 3.209437912436633,
+ 'restarts': 2.0,
+ 'sum_niters': 2000,
+ }
+ elif dt == 4e-2:
+ expected = {
+ 'cL': 1.180493652021971,
+ 'vC1': 1.0094825917376264,
+ 'vC2': 1.0007713468084405,
+ 'switch1': 1.6094074085553605,
+ 'switch2': 3.209437912440314,
+ 'restarts': 2.0,
+ 'sum_niters': 2364,
+ }
+ elif dt == 4e-3:
+ expected = {
+ 'cL': 1.1537529501025199,
+ 'vC1': 1.001438946726028,
+ 'vC2': 1.0004331625246141,
+ 'switch1': 1.6093728710270467,
+ 'switch2': 3.217437912434171,
+ 'restarts': 2.0,
+ 'sum_niters': 8920,
+ }
+
+ got = {
+ 'cL': data['cL'][-1],
+ 'vC1': data['vC1'][-1],
+ 'vC2': data['vC2'][-1],
+ 'switch1': data['switch1'],
+ 'switch2': data['switch2'],
+ 'restarts': data['restarts'],
+ 'sum_niters': data['sum_niters'],
+ }
+
+ for key in expected.keys():
+ assert np.isclose(
+ expected[key], got[key], rtol=1e-4
+ ), f'{msg} Expected {key}={expected[key]:.4e}, got {key}={got[key]:.4e}'
+
+
if __name__ == "__main__":
check()
diff --git a/pySDC/projects/PinTSimE/switch_estimator.py b/pySDC/projects/PinTSimE/switch_estimator.py
index e43fbb49b3..f31a290ea3 100644
--- a/pySDC/projects/PinTSimE/switch_estimator.py
+++ b/pySDC/projects/PinTSimE/switch_estimator.py
@@ -2,7 +2,7 @@
import scipy as sp
from pySDC.core.Collocation import CollBase
-from pySDC.core.ConvergenceController import ConvergenceController
+from pySDC.core.ConvergenceController import ConvergenceController, Status
class SwitchEstimator(ConvergenceController):
@@ -31,13 +31,34 @@ def setup(self, controller, params, description):
num_nodes=description['sweeper_params']['num_nodes'],
quad_type=description['sweeper_params']['quad_type'],
)
- self.coll_nodes_local = coll.nodes
- self.switch_detected = False
- self.switch_detected_step = False
- self.t_switch = None
- self.count_switches = 0
- self.dt_initial = description['level_params']['dt']
- return {'control_order': 100, **params}
+
+ defaults = {
+ 'control_order': 100,
+ 'tol': description['level_params']['dt'],
+ 'coll_nodes': coll.nodes,
+ 'dt_initial': description['level_params']['dt'],
+ }
+ return {**defaults, **params}
+
+ def setup_status_variables(self, controller, **kwargs):
+ """
+ Adds switching specific variables to status variables.
+
+ Args:
+ controller (pySDC.Controller): The controller
+ """
+
+ self.status = Status(['t_switch', 'switch_detected', 'switch_detected_step'])
+
+ def reset_status_variables(self, controller, **kwargs):
+ """
+ Resets status variables.
+
+ Args:
+ controller (pySDC.Controller): The controller
+ """
+
+ self.setup_status_variables(controller, **kwargs)
def get_new_step_size(self, controller, S):
"""
@@ -51,76 +72,62 @@ def get_new_step_size(self, controller, S):
None
"""
- self.switch_detected = False # reset between steps
-
L = S.levels[0]
- if not type(L.prob.params.V_ref) == int and not type(L.prob.params.V_ref) == float:
- # if V_ref is not a scalar, but an (np.)array
- V_ref = np.zeros(np.shape(L.prob.params.V_ref)[0], dtype=float)
- for m in range(np.shape(L.prob.params.V_ref)[0]):
- V_ref[m] = L.prob.params.V_ref[m]
- else:
- V_ref = np.array([L.prob.params.V_ref], dtype=float)
+ if S.status.iter == S.params.maxiter:
- if S.status.iter > 0 and self.count_switches < np.shape(V_ref)[0]:
- for m in range(len(L.u)):
- if L.u[m][self.count_switches + 1] - V_ref[self.count_switches] <= 0:
- self.switch_detected = True
- m_guess = m - 1
- break
+ self.status.switch_detected, m_guess, vC_switch = L.prob.get_switching_info(L.u, L.time)
- if self.switch_detected:
- t_interp = [L.time + L.dt * self.coll_nodes_local[m] for m in range(len(self.coll_nodes_local))]
-
- vC_switch = []
- for m in range(1, len(L.u)):
- vC_switch.append(L.u[m][self.count_switches + 1] - V_ref[self.count_switches])
+ if self.status.switch_detected:
+ t_interp = [L.time + L.dt * self.params.coll_nodes[m] for m in range(len(self.params.coll_nodes))]
# only find root if vc_switch[0], vC_switch[-1] have opposite signs (intermediate value theorem)
if vC_switch[0] * vC_switch[-1] < 0:
- self.t_switch = self.get_switch(t_interp, vC_switch, m_guess)
+ self.status.t_switch = self.get_switch(t_interp, vC_switch, m_guess)
- # if the switch is not find, we need to do ... ?
- if L.time < self.t_switch < L.time + L.dt:
- r = 1
- tol = self.dt_initial / r
+ if L.time < self.status.t_switch < L.time + L.dt:
- if not np.isclose(self.t_switch - L.time, L.dt, atol=tol):
- dt_search = self.t_switch - L.time
+ dt_switch = self.status.t_switch - L.time
+ if not np.isclose(self.status.t_switch - L.time, L.dt, atol=self.params.tol):
+ self.log(
+ f"Located Switch at time {self.status.t_switch:.6f} is outside the range of tol={self.params.tol:.4e}",
+ S,
+ )
else:
- print('Switch located at time: {}'.format(self.t_switch))
- dt_search = self.t_switch - L.time
- L.prob.params.set_switch[self.count_switches] = self.switch_detected
- L.prob.params.t_switch[self.count_switches] = self.t_switch
- controller.hooks.add_to_stats(
+ self.log(
+ f"Switch located at time {self.status.t_switch:.6f} inside tol={self.params.tol:.4e}", S
+ )
+
+ L.prob.t_switch = self.status.t_switch
+ controller.hooks[0].add_to_stats(
process=S.status.slot,
time=L.time,
level=L.level_index,
iter=0,
sweep=L.status.sweep,
- type='switch{}'.format(self.count_switches + 1),
- value=self.t_switch,
+ type='switch',
+ value=self.status.t_switch,
)
- self.switch_detected_step = True
+ L.prob.count_switches()
+ self.status.switch_detected_step = True
dt_planned = L.status.dt_new if L.status.dt_new is not None else L.params.dt
# when a switch is found, time step to match with switch should be preferred
- if self.switch_detected:
- L.status.dt_new = dt_search
+ if self.status.switch_detected:
+ L.status.dt_new = dt_switch
else:
- L.status.dt_new = min([dt_planned, dt_search])
+ L.status.dt_new = min([dt_planned, dt_switch])
else:
- self.switch_detected = False
+ self.status.switch_detected = False
else:
- self.switch_detected = False
+ self.status.switch_detected = False
def determine_restart(self, controller, S):
"""
@@ -134,8 +141,7 @@ def determine_restart(self, controller, S):
None
"""
- if self.switch_detected:
- print("Restart")
+ if self.status.switch_detected:
S.status.restart = True
S.status.force_done = True
@@ -156,13 +162,9 @@ def post_step_processing(self, controller, S):
L = S.levels[0]
- if self.switch_detected_step:
- if L.prob.params.set_switch[self.count_switches] and L.time + L.dt >= self.t_switch:
- self.count_switches += 1
- self.t_switch = None
- self.switch_detected_step = False
-
- L.status.dt_new = self.dt_initial
+ if self.status.switch_detected_step:
+ if L.time + L.dt >= self.params.t_switch:
+ L.status.dt_new = L.status.dt_new if L.status.dt_new is not None else L.params.dt
super(SwitchEstimator, self).post_step_processing(controller, S)
diff --git a/pySDC/projects/Resilience/accuracy_check.py b/pySDC/projects/Resilience/accuracy_check.py
index 4c51d696c1..0bbc375ed9 100644
--- a/pySDC/projects/Resilience/accuracy_check.py
+++ b/pySDC/projects/Resilience/accuracy_check.py
@@ -34,24 +34,6 @@ def post_step(self, step, level_number):
L.sweep.compute_end_point()
- self.add_to_stats(
- process=step.status.slot,
- time=L.time + L.dt,
- level=L.level_index,
- iter=0,
- sweep=L.status.sweep,
- type='e_embedded',
- value=L.status.error_embedded_estimate,
- )
- self.add_to_stats(
- process=step.status.slot,
- time=L.time + L.dt,
- level=L.level_index,
- iter=0,
- sweep=L.status.sweep,
- type='e_extrapolated',
- value=L.status.get('error_extrapolation_estimate'),
- )
self.add_to_stats(
process=step.status.slot,
time=L.time,
@@ -105,8 +87,8 @@ def get_results_from_stats(stats, var, val, hook_class=log_errors):
}
if hook_class == log_errors:
- e_extrapolated = np.array(get_sorted(stats, type='e_extrapolated'))[:, 1]
- e_embedded = np.array(get_sorted(stats, type='e_embedded'))[:, 1]
+ e_extrapolated = np.array(get_sorted(stats, type='error_extrapolation_estimate'))[:, 1]
+ e_embedded = np.array(get_sorted(stats, type='error_embedded_estimate'))[:, 1]
e_loc = np.array(get_sorted(stats, type='e_loc'))[:, 1]
if len(e_extrapolated[e_extrapolated != [None]]) > 0:
diff --git a/pySDC/projects/Resilience/advection.py b/pySDC/projects/Resilience/advection.py
index 6515a0931e..1591d340c2 100644
--- a/pySDC/projects/Resilience/advection.py
+++ b/pySDC/projects/Resilience/advection.py
@@ -6,7 +6,7 @@
from pySDC.core.Hooks import hooks
from pySDC.helpers.stats_helper import get_sorted
import numpy as np
-from pySDC.projects.Resilience.hook import log_error_estimates
+from pySDC.projects.Resilience.hook import log_data
def plot_embedded(stats, ax):
@@ -21,17 +21,8 @@ def plot_embedded(stats, ax):
ax.legend(frameon=False)
-class log_data(hooks):
- def pre_run(self, step, level_number):
- """
- Record los conditiones initiales
- """
- super(log_data, self).pre_run(step, level_number)
- L = step.levels[level_number]
- self.add_to_stats(process=0, time=0, level=0, iter=0, sweep=0, type='u0', value=L.u[0])
-
+class log_every_iteration(hooks):
def post_iteration(self, step, level_number):
- super(log_data, self).post_iteration(step, level_number)
if step.status.iter == step.params.maxiter - 1:
L = step.levels[level_number]
L.sweep.compute_end_point()
@@ -45,58 +36,12 @@ def post_iteration(self, step, level_number):
value=L.uold[-1],
)
- def post_step(self, step, level_number):
-
- super(log_data, self).post_step(step, level_number)
-
- # some abbreviations
- L = step.levels[level_number]
-
- L.sweep.compute_end_point()
-
- self.add_to_stats(
- process=step.status.slot,
- time=L.time + L.dt,
- level=L.level_index,
- iter=0,
- sweep=L.status.sweep,
- type='u',
- value=L.uend,
- )
- self.add_to_stats(
- process=step.status.slot,
- time=L.time,
- level=L.level_index,
- iter=0,
- sweep=L.status.sweep,
- type='dt',
- value=L.dt,
- )
- self.add_to_stats(
- process=step.status.slot,
- time=L.time + L.dt,
- level=L.level_index,
- iter=0,
- sweep=L.status.sweep,
- type='e_embedded',
- value=L.status.get('error_embedded_estimate'),
- )
- self.add_to_stats(
- process=step.status.slot,
- time=L.time + L.dt,
- level=L.level_index,
- iter=0,
- sweep=L.status.sweep,
- type='e_extrapolated',
- value=L.status.get('error_extrapolation_estimate'),
- )
-
def run_advection(
custom_description=None,
num_procs=1,
Tend=2e-1,
- hook_class=log_error_estimates,
+ hook_class=log_data,
fault_stuff=None,
custom_controller_params=None,
custom_problem_params=None,
diff --git a/pySDC/projects/Resilience/heat.py b/pySDC/projects/Resilience/heat.py
index e4d3e55bb4..770002ffab 100644
--- a/pySDC/projects/Resilience/heat.py
+++ b/pySDC/projects/Resilience/heat.py
@@ -5,7 +5,7 @@
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
from pySDC.core.Hooks import hooks
from pySDC.helpers.stats_helper import get_sorted
-from pySDC.projects.Resilience.hook import log_error_estimates
+from pySDC.projects.Resilience.hook import log_data
import numpy as np
@@ -13,7 +13,7 @@ def run_heat(
custom_description=None,
num_procs=1,
Tend=2e-1,
- hook_class=log_error_estimates,
+ hook_class=log_data,
fault_stuff=None,
custom_controller_params=None,
custom_problem_params=None,
diff --git a/pySDC/projects/Resilience/hook.py b/pySDC/projects/Resilience/hook.py
index 0ad5bdfb6f..a980e0223c 100644
--- a/pySDC/projects/Resilience/hook.py
+++ b/pySDC/projects/Resilience/hook.py
@@ -1,7 +1,14 @@
from pySDC.core.Hooks import hooks
+from pySDC.implementations.hooks.log_solution import LogSolution
+from pySDC.implementations.hooks.log_embedded_error_estimate import LogEmbeddedErrorEstimate
+from pySDC.implementations.hooks.log_extrapolated_error_estimate import LogExtrapolationErrorEstimate
+from pySDC.implementations.hooks.log_step_size import LogStepSize
-class log_error_estimates(hooks):
+hook_collection = [LogSolution, LogEmbeddedErrorEstimate, LogExtrapolationErrorEstimate, LogStepSize]
+
+
+class log_data(hooks):
"""
Record data required for analysis of problems in the resilience project
"""
@@ -10,7 +17,8 @@ def pre_run(self, step, level_number):
"""
Record los conditiones initiales
"""
- super(log_error_estimates, self).pre_run(step, level_number)
+ super().pre_run(step, level_number)
+
L = step.levels[level_number]
self.add_to_stats(process=0, time=0, level=0, iter=0, sweep=0, type='u0', value=L.u[0])
@@ -18,49 +26,10 @@ def post_step(self, step, level_number):
"""
Record final solutions as well as step size and error estimates
"""
- super(log_error_estimates, self).post_step(step, level_number)
+ super().post_step(step, level_number)
- # some abbreviations
L = step.levels[level_number]
- L.sweep.compute_end_point()
-
- self.add_to_stats(
- process=step.status.slot,
- time=L.time + L.dt,
- level=L.level_index,
- iter=0,
- sweep=L.status.sweep,
- type='u',
- value=L.uend,
- )
- self.add_to_stats(
- process=step.status.slot,
- time=L.time,
- level=L.level_index,
- iter=0,
- sweep=L.status.sweep,
- type='dt',
- value=L.dt,
- )
- self.add_to_stats(
- process=step.status.slot,
- time=L.time + L.dt,
- level=L.level_index,
- iter=0,
- sweep=L.status.sweep,
- type='e_embedded',
- value=L.status.__dict__.get('error_embedded_estimate', None),
- )
- self.add_to_stats(
- process=step.status.slot,
- time=L.time + L.dt,
- level=L.level_index,
- iter=0,
- sweep=L.status.sweep,
- type='e_extrapolated',
- value=L.status.__dict__.get('error_extrapolation_estimate', None),
- )
self.add_to_stats(
process=step.status.slot,
time=L.time,
diff --git a/pySDC/projects/Resilience/piline.py b/pySDC/projects/Resilience/piline.py
index 0d301cab59..9ac4af3fbe 100644
--- a/pySDC/projects/Resilience/piline.py
+++ b/pySDC/projects/Resilience/piline.py
@@ -7,14 +7,14 @@
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
from pySDC.implementations.convergence_controller_classes.hotrod import HotRod
-from pySDC.projects.Resilience.hook import log_error_estimates
+from pySDC.projects.Resilience.hook import log_data, hook_collection
def run_piline(
custom_description=None,
num_procs=1,
Tend=20.0,
- hook_class=log_error_estimates,
+ hook_class=log_data,
fault_stuff=None,
custom_controller_params=None,
custom_problem_params=None,
@@ -68,7 +68,7 @@ def run_piline(
# initialize controller parameters
controller_params = dict()
controller_params['logger_level'] = 30
- controller_params['hook_class'] = hook_class
+ controller_params['hook_class'] = hook_collection + [hook_class]
controller_params['mssdc_jac'] = False
if custom_controller_params is not None:
@@ -131,8 +131,8 @@ def get_data(stats, recomputed=False):
't': np.array([me[0] for me in get_sorted(stats, type='u', recomputed=recomputed)]),
'dt': np.array([me[1] for me in get_sorted(stats, type='dt', recomputed=recomputed)]),
't_dt': np.array([me[0] for me in get_sorted(stats, type='dt', recomputed=recomputed)]),
- 'e_em': np.array(get_sorted(stats, type='e_embedded', recomputed=recomputed))[:, 1],
- 'e_ex': np.array(get_sorted(stats, type='e_extrapolated', recomputed=recomputed))[:, 1],
+ 'e_em': np.array(get_sorted(stats, type='error_embedded_estimate', recomputed=recomputed))[:, 1],
+ 'e_ex': np.array(get_sorted(stats, type='error_extrapolation_estimate', recomputed=recomputed))[:, 1],
'restarts': np.array(get_sorted(stats, type='restart', recomputed=None))[:, 1],
't_restarts': np.array(get_sorted(stats, type='restart', recomputed=None))[:, 0],
'sweeps': np.array(get_sorted(stats, type='sweeps', recomputed=None))[:, 1],
diff --git a/pySDC/projects/Resilience/vdp.py b/pySDC/projects/Resilience/vdp.py
index 2a712d39ce..de7de087c1 100644
--- a/pySDC/projects/Resilience/vdp.py
+++ b/pySDC/projects/Resilience/vdp.py
@@ -8,7 +8,7 @@
from pySDC.implementations.controller_classes.controller_nonMPI import controller_nonMPI
from pySDC.implementations.convergence_controller_classes.adaptivity import Adaptivity
from pySDC.core.Errors import ProblemError
-from pySDC.projects.Resilience.hook import log_error_estimates
+from pySDC.projects.Resilience.hook import log_data, hook_collection
def plot_step_sizes(stats, ax):
@@ -28,7 +28,7 @@ def plot_step_sizes(stats, ax):
p = np.array([me[1][1] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
t = np.array([me[0] for me in get_sorted(stats, type='u', recomputed=False, sortby='time')])
- e_em = np.array(get_sorted(stats, type='e_embedded', recomputed=False, sortby='time'))[:, 1]
+ e_em = np.array(get_sorted(stats, type='error_embedded_estimate', recomputed=False, sortby='time'))[:, 1]
dt = np.array(get_sorted(stats, type='dt', recomputed=False, sortby='time'))
restart = np.array(get_sorted(stats, type='restart', recomputed=None, sortby='time'))
@@ -85,7 +85,7 @@ def run_vdp(
custom_description=None,
num_procs=1,
Tend=10.0,
- hook_class=log_error_estimates,
+ hook_class=log_data,
fault_stuff=None,
custom_controller_params=None,
custom_problem_params=None,
@@ -138,7 +138,7 @@ def run_vdp(
# initialize controller parameters
controller_params = dict()
controller_params['logger_level'] = 30
- controller_params['hook_class'] = hook_class
+ controller_params['hook_class'] = hook_collection + [hook_class]
controller_params['mssdc_jac'] = False
if custom_controller_params is not None:
@@ -212,7 +212,7 @@ def fetch_test_data(stats, comm=None, use_MPI=False):
Returns:
dict: Key values to perform tests on
"""
- types = ['e_embedded', 'restart', 'dt', 'sweeps', 'residual_post_step']
+ types = ['error_embedded_estimate', 'restart', 'dt', 'sweeps', 'residual_post_step']
data = {}
for type in types:
if type not in get_list_of_types(stats):
diff --git a/pySDC/projects/matrixPFASST/controller_matrix_nonMPI.py b/pySDC/projects/matrixPFASST/controller_matrix_nonMPI.py
index afdcfed8a0..081ca420d4 100644
--- a/pySDC/projects/matrixPFASST/controller_matrix_nonMPI.py
+++ b/pySDC/projects/matrixPFASST/controller_matrix_nonMPI.py
@@ -133,7 +133,8 @@ def run(self, u0, t0, Tend):
# some initializations and reset of statistics
uend = None
num_procs = len(self.MS)
- self.hooks.reset_stats()
+ for hook in self.hooks:
+ hook.reset_stats()
assert (
(Tend - t0) / self.dt
@@ -152,7 +153,8 @@ def run(self, u0, t0, Tend):
# call pre-run hook
for S in self.MS:
- self.hooks.pre_run(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_run(step=S, level_number=0)
nblocks = int((Tend - t0) / self.dt / num_procs)
@@ -169,9 +171,10 @@ def run(self, u0, t0, Tend):
# call post-run hook
for S in self.MS:
- self.hooks.post_run(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.post_run(step=S, level_number=0)
- return uend, self.hooks.return_stats()
+ return uend, self.return_stats()
def build_propagation_matrix(self, niter):
"""
@@ -302,7 +305,8 @@ def pfasst(self, MS):
MS = self.update_data(MS=MS, u=self.u, res=self.res, niter=niter, level=0, stage='PRE_STEP')
for S in MS:
- self.hooks.pre_step(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_step(step=S, level_number=0)
while np.linalg.norm(self.res, np.inf) > self.tol and niter < self.maxiter:
@@ -310,14 +314,16 @@ def pfasst(self, MS):
MS = self.update_data(MS=MS, u=self.u, res=self.res, niter=niter, level=0, stage='PRE_ITERATION')
for S in MS:
- self.hooks.pre_iteration(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_iteration(step=S, level_number=0)
if self.nlevels > 1:
for _ in range(MS[0].levels[1].params.nsweeps):
MS = self.update_data(MS=MS, u=self.u, res=self.res, niter=niter, level=1, stage='PRE_COARSE_SWEEP')
for S in MS:
- self.hooks.pre_sweep(step=S, level_number=1)
+ for hook in self.hooks:
+ hook.pre_sweep(step=S, level_number=1)
self.u += self.Tcf.dot(np.linalg.solve(self.Pc, self.Tfc.dot(self.res)))
self.res = self.u0 - self.C.dot(self.u)
@@ -326,27 +332,32 @@ def pfasst(self, MS):
MS=MS, u=self.u, res=self.res, niter=niter, level=1, stage='POST_COARSE_SWEEP'
)
for S in MS:
- self.hooks.post_sweep(step=S, level_number=1)
+ for hook in self.hooks:
+ hook.post_sweep(step=S, level_number=1)
for _ in range(MS[0].levels[0].params.nsweeps):
MS = self.update_data(MS=MS, u=self.u, res=self.res, niter=niter, level=0, stage='PRE_FINE_SWEEP')
for S in MS:
- self.hooks.pre_sweep(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.pre_sweep(step=S, level_number=0)
self.u += np.linalg.solve(self.P, self.res)
self.res = self.u0 - self.C.dot(self.u)
MS = self.update_data(MS=MS, u=self.u, res=self.res, niter=niter, level=0, stage='POST_FINE_SWEEP')
for S in MS:
- self.hooks.post_sweep(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.post_sweep(step=S, level_number=0)
MS = self.update_data(MS=MS, u=self.u, res=self.res, niter=niter, level=0, stage='POST_ITERATION')
for S in MS:
- self.hooks.post_iteration(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.post_iteration(step=S, level_number=0)
MS = self.update_data(MS=MS, u=self.u, res=self.res, niter=niter, level=0, stage='POST_STEP')
for S in MS:
- self.hooks.post_step(step=S, level_number=0)
+ for hook in self.hooks:
+ hook.post_step(step=S, level_number=0)
return MS
diff --git a/pyproject.toml b/pyproject.toml
index ede34d02c1..f8bbeb3e87 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -78,7 +78,7 @@ pyflakes = [
'-E203', '-E741', '-E402', '-W504', '-W605', '-F401'
]
#flake8-black = ["+*"]
-flake8-bugbear = ["+*", '-B023']
+flake8-bugbear = ["+*", '-B023', '-B028']
flake8-comprehensions = ["+*", '-C408', '-C417']
[tool.black]