Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send / recv proof of concept #2

Merged
merged 22 commits into from
Sep 12, 2020
Merged

Conversation

dionhaefner
Copy link
Collaborator

The following script works:

from mpi4py import MPI

import numpy as onp
import jax
import jax.numpy as jnp
from mpi4jax import Send, Recv

SHAPE = (10, 10)

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
status = MPI.Status()


def send_recv(x, root=0, use_status=False):
    if rank == root:
        if use_status:
            x = Recv(x, comm=comm, status=status)
        else:
            x = Recv(x, comm=comm)
    else:
        x = Send(x, root, comm=comm)
    return x


send_recv_jit = jax.jit(send_recv, static_argnums=(1, 2))


if __name__ == '__main__':
    assert size == 2
    root = 1

    if rank == root:
        x = jnp.empty(SHAPE)
    else:
        x = jnp.ones(SHAPE)

    res = send_recv(x, root)
    expected_res = onp.ones(SHAPE, dtype='float32')
    assert onp.array_equal(res, expected_res), (rank, res)
    print('ok', rank)

    res_jit = send_recv_jit(x, root, False)
    assert onp.array_equal(res_jit, expected_res), (rank, res_jit)
    print('ok', rank)

    res_jit = send_recv_jit(x, root, True)
    assert onp.array_equal(res_jit, expected_res), (rank, res_jit)
    print('ok', rank)

    if rank == root:
        assert status.Get_source() == 1 - rank
    else:
        assert status.Get_source() == -1

There is one problem though that you have to assign something to the result of the Send call, otherwise it gets optimized out and everything deadlocks. I.e., this doesn't work:

@jax.jit
def send_recv(x):
    if rank == 0:
        x = Recv(x, comm=comm)
    else:
        Send(x, 0, comm=comm)  # has to be x = Send(...)
    return x

Not sure if there's anything we can do about that. The whole implementation is pretty hacky with an unnecessary memcpy, but I don't think JAX / XLA accounts for custom calls that have side effects.

I didn't touch the gradient code (yet).

@dionhaefner
Copy link
Collaborator Author

One solution could be to only implement sendrecv. Then all processes call the same function, which should be a bit more robust to being optimized away on some but not all processes. But this would still fail:

@jax.jit
def foo(x):
    y = sendrecv(x, source=0, dest=1)
    if rank == 1:
        return y
    return x

since rank 0 doesn't do anything with y.

@PhilipVinc
Copy link
Member

I think that this is because XLA's compiler is very aggressive.
As soon as he sees that you don't use the output value of a leaf of the computational graph, he optimises it out.
Send of course has no used leaves, so he gets rid of it.

I think for this we should ask jag's people if it's somehow possible to tag as 'do_not_optimise' a function.

@PhilipVinc
Copy link
Member

By the way, if you rebase, tests should be working now.

@PhilipVinc
Copy link
Member

I guess that until Jax#3370 is merged we should rather focus on collective all-to-all communications, which are not affected by the side-effect problem.
(BTW, I'll be on holiday in the next few weeks so I won't be working on this, but in case you cook up some PR I'll be quick to review)

@dionhaefner
Copy link
Collaborator Author

dionhaefner commented Jul 27, 2020

Collective all-to-all operations are affected, too, though. Example:

@jax.jit
def foo(x):
    x = Allreduce(x)
    if rank == 0:
        return 0  # kaboom
    return x

So right now, it is the user’s responsibility to make sure that there is a data dependency on the return value of the MPI calls.

@PhilipVinc
Copy link
Member

Allreduce is what I have implemented and, at least in my experience, is working well.

Why are you returning 0 if rank == 0?

@PhilipVinc
Copy link
Member

Ah, ok, I get it.
You mean that if the function does not return something that depends on it's result.

What I meant is that all-to-all are (usually) used in contexts where all ranks execute the same code, so that is (sometimes) not an issue.

@dionhaefner
Copy link
Collaborator Author

I agree, all-to-all are lower risk. Ultimately it's the user's responsibility not to mess up though, so we should put a warning in the readme or so :)

@PhilipVinc
Copy link
Member

The omnistaging and has_side_effects stuff are not yet on a released version right?
should we bump the minimum jax version?

@dionhaefner
Copy link
Collaborator Author

The omnistaging and has_side_effects stuff are not yet on a released version right?
should we bump the minimum jax version?

Yes, I think that would be sensible. I don't think there's a strong motivation to introduce a bunch of extra logic for JAX versions pre-omnistaging / side effect support.

@PhilipVinc
Copy link
Member

