-
Notifications
You must be signed in to change notification settings - Fork 8
/
train.py
236 lines (182 loc) · 7.03 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""
(Bayesian) FlowNet training module in Tensorflow
"""
import os
from os.path import dirname
import argparse
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.python.platform import flags
from tensorflow.python.training import saver as tf_saver
import flownet
import flownet_tools
import architectures
dir_path = dirname(os.path.realpath(__file__))
# Basic model parameters as external flags.
FLAGS = flags.FLAGS
# ARCHITECTURE
model = architectures.flownet_s
# DATA
flags.DEFINE_integer('d_shape_img', [384, 512, 3],
'Data shape: width, height, channels')
flags.DEFINE_integer('d_shape_flow', [384, 512, 2],
'Data shape: width, height, channels')
flags.DEFINE_integer('img_shape', [384, 512, 3],
'Image shape: width, height, channels')
flags.DEFINE_integer('flow_shape', [384, 512, 2],
'Image shape: width, height, 2')
flags.DEFINE_integer('record_bytes', 1572876,
'Flow record bytes reader for FlyingChairs')
# HYPERPARAMETER
flags.DEFINE_integer('batchsize', 8, 'Batch size.')
flags.DEFINE_integer('max_steps', 800000,
'Number of training steps.')
flags.DEFINE_integer('boundaries', [i * 100000 for i in range(3, FLAGS.max_steps / 100000)],
'boundaries for learning rate')
flags.DEFINE_integer('values', [1e-4 / (2**i) for i in range(0, FLAGS.max_steps / 100000 - 2)],
'learning rate values')
flags.DEFINE_integer('learning_rate', 1e-4, 'learning rate values')
flags.DEFINE_integer('drop_rate', 0.5, 'Dropout change')
flags.DEFINE_boolean('batch_normalization', False, 'Batch on/off')
flags.DEFINE_boolean('is_training', True, 'Batch on/off')
# TRAINING
flags.DEFINE_integer('img_summary_num', 1, 'Number of images in summary')
flags.DEFINE_integer('max_checkpoints', 5,
'Maximum number of recent checkpoints to keep.')
flags.DEFINE_float('keep_checkpoint_every_n_hours', 5.0,
'How often checkpoints should be kept.')
flags.DEFINE_integer('save_summaries_secs', 60,
'How often should summaries be saved (in seconds).')
flags.DEFINE_integer('save_interval_secs', 300,
'How often should checkpoints be saved (in seconds).')
flags.DEFINE_integer('log_every_n_steps', 100,
'Logging interval for slim training loop.')
flags.DEFINE_integer('trace_every_n_steps', 1000,
'Logging interval for trace.')
def apply_augmentation(imgs_0, imgs_1, flows):
""" Data augmentation devided in
- chromatic
- rotation / translation / crop (+ resize)
Keyword arguments:
imgs_0 -- first image of image pair (with length of bath size)
imgs_1 -- second image of image pair (with length of bath size)
flows -- ground truth optical flows between imgs_0, imgs_1
"""
if FLAGS.augmentation:
with tf.name_scope('Augmentation'):
# chromatic tranformation of images
#imgs_0, imgs_1 = flownet.fast_chromatic_augm(imgs_0, imgs_1)
# rotation / scaling / cropping (very important for flow)
imgs_0, imgs_1, flows = flownet.rotation_crop_trans(
imgs_0, imgs_1, flows)
# summary
#flownet.image_summary(imgs_0, imgs_1, "B_after_augm", flows)
return imgs_0, imgs_1, flows
def main(_):
"""Train FlowNet"""
with tf.Graph().as_default():
# get data
imgs_0, imgs_1, flows = flownet_tools.get_data(FLAGS.datadir, True)
# img summary after loading
#flownet.image_summary(imgs_0, imgs_1, "A_input", flows)
# apply augmentation
imgs_0, imgs_1, flows = apply_augmentation(imgs_0, imgs_1, flows)
# model
calc_flows = model(imgs_0, imgs_1, flows)
# img summary of result
flownet.image_summary(None, None, "E_result", calc_flows)
# global step and other config
global_step = slim.get_or_create_global_step()
train_op = flownet.create_train_op(global_step)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours)
# start slim training
slim.learning.train(
train_op,
logdir=FLAGS.logdir + '/train',
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs,
summary_op=tf.summary.merge_all(),
log_every_n_steps=FLAGS.log_every_n_steps,
trace_every_n_steps=FLAGS.trace_every_n_steps,
session_config=config,
saver=saver,
number_of_steps=FLAGS.max_steps,
)
if __name__ == "__main__":
# get arguments
parser = argparse.ArgumentParser()
parser.add_argument(
'--datadir',
type=str,
default='data/flying/train/',
help='Directory to put the input data.'
)
parser.add_argument(
'--logdir',
type=str,
default='with_data_aug',
help='Directory where to write event logs and checkpoints'
)
parser.add_argument(
'--dropout',
type=str,
default='true',
help='Trun dropout on/off'
)
parser.add_argument(
'--imgsummary',
type=str,
default="true",
help='Make image summary'
)
parser.add_argument(
'--augmentation',
type=str,
default="true",
help='Make data augmentation'
)
parser.add_argument(
'--weights_reg',
type=float,
default=1e-4,
help='weights regularizer'
)
args = parser.parse_known_args()[0]
FLAGS.datadir = os.path.join(dir_path, args.datadir)
FLAGS.logdir = os.path.join(dir_path, args.logdir)
# get boolean if data augmentation is wanted
aug = args.augmentation
if aug.lower() in ('yes', 'true'):
FLAGS.augmentation = True
print("Data augmentation on")
elif aug.lower() in ('no', 'false'):
FLAGS.augmentation = False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
drop = args.dropout
if drop.lower() in ('yes', 'true'):
FLAGS.dropout = True
print("Dropout on")
elif drop.lower() in ('no', 'false'):
FLAGS.dropout = False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
# get boolean if img summary is wanted
img_ = args.imgsummary
if img_.lower() in ('yes', 'true'):
FLAGS.imgsummary = True
elif img_.lower() in ('no', 'false'):
FLAGS.imgsummary = False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
# turn weight decay on/off
if args.weights_reg != 0:
print("Weight decay with: " + str(args.weights_reg))
FLAGS.weights_reg = slim.l1_regularizer(args.weights_reg)
else:
FLAGS.weights_reg = None
print("Using architecture: " + model.__name__)
tf.app.run()