New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated' #1479
Comments
I assume you are running on a recent version of jax (>0.4.3). If that's not the case, please try to update. MPI_ERR_TRUNCATE means that the message received by MPI on a node (probably that all reduce call) is longer than what was expected.
To find out more, try running your code with MPI4JAX_DEBUG=1 mpirun -n 2 python script.py To get some debug information and try to paste it back here, it should show up some payloads having different sizes on different ranks... |
First of all thank you for all the tips. I'm using a recent version of jax (updated in description), I'm just calculating expect_and_forces here, so no sync issue(?). Also due to the small size I think there are no memory issues either. I've created a (minimal) example that reproduces the error. Unfortunately I couldn't cut down the model definition much further without losing the error. The network is essentially a symmetric RBM with additional correlators, so the first part is just that of a symmetric RBM, the first loop adds correlator terms and the second loop is very similar (here hidden units are only connected to correlator terms). This example here however does by no means make any physical sense and is also not optimized, it's just for demonstration. import os
from mpi4py import MPI
comm = MPI.COMM_WORLD
n_ranks = comm.Get_size()
rank = MPI.COMM_WORLD.Get_rank()
os.environ["MPI4JAX_USE_CUDA_MPI"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = f"{rank}"
os.environ["JAX_PLATFORM_NAME"] = "gpu"
import jax
import jax.numpy as jnp
import netket as nk
from netket.utils import HashableArray
from netket import nn as nknn
from flax import linen as nn
from flax.linen.dtypes import promote_dtype
from typing import Union, Any, Callable, Sequence
L = 3
N = 2*L**2
hilbert = nk.hilbert.Spin(s=1 / 2, N=N)
ham = sum([nk.operator.spin.sigmaz(hilbert, i) for i in range(hilbert.size)])
# chosen arbitrary with adequate shapes
perms = jnp.stack([(jnp.arange(int(N/2)) + i) % N for i in range(int(N/2))], axis=0)
perms = HashableArray(perms)
link_perms = jnp.stack([(jnp.arange(int(N)) + i) % N for i in range(int(N/2))], axis=0)
link_perms = HashableArray(link_perms)
plaqs = jnp.array([[1, 2, 3, 4] for _ in range(int(N/2))])
plaqs = HashableArray(plaqs)
correlators = (plaqs,)
correlator_symmetries = (perms,)
model = ToricCRBM(symmetries=link_perms,
correlators=correlators,
correlator_symmetries=correlator_symmetries,
param_dtype=complex)
single_rule = nk.sampler.rules.LocalRule()
sampler = nk.sampler.MetropolisSampler(hilbert, rule=single_rule)
variational_gs = nk.vqs.MCState(sampler, model)
exp, forces = variational_gs.expect_and_forces(ham) where the model is defined as class ToricCRBM(nn.Module):
# permutations of lattice sites corresponding to symmetries
symmetries: HashableArray
# correlators that serve as additional input for the cRBM
correlators: Sequence[HashableArray]
# permutations of correlators corresponding to symmetries
correlator_symmetries: Sequence[HashableArray]
# The dtype of the weights
param_dtype: Any = jnp.float64
# The nonlinear activation function
activation: Any = nknn.log_cosh
# feature density. Number of features equal to alpha * input.shape[-1]
alpha: Union[float, int] = 1
# Initializer for the Dense layer matrix
kernel_init: Callable = jax.nn.initializers.normal(0.01)
# Initializer for the biases
bias_init: Callable = jax.nn.initializers.normal(0.01)
def setup(self):
self.n_symm, self.n_sites = self.symmetries.wrapped.shape
self.features = int(self.alpha * self.n_sites / self.n_symm)
@nn.compact
def __call__(self, x):
n_batch = x.shape[0]
# initialize bias and kernel for the "single spin correlators" analogous to GCNN
hidden_bias = self.param("hidden_bias", self.bias_init, (self.features,), self.param_dtype)
symm_kernel = self.param("symm_kernel", self.kernel_init, (self.features, self.n_sites), self.param_dtype)
# take care of possibly different dtypes (e.g. x is float while parameters are complex)
x, hidden_bias, symm_kernel = promote_dtype(x, hidden_bias, symm_kernel, dtype=None)
# convert kernel to dense kernel of shape (features, n_symmetries, n_sites)
symm_kernel = jnp.take(symm_kernel, self.symmetries.wrapped, axis=1)
# x has shape (batch, n_sites), kernel has shape (features, n_symmetries, n_sites)
# theta has shape (batch, features, n_symmetries)
theta = jax.lax.dot_general(x, symm_kernel, (((1,), (2,)), ((), ())))
theta += jnp.expand_dims(hidden_bias, 1)
# here, use two visible biases according to the two spins in one unit cell
visible_bias = self.param("visible_bias", self.bias_init, (2,), self.param_dtype)
bias = jnp.sum(x.reshape(n_batch, -1, 2), axis=(1,)) @ visible_bias
for i, (correlator, correlator_symmetry) in enumerate(zip(self.correlators, self.correlator_symmetries)):
# initialize "visible" bias and kernel matrix for correlator
correlator = correlator.wrapped # convert hashable array to (usable) jax.Array
corr_bias = self.param(f"corr{i}_bias", self.bias_init, (1,), self.param_dtype)
corr_kernel = self.param(f"corr{i}_kernel", self.kernel_init, (self.features, len(correlator)),
self.param_dtype)
x, corr_bias, corr_kernel = promote_dtype(x, corr_bias, corr_kernel, dtype=None)
# convert kernel to dense kernel of shape (features, n_correlator_symmetries, n_corrs)
corr_kernel = jnp.take(corr_kernel, correlator_symmetry.wrapped, axis=1)
# correlator has shape (n_corrs, degree), e.g. n_corrs=L²=n_site/2 and degree=4 for plaquettes
corr_values = jnp.take(x, correlator, axis=1).prod(axis=2) # shape (batch, n_corrs)
# corr_values has shape (batch, n_corrs)
# kernel has shape (features, n_correlator_symmetries, n_corrs)
# theta has shape (batch, features, n_symmetries)
theta += jax.lax.dot_general(corr_values, corr_kernel, (((1,), (2,)), ((), ())))
bias += corr_bias * jnp.sum(corr_values, axis=(1,))
out = self.activation(theta)
out = jnp.sum(out, axis=(1, 2)) # sum over all symmetries and features = alpha * n_sites / n_symmetries
out += bias
for i, (loop_corr, loop_symmetries) in enumerate(zip(self.correlators, self.correlator_symmetries)):
# initialize "visible" bias and kernel matrix for loop correlator
loop_corr = loop_corr.wrapped # convert hashable array to (usable) jax.Array
loop_hidden_bias = self.param(f"loop{i}_hidden_bias", self.bias_init, (self.features,), self.param_dtype)
loop_kernel = self.param(f"loop{i}_kernel", self.kernel_init, (self.features, len(loop_corr)),
self.param_dtype)
x, loop_hidden_bias, loop_kernel = promote_dtype(x, loop_hidden_bias, loop_kernel, dtype=None)
# convert kernel to dense kernel of shape (features, n_correlator_symmetries, n_corrs)
loop_kernel = jnp.take(loop_kernel, loop_symmetries.wrapped, axis=1)
# loop_corr has shape (n_corrs, degree)
loop_values = jnp.take(x, loop_corr, axis=1).prod(axis=2) # shape (batch, n_corrs)
# loop_values has shape (batch, n_corrs)
# kernel has shape (features, n_loop_symmetries, n_corrs)
# loop_theta has shape (batch, features, n_symmetries)
loop_theta = jax.lax.dot_general(loop_values, loop_kernel, (((1,), (2,)), ((), ())))
loop_theta += jnp.expand_dims(loop_hidden_bias, 1)
loop_out = jnp.sum(self.activation(loop_theta), axis=(1, 2))
# this last line causes mpi allreduce error code 15
out += loop_out
return out When running this with
The second loop in the forward method of the model is causing the error, I believe. If I just remove the line |
Is the bug reliable? in the sense, does it show up 100% of time you run this script? The last lines here r0 | fV1PM63v | MPI_Allreduce with 1 items
r3 | nCfVh97V | MPI_Allreduce with 1 items
r2 | jXdDZXxX | MPI_Allreduce with 1 items
r1 | z4K1BxYz | MPI_Allreduce with 36 items
r3 | nCfVh97V | MPI_Allreduce done with code 15 (1.59e-04s)
r0 | fV1PM63v | MPI_Allreduce done with code 15 (1.53e-04s)
r1 | z4K1BxYz | MPI_Allreduce done with code 0 (1.91e-05s)
r3 | MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated' - aborting
r0 | MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated' - aborting
r1 | kzIVpqUw | MPI_Allreduce with 18 items say that you are running an Allreduce between ranks with objects of different shapes, for some reason I don't understand. Your model seems correct in that you don't do anything wrong in here. I tried running this code on a CPU and it runs fine (I don't have access to GPUs at the moment). I would also try printing with |
Running it with the flag -np 4 (or anything bigger than 2) crashed 100% of times I tried. When running it with -np 2 it sometimes throws the error, just retrying then usually works. Running just one process or only on CPU works fine for me too. Changing samples also does not seem to have any effect. I'll try my luck with jax.debug.print and update after that. |
Alright, I found some interesting things: Whenever I use results in
So there can't be an error with NaNs or Infs as far as I understand. Next, I checked the keys of Ō_grad, resulting in
Also, when comparing the shapes of each individual entry (like
I get the output below (including the debug output)
Notice that
Now the question is why some parameters seem to have different shapes or do not match between ranks.
|
That's not really how you're supposed to use In short, you should be trying with |
Can you also try to insert the following line before the sum call? print(f"r{mpi.rank}: Parameters:{jax.tree_map(lambda x: (x.shape, x.dtype), parameters)}")
print(f"r{mpi.rank}: Ō_grad:{jax.tree_map(lambda x: (x.shape, x.dtype), Ō_grad)}")
return Ō, jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state This should confirm if your intuition is correct. |
As for the loop being the source of problems, not really. |
Yes you're right, my bad. I've mistaken the positional argument for string formatting.... Inserting the print statements I get
They all seem to match. jax.debug.print("r{x}: Parameters:{y}", x=mpi.rank, y=jax.tree_map(lambda x: (x.shape, x.dtype), parameters))
jax.debug.print("r{x}: Ō_grad:{y}", x=mpi.rank, y=jax.tree_map(lambda x: (x.shape, x.dtype), Ō_grad)) gives an error
Is that intended? |
You should print jax.debug.print("pars: {0}", parameters)
jax.debug.print("pars: {0}", O_loc)
That's not surprising. But will probably destroy the runtime.. |
Just printing jax.debug.print("pars: {0}", parameters)
jax.debug.print("O_loc: {0}", O_loc)
jax.debug.print("O_grad: {0}", O_grad) within jit works. The parameters over all ranks don't contain any nan, inf and have the appropriate shape (and are equal across all ranks). The output is as follows:
|
Well changing the return value is not preventing the error because the error is happening within this MPI call. I think the only possibility left is that for some reason jax is reordering the MPI operations on one of the ranks (or all of them).... but I'm not sure how we could check that. (Maybe @dionhaefner has some ideas?) |
Well, errors like these could be from inconsistent shapes between ranks or things being re-ordered. All the evidence I've seen here points towards the latter (disappears without JIT, shapes seem to match). (If you have explicit rank-dependent branches or local value-dependent branches in netket I would triple check to rule out the first possibility.) Operations like these seem risky to me: jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad) For this to behave properly, you need (1) |
Is there a way to make mpi show what the expected and received message length is? |
Another point: This being a GPU-only problem also points towards improper token management, since the GPU compiler is much more aggressive with re-orderings. |
@MandMarc , @inailuig identified the problem and is working on a fix. Indeed it was due to token mismanagement in netket. the fix will take a bit to land but we'll get there... |
Thank you for pursuing this issue and attempting to fix it! I‘ll follow the pull requests, maybe I can help at some point. |
The main limitation right now is that jax linear transposition does not understand tokens and crash. |
I was getting the same error as @MandMarc (but on cpu), which I also traced back to the mpi_allreduce calls in the tree_map being reordered, resulting in mpi messages with the wrong lengths (some ranks sending different leaves) causing error 15. In #1494 I proposed a possible workaround, involving a custom tree map which uses tokens to enforce the order. Had to make it a custom primitive (essentially linear_call plus different batching rule) so that it's transposable with AD (as required by the jax solvers). Filippo wants to come up with a more elegant solution, therefore we are not going to merge it, however feel free to use it in the meantime. |
Hello,
I'm aware this might be a shot into the blue but let's see where it goes:
I have a script that trains a custom model on the Toric Code. This script works and runs without any problems using MPI with one or two processes. However, when running it with with "-np 4" I get the following error:
MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated
I've been able to narrow down the source. After reducing my model complexity the error disappeared and I was able to run it on e.g. 4 processes. Moreover, vqs.expect(...) calls don't cause the error, but vqs.expect_and_forces(...) does. So it seems that the last line in expect_forces.py
return Ō, jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
contains the MPI Allreduce call that throws the error. This is plausible to me because the error message suggests buffer issues and increasing the number of parameters in my network directly affects the size of Ō_grad which gets reduced.
Now, is this something I can fix or does it need to be fixed within Netket, MPI4Jax or within my MPI installation? I'm not sure where to go from here.
Thanks!
Edit:
I'm using the following versions:
NetKet 3.7
mpi4py 3.1.4
mpi4jax 0.3.14.post1
jax 0.4.6
jaxlib 0.4.6+cuda11.cudnn82
The text was updated successfully, but these errors were encountered: