Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document the mpi.n_nodes rescaling factor in gradient computation with nk.jax.expect #1690

Open
alleSini99 opened this issue Jan 11, 2024 · 19 comments
Labels
docs Documentation-related issues

Comments

@alleSini99
Copy link
Collaborator

alleSini99 commented Jan 11, 2024

Hi everybody,
I think there is a rescaling factor mpi.n_nodes when using the function nk.jax.expect for computing the gradient of an arbitrary loss function.
I believe this is due to the operation mpi_mean in the definition of the backward rule in nk.jax.expect:

def _expect_bwd(n_chains, log_pdf, expected_fun, residuals, dout):
    pars, σ, cost_args, ΔL_σ = residuals
    dL̄, dL̄_stats = dout

    def f(pars, σ, *cost_args):
        log_p = log_pdf(pars, σ)
        term1 = jax.vmap(jnp.multiply)(ΔL_σ, log_p)
        term2 = expected_fun(pars, σ, *cost_args)
        out = mpi_mean(term1 + term2, axis=0)
        out = out.sum()
        return out

    _, pb = nkvjp(f, pars, σ, *cost_args)
    grad_f = pb(dL̄)
    return grad_f

Indeed, when vjp is performed on f, it does the sum over the samples on each rank separately but it does not perform also the sum over the ranks (and this is instead needed, since in f we realize the mean over the ranks too).
Therefore, this results in a gradient that is rescaled by a factor of 1/mpi.n_nodes on each rank, as shown in the following MWE for the computation of the energy gradient.

# To run with MPI

import jax.numpy as jnp
import jax 

import netket as nk
from netket.jax import expect, vjp
from netket.utils import mpi

L = 4
hi = nk.hilbert.Spin(s=0.5, N=L)
vstate = nk.vqs.MCState(sampler=nk.sampler.MetropolisLocal(hi), model=nk.models.RBM(alpha=1), n_samples=8192)
H = nk.operator.IsingJax(hi, graph=nk.graph.Chain(L), h=1.0, J=-1.0)

out = vstate.expect_and_grad(H)[1]
if(nk.utils.mpi.rank==0):   
    print("standard expect: ", out["Dense"]["bias"][:8], "\n")

def expect_and_grad_jax(vstate, op):
    afun = vstate._apply_fun
    model_state = vstate.model_state
    sigma, args = nk.vqs.get_local_kernel_arguments(vstate, op)
    n_chains = sigma.shape[0]
    sigma = sigma.reshape(-1, sigma.shape[-1])
    eloc = nk.vqs.get_local_kernel(vstate, op)

    def expect_fun(params):
            log_pdf = lambda params, σ: 2 * afun({"params": params, **model_state}, σ).real
            eloc_fun = lambda params, sigma: eloc(afun, {"params": params, **model_state}, sigma, args)

            return expect(log_pdf, eloc_fun, params, sigma, n_chains=n_chains)
    
    loss, vjp_fun, loss_stats = vjp(
        expect_fun, vstate.parameters, has_aux=True
    )

    loss_grad = vjp_fun(jnp.ones_like(loss))[0]
    loss_grad = jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], loss_grad)

    return loss_stats, loss_grad 

out = expect_and_grad_jax(vstate, H)[1]
if(nk.utils.mpi.rank==0):   
    print("jax expect: ", out["Dense"]["bias"][:8], "\n")

n_nodes =  mpi.n_nodes
if(nk.utils.mpi.rank==0):   
    print("jax expect with mpi.n_nodes: ", n_nodes * out["Dense"]["bias"][:8])

A possible solution would be to add an operation of mpi_sum which sums over the ranks after grad_f = pb(dL̄) inside _expect_bwd.
Another possibility would be to substitute the mpi_mean inside f with mpi_sum and performed the mpi_mean outside, but I think this is more complicated, so the previous one is preferable.
What do you think?

alleSini99 added a commit to alleSini99/netket that referenced this issue Jan 11, 2024
@PhilipVinc
Copy link
Member

If I understand correctly, we have two options that lead to a backward pass gradient with different normalisations.

The two different alternatives lead to an èxpect function that should be differentiated as

    def expect_fun(params):
            ...
            return expect(...)
    
    loss, vjp_fun, loss_stats = vjp(
        expect_fun, vstate.parameters, has_aux=True
    )

    loss_grad = vjp_fun(jnp.ones_like(loss))[0]

    # option 1
    loss_grad = jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], loss_grad)

    # option 2
    loss_grad = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], loss_grad)

I think that option 2 is more computationally sensible because it is consistent with the fact that the vjp is a sum, which only sums over rank-local values, so we need to add a sum over mpi ranks as well.

However, people might think they need to average so they want to use mpi.mean.

I am in favour of moving to option 2 but that is breaking (something that is not really public api, so it's ok) but I would love the opinion of luigi @inailuig

@gcarleo
Copy link
Member

gcarleo commented Jan 12, 2024

I think this is a bug and should be corrected anyways?

@PhilipVinc
Copy link
Member

It's just a matter of where we put the 1/Nrank factor when computing the gradient over several mpi nodes. Should the user use mpi.sum and we put the 1/nrank in the AD rule of expect, or should the user use mpi.mean and we do not put the 1/nrank in the AD rule of expect?

Right now we have the latter, even if I believe it's mathematically wrong, because the AD rule of expect should have the 1/nrank factor.

@alleSini99
Copy link
Collaborator Author

I agree but, actually, my proposal was to put mpi.mpi_sum inside the code of nk.jax.expect and not on the personal code of the user when computing the gradient, such that the user does not have to worry about it. I think this is the cleanest way.

@PhilipVinc
Copy link
Member

I agree but, actually, my proposal was to put mpi.mpi_sum inside the code of nk.jax.expect and not on the personal code of the user when computing the gradient, such that the user does not have to worry about it. I think this is the cleanest way.

I'm relatively sure you can't have the user not put an mpi call in his code...

@alleSini99
Copy link
Collaborator Author

To be more concrete I would propose:

def _expect_bwd(n_chains, log_pdf, expected_fun, residuals, dout):
    pars, σ, cost_args, ΔL_σ = residuals
    dL̄, dL̄_stats = dout

    def f(pars, σ, *cost_args):
        log_p = log_pdf(pars, σ)
        term1 = jax.vmap(jnp.multiply)(ΔL_σ, log_p)
        term2 = expected_fun(pars, σ, *cost_args)
        out = mpi_mean(term1 + term2, axis=0)
        out = out.sum()
        return out

    _, pb = nkvjp(f, pars, σ, *cost_args)
    grad_f = pb(dL̄)
    grad_f = jax.tree_map(lambda x: mpi_sum_jax(x)[0], grad_f)

    return grad_f


_expect.defvjp(_expect_fwd, _expect_bwd)

@PhilipVinc
Copy link
Member

Ok, the misunderstanding has been understood:

Alessandro is talking about doing

    _, pb = nkvjp(f, pars, σ, *cost_args)
    grad_f = pb(dL̄)
    grad_f = mpi.mean(grad_f)
    return grad_f

which indeed would get rid of the mpi call from user code.
However this code is technically wrong, and would return the correct gradient only for differentiated quantities that are identical among ranks.
So this would return the correct gradient wrt parameters, already averaged among mpi simplifying user code.
But the gradient wrt samples, which are different on every rank, would be wrong, because it would have been averaged as well.

The correct rule handling non - rank - replicated arguments cannot average over MPI because the arguments might be different and therefore the gradients should be different among ranks.

We could have a special function or some kwarts that is easier to use and assumes replication of differentiated inputs

@inailuig
Copy link
Collaborator

inailuig commented Jan 12, 2024

I think this is indeed a bug.

What we try to compute is

∇_θ ⟨H⟩ = ... = ⟨∇_θ logp(θ;x) ΔE_loc(x) + ∇_θ E_loc(θ,x)⟩ 
              = 1/N Σ_{i=1..N} ( ∇_θ logp(θ;x_i) ΔE_loc(x_i) + ∇_θ E_loc(θ, x_i) )

In the code we compute this expectation value by ∫ w.r.t θ, obtaining

F(θ) = 1/N Σ_{i=1..N} ( logp(θ;x_i) ΔE_loc(x_i) + E_loc(θ, x_i) ),

(Note that here we assume fixed params for ΔE_loc)
and taking the gradient of the resulting scalar function F with autodiff.

If you do a vjp of F, for the mpi part of the sum Σ_{i=1..N}, you 1. differentiate which gives again the same sum, and 2. transpose in the backward pass which gives the identity.
This means you only take the sum over the x_i on the current node, as no communication happens.
The result will be wrong, with different gradients on different nodes (not just off by a factor of n_nodes!)

Here are a couple of proposals to fix it:

  • just add an mpi allreduce after the vjp (like alleSini99@6ccf51b did)
  • apply the transpose of allreduce (identity) to the params inside f, which will cause jax to insert the allreduce
  • do the sum outside of F (vjp with a vector of 1/N, and manually mpi allreduce at the end)

@alleSini99
Copy link
Collaborator Author

alleSini99 commented Jan 15, 2024

So, since this seems to be the easiest way to implement (also to not break existing code), I will simply update the documentation of nk.jax.expect highlighting that the user needs to perform an mpi.sum its own code right after the vjp, if everyone agrees.

@gcarleo
Copy link
Member

gcarleo commented Jan 15, 2024

I don't know, I tend to like more the idea of being mathematically consistent and adopt the first solution proposed by Clemens

@PhilipVinc
Copy link
Member

PhilipVinc commented Jan 15, 2024 via email

@alleSini99
Copy link
Collaborator Author

I fully agree with @gcarleo that would be better to have something that is already consistent. However, from what I have understood, the solution of adding the mpi.all_reduce after the vjp works if we are taking the mean of quantities (over which we differentiate) that have to be the same over each rank, so it works for instance for the parameters (that are what we are interested in right now). However, it does not work if we differentiate with respect to quantities that are different on each rank (such as the samples). In this moment, we throw away the derivatives wrt samples, so the proposed solution will work. But, if in the future we are interested somehow in using the derivative wrt samples, then we have to find an alternative to it.

@PhilipVinc
Copy link
Member

That's half the story.

If you have a function that looks like

def my_fun(pars, samples):
     # this is the same on every rank
     res = expect(local_estimator, pars, samples)
     return res

and if you compute the gradient of this

E_fun, vjp_fun = jax.vjp(my_fun, pars, samples)
E_grad = vjp_fun(1)

You will get the correct gradient that way because the backward pass gets expanded to something like

def vjp_fun(dE ): #dE=1 usually
    _, pb = nkvjp(my_fun, pars, σ)
    # grad_f is different on every rank
    grad_pars, grad_σ = pb(dE)
    # we average before returning from the vjp rule proposed by Clemens and Adriano above
    grad_pars_avg = mpi.mean(grad_pars)
    # The proposed rule also averages the gradient wrt samples, but that is wrong because parameters
    # are different so we should not be pooling here
    grad_σ_avg = mpi.mean(grad_σ)
    return grad_pars_avg, grad_σ_avg

and you have the mean and are happy.
Note that the gradient wrt the parameters is correctly averaged over ranks, but also the one wrt the samples and that is wrong.

Nice thing is that the gradient wrt parameters is correct without asking to the user to add the mPI reduction himself because we inserted it in the rule.

Case where this breaks down

Now consider the case where you instead have

def my_fun(pars, samples):
     # new stuff that is not the same on every rank
     new_stuff = nonlinear_fun(pars, samples)
     
     # this is the same on every rank
     res = expect(local_estimator, pars, samples, new_stuff)
     return res

and if you compute the gradient of this

E_fun, vjp_fun = jax.vjp(my_fun, pars, samples)
E_grad = vjp_fun(1)

This will be expanded into something that looks like

def vjp_fun(dE ): #dE=1 usually
    _, pb = nkvjp(my_fun, pars, σ, new_stuff)
    # grad_f is different on every rank
    grad_pars, grad_σ, grad_new_stuff = pb(dE)
    # we average before returning from the vjp rule proposed by Clemens and Adriano above
    grad_pars_avg = mpi.mean(grad_pars)
    
    # Also those. note that all those averages are wrong.
    grad_σ_avg = mpi.mean(grad_σ)
    grad_new_stuff_avg = mpi.mean(grad_new_stuff)
    ###.
    
    # now gotta do reverse pass of nonlinear_fun
    # it depends on the samples so those values are different on every rank!
    dpars_nonlinear, dσ_nonlinear = d_nonlinear_fun(dExpect)

    # Sum the different terms
    # the _avg terms are equal on all ranks, but the terms arising from the nonlinear term are different!
    grad_pars = grad_pars_avg + dpars_nonlinear
    grad_σ = grad_σ_avg + dσ_nonlinear
    return grad_pars, grad_σ

Now note that the gradient above is now wrong, and we still need to mpi.mean the gradients of the parameters.

In short, even if we refine the rule proposed by the kiddos above to only do an mpi.mean over the gradients of the parameters, if we have a nonlinear function before the nkjax.expect this is not aware of the mean, and we still need to mpi.mean(vjp_fun (1)).

Moreover, in this way we have an extra mpi.mean in the middle that we could have just ignored, so we are losing performance.

In short, I do not we can get rid of the mpi.mean of the output of the vjp without actually redefining a custom vjp over MPI.

@alleSini99
Copy link
Collaborator Author

alleSini99 commented Jan 17, 2024

Ok, I see. So, should I simply change the documentation of nk.jax.expect for the moment or should we add a custom definition of vjp with mpi in Netket? @gcarleo @PhilipVinc

@PhilipVinc PhilipVinc changed the title Unwanted mpi.n_nodes rescaling factor in gradient computation with nk.jax.expect Document the mpi.n_nodes rescaling factor in gradient computation with nk.jax.expect Jan 25, 2024
@PhilipVinc PhilipVinc added the docs Documentation-related issues label Jan 25, 2024
@PhilipVinc
Copy link
Member

This should be addressed by adding documentation and an example

@PhilipVinc PhilipVinc reopened this Jan 25, 2024
@alleSini99
Copy link
Collaborator Author

Ok, I will open a PR for it.

@inailuig
Copy link
Collaborator

inailuig commented Jan 25, 2024

We could have a special function or some kwarts that is easier to use and assumes replication of differentiated inputs

just for the record: in jax they use a kwarg called reduce_axes to specify this,
see e.g. https://jax.readthedocs.io/en/latest/_autosummary/jax.vjp.html
and we could add something similar to nkvjp but to do a mpi sum not psum ...

@PhilipVinc
Copy link
Member

I now think it's actually a good idea...

@PhilipVinc PhilipVinc reopened this Jan 25, 2024
@PhilipVinc
Copy link
Member

still this is mainly a documentation issue.

Regardless, it would bring us closer to not having mpi calls around in our implementations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
docs Documentation-related issues
Projects
None yet
Development

No branches or pull requests

4 participants