Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions etc/environment-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ dependencies:
- pytorch
- matplotlib>=3.0
- dill
- mpich
- mpi4py>=3.0.0
40 changes: 8 additions & 32 deletions pySDC/implementations/datatype_classes/mesh.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import numpy as np

from pySDC.core.Errors import DataError

try:
# TODO : mpi4py cannot be imported before dolfin when using fenics mesh
# see https://github.com/Parallel-in-Time/pySDC/pull/285#discussion_r1145850590
Expand All @@ -17,10 +15,12 @@ class mesh(np.ndarray):
Can include a communicator and expects a dtype to allow complex data.

Attributes:
_comm: MPI communicator or None
comm: MPI communicator or None
"""

def __new__(cls, init, val=0.0, offset=0, buffer=None, strides=None, order=None):
comm = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not _comm here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the getter function. I honestly don't see the point. But I can add it back in if you want.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I "get" it. Let's keep it explicit, then.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can still have private "explicit" attributes denoted by the leading _, but no getter and setter (which are indeed superfluous)


def __new__(cls, init, val=0.0, **kwargs):
"""
Instantiates new datatype. This ensures that even when manipulating data, the result is still a mesh.

Expand All @@ -33,56 +33,32 @@ def __new__(cls, init, val=0.0, offset=0, buffer=None, strides=None, order=None)

"""
if isinstance(init, mesh):
obj = np.ndarray.__new__(
cls, shape=init.shape, dtype=init.dtype, buffer=buffer, offset=offset, strides=strides, order=order
)
obj = np.ndarray.__new__(cls, shape=init.shape, dtype=init.dtype, **kwargs)
obj[:] = init[:]
obj._comm = init._comm
elif (
isinstance(init, tuple)
and (init[1] is None or isinstance(init[1], MPI.Intracomm))
and isinstance(init[2], np.dtype)
):
obj = np.ndarray.__new__(
cls, init[0], dtype=init[2], buffer=buffer, offset=offset, strides=strides, order=order
)
obj = np.ndarray.__new__(cls, init[0], dtype=init[2], **kwargs)
obj.fill(val)
obj._comm = init[1]
cls.comm = init[1]
else:
raise NotImplementedError(type(init))
return obj

@property
def comm(self):
"""
Getter for the communicator
"""
return self._comm

def __array_finalize__(self, obj):
"""
Finalizing the datatype. Without this, new datatypes do not 'inherit' the communicator.
"""
if obj is None:
return
self._comm = getattr(obj, '_comm', None)

def __array_ufunc__(self, ufunc, method, *inputs, out=None, **kwargs):
"""
Overriding default ufunc, cf. https://numpy.org/doc/stable/user/basics.subclassing.html#array-ufunc-for-ufuncs
"""
args = []
comm = None
for _, input_ in enumerate(inputs):
if isinstance(input_, mesh):
args.append(input_.view(np.ndarray))
comm = input_.comm
else:
args.append(input_)

results = super(mesh, self).__array_ufunc__(ufunc, method, *args, **kwargs).view(type(self))
if type(self) == type(results):
results._comm = comm
results = super().__array_ufunc__(ufunc, method, *args, **kwargs).view(type(self))
return results

def __abs__(self):
Expand Down
58 changes: 15 additions & 43 deletions pySDC/playgrounds/ML_initial_guess/tensor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import numpy as np
import torch

from pySDC.core.Errors import DataError

try:
from mpi4py import MPI
except ImportError:
Expand All @@ -12,14 +9,17 @@
class Tensor(torch.Tensor):
"""
Wrapper for PyTorch tensor.
Be aware that this is totally WIP! Should be fine to count iterations, but desperately needs cleaning up if this project goes much further!
Be aware that this is totally WIP! Should be fine to count iterations, but desperately needs cleaning up if this
project goes much further!

TODO: Have to update `torch/multiprocessing/reductions.py` in order to share this datatype across processes.

Attributes:
_comm: MPI communicator or None
comm: MPI communicator or None
"""

comm = None

@staticmethod
def __new__(cls, init, val=0.0, *args, **kwargs):
"""
Expand All @@ -33,54 +33,26 @@ def __new__(cls, init, val=0.0, *args, **kwargs):
obj of type mesh

"""
if isinstance(init, Tensor):
obj = super().__new__(cls, init)
# TODO: The cloning of tensors going in is likely slow

if isinstance(init, torch.Tensor):
obj = super().__new__(cls, init.clone())
obj[:] = init[:]
obj._comm = init._comm
elif (
isinstance(init, tuple)
# and (init[1] is None or isinstance(init[1], MPI.Intracomm))
and (init[1] is None or isinstance(init[1], MPI.Intracomm))
# and isinstance(init[2], np.dtype)
):
obj = super().__new__(cls, init[0].clone())
if isinstance(init[0][0], torch.Tensor):
obj = super().__new__(cls, init[0].clone())
else:
obj = super().__new__(cls, *init[0])
obj.fill_(val)
obj._comm = init[1]
cls.comm = init[1]
else:
raise NotImplementedError(type(init))
return obj

def __add__(self, *args, **kwargs):
res = super().__add__(*args, **kwargs)
res._comm = self.comm
return res

def __sub__(self, *args, **kwargs):
res = super().__sub__(*args, **kwargs)
res._comm = self.comm
return res

def __lmul__(self, *args, **kwargs):
res = super().__lmul__(*args, **kwargs)
res._comm = self.comm
return res

def __rmul__(self, *args, **kwargs):
res = super().__rmul__(*args, **kwargs)
res._comm = self.comm
return res

def __mul__(self, *args, **kwargs):
res = super().__mul__(*args, **kwargs)
res._comm = self.comm
return res

@property
def comm(self):
"""
Getter for the communicator
"""
return self._comm

def __abs__(self):
"""
Overloading the abs operator
Expand Down
107 changes: 107 additions & 0 deletions pySDC/tests/test_datatypes/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import pytest


def get_dtype(name):
if name == 'Tensor':
from pySDC.playgrounds.ML_initial_guess.tensor import Tensor as dtype_cls
elif name in ['mesh', 'imex_mesh']:
import pySDC.implementations.datatype_classes.mesh as mesh

dtype_cls = eval(f'mesh.{name}')
else:
raise NotImplementedError(f'Don\'t know a dtype of name {name!r}!')

return dtype_cls


def single_test(name, useMPI=False):
"""
This test checks that the communicator and datatype are maintained when generating new instances.
Also, it makes sure that you can supply different communicators.
"""
import numpy as np

dtype_cls = get_dtype(name)

shape = (5,)
comm = None
dtype = np.dtype('f')

if useMPI:
from mpi4py import MPI

comm_wd = MPI.COMM_WORLD
comm = comm_wd.Split(comm_wd.rank < comm_wd.size - 1)

expected_rank = comm_wd.rank % (comm_wd.size - 1)

init = (shape, comm, dtype)

a = dtype_cls(init, val=1.0)
b = dtype_cls(init, val=99.0)
c = dtype_cls(a)
d = a + b

for me in [a, b, c, d]:
assert type(me) == dtype_cls
assert me.comm == comm

if hasattr(me, 'shape') and not hasattr(me, 'components'):
assert me.shape == shape

if useMPI:
assert comm.rank == expected_rank
assert comm.size < comm_wd.size


def launch_test(name, useMPI, num_procs=1):
if useMPI:
import os
import subprocess

# Set python path once
my_env = os.environ.copy()
my_env['PYTHONPATH'] = '../../..:.'
my_env['COVERAGE_PROCESS_START'] = 'pyproject.toml'

cmd = f"mpirun -np {num_procs} python {__file__} --name={name} --useMPI=True"

p = subprocess.Popen(cmd.split(), env=my_env, cwd=".")

p.wait()
assert p.returncode == 0, 'ERROR: did not get return code 0, got %s with %2i processes' % (
p.returncode,
num_procs,
)
else:
single_test(name, False)


@pytest.mark.pytorch
@pytest.mark.parametrize('useMPI', [True, False])
def test_PyTorch_dtype(useMPI):
launch_test('Tensor', useMPI=useMPI, num_procs=4)


@pytest.mark.mpi4py
@pytest.mark.parametrize('name', ['mesh', 'imex_mesh'])
def test_mesh_dtypes_MPI(name):
launch_test(name, useMPI=True, num_procs=4)


@pytest.mark.base
@pytest.mark.parametrize('name', ['mesh', 'imex_mesh'])
def test_mesh_dtypes(name):
launch_test(name, useMPI=False)


if __name__ == '__main__':
str_to_bool = lambda me: False if me == 'False' else True
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--name', type=str, help='Name of the datatype')
parser.add_argument('--useMPI', type=str_to_bool, help='Toggle for MPI', choices=[True, False])
args = parser.parse_args()

single_test(**vars(args))