Skip to content

Commit

Permalink
Simplize Gradient Check (#5024)
Browse files Browse the repository at this point in the history
  • Loading branch information
reyoung committed Oct 23, 2017
1 parent cdb5f29 commit 9023248
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions python/paddle/v2/framework/tests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,12 @@ def get_backward_op(scope, op, no_grad_set):
return backward_op


def get_gradient(scope, op, inputs, outputs, grad_name, place,
def get_gradient(scope,
op,
inputs,
outputs,
grad_names,
place,
no_grad_set=None):
ctx = core.DeviceContext.create(place)

Expand All @@ -195,8 +200,10 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place,

backward_op.run(scope, ctx)

out = np.array(scope.find_var(grad_name).get_tensor())
return out
return [
np.array(scope.find_var(grad_name).get_tensor())
for grad_name in grad_names
]


def append_input_output(block, op_proto, np_list, is_input):
Expand Down Expand Up @@ -399,23 +406,19 @@ def check_grad(self,
]

cpu_place = core.CPUPlace()
cpu_analytic_grads = [
get_gradient(self.scope, self.op, self.inputs, self.outputs,
grad_name, cpu_place, no_grad_set)
for grad_name in grad_names
]
cpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs,
self.outputs, grad_names, cpu_place,
no_grad_set)

self.__assert_is_close(numeric_grads, cpu_analytic_grads, grad_names,
max_relative_error,
"Gradient Check On %s" % str(cpu_place))

if core.is_compile_gpu() and self.op.support_gpu():
gpu_place = core.GPUPlace(0)
gpu_analytic_grads = [
get_gradient(self.scope, self.op, self.inputs, self.outputs,
grad_name, gpu_place, no_grad_set)
for grad_name in grad_names
]
gpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs,
self.outputs, grad_names,
gpu_place, no_grad_set)

self.__assert_is_close(numeric_grads, gpu_analytic_grads,
grad_names, max_relative_error,
Expand Down

0 comments on commit 9023248

Please sign in to comment.