/
jld2io.jl
120 lines (106 loc) · 4.01 KB
/
jld2io.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import JLD2
"""
Adds simplistic checkpointing to a DFTK self-consistent field calculation.
"""
function ScfSaveCheckpoints(filename="dftk_scf_checkpoint.jld2"; keep=false, overwrite=false)
# TODO Save only every 30 minutes or so
function callback(info)
if info.n_iter == 1
isfile(filename) && !overwrite && error(
"Checkpoint $filename exists. Use 'overwrite=true' to force overwriting."
)
end
if info.stage == :finalize
!keep && isfile(filename) && rm(filename) # Cleanup checkpoint
else
scfres = (; (k => v for (k, v) in pairs(info) if !startswith(string(k), "ρ"))...)
scfres = merge(scfres, (ρ=info.ρout, ρspin=info.ρ_spin_out))
save_scfres(filename, scfres)
end
info
end
end
function save_scfres(jld::JLD2.JLDFile, scfres::NamedTuple)
jld["__propertynames"] = propertynames(scfres)
jld["ρ_real"] = scfres.ρ.real
jld["ρspin_real"] = isnothing(scfres.ρspin) ? nothing : scfres.ρspin.real
jld["basis"] = scfres.basis
for symbol in propertynames(scfres)
symbol in (:ham, :basis, :ρ, :ρspin, :energies) && continue # special
jld[string(symbol)] = getproperty(scfres, symbol)
end
jld
end
function save_scfres(file::AbstractString, scfres::NamedTuple)
JLD2.jldopen(file, "w") do jld
save_scfres(jld, scfres)
end
end
function load_scfres(jld::JLD2.JLDFile)
basis = jld["basis"]
scfdict = Dict{Symbol, Any}(
:ρ => from_real(basis, jld["ρ_real"]),
:ρspin => nothing,
:basis => basis
)
if !isnothing(jld["ρspin_real"])
scfdict[:ρspin] = from_real(basis, jld["ρspin_real"])
end
kpt_properties = (:ψ, :occupation, :eigenvalues) # Need splitting over MPI processes
for sym in kpt_properties
scfdict[sym] = jld[string(sym)][basis.krange_thisproc]
end
for sym in jld["__propertynames"]
sym in (:ham, :basis, :ρ, :ρspin, :energies) && continue # special
sym in kpt_properties && continue
scfdict[sym] = jld[string(sym)]
end
energies, ham = energy_hamiltonian(basis, scfdict[:ψ], scfdict[:occupation];
ρ=scfdict[:ρ], ρspin=scfdict[:ρspin],
eigenvalues=scfdict[:eigenvalues],
εF=scfdict[:εF])
scfdict[:energies] = energies
scfdict[:ham] = ham
(; (sym => scfdict[sym] for sym in jld["__propertynames"])...)
end
load_scfres(file::AbstractString) = JLD2.jldopen(load_scfres, file, "r")
#
# Custom serialisations
#
struct PlaneWaveBasisSerialisation{T <: Real}
model::Model{T}
Ecut::T
kcoords::Vector{Vec3{T}}
kweights::Vector{T}
ksymops::Vector{Vector{SymOp}}
kgrid::Union{Nothing,Vec3{Int}}
kshift::Union{Nothing,Vec3{T}}
fft_size::Tuple{Int, Int, Int}
symmetries::Vector{SymOp}
end
JLD2.writeas(::Type{PlaneWaveBasis{T}}) where {T} = PlaneWaveBasisSerialisation{T}
function Base.convert(::Type{PlaneWaveBasisSerialisation{T}},
basis::PlaneWaveBasis{T}) where {T}
if mpi_nprocs() > 1
error("JLD2 serialisation for PlaneWaveBasis only implemented for non-MPI calculations for now.")
end
n_kcoords = div(length(basis.kpoints), basis.model.n_spin_components)
PlaneWaveBasisSerialisation{T}(
basis.model,
basis.Ecut,
getproperty.(basis.kpoints[1:n_kcoords], :coordinate),
basis.kweights[1:n_kcoords],
basis.ksymops[1:n_kcoords],
basis.kgrid,
basis.kshift,
basis.fft_size,
basis.symmetries
)
end
function Base.convert(::Type{PlaneWaveBasis{T}},
serial::PlaneWaveBasisSerialisation{T}) where {T}
PlaneWaveBasis(serial.model, serial.Ecut, serial.kcoords,
serial.ksymops, serial.symmetries;
fft_size=serial.fft_size,
kgrid=serial.kgrid, kshift=serial.kshift)
end