-
Notifications
You must be signed in to change notification settings - Fork 82
/
state_store.py
185 lines (171 loc) · 7.61 KB
/
state_store.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# Standard Library
import json
import os
# First Party
from smdebug.core.config_constants import (
CHECKPOINT_CONFIG_FILE_PATH_ENV_VAR,
CHECKPOINT_DIR_KEY,
DEFAULT_CHECKPOINT_CONFIG_FILE,
LATEST_GLOBAL_STEP_SAVED,
LATEST_GLOBAL_STEP_SEEN,
LATEST_MODE_STEP,
METADATA_FILENAME,
METADATA_FILENAME_S3_UPLOADED,
TRAINING_RUN,
)
from smdebug.core.logger import get_logger
logger = get_logger()
# This is 'predicate' for sorting the list of states based on seen steps.
def _rule_for_sorting(state):
return state[LATEST_GLOBAL_STEP_SEEN]
class StateStore:
def _get_checkpoint_files_in_dir(self, cp_dir):
checkpoint_files = []
for child, _, files in os.walk(cp_dir):
for file in files:
if (
file != METADATA_FILENAME
and file != METADATA_FILENAME_S3_UPLOADED
and "sagemaker-uploaded" not in file
):
checkpoint_files.append(os.path.join(child, file))
return sorted(checkpoint_files)
def __init__(self):
self._saved_states = []
self._states_file = None
self._checkpoint_dir = None
self._retrieve_path_to_checkpoint()
self._last_seen_checkpoint_files = []
self._last_seen_cp_files_size = []
if self._checkpoint_dir is not None:
self._states_file = os.path.join(self._checkpoint_dir, METADATA_FILENAME)
self._read_states_file()
self._last_seen_checkpoint_files = self._get_checkpoint_files_in_dir(
self._checkpoint_dir
)
for file in self._last_seen_checkpoint_files:
try:
self._last_seen_cp_files_size.append(os.path.getsize(file))
except Exception as e:
self._last_seen_cp_files_size.append(0)
logger.debug(e)
def _retrieve_path_to_checkpoint(self):
"""
Retrieve the folder/path where users will store the checkpoints. This path will be stored as a value for key
'CHECKPOINT_DIR_KEY' in the checkpoint config file.
We will monitor this folder and write the current state if this folder is recently modified.
"""
if self._checkpoint_dir is not None:
return self._checkpoint_dir
checkpoint_config_file = os.getenv(
CHECKPOINT_CONFIG_FILE_PATH_ENV_VAR, DEFAULT_CHECKPOINT_CONFIG_FILE
)
if os.path.exists(checkpoint_config_file):
with open(checkpoint_config_file) as json_data:
parameters = json.load(json_data)
if CHECKPOINT_DIR_KEY in parameters:
self._checkpoint_dir = parameters[CHECKPOINT_DIR_KEY]
else:
logger.info(f"The checkpoint config file {checkpoint_config_file} does not exist.")
def _read_states_file(self):
"""
Read the states from the file and create a sorted list of states.
The states are sorted based on the last seen step.
"""
if os.path.exists(self._states_file):
with open(self._states_file) as json_data:
parameters = json.load(json_data)
for param in parameters:
ts_state = dict()
ts_state[TRAINING_RUN] = param[TRAINING_RUN]
ts_state[LATEST_GLOBAL_STEP_SAVED] = param[LATEST_GLOBAL_STEP_SAVED]
ts_state[LATEST_GLOBAL_STEP_SEEN] = param[LATEST_GLOBAL_STEP_SEEN]
ts_state[LATEST_MODE_STEP] = param[LATEST_MODE_STEP]
self._saved_states.append(ts_state)
self._saved_states.sort(key=_rule_for_sorting)
def is_checkpoint_updated(self):
"""
Check whether new checkpoint files got added or existing checkpoint files that are
stored got updated.
"""
if self._checkpoint_dir is not None:
checkpoint_files = self._get_checkpoint_files_in_dir(self._checkpoint_dir)
if not checkpoint_files:
logger.debug(
"Checkpoints not updated. There are no checkpoint files created yet, to be updated"
)
return False
timestamps = []
for file in checkpoint_files:
try:
timestamps.append(os.path.getmtime(file))
except FileNotFoundError as e:
timestamps.append(0)
logger.debug(e)
logger.info(
f"Timestamps of different checkpoint files {[i for i in zip(checkpoint_files, timestamps)]}"
)
if len(self._last_seen_checkpoint_files) != len(checkpoint_files):
self._last_seen_checkpoint_files = checkpoint_files
for file in checkpoint_files:
try:
sz = os.path.getsize(file)
self._last_seen_cp_files_size.append(sz)
except FileNotFoundError as e:
self._last_seen_cp_files_size.append(0)
logger.debug(e)
logger.info(
f"sizes of different checkpoint files {[i for i in zip(checkpoint_files, self._last_seen_cp_files_size)]}"
)
return True
# check for each file if file size has changed
cp_file_sizes = []
for file in checkpoint_files:
try:
cp_file_sizes.append(os.path.getsize(file))
except FileNotFoundError as e:
cp_file_sizes.append(0)
logger.warning(e)
i = 0
for size in cp_file_sizes:
if size != self._last_seen_cp_files_size[i]:
self._last_seen_cp_files_size = cp_file_sizes
self._last_seen_checkpoint_files = checkpoint_files
logger.info(
f"sizes of different checkpoint files {[i for i in zip(checkpoint_files, self._last_seen_cp_files_size)]}"
)
return True
i += 1
# check if actual seen files has changed
i = 0
for file in checkpoint_files:
if file != self._last_seen_checkpoint_files[i]:
self._last_seen_checkpoint_files = checkpoint_files
for file in checkpoint_files:
try:
self._last_seen_cp_files_size.append(os.path.getsize(file))
except FileNotFoundError as e:
self._last_seen_cp_files_size.append(0)
logger.warning(e)
logger.info(
f"sizes of different checkpoint files {[i for i in zip(checkpoint_files, self._last_seen_cp_files_size)]}"
)
return True
i += 1
return False
def get_last_saved_state(self):
"""
Retrieve the last save state from the state file if exists.
The file can contain multiple states. The function will return only the last saves state.
"""
if len(self._saved_states) > 0:
return self._saved_states[-1]
return None
def update_state(self, ts_state):
"""
Write the passed state to state file. Since the state file is stored in the same folder as
that of checkpoints, we update the checkpoint update timestamp after state is written to the file.
"""
self._saved_states.append(ts_state)
with open(self._states_file, "w") as out_file:
json.dump(self._saved_states, out_file)