2
2
3
3
import sys
4
4
import os
5
- import pickle
5
+ from dict_pb2 import Dict as DictPB
6
6
import tensorflow as tf
7
7
8
8
class Data :
@@ -11,12 +11,16 @@ def __init__(self, dict_file, continuous_fields, sparse_fields, linear_fields):
11
11
self .ParseFields (continuous_fields , sparse_fields , linear_fields )
12
12
13
13
# load fieldid and its featureid dict
14
- # field : {featureid : sortid, featureid : sortid, 'miss' : sortid, 'num': feature_num}
15
14
def LoadDict (self , dict_file ):
16
- self .field_feature_dict = pickle .load (open (dict_file , 'rb' ))
17
- print ('load field num: ' + str (len (self .field_feature_dict )))
18
- for fieldid in self .field_feature_dict :
19
- print ('field: ' + str (fieldid ) + ' feature num: ' + str (self .field_feature_dict [fieldid ]['num' ]))
15
+ f = open (dict_file , 'rb' )
16
+ dict = DictPB ()
17
+ dict .ParseFromString (f .read ())
18
+ self .field_feature_dict = dict
19
+ f .close ()
20
+ print ('load field num: ' + str (len (self .field_feature_dict .field2missid )))
21
+ for fieldid , feanum in self .field_feature_dict .field2feanum .items ():
22
+ print ('field: ' + str (fieldid ) + ' feature num: ' + str (feanum ))
23
+
20
24
21
25
def Dict (self ):
22
26
assert self .field_feature_dict is not None
@@ -67,10 +71,10 @@ def StringToRecord(self, input_file, output_file):
67
71
if fieldid in field2feature :
68
72
for featureid in field2feature [fieldid ]:
69
73
value = field2feature [fieldid ][featureid ]
70
- feature_id_list .append (self .field_feature_dict [ fieldid ] [featureid ])
74
+ feature_id_list .append (self .field_feature_dict . featureid2sortid [featureid ])
71
75
feature_val_list .append (value )
72
76
else :
73
- feature_id_list .append (self .field_dict [fieldid ][ 'miss' ])
77
+ feature_id_list .append (self .field_feature_dict . field2missid [fieldid ])
74
78
feature_val_list .append (0.0 )
75
79
feature ['sparse_id_in_field_' + str (fieldid )] = tf .train .Feature (int64_list = tf .train .Int64List (value = feature_id_list ))
76
80
feature ['sparse_val_in_field_' + str (fieldid )] = tf .train .Feature (float_list = tf .train .FloatList (value = feature_val_list ))
@@ -80,10 +84,10 @@ def StringToRecord(self, input_file, output_file):
80
84
if fieldid in field2feature :
81
85
for featureid in field2feature [fieldid ]:
82
86
value = field2feature [fieldid ][featureid ]
83
- feature_id_list .append (self .field_feature_dict [ fieldid ] [featureid ])
87
+ feature_id_list .append (self .field_feature_dict . featureid2sortid [featureid ])
84
88
feature_val_list .append (value )
85
89
else :
86
- feature_id_list .append (self .field_dict [fieldid ][ 'miss' ])
90
+ feature_id_list .append (self .field_feature_dict . field2missid [fieldid ])
87
91
feature_val_list .append (0.0 )
88
92
feature ['linear_id_in_field_' + str (fieldid )] = tf .train .Feature (int64_list = tf .train .Int64List (value = feature_id_list ))
89
93
feature ['linear_val_in_field_' + str (fieldid )] = tf .train .Feature (float_list = tf .train .FloatList (value = feature_val_list ))
@@ -101,14 +105,17 @@ def StringToRecord(self, input_file, output_file):
101
105
writer .write (example .SerializeToString ())
102
106
103
107
writer .close ()
104
- print ('Successfully convert {} to {}' .format (input_file , output_file ))
108
+ print ('Successfully convert {} to {}' .format (input_file , output_file ))
105
109
106
110
def Decode (self , filename_queue ):
107
111
reader = tf .TFRecordReader ()
108
112
_ , serialized_example = reader .read (filename_queue )
109
113
return serialized_example
110
114
111
115
def ReadBatch (self , file_name , max_epoch , batch_size , thread_num , min_after_dequeue ):
116
+ '''
117
+ Return Tensor and SparseTensor parsed from tfrecord
118
+ '''
112
119
with tf .name_scope ('input' ):
113
120
filename_queue = tf .train .string_input_producer (
114
121
tf .train .match_filenames_once (file_name ), num_epochs = max_epoch )
@@ -144,6 +151,45 @@ def ReadBatch(self, file_name, max_epoch, batch_size, thread_num, min_after_dequ
144
151
continuous_val = instance ['continuous_val' ]
145
152
return instance ['label' ], sparse_id , sparse_val , linear_id , linear_val , continuous_val
146
153
154
+ def ReadBatchPlaceholder (self ):
155
+ '''
156
+ Return placeholder
157
+ '''
158
+ with tf .name_scope ('input' ):
159
+ sparse_id = []
160
+ sparse_val = []
161
+ for fieldid in self .sparse_field :
162
+ with tf .variable_scope ('sparse_' + str (fieldid )):
163
+ with tf .variable_scope ('index' ):
164
+ self .sparse_index = tf .placeholder (tf .int64 )
165
+ with tf .variable_scope ('id' ):
166
+ self .sparse_ids = tf .placeholder (tf .int64 )
167
+ with tf .variable_scope ('value' ):
168
+ self .sparse_vals = tf .placeholder (tf .float32 )
169
+ with tf .variable_scope ('shape' ):
170
+ self .sparse_shape = tf .placeholder (tf .int64 )
171
+ sparse_id .append (tf .SparseTensor (self .sparse_index , self .sparse_ids , self .sparse_shape ))
172
+ sparse_val .append (tf .SparseTensor (self .sparse_index , self .sparse_vals , self .sparse_shape ))
173
+ linear_id = []
174
+ linear_val = []
175
+ for fieldid in self .linear_field :
176
+ with tf .variable_scope ('linear_' + str (fieldid )):
177
+ with tf .variable_scope ('index' ):
178
+ self .linear_index = tf .placeholder (tf .int64 )
179
+ with tf .variable_scope ('id' ):
180
+ self .linear_ids = tf .placeholder (tf .int64 )
181
+ with tf .variable_scope ('value' ):
182
+ self .linear_vals = tf .placeholder (tf .float32 )
183
+ with tf .variable_scope ('shape' ):
184
+ self .linear_shape = tf .placeholder (tf .int64 )
185
+ linear_id .append (tf .SparseTensor (self .linear_index , self .linear_ids , self .linear_shape ))
186
+ linear_val .append (tf .SparseTensor (self .linear_index , self .linear_vals , self .linear_shape ))
187
+ with tf .variable_scope ('label' ):
188
+ self .label = tf .placeholder (tf .float32 )
189
+ with tf .variable_scope ('continuous' ):
190
+ self .continuous_val = tf .placeholder (tf .float32 )
191
+ return self .label , sparse_id , sparse_val , linear_id , linear_val , self .continuous_val
192
+
147
193
if __name__ == '__main__' :
148
194
if len (sys .argv ) != 7 :
149
195
print ('''
@@ -157,4 +203,4 @@ def ReadBatch(self, file_name, max_epoch, batch_size, thread_num, min_after_dequ
157
203
''' )
158
204
exit (1 )
159
205
data = Data (sys .argv [1 ], sys .argv [2 ], sys .argv [3 ], sys .argv [4 ])
160
- data .StringToRecord (sys .argv [5 ], sys .argv [6 ])
206
+ data .StringToRecord (sys .argv [5 ], sys .argv [6 ])
0 commit comments