-
Notifications
You must be signed in to change notification settings - Fork 83
/
PlaneWaveBasis.jl
150 lines (124 loc) · 5.65 KB
/
PlaneWaveBasis.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
@testitem "PlaneWaveBasis: Check struct construction" setup=[TestCases] begin
using DFTK
using LinearAlgebra
silicon = TestCases.silicon
Ecut = 3
fft_size = [15, 15, 15]
model = Model(silicon.lattice, silicon.atoms, silicon.positions;
spin_polarization=:collinear)
basis = PlaneWaveBasis(model; Ecut=3, silicon.kgrid, fft_size)
@test basis.model.lattice == silicon.lattice
@test basis.model.recip_lattice ≈ 2π * inv(silicon.lattice)
@test basis.model.unit_cell_volume ≈ det(silicon.lattice)
@test basis.model.recip_cell_volume ≈ (2π)^3 * det(inv(silicon.lattice))
@test basis.Ecut == 3
@test basis.fft_size == Tuple(fft_size)
g_start = -ceil.(Int, (Vec3(basis.fft_size) .- 1) ./ 2)
g_stop = floor.(Int, (Vec3(basis.fft_size) .- 1) ./ 2)
g_all = vec(collect(G_vectors(basis)))
for (ik, kpt) in enumerate(basis.kpoints)
kpt = basis.kpoints[ik]
ikmod = mod1(basis.krange_thisproc_allspin[ik], length(silicon.kgrid.kcoords))
@test kpt.coordinate == silicon.kgrid.kcoords[ikmod]
@test basis.kweights[ik] == silicon.kgrid.kweights[ikmod]
for (ig, G) in enumerate(G_vectors(basis, kpt))
@test g_start <= G <= g_stop
end
@test g_all[kpt.mapping] == G_vectors(basis, kpt)
end
for σ = 1:basis.model.n_spin_components
for (ikσ, ik) = enumerate(DFTK.krange_spin(basis, σ))
@test basis.krange_thisproc[σ][ikσ] == basis.krange_thisproc_allspin[ik]
end
end
end
@testitem "PlaneWaveBasis: Energy cutoff is respected" setup=[TestCases] begin
using DFTK
silicon = TestCases.silicon
function test_pw_cutoffs(testcase, Ecut, fft_size)
model = Model(testcase.lattice)
basis = PlaneWaveBasis(model; Ecut, fft_size, kgrid=(2, 5, 5), kshift=[1, 0, 0]/2)
for kpt in basis.kpoints
for G in G_vectors(basis, kpt)
@test sum(abs2, model.recip_lattice * (kpt.coordinate + G)) ≤ 2 * Ecut
end
end
end
test_pw_cutoffs(silicon, 4.0, [15, 15, 15])
test_pw_cutoffs(silicon, 3.0, [15, 13, 13])
test_pw_cutoffs(silicon, 4.0, [11, 13, 11])
end
@testitem "PlaneWaveBasis: Check cubic basis and cubic index" setup=[TestCases] begin
using DFTK
using DFTK: index_G_vectors
silicon = TestCases.silicon
model = Model(silicon.lattice)
basis = PlaneWaveBasis(model; Ecut=3, fft_size=(15, 15, 15), kgrid=(1, 1, 1))
g_all = collect(G_vectors(basis))
for i = 1:15, j = 1:15, k = 1:15
@test index_G_vectors(basis, g_all[i, j, k]) == CartesianIndex(i, j, k)
end
@test index_G_vectors(basis, [15, 1, 1]) === nothing
@test index_G_vectors(basis, [-15, 1, 1]) === nothing
end
@testitem "PlaneWaveBasis: Check index for kpoints" setup=[TestCases] begin
using DFTK
using DFTK: index_G_vectors
silicon = TestCases.silicon
model = Model(silicon.lattice, silicon.atoms, silicon.positions)
basis = PlaneWaveBasis(model; Ecut=3, silicon.kgrid, fft_size=[7, 9, 11])
g_all = collect(G_vectors(basis))
for kpt in basis.kpoints
for (iball, ifull) in enumerate(kpt.mapping)
@test index_G_vectors(basis, kpt, g_all[ifull]) == iball
end
if kpt.coordinate == [1/3, 1/3, 0]
@test index_G_vectors(basis, kpt, [-2, -3, -1]) == 62
else
@test index_G_vectors(basis, kpt, [-2, -3, -1]) === nothing
end
@test index_G_vectors(basis, kpt, [15, 1, 1]) === nothing
@test index_G_vectors(basis, kpt, [-15, 1, 1]) === nothing
end
end
@testitem "PlaneWaveBasis: kpoint mapping" setup=[TestCases] begin
using DFTK
silicon = TestCases.silicon
model = Model(silicon.lattice, silicon.atoms, silicon.positions)
basis = PlaneWaveBasis(model; Ecut=3, kgrid=(2, 2, 2), fft_size=[7, 9, 11],
kshift=ones(3)/2)
for kpt in basis.kpoints
Gs_basis = collect(G_vectors(basis))
Gs_kpt = collect(G_vectors(basis, kpt))
for i = 1:length(kpt.mapping)
@test Gs_basis[kpt.mapping[i]] == Gs_kpt[i]
end
for i in keys(kpt.mapping_inv)
@test Gs_basis[i] == Gs_kpt[kpt.mapping_inv[i]]
end
end
end
@testitem "PlaneWaveBasis: Check G_vector-like accessor functions" setup=[TestCases] begin
using DFTK
silicon = TestCases.silicon
fft_size = [15, 15, 15]
model = Model(silicon.lattice, silicon.atoms, silicon.positions)
basis = PlaneWaveBasis(model; Ecut=3, kgrid=(3, 3, 3), fft_size)
# `isapprox` and not `==` because of https://github.com/JuliaLang/julia/issues/46849
atol = 20eps(eltype(basis))
@test length(G_vectors(fft_size)) == prod(fft_size)
@test length(r_vectors(basis)) == prod(fft_size)
@test G_vectors(basis) ≈ G_vectors(fft_size) atol=atol
@test G_vectors_cart(basis) ≈ map(G -> model.recip_lattice * G,
G_vectors(fft_size)) atol=atol
@test r_vectors_cart(basis) ≈ map(r -> model.lattice * r, r_vectors(basis)) atol=atol
for kpt in basis.kpoints
@test length(G_vectors(basis, kpt)) == length(kpt.mapping)
@test G_vectors_cart(basis, kpt) ≈ map(G -> model.recip_lattice * G,
G_vectors(basis, kpt)) atol=atol
@test Gplusk_vectors(basis, kpt) ≈ map(G -> G + kpt.coordinate,
G_vectors(basis, kpt)) atol=atol
@test Gplusk_vectors_cart(basis, kpt) ≈ map(q -> model.recip_lattice * q,
Gplusk_vectors(basis, kpt)) atol=atol
end
end