Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated' #1479

Open
MarcMachaczek opened this issue May 17, 2023 · 19 comments

Comments

@MarcMachaczek
Copy link

MarcMachaczek commented May 17, 2023

Hello,
I'm aware this might be a shot into the blue but let's see where it goes:
I have a script that trains a custom model on the Toric Code. This script works and runs without any problems using MPI with one or two processes. However, when running it with with "-np 4" I get the following error:
MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated
I've been able to narrow down the source. After reducing my model complexity the error disappeared and I was able to run it on e.g. 4 processes. Moreover, vqs.expect(...) calls don't cause the error, but vqs.expect_and_forces(...) does. So it seems that the last line in expect_forces.py
return Ō, jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state
contains the MPI Allreduce call that throws the error. This is plausible to me because the error message suggests buffer issues and increasing the number of parameters in my network directly affects the size of Ō_grad which gets reduced.

Now, is this something I can fix or does it need to be fixed within Netket, MPI4Jax or within my MPI installation? I'm not sure where to go from here.

Thanks!

Edit:
I'm using the following versions:
NetKet 3.7
mpi4py 3.1.4
mpi4jax 0.3.14.post1
jax 0.4.6
jaxlib 0.4.6+cuda11.cudnn82

@PhilipVinc
Copy link
Member

PhilipVinc commented May 17, 2023

I assume you are running on a recent version of jax (>0.4.3). If that's not the case, please try to update.

MPI_ERR_TRUNCATE means that the message received by MPI on a node (probably that all reduce call) is longer than what was expected.
There are two possible reasons for that:

  • You are actually not syncing your different ranks, and some are executing more code than others. This might happen for example because of logging if there are bugs in the loggers, or your custom logic in a callback. I will assume you minimised the problem enough that you are only computing gradients and nothing else is going on.
  • It's possible that you have a different number of samples on different ranks? This should not happen if you use plain netket, because I go to great lengths to make sure this does not happen, but if you are doing something particular maybe you are going beyond the chess in place.
  • It's possible that for some reason your dtypes in a rank end up being different from the dtypes in another rank. This happened once, again, because of a bug somewhere in the user code, and therefore one rank ended up having float64 gradients while the others have float32 and this results in this kind of error. A way to check against this error is to keep printing dtypes all the time to be sure
  • An out of memory issue? Those are hard to debug and mainly happen on GPU, however I've seen them because MPI uses buffers on the side different from those of Jax. You can rule those out by decreasing the percentage of memory preallocated by Jax (by default 90%).

To find out more, try running your code with

MPI4JAX_DEBUG=1 mpirun -n 2 python script.py

To get some debug information and try to paste it back here, it should show up some payloads having different sizes on different ranks...

@MarcMachaczek
Copy link
Author

MarcMachaczek commented May 17, 2023

First of all thank you for all the tips. I'm using a recent version of jax (updated in description), I'm just calculating expect_and_forces here, so no sync issue(?). Also due to the small size I think there are no memory issues either.

I've created a (minimal) example that reproduces the error. Unfortunately I couldn't cut down the model definition much further without losing the error. The network is essentially a symmetric RBM with additional correlators, so the first part is just that of a symmetric RBM, the first loop adds correlator terms and the second loop is very similar (here hidden units are only connected to correlator terms). This example here however does by no means make any physical sense and is also not optimized, it's just for demonstration.

import os
from mpi4py import MPI

comm = MPI.COMM_WORLD
n_ranks = comm.Get_size()
rank = MPI.COMM_WORLD.Get_rank()
os.environ["MPI4JAX_USE_CUDA_MPI"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = f"{rank}"
os.environ["JAX_PLATFORM_NAME"] = "gpu"

import jax
import jax.numpy as jnp

import netket as nk
from netket.utils import HashableArray
from netket import nn as nknn
from flax import linen as nn
from flax.linen.dtypes import promote_dtype

from typing import Union, Any, Callable, Sequence

L = 3 
N  = 2*L**2
hilbert = nk.hilbert.Spin(s=1 / 2, N=N)
ham = sum([nk.operator.spin.sigmaz(hilbert, i) for i in range(hilbert.size)])

# chosen arbitrary with adequate shapes
perms = jnp.stack([(jnp.arange(int(N/2)) + i) % N for i in range(int(N/2))], axis=0)
perms = HashableArray(perms)
link_perms = jnp.stack([(jnp.arange(int(N)) + i) % N for i in range(int(N/2))], axis=0)
link_perms = HashableArray(link_perms)
plaqs = jnp.array([[1, 2, 3, 4] for _ in range(int(N/2))])
plaqs = HashableArray(plaqs)

correlators = (plaqs,)
correlator_symmetries = (perms,)

model = ToricCRBM(symmetries=link_perms,
                  correlators=correlators,
                  correlator_symmetries=correlator_symmetries,
                  param_dtype=complex)

single_rule = nk.sampler.rules.LocalRule()
sampler = nk.sampler.MetropolisSampler(hilbert, rule=single_rule)
variational_gs = nk.vqs.MCState(sampler, model)

exp, forces = variational_gs.expect_and_forces(ham)

where the model is defined as

class ToricCRBM(nn.Module):
    # permutations of lattice sites corresponding to symmetries
    symmetries: HashableArray
    # correlators that serve as additional input for the cRBM
    correlators: Sequence[HashableArray]
    # permutations of correlators corresponding to symmetries
    correlator_symmetries: Sequence[HashableArray]
    # The dtype of the weights
    param_dtype: Any = jnp.float64
    # The nonlinear activation function
    activation: Any = nknn.log_cosh
    # feature density. Number of features equal to alpha * input.shape[-1]
    alpha: Union[float, int] = 1

    # Initializer for the Dense layer matrix
    kernel_init: Callable = jax.nn.initializers.normal(0.01)
    # Initializer for the biases
    bias_init: Callable = jax.nn.initializers.normal(0.01)

    def setup(self):
        self.n_symm, self.n_sites = self.symmetries.wrapped.shape
        self.features = int(self.alpha * self.n_sites / self.n_symm)

    @nn.compact
    def __call__(self, x):
        n_batch = x.shape[0]
        # initialize bias and kernel for the "single spin correlators" analogous to GCNN
        hidden_bias = self.param("hidden_bias", self.bias_init, (self.features,), self.param_dtype)
        symm_kernel = self.param("symm_kernel", self.kernel_init, (self.features, self.n_sites), self.param_dtype)

        # take care of possibly different dtypes (e.g. x is float while parameters are complex)
        x, hidden_bias, symm_kernel = promote_dtype(x, hidden_bias, symm_kernel, dtype=None)

        # convert kernel to dense kernel of shape (features, n_symmetries, n_sites)
        symm_kernel = jnp.take(symm_kernel, self.symmetries.wrapped, axis=1)

        # x has shape (batch, n_sites), kernel has shape (features, n_symmetries, n_sites)
        # theta has shape (batch, features, n_symmetries)
        theta = jax.lax.dot_general(x, symm_kernel, (((1,), (2,)), ((), ())))
        theta += jnp.expand_dims(hidden_bias, 1)

        # here, use two visible biases according to the two spins in one unit cell
        visible_bias = self.param("visible_bias", self.bias_init, (2,), self.param_dtype)
        bias = jnp.sum(x.reshape(n_batch, -1, 2), axis=(1,)) @ visible_bias

        for i, (correlator, correlator_symmetry) in enumerate(zip(self.correlators, self.correlator_symmetries)):
            # initialize "visible" bias and kernel matrix for correlator
            correlator = correlator.wrapped  # convert hashable array to (usable) jax.Array
            corr_bias = self.param(f"corr{i}_bias", self.bias_init, (1,), self.param_dtype)
            corr_kernel = self.param(f"corr{i}_kernel", self.kernel_init, (self.features, len(correlator)),
                                     self.param_dtype)
            
            x, corr_bias, corr_kernel = promote_dtype(x, corr_bias, corr_kernel, dtype=None)

            # convert kernel to dense kernel of shape (features, n_correlator_symmetries, n_corrs)
            corr_kernel = jnp.take(corr_kernel, correlator_symmetry.wrapped, axis=1)

            # correlator has shape (n_corrs, degree), e.g. n_corrs=L²=n_site/2 and degree=4 for plaquettes
            corr_values = jnp.take(x, correlator, axis=1).prod(axis=2)  # shape (batch, n_corrs)

            # corr_values has shape (batch, n_corrs)
            # kernel has shape (features, n_correlator_symmetries, n_corrs)
            # theta has shape (batch, features, n_symmetries)
            theta += jax.lax.dot_general(corr_values, corr_kernel, (((1,), (2,)), ((), ())))
            bias += corr_bias * jnp.sum(corr_values, axis=(1,))

        out = self.activation(theta)
        out = jnp.sum(out, axis=(1, 2))  # sum over all symmetries and features = alpha * n_sites / n_symmetries
        out += bias

        for i, (loop_corr, loop_symmetries) in enumerate(zip(self.correlators, self.correlator_symmetries)):
            # initialize "visible" bias and kernel matrix for loop correlator
            loop_corr = loop_corr.wrapped  # convert hashable array to (usable) jax.Array
            loop_hidden_bias = self.param(f"loop{i}_hidden_bias", self.bias_init, (self.features,), self.param_dtype)
            loop_kernel = self.param(f"loop{i}_kernel", self.kernel_init, (self.features, len(loop_corr)),
                                     self.param_dtype)
            
            x, loop_hidden_bias, loop_kernel = promote_dtype(x, loop_hidden_bias, loop_kernel, dtype=None)

            # convert kernel to dense kernel of shape (features, n_correlator_symmetries, n_corrs)
            loop_kernel = jnp.take(loop_kernel, loop_symmetries.wrapped, axis=1)

            # loop_corr has shape (n_corrs, degree)
            loop_values = jnp.take(x, loop_corr, axis=1).prod(axis=2)  # shape (batch, n_corrs)

            # loop_values has shape (batch, n_corrs)
            # kernel has shape (features, n_loop_symmetries, n_corrs)
            # loop_theta has shape (batch, features, n_symmetries)
            loop_theta = jax.lax.dot_general(loop_values, loop_kernel, (((1,), (2,)), ((), ())))
            loop_theta += jnp.expand_dims(loop_hidden_bias, 1)
            loop_out = jnp.sum(self.activation(loop_theta), axis=(1, 2))
            # this last line causes mpi allreduce error code 15
            out += loop_out

        return out

When running this with MPI4JAX_DEBUG=1 mpiexec -np 4 python script.py I get the following output:

r2 | fHRGuSaC | MPI_Bcast -> 0 with 2 items
r3 | bfEyaLwe | MPI_Bcast -> 0 with 2 items
r1 | KE0Bcpk8 | MPI_Bcast -> 0 with 2 items
r0 | bgyndzqp | MPI_Bcast -> 0 with 2 items
r0 | bgyndzqp | MPI_Bcast done with code 0 (5.31e-05s)
r3 | bfEyaLwe | MPI_Bcast done with code 0 (2.33e-02s)
r2 | fHRGuSaC | MPI_Bcast done with code 0 (2.36e-02s)
r1 | KE0Bcpk8 | MPI_Bcast done with code 0 (2.28e-02s)
r0 | b4etROii | MPI_Bcast -> 0 with 2 items
r0 | b4etROii | MPI_Bcast done with code 0 (1.22e-05s)
r3 | bpDhgQAc | MPI_Bcast -> 0 with 2 items
r2 | GECnL1L9 | MPI_Bcast -> 0 with 2 items
r2 | GECnL1L9 | MPI_Bcast done with code 0 (1.24e-05s)
r3 | bpDhgQAc | MPI_Bcast done with code 0 (6.81e-02s)
r1 | sMNJYWUr | MPI_Bcast -> 0 with 2 items
r1 | sMNJYWUr | MPI_Bcast done with code 0 (1.88e-05s)
r0 | AvNxMDc4 | MPI_Bcast -> 0 with 8 items
r0 | AvNxMDc4 | MPI_Bcast done with code 0 (1.19e-05s)
r2 | 19aYzDVT | MPI_Bcast -> 0 with 8 items
r2 | 19aYzDVT | MPI_Bcast done with code 0 (1.31e-05s)
r1 | yjrakZtI | MPI_Bcast -> 0 with 8 items
r1 | yjrakZtI | MPI_Bcast done with code 0 (2.08e-05s)
r3 | P5cU2kxq | MPI_Bcast -> 0 with 8 items
r3 | P5cU2kxq | MPI_Bcast done with code 0 (1.21e-05s)
r0 | Qz6UJ9yn | MPI_Allreduce with 1 items
r2 | t759NLAa | MPI_Allreduce with 1 items
r1 | k8djzAxl | MPI_Allreduce with 1 items
r3 | fKuxHacJ | MPI_Allreduce with 1 items
r2 | t759NLAa | MPI_Allreduce done with code 0 (2.17e-01s)
r0 | Qz6UJ9yn | MPI_Allreduce done with code 0 (3.86e-01s)
r3 | fKuxHacJ | MPI_Allreduce done with code 0 (5.19e-05s)
r1 | k8djzAxl | MPI_Allreduce done with code 0 (2.87e-02s)
r3 | o5VIplMj | MPI_Allreduce with 1 items
r2 | ZWH5WdGP | MPI_Allreduce with 1 items
r1 | 3rR5bm2B | MPI_Allreduce with 1 items
r0 | trcpO5de | MPI_Allreduce with 1 items
r1 | 3rR5bm2B | MPI_Allreduce done with code 0 (2.20e-05s)
r3 | o5VIplMj | MPI_Allreduce done with code 0 (3.08e-05s)
r0 | trcpO5de | MPI_Allreduce done with code 0 (1.08e-05s)
r2 | ZWH5WdGP | MPI_Allreduce done with code 0 (3.68e-05s)
r0 | fV1PM63v | MPI_Allreduce with 1 items
r3 | nCfVh97V | MPI_Allreduce with 1 items
r2 | jXdDZXxX | MPI_Allreduce with 1 items
r1 | z4K1BxYz | MPI_Allreduce with 36 items
r3 | nCfVh97V | MPI_Allreduce done with code 15 (1.59e-04s)
r0 | fV1PM63v | MPI_Allreduce done with code 15 (1.53e-04s)
r1 | z4K1BxYz | MPI_Allreduce done with code 0 (1.91e-05s)
r3 | MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated' - aborting
r0 | MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated' - aborting
r1 | kzIVpqUw | MPI_Allreduce with 18 items
--------------------------------------------------------------------------
MPI_ABORT was invoked on rank 0 in communicator MPI COMMUNICATOR 4 CREATE FROM 0
with errorcode 15.

NOTE: invoking MPI_ABORT causes Open MPI to kill all MPI processes.
You may or may not see output from other processes, depending on
exactly when Open MPI kills them.
--------------------------------------------------------------------------

The second loop in the forward method of the model is causing the error, I believe. If I just remove the line out += loop_out the error disappears.

@PhilipVinc
Copy link
Member

Is the bug reliable? in the sense, does it show up 100% of time you run this script?
Does it show up with different number of samples? OR different number of ranks?

The last lines here

r0 | fV1PM63v | MPI_Allreduce with 1 items
r3 | nCfVh97V | MPI_Allreduce with 1 items
r2 | jXdDZXxX | MPI_Allreduce with 1 items
r1 | z4K1BxYz | MPI_Allreduce with 36 items
r3 | nCfVh97V | MPI_Allreduce done with code 15 (1.59e-04s)
r0 | fV1PM63v | MPI_Allreduce done with code 15 (1.53e-04s)
r1 | z4K1BxYz | MPI_Allreduce done with code 0 (1.91e-05s)
r3 | MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated' - aborting
r0 | MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated' - aborting
r1 | kzIVpqUw | MPI_Allreduce with 18 items

say that you are running an Allreduce between ranks with objects of different shapes, for some reason I don't understand. Your model seems correct in that you don't do anything wrong in here.

I tried running this code on a CPU and it runs fine (I don't have access to GPUs at the moment).
I would try inserting some jax.debug.print calls before that line to check if by any chance you have nans.
Maybe you hit nan and jax silently avoids doing some calculations?

I would also try printing with jax.debug.print inside of the code computing the gradient to see what is going on..

@MarcMachaczek
Copy link
Author

Running it with the flag -np 4 (or anything bigger than 2) crashed 100% of times I tried. When running it with -np 2 it sometimes throws the error, just retrying then usually works. Running just one process or only on CPU works fine for me too. Changing samples also does not seem to have any effect.

I'll try my luck with jax.debug.print and update after that.

@MarcMachaczek
Copy link
Author

MarcMachaczek commented May 18, 2023

Alright, I found some interesting things: Whenever I use jax.debug.print(), to show dtypes, shapes, values etc, I only catch traced arrays. For instance, including jax.debug.print(f"{O_loc}") right before the return line in forces_expect_hermitian that contains the allreduce call,
return Ō, jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state

results in

Traced<ShapedArray(complex128[1008])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(complex128[1008])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(complex128[1008])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(complex128[1008])>with<DynamicJaxprTrace(level=1/0)>
r3 | MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated' - aborting
r2 | MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated' - aborting

So there can't be an error with NaNs or Infs as far as I understand.
Also simply removing the above mpi_sum_jax call makes the error disappear, so this is actually the problematic call here.

Next, I checked the keys of Ō_grad, resulting in

frozen_dict_keys(['corr0_bias', 'corr0_kernel', 'hidden_bias', 'loop0_hidden_bias', 'loop0_kernel', 'symm_kernel', 'visible_bias'])

Also, when comparing the shapes of each individual entry (like jax.debug.print(Ō_grad["corr0_bias"]), they match. However, I found the following:
When including

p = Ō_grad["corr0_kernel"]
jax.debug.print(f"{mpi.mpi_sum_jax(p)[0]}")

I get the output below (including the debug output)

MPI4JAX_DEBUG=1 mpiexec -np 4 python Toric_MPI_Test.py 
r2 | 6HAzAxqs | MPI_Bcast -> 0 with 2 items
r1 | ubR5YkfW | MPI_Bcast -> 0 with 2 items
r3 | Kp3aSWki | MPI_Bcast -> 0 with 2 items
r0 | dnSQ1YzW | MPI_Bcast -> 0 with 2 items
r1 | ubR5YkfW | MPI_Bcast done with code 0 (2.46e-02s)
r0 | dnSQ1YzW | MPI_Bcast done with code 0 (2.33e-05s)
r3 | Kp3aSWki | MPI_Bcast done with code 0 (2.19e-02s)
r2 | 6HAzAxqs | MPI_Bcast done with code 0 (3.45e-02s)
r1 | JKuBqS8g | MPI_Bcast -> 0 with 2 items
r3 | rU0oF7Aq | MPI_Bcast -> 0 with 2 items
r0 | 80LTa8FM | MPI_Bcast -> 0 with 2 items
r1 | JKuBqS8g | MPI_Bcast done with code 0 (6.15e-02s)
r0 | 80LTa8FM | MPI_Bcast done with code 0 (2.40e-05s)
r3 | rU0oF7Aq | MPI_Bcast done with code 0 (1.55e-02s)
r2 | zbnGz3xJ | MPI_Bcast -> 0 with 2 items
r2 | zbnGz3xJ | MPI_Bcast done with code 0 (2.49e-05s)
r0 | vPPmQGpD | MPI_Bcast -> 0 with 8 items
r0 | vPPmQGpD | MPI_Bcast done with code 0 (1.28e-05s)
r3 | 8Z1DeTBy | MPI_Bcast -> 0 with 8 items
r1 | DKglLp8F | MPI_Bcast -> 0 with 8 items
r3 | 8Z1DeTBy | MPI_Bcast done with code 0 (9.81e-05s)
r1 | DKglLp8F | MPI_Bcast done with code 0 (2.12e-05s)
r2 | mB3rOYse | MPI_Bcast -> 0 with 8 items
r2 | mB3rOYse | MPI_Bcast done with code 0 (1.26e-05s)
r3 | RmHnAFfG | MPI_Allreduce with 1 items
r0 | ARYIjrij | MPI_Allreduce with 1 items
Traced<ShapedArray(complex128[2,9])>with<DynamicJaxprTrace(level=1/0)>
r2 | RT9iyCeU | MPI_Allreduce with 1 items
r1 | wvKD637X | MPI_Allreduce with 1 items
r2 | RT9iyCeU | MPI_Allreduce done with code 0 (8.65e-02s)
r1 | wvKD637X | MPI_Allreduce done with code 0 (4.17e-05s)
r0 | ARYIjrij | MPI_Allreduce done with code 0 (1.80e-01s)
r3 | RmHnAFfG | MPI_Allreduce done with code 0 (3.82e-01s)
r1 | 2gzZeihm | MPI_Allreduce with 1 items
r2 | joSee3jI | MPI_Allreduce with 1 items
r0 | SFJSwFav | MPI_Allreduce with 1 items
r3 | HsfWTDFG | MPI_Allreduce with 1 items
r0 | SFJSwFav | MPI_Allreduce done with code 0 (5.74e-05s)
r2 | joSee3jI | MPI_Allreduce done with code 0 (6.74e-05s)
r3 | HsfWTDFG | MPI_Allreduce done with code 0 (1.48e-05s)
r1 | 2gzZeihm | MPI_Allreduce done with code 0 (1.04e-04s)
r2 | TIVpNDdL | MPI_Allreduce with 1 items
r1 | nlYEDbg0 | MPI_Allreduce with 18 items
r0 | 3qLQwtcC | MPI_Allreduce with 18 items
r3 | eMu8bZXS | MPI_Allreduce with 18 items
r2 | TIVpNDdL | MPI_Allreduce done with code 15 (3.16e-04s)
r3 | eMu8bZXS | MPI_Allreduce done with code 0 (2.40e-05s)
r1 | nlYEDbg0 | MPI_Allreduce done with code 0 (5.39e-05s)
r2 | MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated' - aborting
Traced<ShapedArray(complex128[2,9])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(complex128[2,9])>with<DynamicJaxprTrace(level=1/0)>
r1 | 78TMzPQZ | MPI_Allreduce with 1 items
r3 | u0f765G3 | MPI_Allreduce with 1 items
--------------------------------------------------------------------------
MPI_ABORT was invoked on rank 2 in communicator MPI COMMUNICATOR 4 CREATE FROM 0
with errorcode 15.

Notice that Traced<ShapedArray(complex128[2,9])>with<DynamicJaxprTrace(level=1/0)> only appears three times and the allreduce error appears for rank 2. This means that, for some reason, on rank 2 the parameters for "corr0_kernel" only consist of a scalar, whereas it has the correct shape of 2x9 on the other ranks.

r2 | TIVpNDdL | MPI_Allreduce with 1 items
r1 | nlYEDbg0 | MPI_Allreduce with 18 items
r0 | 3qLQwtcC | MPI_Allreduce with 18 items
r3 | eMu8bZXS | MPI_Allreduce with 18 items

Now the question is why some parameters seem to have different shapes or do not match between ranks.
Rerunning this causes the error on different ranks, also not necessarily at this particular 18 items.

Just some idea: Could the naming/creation of the parameters within the loop cause this? Maybe, once compiled, the ops for different parameters are run in different order on different ranks which could cause the same name for different shaped parameters?

@PhilipVinc
Copy link
Member

That's not really how you're supposed to use jax.debug.print, as you're just printing the object at trace time instead the actual numbers you have during the simulation. Check out the docs

In short, you should be trying with jax.debug.print(f"{0}", O_loc)

@PhilipVinc
Copy link
Member

Can you also try to insert the following line before the sum call?

print(f"r{mpi.rank}: Parameters:{jax.tree_map(lambda x: (x.shape, x.dtype), parameters)}")
print(f"r{mpi.rank}: Ō_grad:{jax.tree_map(lambda x: (x.shape, x.dtype), Ō_grad)}")
return Ō, jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state

This should confirm if your intuition is correct.

@PhilipVinc
Copy link
Member

As for the loop being the source of problems, not really.
This is a standard way of doing this kind of things in flax...
The only thing that could cause issues is if your symmetry permutation tables are different among ranks, but that should not happen (though you can check).

@MarcMachaczek
Copy link
Author

Yes you're right, my bad. I've mistaken the positional argument for string formatting....

Inserting the print statements I get

r0: Parameters:FrozenDict({
    corr0_kernel: ((2, 9), dtype('complex128')),
    loop0_kernel: ((2, 9), dtype('complex128')),
    symm_kernel: ((2, 18), dtype('complex128')),
})
r0: Ō_grad:FrozenDict({
    corr0_kernel: ((2, 9), dtype('complex128')),
    loop0_kernel: ((2, 9), dtype('complex128')),
    symm_kernel: ((2, 18), dtype('complex128')),
})
r2: Parameters:FrozenDict({
    corr0_kernel: ((2, 9), dtype('complex128')),
    loop0_kernel: ((2, 9), dtype('complex128')),
    symm_kernel: ((2, 18), dtype('complex128')),
})
r2: Ō_grad:FrozenDict({
    corr0_kernel: ((2, 9), dtype('complex128')),
    loop0_kernel: ((2, 9), dtype('complex128')),
    symm_kernel: ((2, 18), dtype('complex128')),
})
r3: Parameters:FrozenDict({
    corr0_kernel: ((2, 9), dtype('complex128')),
    loop0_kernel: ((2, 9), dtype('complex128')),
    symm_kernel: ((2, 18), dtype('complex128')),
})
r3: Ō_grad:FrozenDict({
    corr0_kernel: ((2, 9), dtype('complex128')),
    loop0_kernel: ((2, 9), dtype('complex128')),
    symm_kernel: ((2, 18), dtype('complex128')),
})
r1: Parameters:FrozenDict({
    corr0_kernel: ((2, 9), dtype('complex128')),
    loop0_kernel: ((2, 9), dtype('complex128')),
    symm_kernel: ((2, 18), dtype('complex128')),
})
r1: Ō_grad:FrozenDict({
    corr0_kernel: ((2, 9), dtype('complex128')),
    loop0_kernel: ((2, 9), dtype('complex128')),
    symm_kernel: ((2, 18), dtype('complex128')),
})
r1 | MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated' - aborting
r2 | MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated' - aborting

They all seem to match.
Using jax.debug.print properly(?) with:

jax.debug.print("r{x}: Parameters:{y}", x=mpi.rank, y=jax.tree_map(lambda x: (x.shape, x.dtype), parameters))
jax.debug.print("r{x}: Ō_grad:{y}", x=mpi.rank, y=jax.tree_map(lambda x: (x.shape, x.dtype), Ō_grad))

gives an error

Traceback (most recent call last):
  File "/home/m/M.Machaczek/geneqs/Toric_MPI_Test.py", line 190, in <module>
    variational_gs = nk.vqs.MCState(sampler, model, n_samples=4032)
  File "/home/m/M.Machaczek/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/state.py", line 240, in __init__
    self.init(seed, dtype=sampler.dtype)
  File "/home/m/M.Machaczek/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/state.py", line 275, in init
    variables = jit_evaluate(self._init_fun, {"params": key}, dummy_input)
  File "/home/m/M.Machaczek/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/state.py", line 93, in jit_evaluate
    return fun(*args)
  File "/home/m/M.Machaczek/geneqs/venv/lib/python3.10/site-packages/netket/vqs/mc/mc_state/state.py", line 191, in <lambda>
    lambda model, *args, **kwargs: model.init(*args, **kwargs), model
  File "/home/m/M.Machaczek/geneqs/Toric_MPI_Test.py", line 75, in __call__
    jax.debug.print("{x}", x=corr_kernel.dtype)
  File "/home/m/M.Machaczek/geneqs/venv/lib/python3.10/site-packages/jax/_src/debugging.py", line 269, in debug_print
    debug_callback(functools.partial(_format_print_callback, fmt), *args,
  File "/home/m/M.Machaczek/geneqs/venv/lib/python3.10/site-packages/jax/_src/debugging.py", line 229, in debug_callback
    return debug_callback_p.bind(*flat_args, callback=_flat_callback,
TypeError: Value dtype('complex128') with type <class 'numpy.dtype[complex128]'> is not a valid JAX type

Is that intended?
Moreover, I found that not jitting the forces_expect_hermitian seems to remove the error.

@PhilipVinc
Copy link
Member

PhilipVinc commented May 19, 2023

You should print

jax.debug.print("pars: {0}", parameters)
jax.debug.print("pars: {0}", O_loc)

Moreover, I found that not jitting the forces_expect_hermitian seems to remove the error.

That's not surprising. But will probably destroy the runtime..

@MarcMachaczek
Copy link
Author

MarcMachaczek commented May 19, 2023

Just printing

jax.debug.print("pars: {0}", parameters)
jax.debug.print("O_loc: {0}", O_loc)
jax.debug.print("O_grad: {0}", O_grad)

within jit works. The parameters over all ranks don't contain any nan, inf and have the appropriate shape (and are equal across all ranks). The output is as follows:
(I managed to reduce model complexity further, let me know in case you want the shorter code snippet)

rank0, pars: FrozenDict({
    corr0_kernel: array([[-0.00392577-0.00585034j, -0.00395896+0.00717666j,
             0.00482844+0.01503789j,  0.01206825+0.00596183j,
            -0.01333064-0.00178544j, -0.00534998-0.00248572j,
            -0.00133964-0.00041695j,  0.00489853+0.00136481j,
             0.00672664+0.00186884j]]),
    loop0_kernel: array([[ 0.00127958+0.0093299j ,  0.00839852-0.00014896j,
             0.00539055-0.00237713j,  0.00147185+0.00198063j,
             0.00709299-0.013578j  , -0.00166158+0.00206187j,
             0.0066111 +0.01501796j,  0.00206404+0.00841864j,
             0.00753647+0.00267465j]]),
    symm_kernel: array([[ 1.33879172e-03-0.01051146j, -2.91323357e-03+0.00396878j,
             1.43061758e-02+0.00620577j,  3.02559528e-03+0.00317619j,
             6.07562775e-04+0.00093176j, -2.34646905e-04+0.00023241j,
            -1.23231311e-02+0.00436134j,  2.04907351e-03+0.00656046j,
             1.57510688e-02+0.00933856j,  9.60645326e-03-0.01480771j,
             3.94705801e-03+0.00456447j, -2.52305742e-03+0.0007859j ,
            -4.69016011e-03+0.00249195j,  3.96973557e-03-0.00769581j,
            -4.70332972e-03+0.00154195j,  5.94012878e-05+0.00095087j,
            -1.78246866e-03-0.01250062j,  5.17321802e-04+0.00215315j]]),
})
rank3, pars: FrozenDict({
    corr0_kernel: array([[-0.00392577-0.00585034j, -0.00395896+0.00717666j,
             0.00482844+0.01503789j,  0.01206825+0.00596183j,
            -0.01333064-0.00178544j, -0.00534998-0.00248572j,
            -0.00133964-0.00041695j,  0.00489853+0.00136481j,
             0.00672664+0.00186884j]]),
    loop0_kernel: array([[ 0.00127958+0.0093299j ,  0.00839852-0.00014896j,
             0.00539055-0.00237713j,  0.00147185+0.00198063j,
             0.00709299-0.013578j  , -0.00166158+0.00206187j,
             0.0066111 +0.01501796j,  0.00206404+0.00841864j,
             0.00753647+0.00267465j]]),
    symm_kernel: array([[ 1.33879172e-03-0.01051146j, -2.91323357e-03+0.00396878j,
             1.43061758e-02+0.00620577j,  3.02559528e-03+0.00317619j,
             6.07562775e-04+0.00093176j, -2.34646905e-04+0.00023241j,
            -1.23231311e-02+0.00436134j,  2.04907351e-03+0.00656046j,
             1.57510688e-02+0.00933856j,  9.60645326e-03-0.01480771j,
             3.94705801e-03+0.00456447j, -2.52305742e-03+0.0007859j ,
            -4.69016011e-03+0.00249195j,  3.96973557e-03-0.00769581j,
            -4.70332972e-03+0.00154195j,  5.94012878e-05+0.00095087j,
            -1.78246866e-03-0.01250062j,  5.17321802e-04+0.00215315j]]),
})
rank2, pars: FrozenDict({
    corr0_kernel: array([[-0.00392577-0.00585034j, -0.00395896+0.00717666j,
             0.00482844+0.01503789j,  0.01206825+0.00596183j,
            -0.01333064-0.00178544j, -0.00534998-0.00248572j,
            -0.00133964-0.00041695j,  0.00489853+0.00136481j,
             0.00672664+0.00186884j]]),
    loop0_kernel: array([[ 0.00127958+0.0093299j ,  0.00839852-0.00014896j,
             0.00539055-0.00237713j,  0.00147185+0.00198063j,
             0.00709299-0.013578j  , -0.00166158+0.00206187j,
             0.0066111 +0.01501796j,  0.00206404+0.00841864j,
             0.00753647+0.00267465j]]),
    symm_kernel: array([[ 1.33879172e-03-0.01051146j, -2.91323357e-03+0.00396878j,
             1.43061758e-02+0.00620577j,  3.02559528e-03+0.00317619j,
             6.07562775e-04+0.00093176j, -2.34646905e-04+0.00023241j,
            -1.23231311e-02+0.00436134j,  2.04907351e-03+0.00656046j,
             1.57510688e-02+0.00933856j,  9.60645326e-03-0.01480771j,
             3.94705801e-03+0.00456447j, -2.52305742e-03+0.0007859j ,
            -4.69016011e-03+0.00249195j,  3.96973557e-03-0.00769581j,
            -4.70332972e-03+0.00154195j,  5.94012878e-05+0.00095087j,
            -1.78246866e-03-0.01250062j,  5.17321802e-04+0.00215315j]]),
})
rank1, pars: FrozenDict({
    corr0_kernel: array([[-0.00392577-0.00585034j, -0.00395896+0.00717666j,
             0.00482844+0.01503789j,  0.01206825+0.00596183j,
            -0.01333064-0.00178544j, -0.00534998-0.00248572j,
            -0.00133964-0.00041695j,  0.00489853+0.00136481j,
             0.00672664+0.00186884j]]),
    loop0_kernel: array([[ 0.00127958+0.0093299j ,  0.00839852-0.00014896j,
             0.00539055-0.00237713j,  0.00147185+0.00198063j,
             0.00709299-0.013578j  , -0.00166158+0.00206187j,
             0.0066111 +0.01501796j,  0.00206404+0.00841864j,
             0.00753647+0.00267465j]]),
    symm_kernel: array([[ 1.33879172e-03-0.01051146j, -2.91323357e-03+0.00396878j,
             1.43061758e-02+0.00620577j,  3.02559528e-03+0.00317619j,
             6.07562775e-04+0.00093176j, -2.34646905e-04+0.00023241j,
            -1.23231311e-02+0.00436134j,  2.04907351e-03+0.00656046j,
             1.57510688e-02+0.00933856j,  9.60645326e-03-0.01480771j,
             3.94705801e-03+0.00456447j, -2.52305742e-03+0.0007859j ,
            -4.69016011e-03+0.00249195j,  3.96973557e-03-0.00769581j,
            -4.70332972e-03+0.00154195j,  5.94012878e-05+0.00095087j,
            -1.78246866e-03-0.01250062j,  5.17321802e-04+0.00215315j]]),
})
rank1,O_loc: [-0.04662698+0.j -0.04662698+0.j  5.95337302+0.j ... -2.04662698+0.j
 -0.04662698+0.j  9.95337302+0.j]
rank2,O_loc: [-0.04662698+0.j  7.95337302+0.j -6.04662698+0.j ... -4.04662698+0.j
 -4.04662698+0.j -2.04662698+0.j]
rank3,O_loc: [-2.04662698+0.j -6.04662698+0.j  3.95337302+0.j ... -4.04662698+0.j
 -0.04662698+0.j -0.04662698+0.j]
rank1,O_grad: FrozenDict({
    corr0_kernel: array([[-0.0011029-0.00233805j, -0.0011029-0.00233805j,
            -0.0011029-0.00233805j, -0.0011029-0.00233805j,
            -0.0011029-0.00233805j, -0.0011029-0.00233805j,
            -0.0011029-0.00233805j, -0.0011029-0.00233805j,
            -0.0011029-0.00233805j]]),
    loop0_kernel: array([[0.00451752-0.00276236j, 0.00451752-0.00276236j,
            0.00451752-0.00276236j, 0.00451752-0.00276236j,
            0.00451752-0.00276236j, 0.00451752-0.00276236j,
            0.00451752-0.00276236j, 0.00451752-0.00276236j,
            0.00451752-0.00276236j]]),
    symm_kernel: array([[-0.00402048+0.00494992j, -0.00067356-0.00246472j,
             0.00188191-0.00089012j,  0.00658493+0.00068265j,
             0.00471181+0.00456327j, -0.00416638+0.00655894j,
            -0.00527095-0.00113496j,  0.0007786 +0.00093602j,
             0.00521797+0.00146046j,  0.00573133+0.00251158j,
            -0.00087278+0.00741686j, -0.00211068-0.00090786j,
            -0.00373106+0.00425878j,  0.00271465+0.00473184j,
             0.00313997+0.00064955j, -0.0002848 +0.00094062j,
             0.00531945-0.00149018j, -0.00013818+0.00635979j]]),
})
rank2,O_grad: FrozenDict({
    corr0_kernel: array([[0.00124857+0.00433851j, 0.00124857+0.00433851j,
            0.00124857+0.00433851j, 0.00124857+0.00433851j,
            0.00124857+0.00433851j, 0.00124857+0.00433851j,
            0.00124857+0.00433851j, 0.00124857+0.00433851j,
            0.00124857+0.00433851j]]),
    loop0_kernel: array([[-0.01048405+0.00641076j, -0.01048405+0.00641076j,
            -0.01048405+0.00641076j, -0.01048405+0.00641076j,
            -0.01048405+0.00641076j, -0.01048405+0.00641076j,
            -0.01048405+0.00641076j, -0.01048405+0.00641076j,
            -0.01048405+0.00641076j]]),
    symm_kernel: array([[ 8.04429320e-04-0.00380855j,  3.73425612e-03+0.00287264j,
            -2.54154429e-03+0.00679805j, -1.15870220e-03-0.00114791j,
             1.89584990e-03-0.00153682j,  4.46062024e-03-0.00551735j,
             7.31850524e-03+0.0004957j ,  7.65450695e-04-0.00110581j,
            -4.75231024e-03-0.00164158j, -2.50239482e-03-0.00850731j,
            -8.99609531e-04-0.0033528j ,  1.71615123e-03-0.00674836j,
             1.71294144e-03-0.00068838j,  9.92617034e-04-0.00716566j,
             4.90317941e-03-0.00152089j, -1.21894460e-03+0.00063811j,
             9.79234361e-05-0.00071282j,  1.33977792e-03-0.00312163j]]),
})
rank3,O_grad: FrozenDict({
    corr0_kernel: array([[0.00079069-0.01626548j, 0.00079069-0.01626548j,
            0.00079069-0.01626548j, 0.00079069-0.01626548j,
            0.00079069-0.01626548j, 0.00079069-0.01626548j,
            0.00079069-0.01626548j, 0.00079069-0.01626548j,
            0.00079069-0.01626548j]]),
    loop0_kernel: array([[0.02037145-0.01245667j, 0.02037145-0.01245667j,
            0.02037145-0.01245667j, 0.02037145-0.01245667j,
            0.02037145-0.01245667j, 0.02037145-0.01245667j,
            0.02037145-0.01245667j, 0.02037145-0.01245667j,
            0.02037145-0.01245667j]]),
    symm_kernel: array([[ 0.00020194+0.00473587j, -0.00214112-0.00404118j,
             0.00814897-0.00287663j,  0.00650817-0.00212251j,
             0.00375884+0.00366102j,  0.00116967-0.00273324j,
            -0.00181614-0.00477027j,  0.00163528-0.00576138j,
             0.01049396-0.00593831j,  0.0039593 +0.00920016j,
             0.00346919-0.00294757j, -0.00131917+0.00213894j,
            -0.00097301-0.00012452j,  0.00742335+0.00491303j,
             0.00056936+0.00416506j,  0.00353382-0.00419881j,
             0.00521233+0.00747299j,  0.00340546+0.00038205j]]),
})
r1 | MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated' - aborting
r2 | MPI_Allreduce returned error code 15: b'MPI_ERR_TRUNCATE: message truncated' - aborting

What is suspicious in my opinion is why the gradients for the correlator kernels are the same for every entry. (Just an artifact of my test implementation)
Also, just returning return Ō, Ō, Ō instead of return Ō, jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad), new_model_state also causes the error. possibly it is not this particular mpi allreduce call. Maybe it is something related to the vjp?

@PhilipVinc
Copy link
Member

Well changing the return value is not preventing the error because the error is happening within this MPI call.

I think the only possibility left is that for some reason jax is reordering the MPI operations on one of the ranks (or all of them).... but I'm not sure how we could check that.

(Maybe @dionhaefner has some ideas?)

@dionhaefner
Copy link

dionhaefner commented May 19, 2023

Well, errors like these could be from inconsistent shapes between ranks or things being re-ordered. All the evidence I've seen here points towards the latter (disappears without JIT, shapes seem to match). (If you have explicit rank-dependent branches or local value-dependent branches in netket I would triple check to rule out the first possibility.)

Operations like these seem risky to me:

jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], Ō_grad)

For this to behave properly, you need (1) Ō_grad guaranteeing order (e.g. not the case for Python dicts < 3.8 or so), (2) tree_map guaranteeing order, (3) proper token dependencies between all MPI calls of the iteration. I don't see any token management here?

@MarcMachaczek
Copy link
Author

MarcMachaczek commented May 19, 2023

Well changing the return value is not preventing the error because the error is happening within this MPI call.

I think the only possibility left is that for some reason jax is reordering the MPI operations on one of the ranks (or all of them).... but I'm not sure how we could check that.

(Maybe @dionhaefner has some ideas?)

Is there a way to make mpi show what the expected and received message length is?

@dionhaefner
Copy link

Another point: This being a GPU-only problem also points towards improper token management, since the GPU compiler is much more aggressive with re-orderings.

@PhilipVinc
Copy link
Member

@MandMarc , @inailuig identified the problem and is working on a fix. Indeed it was due to token mismanagement in netket.

the fix will take a bit to land but we'll get there...

@MarcMachaczek
Copy link
Author

Thank you for pursuing this issue and attempting to fix it! I‘ll follow the pull requests, maybe I can help at some point.

@PhilipVinc
Copy link
Member

The main limitation right now is that jax linear transposition does not understand tokens and crash.
We must open an issue on the jax repository and ask them to fix it..

@inailuig
Copy link
Collaborator

I was getting the same error as @MandMarc (but on cpu), which I also traced back to the mpi_allreduce calls in the tree_map being reordered, resulting in mpi messages with the wrong lengths (some ranks sending different leaves) causing error 15.

In #1494 I proposed a possible workaround, involving a custom tree map which uses tokens to enforce the order. Had to make it a custom primitive (essentially linear_call plus different batching rule) so that it's transposable with AD (as required by the jax solvers).

Filippo wants to come up with a more elegant solution, therefore we are not going to merge it, however feel free to use it in the meantime.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants