New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Document the mpi.n_nodes
rescaling factor in gradient computation with nk.jax.expect
#1690
Comments
If I understand correctly, we have two options that lead to a backward pass gradient with different normalisations. The two different alternatives lead to an
I think that option 2 is more computationally sensible because it is consistent with the fact that the vjp is a sum, which only sums over rank-local values, so we need to add a sum over mpi ranks as well. However, people might think they need to average so they want to use I am in favour of moving to option 2 but that is breaking (something that is not really public api, so it's ok) but I would love the opinion of luigi @inailuig |
I think this is a bug and should be corrected anyways? |
It's just a matter of where we put the Right now we have the latter, even if I believe it's mathematically wrong, because the AD rule of expect should have the 1/nrank factor. |
I agree but, actually, my proposal was to put |
I'm relatively sure you can't have the user not put an mpi call in his code... |
To be more concrete I would propose:
|
Ok, the misunderstanding has been understood: Alessandro is talking about doing _, pb = nkvjp(f, pars, σ, *cost_args)
grad_f = pb(dL̄)
grad_f = mpi.mean(grad_f)
return grad_f which indeed would get rid of the mpi call from user code. The correct rule handling non - rank - replicated arguments cannot average over MPI because the arguments might be different and therefore the gradients should be different among ranks. We could have a special function or some kwarts that is easier to use and assumes replication of differentiated inputs |
I think this is indeed a bug. What we try to compute is
In the code we compute this expectation value by ∫ w.r.t θ, obtaining
(Note that here we assume fixed params for ΔE_loc) If you do a vjp of F, for the mpi part of the sum Here are a couple of proposals to fix it:
|
So, since this seems to be the easiest way to implement (also to not break existing code), I will simply update the documentation of |
I don't know, I tend to like more the idea of being mathematically consistent and adopt the first solution proposed by Clemens |
The solutions proposed (hiding MPI) can be wrong in some cases, leading to different gradients on MPI ranks.Admittedly we never use expect that way but I’d prefer to avoid hiding code like that within netket.expect.Alessandro will write down some notes and documentation of why that is the case and how to use the existing expect…Il giorno 15 gen 2024, alle ore 08:52, Giuseppe Carleo ***@***.***> ha scritto:
I don't know, I tend to like more the idea of being mathematically consistent and adopt the first solution proposed by Clemens
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you commented.Message ID: ***@***.***>
|
I fully agree with @gcarleo that would be better to have something that is already consistent. However, from what I have understood, the solution of adding the |
That's half the story. If you have a function that looks like def my_fun(pars, samples):
# this is the same on every rank
res = expect(local_estimator, pars, samples)
return res and if you compute the gradient of this E_fun, vjp_fun = jax.vjp(my_fun, pars, samples)
E_grad = vjp_fun(1) You will get the correct gradient that way because the backward pass gets expanded to something like def vjp_fun(dE ): #dE=1 usually
_, pb = nkvjp(my_fun, pars, σ)
# grad_f is different on every rank
grad_pars, grad_σ = pb(dE)
# we average before returning from the vjp rule proposed by Clemens and Adriano above
grad_pars_avg = mpi.mean(grad_pars)
# The proposed rule also averages the gradient wrt samples, but that is wrong because parameters
# are different so we should not be pooling here
grad_σ_avg = mpi.mean(grad_σ)
return grad_pars_avg, grad_σ_avg and you have the mean and are happy. Nice thing is that the gradient wrt parameters is correct without asking to the user to add the mPI reduction himself because we inserted it in the rule. Case where this breaks downNow consider the case where you instead have def my_fun(pars, samples):
# new stuff that is not the same on every rank
new_stuff = nonlinear_fun(pars, samples)
# this is the same on every rank
res = expect(local_estimator, pars, samples, new_stuff)
return res and if you compute the gradient of this E_fun, vjp_fun = jax.vjp(my_fun, pars, samples)
E_grad = vjp_fun(1) This will be expanded into something that looks like def vjp_fun(dE ): #dE=1 usually
_, pb = nkvjp(my_fun, pars, σ, new_stuff)
# grad_f is different on every rank
grad_pars, grad_σ, grad_new_stuff = pb(dE)
# we average before returning from the vjp rule proposed by Clemens and Adriano above
grad_pars_avg = mpi.mean(grad_pars)
# Also those. note that all those averages are wrong.
grad_σ_avg = mpi.mean(grad_σ)
grad_new_stuff_avg = mpi.mean(grad_new_stuff)
###.
# now gotta do reverse pass of nonlinear_fun
# it depends on the samples so those values are different on every rank!
dpars_nonlinear, dσ_nonlinear = d_nonlinear_fun(dExpect)
# Sum the different terms
# the _avg terms are equal on all ranks, but the terms arising from the nonlinear term are different!
grad_pars = grad_pars_avg + dpars_nonlinear
grad_σ = grad_σ_avg + dσ_nonlinear
return grad_pars, grad_σ Now note that the gradient above is now wrong, and we still need to In short, even if we refine the rule proposed by the kiddos above to only do an Moreover, in this way we have an extra In short, I do not we can get rid of the |
Ok, I see. So, should I simply change the documentation of |
mpi.n_nodes
rescaling factor in gradient computation with nk.jax.expect
mpi.n_nodes
rescaling factor in gradient computation with nk.jax.expect
This should be addressed by adding documentation and an example |
Ok, I will open a PR for it. |
just for the record: in jax they use a kwarg called |
I now think it's actually a good idea... |
still this is mainly a documentation issue. Regardless, it would bring us closer to not having mpi calls around in our implementations. |
Hi everybody,
I think there is a rescaling factor
mpi.n_nodes
when using the functionnk.jax.expect
for computing the gradient of an arbitrary loss function.I believe this is due to the operation
mpi_mean
in the definition of the backward rule innk.jax.expect
:Indeed, when
vjp
is performed onf
, it does the sum over the samples on each rank separately but it does not perform also the sum over the ranks (and this is instead needed, since inf
we realize the mean over the ranks too).Therefore, this results in a gradient that is rescaled by a factor of 1/
mpi.n_nodes
on each rank, as shown in the following MWE for the computation of the energy gradient.A possible solution would be to add an operation of
mpi_sum
which sums over the ranks aftergrad_f = pb(dL̄)
inside_expect_bwd
.Another possibility would be to substitute the
mpi_mean
insidef
withmpi_sum
and performed thempi_mean
outside, but I think this is more complicated, so the previous one is preferable.What do you think?
The text was updated successfully, but these errors were encountered: