This repository has been archived by the owner on Mar 1, 2023. It is now read-only.
/
Checkpoint.jl
127 lines (109 loc) · 2.71 KB
/
Checkpoint.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
module Checkpoint
export write_checkpoint, rm_checkpoint, read_checkpoint
using JLD2
using MPI
using Printf
import KernelAbstractions: CPU
using ..ODESolvers
using ..ODESolvers: AbstractODESolver
using ..MPIStateArrays
import ..MPIStateArrays: array_device
"""
write_checkpoint(solver_config, checkpoint_dir, name, mpicomm, num)
Read in the state and auxiliary arrays as well as the simulation time
stored in the checkpoint file for `name` and `num`.
"""
function write_checkpoint(
solver_config,
checkpoint_dir::String,
name::String,
mpicomm::MPI.Comm,
num::Int,
)
Q = solver_config.Q
A = solver_config.dg.state_auxiliary
odesolver = solver_config.solver
write_checkpoint(Q, A, odesolver, checkpoint_dir, name, mpicomm, num)
return nothing
end
function write_checkpoint(
Q::MPIStateArray,
A::MPIStateArray,
odesolver::AbstractODESolver,
checkpoint_dir::String,
name::String,
mpicomm::MPI.Comm,
num::Int,
)
nm = replace(name, " " => "_")
cname = @sprintf(
"%s_checkpoint_mpirank%04d_num%04d.jld2",
nm,
MPI.Comm_rank(mpicomm),
num,
)
cfull = joinpath(checkpoint_dir, cname)
@info @sprintf(
"""
Checkpoint
saving to %s""",
cfull
)
if array_device(Q) isa CPU
h_Q = Q.realdata
h_aux = A.realdata
else
h_Q = Array(Q.realdata)
h_aux = Array(A.realdata)
end
t = ODESolvers.gettime(odesolver)
@save cfull h_Q h_aux t
return nothing
end
"""
rm_checkpoint(checkpoint_dir, name, mpicomm, num)
Remove the checkpoint file identified by `solver_config.name` and `num`.
"""
function rm_checkpoint(
checkpoint_dir::String,
name::String,
mpicomm::MPI.Comm,
num::Int,
)
nm = replace(name, " " => "_")
cname = @sprintf(
"%s_checkpoint_mpirank%04d_num%04d.jld2",
nm,
MPI.Comm_rank(mpicomm),
num,
)
rm(joinpath(checkpoint_dir, cname), force = true)
return nothing
end
"""
read_checkpoint(checkpoint_dir, name, mpicomm, num)
Read in the state and auxiliary arrays as well as the simulation time
stored in the checkpoint file for `name` and `num`.
"""
function read_checkpoint(
checkpoint_dir::String,
name::String,
array_type,
mpicomm::MPI.Comm,
num::Int,
)
nm = replace(name, " " => "_")
cname = @sprintf(
"%s_checkpoint_mpirank%04d_num%04d.jld2",
nm,
MPI.Comm_rank(mpicomm),
num,
)
cfull = joinpath(checkpoint_dir, cname)
if !isfile(cfull)
error("Cannot restore from checkpoint in $(cfull), file not found")
end
@load cfull h_Q h_aux t
return (array_type(h_Q), array_type(h_aux), t)
end
end # module