What is the tradeoff of omnistaging?
We could also activate it ourselves in init.
Most of mpi4jax won't work otherwise...

@dionhaefner
Copy link
Collaborator Author

Most of mpi4jax won't work otherwise...

It works fine if you don't use jit or just use all-to-all-type operations. We don't know the performance impact of omnistaging yet, and some packages like Tensorflow probability break with it. I opted for a warning for now whenever an MPI call is being jitted.

But yes, when / if omnistaging becomes the default in JAX I think we should just require it.

@dionhaefner
Copy link
Collaborator Author

This is done from my side. Tests are failing hard until omnistaging is released, but it works on my machine™️

@dionhaefner
Copy link
Collaborator Author

When using this I noticed that XLA would sometimes re-order send and recv calls, which causes deadlocks.

So unless a solution comes up in jax-ml/jax#3976, this will need some token mechanism to ensure proper order.

This principally affects all primitives, but it's easiest to run into with send and recv.

@PhilipVinc
Copy link
Member

PhilipVinc commented Aug 7, 2020 via email

@dionhaefner
Copy link
Collaborator Author

This is current JAX master. I've opened an issue already and the JAX devs confirmed the problem (jax-ml/jax#3976). If there's no fix on the horizon, I had some ideas of a "token tape" that should allow us to get around this with 1 line of extra boilerplate in user code.

@dionhaefner
Copy link
Collaborator Author

I added a token mechanism which ensures that proper order is conserved. The idea is to chain the calls by using an XLA token:

token = Send(...)  # not passing a token creates a new one
token = Send(..., token=token)  # re-use previous token
arr, token = Recv(..., token=token)
arr, token = Sendrecv(..., token=token)
arr, token = Allreduce(..., token=token)

As long as the correct token is passed, those statements should never get re-ordered (relative to each other) or optimized away.

It sounds like the JAX people are cooking up a solution that does this token chaining automatically behind the scenes, but my feeling is that this might take a while to land.

Left to-do:

vmap and grad support is broken

@PhilipVinc
Copy link
Member

very well
It's clumsy but indeed we cannot do much without it, until they fix this upstreams.

@dionhaefner
Copy link
Collaborator Author

How important are vmap and grad for you? AFAICT, grad would only be meaningful on global sums, and vmap is a bit pointless, too. I'm asking because create_token supports neither of those, so it would require some additional hacking to get this to work.

@PhilipVinc
Copy link
Member

They are not essential.
I'd say we can merge this as-is and I can add vmap and grad by myself in the future.
In that case, however, I think we should make a table with supported features in the readme.md for every operation.

--

grad for send and recv is a bit tricky to implement and would require some thinking.
vmap instead I think it would be simple: the operation is the same, you simply have to recompute the total buffer size, as I do with allreduce. For the token... you can keep a single token, I guess. Why would you need more than one?

@dionhaefner
Copy link
Collaborator Author

vmap is easy conceptionally, but JAX doesn't like it, because you cannot create tokens in vmapped functions. To get around that, we could write a thin wrapper around create_token that defines a trivial batching rule, but it requires a bit more work.

I'll patch out grad and vmap for now, and then we can merge as soon as the JAX release with omnistaging is available.

@PhilipVinc
Copy link
Member

Ok, thanks for the investigation!

That's fine by me.

@PhilipVinc
Copy link
Member

I'm back from holidays and starting to work again!

How are we with the merging of omnistaging in jax? do you have any news?

@dionhaefner
Copy link
Collaborator Author

dionhaefner commented Sep 9, 2020

Welcome back! Omnistaging is merged, but there is still no jaxlib release, so has_side_effects is not yet available. Maybe we could ask for a jaxlib release, it's almost been 2 months now.

There has been a jaxlib release today, so this should work now :)

@dionhaefner
Copy link
Collaborator Author

Ah, it's just a tag, not a release... I asked for one.

@PhilipVinc
Copy link
Member

Yay! Thanks for bumping the google guys

@PhilipVinc PhilipVinc marked this pull request as ready for review September 11, 2020 09:20
@PhilipVinc PhilipVinc mentioned this pull request Sep 11, 2020
@dionhaefner
Copy link
Collaborator Author

Done from my side (for real this time).

@PhilipVinc
Copy link
Member

All is great.
If you can add just a comment somewhere in code about MPI_STATUS_IGNORE, then I'll merge and try to tag a new release

@PhilipVinc PhilipVinc merged commit 989237b into mpi4jax:master Sep 12, 2020
@PhilipVinc PhilipVinc mentioned this pull request Mar 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants