This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Bucketing module prediction #5008
Copy link
Copy link
Closed
Description
def load_param(self, param_path):
node_data = mx.nd.load(param_path)
# pickle is not working although the doc says so
# param_file = open('./model/card-0001.params','rb')
# print pickle.load(param_file)
# parse node data into 2 lists
arg_param={}
aux_param={}
for k, v in node_data.items():
tp, name = k.split(':', 1)
if tp == 'arg':
arg_param[name] = v
if tp == 'aux':
aux_param[name] = v
print(arg_param)
print(aux_param)
return arg_param, aux_param
def __init_ocr(self):
batch_size = 1
self.predictor=mx.module.BucketingModule(sym_gen, default_bucket_key=self.default_bucket_key,context=mx.context.gpu(0))
arg_param, aux_param = self.load_param(self.path_of_params)
self.predictor.bind(self.data_shapes, self.label_shapes,for_training=False)
self.predictor.init_params(arg_params=arg_param, aux_params=aux_param)
def predict(self, data_batch):
self.predictor.forward(data_batch, is_train=False)
return self.predictor.get_outputs()
# sym_gen
def sym_gen(bucket_key):
label_length, lstm_length = bucket_key.split(',')
lstm_length = int(lstm_length)
label_length = int(label_length)
return get_symbol(num_classes, num_hidden, lstm_length, label_length), data_name, label_nameI tried to load my bucketing ocr using bucketing module with saved model. The prediction predicts nothing but blank in ctc for all examples even if I use training data for testing.
The training was doing ok with model converged and saved.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels