This repository has been archived by the owner on Jun 11, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 101
/
data_manager.py
181 lines (140 loc) · 5.54 KB
/
data_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import re
import os
import numpy as np
import tensorflow as tf
from multiprocessing import Queue, Process
from utils import sparse_tuple_from, resize_image, label_to_array
from scipy.misc import imread
from trdg.generators import GeneratorFromDict
class DataManager(object):
def __init__(
self,
batch_size,
model_path,
examples_path,
max_image_width,
train_test_ratio,
max_char_count,
char_vector,
use_trdg,
language,
):
if train_test_ratio > 1.0 or train_test_ratio < 0:
raise Exception("Incoherent ratio!")
self.char_vector = char_vector
self.train_test_ratio = train_test_ratio
self.max_image_width = max_image_width
self.batch_size = batch_size
self.model_path = model_path
self.current_train_offset = 0
self.examples_path = examples_path
self.max_char_count = max_char_count
self.use_trdg = use_trdg
self.language = language
if self.use_trdg:
self.train_batches = self.multiprocess_batch_generator()
self.test_batches = self.multiprocess_batch_generator()
else:
self.data, self.data_len = self.load_data()
self.test_offset = int(train_test_ratio * self.data_len)
self.current_test_offset = self.test_offset
self.train_batches = self.generate_all_train_batches()
self.test_batches = self.generate_all_test_batches()
def batch_generator(self, queue):
"""Takes a queue and enqueue batches in it
"""
generator = GeneratorFromDict(language=self.language)
while True:
batch = []
while len(batch) < self.batch_size:
img, lbl = generator.next()
batch.append(
(
resize_image(np.array(img.convert("L")), self.max_image_width)[
0
],
lbl,
label_to_array(lbl, self.char_vector),
)
)
raw_batch_x, raw_batch_y, raw_batch_la = zip(*batch)
batch_y = np.reshape(np.array(raw_batch_y), (-1))
batch_dt = sparse_tuple_from(np.reshape(np.array(raw_batch_la), (-1)))
raw_batch_x = np.swapaxes(raw_batch_x, 1, 2)
raw_batch_x = raw_batch_x / 255.0
batch_x = np.reshape(
np.array(raw_batch_x), (len(raw_batch_x), self.max_image_width, 32, 1)
)
if queue.qsize() < 20:
queue.put((batch_y, batch_dt, batch_x))
else:
pass
def multiprocess_batch_generator(self):
"""Returns a batch generator to use in training
"""
q = Queue()
processes = []
for i in range(2):
processes.append(Process(target=self.batch_generator, args=(q,)))
processes[-1].start()
while True:
yield q.get()
def load_data(self):
"""Load all the images in the folder
"""
print("Loading data")
examples = []
count = 0
skipped = 0
for f in os.listdir(self.examples_path):
if len(f.split("_")[0]) > self.max_char_count:
continue
arr, initial_len = resize_image(
imread(os.path.join(self.examples_path, f), mode="L"),
self.max_image_width,
)
examples.append(
(
arr,
f.split("_")[0],
label_to_array(f.split("_")[0], self.char_vector),
)
)
count += 1
return examples, len(examples)
def generate_all_train_batches(self):
train_batches = []
while not self.current_train_offset + self.batch_size > self.test_offset:
old_offset = self.current_train_offset
new_offset = self.current_train_offset + self.batch_size
self.current_train_offset = new_offset
raw_batch_x, raw_batch_y, raw_batch_la = zip(
*self.data[old_offset:new_offset]
)
batch_y = np.reshape(np.array(raw_batch_y), (-1))
batch_dt = sparse_tuple_from(np.reshape(np.array(raw_batch_la), (-1)))
raw_batch_x = np.swapaxes(raw_batch_x, 1, 2)
raw_batch_x = raw_batch_x / 255.0
batch_x = np.reshape(
np.array(raw_batch_x), (len(raw_batch_x), self.max_image_width, 32, 1)
)
train_batches.append((batch_y, batch_dt, batch_x))
return train_batches
def generate_all_test_batches(self):
test_batches = []
while not self.current_test_offset + self.batch_size > self.data_len:
old_offset = self.current_test_offset
new_offset = self.current_test_offset + self.batch_size
self.current_test_offset = new_offset
raw_batch_x, raw_batch_y, raw_batch_la = zip(
*self.data[old_offset:new_offset]
)
batch_y = np.reshape(np.array(raw_batch_y), (-1))
batch_dt = sparse_tuple_from(np.reshape(np.array(raw_batch_la), (-1)))
raw_batch_x = np.swapaxes(raw_batch_x, 1, 2)
raw_batch_x = raw_batch_x / 255.0
batch_x = np.reshape(
np.array(raw_batch_x), (len(raw_batch_x), self.max_image_width, 32, 1)
)
test_batches.append((batch_y, batch_dt, batch_x))
return test_batches