-
Notifications
You must be signed in to change notification settings - Fork 0
/
eye_input.py
168 lines (113 loc) · 5.27 KB
/
eye_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
IMAGE_SIZE = 200
IMAGE_SIZE_label = 51
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 1000
HEIGHT=IMAGE_SIZE
WIDTH=IMAGE_SIZE
HEIGHT_label=IMAGE_SIZE_label
WIDTH_label=IMAGE_SIZE_label
NUMEVAL_SAMPLES=NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
height=HEIGHT
width=WIDTH
def read_fiber10(filename_queue):
class CIFAR10Record(object):
pass
result = CIFAR10Record()
result.height =HEIGHT
result.width = WIDTH
result.height_label =HEIGHT_label
result.width_label = WIDTH_label
result.depth = 2
result.depth_label = 2
image_bytes = result.height * result.width * result.depth
label_bytes = result.height_label * result.width_label * result.depth_label
record_bytes = label_bytes + image_bytes
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
result.key, value = reader.read(filename_queue)
# Convert from a string to a vector of uint8 that is record_bytes long.
record_bytes_label = tf.decode_raw(value, tf.uint8)
record_bytes_image = tf.decode_raw(value, tf.uint8)
depth_label_major = (tf.reshape(
tf.slice(record_bytes_label, [0],
[label_bytes]),
[result.depth,result.height_label, result.width_label]))
depth_label_major=tf.transpose(depth_label_major,[1,2,0])
depth_image_major = (tf.reshape(
tf.slice(record_bytes_image, [label_bytes],
[image_bytes]),
[result.depth,result.height, result.width]))
depth_image_major=tf.transpose(depth_image_major,[1,2,0])
result.uint8label = depth_label_major
result.uint8image=depth_image_major
return result
def _generate_image_and_label_batch(image0, label0, min_queue_examples,
batch_size, shuffle):
num_preprocess_threads = 16
num_preprocess_threads_eval = 1
if shuffle:
images, labels = tf.train.shuffle_batch(
[image0, label0],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
else:
images, labels = tf.train.batch(
[image0, label0],
batch_size=batch_size,
num_threads=num_preprocess_threads_eval,
capacity=min_queue_examples + 3 * batch_size)
return images, labels
def distorted_inputs(data_dir, batch_size,NUMTRAIN_SAMPLES):
filenames = [os.path.join(data_dir, 'fiber_train_data_%d.bin' % i)
for i in xrange(3, 4)]
print(filenames)
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
filename_queue = tf.train.string_input_producer(filenames)
with tf.name_scope('data_augmentation'):
read_input = read_fiber10(filename_queue)
reshaped_image = tf.image.convert_image_dtype(read_input.uint8image, tf.float32,saturate=False)
reshaped_label = tf.image.convert_image_dtype(read_input.uint8label, tf.float32,saturate=False)
reshaped_image.set_shape([HEIGHT, WIDTH, read_input.depth])
reshaped_label.set_shape([HEIGHT_label, WIDTH_label, read_input.depth])
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(NUMTRAIN_SAMPLES *
min_fraction_of_examples_in_queue)
print ('Filling queue with %d Fiber images before starting to train. '
'This will take a few minutes.' % min_queue_examples)
return _generate_image_and_label_batch(reshaped_image, reshaped_label,
min_queue_examples, batch_size,
shuffle=True)
def inputs(eval_data, data_dir, batch_size):
if not eval_data:
# filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
# for i in xrange(1, 2)]
filenames = [os.path.join(data_dir, 'fiber_train_data_%d.bin' % i)
for i in xrange(5, 6)]
num_examples_per_epoch = NUMTRAIN_SAMPLES
else:
filenames = [os.path.join(data_dir, 'fiber_test_data_3.bin')]
num_examples_per_epoch = NUMEVAL_SAMPLES
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer(filenames,shuffle=False)
read_input = read_fiber10(filename_queue)
reshaped_image = tf.image.convert_image_dtype(read_input.uint8image, tf.float32,saturate=False)
reshaped_label = tf.image.convert_image_dtype(read_input.uint8label, tf.float32,saturate=False)
reshaped_image.set_shape([HEIGHT, WIDTH, read_input.depth])
reshaped_label.set_shape([HEIGHT_label, WIDTH_label, read_input.depth])
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(num_examples_per_epoch *
min_fraction_of_examples_in_queue)
return _generate_image_and_label_batch(reshaped_image, reshaped_label,
min_queue_examples, batch_size,
shuffle=False)