/
densities.jl
115 lines (100 loc) · 4.81 KB
/
densities.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Densities (and potentials) are represented by arrays
# ρ[ix,iy,iz,iσ] in real space, where iσ ∈ [1:n_spin_components]
function _check_nonnegative(ρ::AbstractArray{T}; tol=eps(T)) where {T}
minimum(ρ) < -tol && @warn("Negative ρ detected", min_ρ=minimum(ρ))
end
"""
compute_density(basis::PlaneWaveBasis, ψ::AbstractVector, occupation::AbstractVector)
Compute the density for a wave function `ψ` discretized on the plane-wave
grid `basis`, where the individual k-points are occupied according to `occupation`.
`ψ` should be one coefficient matrix per ``k``-point.
It is possible to ask only for occupations higher than a certain level to be computed by
using an optional `occupation_threshold`. By default all occupation numbers are considered.
"""
@views @timing function compute_density(basis::PlaneWaveBasis{T}, ψ, occupation;
occupation_threshold=zero(T)) where {T}
S = promote_type(T, real(eltype(ψ[1])))
# occupation should be on the CPU as we are going to be doing scalar indexing.
occupation = [to_cpu(oc) for oc in occupation]
mask_occ = [findall(occnk -> abs(occnk) ≥ occupation_threshold, occk)
for occk in occupation]
if all(isempty, mask_occ) # No non-zero occupations => return zero density
ρ = zeros_like(basis.G_vectors, S, basis.fft_size..., basis.model.n_spin_components)
else
# we split the total iteration range (ik, n) in chunks, and parallelize over them
ik_n = [(ik, n) for ik = 1:length(basis.kpoints) for n = mask_occ[ik]]
chunk_length = cld(length(ik_n), Threads.nthreads())
# chunk-local variables
ρ_chunklocal = map(1:Threads.nthreads()) do i
zeros_like(basis.G_vectors, S, basis.fft_size..., basis.model.n_spin_components)
end
ψnk_real_chunklocal = [zeros_like(basis.G_vectors, complex(S), basis.fft_size...)
for _ = 1:Threads.nthreads()]
@sync for (ichunk, chunk) in enumerate(Iterators.partition(ik_n, chunk_length))
Threads.@spawn for (ik, n) in chunk # spawn a task per chunk
ρ_loc = ρ_chunklocal[ichunk]
ψnk_real = ψnk_real_chunklocal[ichunk]
kpt = basis.kpoints[ik]
ifft!(ψnk_real, basis, kpt, ψ[ik][:, n])
ρ_loc[:, :, :, kpt.spin] .+= (occupation[ik][n] .* basis.kweights[ik]
.* abs2.(ψnk_real))
synchronize_device(basis.architecture)
end
end
ρ = sum(ρ_chunklocal)
end
mpi_sum!(ρ, basis.comm_kpts)
ρ = symmetrize_ρ(basis, ρ; do_lowpass=false)
_check_nonnegative(ρ; tol=5occupation_threshold)
ρ
end
# Variation in density corresponding to a variation in the orbitals and occupations.
@views @timing function compute_δρ(basis::PlaneWaveBasis{T}, ψ, δψ,
occupation, δoccupation=zero.(occupation);
occupation_threshold=zero(T)) where {T}
ForwardDiff.derivative(zero(T)) do ε
ψ_ε = [ψk .+ ε .* δψk for (ψk, δψk) in zip(ψ, δψ)]
occ_ε = [occk .+ ε .* δocck for (occk, δocck) in zip(occupation, δoccupation)]
compute_density(basis, ψ_ε, occ_ε; occupation_threshold)
end
end
@views @timing function compute_kinetic_energy_density(basis::PlaneWaveBasis, ψ, occupation)
T = promote_type(eltype(basis), real(eltype(ψ[1])))
τ = similar(ψ[1], T, (basis.fft_size..., basis.model.n_spin_components))
τ .= 0
dαψnk_real = zeros(complex(eltype(basis)), basis.fft_size)
for (ik, kpt) in enumerate(basis.kpoints)
G_plus_k = [[Gk[α] for Gk in Gplusk_vectors_cart(basis, kpt)] for α in 1:3]
for n = 1:size(ψ[ik], 2), α = 1:3
ifft!(dαψnk_real, basis, kpt, im .* G_plus_k[α] .* ψ[ik][:, n])
@. τ[:, :, :, kpt.spin] += occupation[ik][n] * basis.kweights[ik] / 2 * abs2(dαψnk_real)
end
end
mpi_sum!(τ, basis.comm_kpts)
symmetrize_ρ(basis, τ; do_lowpass=false)
end
total_density(ρ) = dropdims(sum(ρ; dims=4); dims=4)
@views function spin_density(ρ)
if size(ρ, 4) == 2
ρ[:, :, :, 1] - ρ[:, :, :, 2]
else
zero(ρ[:, :, :])
end
end
function ρ_from_total_and_spin(ρtot, ρspin=nothing)
if ρspin === nothing
# Val used to ensure inferability
cat(ρtot; dims=Val(4)) # copy for consistency with other case
else
cat((ρtot .+ ρspin) ./ 2,
(ρtot .- ρspin) ./ 2; dims=Val(4))
end
end
function ρ_from_total(basis, ρtot::AbstractArray{T}) where {T}
if basis.model.spin_polarization in (:none, :spinless)
ρspin = nothing
else
ρspin = zeros(T, basis.fft_size)
end
ρ_from_total_and_spin(ρtot, ρspin)
end