From 7e7b5a28095acc91a6b9ff12bef7b0c808713482 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Wed, 29 May 2024 10:41:49 +0200 Subject: [PATCH 1/5] Made communicator class attribute of `mesh` --- .../implementations/datatype_classes/mesh.py | 29 ++++--------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/pySDC/implementations/datatype_classes/mesh.py b/pySDC/implementations/datatype_classes/mesh.py index 41d77b1540..e92a9fd5d4 100644 --- a/pySDC/implementations/datatype_classes/mesh.py +++ b/pySDC/implementations/datatype_classes/mesh.py @@ -1,7 +1,5 @@ import numpy as np -from pySDC.core.Errors import DataError - try: # TODO : mpi4py cannot be imported before dolfin when using fenics mesh # see https://github.com/Parallel-in-Time/pySDC/pull/285#discussion_r1145850590 @@ -20,7 +18,9 @@ class mesh(np.ndarray): _comm: MPI communicator or None """ - def __new__(cls, init, val=0.0, offset=0, buffer=None, strides=None, order=None): + _comm = None + + def __new__(cls, init, val=0.0, **kwargs): """ Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh. @@ -33,19 +33,14 @@ def __new__(cls, init, val=0.0, offset=0, buffer=None, strides=None, order=None) """ if isinstance(init, mesh): - obj = np.ndarray.__new__( - cls, shape=init.shape, dtype=init.dtype, buffer=buffer, offset=offset, strides=strides, order=order - ) + obj = np.ndarray.__new__(cls, shape=init.shape, dtype=init.dtype, **kwargs) obj[:] = init[:] - obj._comm = init._comm elif ( isinstance(init, tuple) and (init[1] is None or isinstance(init[1], MPI.Intracomm)) and isinstance(init[2], np.dtype) ): - obj = np.ndarray.__new__( - cls, init[0], dtype=init[2], buffer=buffer, offset=offset, strides=strides, order=order - ) + obj = np.ndarray.__new__(cls, init[0], dtype=init[2], **kwargs) obj.fill(val) obj._comm = init[1] else: @@ -59,30 +54,18 @@ def comm(self): """ return self._comm - def __array_finalize__(self, obj): - """ - Finalizing the datatype. Without this, new datatypes do not 'inherit' the communicator. - """ - if obj is None: - return - self._comm = getattr(obj, '_comm', None) - def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs): """ Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs """ args = [] - comm = None for _, input_ in enumerate(inputs): if isinstance(input_, mesh): args.append(input_.view(np.ndarray)) - comm = input_.comm else: args.append(input_) - results = super(mesh, self).__array_ufunc__(ufunc, method, *args, **kwargs).view(type(self)) - if type(self) == type(results): - results._comm = comm + results = super().__array_ufunc__(ufunc, method, *args, **kwargs).view(type(self)) return results def __abs__(self): From f242f1f36391b16fb333b058fa41a6bd2bde8afe Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Wed, 29 May 2024 11:53:07 +0200 Subject: [PATCH 2/5] Moved to communicator as class attribute of `Tensor` as well --- .../implementations/datatype_classes/mesh.py | 27 ++-------- pySDC/playgrounds/ML_initial_guess/tensor.py | 49 +++---------------- 2 files changed, 11 insertions(+), 65 deletions(-) diff --git a/pySDC/implementations/datatype_classes/mesh.py b/pySDC/implementations/datatype_classes/mesh.py index e92a9fd5d4..207ad75c10 100644 --- a/pySDC/implementations/datatype_classes/mesh.py +++ b/pySDC/implementations/datatype_classes/mesh.py @@ -15,10 +15,10 @@ class mesh(np.ndarray): Can include a communicator and expects a dtype to allow complex data. Attributes: - _comm: MPI communicator or None + comm: MPI communicator or None """ - _comm = None + comm = None def __new__(cls, init, val=0.0, **kwargs): """ @@ -42,32 +42,11 @@ def __new__(cls, init, val=0.0, **kwargs): ): obj = np.ndarray.__new__(cls, init[0], dtype=init[2], **kwargs) obj.fill(val) - obj._comm = init[1] + cls.comm = init[1] else: raise NotImplementedError(type(init)) return obj - @property - def comm(self): - """ - Getter for the communicator - """ - return self._comm - - def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs): - """ - Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs - """ - args = [] - for _, input_ in enumerate(inputs): - if isinstance(input_, mesh): - args.append(input_.view(np.ndarray)) - else: - args.append(input_) - - results = super().__array_ufunc__(ufunc, method, *args, **kwargs).view(type(self)) - return results - def __abs__(self): """ Overloading the abs operator diff --git a/pySDC/playgrounds/ML_initial_guess/tensor.py b/pySDC/playgrounds/ML_initial_guess/tensor.py index d59ef4b8d2..5d7f3466e0 100644 --- a/pySDC/playgrounds/ML_initial_guess/tensor.py +++ b/pySDC/playgrounds/ML_initial_guess/tensor.py @@ -1,8 +1,5 @@ -import numpy as np import torch -from pySDC.core.Errors import DataError - try: from mpi4py import MPI except ImportError: @@ -12,14 +9,17 @@ class Tensor(torch.Tensor): """ Wrapper for PyTorch tensor. - Be aware that this is totally WIP! Should be fine to count iterations, but desperately needs cleaning up if this project goes much further! + Be aware that this is totally WIP! Should be fine to count iterations, but desperately needs cleaning up if this + project goes much further! TODO: Have to update `torch/multiprocessing/reductions.py` in order to share this datatype across processes. Attributes: - _comm: MPI communicator or None + comm: MPI communicator or None """ + comm = None + @staticmethod def __new__(cls, init, val=0.0, *args, **kwargs): """ @@ -36,51 +36,18 @@ def __new__(cls, init, val=0.0, *args, **kwargs): if isinstance(init, Tensor): obj = super().__new__(cls, init) obj[:] = init[:] - obj._comm = init._comm elif ( isinstance(init, tuple) - # and (init[1] is None or isinstance(init[1], MPI.Intracomm)) + and (init[1] is None or isinstance(init[1], MPI.Intracomm)) # and isinstance(init[2], np.dtype) ): - obj = super().__new__(cls, init[0].clone()) + obj = super().__new__(cls, *init[0]) obj.fill_(val) - obj._comm = init[1] + cls.comm = init[1] else: raise NotImplementedError(type(init)) return obj - def __add__(self, *args, **kwargs): - res = super().__add__(*args, **kwargs) - res._comm = self.comm - return res - - def __sub__(self, *args, **kwargs): - res = super().__sub__(*args, **kwargs) - res._comm = self.comm - return res - - def __lmul__(self, *args, **kwargs): - res = super().__lmul__(*args, **kwargs) - res._comm = self.comm - return res - - def __rmul__(self, *args, **kwargs): - res = super().__rmul__(*args, **kwargs) - res._comm = self.comm - return res - - def __mul__(self, *args, **kwargs): - res = super().__mul__(*args, **kwargs) - res._comm = self.comm - return res - - @property - def comm(self): - """ - Getter for the communicator - """ - return self._comm - def __abs__(self): """ Overloading the abs operator From b6699add59cd4bd5e5f350bf29e212757eb1b804 Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Wed, 29 May 2024 11:53:42 +0200 Subject: [PATCH 3/5] Forgot to add file to commit. --- pySDC/tests/test_datatypes/test_datatypes.py | 107 +++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 pySDC/tests/test_datatypes/test_datatypes.py diff --git a/pySDC/tests/test_datatypes/test_datatypes.py b/pySDC/tests/test_datatypes/test_datatypes.py new file mode 100644 index 0000000000..6f152890c2 --- /dev/null +++ b/pySDC/tests/test_datatypes/test_datatypes.py @@ -0,0 +1,107 @@ +import pytest + + +def get_dtype(name): + if name == 'Tensor': + from pySDC.playgrounds.ML_initial_guess.tensor import Tensor as dtype_cls + elif name in ['mesh', 'imex_mesh']: + import pySDC.implementations.datatype_classes.mesh as mesh + + dtype_cls = eval(f'mesh.{name}') + else: + raise NotImplementedError(f'Don\'t know a dtype of name {name!r}!') + + return dtype_cls + + +def single_test(name, useMPI=False): + """ + This test checks that the communicator and datatype are maintained when generating new instances. + Also, it makes sure that you can supply different communicators. + """ + import numpy as np + + dtype_cls = get_dtype(name) + + shape = (5,) + comm = None + dtype = np.dtype('f') + + if useMPI: + from mpi4py import MPI + + comm_wd = MPI.COMM_WORLD + comm = comm_wd.Split(comm_wd.rank < comm_wd.size - 1) + + expected_rank = comm_wd.rank % (comm_wd.size - 1) + + init = (shape, comm, dtype) + + a = dtype_cls(init, val=1.0) + b = dtype_cls(init, val=99.0) + c = dtype_cls(a) + d = a + b + + for me in [a, b, c, d]: + assert type(me) == dtype_cls + assert me.comm == comm + + if hasattr(me, 'shape') and not hasattr(me, 'components'): + assert me.shape == shape + + if useMPI: + assert comm.rank == expected_rank + assert comm.size < comm_wd.size + + +def launch_test(name, useMPI, num_procs=1): + if useMPI: + import os + import subprocess + + # Set python path once + my_env = os.environ.copy() + my_env['PYTHONPATH'] = '../../..:.' + my_env['COVERAGE_PROCESS_START'] = 'pyproject.toml' + + cmd = f"mpirun -np {num_procs} python {__file__} --name={name} --useMPI=True" + + p = subprocess.Popen(cmd.split(), env=my_env, cwd=".") + + p.wait() + assert p.returncode == 0, 'ERROR: did not get return code 0, got %s with %2i processes' % ( + p.returncode, + num_procs, + ) + else: + single_test(name, False) + + +@pytest.mark.pytorch +@pytest.mark.parametrize('useMPI', [True, False]) +def test_PyTorch_dtype(useMPI): + launch_test('Tensor', useMPI=useMPI, num_procs=4) + + +@pytest.mark.mpi4py +@pytest.mark.parametrize('name', ['mesh', 'imex_mesh']) +def test_mesh_dtypes_MPI(name): + launch_test(name, useMPI=True, num_procs=4) + + +@pytest.mark.base +@pytest.mark.parametrize('name', ['mesh', 'imex_mesh']) +def test_mesh_dtypes(name): + launch_test(name, useMPI=False) + + +if __name__ == '__main__': + str_to_bool = lambda me: False if me == 'False' else True + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--name', type=str, help='Name of the datatype') + parser.add_argument('--useMPI', type=str_to_bool, help='Toggle for MPI', choices=[True, False]) + args = parser.parse_args() + + single_test(**vars(args)) From b034645d3b44b24cc90e52738fc41007182b46bb Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Wed, 29 May 2024 12:17:34 +0200 Subject: [PATCH 4/5] Fixes --- etc/environment-pytorch.yml | 1 + pySDC/implementations/datatype_classes/mesh.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/etc/environment-pytorch.yml b/etc/environment-pytorch.yml index c690be37ff..237ae0b760 100644 --- a/etc/environment-pytorch.yml +++ b/etc/environment-pytorch.yml @@ -11,3 +11,4 @@ dependencies: - pytorch - matplotlib>=3.0 - dill + - mpi4py diff --git a/pySDC/implementations/datatype_classes/mesh.py b/pySDC/implementations/datatype_classes/mesh.py index 207ad75c10..cda77da4a6 100644 --- a/pySDC/implementations/datatype_classes/mesh.py +++ b/pySDC/implementations/datatype_classes/mesh.py @@ -47,6 +47,20 @@ def __new__(cls, init, val=0.0, **kwargs): raise NotImplementedError(type(init)) return obj + def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs): + """ + Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs + """ + args = [] + for _, input_ in enumerate(inputs): + if isinstance(input_, mesh): + args.append(input_.view(np.ndarray)) + else: + args.append(input_) + + results = super().__array_ufunc__(ufunc, method, *args, **kwargs).view(type(self)) + return results + def __abs__(self): """ Overloading the abs operator From 4039cbb923f081b187b8f26484a1cad90bbec10b Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Wed, 29 May 2024 12:37:19 +0200 Subject: [PATCH 5/5] Fixed environment --- etc/environment-pytorch.yml | 3 ++- pySDC/playgrounds/ML_initial_guess/tensor.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/etc/environment-pytorch.yml b/etc/environment-pytorch.yml index 237ae0b760..ff5ce4f24d 100644 --- a/etc/environment-pytorch.yml +++ b/etc/environment-pytorch.yml @@ -11,4 +11,5 @@ dependencies: - pytorch - matplotlib>=3.0 - dill - - mpi4py + - mpich + - mpi4py>=3.0.0 diff --git a/pySDC/playgrounds/ML_initial_guess/tensor.py b/pySDC/playgrounds/ML_initial_guess/tensor.py index 5d7f3466e0..9b58a79b9d 100644 --- a/pySDC/playgrounds/ML_initial_guess/tensor.py +++ b/pySDC/playgrounds/ML_initial_guess/tensor.py @@ -33,15 +33,20 @@ def __new__(cls, init, val=0.0, *args, **kwargs): obj of type mesh """ - if isinstance(init, Tensor): - obj = super().__new__(cls, init) + # TODO: The cloning of tensors going in is likely slow + + if isinstance(init, torch.Tensor): + obj = super().__new__(cls, init.clone()) obj[:] = init[:] elif ( isinstance(init, tuple) and (init[1] is None or isinstance(init[1], MPI.Intracomm)) # and isinstance(init[2], np.dtype) ): - obj = super().__new__(cls, *init[0]) + if isinstance(init[0][0], torch.Tensor): + obj = super().__new__(cls, init[0].clone()) + else: + obj = super().__new__(cls, *init[0]) obj.fill_(val) cls.comm = init[1] else: