Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Autodiff stresses #443

Closed
1 of 2 tasks
niklasschmitz opened this issue Jun 7, 2021 · 17 comments
Closed
1 of 2 tasks

[WIP] Autodiff stresses #443

niklasschmitz opened this issue Jun 7, 2021 · 17 comments
Assignees
Labels
discussion Discussion thread of broader scope feature New feature or request

Comments

@niklasschmitz
Copy link
Collaborator

niklasschmitz commented Jun 7, 2021

Opening this to keep track of progress on obtaining stresses via autodiff.

Goal

Calculate the stress as the total derivative of the total energy wrt lattice parameters via automatic differentiation. As this falls under scope of the Hellmann-Feynman theorem, we do not need to differentiate through the full SCF solve but rather only through a post-processing on the final solution scfres.
We start with the following minimal example of silicon with a single scalar lattice parameter a

using DFTK
using Test

function make_basis(a)
    lattice = a / 2 * [[0 1 1.];
                    [1 0 1.];
                    [1 1 0.]]
    Si = ElementPsp(:Si, psp=load_psp("hgh/lda/Si-q4"))
    atoms = [Si => [ones(3)/8, -ones(3)/8]]
    model = model_atomic(lattice, atoms, symmetries=false)
    kgrid = [1, 1, 1]  # k-point grid (Regular Monkhorst-Pack grid)
    Ecut = 15          # kinetic energy cutoff in Hartree
    PlaneWaveBasis(model, Ecut; kgrid=kgrid)
end

a = 10.26
scfres = self_consistent_field(make_basis(a), tol=1e-8)

function compute_energy(scfres_ref, a)
    basis = make_basis(a)
    energies, H = energy_hamiltonian(basis, scfres_ref.ψ, scfres_ref.occupation; ρ=scfres_ref.ρ)
    energies.total
end

function compute_stress(scfres_ref, a)
    Inf # TODO implement
end
@test compute_stress(scfres, a)  FiniteDiff.finite_difference_derivative(a -> compute_energy(scfres, a), a) # -1.411

Approach

We plan to try ForwardDiff.jl, ReverseDiff.jl and Zygote.jl.
For stresses only (#params < 10) we expect ForwardDiff to perform best. Going further the reverse modes of ReverseDiff and Zygote are also interesting as they could jointly evaluate stresses and other derivatives of the total energy (eg. forces) more efficiently.

Expected challenges:

  • calls into non-Julia code such as FFTW (hurdle for ForwardDiff and ReverseDiff which don't support ChainRules, might need to manually overload)
  • many loops & generators (high compilation times for Zygote, unrolling for ReverseDiff)
  • scalar indexing (slow for Zygote)
  • try/catch (unsupported by Zygote)
  • value-dependent control flow (caution with caching tapes in ReverseDiff)

Progress

  • calculate stress term from (simpler) kinetic energy only
  • ForwardDiff, ReverseDiff and Zygote worked all out of the box (!)
  • Zygote gave numerically correct but ComplexF64 valued results (?)
  • Zygote had much higher compile time (60 s) than ReverseDiff (5 s) and ForwardDiff (4 s)
  • calculate total stress
  • initial stack traces are here
  • ForwardDiff and ReverseDiff both fail on LinearAlgebra.cond check
  • Zygote fails with no method matching zero(::String) (TODO understand stack trace)

Related links

An overview of AD tools in Julia: https://juliadiff.org/
Chris Rackauckas on strengths and weaknesses of different AD packages: https://discourse.julialang.org/t/state-of-automatic-differentiation-in-julia/43083/3
Common patterns that need rules in Zygote: https://juliadiff.org/ChainRulesCore.jl/stable/writing_good_rules.html

@niklasschmitz
Copy link
Collaborator Author

cc #107

@antoine-levitt
Copy link
Member

Cf also #47 (in particular a trick to reduce to 6 instead of 9 DOF, although we probably don't care too much)

@antoine-levitt
Copy link
Member

I'd focus on forwarddiff for now. We should be able to work around the errors. Michael did the work of making it work for IntervalArithmetic scalar types so it should hopefully be similar.

@mfherbst
Copy link
Member

mfherbst commented Jun 7, 2021

yes please let me know if you need any help on that. I should easily find some nice examples to get you going in case you need any.

@mfherbst mfherbst added the feature New feature or request label Jun 8, 2021
@niklasschmitz
Copy link
Collaborator Author

Thanks, starting with ForwardDiff sounds good to me. As I understand it, there's the options of either

  1. using GenericLinearAlgebra fallbacks as IntervalArithmetic (currently fails with stacktrace) or
  2. directly overloading the relevant LinearAlgebra and FFTW calls on ForwardDiff dual number types (in the spirit of Compatibility with Base linear algebra functions JuliaDiff/ForwardDiff.jl#111 (comment) and https://github.com/JuliaDiff/ForwardDiff.jl/pull/495/files).

My guess would be that both are similar in difficulty but 2) should be preferable for performance, what are your thoughts?

@antoine-levitt
Copy link
Member

Performance is not the foremost issue so following IntervalArithmetic sounds good, however we've had quite a bit of issues with the generic FFTs (which are buggy and not actively developed), so if https://github.com/JuliaDiff/ForwardDiff.jl/pull/495/files does the job then great!

@mfherbst
Copy link
Member

mfherbst commented Jun 8, 2021

Regarding the stacktrace: Only some FFT sizes work for the generic implementation we have and unless you specify an fft_size explicitly the PlaneWaveBasis constructor will auto-adjust. Therefore the effective fft_size which is used in https://gist.github.com/niklasschmitz/e7030b3f6341bcf56538a87d0b91d5e1#file-stress-genericlinearalgebra-jl-L30 and in https://gist.github.com/niklasschmitz/e7030b3f6341bcf56538a87d0b91d5e1#file-stress-genericlinearalgebra-jl-L16 don't agree. The solution is to explicitly pass a fixed fft_size to both constructors, e.g. just say PlaneWaveBasis(model, Ecut; kgrid=kgrid, fft_size=[32, 32, 32]) in both lines.

@mfherbst
Copy link
Member

mfherbst commented Jun 8, 2021

But I agree with Antoine. The generic FFT stuff only works "Mäh", so if we can avoid it, that would probably the better solution long-term.

@antoine-levitt
Copy link
Member

Tricky bug, nice catch!

@niklasschmitz
Copy link
Collaborator Author

Some updates on both ForwardDiff approaches:

I have iterated on the examples as discussed:

  • chose a const fft_size=[32,32,32]
  • chose a very high tolerance for the scf
  • moved term selection from model_atomic into a make_basis helper
  1. using generic arithmetic on FourierTransforms.jl :
  1. adding ForwardDiff.Dual rules on AbstractFFTs / FFTW:
  • Additional changes: added/modified custom rules for AbstractFFTs on dual numbers (still much WIP), disabled the lattice cond check, disabled the custom FFT normalization (will need a rule for s*ScaledPlan where s is a Dual), see https://github.com/niklasschmitz/DFTK.jl/pull/2/files

The inclusion of the AtomicNonLocal() term currently leads to NaN derivative results with ForwardDiff in both approaches, while other terms seem to work without further errors at least

@mfherbst
Copy link
Member

mfherbst commented Jun 16, 2021

Cool that's great news! So we can actually use finite diff to debug the AtomicNonLocal term. Some ideas how to debug:

  • Try to put checks into the loops of the force computation (https://github.com/JuliaMolSim/DFTK.jl/blob/master/src/terms/nonlocal.jl#L84) to see where the NaN occurs in the duals / deltas
  • Things you can do to simplify the problem (also to better understand the printing above):
    • Reduce the number of atoms per unit cell from two to one (in this line)
    • Instead of using the silicon pseudos, use the carbon pseudos (which are a bit simpler in the non-local term), i.e. use c-q4.hgh instead of si-q4.hgh.

Regarding the stacktraces in the second PR ... it appears at least for reverse diff this happens already in the PWBasis setup. I don't really fully get why on a first glance. Let's discuss tmr.

@antoine-levitt
Copy link
Member

JuliaLang/julia#27705 has a snippet for yielding an error when a NaN is produced

@niklasschmitz
Copy link
Collaborator Author

We've found the NaN of AtomicNonLocal, it came in due to a bug/inconsistency of ForwardDiff on norm(a::StaticArray) of zero's also shown here: JuliaDiff/ForwardDiff.jl#243 (comment).
The first very ad-hoc fix I applied was to manually fall back to the norm on Vector which gave correct directional derivatives at [0.,0.,0.] niklasschmitz#2 (comment).

This fixes the stress of AtomicNonlocal for both ForwardDiff approaches, which also each agree with FiniteDiff.

On Approach 2 I also re-enabled the fft normalizations and added the required additional Dual rule for ScaledPlan. After this now both above both ForwardDiff approaches finally agree on the stress of the example system above!

@antoine-levitt
Copy link
Member

Interesting. Actually this is a structural zero, ie it comes about by recip_lattice * zeros(3). So norm always gets called on a vector of 0+eps 0, so the non differentiability of norm at zero is not an issue (at least for forward). Can you check it's OK with chainrules? If yes might as well do a quick workaround here and wait for the next gen of forward diff tools.

@mfherbst mfherbst added the discussion Discussion thread of broader scope label Jun 20, 2021
@niklasschmitz
Copy link
Collaborator Author

This is the current behavior of norm at zero using (Zygote+ChainRules, ForwardDiff) x (Vector, SVector)

using Zygote
using ForwardDiff
using StaticArrays
using LinearAlgebra

x = zeros(3)
Zygote.gradient(norm, x)[1]
# 3-element Vector{Float64}:
#  0.0
#  0.0
#  0.0
ForwardDiff.gradient(norm, x)
# 3-element Vector{Float64}:
#  0.0
#  0.0
#  1.0

y = @SVector zeros(3)
Zygote.gradient(norm, y)[1]
# 3-element SVector{3, Float64} with indices SOneTo(3):
#  0.0
#  0.0
#  0.0
ForwardDiff.gradient(norm, y)
# 3-element SVector{3, Float64} with indices SOneTo(3):
#  NaN
#  NaN
#  NaN

# [f6369f11] ForwardDiff v0.10.18
# [90137ffa] StaticArrays v1.2.3
# [e88e6eb3] Zygote v0.6.13

For our use case all results are ok except the NaN since it doesn't cancel out in subsequent multiplication by zero, although I'm surprised by ForwardDiff.gradient(norm, x) giving preference to the last input dimension. Zygote picks up on the dedicated rulesets for norm in ChainRules. Calling ChainRules directly also works (in particular the frule gives a consistent 0.0 sensitivity for all input dims)

using ChainRules # [082447d4] ChainRules v0.8.13

ChainRules.unthunk(ChainRules.rrule(norm, x)[2](1.0)[2])
# 3-element Vector{Float64}:
#  0.0
#  0.0
#  0.0

ChainRules.unthunk(ChainRules.rrule(norm, y)[2](1.0)[2])
# 3-element SVector{3, Float64} with indices SOneTo(3):
#  0.0
#  0.0
#  0.0

function onehot(i, n)
    x = zeros(n)
    x[i] = 1.0
    x
end
ChainRules.frule((ChainRules.NoTangent(), onehot(1,3),), norm, x) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(2,3),), norm, x) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(3,3),), norm, x) # (0.0, 0.0)

ChainRules.frule((ChainRules.NoTangent(), onehot(1,3),), norm, y) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(2,3),), norm, y) # (0.0, 0.0)
ChainRules.frule((ChainRules.NoTangent(), onehot(3,3),), norm, y) # (0.0, 0.0)

So a next gen forward diff picking up on ChainRules should indeed fix the problem.
As for quick workarounds I'm thinking of either

  1. convert structural zero static-vectors to Vector (as I did during debugging nonlocal.jl), or
  2. overload norm on (static-)vectors of ForwardDiff.Dual to use the corresponding frule (which might need some thinking about how it gets picked up under broadcasting of norm too)

@antoine-levitt
Copy link
Member

Yeah, that just looks like a forwarddiff bug, so either work around it locally or fix it upstream.

@antoine-levitt
Copy link
Member

done in #476

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion Discussion thread of broader scope feature New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants