Skip to content

Commit

Permalink
Big data op_test benchmark, for checking output consistent in differe…
Browse files Browse the repository at this point in the history
…nt runs. (#10646)

* "init benchmark ops"

* "untrack outputs"

* "delete some usused code"

* "benchmark"

* "fix ci"

* "fix op test"

* "fix uint16 missing"

* "fix ci"

* "follow comments"

* "fix ci"

* "follow comments"

* "conficts. merge develop branch"

* repick

* "merge develop branch"
  • Loading branch information
dzhwinter committed Jun 7, 2018
1 parent 3ff9ba0 commit f7c96f0
Show file tree
Hide file tree
Showing 8 changed files with 621 additions and 352 deletions.
6 changes: 4 additions & 2 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,10 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
}
if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type()));
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
"DataType of Paddle Op %s must be the same.", Type());
PADDLE_ENFORCE(
tmp == data_type || data_type == -1,
"DataType of Paddle Op %s must be the same. Get %d != %d", Type(),
data_type, tmp);
data_type = tmp;
}
}
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def to_name_str(var):
return var.desc.name()
elif isinstance(var, str):
return var
elif isinstance(var, basestring):
return str(var)
else:
raise TypeError(str(var) + " should be Variable or str")

Expand Down
25 changes: 16 additions & 9 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def convert_np_dtype_to_dtype_(np_dtype):
return core.VarDesc.VarType.INT64
elif dtype == np.bool:
return core.VarDesc.VarType.BOOL
elif dtype == np.uint16:
return core.VarDesc.VarType.INT16
elif dtype == np.uint8:
return core.VarDesc.VarType.UINT8
else:
Expand Down Expand Up @@ -368,6 +370,13 @@ class Operator(object):
Block. Users can use the build in instructions to describe their neural
network.
"""
OP_WITHOUT_KERNEL_SET = {
'feed', 'fetch', 'save', 'load', 'recurrent', 'go',
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv',
'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine',
'ncclInit', 'channel_create', 'channel_close', 'channel_send',
'channel_recv', 'select'
}

def __init__(self,
block,
Expand Down Expand Up @@ -504,17 +513,13 @@ def find_name(var_list, name):
else:
self.desc.set_attr(attr_name, self.attrs[attr_name])
self.desc.check_attrs()
no_kernel_op_set = {
'feed', 'fetch', 'save', 'load', 'recurrent', 'go',
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send',
'recv', 'listen_and_serv', 'parallel_do', 'save_combine',
'load_combine', 'ncclInit', 'channel_create', 'channel_close',
'channel_send', 'channel_recv', 'select', 'gen_nccl_id'
}
if type not in no_kernel_op_set:
if self.has_kernel(type):
self.desc.infer_var_type(self.block.desc)
self.desc.infer_shape(self.block.desc)

def has_kernel(self, op_type):
return op_type not in self.OP_WITHOUT_KERNEL_SET

def to_string(self, throw_on_error):
"""
To debug string.
Expand Down Expand Up @@ -742,7 +747,9 @@ def idx(self):

def var(self, name):
if not isinstance(name, basestring):
raise TypeError()
raise TypeError(
"var require string as parameter, but get %s instead." %
(type(name)))
v = self.vars.get(name, None)
if v is None:
raise ValueError("var %s not in this block" % name)
Expand Down
113 changes: 113 additions & 0 deletions python/paddle/fluid/tests/unittests/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import unittest
import time
import itertools

import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from op_test import OpTest


class BenchmarkSuite(OpTest):
def timeit_function(self, callback, iters, *args, **kwargs):
assert iters != 0, "Iters should >= 1"
start = time.time()
for i in range(iters):
callback(*args, **kwargs)
elapse = time.time() - start
return elapse / iters

def _assert_cpu_gpu_same(self, cpu_outs, gpu_outs, fetch_list, atol):
for item_cpu_out, item_gpu_out, variable in zip(cpu_outs, gpu_outs,
fetch_list):
# the cpu version is baseline, expect gpu version keep same with cpu version.
expect = item_cpu_out
expect_t = np.array(item_cpu_out)
actual = item_gpu_out
actual_t = np.array(item_gpu_out)
var_name = variable if isinstance(variable,
basestring) else variable.name
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol),
"Output (" + var_name + ") has diff" + str(actual_t) + "\n" +
str(expect_t))
self.assertListEqual(actual.lod(),
expect.lod(),
"Output (" + var_name + ") has different lod")

def _get_input_names(self):
inputs = []
for name, value in self.inputs.iteritems():
if isinstance(value, list):
inputs.extend([sub_name for sub_name, _ in value])
inputs.append(name)
return inputs

def _get_output_names(self):
outputs = []
for var_name, var in self.outputs.iteritems():
if isinstance(var, list):
for sub_var_name, sub_var in var:
outputs.append(sub_var_name)
else:
outputs.append(var_name)
if len(outputs) == 0:
for out_name, out_dup in Operator.get_op_outputs(self.op_type):
outputs.append(str(out_name))
return outputs

def check_output_stability(self, atol=1e-8):
places = self._get_places()
if len(places) < 2:
return
cpu_outs, fetch_list = self._calc_output(places[0])
gpu_outs, _ = self._calc_output(places[1])
self._assert_cpu_gpu_same(cpu_outs, gpu_outs, fetch_list, atol)

def timeit_output_with_place(self, place, iters):
return self.timeit_function(self.calc_output, iters, place)

def timeit_output(self, iters=100):
places = self._get_places()
elapses = []
for place in places:
elapses.append(self.timeit_output_with_place(place, iters))
for place, elapse in zip(places, elapses):
print("One pass of ({2}_op) at {0} cost {1}".format(
str(place), elapse, self.op_type))

def timeit_grad_with_place(self, place, iters=100):
inputs_to_check = self._get_input_names()
output_names = self._get_output_names()
return self.timeit_function(
self._get_gradient,
iters,
inputs_to_check,
place,
output_names,
no_grad_set=None)

def timeit_grad(self, iters=100):
places = self._get_places()
elapses = []
for place in places:
elapses.append(self.timeit_grad_with_place(place, iters))
for place, elapse in zip(places, elapses):
print("One pass of ({2}_grad_op) at {0} cost {1}".format(
str(place), elapse, self.op_type))
82 changes: 82 additions & 0 deletions python/paddle/fluid/tests/unittests/benchmark_sum_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np

import paddle.fluid as fluid
from benchmark import BenchmarkSuite
from op_test import OpTest

# This is a demo op test case for operator benchmarking and high resolution number stability alignment.


class TestSumOp(BenchmarkSuite):
def setUp(self):
self.op_type = "sum"
self.customize_testcase()
self.customize_fetch_list()

def customize_fetch_list(self):
"""
customize fetch list, configure the wanted variables.
>>> self.fetch_list = ["Out"]
"""
self.fetch_list = ["Out"]
# pass

def customize_testcase(self):
# a test case
x0 = np.random.random((300, 400)).astype('float32')
x1 = np.random.random((300, 400)).astype('float32')
x2 = np.random.random((300, 400)).astype('float32')

# NOTE: if the output is empty, then it will autofilled by benchmarkSuite.
# only the output dtype is used, the shape, lod and data is computed from input.
self.inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]}
self.outputs = {"Out": x0 + x1 + x2}

def test_check_output(self):
"""
compare the output with customized output. In this case,
you should set the correct output by hands.
>>> self.outputs = {"Out": x0 + x1 + x2}
"""
self.check_output(atol=1e-8)

def test_output_stability(self):
# compare the cpu gpu output in high resolution.
self.check_output_stability()

def test_timeit_output(self):
"""
perf the op, time cost will be averged in iters.
output example
>>> One pass of (sum_op) at CPUPlace cost 0.000461330413818
>>> One pass of (sum_op) at CUDAPlace(0) cost 0.000556070804596
"""
self.timeit_output(iters=100)

def test_timeit_grad(self):
"""
perf the op gradient, time cost will be averged in iters.
output example
>>> One pass of (sum_grad_op) at CPUPlace cost 0.00279935121536
>>> One pass of (sum_grad_op) at CUDAPlace(0) cost 0.00500632047653
"""
self.timeit_grad(iters=100)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit f7c96f0

Please sign in to comment.