/
train.py
executable file
路117 lines (95 loc) 路 3.84 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python
import json
import os
import signal
import sys
import time
input_dir = "/opt/ml/input"
model_dir = "/opt/ml/model"
output_dir = "/opt/ml/output"
# we're arbitrarily going to iterate through 5 epochs here, a real algorithm
# may choose to determine the number of epochs based on a more realistic
# convergence criteria
num_epochs = 5
channel_name = "training"
terminated = False
def main():
# trapping signals and responding to them appropriately is required by
# SageMaker spec
trap_signal()
# writing to a failure file is also part of the spec
failure_file = output_dir + "/failure"
data_dir = input_dir + "/data"
try:
# we're allocating a byte array here to read data into, a real algo
# may opt to prefetch the data into a memory buffer and train in
# in parallel so that both IO and training happen simultaneously
data = bytearray(16777216)
total_read = 0
total_duration = 0
for epoch in range(num_epochs):
check_termination()
epoch_bytes_read = 0
# As per SageMaker Training spec, the FIFO's path will be based on
# the channel name and the current epoch:
fifo_path = "{0}/{1}_{2}".format(data_dir, channel_name, epoch)
# Usually the fifo will already exist by the time we get here, but
# to be safe we should wait to confirm:
wait_till_fifo_exists(fifo_path)
with open(fifo_path, "rb", buffering=0) as fifo:
print("opened fifo: %s" % fifo_path)
# Now simply iterate reading from the file until EOF. Again, a
# real algorithm will actually do something with the data
# rather than simply reading and immediately discarding like we
# are doing here
start = time.time()
bytes_read = fifo.readinto(data)
total_read += bytes_read
epoch_bytes_read += bytes_read
while bytes_read > 0 and not terminated:
bytes_read = fifo.readinto(data)
total_read += bytes_read
epoch_bytes_read += bytes_read
duration = time.time() - start
total_duration += duration
epoch_throughput = epoch_bytes_read / duration / 1000000
print(
"Completed epoch %s; read %s bytes; time: %.2fs, throughput: %.2f MB/s"
% (epoch, epoch_bytes_read, duration, epoch_throughput)
)
# now write a model, again, totally meaningless contents:
with open(model_dir + "/model.json", "w") as model:
json.dump(
{
"bytes_read": total_read,
"duration": total_duration,
"throughput_MB_per_sec": total_read / total_duration / 1000000,
},
model,
)
except Exception:
print("Failed to train: %s" % (sys.exc_info()[0]))
touch(failure_file)
raise
def check_termination():
if terminated:
print("Exiting due to termination request")
sys.exit(0)
def wait_till_fifo_exists(fname):
print("Wait till FIFO available: %s" % (fname))
while not os.path.exists(fname) and not terminated:
time.sleep(0.1)
check_termination()
def touch(fname):
open(fname, "wa").close()
def on_terminate(signum, frame):
print("caught termination signal, exiting gracefully...")
global terminated
terminated = True
def trap_signal():
signal.signal(signal.SIGTERM, on_terminate)
signal.signal(signal.SIGINT, on_terminate)
if __name__ == "__main__":
# As per the SageMaker container spec, the algo takes a 'train' parameter.
# We will simply ignore this in this dummy implementation.
main()