diff --git a/doc/design/evaluator.md b/doc/design/evaluator.md new file mode 100644 index 0000000000000..a62d75ffef149 --- /dev/null +++ b/doc/design/evaluator.md @@ -0,0 +1,58 @@ +## Evaluator Design + +### The Problem + +During training or serving, we provide the evaluation function to measure the model performance, e.g., accuracy, precision. In the operator based framework design, the data go through the network pipeline batch by batch. As a result, inside the operator, we only can calculate one minibatch metrics. We need to provide a mechanism to calculate the metrics for each N pass/batch the user wanted. + +### Evaluator Design +Currently, every operation is expressed in the graph. we divide the evaluator process into three steps. + +1. Initialize the metric state and add it into the block. + +2. Calculate the statistic of the metric state in every mini-batch. The single operator is only responsible for calculating necessary statistics for one mini-batch. For example, accuracy operator only calculate a minibatch data if run once. + + +3. Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. When it comes to distributed training/Multi-GPU training, aggregate the value from different devices. + +### Implementation +This design is shown in python API. +Each metric operator need to caculate the metric statistic and return the batch aware states, Python side responsible for accumulate the states for each pass. + + +```python +class Evaluator(object): + """ + Evaluator Base class. + """ + def __init__(self, name, **kwargs): + """ + Different evaluator may has different metric states. E.g, Accuracy need two variables, total and right sample counts. + Auc need four variables, `true_positives`, + `true_negatives`, `false_positives` and `false_negatives`. So every evaluator should create its needed variables and append to main_program + + The initialization of Evaluator should be responsible for: + create metric states and append to the main_program + """ + pass + + def _update_ops(self, input, label, **kwargs) + """ + Add mini-batch evaluator caculate operators to the main_program. + Add increment operator to accumulate the metric states. + """ + + + def reset(self, executor, reset_program=None): + """ + Reset metric states at the begin of each pass/user specified batch number. + Execute the reset_program to reset the states. + """ + + + def eval(self, executor, eval_program=None): + """ + Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. + Execute the eval_program and return the result. + """ + return eval_result +``` diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc index 03c2fa945d94a..2785a8c6fb625 100644 --- a/paddle/operators/accuracy_op.cc +++ b/paddle/operators/accuracy_op.cc @@ -30,6 +30,10 @@ class AccuracyOp : public framework::OperatorWithKernel { "Input (Label) of accuracy op should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Accuracy"), "Output (Accuracy) of AccuracyOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Correct"), + "Output (Correct) of AccuracyOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Total"), + "Output (Total) of AccuracyOp should not be null."); auto inference_dim = ctx->GetInputDim("Out"); auto label_dim = ctx->GetInputDim("Label"); @@ -43,6 +47,8 @@ class AccuracyOp : public framework::OperatorWithKernel { " the same as label."); ctx->SetOutputDim("Accuracy", {1}); + ctx->SetOutputDim("Correct", {1}); + ctx->SetOutputDim("Total", {1}); ctx->ShareLoD("Out", /*->*/ "Accuracy"); } @@ -66,6 +72,8 @@ class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Label", "Label of the training data"); // TODO(typhoonzero): AddInput("Weight", ... AddOutput("Accuracy", "The accuracy of current batch"); + AddOutput("Correct", "The correct samples count of current batch"); + AddOutput("Total", "The samples count of current batch"); AddComment(R"DOC( Accuracy Operator. diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index 1776f33105367..b575c682f0d30 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -24,7 +24,8 @@ using platform::PADDLE_CUDA_NUM_THREADS; template __global__ void AccuracyCudaKernel(const int N, const int D, const int64_t* Xdata, - const int64_t* labeldata, float* accuracy) { + const int64_t* labeldata, int* correct_data, + float* accuracy) { int count = 0; __shared__ int total[BlockSize]; @@ -43,6 +44,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, // reduce the count with init value 0, and output accuracy. int result = thrust::reduce(thrust::device, total, total + BlockSize, 0); if (threadIdx.x == 0) { + *correct_data = result; *accuracy = static_cast(result) / static_cast(N); } } @@ -56,31 +58,48 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { auto* inference = ctx.Input("Out"); auto* indices = ctx.Input("Indices"); auto* label = ctx.Input("Label"); + auto* accuracy = ctx.Output("Accuracy"); + auto* correct = ctx.Output("Correct"); + auto* total = ctx.Output("Total"); // FIXME(typhoonzero): only support indices currently // if add support for output values, how to detect the data type? const int64_t* indices_data = indices->data(); const int64_t* label_data = label->data(); + + int* correct_data = correct->mutable_data(ctx.GetPlace()); + int* total_data = total->mutable_data(ctx.GetPlace()); float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); - size_t num_samples = inference->dims()[0]; + int num_samples = static_cast(inference->dims()[0]); size_t infer_width = inference->dims()[1]; PADDLE_ENFORCE(cudaMemset(accuracy_data, 0, sizeof(float))); + // cudaMemset((void**)&correct_data, 0, sizeof(float)); if (num_samples == 0) { return; } + cudaMemcpy(total_data, &num_samples, sizeof(int), cudaMemcpyHostToDevice); AccuracyCudaKernel<<< 1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>( - num_samples, infer_width, indices_data, label_data, accuracy_data); + num_samples, infer_width, indices_data, label_data, correct_data, + accuracy_data); + + int d_num_samples, d_num_correct; + float d_accuracy; + cudaMemcpy(&d_num_correct, correct_data, sizeof(int), + cudaMemcpyDeviceToHost); + cudaMemcpy(&d_num_samples, total_data, sizeof(int), cudaMemcpyDeviceToHost); + cudaMemcpy(&d_accuracy, accuracy_data, sizeof(float), + cudaMemcpyDeviceToHost); } }; } // namespace operators } // namespace paddle -// FIXME(typhoonzero): types of T is for infernece data. -// label data is always int +// FIXME(typhoonzero): types of T is for inference data. +// label data is always int64 REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel, paddle::operators::AccuracyOpCUDAKernel); diff --git a/paddle/operators/accuracy_op.h b/paddle/operators/accuracy_op.h index 28dbc77f64842..d060e6edddb31 100644 --- a/paddle/operators/accuracy_op.h +++ b/paddle/operators/accuracy_op.h @@ -29,7 +29,11 @@ class AccuracyKernel : public framework::OpKernel { auto* indices = ctx.Input("Indices"); auto* label = ctx.Input("Label"); auto* accuracy = ctx.Output("Accuracy"); + auto* correct = ctx.Output("Correct"); + auto* total = ctx.Output("Total"); + int* correct_data = correct->mutable_data(ctx.GetPlace()); + int* total_data = total->mutable_data(ctx.GetPlace()); float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); const int64_t* indices_data = indices->data(); @@ -55,7 +59,8 @@ class AccuracyKernel : public framework::OpKernel { } } - // FIXME(typhoonzero): we don't accumulate the accuracy for now. + *correct_data = num_correct; + *total_data = num_samples; *accuracy_data = static_cast(num_correct) / static_cast(num_samples); } diff --git a/paddle/operators/elementwise_add_op.cc b/paddle/operators/elementwise_add_op.cc index ebe1de90c7d24..432b9ba6f72f8 100644 --- a/paddle/operators/elementwise_add_op.cc +++ b/paddle/operators/elementwise_add_op.cc @@ -34,7 +34,13 @@ REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker, elementwise_add_grad, ops::ElementwiseOpGrad); REGISTER_OP_CPU_KERNEL( elementwise_add, - ops::ElementwiseAddKernel); + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel); REGISTER_OP_CPU_KERNEL( elementwise_add_grad, - ops::ElementwiseAddGradKernel); + ops::ElementwiseAddGradKernel, + ops::ElementwiseAddGradKernel, + ops::ElementwiseAddGradKernel, + ops::ElementwiseAddGradKernel); diff --git a/paddle/operators/elementwise_div_op.cc b/paddle/operators/elementwise_div_op.cc index de75816a24900..7a325199bd07e 100644 --- a/paddle/operators/elementwise_div_op.cc +++ b/paddle/operators/elementwise_div_op.cc @@ -35,7 +35,13 @@ REGISTER_OP(elementwise_div, ops::ElementwiseOp, ops::ElementwiseDivOpMaker, elementwise_div_grad, ops::ElementwiseOpGrad); REGISTER_OP_CPU_KERNEL( elementwise_div, - ops::ElementwiseDivKernel); + ops::ElementwiseDivKernel, + ops::ElementwiseDivKernel, + ops::ElementwiseDivKernel, + ops::ElementwiseDivKernel); REGISTER_OP_CPU_KERNEL( elementwise_div_grad, - ops::ElementwiseDivGradKernel); + ops::ElementwiseDivGradKernel, + ops::ElementwiseDivGradKernel, + ops::ElementwiseDivGradKernel, + ops::ElementwiseDivGradKernel); diff --git a/paddle/operators/elementwise_mul_op.cc b/paddle/operators/elementwise_mul_op.cc index ffa10486f1239..8851267a524f5 100644 --- a/paddle/operators/elementwise_mul_op.cc +++ b/paddle/operators/elementwise_mul_op.cc @@ -37,8 +37,12 @@ REGISTER_OP(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker, REGISTER_OP_CPU_KERNEL( elementwise_mul, ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel); + ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel, + ops::ElementwiseMulKernel); REGISTER_OP_CPU_KERNEL( elementwise_mul_grad, ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel); + ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel, + ops::ElementwiseMulGradKernel); diff --git a/paddle/operators/elementwise_sub_op.cc b/paddle/operators/elementwise_sub_op.cc index 39702dad0ee61..95d7979e39bfe 100644 --- a/paddle/operators/elementwise_sub_op.cc +++ b/paddle/operators/elementwise_sub_op.cc @@ -34,7 +34,13 @@ REGISTER_OP(elementwise_sub, ops::ElementwiseOp, ops::ElementwiseSubOpMaker, elementwise_sub_grad, ops::ElementwiseOpGrad); REGISTER_OP_CPU_KERNEL( elementwise_sub, - ops::ElementwiseSubKernel); + ops::ElementwiseSubKernel, + ops::ElementwiseSubKernel, + ops::ElementwiseSubKernel, + ops::ElementwiseSubKernel); REGISTER_OP_CPU_KERNEL( elementwise_sub_grad, - ops::ElementwiseSubGradKernel); + ops::ElementwiseSubGradKernel, + ops::ElementwiseSubGradKernel, + ops::ElementwiseSubGradKernel, + ops::ElementwiseSubGradKernel); diff --git a/python/paddle/v2/fluid/evaluator.py b/python/paddle/v2/fluid/evaluator.py index 180d0135ffe8f..3a8f1831cf2c4 100644 --- a/python/paddle/v2/fluid/evaluator.py +++ b/python/paddle/v2/fluid/evaluator.py @@ -1,59 +1,187 @@ -import paddle.v2.fluid.op as op import numpy as np +from paddle.v2.fluid.framework import Program, g_main_program, unique_name, Variable import paddle.v2.fluid.core as core -def avg_accumulate(accumulated_var, per_eval, num_batches, place): - t = np.array(accumulated_var.get_tensor()) - t[0] += per_eval[0] - accumulated_var.get_tensor().set([t[0] / float(num_batches)], place) +def _clone_var_in_block_(block, var): + assert isinstance(var, Variable) + return block.create_var( + name=var.name, + shape=var.shape, + dtype=var.data_type, + type=var.type, + lod_level=var.lod_level, + persistable=True) class Evaluator(object): - def __init__(self, - scope, - operator='accuracy', - input='Inference', - label='Label', - output='Output', - place=core.CPUPlace()): - """ - create an evaluator for evaluating the inference. - NOTE: default run on CPUPlace(), running on GPUPlace doesn't improve performance much. - - :param scope: the scope instance contains the input. - :type scope: paddle.v2.fluid.core.scope - :param operator: operator name for caculating the evaluation for each mini-batch. - :type operator: string - :param input: output variable name of forward network. - :type input: string - :param label: variable name of label - :type label: string - """ - self.scope = scope - self.place = place - self.output_name = output - self.num_batches = 0 - # create variable to store accumulated evaluator output - eval_name = ''.join([operator, "@Eval"]) - if scope.find_var(eval_name): - raise Exception("evaluator already exist in scope: %s" % eval_name) - self.accumulated_var = scope.var(eval_name) - t = self.accumulated_var.get_tensor() - t.set_dims((1, )) - t.set([0.0], place) - # self.accumulated_var = block.create_var(block, name=eval_name, shape=(1,)) - # self.accumulated_var.get_tensor().set([0.0]) - # create operator of evaluation - var_map = dict() # var name -> variable - var_map[input] = [input] - var_map[label] = [label] - var_map[output] = [output] - self.op = op.Operator(operator, **var_map) - - def evaluate(self, ctx, accumulator=avg_accumulate): - self.op.run(self.scope, ctx) - per_eval = np.array(self.scope.find_var(self.output_name).get_tensor()) - self.num_batches += 1 - accumulator(self.accumulated_var, per_eval, self.num_batches, - self.place) + """ + Evalutor Base class. + + create metric states + add mini-batch evaluator caculate operator + add increment operator to accumulate the metric states + """ + + def __init__(self, name, **kwargs): + """ + init the global states + """ + self._states = {} + if kwargs.has_key("main_program"): + self._main_program = kwargs.get("main_program") + else: + self._main_program = g_main_program + + def _update_ops(self, *args, **kwargs): + """ + append update ops to the global states + """ + raise NotImplementedError() + + def reset(self, executor, reset_program=None): + """ + Clear metric states at the begin of each pass/user specified batch + """ + if reset_program == None: + reset_program = Program() + else: + reset_program = program + block = reset_program.global_block() + for k, var in self._states.iteritems(): + g_var = _clone_var_in_block_(block, var) + zeros = block.create_var(dtype="float32", persistable=True) + block.append_op( + type="fill_constant", + outputs={"Out": [zeros]}, + attrs={ + "shape": g_var.shape, + "value": .0, + "data_type": 5, + }) + block.append_op( + type="scale", inputs={"X": zeros}, outputs={"Out": g_var}) + executor.run(reset_program, fetch_list=self._states.values()) + + def eval(self, executor, eval_program=None): + """ + Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. + """ + raise NotImplementedError() + + +class Accuracy(Evaluator): + """ + Accuracy need two state variable Total, Correct + """ + + def __init__(self, *args, **kwargs): + super(Accuracy, self).__init__("accuracy", **kwargs) + block = self._main_program.global_block() + g_total = block.create_var( + name=unique_name("Total"), + persistable=True, + dtype="int64", + shape=[1]) + g_correct = block.create_var( + name=unique_name("Correct"), + persistable=True, + dtype="int64", + shape=[1]) + self._states["Total"] = g_total + self._states["Correct"] = g_correct + + def _update_ops(self, input, label, k=1, **kwargs): + block = self._main_program.global_block() + topk_out = block.create_var(dtype=input.data_type) + topk_indices = block.create_var(dtype="int64") + block.append_op( + type="top_k", + inputs={"X": [input]}, + outputs={"Out": [topk_out], + "Indices": [topk_indices]}, + attrs={"k": k}) + acc_out = block.create_var(dtype=kwargs.get("out_dtype", "float32")) + correct = block.create_var(dtype="int64", persistable=True) + total = block.create_var(dtype="int64", persistable=True) + block.append_op( + type="accuracy", + inputs={ + "Out": [topk_out], + "Indices": [topk_indices], + "Label": [label] + }, + outputs={ + "Accuracy": [acc_out], + "Correct": [correct], + "Total": [total], + }) + + block.append_op( + type="cast", + inputs={"X": [self._states["Total"]]}, + outputs={"Out": [self._states["Total"]]}, + attrs={ + "in_data_type": 5, # float32 + "out_data_type": 2, #int32 + }) + block.append_op( + type="cast", + inputs={"X": [self._states["Correct"]]}, + outputs={"Out": [self._states["Correct"]]}, + attrs={ + "in_data_type": 5, + "out_data_type": 2, + }) + + block.append_op( + type="elementwise_add", + inputs={"X": [self._states["Total"]], + "Y": [total]}, + outputs={"Out": [self._states["Total"]]}) + block.append_op( + type="elementwise_add", + inputs={"X": [self._states["Correct"]], + "Y": [correct]}, + outputs={"Out": [self._states["Correct"]]}) + + return acc_out + + def eval(self, executor, eval_program=None): + if eval_program != None: + eval_program = eval_program + else: + eval_program = Program() + block = eval_program.global_block() + eval_out = block.create_var(dtype=self._states["Total"].data_type) + e_total = _clone_var_in_block_(block, self._states["Total"]) + e_correct = _clone_var_in_block_(block, self._states["Correct"]) + block.append_op( + type="cast", + inputs={"X": [e_total]}, + outputs={"Out": [e_total]}, + attrs={ + "in_data_type": 2, #int32 + "out_data_type": 5, #float32 + }) + block.append_op( + type="cast", + inputs={"X": [e_correct]}, + outputs={"Out": [e_correct]}, + attrs={ + "in_data_type": 2, + "out_data_type": 5, + }) + block.append_op( + type="elementwise_div", + inputs={"X": e_correct, + "Y": e_total}, + outputs={"Out": eval_out}) + out = executor.run(eval_program, fetch_list=[eval_out]) + return np.array(out[0]) + + +def accuracy(*args, **kwargs): + cls = Accuracy(*args, **kwargs) + out = cls._update_ops(*args, **kwargs) + return cls, out diff --git a/python/paddle/v2/fluid/layers.py b/python/paddle/v2/fluid/layers.py index 8a1aa1c42d5a0..b582f2ef6df4c 100644 --- a/python/paddle/v2/fluid/layers.py +++ b/python/paddle/v2/fluid/layers.py @@ -574,7 +574,9 @@ def accuracy(input, label, k=1, **kwargs): "Indices": [topk_indices]}, attrs={"k": k}) acc_out_dtype = kwargs.get("out_dtype", "float32") - acc_out = helper.create_tmp_variable(dtype=acc_out_dtype) + acc_out = helper.create_tmp_variable(dtype="float32") + correct = helper.create_tmp_variable(dtype="int64") + total = helper.create_tmp_variable(dtype="int64") helper.append_op( type="accuracy", inputs={ @@ -582,7 +584,11 @@ def accuracy(input, label, k=1, **kwargs): "Indices": [topk_indices], "Label": [label] }, - outputs={"Accuracy": [acc_out]}) + outputs={ + "Accuracy": [acc_out], + "Correct": [correct], + "Total": [total], + }) return acc_out diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py index 2b723125412c1..a10530bd823b5 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py @@ -3,6 +3,7 @@ import paddle.v2.fluid.nets as nets import paddle.v2.fluid.core as core import paddle.v2.fluid.optimizer as optimizer +import paddle.v2.fluid.evaluator as evaluator from paddle.v2.fluid.framework import Program from paddle.v2.fluid.executor import Executor @@ -54,17 +55,15 @@ main_program=main_program, startup_program=startup_program) avg_cost = layers.mean(x=cost, main_program=main_program) -accuracy = layers.accuracy( +optimizer = optimizer.AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999) +opts = optimizer.minimize(avg_cost, startup_program) + +accuracy, acc_out = evaluator.accuracy( input=predict, label=label, main_program=main_program, startup_program=startup_program) -# optimizer = optimizer.MomentumOptimizer(learning_rate=0.1 / 128.0, -# momentum=0.9) -optimizer = optimizer.AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999) -opts = optimizer.minimize(avg_cost, startup_program) - BATCH_SIZE = 50 PASS_NUM = 3 train_reader = paddle.batch( @@ -79,6 +78,7 @@ for pass_id in range(PASS_NUM): count = 0 + accuracy.reset(exe) for data in train_reader(): img_data = np.array(map(lambda x: x[0].reshape([1, 28, 28]), data)).astype("float32") @@ -93,11 +93,17 @@ outs = exe.run(main_program, feed={"pixel": tensor_img, "label": tensor_y}, - fetch_list=[avg_cost, accuracy]) + fetch_list=[avg_cost, acc_out]) loss = np.array(outs[0]) acc = np.array(outs[1]) - + pass_acc = accuracy.eval(exe) + print "pass id : ", pass_id, pass_acc + # print loss, acc if loss < 10.0 and acc > 0.9: # if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good. exit(0) + + pass_acc = accuracy.eval(exe) + print "pass id : ", pass_id, pass_acc + exit(1) diff --git a/python/paddle/v2/fluid/tests/test_accuracy_op.py b/python/paddle/v2/fluid/tests/test_accuracy_op.py index 6536c297e8e55..6f72918b7178b 100644 --- a/python/paddle/v2/fluid/tests/test_accuracy_op.py +++ b/python/paddle/v2/fluid/tests/test_accuracy_op.py @@ -18,7 +18,9 @@ def setUp(self): num_correct += 1 break self.outputs = { - 'Accuracy': np.array([num_correct / float(n)]).astype("float32") + 'Accuracy': np.array([num_correct / float(n)]).astype("float32"), + 'Correct': np.array([num_correct]).astype("int32"), + 'Total': np.array([n]).astype("int32") } def test_check_output(self): diff --git a/python/paddle/v2/fluid/tests/test_evaluator.py b/python/paddle/v2/fluid/tests/test_evaluator.py deleted file mode 100644 index 1d51205b703f8..0000000000000 --- a/python/paddle/v2/fluid/tests/test_evaluator.py +++ /dev/null @@ -1,64 +0,0 @@ -from paddle.v2.fluid.evaluator import Evaluator -from paddle.v2.fluid.op import Operator -import paddle.v2.fluid.core as core -import unittest -import op_test -import numpy as np - - -class TestEvaluator(unittest.TestCase): - def setup(self, scope, inputs, outputs): - def __create_var__(var_name, arr): - np_arr = np.array(arr) - scope.var(var_name) - # tensor = var.get_tensor() - # tensor.set_dims(np_arr.shape) - - for var_name, arr in inputs.iteritems(): - __create_var__(var_name, arr) - - for var_name, arr in outputs.iteritems(): - __create_var__(var_name, arr) - - def test_evaluator(self): - - inputs = { - 'Inference': np.array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 1]]).T, - 'Label': np.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]) - } - outputs = {'Accuracy': np.array([0.9])} - out_name = 'Accuracy' - - places = [core.CPUPlace()] - if core.is_compile_gpu(): - places.append(core.GPUPlace(0)) - - for place in places: - scope = core.Scope() - self.setup(scope, inputs, outputs) - - evaluator = Evaluator( - scope, - operator='accuracy', - input='Inference', - label='Label', - output=out_name, - place=place) - op_test.set_input(scope, evaluator.op, inputs, place) - ctx = core.DeviceContext.create(place) - - for i in range(10): # simulate 10 mini-batches - evaluator.evaluate(ctx) - - actual = np.array(scope.find_var(out_name).get_tensor()) - print actual - - self.assertTrue( - np.allclose( - actual, outputs[out_name], atol=1e-5), - "output name: " + out_name + " has diff.") - - -if __name__ == '__main__': - exit(0) - unittest.main() diff --git a/python/paddle/v2/framework/math_ops.py b/python/paddle/v2/framework/math_ops.py new file mode 100644 index 0000000000000..408656a75d676 --- /dev/null +++ b/python/paddle/v2/framework/math_ops.py @@ -0,0 +1,3 @@ +import paddle.v2.framework.core as core +from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, \ + Operator