-
Notifications
You must be signed in to change notification settings - Fork 83
/
chainrules.jl
96 lines (76 loc) · 3.37 KB
/
chainrules.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
using Test
using DFTK
using FiniteDiff
using Zygote
include("testcases.jl")
# Testing rrules needed for reverse Hellmann-Feynman stress
@testset "ChainRules" begin
function has_consistent_derivative(f, a)
d1 = Zygote.gradient(f, a)[1]
d2 = FiniteDiff.finite_difference_derivative(f, a)
isapprox(d1, d2, atol=1e-5)
end
function has_consistent_gradient(f, x)
g1 = Zygote.gradient(f, x)[1]
g2 = FiniteDiff.finite_difference_gradient(f, x)
isapprox(g1, g2, atol=1e-4)
end
a = 10.26
Si = ElementPsp(silicon.atnum, psp=load_psp(silicon.psp))
atoms = [Si]
positions = silicon.positions
function make_model(a)
lattice = a / 2 * [[0. 1. 1.];
[1. 0. 1.];
[1. 1. 0.]]
terms = [Kinetic(), AtomicLocal()]
Model(lattice, atoms, positions; terms, temperature=1e-3)
end
kgrid = [1, 1, 1]
Ecut = 7
make_basis(model::Model) = PlaneWaveBasis(model, Ecut; kgrid=kgrid)
make_basis(a::Real) = make_basis(make_model(a))
basis = make_basis(a)
@testset "r_to_G, G_to_r" begin
kpt = basis.kpoints[1]
x = rand(ComplexF64,259)
y = rand(20,20,20)
w = rand(ComplexF64,20,20,20)
# r_to_G w.r.t. lattice
@test has_consistent_derivative(a -> abs2(sum(r_to_G(make_basis(a), y) .* y)), a)
# r_to_G kpt w.r.t. lattice
@test has_consistent_derivative(a -> abs2(sum(r_to_G(make_basis(a), kpt, w) .* x)), a)
# G_to_r w.r.t. lattice
@test has_consistent_derivative(a -> abs2(sum(G_to_r(make_basis(a), w) .* w)), a)
# G_to_r kpt w.r.t. lattice
@test has_consistent_derivative(a -> abs2(sum(G_to_r(make_basis(a), kpt, x) .* y)), a)
# r_to_G w.r.t. f_real
@test has_consistent_gradient(y -> abs2(sum(r_to_G(basis, y) .* w)), y)
# r_to_G kpt w.r.t. f_real
@test has_consistent_gradient(w -> abs2(sum(r_to_G(basis, kpt, w) .* x)), w)
# G_to_r w.r.t. f_fourier
@test has_consistent_gradient(w -> abs2(sum(G_to_r(basis, w) .* y)), w)
# G_to_r kpt w.r.t. f_fourier
@test has_consistent_gradient(x -> abs2(sum(G_to_r(basis, kpt, x) .* y)), x)
end
@testset "PlaneWaveBasis w.r.t. lattice" begin
@test has_consistent_derivative(a -> make_model(a).recip_cell_volume, a)
@test has_consistent_derivative(a -> make_basis(a).model.recip_cell_volume, a)
@test has_consistent_derivative(a -> make_basis(a).r_to_G_normalization, a)
@test has_consistent_derivative(a -> make_basis(a).G_to_r_normalization, a)
@test has_consistent_derivative(a -> make_basis(a).dvol, a)
end
@testset "term precomputations w.r.t. lattice" begin
# Kinetic
@test has_consistent_derivative(a -> sum(make_basis(a).terms[1].kinetic_energies[1]), a)
# AtomicLocal
@test has_consistent_derivative(a -> make_basis(a).terms[2].potential[1], a)
end
@testset "compute_density w.r.t. lattice" begin
scfres = self_consistent_field(basis, is_converged=DFTK.ScfConvergenceDensity(1e-4))
ψ = scfres.ψ
occupation = scfres.occupation
@test compute_density(basis, ψ, occupation) == DFTK._autodiff_compute_density(basis, ψ, occupation)
@test has_consistent_derivative(a -> sum(compute_density(make_basis(a), ψ, occupation)), a)
end
end