-
Notifications
You must be signed in to change notification settings - Fork 6.7k
/
train.py
executable file
·112 lines (90 loc) · 3.74 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/env python
import json
import os
import signal
import sys
import time
input_dir = '/opt/ml/input'
model_dir = '/opt/ml/model'
output_dir = '/opt/ml/output'
# we're arbitrarily going to iterate through 5 epochs here, a real algorithm
# may choose to determine the number of epochs based on a more realistic
# convergence criteria
num_epochs = 5
channel_name = 'training'
terminated = False
def main():
# trapping signals and responding to them appropriately is required by
# SageMaker spec
trap_signal()
# writing to a failure file is also part of the spec
failure_file = output_dir + '/failure'
data_dir = input_dir + '/data'
try:
# we're allocating a byte array here to read data into, a real algo
# may opt to prefetch the data into a memory buffer and train in
# in parallel so that both IO and training happen simultaneously
data = bytearray(16777216)
total_read = 0
total_duration = 0
for epoch in range(num_epochs):
check_termination()
epoch_bytes_read = 0
# As per SageMaker Training spec, the FIFO's path will be based on
# the channel name and the current epoch:
fifo_path = '{0}/{1}_{2}'.format(data_dir, channel_name, epoch)
# Usually the fifo will already exist by the time we get here, but
# to be safe we should wait to confirm:
wait_till_fifo_exists(fifo_path)
with open(fifo_path, 'rb', buffering=0) as fifo:
print('opened fifo: %s' % fifo_path)
# Now simply iterate reading from the file until EOF. Again, a
# real algorithm will actually do something with the data
# rather than simply reading and immediately discarding like we
# are doing here
start = time.time()
bytes_read = fifo.readinto(data)
total_read += bytes_read
epoch_bytes_read += bytes_read
while bytes_read > 0 and not terminated:
bytes_read = fifo.readinto(data)
total_read += bytes_read
epoch_bytes_read += bytes_read
duration = time.time() - start
total_duration += duration
epoch_throughput = epoch_bytes_read / duration / 1000000
print('Completed epoch %s; read %s bytes; time: %.2fs, throughput: %.2f MB/s'
% (epoch, epoch_bytes_read, duration, epoch_throughput))
# now write a model, again, totally meaningless contents:
with open(model_dir + '/model.json', 'w') as model:
json.dump({
'bytes_read': total_read,
'duration': total_duration,
'throughput_MB_per_sec': total_read / total_duration / 1000000
}, model)
except Exception:
print('Failed to train: %s' % (sys.exc_info()[0]))
touch(failure_file)
raise
def check_termination():
if terminated:
print('Exiting due to termination request')
sys.exit(0)
def wait_till_fifo_exists(fname):
print('Wait till FIFO available: %s' % (fname))
while not os.path.exists(fname) and not terminated:
time.sleep(.1)
check_termination()
def touch(fname):
open(fname, 'wa').close()
def on_terminate(signum, frame):
print('caught termination signal, exiting gracefully...')
global terminated
terminated = True
def trap_signal():
signal.signal(signal.SIGTERM, on_terminate)
signal.signal(signal.SIGINT, on_terminate)
if __name__ == '__main__':
# As per the SageMaker container spec, the algo takes a 'train' parameter.
# We will simply ignore this in this dummy implementation.
main()